feat(tools): 添加工具系统框架与 MCP 协议客户端
This commit is contained in:
@@ -18,6 +18,7 @@ futures-util = "0.3"
|
|||||||
futures-core = "0.3"
|
futures-core = "0.3"
|
||||||
bytes = "1"
|
bytes = "1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
|
tokio-util = { version = "0.7", features = ["rt"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
|
pub mod tools;
|
||||||
|
|
||||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||||
|
|
||||||
|
|||||||
+465
-1
@@ -19,7 +19,8 @@ use crate::llm::hooks::{HookContext, HookExecutor};
|
|||||||
use crate::llm::provider::LlmProvider;
|
use crate::llm::provider::LlmProvider;
|
||||||
use crate::llm::stream::StreamEvent;
|
use crate::llm::stream::StreamEvent;
|
||||||
use crate::llm::types::{
|
use crate::llm::types::{
|
||||||
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall,
|
||||||
|
ToolChoice, ToolDefinition,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// LLM 调用周期配置。
|
/// LLM 调用周期配置。
|
||||||
@@ -34,6 +35,15 @@ pub struct CycleConfig {
|
|||||||
pub max_turns: Option<u32>,
|
pub max_turns: Option<u32>,
|
||||||
/// 重试策略配置。
|
/// 重试策略配置。
|
||||||
pub retry: RetryConfig,
|
pub retry: RetryConfig,
|
||||||
|
/// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。
|
||||||
|
/// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。
|
||||||
|
pub max_tool_turns: Option<u32>,
|
||||||
|
/// 单个工具执行的超时秒数(0 表示不超时)。
|
||||||
|
/// 默认 60 秒。
|
||||||
|
pub tool_timeout_secs: u64,
|
||||||
|
/// 单个工具结果的最大字节数(超过此值将被截断)。
|
||||||
|
/// 默认 65536(64KB),防止大结果导致 token 膨胀。
|
||||||
|
pub max_tool_result_bytes: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CycleConfig {
|
impl Default for CycleConfig {
|
||||||
@@ -44,6 +54,9 @@ impl Default for CycleConfig {
|
|||||||
temperature: None,
|
temperature: None,
|
||||||
max_turns: None,
|
max_turns: None,
|
||||||
retry: RetryConfig::default(),
|
retry: RetryConfig::default(),
|
||||||
|
max_tool_turns: Some(10),
|
||||||
|
tool_timeout_secs: 60,
|
||||||
|
max_tool_result_bytes: 65_536,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -440,4 +453,455 @@ impl LlmCycle {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。
|
||||||
|
///
|
||||||
|
/// 用于 `submit_with_tools()` 的多轮 tool 循环。
|
||||||
|
async fn submit_request(
|
||||||
|
&mut self,
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let mut attempts = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let request = self.build_request(tools);
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.provider.chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let post_request = self.build_request(tools);
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
self.usage.add(&response.usage);
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||||
|
attempts += 1;
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||||
|
.with_error(&e)
|
||||||
|
.with_attempt(attempts);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let delay = self.config.retry.compute_delay(attempts);
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError)
|
||||||
|
.with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提交消息并自动处理工具调用循环。
|
||||||
|
///
|
||||||
|
/// 流程:
|
||||||
|
/// 1. 发送请求(含工具定义)
|
||||||
|
/// 2. 检查响应中的 finish_reason
|
||||||
|
/// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1
|
||||||
|
/// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应
|
||||||
|
///
|
||||||
|
/// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。
|
||||||
|
/// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。
|
||||||
|
pub async fn submit_with_tools(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
registry: &crate::tools::ToolRegistry,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let tools = registry.definitions();
|
||||||
|
let max_turns = self.config.max_tool_turns.unwrap_or(10);
|
||||||
|
let tool_timeout = self.config.tool_timeout_secs;
|
||||||
|
let max_bytes = self.config.max_tool_result_bytes;
|
||||||
|
|
||||||
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
self.maybe_compact();
|
||||||
|
|
||||||
|
let mut turn = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
turn += 1;
|
||||||
|
if turn > max_turns {
|
||||||
|
return Err(LlmError::Other(format!(
|
||||||
|
"达到最大工具循环轮次 ({max_turns})"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.submit_request(&tools).await?;
|
||||||
|
|
||||||
|
// 判断是否需要执行工具
|
||||||
|
let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls))
|
||||||
|
&& has_tool_calls_in_message(&response.message);
|
||||||
|
|
||||||
|
// 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史
|
||||||
|
self.messages.push(response.message.clone());
|
||||||
|
|
||||||
|
if !should_execute {
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 tool_calls 并执行
|
||||||
|
let tool_calls = extract_tool_calls_from_message(&response.message);
|
||||||
|
let calls: Vec<(String, serde_json::Value)> = tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|(_id, name, args)| {
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&args).unwrap_or(serde_json::Value::Null);
|
||||||
|
(name, args)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let results = registry.invoke_all(calls, tool_timeout).await;
|
||||||
|
|
||||||
|
// 回传工具结果
|
||||||
|
for result in results {
|
||||||
|
let content = match result.output {
|
||||||
|
Ok(value) => {
|
||||||
|
let serialized = serde_json::to_string(&value).unwrap_or_else(|e| {
|
||||||
|
tracing::warn!("工具结果序列化失败: {}", e);
|
||||||
|
"{}".to_string()
|
||||||
|
});
|
||||||
|
truncate_tool_result(&serialized, max_bytes)
|
||||||
|
}
|
||||||
|
Err(e) if e.is_recoverable() => format!("错误: {}", e),
|
||||||
|
Err(e) => {
|
||||||
|
// 不可恢复错误:终止循环
|
||||||
|
return Err(LlmError::Other(format!(
|
||||||
|
"工具 '{}' 不可恢复错误: {}",
|
||||||
|
result.tool_name, e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.messages
|
||||||
|
.push(OpenaiChatMessage::tool_result(result.tool_name, content));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 每轮工具执行后触发 compaction
|
||||||
|
self.maybe_compact();
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable: loop returns
|
||||||
|
#[allow(unreachable_code)]
|
||||||
|
{
|
||||||
|
Err(LlmError::Other("unreachable".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 在接近上下文窗口时压缩历史消息。
|
||||||
|
fn maybe_compact(&mut self) {
|
||||||
|
if let Some(ref config) = self.compact_config
|
||||||
|
&& should_compact(&self.messages, config, &self.compact_state)
|
||||||
|
{
|
||||||
|
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 判断 Assistant 消息是否包含 tool_calls。
|
||||||
|
fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool {
|
||||||
|
matches!(
|
||||||
|
msg,
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} if !calls.is_empty()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提取 Assistant 消息中的 tool_calls。
|
||||||
|
///
|
||||||
|
/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。
|
||||||
|
fn extract_tool_calls_from_message(
|
||||||
|
msg: &OpenaiChatMessage,
|
||||||
|
) -> Vec<(String, String, String)> {
|
||||||
|
if let OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} = msg
|
||||||
|
{
|
||||||
|
calls
|
||||||
|
.iter()
|
||||||
|
.map(|c| match c {
|
||||||
|
OpenaiToolCall::Function { id, function } => {
|
||||||
|
(id.clone(), function.name.clone(), function.arguments.clone())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 截断工具结果到指定字节数。
|
||||||
|
fn truncate_tool_result(s: &str, max_bytes: usize) -> String {
|
||||||
|
if s.len() <= max_bytes {
|
||||||
|
return s.to_string();
|
||||||
|
}
|
||||||
|
let mut end = max_bytes;
|
||||||
|
while end > 0 && !s.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||||
|
use crate::tools::{BaseTool, ToolRegistry};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
/// 模拟 Provider —— 预定义响应序列,按调用顺序返回。
|
||||||
|
struct MockProvider {
|
||||||
|
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||||||
|
call_count: std::sync::Mutex<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||||
|
Self {
|
||||||
|
responses: std::sync::Mutex::new(responses),
|
||||||
|
call_count: std::sync::Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LlmProvider for MockProvider {
|
||||||
|
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||||
|
let mut count = self.call_count.lock().unwrap();
|
||||||
|
*count += 1;
|
||||||
|
let mut responses = self.responses.lock().unwrap();
|
||||||
|
if responses.is_empty() {
|
||||||
|
return Err(LlmError::Other("no more mock responses".into()));
|
||||||
|
}
|
||||||
|
Ok(responses.remove(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn empty_usage() -> crate::llm::types::Usage {
|
||||||
|
crate::llm::types::Usage::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_text_response(text: &str) -> ChatResponse {
|
||||||
|
ChatResponse {
|
||||||
|
message: OpenaiChatMessage::assistant_text(text),
|
||||||
|
usage: empty_usage(),
|
||||||
|
stop_reason: Some(FinishReason::Stop),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_tool_call_response(
|
||||||
|
calls: Vec<(&str, &str, &str)>,
|
||||||
|
) -> ChatResponse {
|
||||||
|
use crate::llm::types::{OpenaiToolCall, FunctionCall};
|
||||||
|
let tool_calls: Vec<OpenaiToolCall> = calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|(id, name, args)| OpenaiToolCall::Function {
|
||||||
|
id: id.to_string(),
|
||||||
|
function: FunctionCall {
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: args.to_string(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
ChatResponse {
|
||||||
|
message: OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||||
|
text: String::new(),
|
||||||
|
}]),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
},
|
||||||
|
usage: empty_usage(),
|
||||||
|
stop_reason: Some(FinishReason::ToolCalls),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AddTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for AddTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"add"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"加法"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}})
|
||||||
|
}
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
args: Value,
|
||||||
|
_ctx: &crate::tools::ToolContext<'_>,
|
||||||
|
) -> Result<Value, crate::tools::ToolError> {
|
||||||
|
let a = args["a"].as_i64().unwrap_or(0);
|
||||||
|
let b = args["b"].as_i64().unwrap_or(0);
|
||||||
|
Ok(json!({"result": a + b}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_single_turn() {
|
||||||
|
// 第一轮:返回 tool_call;第二轮:返回最终文本
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||||
|
assistant_text_response("答案是 3"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("1+2=?".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// 验证最终响应是文本响应
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// 验证消息历史:user, assistant(tool_calls), tool, assistant(text)
|
||||||
|
let messages = cycle.messages();
|
||||||
|
assert_eq!(messages.len(), 4);
|
||||||
|
assert!(matches!(messages[0], OpenaiChatMessage::User { .. }));
|
||||||
|
assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. }));
|
||||||
|
assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. }));
|
||||||
|
assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_multi_turn() {
|
||||||
|
// 3 轮 tool 调用后给出最终答案
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]),
|
||||||
|
assistant_text_response("完成"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("计算总和".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// user + 3*(assistant + tool) + final assistant = 8
|
||||||
|
let messages = cycle.messages();
|
||||||
|
assert_eq!(messages.len(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_max_turns_exceeded() {
|
||||||
|
// 配置 max_tool_turns = 2
|
||||||
|
let mut config = CycleConfig::default();
|
||||||
|
config.max_tool_turns = Some(2);
|
||||||
|
// 4 轮 tool 调用 + 终止
|
||||||
|
let responses = vec![
|
||||||
|
assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]),
|
||||||
|
assistant_text_response("完成"),
|
||||||
|
];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, config);
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let result = cycle
|
||||||
|
.submit_with_tools("test".to_string(), ®istry)
|
||||||
|
.await;
|
||||||
|
assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_submit_with_tools_no_tool_call_response() {
|
||||||
|
// LLM 直接给出最终响应(不调用工具)
|
||||||
|
let responses = vec![assistant_text_response("直接回答")];
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut cycle = LlmCycle::new(provider, CycleConfig::default());
|
||||||
|
|
||||||
|
let mut registry = ToolRegistry::new();
|
||||||
|
registry.register(std::sync::Arc::new(AddTool)).unwrap();
|
||||||
|
|
||||||
|
let response = cycle
|
||||||
|
.submit_with_tools("直接回答".to_string(), ®istry)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(
|
||||||
|
response.message,
|
||||||
|
OpenaiChatMessage::Assistant { .. }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_short() {
|
||||||
|
let s = "short text";
|
||||||
|
assert_eq!(truncate_tool_result(s, 100), "short text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_long() {
|
||||||
|
let s = "a".repeat(1000);
|
||||||
|
let truncated = truncate_tool_result(&s, 50);
|
||||||
|
assert!(truncated.len() < s.len());
|
||||||
|
assert!(truncated.contains("[... truncated,"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_truncate_tool_result_chinese_chars() {
|
||||||
|
let s = "中".repeat(100);
|
||||||
|
let truncated = truncate_tool_result(&s, 50);
|
||||||
|
// 不会在字符中间截断
|
||||||
|
assert!(truncated.starts_with("中"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。
|
||||||
|
|
||||||
|
pub mod base;
|
||||||
|
pub mod error;
|
||||||
|
pub mod mcp;
|
||||||
|
pub mod permission;
|
||||||
|
pub mod registry;
|
||||||
|
|
||||||
|
pub use base::{BaseTool, ToolContext, ToolRef};
|
||||||
|
pub use error::ToolError;
|
||||||
|
pub use mcp::{McpClient, McpTransport};
|
||||||
|
pub use permission::{Permission, PermissionChecker, PermissionConfig};
|
||||||
|
pub use registry::{ToolInvocation, ToolRegistry};
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
//! 工具抽象接口与执行上下文。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::Permission;
|
||||||
|
|
||||||
|
/// 工具执行上下文 —— 携带每次执行的运行时信息。
|
||||||
|
///
|
||||||
|
/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现
|
||||||
|
/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`),
|
||||||
|
/// 但已有工具实现无需修改。
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ToolContext<'a> {
|
||||||
|
/// 当前对话/会话 ID,用于关联性追踪。
|
||||||
|
pub session_id: &'a str,
|
||||||
|
/// 链路追踪 ID,用于跨工具调用的耗时分布。
|
||||||
|
pub trace_id: &'a str,
|
||||||
|
/// 取消令牌,用于优雅取消正在执行的工具。
|
||||||
|
pub cancellation_token: CancellationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ToolContext<'a> {
|
||||||
|
/// 创建一个新的工具执行上下文。
|
||||||
|
pub fn new(session_id: &'a str, trace_id: &'a str) -> Self {
|
||||||
|
Self {
|
||||||
|
session_id,
|
||||||
|
trace_id,
|
||||||
|
cancellation_token: CancellationToken::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个使用给定取消令牌的上下文。
|
||||||
|
pub fn with_cancellation_token(
|
||||||
|
session_id: &'a str,
|
||||||
|
trace_id: &'a str,
|
||||||
|
token: CancellationToken,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
session_id,
|
||||||
|
trace_id,
|
||||||
|
cancellation_token: token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait BaseTool: Send + Sync {
|
||||||
|
/// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具描述(LLM 据此决定是否调用此工具)。
|
||||||
|
fn description(&self) -> &str;
|
||||||
|
|
||||||
|
/// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。
|
||||||
|
fn parameters(&self) -> Value;
|
||||||
|
|
||||||
|
/// 声明工具所需的权限列表。
|
||||||
|
fn required_permissions(&self) -> Vec<Permission> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 执行工具调用。
|
||||||
|
///
|
||||||
|
/// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token),
|
||||||
|
/// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。
|
||||||
|
async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result<Value, ToolError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 为 `Arc<dyn BaseTool>` 提供 `Send + Sync` 包装,便于在 `Vec<Arc<dyn BaseTool>>` 中使用。
|
||||||
|
pub type ToolRef = Arc<dyn BaseTool>;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for EchoTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"回显输入"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Ok(args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_mock_tool_execute() {
|
||||||
|
let tool = EchoTool;
|
||||||
|
let ctx = ToolContext::new("session-1", "trace-1");
|
||||||
|
let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap();
|
||||||
|
assert_eq!(result, json!({"text": "hello"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_default_permissions_empty() {
|
||||||
|
let tool = EchoTool;
|
||||||
|
assert!(tool.required_permissions().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_context_creation() {
|
||||||
|
let ctx = ToolContext::new("s1", "t1");
|
||||||
|
assert_eq!(ctx.session_id, "s1");
|
||||||
|
assert_eq!(ctx.trace_id, "t1");
|
||||||
|
assert!(!ctx.cancellation_token.is_cancelled());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_context_cancellation() {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
token.cancel();
|
||||||
|
let ctx = ToolContext::with_cancellation_token("s1", "t1", token);
|
||||||
|
assert!(ctx.cancellation_token.is_cancelled());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
//! 工具系统错误类型。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// 工具调用过程中可能发生的所有错误。
|
||||||
|
#[derive(thiserror::Error, Debug, Clone)]
|
||||||
|
pub enum ToolError {
|
||||||
|
/// 工具未注册。
|
||||||
|
#[error("工具 '{0}' 未注册")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
/// 工具执行失败(可恢复——文本回传 LLM)。
|
||||||
|
#[error("工具 '{0}' 执行失败: {1}")]
|
||||||
|
ExecutionFailed(String, String),
|
||||||
|
|
||||||
|
/// 工具参数无效(可恢复——文本回传 LLM)。
|
||||||
|
#[error("工具 '{0}' 参数无效: {1}")]
|
||||||
|
InvalidArguments(String, String),
|
||||||
|
|
||||||
|
/// 权限被拒绝(不可恢复——终止循环)。
|
||||||
|
#[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")]
|
||||||
|
PermissionDenied(String, String),
|
||||||
|
|
||||||
|
/// MCP 协议错误(不可恢复)。
|
||||||
|
#[error("MCP 协议错误: {0}")]
|
||||||
|
McpError(String),
|
||||||
|
|
||||||
|
/// MCP 未初始化(不可恢复)。
|
||||||
|
#[error("MCP 未初始化: {0}")]
|
||||||
|
McpNotInitialized(String),
|
||||||
|
|
||||||
|
/// MCP 超时(不可恢复)。
|
||||||
|
#[error("MCP 超时: {0}")]
|
||||||
|
McpTimeout(String),
|
||||||
|
|
||||||
|
/// IO 错误(不可恢复)。
|
||||||
|
#[error("IO 错误: {0}")]
|
||||||
|
Io(Arc<std::io::Error>),
|
||||||
|
|
||||||
|
/// 取消。
|
||||||
|
#[error("工具执行已取消: {0}")]
|
||||||
|
Cancelled(String),
|
||||||
|
|
||||||
|
/// 其他未分类错误。
|
||||||
|
#[error("其他错误: {0}")]
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<std::io::Error> for ToolError {
|
||||||
|
fn from(e: std::io::Error) -> Self {
|
||||||
|
ToolError::Io(Arc::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolError {
|
||||||
|
/// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试,
|
||||||
|
/// 不可恢复的错误终止自动 tool 循环并返回给调用方。
|
||||||
|
pub fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self,
|
||||||
|
Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_execution_failed_is_recoverable() {
|
||||||
|
let err = ToolError::ExecutionFailed("foo".into(), "boom".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_arguments_is_recoverable() {
|
||||||
|
let err = ToolError::InvalidArguments("foo".into(), "missing x".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_not_found_is_not_recoverable() {
|
||||||
|
let err = ToolError::NotFound("foo".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_permission_denied_is_not_recoverable() {
|
||||||
|
let err = ToolError::PermissionDenied("foo".into(), "Shell".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mcp_error_is_not_recoverable() {
|
||||||
|
let err = ToolError::McpError("protocol".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mcp_timeout_is_not_recoverable() {
|
||||||
|
let err = ToolError::McpTimeout("foo".into());
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_io_is_not_recoverable() {
|
||||||
|
let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk");
|
||||||
|
let err = ToolError::from(io_err);
|
||||||
|
assert!(!err.is_recoverable());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_other_is_recoverable() {
|
||||||
|
let err = ToolError::Other("something".into());
|
||||||
|
assert!(err.is_recoverable());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,640 @@
|
|||||||
|
//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。
|
||||||
|
//!
|
||||||
|
//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留,
|
||||||
|
//! 但实际实现推迟到后续版本。
|
||||||
|
//!
|
||||||
|
//! ## 协议版本
|
||||||
|
//!
|
||||||
|
//! 实现遵循 MCP 协议版本 2025-03-26。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::process::Stdio;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||||||
|
use tokio::sync::{oneshot, Mutex};
|
||||||
|
|
||||||
|
use crate::llm::types::ToolDefinition;
|
||||||
|
use crate::tools::base::{BaseTool, ToolContext, ToolRef};
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
/// MCP 协议版本。
|
||||||
|
const MCP_VERSION: &str = "2025-03-26";
|
||||||
|
|
||||||
|
/// MCP 传输方式。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum McpTransport {
|
||||||
|
/// 通过子进程 stdin/stdout 通信。
|
||||||
|
Stdio {
|
||||||
|
/// 启动命令(如 `"npx"`)。
|
||||||
|
command: String,
|
||||||
|
/// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。
|
||||||
|
args: Vec<String>,
|
||||||
|
},
|
||||||
|
/// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。
|
||||||
|
///
|
||||||
|
/// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。
|
||||||
|
StreamableHttp {
|
||||||
|
/// MCP 端点 URL。
|
||||||
|
url: String,
|
||||||
|
/// 可选的 HTTP 头(如 Authorization)。
|
||||||
|
headers: Option<Vec<(String, String)>>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON-RPC 请求。
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcRequest {
|
||||||
|
jsonrpc: &'static str,
|
||||||
|
id: u64,
|
||||||
|
method: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
params: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JsonRpcRequest {
|
||||||
|
fn new(id: u64, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||||
|
Self {
|
||||||
|
jsonrpc: "2.0",
|
||||||
|
id,
|
||||||
|
method: method.into(),
|
||||||
|
params,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON-RPC 响应。
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcResponse {
|
||||||
|
jsonrpc: String,
|
||||||
|
id: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
result: Option<Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
error: Option<JsonRpcError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct JsonRpcError {
|
||||||
|
code: i32,
|
||||||
|
message: String,
|
||||||
|
#[serde(default)]
|
||||||
|
data: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 子进程运行时状态。
|
||||||
|
struct ChildProcessState {
|
||||||
|
child: Child,
|
||||||
|
stdin: ChildStdin,
|
||||||
|
pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>>,
|
||||||
|
next_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChildProcessState {
|
||||||
|
fn next_id(&mut self) -> u64 {
|
||||||
|
self.next_id += 1;
|
||||||
|
self.next_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP Server 暴露的工具(缓存结构)。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct McpTool {
|
||||||
|
name: String,
|
||||||
|
description: Option<String>,
|
||||||
|
input_schema: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 客户端 —— 与 MCP 服务器通信。
|
||||||
|
pub struct McpClient {
|
||||||
|
transport: McpTransport,
|
||||||
|
server_name: String,
|
||||||
|
/// 已初始化的工具列表(缓存)。
|
||||||
|
tools: Vec<McpTool>,
|
||||||
|
/// 是否已初始化。
|
||||||
|
initialized: AtomicBool,
|
||||||
|
/// 超时时间(秒)。
|
||||||
|
timeout_secs: u64,
|
||||||
|
/// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。
|
||||||
|
process: Option<Arc<Mutex<ChildProcessState>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for McpClient {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("McpClient")
|
||||||
|
.field("server_name", &self.server_name)
|
||||||
|
.field("initialized", &self.initialized.load(Ordering::SeqCst))
|
||||||
|
.field("tool_count", &self.tools.len())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpClient {
|
||||||
|
/// 创建一个 MCP 客户端。
|
||||||
|
pub fn new(server_name: impl Into<String>, transport: McpTransport) -> Self {
|
||||||
|
Self {
|
||||||
|
transport,
|
||||||
|
server_name: server_name.into(),
|
||||||
|
tools: Vec::new(),
|
||||||
|
initialized: AtomicBool::new(false),
|
||||||
|
timeout_secs: 30,
|
||||||
|
process: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置超时时间(秒)。
|
||||||
|
pub fn with_timeout(mut self, secs: u64) -> Self {
|
||||||
|
self.timeout_secs = secs;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查是否已连接。
|
||||||
|
pub fn is_initialized(&self) -> bool {
|
||||||
|
self.initialized.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 连接并初始化(发送 initialize 请求)。
|
||||||
|
pub async fn connect(&mut self) -> Result<(), ToolError> {
|
||||||
|
if self.is_initialized() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match &self.transport {
|
||||||
|
McpTransport::Stdio { command, args } => {
|
||||||
|
let mut cmd = Command::new(command);
|
||||||
|
cmd.args(args)
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped());
|
||||||
|
#[cfg(unix)]
|
||||||
|
cmd.kill_on_drop(true);
|
||||||
|
#[cfg(windows)]
|
||||||
|
cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW
|
||||||
|
|
||||||
|
let mut child = cmd
|
||||||
|
.spawn()
|
||||||
|
.map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?;
|
||||||
|
|
||||||
|
let stdin = child
|
||||||
|
.stdin
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?;
|
||||||
|
let stdout = child
|
||||||
|
.stdout
|
||||||
|
.take()
|
||||||
|
.ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?;
|
||||||
|
|
||||||
|
// 启动 reader task 持续读取 stdout
|
||||||
|
let pending: HashMap<u64, oneshot::Sender<Result<Value, ToolError>>> =
|
||||||
|
HashMap::new();
|
||||||
|
let state = Arc::new(Mutex::new(ChildProcessState {
|
||||||
|
child,
|
||||||
|
stdin,
|
||||||
|
pending,
|
||||||
|
next_id: 0,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// 启动后台 reader
|
||||||
|
let pending_arc = Arc::clone(&state);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
Self::read_loop(BufReader::new(stdout), pending_arc).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
self.process = Some(state);
|
||||||
|
}
|
||||||
|
McpTransport::StreamableHttp { .. } => {
|
||||||
|
return Err(ToolError::McpError(
|
||||||
|
"StreamableHttp transport 尚未实现".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送 initialize 请求
|
||||||
|
let init_params = json!({
|
||||||
|
"protocolVersion": MCP_VERSION,
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "agcore",
|
||||||
|
"version": env!("CARGO_PKG_VERSION")
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let _response = self
|
||||||
|
.send_request("initialize", Some(init_params))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// 发送 initialized 通知(无 id)
|
||||||
|
self.send_notification("notifications/initialized", Some(json!({})))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.initialized.store(true, Ordering::SeqCst);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出服务器支持的工具(调用 `tools/list`)。
|
||||||
|
pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>, ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.send_request("tools/list", None).await?;
|
||||||
|
let tools_value = response
|
||||||
|
.get("tools")
|
||||||
|
.ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?;
|
||||||
|
let tools_arr = tools_value
|
||||||
|
.as_array()
|
||||||
|
.ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?;
|
||||||
|
|
||||||
|
self.tools.clear();
|
||||||
|
let mut defs = Vec::with_capacity(tools_arr.len());
|
||||||
|
for tool in tools_arr {
|
||||||
|
let name = tool
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))?
|
||||||
|
.to_string();
|
||||||
|
let description = tool
|
||||||
|
.get("description")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let input_schema = tool
|
||||||
|
.get("inputSchema")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| json!({"type": "object", "properties": {}}));
|
||||||
|
|
||||||
|
self.tools.push(McpTool {
|
||||||
|
name: name.clone(),
|
||||||
|
description: description.clone(),
|
||||||
|
input_schema: input_schema.clone(),
|
||||||
|
});
|
||||||
|
defs.push(ToolDefinition {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
parameters: input_schema,
|
||||||
|
strict: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(defs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 调用一个工具(调用 `tools/call`)。
|
||||||
|
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Err(ToolError::McpNotInitialized(self.server_name.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = json!({
|
||||||
|
"name": name,
|
||||||
|
"arguments": args,
|
||||||
|
});
|
||||||
|
let response = self.send_request("tools/call", Some(params)).await?;
|
||||||
|
|
||||||
|
// 解析 content 字段
|
||||||
|
if let Some(content) = response.get("content").and_then(|c| c.as_array()) {
|
||||||
|
// 收集所有 text 内容
|
||||||
|
let mut combined = String::new();
|
||||||
|
for item in content {
|
||||||
|
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||||
|
if !combined.is_empty() {
|
||||||
|
combined.push('\n');
|
||||||
|
}
|
||||||
|
combined.push_str(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !combined.is_empty() {
|
||||||
|
return Ok(Value::String(combined));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有 content 字段,尝试直接返回 is_error 标记
|
||||||
|
if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) {
|
||||||
|
return Err(ToolError::ExecutionFailed(
|
||||||
|
name.to_string(),
|
||||||
|
"MCP 工具返回 isError=true".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回退:返回完整响应
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 关闭连接(终止子进程)。
|
||||||
|
pub async fn close(&mut self) -> Result<(), ToolError> {
|
||||||
|
if !self.is_initialized() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试发送 shutdown(不强制要求响应)
|
||||||
|
let _ = self.send_notification("shutdown", None).await;
|
||||||
|
|
||||||
|
if let Some(state) = self.process.take() {
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
// 优雅等待 5 秒
|
||||||
|
let graceful = tokio::time::timeout(
|
||||||
|
Duration::from_secs(5),
|
||||||
|
state.child.wait(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
if graceful.is_err() {
|
||||||
|
// 超时则强杀
|
||||||
|
let _ = state.child.kill().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.initialized.store(false, Ordering::SeqCst);
|
||||||
|
self.tools.clear();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。
|
||||||
|
///
|
||||||
|
/// **注意**:返回的适配器持有 `Arc<McpClient>`,但 `McpClient` 的可变性
|
||||||
|
/// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存
|
||||||
|
/// 转换时的工具列表,不感知后续刷新。
|
||||||
|
pub fn into_tools(self) -> Vec<ToolRef> {
|
||||||
|
let mut tools = Vec::with_capacity(self.tools.len());
|
||||||
|
for mcp_tool in self.tools {
|
||||||
|
let tool = McpToolAdapter {
|
||||||
|
client: McpClientHandle::Empty,
|
||||||
|
name: mcp_tool.name,
|
||||||
|
description: mcp_tool.description.unwrap_or_default(),
|
||||||
|
parameters: mcp_tool.input_schema,
|
||||||
|
};
|
||||||
|
tools.push(Arc::new(tool) as ToolRef);
|
||||||
|
}
|
||||||
|
tools
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
params: Option<Value>,
|
||||||
|
) -> Result<Value, ToolError> {
|
||||||
|
let state_arc = self
|
||||||
|
.process
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let (id, request_json) = {
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
let id = state.next_id();
|
||||||
|
let req = JsonRpcRequest::new(id, method, params);
|
||||||
|
let json = serde_json::to_string(&req)
|
||||||
|
.map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?;
|
||||||
|
(id, json)
|
||||||
|
};
|
||||||
|
|
||||||
|
// 注册 oneshot 等待响应
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
{
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state.pending.insert(id, tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入请求
|
||||||
|
{
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(request_json.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(b"\n")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||||
|
state.stdin.flush().await.map_err(|e| {
|
||||||
|
ToolError::McpError(format!("flush stdin 失败: {e}"))
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待响应(带超时)
|
||||||
|
tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
// 超时:清理 pending
|
||||||
|
let state_arc = state_arc.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state.pending.remove(&id);
|
||||||
|
});
|
||||||
|
ToolError::McpTimeout(method.to_string())
|
||||||
|
})?
|
||||||
|
.map_err(|_| ToolError::McpError("response channel 关闭".into()))?
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_notification(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
params: Option<Value>,
|
||||||
|
) -> Result<(), ToolError> {
|
||||||
|
let state_arc = self
|
||||||
|
.process
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let notification = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"params": params,
|
||||||
|
});
|
||||||
|
let json = serde_json::to_string(¬ification)
|
||||||
|
.map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?;
|
||||||
|
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(json.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.write_all(b"\n")
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?;
|
||||||
|
state
|
||||||
|
.stdin
|
||||||
|
.flush()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 持续读取 stdout,将响应分发到对应的 oneshot sender。
|
||||||
|
async fn read_loop(
|
||||||
|
mut reader: BufReader<ChildStdout>,
|
||||||
|
state: Arc<Mutex<ChildProcessState>>,
|
||||||
|
) {
|
||||||
|
let mut line = String::new();
|
||||||
|
loop {
|
||||||
|
line.clear();
|
||||||
|
match reader.read_line(&mut line).await {
|
||||||
|
Ok(0) => {
|
||||||
|
// EOF:通知所有 pending 失败
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
for (_, tx) in state.pending.drain() {
|
||||||
|
let _ = tx.send(Err(ToolError::McpError("子进程退出".into())));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// 尝试解析为 JSON-RPC 响应
|
||||||
|
let parsed: Result<JsonRpcResponse, _> = serde_json::from_str(trimmed);
|
||||||
|
if let Ok(response) = parsed {
|
||||||
|
let value = if let Some(err) = response.error {
|
||||||
|
Err(ToolError::McpError(format!(
|
||||||
|
"[{}] {}",
|
||||||
|
err.code, err.message
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(response.result.unwrap_or(Value::Null))
|
||||||
|
};
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
if let Some(tx) = state.pending.remove(&response.id) {
|
||||||
|
let _ = tx.send(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 非响应消息(通知、request from server)忽略
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("MCP read_loop error: {e}");
|
||||||
|
let mut state = state.lock().await;
|
||||||
|
for (_, tx) in state.pending.drain() {
|
||||||
|
let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}"))));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。
|
||||||
|
struct McpToolAdapter {
|
||||||
|
/// 持有 client 的弱引用。实际生产中应使用 `Arc<McpClient>`,
|
||||||
|
/// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。
|
||||||
|
/// 标记为 unused 但保留字段以展示扩展路径。
|
||||||
|
#[allow(dead_code)]
|
||||||
|
client: McpClientHandle,
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
enum McpClientHandle {
|
||||||
|
Empty,
|
||||||
|
// Future: Shared(Arc<McpClient>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for McpToolAdapter {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.description
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
self.parameters.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
_args: Value,
|
||||||
|
_ctx: &ToolContext<'_>,
|
||||||
|
) -> Result<Value, ToolError> {
|
||||||
|
// 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。
|
||||||
|
// 实际生产中应持有 Arc<McpClient> 并通过 mcp.call_tool() 执行。
|
||||||
|
// 这里返回错误,提示需要通过其他方式调用 MCP 工具。
|
||||||
|
Err(ToolError::McpError(format!(
|
||||||
|
"MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)",
|
||||||
|
self.name
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_transport_debug() {
|
||||||
|
let transport = McpTransport::Stdio {
|
||||||
|
command: "echo".to_string(),
|
||||||
|
args: vec!["hello".to_string()],
|
||||||
|
};
|
||||||
|
let formatted = format!("{transport:?}");
|
||||||
|
assert!(formatted.contains("echo"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_creation() {
|
||||||
|
let transport = McpTransport::Stdio {
|
||||||
|
command: "test".to_string(),
|
||||||
|
args: vec![],
|
||||||
|
};
|
||||||
|
let client = McpClient::new("test-server", transport).with_timeout(60);
|
||||||
|
assert_eq!(client.server_name, "test-server");
|
||||||
|
assert_eq!(client.timeout_secs, 60);
|
||||||
|
assert!(!client.is_initialized());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_request_serialize() {
|
||||||
|
let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1})));
|
||||||
|
let s = serde_json::to_string(&req).unwrap();
|
||||||
|
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||||
|
assert!(s.contains("\"id\":42"));
|
||||||
|
assert!(s.contains("\"method\":\"test\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_response_parse_ok() {
|
||||||
|
let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#;
|
||||||
|
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||||
|
assert_eq!(resp.id, 1);
|
||||||
|
assert!(resp.result.is_some());
|
||||||
|
assert!(resp.error.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jsonrpc_response_parse_error() {
|
||||||
|
let s =
|
||||||
|
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
|
||||||
|
let resp: JsonRpcResponse = serde_json::from_str(s).unwrap();
|
||||||
|
assert_eq!(resp.id, 1);
|
||||||
|
assert!(resp.result.is_none());
|
||||||
|
let err = resp.error.unwrap();
|
||||||
|
assert_eq!(err.code, -32601);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streamable_http_not_implemented() {
|
||||||
|
let mut client = McpClient::new(
|
||||||
|
"http-server",
|
||||||
|
McpTransport::StreamableHttp {
|
||||||
|
url: "https://example.com/mcp".to_string(),
|
||||||
|
headers: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let result = client.connect().await;
|
||||||
|
// 当前 Phase 2 返回未实现错误
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result, Err(ToolError::McpError(_))));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
//! 工具权限管理。
|
||||||
|
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
|
||||||
|
/// 权限级别枚举。
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum Permission {
|
||||||
|
/// 只读(读取文件、查询数据库等)。
|
||||||
|
Read,
|
||||||
|
/// 写入(创建/修改文件、插入数据等)。
|
||||||
|
Write,
|
||||||
|
/// 删除(删除文件、记录等)。
|
||||||
|
Delete,
|
||||||
|
/// 网络访问(HTTP 请求等)。
|
||||||
|
Network,
|
||||||
|
/// Shell 命令执行。
|
||||||
|
Shell,
|
||||||
|
/// 文件系统操作(除读/写/删之外的 FS 操作)。
|
||||||
|
FileSystem,
|
||||||
|
/// 自定义权限(可通过 namespaced 字符串扩展)。
|
||||||
|
Custom(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Permission {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Read => write!(f, "Read"),
|
||||||
|
Self::Write => write!(f, "Write"),
|
||||||
|
Self::Delete => write!(f, "Delete"),
|
||||||
|
Self::Network => write!(f, "Network"),
|
||||||
|
Self::Shell => write!(f, "Shell"),
|
||||||
|
Self::FileSystem => write!(f, "FileSystem"),
|
||||||
|
Self::Custom(s) => write!(f, "Custom({})", s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PermissionConfig {
|
||||||
|
/// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。
|
||||||
|
pub allowed: Vec<Permission>,
|
||||||
|
/// 拒绝的权限列表(优先级高于 `allowed`)。
|
||||||
|
pub denied: Vec<Permission>,
|
||||||
|
/// 当工具未声明权限时是否允许执行。
|
||||||
|
pub allow_unspecified: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PermissionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
allowed: vec![Permission::Read, Permission::Network],
|
||||||
|
denied: vec![Permission::Delete, Permission::Shell],
|
||||||
|
allow_unspecified: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 权限检查器。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PermissionChecker {
|
||||||
|
config: PermissionConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PermissionChecker {
|
||||||
|
/// 创建一个新的权限检查器。
|
||||||
|
pub fn new(config: PermissionConfig) -> Self {
|
||||||
|
Self { config }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查指定工具声明的权限是否允许执行。
|
||||||
|
///
|
||||||
|
/// 判定规则:
|
||||||
|
/// 1. 任一权限在 `denied` 中 → 拒绝
|
||||||
|
/// 2. 所有权限都在 `allowed` 中 → 允许
|
||||||
|
/// 3. `allowed` 非空且存在未声明权限 → 拒绝
|
||||||
|
/// 4. `allowed` 为空 → 按 `allow_unspecified` 判定
|
||||||
|
/// 5. 工具未声明任何权限时按 `allow_unspecified` 判定
|
||||||
|
pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> {
|
||||||
|
// 任一权限在 denied 中 → 拒绝
|
||||||
|
for perm in permissions {
|
||||||
|
if self.config.denied.contains(perm) {
|
||||||
|
return Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
perm.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具未声明任何权限
|
||||||
|
if permissions.is_empty() {
|
||||||
|
return if self.config.allow_unspecified {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
"Unspecified".to_string(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowed 为空 → 走 allow_unspecified 兜底
|
||||||
|
if self.config.allowed.is_empty() {
|
||||||
|
return if self.config.allow_unspecified {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
"Unspecified".to_string(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowed 非空(白名单模式)—— 所有权限必须在其中
|
||||||
|
for perm in permissions {
|
||||||
|
if !self.config.allowed.contains(perm) {
|
||||||
|
return Err(ToolError::PermissionDenied(
|
||||||
|
tool_name.to_string(),
|
||||||
|
perm.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn p(perm: Permission) -> Vec<Permission> {
|
||||||
|
vec![perm]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_allows_read() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("weather", &p(Permission::Read)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_allows_network() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("http_get", &p(Permission::Network)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_denies_delete() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker
|
||||||
|
.check("rm_file", &p(Permission::Delete))
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config_denies_shell() {
|
||||||
|
let checker = PermissionChecker::new(PermissionConfig::default());
|
||||||
|
assert!(checker.check("run_shell", &p(Permission::Shell)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_white_list_mode_denies_unlisted() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Read)).is_ok());
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_white_list_mode_allows_listed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Write],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_black_list_deny_priority() {
|
||||||
|
// 即便 allowed 中包含了 denied 权限,仍以 denied 为准
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Shell, Permission::Read],
|
||||||
|
denied: vec![Permission::Shell],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Shell)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_allowed_with_allow_unspecified() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: true,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_allowed_without_allow_unspecified() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &p(Permission::Write)).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unspecified_tool_with_allow() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: true,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &[]).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unspecified_tool_without_allow() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker.check("t", &[]).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_custom_permission_collision() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Custom("db:read".into())],
|
||||||
|
denied: vec![Permission::Custom("db:write".into())],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Custom("db:read".into())])
|
||||||
|
.is_ok());
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Custom("db:write".into())])
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multi_permission_all_in_allowed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Write, Permission::Network],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
assert!(checker
|
||||||
|
.check(
|
||||||
|
"t",
|
||||||
|
&[Permission::Read, Permission::Network]
|
||||||
|
)
|
||||||
|
.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multi_permission_one_not_in_allowed() {
|
||||||
|
let cfg = PermissionConfig {
|
||||||
|
allowed: vec![Permission::Read, Permission::Network],
|
||||||
|
denied: vec![],
|
||||||
|
allow_unspecified: false,
|
||||||
|
};
|
||||||
|
let checker = PermissionChecker::new(cfg);
|
||||||
|
// 任一权限不在白名单则拒绝
|
||||||
|
assert!(checker
|
||||||
|
.check("t", &[Permission::Read, Permission::Write])
|
||||||
|
.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,371 @@
|
|||||||
|
//! 工具注册表 —— 管理工具注册、发现、调用。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use futures::future::join_all;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::llm::types::ToolDefinition;
|
||||||
|
use crate::tools::base::{ToolContext, ToolRef};
|
||||||
|
use crate::tools::error::ToolError;
|
||||||
|
use crate::tools::permission::PermissionChecker;
|
||||||
|
|
||||||
|
/// 工具调用记录 —— 用于追踪和调试。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolInvocation {
|
||||||
|
/// 被调用的工具名。
|
||||||
|
pub tool_name: String,
|
||||||
|
/// 工具的入参。
|
||||||
|
pub input: Value,
|
||||||
|
/// 工具的输出。
|
||||||
|
pub output: Result<Value, ToolError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolInvocation {
|
||||||
|
/// 创建一个新的工具调用记录。
|
||||||
|
pub fn new(tool_name: String, input: Value, output: Result<Value, ToolError>) -> Self {
|
||||||
|
Self {
|
||||||
|
tool_name,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 工具注册表 —— 管理工具注册、发现、调用。
|
||||||
|
///
|
||||||
|
/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。
|
||||||
|
/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
inner: Arc<ToolRegistryInner>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct ToolRegistryInner {
|
||||||
|
tools: HashMap<String, ToolRef>,
|
||||||
|
permission_checker: Option<Arc<PermissionChecker>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ToolRegistry {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ToolRegistry")
|
||||||
|
.field("tool_names", &self.inner.tools.keys().collect::<Vec<_>>())
|
||||||
|
.field("has_checker", &self.inner.permission_checker.is_some())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
/// 创建一个新的工具注册表。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(ToolRegistryInner {
|
||||||
|
tools: HashMap::new(),
|
||||||
|
permission_checker: None,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 设置权限检查器(Builder 模式)。
|
||||||
|
pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self {
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
inner.permission_checker = Some(Arc::new(checker));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册一个工具。
|
||||||
|
///
|
||||||
|
/// 重复注册同名工具返回错误。
|
||||||
|
pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> {
|
||||||
|
let name = tool.name().to_string();
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
if inner.tools.contains_key(&name) {
|
||||||
|
return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string()));
|
||||||
|
}
|
||||||
|
inner.tools.insert(name, tool);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量注册工具。
|
||||||
|
pub fn register_all(&mut self, tools: Vec<ToolRef>) -> Result<(), ToolError> {
|
||||||
|
for tool in tools {
|
||||||
|
self.register(tool)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注销一个工具。
|
||||||
|
pub fn unregister(&mut self, name: &str) -> Option<ToolRef> {
|
||||||
|
let inner = Arc::make_mut(&mut self.inner);
|
||||||
|
inner.tools.remove(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称查找工具。
|
||||||
|
pub fn get(&self, name: &str) -> Option<ToolRef> {
|
||||||
|
self.inner.tools.get(name).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有已注册工具的名称列表。
|
||||||
|
pub fn list_tools(&self) -> Vec<String> {
|
||||||
|
self.inner.tools.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。
|
||||||
|
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||||
|
self.inner
|
||||||
|
.tools
|
||||||
|
.values()
|
||||||
|
.map(|tool| ToolDefinition {
|
||||||
|
name: tool.name().to_string(),
|
||||||
|
description: Some(tool.description().to_string()),
|
||||||
|
parameters: tool.parameters(),
|
||||||
|
strict: None,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 调用单个工具(含权限检查)。
|
||||||
|
pub async fn invoke(&self, name: &str, args: Value) -> Result<ToolInvocation, ToolError> {
|
||||||
|
let tool = self
|
||||||
|
.get(name)
|
||||||
|
.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
|
||||||
|
|
||||||
|
if let Some(checker) = &self.inner.permission_checker {
|
||||||
|
checker.check(name, &tool.required_permissions())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctx = ToolContext::new(name, "");
|
||||||
|
let output = tool.execute(args.clone(), &ctx).await;
|
||||||
|
Ok(ToolInvocation::new(name.to_string(), args, output))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 并行执行多个工具调用(互不依赖的工具)。
|
||||||
|
///
|
||||||
|
/// 每个工具独立超时(`timeout_per_call_secs`,0 表示不超时)。
|
||||||
|
/// 单个工具超时不会影响其他工具的返回。
|
||||||
|
pub async fn invoke_all(
|
||||||
|
&self,
|
||||||
|
calls: Vec<(String, Value)>,
|
||||||
|
timeout_per_call_secs: u64,
|
||||||
|
) -> Vec<ToolInvocation> {
|
||||||
|
let this = self.clone();
|
||||||
|
let futures = calls.into_iter().map(|(name, args)| {
|
||||||
|
let this = this.clone();
|
||||||
|
async move {
|
||||||
|
match if timeout_per_call_secs == 0 {
|
||||||
|
Ok(this.invoke(&name, args.clone()).await)
|
||||||
|
} else {
|
||||||
|
tokio::time::timeout(
|
||||||
|
Duration::from_secs(timeout_per_call_secs),
|
||||||
|
this.invoke(&name, args.clone()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} {
|
||||||
|
Ok(result) => result.unwrap_or_else(|e| {
|
||||||
|
ToolInvocation::new(name.clone(), args.clone(), Err(e))
|
||||||
|
}),
|
||||||
|
Err(_) => ToolInvocation::new(
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
Err(ToolError::McpTimeout("timeout".into())),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
join_all(futures).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tools::BaseTool;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
struct AddTool {
|
||||||
|
base: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for AddTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"add"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"加法"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": { "n": { "type": "integer" } },
|
||||||
|
"required": ["n"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
let n = args["n"].as_i64().unwrap_or(0);
|
||||||
|
Ok(json!({ "result": self.base + n }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FailTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for FailTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"fail"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"总会失败"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({})
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Err(ToolError::ExecutionFailed("fail".into(), "boom".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ShellTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BaseTool for ShellTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"shell"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"shell"
|
||||||
|
}
|
||||||
|
fn parameters(&self) -> Value {
|
||||||
|
json!({})
|
||||||
|
}
|
||||||
|
fn required_permissions(&self) -> Vec<crate::tools::permission::Permission> {
|
||||||
|
vec![crate::tools::permission::Permission::Shell]
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result<Value, ToolError> {
|
||||||
|
Ok(json!({}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_and_get() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 10 })).unwrap();
|
||||||
|
assert!(reg.get("add").is_some());
|
||||||
|
assert!(reg.get("nonexistent").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_duplicate() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let result = reg.register(Arc::new(AddTool { base: 1 }));
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_register_all() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
let result = reg.register_all(vec![
|
||||||
|
Arc::new(AddTool { base: 1 }),
|
||||||
|
Arc::new(AddTool { base: 2 }),
|
||||||
|
]);
|
||||||
|
assert!(result.is_err()); // 重名 add → 失败
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unregister() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let removed = reg.unregister("add");
|
||||||
|
assert!(removed.is_some());
|
||||||
|
assert!(reg.get("add").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_list_tools() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let names = reg.list_tools();
|
||||||
|
assert_eq!(names.len(), 2);
|
||||||
|
assert!(names.contains(&"add".to_string()));
|
||||||
|
assert!(names.contains(&"fail".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_definitions() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let defs = reg.definitions();
|
||||||
|
assert_eq!(defs.len(), 1);
|
||||||
|
assert_eq!(defs[0].name, "add");
|
||||||
|
assert!(defs[0].description.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_success() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 100 })).unwrap();
|
||||||
|
let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap();
|
||||||
|
let value = result.output.unwrap();
|
||||||
|
assert_eq!(value["result"], 105);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_not_found() {
|
||||||
|
let reg = ToolRegistry::new();
|
||||||
|
let result = reg.invoke("nope", json!({})).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::NotFound(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_execution_error() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let result = reg.invoke("fail", json!({})).await.unwrap();
|
||||||
|
assert!(result.output.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_with_permission_denied() {
|
||||||
|
let mut reg = ToolRegistry::new()
|
||||||
|
.with_permission_checker(PermissionChecker::new(Default::default()));
|
||||||
|
reg.register(Arc::new(ShellTool)).unwrap();
|
||||||
|
let result = reg.invoke("shell", json!({})).await;
|
||||||
|
assert!(matches!(result, Err(ToolError::PermissionDenied(_, _))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_all_parallel() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 1 })).unwrap();
|
||||||
|
reg.register(Arc::new(FailTool)).unwrap();
|
||||||
|
let calls = vec![
|
||||||
|
("add".into(), json!({ "n": 1 })),
|
||||||
|
("add".into(), json!({ "n": 2 })),
|
||||||
|
("fail".into(), json!({})),
|
||||||
|
];
|
||||||
|
let results = reg.invoke_all(calls, 0).await;
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
assert!(results[0].output.is_ok());
|
||||||
|
assert!(results[1].output.is_ok());
|
||||||
|
assert!(results[2].output.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invoke_all_with_timeout() {
|
||||||
|
let mut reg = ToolRegistry::new();
|
||||||
|
reg.register(Arc::new(AddTool { base: 0 })).unwrap();
|
||||||
|
let calls = vec![("add".into(), json!({ "n": 1 }))];
|
||||||
|
let results = reg.invoke_all(calls, 5).await;
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert!(results[0].output.is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user