diff --git a/Cargo.toml b/Cargo.toml index 539b965..6040a85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ futures-core = "0.3" bytes = "1" async-stream = "0.3" tokio-util = { version = "0.7", features = ["rt"] } +time = { version = "0.3", features = ["serde"] } [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/lib.rs b/src/lib.rs index c63cb2f..4444d6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ //! agcore —— 智能体(Agent)核心工具箱。 pub mod llm; +pub mod memory; pub mod prompt; pub mod tools; diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 0000000..84203e0 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,22 @@ +//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。 + +pub mod conversation; +pub mod error; +pub mod knowledge; +pub mod retriever; +pub mod store; +pub mod types; + +// 高频类型(大多数下游需要) +pub use conversation::{ConversationMemory, ConversationMemoryConfig}; +pub use error::MemoryError; +pub use knowledge::KnowledgeStore; +pub use retriever::MemoryRetriever; +pub use store::{InMemoryStore, MemoryStore}; + +// 低频类型(配置/高级使用) +pub use conversation::MemoryStrategy; +pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX}; +pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem}; +pub use store::{EvictionConfig, EvictionPolicy}; +pub use types::{KnowledgePage, MemoryFilter, MemoryItem}; diff --git a/src/memory/conversation.rs b/src/memory/conversation.rs new file mode 100644 index 0000000..c60fdc0 --- /dev/null +++ b/src/memory/conversation.rs @@ -0,0 +1,260 @@ +//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。 + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact}; +use crate::llm::types::OpenaiChatMessage; +use crate::memory::error::MemoryError; +use crate::memory::store::MemoryStore; +use crate::memory::types::MemoryItem; + +/// 对话消息管理策略。 +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum MemoryStrategy { + /// 滑动窗口:达到上限时删除最旧消息。 + SlidingWindow, + /// 保留所有消息(仅压缩,不删除)。 + Full, +} + +/// 对话记忆配置。 +#[derive(Debug, Clone)] +pub struct ConversationMemoryConfig { + pub strategy: MemoryStrategy, + pub max_turns: usize, + pub compact_config: Option, +} + +impl Default for ConversationMemoryConfig { + fn default() -> Self { + Self { + strategy: MemoryStrategy::SlidingWindow, + max_turns: 50, + compact_config: Some(CompactConfig::default()), + } + } +} + +/// 对话记忆 —— 按 session 管理多轮对话消息历史。 +/// +/// 内部维护 `Vec` 热缓存(供 `llm::compact` 直接操作), +/// `MemoryStore` 用作冷持久化层。 +pub struct ConversationMemory { + store: Arc, + session_id: String, + config: ConversationMemoryConfig, + /// 热缓存:消息列表,供 `llm::compact` 直接操作。 + messages: Vec, + /// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。 + message_ids: Vec, + /// 压缩断路器状态。 + compact_state: CompactState, +} + +impl ConversationMemory { + /// 创建一个新的 ConversationMemory。 + pub fn new( + store: Arc, + session_id: impl Into, + config: ConversationMemoryConfig, + ) -> Self { + Self { + store, + session_id: session_id.into(), + config, + messages: Vec::new(), + message_ids: Vec::new(), + compact_state: CompactState::new(), + } + } + + /// 获取 session id。 + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// 获取配置。 + pub fn config(&self) -> &ConversationMemoryConfig { + &self.config + } + + /// 从 MemoryStore 加载历史消息到热缓存。 + pub async fn load(&mut self) -> Result<(), MemoryError> { + let filter = crate::memory::types::MemoryFilter { + prefix: Some(self.session_prefix()), + ..Default::default() + }; + let items = self.store.list(&filter).await?; + let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len()); + for item in items { + match serde_json::from_str::(&item.content) { + Ok(msg) => pairs.push((item.id, msg, item.created_at)), + Err(e) => { + return Err(MemoryError::Serialization(format!( + "load message {} failed: {e}", + item.id + ))); + } + } + } + // 按 created_at 升序排列 + pairs.sort_by_key(|p| p.2); + self.message_ids = pairs.iter().map(|p| p.0.clone()).collect(); + self.messages = pairs.into_iter().map(|p| p.1).collect(); + Ok(()) + } + + /// 添加一条消息。 + /// + /// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。 + pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> { + let now = OffsetDateTime::now_utc(); + let index = self.messages.len(); + let id = self.make_message_id(index, &now); + + // 写入热缓存 + self.messages.push(msg); + self.message_ids.push(id.clone()); + + // 同步到冷存储 + let item = MemoryItem { + id: id.clone(), + content: serde_json::to_string(self.messages.last().unwrap()) + .map_err(|e| MemoryError::Serialization(e.to_string()))?, + metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }), + created_at: now, + }; + self.store.save(item).await?; + + // 触发淘汰和压缩 + self.maybe_evict_and_compact().await; + Ok(()) + } + + /// 获取完整消息历史。 + pub fn get_history(&self) -> &[OpenaiChatMessage] { + &self.messages + } + + /// 清空所有消息。 + pub async fn clear(&mut self) -> Result<(), MemoryError> { + let to_delete = std::mem::take(&mut self.message_ids); + self.messages.clear(); + self.compact_state = CompactState::new(); + for id in to_delete { + self.store.delete(&id).await?; + } + Ok(()) + } + + /// 当前消息数量。 + pub fn len(&self) -> usize { + self.messages.len() + } + + /// 是否为空。 + pub fn is_empty(&self) -> bool { + self.messages.is_empty() + } + + fn session_prefix(&self) -> String { + format!("conv:{self}:", self = self.session_id) + } + + fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String { + format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos()) + } + + async fn maybe_evict_and_compact(&mut self) { + // 1. Sliding window 淘汰:删除最旧消息 + if self.config.strategy == MemoryStrategy::SlidingWindow { + while self.messages.len() > self.config.max_turns { + if let Some(removed_id) = self.message_ids.first().cloned() { + let _ = self.store.delete(&removed_id).await; + } + self.messages.remove(0); + self.message_ids.remove(0); + } + } + + // 2. 压缩(复用 llm::compact) + if let Some(ref compact_config) = self.config.compact_config { + if should_compact(&self.messages, compact_config, &self.compact_state) { + let keep_recent = compact_config.keep_recent; + let freed = microcompact(&mut self.messages, keep_recent); + if freed > 0 { + self.compact_state.record_success(); + } else { + // 没有 token 被释放(可能没找到可压缩的 tool result) + let _ = self.compact_state.record_failure(); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::types::OpenaiChatMessage; + use crate::memory::InMemoryStore; + use crate::memory::MemoryStore; + + fn user_text(s: &str) -> OpenaiChatMessage { + OpenaiChatMessage::user_text(s) + } + + #[tokio::test] + async fn add_and_get_history() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default()); + conv.add_message(user_text("hello")).await.unwrap(); + conv.add_message(user_text("world")).await.unwrap(); + assert_eq!(conv.len(), 2); + assert_eq!(conv.get_history().len(), 2); + } + + #[tokio::test] + async fn sliding_window_evicts_oldest() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let config = ConversationMemoryConfig { + strategy: MemoryStrategy::SlidingWindow, + max_turns: 3, + compact_config: None, + }; + let mut conv = ConversationMemory::new(store, "s1", config); + for i in 0..5 { + conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap(); + } + assert_eq!(conv.len(), 3); + } + + #[tokio::test] + async fn full_strategy_no_evict() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let config = ConversationMemoryConfig { + strategy: MemoryStrategy::Full, + max_turns: 3, + compact_config: None, + }; + let mut conv = ConversationMemory::new(store, "s1", config); + for i in 0..5 { + conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap(); + } + // Full 策略不删除消息 + assert_eq!(conv.len(), 5); + } + + #[tokio::test] + async fn clear_empties_messages() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default()); + conv.add_message(user_text("hello")).await.unwrap(); + assert!(!conv.is_empty()); + conv.clear().await.unwrap(); + assert!(conv.is_empty()); + } +} diff --git a/src/memory/error.rs b/src/memory/error.rs new file mode 100644 index 0000000..c7ac6d4 --- /dev/null +++ b/src/memory/error.rs @@ -0,0 +1,29 @@ +//! 记忆系统错误类型。 + +use thiserror::Error; + +/// 记忆系统错误枚举。 +#[derive(Debug, Error)] +pub enum MemoryError { + #[error("Item not found: {0}")] + NotFound(String), + + #[error("Storage error: {0}")] + Storage(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Invalid input: {0}")] + InvalidInput(String), + + #[error("Retrieval error: {0}")] + RetrievalError(String), +} + +impl MemoryError { + /// 是否为可恢复错误(调用方可重试或调整参数)。 + pub fn is_recoverable(&self) -> bool { + matches!(self, Self::NotFound(_) | Self::RetrievalError(_)) + } +} diff --git a/src/memory/knowledge.rs b/src/memory/knowledge.rs new file mode 100644 index 0000000..79770ea --- /dev/null +++ b/src/memory/knowledge.rs @@ -0,0 +1,247 @@ +//! 知识库 —— KnowledgePage 存储与关键词检索。 + +use std::sync::Arc; + +use time::OffsetDateTime; + +use crate::memory::error::MemoryError; +use crate::memory::store::MemoryStore; +use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem}; + +pub use crate::memory::types::PageIndexEntry; + +/// `MemoryItem.id` 中知识页面前缀。 +pub const KNOWLEDGE_PREFIX: &str = "knowledge_"; + +/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。 +/// +/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON), +/// 同时维护一个 `Vec` 索引以加速列表遍历。 +pub struct KnowledgeStore { + store: Arc, + index: std::sync::Mutex>, +} + +impl KnowledgeStore { + /// 创建一个新的 KnowledgeStore。 + pub fn new(store: Arc) -> Self { + Self { + store, + index: std::sync::Mutex::new(Vec::new()), + } + } + + /// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。 + pub async fn rebuild_index(&self) -> Result<(), MemoryError> { + let items = self + .store + .list(&MemoryFilter { + prefix: Some(KNOWLEDGE_PREFIX.to_string()), + ..Default::default() + }) + .await?; + let mut index = self.index.lock().unwrap(); + index.clear(); + for item in items { + let page: KnowledgePage = serde_json::from_str(&item.content) + .map_err(|e| MemoryError::Serialization(e.to_string()))?; + index.push(PageIndexEntry::from(&page)); + } + Ok(()) + } + + /// 创建一个新的知识页面。 + pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { + if page.id.is_empty() { + return Err(MemoryError::InvalidInput("page.id is empty".into())); + } + let now = OffsetDateTime::now_utc(); + let id = format!("{KNOWLEDGE_PREFIX}{}", page.id); + let content = serde_json::to_string(&page) + .map_err(|e| MemoryError::Serialization(e.to_string()))?; + let item = MemoryItem { + id, + content, + metadata: serde_json::json!({}), + created_at: now, + }; + self.store.save(item).await?; + let mut index = self.index.lock().unwrap(); + // 替换或追加 + if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) { + *existing = PageIndexEntry::from(&page); + } else { + index.push(PageIndexEntry::from(&page)); + } + Ok(()) + } + + /// 根据 page id 获取一个页面。 + pub async fn get_page(&self, id: &str) -> Result, MemoryError> { + let full_id = format!("{KNOWLEDGE_PREFIX}{id}"); + let item = self.store.get(&full_id).await?; + match item { + None => Ok(None), + Some(item) => { + let page: KnowledgePage = serde_json::from_str(&item.content) + .map_err(|e| MemoryError::Serialization(e.to_string()))?; + Ok(Some(page)) + } + } + } + + /// 更新一个已存在的知识页面。 + pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> { + if page.id.is_empty() { + return Err(MemoryError::InvalidInput("page.id is empty".into())); + } + // 通过 get_page 检查存在性 + if self.get_page(&page.id).await?.is_none() { + return Err(MemoryError::NotFound(page.id)); + } + self.add_page(page).await + } + + /// 删除一个知识页面。 + pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> { + let full_id = format!("{KNOWLEDGE_PREFIX}{id}"); + self.store.delete(&full_id).await?; + let mut index = self.index.lock().unwrap(); + index.retain(|e| e.id != id); + Ok(()) + } + + /// 根据关键词搜索知识页面。 + /// + /// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。 + /// 全文 `content` 搜索走 `MemoryStore`。 + pub async fn search(&self, query: &str) -> Result, MemoryError> { + if query.is_empty() { + return Ok(Vec::new()); + } + let needle = query.to_lowercase(); + let mut results = Vec::new(); + let index = self.index.lock().unwrap(); + for entry in index.iter() { + if entry.title.to_lowercase().contains(&needle) + || entry.summary.to_lowercase().contains(&needle) + || entry.tags.iter().any(|t| t.to_lowercase().contains(&needle)) + { + if let Some(page) = self.get_page(&entry.id).await? { + results.push(page); + } + } + } + Ok(results) + } + + /// 获取内容目录(所有页面的轻量级索引条目)。 + pub fn get_index(&self) -> Vec { + self.index.lock().unwrap().clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::InMemoryStore; + use time::OffsetDateTime; + + fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage { + let now = OffsetDateTime::now_utc(); + KnowledgePage { + id: id.to_string(), + title: title.to_string(), + summary: format!("summary of {title}"), + content: format!("full content of {title}"), + tags: tags.iter().map(|s| s.to_string()).collect(), + references: Vec::new(), + created_at: now, + updated_at: now, + } + } + + #[tokio::test] + async fn add_get_delete_page() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"])) + .await + .unwrap(); + let got = ks.get_page("p1").await.unwrap(); + assert!(got.is_some()); + assert_eq!(got.unwrap().title, "LangGraph"); + + ks.delete_page("p1").await.unwrap(); + assert!(ks.get_page("p1").await.unwrap().is_none()); + } + + #[tokio::test] + async fn add_page_rejects_empty_id() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + let result = ks.add_page(make_page("", "NoId", &[])).await; + assert!(matches!(result, Err(MemoryError::InvalidInput(_)))); + } + + #[tokio::test] + async fn update_page_requires_existing() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + let result = ks.update_page(make_page("nope", "Ghost", &[])).await; + assert!(matches!(result, Err(MemoryError::NotFound(_)))); + } + + #[tokio::test] + async fn search_finds_by_title_summary_tag() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"])) + .await + .unwrap(); + ks.add_page(make_page("p2", "Other", &["knowledge-graph"])) + .await + .unwrap(); + ks.add_page(make_page("p3", "Third", &["unrelated"])) + .await + .unwrap(); + + let results = ks.search("stategraph").await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].id, "p1"); + + let results = ks.search("knowledge-graph").await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].id, "p2"); + + let results = ks.search("nonexistent").await.unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn get_index_returns_all_pages() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + ks.add_page(make_page("p1", "A", &[])).await.unwrap(); + ks.add_page(make_page("p2", "B", &[])).await.unwrap(); + let index = ks.get_index(); + assert_eq!(index.len(), 2); + } + + #[tokio::test] + async fn rebuild_index_recovers_from_drift() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + // 添加页面 + ks.add_page(make_page("p1", "A", &[])).await.unwrap(); + ks.add_page(make_page("p2", "B", &[])).await.unwrap(); + assert_eq!(ks.get_index().len(), 2); + + // 模拟 index 漂移:清空后重建 + ks.index.lock().unwrap().clear(); + assert_eq!(ks.get_index().len(), 0); + + ks.rebuild_index().await.unwrap(); + assert_eq!(ks.get_index().len(), 2); + } +} diff --git a/src/memory/retriever.rs b/src/memory/retriever.rs new file mode 100644 index 0000000..09fa7bf --- /dev/null +++ b/src/memory/retriever.rs @@ -0,0 +1,297 @@ +//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。 + +use std::collections::HashSet; + +use crate::memory::error::MemoryError; +use crate::memory::knowledge::KnowledgeStore; +use crate::memory::types::KnowledgePage; + +/// 检索器配置。 +#[derive(Debug, Clone)] +pub struct RetrieverConfig { + /// 最大返回条数(默认 20)。 + pub max_results: usize, + /// 最低分数阈值 [0.0, 1.0](默认 0.1)。 + pub min_score: f32, +} + +impl Default for RetrieverConfig { + fn default() -> Self { + Self { + max_results: 20, + min_score: 0.1, + } + } +} + +/// 单条带评分的检索结果。 +#[derive(Debug, Clone)] +pub struct ScoredItem { + pub page: KnowledgePage, + /// TextOverlap 评分 [0.0, 1.0] + pub score: f32, +} + +/// 检索结果。 +#[derive(Debug, Clone)] +pub struct RetrievalResult { + pub items: Vec, + pub query: String, +} + +/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。 +pub struct MemoryRetriever { + knowledge_store: KnowledgeStore, + config: RetrieverConfig, + /// 停用词表(用于关键词提取)。 + stop_words: HashSet, +} + +impl MemoryRetriever { + /// 创建一个新的 MemoryRetriever。 + pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self { + Self { + knowledge_store, + config, + stop_words: default_stop_words(), + } + } + + /// 替换停用词表。 + pub fn with_stop_words(mut self, stop_words: HashSet) -> Self { + self.stop_words = stop_words; + self + } + + /// 检索相关知识页面。 + pub async fn retrieve(&self, query: &str) -> Result { + if query.is_empty() { + return Ok(RetrievalResult { + items: Vec::new(), + query: query.to_string(), + }); + } + + // 1. 关键词提取 + let keywords = extract_keywords(query, &self.stop_words); + + // 2. 用关键词在 KnowledgeStore 中搜索 + let mut pages = Vec::new(); + for keyword in &keywords { + let found = self.knowledge_store.search(keyword).await?; + for page in found { + if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) { + pages.push(page); + } + } + } + + // 3. TextOverlap 评分 + let mut items: Vec = pages + .into_iter() + .map(|page| { + let score = text_overlap_score(query, &page); + ScoredItem { page, score } + }) + .collect(); + + // 4. 过滤 → 排序 → 截取 + items.retain(|i| i.score >= self.config.min_score); + items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); + items.truncate(self.config.max_results); + + Ok(RetrievalResult { + items, + query: query.to_string(), + }) + } +} + +/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。 +fn extract_keywords(query: &str, stop_words: &HashSet) -> Vec { + query + .split(|c: char| !c.is_alphanumeric()) + .filter_map(|s| { + let lower = s.to_lowercase(); + if lower.is_empty() || lower.chars().count() < 2 { + None + } else if stop_words.contains(&lower) { + None + } else { + Some(lower) + } + }) + .collect() +} + +/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。 +/// +/// 字段权重:title 0.5 + summary 0.3 + content 0.2 +/// 中文场景按字符级 bigram 处理,不依赖分词器。 +pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 { + let title = text_overlap_dice(query, &page.title); + let summary = text_overlap_dice(query, &page.summary); + let content = text_overlap_dice(query, &page.content); + title * 0.5 + summary * 0.3 + content * 0.2 +} + +/// Dice 系数(基于字符 bigram)。 +fn text_overlap_dice(query: &str, text: &str) -> f32 { + let q_bigrams = char_bigrams(query); + let t_bigrams = char_bigrams(text); + if q_bigrams.is_empty() || t_bigrams.is_empty() { + return 0.0; + } + let q_set: HashSet<&String> = q_bigrams.iter().collect(); + let t_set: HashSet<&String> = t_bigrams.iter().collect(); + let intersect = q_set.intersection(&t_set).count(); + let denom = q_set.len() + t_set.len(); + if denom == 0 { + 0.0 + } else { + (2.0 * intersect as f32) / denom as f32 + } +} + +/// 提取字符 bigrams(用于字符级 Dice 系数)。 +fn char_bigrams(s: &str) -> Vec { + let chars: Vec = s.chars().collect(); + chars.windows(2).map(|w| w.iter().collect()).collect() +} + +fn default_stop_words() -> HashSet { + [ + "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", + "had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall", + "can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what", + "which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor", + "so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at", + "by", "of", "as", "into", "through", "during", "before", "after", "above", "below", + ] + .iter() + .map(|s| s.to_string()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::knowledge::KnowledgeStore; + use crate::memory::{InMemoryStore, MemoryStore}; + use std::sync::Arc; + use time::OffsetDateTime; + + fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage { + let now = OffsetDateTime::now_utc(); + KnowledgePage { + id: id.to_string(), + title: title.to_string(), + summary: summary.to_string(), + content: content.to_string(), + tags: Vec::new(), + references: Vec::new(), + created_at: now, + updated_at: now, + } + } + + #[tokio::test] + async fn retrieve_empty_query() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + let retriever = MemoryRetriever::new(ks, RetrieverConfig::default()); + let result = retriever.retrieve("").await.unwrap(); + assert!(result.items.is_empty()); + } + + #[tokio::test] + async fn retrieve_finds_relevant_page() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + ks.add_page(make_page( + "p1", + "LangGraph StateGraph", + "state management", + "LangGraph uses StateGraph for state machines", + )) + .await + .unwrap(); + ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching")) + .await + .unwrap(); + + let retriever = MemoryRetriever::new(ks, RetrieverConfig::default()); + let result = retriever.retrieve("LangGraph state").await.unwrap(); + assert!(!result.items.is_empty()); + assert_eq!(result.items[0].page.id, "p1"); + assert!(result.items[0].score > 0.0); + } + + #[tokio::test] + async fn retrieve_respects_min_score() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap(); + + let config = RetrieverConfig { + max_results: 10, + min_score: 0.99, + }; + let retriever = MemoryRetriever::new(ks, config); + let result = retriever.retrieve("totally unrelated content").await.unwrap(); + assert!(result.items.is_empty()); + } + + #[tokio::test] + async fn retrieve_respects_max_results() { + let store = Arc::new(InMemoryStore::new()) as Arc; + let ks = KnowledgeStore::new(store); + for i in 0..5 { + ks.add_page(make_page( + &format!("p{i}"), + "LangGraph", + "framework", + "agent runtime", + )) + .await + .unwrap(); + } + let config = RetrieverConfig { + max_results: 2, + min_score: 0.0, + }; + let retriever = MemoryRetriever::new(ks, config); + let result = retriever.retrieve("LangGraph").await.unwrap(); + assert_eq!(result.items.len(), 2); + } + + #[test] + fn text_overlap_dice_zero_on_empty() { + assert_eq!(text_overlap_dice("hello", ""), 0.0); + assert_eq!(text_overlap_dice("", "hello"), 0.0); + } + + #[test] + fn text_overlap_dice_identical() { + let s = "hello world"; + assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001); + } + + #[test] + fn extract_keywords_filters_stop_words() { + let stop = default_stop_words(); + let kws = extract_keywords("the quick brown fox is fast", &stop); + assert!(!kws.contains(&"the".to_string())); + assert!(!kws.contains(&"is".to_string())); + assert!(kws.contains(&"quick".to_string())); + assert!(kws.contains(&"brown".to_string())); + } + + #[test] + fn extract_keywords_filters_single_chars() { + let stop = default_stop_words(); + let kws = extract_keywords("a b c dog", &stop); + assert!(!kws.contains(&"a".to_string())); + assert!(!kws.contains(&"b".to_string())); + } +} diff --git a/src/memory/store.rs b/src/memory/store.rs new file mode 100644 index 0000000..1bb65bb --- /dev/null +++ b/src/memory/store.rs @@ -0,0 +1,314 @@ +//! MemoryStore 抽象接口与默认实现。 + +use std::collections::HashMap; +use std::sync::Mutex; + +use async_trait::async_trait; +use time::OffsetDateTime; + +use crate::memory::error::MemoryError; +use crate::memory::types::{MemoryFilter, MemoryItem}; + +/// 底层记忆存储抽象接口。 +/// +/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。 +/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。 +#[async_trait] +pub trait MemoryStore: Send + Sync { + /// 保存/覆盖一个 MemoryItem(upsert 语义)。 + /// - 如果 id 不存在,则插入新条目 + /// - 如果 id 已存在,则覆盖旧条目 + async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>; + + /// 根据 id 获取一个 MemoryItem。 + async fn get(&self, id: &str) -> Result, MemoryError>; + + /// 根据 id 删除一个 MemoryItem。 + async fn delete(&self, id: &str) -> Result<(), MemoryError>; + + /// 根据 filter 列出 MemoryItem。 + async fn list(&self, filter: &MemoryFilter) -> Result, MemoryError>; +} + +/// 淘汰策略。 +#[derive(Debug, Clone)] +pub enum EvictionPolicy { + /// 不淘汰(默认)。 + None, + /// 超过存活时间(秒)淘汰。 + Ttl { ttl_secs: u64 }, + /// 超过容量上限淘汰最旧(基于 created_at)。 + Capacity { max_items: usize }, +} + +/// 淘汰配置。 +#[derive(Debug, Clone)] +pub struct EvictionConfig { + pub policy: EvictionPolicy, + /// 每写入 N 条后检查一次淘汰条件。 + pub check_interval: usize, +} + +impl Default for EvictionConfig { + fn default() -> Self { + Self { + policy: EvictionPolicy::None, + check_interval: 64, + } + } +} + +/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。 +pub struct InMemoryStore { + items: Mutex>, + eviction: EvictionConfig, + /// 自上次淘汰检查以来的写入次数。 + writes_since_check: Mutex, +} + +impl InMemoryStore { + /// 创建一个无淘汰策略的 InMemoryStore。 + pub fn new() -> Self { + Self { + items: Mutex::new(HashMap::new()), + eviction: EvictionConfig::default(), + writes_since_check: Mutex::new(0), + } + } + + /// 创建一个带淘汰配置的 InMemoryStore。 + pub fn with_eviction(eviction: EvictionConfig) -> Self { + Self { + items: Mutex::new(HashMap::new()), + eviction, + writes_since_check: Mutex::new(0), + } + } + + fn maybe_evict(&self) { + // 不使用 .lock().await 跨点,先取计数判断是否需要淘汰 + let should_check = { + let mut counter = self.writes_since_check.lock().unwrap(); + *counter += 1; + if *counter >= self.eviction.check_interval { + *counter = 0; + true + } else { + false + } + }; + if !should_check { + return; + } + + let policy = self.eviction.policy.clone(); + match policy { + EvictionPolicy::None => {} + EvictionPolicy::Ttl { ttl_secs } => { + let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64); + let mut items = self.items.lock().unwrap(); + items.retain(|_, v| v.created_at > cutoff); + } + EvictionPolicy::Capacity { max_items } => { + let mut items = self.items.lock().unwrap(); + if items.len() > max_items { + let mut vec: Vec<_> = items.drain().collect(); + // O(n) 部分排序:保留 created_at 最大的 max_items 个 + vec.select_nth_unstable_by(max_items, |a, b| { + b.1.created_at.cmp(&a.1.created_at) + }); + vec.truncate(max_items); + *items = vec.into_iter().collect(); + } + } + } + } +} + +impl Default for InMemoryStore { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl MemoryStore for InMemoryStore { + async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> { + { + let mut items = self.items.lock().unwrap(); + items.insert(item.id.clone(), item); + } + self.maybe_evict(); + Ok(()) + } + + async fn get(&self, id: &str) -> Result, MemoryError> { + let items = self.items.lock().unwrap(); + Ok(items.get(id).cloned()) + } + + async fn delete(&self, id: &str) -> Result<(), MemoryError> { + let mut items = self.items.lock().unwrap(); + items.remove(id); + Ok(()) + } + + async fn list(&self, filter: &MemoryFilter) -> Result, MemoryError> { + let items = self.items.lock().unwrap(); + let mut result: Vec = items + .values() + .filter(|v| match &filter.prefix { + Some(p) => v.id.starts_with(p), + None => true, + }) + .filter(|v| match filter.since { + Some(t) => v.created_at > t, + None => true, + }) + .cloned() + .collect(); + // 按 created_at 升序排列(最旧在前) + result.sort_by_key(|v| v.created_at); + // 应用 offset + if let Some(offset) = filter.offset { + if offset < result.len() { + result.drain(..offset); + } else { + result.clear(); + } + } + // 应用 limit + if let Some(limit) = filter.limit { + result.truncate(limit); + } + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use time::OffsetDateTime; + + fn make_item(id: &str) -> MemoryItem { + MemoryItem { + id: id.to_string(), + content: format!("content-{id}"), + metadata: serde_json::json!({}), + created_at: OffsetDateTime::now_utc(), + } + } + + #[tokio::test] + async fn save_get_delete_list() { + let store = InMemoryStore::new(); + store.save(make_item("a")).await.unwrap(); + store.save(make_item("b")).await.unwrap(); + + let got = store.get("a").await.unwrap(); + assert!(got.is_some()); + assert_eq!(got.unwrap().id, "a"); + + let list = store.list(&MemoryFilter::default()).await.unwrap(); + assert_eq!(list.len(), 2); + + store.delete("a").await.unwrap(); + assert!(store.get("a").await.unwrap().is_none()); + } + + #[tokio::test] + async fn save_is_upsert() { + let store = InMemoryStore::new(); + store.save(make_item("a")).await.unwrap(); + let mut item = make_item("a"); + item.content = "updated".to_string(); + store.save(item).await.unwrap(); + let got = store.get("a").await.unwrap().unwrap(); + assert_eq!(got.content, "updated"); + let list = store.list(&MemoryFilter::default()).await.unwrap(); + assert_eq!(list.len(), 1); + } + + #[tokio::test] + async fn list_with_prefix_and_limit() { + let store = InMemoryStore::new(); + store.save(make_item("foo_a")).await.unwrap(); + store.save(make_item("foo_b")).await.unwrap(); + store.save(make_item("bar_a")).await.unwrap(); + + let filter = MemoryFilter { + prefix: Some("foo_".to_string()), + ..Default::default() + }; + let list = store.list(&filter).await.unwrap(); + assert_eq!(list.len(), 2); + + let filter = MemoryFilter { + prefix: Some("foo_".to_string()), + limit: Some(1), + ..Default::default() + }; + let list = store.list(&filter).await.unwrap(); + assert_eq!(list.len(), 1); + } + + #[tokio::test] + async fn capacity_eviction() { + // 强制每次写入都检查 + let eviction = EvictionConfig { + policy: EvictionPolicy::Capacity { max_items: 2 }, + check_interval: 1, + }; + let store = InMemoryStore::with_eviction(eviction); + // 第一条和第二条共存 + store.save(make_item("a")).await.unwrap(); + store.save(make_item("b")).await.unwrap(); + // 第三条写入触发淘汰:a 或 b 之一被淘汰 + store.save(make_item("c")).await.unwrap(); + + let list = store.list(&MemoryFilter::default()).await.unwrap(); + assert_eq!(list.len(), 2); + // 留下的应该是 b 和 c(最新的两个) + let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect(); + assert!(ids.contains(&"b")); + assert!(ids.contains(&"c")); + } + + #[tokio::test] + async fn ttl_eviction() { + // TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。 + // 改用小 TTL + 睡眠:先 save a,sleep,save b 时 a 已过期被淘汰。 + let eviction = EvictionConfig { + policy: EvictionPolicy::Ttl { ttl_secs: 1 }, + check_interval: 1, + }; + let store = InMemoryStore::with_eviction(eviction); + store.save(make_item("a")).await.unwrap(); + // 等待超过 1 秒 + std::thread::sleep(std::time::Duration::from_millis(1100)); + // 触发淘汰:a 已超过 ttl_secs=1,应被淘汰 + store.save(make_item("b")).await.unwrap(); + let list = store.list(&MemoryFilter::default()).await.unwrap(); + // 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。 + // 我们只断言 list 不包含 "a" 即可。 + let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect(); + assert!( + !ids.contains(&"a"), + "expected 'a' to be evicted, but found in {ids:?}" + ); + } + + #[tokio::test] + async fn none_policy_no_eviction() { + let eviction = EvictionConfig { + policy: EvictionPolicy::None, + check_interval: 1, + }; + let store = InMemoryStore::with_eviction(eviction); + for i in 0..100 { + store.save(make_item(&format!("item_{i}"))).await.unwrap(); + } + let list = store.list(&MemoryFilter::default()).await.unwrap(); + assert_eq!(list.len(), 100); + } +} diff --git a/src/memory/types.rs b/src/memory/types.rs new file mode 100644 index 0000000..e7cdf80 --- /dev/null +++ b/src/memory/types.rs @@ -0,0 +1,73 @@ +//! 记忆系统核心数据类型。 + +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +/// 记忆条目 —— MemoryStore 存储的基本单元。 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryItem { + /// 唯一标识。 + pub id: String, + /// 内容(通常为 JSON 序列化的具体记忆数据)。 + pub content: String, + /// 任意附加元数据。 + pub metadata: serde_json::Value, + /// 创建时间。 + pub created_at: OffsetDateTime, +} + +/// MemoryStore 列表查询条件。 +#[derive(Debug, Clone, Default)] +pub struct MemoryFilter { + /// 按 id 前缀过滤。 + pub prefix: Option, + /// 仅返回该时间之后创建的条目。 + pub since: Option, + /// 跳过前 N 条。 + pub offset: Option, + /// 最多返回 N 条。 + pub limit: Option, +} + +/// 知识页面 —— 描述一段结构化知识。 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KnowledgePage { + /// 唯一标识。 + pub id: String, + /// 标题。 + pub title: String, + /// 一句话摘要。 + pub summary: String, + /// 完整内容。 + pub content: String, + /// 检索标签。 + pub tags: Vec, + /// 交叉引用的其他页面 ID。 + pub references: Vec, + /// 创建时间。 + pub created_at: OffsetDateTime, + /// 最后更新时间。 + pub updated_at: OffsetDateTime, +} + +/// 知识页面索引条目 —— 用于轻量遍历和内容目录。 +#[derive(Debug, Clone)] +pub struct PageIndexEntry { + pub id: String, + pub title: String, + pub summary: String, + pub tags: Vec, + pub updated_at: OffsetDateTime, +} + +impl From<&KnowledgePage> for PageIndexEntry { + fn from(p: &KnowledgePage) -> Self { + Self { + id: p.id.clone(), + title: p.title.clone(), + summary: p.summary.clone(), + tags: p.tags.clone(), + updated_at: p.updated_at, + } + } +}