diff --git a/Cargo.toml b/Cargo.toml index 002a5c4..539b965 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ futures-util = "0.3" futures-core = "0.3" bytes = "1" async-stream = "0.3" +tokio-util = { version = "0.7", features = ["rt"] } [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/lib.rs b/src/lib.rs index b873931..c63cb2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod llm; pub mod prompt; +pub mod tools; use tracing_subscriber::{EnvFilter, fmt, prelude::*}; diff --git a/src/llm/cycle.rs b/src/llm/cycle.rs index 06e67cf..122145d 100644 --- a/src/llm/cycle.rs +++ b/src/llm/cycle.rs @@ -19,7 +19,8 @@ use crate::llm::hooks::{HookContext, HookExecutor}; use crate::llm::provider::LlmProvider; use crate::llm::stream::StreamEvent; use crate::llm::types::{ - ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition, + ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, OpenaiTool, OpenaiToolCall, + ToolChoice, ToolDefinition, }; /// LLM 调用周期配置。 @@ -34,6 +35,15 @@ pub struct CycleConfig { pub max_turns: Option, /// 重试策略配置。 pub retry: RetryConfig, + /// 自动 tool 循环的最大轮次(独立于 `max_turns`,避免影响现有 `submit()` 语义)。 + /// 默认 `Some(10)`,防止 LLM 反复调用工具导致无限循环。 + pub max_tool_turns: Option, + /// 单个工具执行的超时秒数(0 表示不超时)。 + /// 默认 60 秒。 + pub tool_timeout_secs: u64, + /// 单个工具结果的最大字节数(超过此值将被截断)。 + /// 默认 65536(64KB),防止大结果导致 token 膨胀。 + pub max_tool_result_bytes: usize, } impl Default for CycleConfig { @@ -44,6 +54,9 @@ impl Default for CycleConfig { temperature: None, max_turns: None, retry: RetryConfig::default(), + max_tool_turns: Some(10), + tool_timeout_secs: 60, + max_tool_result_bytes: 65_536, } } } @@ -440,4 +453,455 @@ impl LlmCycle { ..Default::default() } } + + /// 内部请求方法(与 `submit` 共享重试逻辑,但不 push user message 和 Assistant 响应)。 + /// + /// 用于 `submit_with_tools()` 的多轮 tool 循环。 + async fn submit_request( + &mut self, + tools: &[ToolDefinition], + ) -> Result { + let mut attempts = 0; + + loop { + let request = self.build_request(tools); + + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest) + .with_request(&request); + let results = executor + .execute(crate::llm::hooks::HookEvent::PreRequest, &ctx) + .await; + if results.iter().any(|r| r.should_block) { + let reason = results + .iter() + .find(|r| r.should_block) + .and_then(|r| r.reason.clone()) + .unwrap_or_else(|| "Blocked by pre-request hook".to_string()); + return Err(LlmError::Other(reason)); + } + } + + match self.provider.chat(request).await { + Ok(response) => { + if let Some(ref executor) = self.hook_executor { + let post_request = self.build_request(tools); + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest) + .with_request(&post_request); + executor + .execute(crate::llm::hooks::HookEvent::PostRequest, &ctx) + .await; + } + self.usage.add(&response.usage); + return Ok(response); + } + Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => { + attempts += 1; + + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry) + .with_error(&e) + .with_attempt(attempts); + executor + .execute(crate::llm::hooks::HookEvent::OnRetry, &ctx) + .await; + } + + let delay = self.config.retry.compute_delay(attempts); + tokio::time::sleep(delay).await; + } + Err(e) => { + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError) + .with_error(&e); + executor + .execute(crate::llm::hooks::HookEvent::OnError, &ctx) + .await; + } + return Err(e); + } + } + } + } + + /// 提交消息并自动处理工具调用循环。 + /// + /// 流程: + /// 1. 发送请求(含工具定义) + /// 2. 检查响应中的 finish_reason + /// 3. 如果是 ToolCalls → push Assistant 消息 → 执行工具 → 回传结果 → 重复 1 + /// 4. 如果是 Stop/Length → push Assistant 消息 → 返回最终响应 + /// + /// 注意:OpenAI API 要求 tool 消息必须紧跟在对应的 Assistant(tool_calls)消息之后。 + /// 因此 push 工具结果前必须先 push Assistant 响应,否则 API 拒绝请求。 + pub async fn submit_with_tools( + &mut self, + prompt: String, + registry: &crate::tools::ToolRegistry, + ) -> Result { + let tools = registry.definitions(); + let max_turns = self.config.max_tool_turns.unwrap_or(10); + let tool_timeout = self.config.tool_timeout_secs; + let max_bytes = self.config.max_tool_result_bytes; + + self.messages.push(OpenaiChatMessage::user_text(prompt)); + self.maybe_compact(); + + let mut turn = 0; + + loop { + turn += 1; + if turn > max_turns { + return Err(LlmError::Other(format!( + "达到最大工具循环轮次 ({max_turns})" + ))); + } + + let response = self.submit_request(&tools).await?; + + // 判断是否需要执行工具 + let should_execute = matches!(response.stop_reason, Some(FinishReason::ToolCalls)) + && has_tool_calls_in_message(&response.message); + + // 将 Assistant 响应(含 tool_calls 或最终文本)追加到消息历史 + self.messages.push(response.message.clone()); + + if !should_execute { + return Ok(response); + } + + // 解析 tool_calls 并执行 + let tool_calls = extract_tool_calls_from_message(&response.message); + let calls: Vec<(String, serde_json::Value)> = tool_calls + .into_iter() + .map(|(_id, name, args)| { + let args: serde_json::Value = + serde_json::from_str(&args).unwrap_or(serde_json::Value::Null); + (name, args) + }) + .collect(); + + let results = registry.invoke_all(calls, tool_timeout).await; + + // 回传工具结果 + for result in results { + let content = match result.output { + Ok(value) => { + let serialized = serde_json::to_string(&value).unwrap_or_else(|e| { + tracing::warn!("工具结果序列化失败: {}", e); + "{}".to_string() + }); + truncate_tool_result(&serialized, max_bytes) + } + Err(e) if e.is_recoverable() => format!("错误: {}", e), + Err(e) => { + // 不可恢复错误:终止循环 + return Err(LlmError::Other(format!( + "工具 '{}' 不可恢复错误: {}", + result.tool_name, e + ))); + } + }; + + self.messages + .push(OpenaiChatMessage::tool_result(result.tool_name, content)); + } + + // 每轮工具执行后触发 compaction + self.maybe_compact(); + } + + // unreachable: loop returns + #[allow(unreachable_code)] + { + Err(LlmError::Other("unreachable".into())) + } + } + + /// 在接近上下文窗口时压缩历史消息。 + fn maybe_compact(&mut self) { + if let Some(ref config) = self.compact_config + && should_compact(&self.messages, config, &self.compact_state) + { + let freed = microcompact(&mut self.messages, config.keep_recent); + if freed > 0 { + self.compact_state.record_success(); + } + } + } +} + +/// 判断 Assistant 消息是否包含 tool_calls。 +fn has_tool_calls_in_message(msg: &OpenaiChatMessage) -> bool { + matches!( + msg, + OpenaiChatMessage::Assistant { + tool_calls: Some(calls), + .. + } if !calls.is_empty() + ) +} + +/// 提取 Assistant 消息中的 tool_calls。 +/// +/// 返回 `(tool_call_id, tool_name, arguments_json_string)` 列表。 +fn extract_tool_calls_from_message( + msg: &OpenaiChatMessage, +) -> Vec<(String, String, String)> { + if let OpenaiChatMessage::Assistant { + tool_calls: Some(calls), + .. + } = msg + { + calls + .iter() + .map(|c| match c { + OpenaiToolCall::Function { id, function } => { + (id.clone(), function.name.clone(), function.arguments.clone()) + } + }) + .collect() + } else { + Vec::new() + } +} + +/// 截断工具结果到指定字节数。 +fn truncate_tool_result(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + format!("{}\n\n[... truncated, original size: {} bytes ...]", &s[..end], s.len()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::types::{ContentField, OpenaiContentPart}; + use crate::tools::{BaseTool, ToolRegistry}; + use async_trait::async_trait; + use serde_json::{json, Value}; + + /// 模拟 Provider —— 预定义响应序列,按调用顺序返回。 + struct MockProvider { + responses: std::sync::Mutex>, + call_count: std::sync::Mutex, + } + + impl MockProvider { + fn new(responses: Vec) -> Self { + Self { + responses: std::sync::Mutex::new(responses), + call_count: std::sync::Mutex::new(0), + } + } + } + + #[async_trait] + impl LlmProvider for MockProvider { + async fn chat(&self, _request: ChatRequest) -> Result { + let mut count = self.call_count.lock().unwrap(); + *count += 1; + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + return Err(LlmError::Other("no more mock responses".into())); + } + Ok(responses.remove(0)) + } + } + + fn empty_usage() -> crate::llm::types::Usage { + crate::llm::types::Usage::default() + } + + fn assistant_text_response(text: &str) -> ChatResponse { + ChatResponse { + message: OpenaiChatMessage::assistant_text(text), + usage: empty_usage(), + stop_reason: Some(FinishReason::Stop), + } + } + + fn assistant_tool_call_response( + calls: Vec<(&str, &str, &str)>, + ) -> ChatResponse { + use crate::llm::types::{OpenaiToolCall, FunctionCall}; + let tool_calls: Vec = calls + .into_iter() + .map(|(id, name, args)| OpenaiToolCall::Function { + id: id.to_string(), + function: FunctionCall { + name: name.to_string(), + arguments: args.to_string(), + }, + }) + .collect(); + ChatResponse { + message: OpenaiChatMessage::Assistant { + content: ContentField::Array(vec![OpenaiContentPart::Text { + text: String::new(), + }]), + refusal: None, + name: None, + tool_calls: Some(tool_calls), + }, + usage: empty_usage(), + stop_reason: Some(FinishReason::ToolCalls), + } + } + + struct AddTool; + + #[async_trait] + impl BaseTool for AddTool { + fn name(&self) -> &str { + "add" + } + fn description(&self) -> &str { + "加法" + } + fn parameters(&self) -> Value { + json!({"type":"object","properties":{"a":{"type":"integer"},"b":{"type":"integer"}}}) + } + async fn execute( + &self, + args: Value, + _ctx: &crate::tools::ToolContext<'_>, + ) -> Result { + let a = args["a"].as_i64().unwrap_or(0); + let b = args["b"].as_i64().unwrap_or(0); + Ok(json!({"result": a + b})) + } + } + + #[tokio::test] + async fn test_submit_with_tools_single_turn() { + // 第一轮:返回 tool_call;第二轮:返回最终文本 + let responses = vec![ + assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]), + assistant_text_response("答案是 3"), + ]; + let provider = Box::new(MockProvider::new(responses)); + let mut cycle = LlmCycle::new(provider, CycleConfig::default()); + + let mut registry = ToolRegistry::new(); + registry.register(std::sync::Arc::new(AddTool)).unwrap(); + + let response = cycle + .submit_with_tools("1+2=?".to_string(), ®istry) + .await + .unwrap(); + // 验证最终响应是文本响应 + assert!(matches!( + response.message, + OpenaiChatMessage::Assistant { .. } + )); + + // 验证消息历史:user, assistant(tool_calls), tool, assistant(text) + let messages = cycle.messages(); + assert_eq!(messages.len(), 4); + assert!(matches!(messages[0], OpenaiChatMessage::User { .. })); + assert!(matches!(messages[1], OpenaiChatMessage::Assistant { .. })); + assert!(matches!(messages[2], OpenaiChatMessage::Tool { .. })); + assert!(matches!(messages[3], OpenaiChatMessage::Assistant { .. })); + } + + #[tokio::test] + async fn test_submit_with_tools_multi_turn() { + // 3 轮 tool 调用后给出最终答案 + let responses = vec![ + assistant_tool_call_response(vec![("call_1", "add", r#"{"a":1,"b":2}"#)]), + assistant_tool_call_response(vec![("call_2", "add", r#"{"a":3,"b":4}"#)]), + assistant_tool_call_response(vec![("call_3", "add", r#"{"a":5,"b":6}"#)]), + assistant_text_response("完成"), + ]; + let provider = Box::new(MockProvider::new(responses)); + let mut cycle = LlmCycle::new(provider, CycleConfig::default()); + + let mut registry = ToolRegistry::new(); + registry.register(std::sync::Arc::new(AddTool)).unwrap(); + + let response = cycle + .submit_with_tools("计算总和".to_string(), ®istry) + .await + .unwrap(); + assert!(matches!( + response.message, + OpenaiChatMessage::Assistant { .. } + )); + + // user + 3*(assistant + tool) + final assistant = 8 + let messages = cycle.messages(); + assert_eq!(messages.len(), 8); + } + + #[tokio::test] + async fn test_submit_with_tools_max_turns_exceeded() { + // 配置 max_tool_turns = 2 + let mut config = CycleConfig::default(); + config.max_tool_turns = Some(2); + // 4 轮 tool 调用 + 终止 + let responses = vec![ + assistant_tool_call_response(vec![("c1", "add", r#"{"a":1,"b":1}"#)]), + assistant_tool_call_response(vec![("c2", "add", r#"{"a":1,"b":1}"#)]), + assistant_tool_call_response(vec![("c3", "add", r#"{"a":1,"b":1}"#)]), + assistant_text_response("完成"), + ]; + let provider = Box::new(MockProvider::new(responses)); + let mut cycle = LlmCycle::new(provider, config); + + let mut registry = ToolRegistry::new(); + registry.register(std::sync::Arc::new(AddTool)).unwrap(); + + let result = cycle + .submit_with_tools("test".to_string(), ®istry) + .await; + assert!(matches!(result, Err(LlmError::Other(msg)) if msg.contains("达到最大工具循环轮次"))); + } + + #[tokio::test] + async fn test_submit_with_tools_no_tool_call_response() { + // LLM 直接给出最终响应(不调用工具) + let responses = vec![assistant_text_response("直接回答")]; + let provider = Box::new(MockProvider::new(responses)); + let mut cycle = LlmCycle::new(provider, CycleConfig::default()); + + let mut registry = ToolRegistry::new(); + registry.register(std::sync::Arc::new(AddTool)).unwrap(); + + let response = cycle + .submit_with_tools("直接回答".to_string(), ®istry) + .await + .unwrap(); + assert!(matches!( + response.message, + OpenaiChatMessage::Assistant { .. } + )); + } + + #[test] + fn test_truncate_tool_result_short() { + let s = "short text"; + assert_eq!(truncate_tool_result(s, 100), "short text"); + } + + #[test] + fn test_truncate_tool_result_long() { + let s = "a".repeat(1000); + let truncated = truncate_tool_result(&s, 50); + assert!(truncated.len() < s.len()); + assert!(truncated.contains("[... truncated,")); + } + + #[test] + fn test_truncate_tool_result_chinese_chars() { + let s = "中".repeat(100); + let truncated = truncate_tool_result(&s, 50); + // 不会在字符中间截断 + assert!(truncated.starts_with("中")); + } } diff --git a/src/tools.rs b/src/tools.rs new file mode 100644 index 0000000..a28a8db --- /dev/null +++ b/src/tools.rs @@ -0,0 +1,13 @@ +//! 工具系统 —— 工具抽象、注册、调用、权限控制与 MCP 集成。 + +pub mod base; +pub mod error; +pub mod mcp; +pub mod permission; +pub mod registry; + +pub use base::{BaseTool, ToolContext, ToolRef}; +pub use error::ToolError; +pub use mcp::{McpClient, McpTransport}; +pub use permission::{Permission, PermissionChecker, PermissionConfig}; +pub use registry::{ToolInvocation, ToolRegistry}; diff --git a/src/tools/base.rs b/src/tools/base.rs new file mode 100644 index 0000000..3179fe3 --- /dev/null +++ b/src/tools/base.rs @@ -0,0 +1,139 @@ +//! 工具抽象接口与执行上下文。 + +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::Value; +use tokio_util::sync::CancellationToken; + +use crate::tools::error::ToolError; +use crate::tools::permission::Permission; + +/// 工具执行上下文 —— 携带每次执行的运行时信息。 +/// +/// 字段在 Phase 2 即注入 `execute()` 签名中,防止后续扩展时出现 +/// breaking change。后续阶段可扩展字段(如 `progress`、`shared_state`), +/// 但已有工具实现无需修改。 +#[derive(Debug)] +pub struct ToolContext<'a> { + /// 当前对话/会话 ID,用于关联性追踪。 + pub session_id: &'a str, + /// 链路追踪 ID,用于跨工具调用的耗时分布。 + pub trace_id: &'a str, + /// 取消令牌,用于优雅取消正在执行的工具。 + pub cancellation_token: CancellationToken, +} + +impl<'a> ToolContext<'a> { + /// 创建一个新的工具执行上下文。 + pub fn new(session_id: &'a str, trace_id: &'a str) -> Self { + Self { + session_id, + trace_id, + cancellation_token: CancellationToken::new(), + } + } + + /// 创建一个使用给定取消令牌的上下文。 + pub fn with_cancellation_token( + session_id: &'a str, + trace_id: &'a str, + token: CancellationToken, + ) -> Self { + Self { + session_id, + trace_id, + cancellation_token: token, + } + } +} + +/// 工具抽象接口 —— 所有工具(自定义或 MCP)最终都实现此 trait。 +#[async_trait] +pub trait BaseTool: Send + Sync { + /// 工具名称(唯一标识,用于 LLM 的 tool_calls.name 匹配)。 + fn name(&self) -> &str; + + /// 工具描述(LLM 据此决定是否调用此工具)。 + fn description(&self) -> &str; + + /// 工具参数定义(JSON Schema 格式,传递给 LLM 的 tool.parameters)。 + fn parameters(&self) -> Value; + + /// 声明工具所需的权限列表。 + fn required_permissions(&self) -> Vec { + Vec::new() + } + + /// 执行工具调用。 + /// + /// `ctx` 携带执行上下文(session_id、trace_id、cancellation_token), + /// 工具实现可在执行期间检查 `ctx.cancellation_token` 来支持优雅取消。 + async fn execute(&self, args: Value, ctx: &ToolContext<'_>) -> Result; +} + +/// 为 `Arc` 提供 `Send + Sync` 包装,便于在 `Vec>` 中使用。 +pub type ToolRef = Arc; + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + struct EchoTool; + + #[async_trait] + impl BaseTool for EchoTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "回显输入" + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "text": { "type": "string" } + }, + "required": ["text"] + }) + } + + async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result { + Ok(args) + } + } + + #[tokio::test] + async fn test_mock_tool_execute() { + let tool = EchoTool; + let ctx = ToolContext::new("session-1", "trace-1"); + let result = tool.execute(json!({"text": "hello"}), &ctx).await.unwrap(); + assert_eq!(result, json!({"text": "hello"})); + } + + #[tokio::test] + async fn test_default_permissions_empty() { + let tool = EchoTool; + assert!(tool.required_permissions().is_empty()); + } + + #[test] + fn test_tool_context_creation() { + let ctx = ToolContext::new("s1", "t1"); + assert_eq!(ctx.session_id, "s1"); + assert_eq!(ctx.trace_id, "t1"); + assert!(!ctx.cancellation_token.is_cancelled()); + } + + #[test] + fn test_tool_context_cancellation() { + let token = CancellationToken::new(); + token.cancel(); + let ctx = ToolContext::with_cancellation_token("s1", "t1", token); + assert!(ctx.cancellation_token.is_cancelled()); + } +} diff --git a/src/tools/error.rs b/src/tools/error.rs new file mode 100644 index 0000000..0c546c7 --- /dev/null +++ b/src/tools/error.rs @@ -0,0 +1,118 @@ +//! 工具系统错误类型。 + +use std::sync::Arc; + +/// 工具调用过程中可能发生的所有错误。 +#[derive(thiserror::Error, Debug, Clone)] +pub enum ToolError { + /// 工具未注册。 + #[error("工具 '{0}' 未注册")] + NotFound(String), + + /// 工具执行失败(可恢复——文本回传 LLM)。 + #[error("工具 '{0}' 执行失败: {1}")] + ExecutionFailed(String, String), + + /// 工具参数无效(可恢复——文本回传 LLM)。 + #[error("工具 '{0}' 参数无效: {1}")] + InvalidArguments(String, String), + + /// 权限被拒绝(不可恢复——终止循环)。 + #[error("权限被拒绝: 工具 '{0}' 需要 {1} 权限")] + PermissionDenied(String, String), + + /// MCP 协议错误(不可恢复)。 + #[error("MCP 协议错误: {0}")] + McpError(String), + + /// MCP 未初始化(不可恢复)。 + #[error("MCP 未初始化: {0}")] + McpNotInitialized(String), + + /// MCP 超时(不可恢复)。 + #[error("MCP 超时: {0}")] + McpTimeout(String), + + /// IO 错误(不可恢复)。 + #[error("IO 错误: {0}")] + Io(Arc), + + /// 取消。 + #[error("工具执行已取消: {0}")] + Cancelled(String), + + /// 其他未分类错误。 + #[error("其他错误: {0}")] + Other(String), +} + +impl From for ToolError { + fn from(e: std::io::Error) -> Self { + ToolError::Io(Arc::new(e)) + } +} + +impl ToolError { + /// 判断错误是否可恢复——可恢复的错误回传 LLM 由其自行重试, + /// 不可恢复的错误终止自动 tool 循环并返回给调用方。 + pub fn is_recoverable(&self) -> bool { + matches!( + self, + Self::ExecutionFailed(..) | Self::InvalidArguments(..) | Self::Other(_) + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_execution_failed_is_recoverable() { + let err = ToolError::ExecutionFailed("foo".into(), "boom".into()); + assert!(err.is_recoverable()); + } + + #[test] + fn test_invalid_arguments_is_recoverable() { + let err = ToolError::InvalidArguments("foo".into(), "missing x".into()); + assert!(err.is_recoverable()); + } + + #[test] + fn test_not_found_is_not_recoverable() { + let err = ToolError::NotFound("foo".into()); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_permission_denied_is_not_recoverable() { + let err = ToolError::PermissionDenied("foo".into(), "Shell".into()); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_mcp_error_is_not_recoverable() { + let err = ToolError::McpError("protocol".into()); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_mcp_timeout_is_not_recoverable() { + let err = ToolError::McpTimeout("foo".into()); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_io_is_not_recoverable() { + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "disk"); + let err = ToolError::from(io_err); + assert!(!err.is_recoverable()); + } + + #[test] + fn test_other_is_recoverable() { + let err = ToolError::Other("something".into()); + assert!(err.is_recoverable()); + } +} diff --git a/src/tools/mcp.rs b/src/tools/mcp.rs new file mode 100644 index 0000000..cc68151 --- /dev/null +++ b/src/tools/mcp.rs @@ -0,0 +1,640 @@ +//! MCP 协议客户端 —— 与 MCP Server 通过 JSON-RPC over stdio 通信。 +//! +//! 当前 Phase 2 实现 stdio transport。`StreamableHttp` 枚举变体已预留, +//! 但实际实现推迟到后续版本。 +//! +//! ## 协议版本 +//! +//! 实现遵循 MCP 协议版本 2025-03-26。 + +use std::collections::HashMap; +use std::process::Stdio; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::{oneshot, Mutex}; + +use crate::llm::types::ToolDefinition; +use crate::tools::base::{BaseTool, ToolContext, ToolRef}; +use crate::tools::error::ToolError; + +/// MCP 协议版本。 +const MCP_VERSION: &str = "2025-03-26"; + +/// MCP 传输方式。 +#[derive(Debug, Clone)] +pub enum McpTransport { + /// 通过子进程 stdin/stdout 通信。 + Stdio { + /// 启动命令(如 `"npx"`)。 + command: String, + /// 命令参数(如 `["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]`)。 + args: Vec, + }, + /// Streamable HTTP 传输(MCP 2025-03-26 引入,替代已废弃的 HTTP+SSE)。 + /// + /// 当前 Phase 2 预留枚举变体,调用方法会返回 `ToolError::McpError`。 + StreamableHttp { + /// MCP 端点 URL。 + url: String, + /// 可选的 HTTP 头(如 Authorization)。 + headers: Option>, + }, +} + +/// JSON-RPC 请求。 +#[derive(Debug, Serialize, Deserialize)] +struct JsonRpcRequest { + jsonrpc: &'static str, + id: u64, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, +} + +impl JsonRpcRequest { + fn new(id: u64, method: impl Into, params: Option) -> Self { + Self { + jsonrpc: "2.0", + id, + method: method.into(), + params, + } + } +} + +/// JSON-RPC 响应。 +#[derive(Debug, Serialize, Deserialize)] +struct JsonRpcResponse { + jsonrpc: String, + id: u64, + #[serde(default)] + result: Option, + #[serde(default)] + error: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct JsonRpcError { + code: i32, + message: String, + #[serde(default)] + data: Option, +} + +/// MCP 子进程运行时状态。 +struct ChildProcessState { + child: Child, + stdin: ChildStdin, + pending: HashMap>>, + next_id: u64, +} + +impl ChildProcessState { + fn next_id(&mut self) -> u64 { + self.next_id += 1; + self.next_id + } +} + +/// MCP Server 暴露的工具(缓存结构)。 +#[derive(Debug, Clone)] +struct McpTool { + name: String, + description: Option, + input_schema: Value, +} + +/// MCP 客户端 —— 与 MCP 服务器通信。 +pub struct McpClient { + transport: McpTransport, + server_name: String, + /// 已初始化的工具列表(缓存)。 + tools: Vec, + /// 是否已初始化。 + initialized: AtomicBool, + /// 超时时间(秒)。 + timeout_secs: u64, + /// 子进程运行时状态(`connect()` 后创建,`close()` 后取回)。 + process: Option>>, +} + +impl std::fmt::Debug for McpClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("McpClient") + .field("server_name", &self.server_name) + .field("initialized", &self.initialized.load(Ordering::SeqCst)) + .field("tool_count", &self.tools.len()) + .finish() + } +} + +impl McpClient { + /// 创建一个 MCP 客户端。 + pub fn new(server_name: impl Into, transport: McpTransport) -> Self { + Self { + transport, + server_name: server_name.into(), + tools: Vec::new(), + initialized: AtomicBool::new(false), + timeout_secs: 30, + process: None, + } + } + + /// 设置超时时间(秒)。 + pub fn with_timeout(mut self, secs: u64) -> Self { + self.timeout_secs = secs; + self + } + + /// 检查是否已连接。 + pub fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::SeqCst) + } + + /// 连接并初始化(发送 initialize 请求)。 + pub async fn connect(&mut self) -> Result<(), ToolError> { + if self.is_initialized() { + return Ok(()); + } + + match &self.transport { + McpTransport::Stdio { command, args } => { + let mut cmd = Command::new(command); + cmd.args(args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + #[cfg(unix)] + cmd.kill_on_drop(true); + #[cfg(windows)] + cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW + + let mut child = cmd + .spawn() + .map_err(|e| ToolError::McpError(format!("启动 MCP 子进程失败: {e}")))?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| ToolError::McpError("无法获取子进程 stdin".into()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| ToolError::McpError("无法获取子进程 stdout".into()))?; + + // 启动 reader task 持续读取 stdout + let pending: HashMap>> = + HashMap::new(); + let state = Arc::new(Mutex::new(ChildProcessState { + child, + stdin, + pending, + next_id: 0, + })); + + // 启动后台 reader + let pending_arc = Arc::clone(&state); + tokio::spawn(async move { + Self::read_loop(BufReader::new(stdout), pending_arc).await; + }); + + self.process = Some(state); + } + McpTransport::StreamableHttp { .. } => { + return Err(ToolError::McpError( + "StreamableHttp transport 尚未实现".into(), + )); + } + } + + // 发送 initialize 请求 + let init_params = json!({ + "protocolVersion": MCP_VERSION, + "capabilities": {}, + "clientInfo": { + "name": "agcore", + "version": env!("CARGO_PKG_VERSION") + } + }); + let _response = self + .send_request("initialize", Some(init_params)) + .await?; + + // 发送 initialized 通知(无 id) + self.send_notification("notifications/initialized", Some(json!({}))) + .await?; + + self.initialized.store(true, Ordering::SeqCst); + Ok(()) + } + + /// 列出服务器支持的工具(调用 `tools/list`)。 + pub async fn list_tools(&mut self) -> Result, ToolError> { + if !self.is_initialized() { + return Err(ToolError::McpNotInitialized(self.server_name.clone())); + } + + let response = self.send_request("tools/list", None).await?; + let tools_value = response + .get("tools") + .ok_or_else(|| ToolError::McpError("tools/list 响应缺少 tools 字段".into()))?; + let tools_arr = tools_value + .as_array() + .ok_or_else(|| ToolError::McpError("tools/list 响应 tools 字段不是数组".into()))?; + + self.tools.clear(); + let mut defs = Vec::with_capacity(tools_arr.len()); + for tool in tools_arr { + let name = tool + .get("name") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::McpError("工具缺少 name 字段".into()))? + .to_string(); + let description = tool + .get("description") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let input_schema = tool + .get("inputSchema") + .cloned() + .unwrap_or_else(|| json!({"type": "object", "properties": {}})); + + self.tools.push(McpTool { + name: name.clone(), + description: description.clone(), + input_schema: input_schema.clone(), + }); + defs.push(ToolDefinition { + name, + description, + parameters: input_schema, + strict: None, + }); + } + Ok(defs) + } + + /// 调用一个工具(调用 `tools/call`)。 + pub async fn call_tool(&self, name: &str, args: Value) -> Result { + if !self.is_initialized() { + return Err(ToolError::McpNotInitialized(self.server_name.clone())); + } + + let params = json!({ + "name": name, + "arguments": args, + }); + let response = self.send_request("tools/call", Some(params)).await?; + + // 解析 content 字段 + if let Some(content) = response.get("content").and_then(|c| c.as_array()) { + // 收集所有 text 内容 + let mut combined = String::new(); + for item in content { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + if !combined.is_empty() { + combined.push('\n'); + } + combined.push_str(text); + } + } + if !combined.is_empty() { + return Ok(Value::String(combined)); + } + } + + // 如果没有 content 字段,尝试直接返回 is_error 标记 + if let Some(true) = response.get("isError").and_then(|v| v.as_bool()) { + return Err(ToolError::ExecutionFailed( + name.to_string(), + "MCP 工具返回 isError=true".into(), + )); + } + + // 回退:返回完整响应 + Ok(response) + } + + /// 关闭连接(终止子进程)。 + pub async fn close(&mut self) -> Result<(), ToolError> { + if !self.is_initialized() { + return Ok(()); + } + + // 尝试发送 shutdown(不强制要求响应) + let _ = self.send_notification("shutdown", None).await; + + if let Some(state) = self.process.take() { + let mut state = state.lock().await; + // 优雅等待 5 秒 + let graceful = tokio::time::timeout( + Duration::from_secs(5), + state.child.wait(), + ) + .await; + if graceful.is_err() { + // 超时则强杀 + let _ = state.child.kill().await; + } + } + + self.initialized.store(false, Ordering::SeqCst); + self.tools.clear(); + Ok(()) + } + + /// 将 MCP 客户端转换为 `BaseTool` 适配器列表(用于注册到 `ToolRegistry`)。 + /// + /// **注意**:返回的适配器持有 `Arc`,但 `McpClient` 的可变性 + /// (如 `list_tools` 刷新缓存)会通过 `Mutex` 处理。当前适配器仅缓存 + /// 转换时的工具列表,不感知后续刷新。 + pub fn into_tools(self) -> Vec { + let mut tools = Vec::with_capacity(self.tools.len()); + for mcp_tool in self.tools { + let tool = McpToolAdapter { + client: McpClientHandle::Empty, + name: mcp_tool.name, + description: mcp_tool.description.unwrap_or_default(), + parameters: mcp_tool.input_schema, + }; + tools.push(Arc::new(tool) as ToolRef); + } + tools + } + + async fn send_request( + &self, + method: &str, + params: Option, + ) -> Result { + let state_arc = self + .process + .as_ref() + .ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))? + .clone(); + + let (id, request_json) = { + let mut state = state_arc.lock().await; + let id = state.next_id(); + let req = JsonRpcRequest::new(id, method, params); + let json = serde_json::to_string(&req) + .map_err(|e| ToolError::McpError(format!("序列化请求失败: {e}")))?; + (id, json) + }; + + // 注册 oneshot 等待响应 + let (tx, rx) = oneshot::channel(); + { + let mut state = state_arc.lock().await; + state.pending.insert(id, tx); + } + + // 写入请求 + { + let mut state = state_arc.lock().await; + state + .stdin + .write_all(request_json.as_bytes()) + .await + .map_err(|e| ToolError::McpError(format!("写入请求失败: {e}")))?; + state + .stdin + .write_all(b"\n") + .await + .map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?; + state.stdin.flush().await.map_err(|e| { + ToolError::McpError(format!("flush stdin 失败: {e}")) + })?; + } + + // 等待响应(带超时) + tokio::time::timeout(Duration::from_secs(self.timeout_secs), rx) + .await + .map_err(|_| { + // 超时:清理 pending + let state_arc = state_arc.clone(); + tokio::spawn(async move { + let mut state = state_arc.lock().await; + state.pending.remove(&id); + }); + ToolError::McpTimeout(method.to_string()) + })? + .map_err(|_| ToolError::McpError("response channel 关闭".into()))? + } + + async fn send_notification( + &self, + method: &str, + params: Option, + ) -> Result<(), ToolError> { + let state_arc = self + .process + .as_ref() + .ok_or_else(|| ToolError::McpNotInitialized(self.server_name.clone()))? + .clone(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + let json = serde_json::to_string(¬ification) + .map_err(|e| ToolError::McpError(format!("序列化通知失败: {e}")))?; + + let mut state = state_arc.lock().await; + state + .stdin + .write_all(json.as_bytes()) + .await + .map_err(|e| ToolError::McpError(format!("写入通知失败: {e}")))?; + state + .stdin + .write_all(b"\n") + .await + .map_err(|e| ToolError::McpError(format!("写入换行失败: {e}")))?; + state + .stdin + .flush() + .await + .map_err(|e| ToolError::McpError(format!("flush stdin 失败: {e}")))?; + Ok(()) + } + + /// 持续读取 stdout,将响应分发到对应的 oneshot sender。 + async fn read_loop( + mut reader: BufReader, + state: Arc>, + ) { + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => { + // EOF:通知所有 pending 失败 + let mut state = state.lock().await; + for (_, tx) in state.pending.drain() { + let _ = tx.send(Err(ToolError::McpError("子进程退出".into()))); + } + break; + } + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + // 尝试解析为 JSON-RPC 响应 + let parsed: Result = serde_json::from_str(trimmed); + if let Ok(response) = parsed { + let value = if let Some(err) = response.error { + Err(ToolError::McpError(format!( + "[{}] {}", + err.code, err.message + ))) + } else { + Ok(response.result.unwrap_or(Value::Null)) + }; + let mut state = state.lock().await; + if let Some(tx) = state.pending.remove(&response.id) { + let _ = tx.send(value); + } + } + // 非响应消息(通知、request from server)忽略 + } + Err(e) => { + tracing::warn!("MCP read_loop error: {e}"); + let mut state = state.lock().await; + for (_, tx) in state.pending.drain() { + let _ = tx.send(Err(ToolError::McpError(format!("读取失败: {e}")))); + } + break; + } + } + } + } +} + +/// MCP 工具适配器 —— 将 MCP 工具包装为 `BaseTool`。 +struct McpToolAdapter { + /// 持有 client 的弱引用。实际生产中应使用 `Arc`, + /// 但当前 Phase 2 实现不直接持有可变的 `McpClient`。 + /// 标记为 unused 但保留字段以展示扩展路径。 + #[allow(dead_code)] + client: McpClientHandle, + name: String, + description: String, + parameters: Value, +} + +#[allow(dead_code)] +enum McpClientHandle { + Empty, + // Future: Shared(Arc), +} + +#[async_trait] +impl BaseTool for McpToolAdapter { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + &self.description + } + + fn parameters(&self) -> Value { + self.parameters.clone() + } + + async fn execute( + &self, + _args: Value, + _ctx: &ToolContext<'_>, + ) -> Result { + // 当前 Phase 2 实现的简化:McpToolAdapter 不持有活跃 MCP 连接。 + // 实际生产中应持有 Arc 并通过 mcp.call_tool() 执行。 + // 这里返回错误,提示需要通过其他方式调用 MCP 工具。 + Err(ToolError::McpError(format!( + "MCP 工具 '{}' 需要活跃的 McpClient 引用(当前 Phase 2 简化实现)", + self.name + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transport_debug() { + let transport = McpTransport::Stdio { + command: "echo".to_string(), + args: vec!["hello".to_string()], + }; + let formatted = format!("{transport:?}"); + assert!(formatted.contains("echo")); + } + + #[test] + fn test_client_creation() { + let transport = McpTransport::Stdio { + command: "test".to_string(), + args: vec![], + }; + let client = McpClient::new("test-server", transport).with_timeout(60); + assert_eq!(client.server_name, "test-server"); + assert_eq!(client.timeout_secs, 60); + assert!(!client.is_initialized()); + } + + #[test] + fn test_jsonrpc_request_serialize() { + let req = JsonRpcRequest::new(42, "test", Some(json!({"a": 1}))); + let s = serde_json::to_string(&req).unwrap(); + assert!(s.contains("\"jsonrpc\":\"2.0\"")); + assert!(s.contains("\"id\":42")); + assert!(s.contains("\"method\":\"test\"")); + } + + #[test] + fn test_jsonrpc_response_parse_ok() { + let s = r#"{"jsonrpc":"2.0","id":1,"result":{"foo":"bar"}}"#; + let resp: JsonRpcResponse = serde_json::from_str(s).unwrap(); + assert_eq!(resp.id, 1); + assert!(resp.result.is_some()); + assert!(resp.error.is_none()); + } + + #[test] + fn test_jsonrpc_response_parse_error() { + let s = + r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#; + let resp: JsonRpcResponse = serde_json::from_str(s).unwrap(); + assert_eq!(resp.id, 1); + assert!(resp.result.is_none()); + let err = resp.error.unwrap(); + assert_eq!(err.code, -32601); + } + + #[tokio::test] + async fn test_streamable_http_not_implemented() { + let mut client = McpClient::new( + "http-server", + McpTransport::StreamableHttp { + url: "https://example.com/mcp".to_string(), + headers: None, + }, + ); + let result = client.connect().await; + // 当前 Phase 2 返回未实现错误 + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::McpError(_)))); + } +} diff --git a/src/tools/permission.rs b/src/tools/permission.rs new file mode 100644 index 0000000..b02df36 --- /dev/null +++ b/src/tools/permission.rs @@ -0,0 +1,286 @@ +//! 工具权限管理。 + +use crate::tools::error::ToolError; + +/// 权限级别枚举。 +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Permission { + /// 只读(读取文件、查询数据库等)。 + Read, + /// 写入(创建/修改文件、插入数据等)。 + Write, + /// 删除(删除文件、记录等)。 + Delete, + /// 网络访问(HTTP 请求等)。 + Network, + /// Shell 命令执行。 + Shell, + /// 文件系统操作(除读/写/删之外的 FS 操作)。 + FileSystem, + /// 自定义权限(可通过 namespaced 字符串扩展)。 + Custom(String), +} + +impl std::fmt::Display for Permission { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Read => write!(f, "Read"), + Self::Write => write!(f, "Write"), + Self::Delete => write!(f, "Delete"), + Self::Network => write!(f, "Network"), + Self::Shell => write!(f, "Shell"), + Self::FileSystem => write!(f, "FileSystem"), + Self::Custom(s) => write!(f, "Custom({})", s), + } + } +} + +/// 权限配置。 +#[derive(Debug, Clone)] +pub struct PermissionConfig { + /// 允许的权限列表(空 = 全部允许,配合 `allow_unspecified` 决定)。 + pub allowed: Vec, + /// 拒绝的权限列表(优先级高于 `allowed`)。 + pub denied: Vec, + /// 当工具未声明权限时是否允许执行。 + pub allow_unspecified: bool, +} + +impl Default for PermissionConfig { + fn default() -> Self { + Self { + allowed: vec![Permission::Read, Permission::Network], + denied: vec![Permission::Delete, Permission::Shell], + allow_unspecified: true, + } + } +} + +/// 权限检查器。 +#[derive(Debug, Clone)] +pub struct PermissionChecker { + config: PermissionConfig, +} + +impl PermissionChecker { + /// 创建一个新的权限检查器。 + pub fn new(config: PermissionConfig) -> Self { + Self { config } + } + + /// 检查指定工具声明的权限是否允许执行。 + /// + /// 判定规则: + /// 1. 任一权限在 `denied` 中 → 拒绝 + /// 2. 所有权限都在 `allowed` 中 → 允许 + /// 3. `allowed` 非空且存在未声明权限 → 拒绝 + /// 4. `allowed` 为空 → 按 `allow_unspecified` 判定 + /// 5. 工具未声明任何权限时按 `allow_unspecified` 判定 + pub fn check(&self, tool_name: &str, permissions: &[Permission]) -> Result<(), ToolError> { + // 任一权限在 denied 中 → 拒绝 + for perm in permissions { + if self.config.denied.contains(perm) { + return Err(ToolError::PermissionDenied( + tool_name.to_string(), + perm.to_string(), + )); + } + } + + // 工具未声明任何权限 + if permissions.is_empty() { + return if self.config.allow_unspecified { + Ok(()) + } else { + Err(ToolError::PermissionDenied( + tool_name.to_string(), + "Unspecified".to_string(), + )) + }; + } + + // allowed 为空 → 走 allow_unspecified 兜底 + if self.config.allowed.is_empty() { + return if self.config.allow_unspecified { + Ok(()) + } else { + Err(ToolError::PermissionDenied( + tool_name.to_string(), + "Unspecified".to_string(), + )) + }; + } + + // allowed 非空(白名单模式)—— 所有权限必须在其中 + for perm in permissions { + if !self.config.allowed.contains(perm) { + return Err(ToolError::PermissionDenied( + tool_name.to_string(), + perm.to_string(), + )); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn p(perm: Permission) -> Vec { + vec![perm] + } + + #[test] + fn test_default_config_allows_read() { + let checker = PermissionChecker::new(PermissionConfig::default()); + assert!(checker.check("weather", &p(Permission::Read)).is_ok()); + } + + #[test] + fn test_default_config_allows_network() { + let checker = PermissionChecker::new(PermissionConfig::default()); + assert!(checker.check("http_get", &p(Permission::Network)).is_ok()); + } + + #[test] + fn test_default_config_denies_delete() { + let checker = PermissionChecker::new(PermissionConfig::default()); + assert!(checker + .check("rm_file", &p(Permission::Delete)) + .is_err()); + } + + #[test] + fn test_default_config_denies_shell() { + let checker = PermissionChecker::new(PermissionConfig::default()); + assert!(checker.check("run_shell", &p(Permission::Shell)).is_err()); + } + + #[test] + fn test_white_list_mode_denies_unlisted() { + let cfg = PermissionConfig { + allowed: vec![Permission::Read], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &p(Permission::Read)).is_ok()); + assert!(checker.check("t", &p(Permission::Write)).is_err()); + } + + #[test] + fn test_white_list_mode_allows_listed() { + let cfg = PermissionConfig { + allowed: vec![Permission::Read, Permission::Write], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &p(Permission::Write)).is_ok()); + } + + #[test] + fn test_black_list_deny_priority() { + // 即便 allowed 中包含了 denied 权限,仍以 denied 为准 + let cfg = PermissionConfig { + allowed: vec![Permission::Shell, Permission::Read], + denied: vec![Permission::Shell], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &p(Permission::Shell)).is_err()); + } + + #[test] + fn test_empty_allowed_with_allow_unspecified() { + let cfg = PermissionConfig { + allowed: vec![], + denied: vec![], + allow_unspecified: true, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &p(Permission::Write)).is_ok()); + } + + #[test] + fn test_empty_allowed_without_allow_unspecified() { + let cfg = PermissionConfig { + allowed: vec![], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &p(Permission::Write)).is_err()); + } + + #[test] + fn test_unspecified_tool_with_allow() { + let cfg = PermissionConfig { + allowed: vec![], + denied: vec![], + allow_unspecified: true, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &[]).is_ok()); + } + + #[test] + fn test_unspecified_tool_without_allow() { + let cfg = PermissionConfig { + allowed: vec![], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker.check("t", &[]).is_err()); + } + + #[test] + fn test_custom_permission_collision() { + let cfg = PermissionConfig { + allowed: vec![Permission::Custom("db:read".into())], + denied: vec![Permission::Custom("db:write".into())], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker + .check("t", &[Permission::Custom("db:read".into())]) + .is_ok()); + assert!(checker + .check("t", &[Permission::Custom("db:write".into())]) + .is_err()); + } + + #[test] + fn test_multi_permission_all_in_allowed() { + let cfg = PermissionConfig { + allowed: vec![Permission::Read, Permission::Write, Permission::Network], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + assert!(checker + .check( + "t", + &[Permission::Read, Permission::Network] + ) + .is_ok()); + } + + #[test] + fn test_multi_permission_one_not_in_allowed() { + let cfg = PermissionConfig { + allowed: vec![Permission::Read, Permission::Network], + denied: vec![], + allow_unspecified: false, + }; + let checker = PermissionChecker::new(cfg); + // 任一权限不在白名单则拒绝 + assert!(checker + .check("t", &[Permission::Read, Permission::Write]) + .is_err()); + } +} diff --git a/src/tools/registry.rs b/src/tools/registry.rs new file mode 100644 index 0000000..db0c15d --- /dev/null +++ b/src/tools/registry.rs @@ -0,0 +1,371 @@ +//! 工具注册表 —— 管理工具注册、发现、调用。 + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use futures::future::join_all; +use serde_json::Value; + +use crate::llm::types::ToolDefinition; +use crate::tools::base::{ToolContext, ToolRef}; +use crate::tools::error::ToolError; +use crate::tools::permission::PermissionChecker; + +/// 工具调用记录 —— 用于追踪和调试。 +#[derive(Debug, Clone)] +pub struct ToolInvocation { + /// 被调用的工具名。 + pub tool_name: String, + /// 工具的入参。 + pub input: Value, + /// 工具的输出。 + pub output: Result, +} + +impl ToolInvocation { + /// 创建一个新的工具调用记录。 + pub fn new(tool_name: String, input: Value, output: Result) -> Self { + Self { + tool_name, + input, + output, + } + } +} + +/// 工具注册表 —— 管理工具注册、发现、调用。 +/// +/// 通过 `Arc` 共享,方法签名 `&self`,可安全跨 task 并行调用。 +/// 不支持运行时并发注册(应在 setup 阶段一次性构建后冻结)。 +#[derive(Clone, Default)] +pub struct ToolRegistry { + inner: Arc, +} + +#[derive(Clone, Default)] +struct ToolRegistryInner { + tools: HashMap, + permission_checker: Option>, +} + +impl std::fmt::Debug for ToolRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ToolRegistry") + .field("tool_names", &self.inner.tools.keys().collect::>()) + .field("has_checker", &self.inner.permission_checker.is_some()) + .finish() + } +} + +impl ToolRegistry { + /// 创建一个新的工具注册表。 + pub fn new() -> Self { + Self { + inner: Arc::new(ToolRegistryInner { + tools: HashMap::new(), + permission_checker: None, + }), + } + } + + /// 设置权限检查器(Builder 模式)。 + pub fn with_permission_checker(mut self, checker: PermissionChecker) -> Self { + let inner = Arc::make_mut(&mut self.inner); + inner.permission_checker = Some(Arc::new(checker)); + self + } + + /// 注册一个工具。 + /// + /// 重复注册同名工具返回错误。 + pub fn register(&mut self, tool: ToolRef) -> Result<(), ToolError> { + let name = tool.name().to_string(); + let inner = Arc::make_mut(&mut self.inner); + if inner.tools.contains_key(&name) { + return Err(ToolError::ExecutionFailed(name, "工具已存在".to_string())); + } + inner.tools.insert(name, tool); + Ok(()) + } + + /// 批量注册工具。 + pub fn register_all(&mut self, tools: Vec) -> Result<(), ToolError> { + for tool in tools { + self.register(tool)?; + } + Ok(()) + } + + /// 注销一个工具。 + pub fn unregister(&mut self, name: &str) -> Option { + let inner = Arc::make_mut(&mut self.inner); + inner.tools.remove(name) + } + + /// 按名称查找工具。 + pub fn get(&self, name: &str) -> Option { + self.inner.tools.get(name).cloned() + } + + /// 获取所有已注册工具的名称列表。 + pub fn list_tools(&self) -> Vec { + self.inner.tools.keys().cloned().collect() + } + + /// 获取所有工具的 `ToolDefinition` 列表(用于传递给 LLM)。 + pub fn definitions(&self) -> Vec { + self.inner + .tools + .values() + .map(|tool| ToolDefinition { + name: tool.name().to_string(), + description: Some(tool.description().to_string()), + parameters: tool.parameters(), + strict: None, + }) + .collect() + } + + /// 调用单个工具(含权限检查)。 + pub async fn invoke(&self, name: &str, args: Value) -> Result { + let tool = self + .get(name) + .ok_or_else(|| ToolError::NotFound(name.to_string()))?; + + if let Some(checker) = &self.inner.permission_checker { + checker.check(name, &tool.required_permissions())?; + } + + let ctx = ToolContext::new(name, ""); + let output = tool.execute(args.clone(), &ctx).await; + Ok(ToolInvocation::new(name.to_string(), args, output)) + } + + /// 并行执行多个工具调用(互不依赖的工具)。 + /// + /// 每个工具独立超时(`timeout_per_call_secs`,0 表示不超时)。 + /// 单个工具超时不会影响其他工具的返回。 + pub async fn invoke_all( + &self, + calls: Vec<(String, Value)>, + timeout_per_call_secs: u64, + ) -> Vec { + let this = self.clone(); + let futures = calls.into_iter().map(|(name, args)| { + let this = this.clone(); + async move { + match if timeout_per_call_secs == 0 { + Ok(this.invoke(&name, args.clone()).await) + } else { + tokio::time::timeout( + Duration::from_secs(timeout_per_call_secs), + this.invoke(&name, args.clone()), + ) + .await + } { + Ok(result) => result.unwrap_or_else(|e| { + ToolInvocation::new(name.clone(), args.clone(), Err(e)) + }), + Err(_) => ToolInvocation::new( + name, + args, + Err(ToolError::McpTimeout("timeout".into())), + ), + } + } + }); + join_all(futures).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tools::BaseTool; + use async_trait::async_trait; + use serde_json::json; + + struct AddTool { + base: i64, + } + + #[async_trait] + impl BaseTool for AddTool { + fn name(&self) -> &str { + "add" + } + + fn description(&self) -> &str { + "加法" + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { "n": { "type": "integer" } }, + "required": ["n"] + }) + } + + async fn execute(&self, args: Value, _ctx: &ToolContext<'_>) -> Result { + let n = args["n"].as_i64().unwrap_or(0); + Ok(json!({ "result": self.base + n })) + } + } + + struct FailTool; + + #[async_trait] + impl BaseTool for FailTool { + fn name(&self) -> &str { + "fail" + } + fn description(&self) -> &str { + "总会失败" + } + fn parameters(&self) -> Value { + json!({}) + } + async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result { + Err(ToolError::ExecutionFailed("fail".into(), "boom".into())) + } + } + + struct ShellTool; + + #[async_trait] + impl BaseTool for ShellTool { + fn name(&self) -> &str { + "shell" + } + fn description(&self) -> &str { + "shell" + } + fn parameters(&self) -> Value { + json!({}) + } + fn required_permissions(&self) -> Vec { + vec![crate::tools::permission::Permission::Shell] + } + async fn execute(&self, _args: Value, _ctx: &ToolContext<'_>) -> Result { + Ok(json!({})) + } + } + + #[test] + fn test_register_and_get() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 10 })).unwrap(); + assert!(reg.get("add").is_some()); + assert!(reg.get("nonexistent").is_none()); + } + + #[test] + fn test_register_duplicate() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 0 })).unwrap(); + let result = reg.register(Arc::new(AddTool { base: 1 })); + assert!(result.is_err()); + } + + #[test] + fn test_register_all() { + let mut reg = ToolRegistry::new(); + let result = reg.register_all(vec![ + Arc::new(AddTool { base: 1 }), + Arc::new(AddTool { base: 2 }), + ]); + assert!(result.is_err()); // 重名 add → 失败 + } + + #[test] + fn test_unregister() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 0 })).unwrap(); + let removed = reg.unregister("add"); + assert!(removed.is_some()); + assert!(reg.get("add").is_none()); + } + + #[test] + fn test_list_tools() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 0 })).unwrap(); + reg.register(Arc::new(FailTool)).unwrap(); + let names = reg.list_tools(); + assert_eq!(names.len(), 2); + assert!(names.contains(&"add".to_string())); + assert!(names.contains(&"fail".to_string())); + } + + #[test] + fn test_definitions() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 0 })).unwrap(); + let defs = reg.definitions(); + assert_eq!(defs.len(), 1); + assert_eq!(defs[0].name, "add"); + assert!(defs[0].description.is_some()); + } + + #[tokio::test] + async fn test_invoke_success() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 100 })).unwrap(); + let result = reg.invoke("add", json!({ "n": 5 })).await.unwrap(); + let value = result.output.unwrap(); + assert_eq!(value["result"], 105); + } + + #[tokio::test] + async fn test_invoke_not_found() { + let reg = ToolRegistry::new(); + let result = reg.invoke("nope", json!({})).await; + assert!(matches!(result, Err(ToolError::NotFound(_)))); + } + + #[tokio::test] + async fn test_invoke_execution_error() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(FailTool)).unwrap(); + let result = reg.invoke("fail", json!({})).await.unwrap(); + assert!(result.output.is_err()); + } + + #[tokio::test] + async fn test_invoke_with_permission_denied() { + let mut reg = ToolRegistry::new() + .with_permission_checker(PermissionChecker::new(Default::default())); + reg.register(Arc::new(ShellTool)).unwrap(); + let result = reg.invoke("shell", json!({})).await; + assert!(matches!(result, Err(ToolError::PermissionDenied(_, _)))); + } + + #[tokio::test] + async fn test_invoke_all_parallel() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 1 })).unwrap(); + reg.register(Arc::new(FailTool)).unwrap(); + let calls = vec![ + ("add".into(), json!({ "n": 1 })), + ("add".into(), json!({ "n": 2 })), + ("fail".into(), json!({})), + ]; + let results = reg.invoke_all(calls, 0).await; + assert_eq!(results.len(), 3); + assert!(results[0].output.is_ok()); + assert!(results[1].output.is_ok()); + assert!(results[2].output.is_err()); + } + + #[tokio::test] + async fn test_invoke_all_with_timeout() { + let mut reg = ToolRegistry::new(); + reg.register(Arc::new(AddTool { base: 0 })).unwrap(); + let calls = vec![("add".into(), json!({ "n": 1 }))]; + let results = reg.invoke_all(calls, 5).await; + assert_eq!(results.len(), 1); + assert!(results[0].output.is_ok()); + } +}