feat(memory): 添加记忆系统模块

This commit is contained in:
徐涛
2026-06-08 08:42:43 +08:00
parent 1fe7f02281
commit 2ecc0b4001
9 changed files with 1244 additions and 0 deletions
+1
View File
@@ -1,6 +1,7 @@
//! agcore —— 智能体(Agent)核心工具箱。
pub mod llm;
pub mod memory;
pub mod prompt;
pub mod tools;
+22
View File
@@ -0,0 +1,22 @@
//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。
pub mod conversation;
pub mod error;
pub mod knowledge;
pub mod retriever;
pub mod store;
pub mod types;
// 高频类型(大多数下游需要)
pub use conversation::{ConversationMemory, ConversationMemoryConfig};
pub use error::MemoryError;
pub use knowledge::KnowledgeStore;
pub use retriever::MemoryRetriever;
pub use store::{InMemoryStore, MemoryStore};
// 低频类型(配置/高级使用)
pub use conversation::MemoryStrategy;
pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX};
pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem};
pub use store::{EvictionConfig, EvictionPolicy};
pub use types::{KnowledgePage, MemoryFilter, MemoryItem};
+260
View File
@@ -0,0 +1,260 @@
//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact};
use crate::llm::types::OpenaiChatMessage;
use crate::memory::error::MemoryError;
use crate::memory::store::MemoryStore;
use crate::memory::types::MemoryItem;
/// 对话消息管理策略。
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryStrategy {
/// 滑动窗口:达到上限时删除最旧消息。
SlidingWindow,
/// 保留所有消息(仅压缩,不删除)。
Full,
}
/// 对话记忆配置。
#[derive(Debug, Clone)]
pub struct ConversationMemoryConfig {
pub strategy: MemoryStrategy,
pub max_turns: usize,
pub compact_config: Option<CompactConfig>,
}
impl Default for ConversationMemoryConfig {
fn default() -> Self {
Self {
strategy: MemoryStrategy::SlidingWindow,
max_turns: 50,
compact_config: Some(CompactConfig::default()),
}
}
}
/// 对话记忆 —— 按 session 管理多轮对话消息历史。
///
/// 内部维护 `Vec<OpenaiChatMessage>` 热缓存(供 `llm::compact` 直接操作),
/// `MemoryStore` 用作冷持久化层。
pub struct ConversationMemory {
store: Arc<dyn MemoryStore>,
session_id: String,
config: ConversationMemoryConfig,
/// 热缓存:消息列表,供 `llm::compact` 直接操作。
messages: Vec<OpenaiChatMessage>,
/// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。
message_ids: Vec<String>,
/// 压缩断路器状态。
compact_state: CompactState,
}
impl ConversationMemory {
/// 创建一个新的 ConversationMemory。
pub fn new(
store: Arc<dyn MemoryStore>,
session_id: impl Into<String>,
config: ConversationMemoryConfig,
) -> Self {
Self {
store,
session_id: session_id.into(),
config,
messages: Vec::new(),
message_ids: Vec::new(),
compact_state: CompactState::new(),
}
}
/// 获取 session id。
pub fn session_id(&self) -> &str {
&self.session_id
}
/// 获取配置。
pub fn config(&self) -> &ConversationMemoryConfig {
&self.config
}
/// 从 MemoryStore 加载历史消息到热缓存。
pub async fn load(&mut self) -> Result<(), MemoryError> {
let filter = crate::memory::types::MemoryFilter {
prefix: Some(self.session_prefix()),
..Default::default()
};
let items = self.store.list(&filter).await?;
let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len());
for item in items {
match serde_json::from_str::<OpenaiChatMessage>(&item.content) {
Ok(msg) => pairs.push((item.id, msg, item.created_at)),
Err(e) => {
return Err(MemoryError::Serialization(format!(
"load message {} failed: {e}",
item.id
)));
}
}
}
// 按 created_at 升序排列
pairs.sort_by_key(|p| p.2);
self.message_ids = pairs.iter().map(|p| p.0.clone()).collect();
self.messages = pairs.into_iter().map(|p| p.1).collect();
Ok(())
}
/// 添加一条消息。
///
/// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。
pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> {
let now = OffsetDateTime::now_utc();
let index = self.messages.len();
let id = self.make_message_id(index, &now);
// 写入热缓存
self.messages.push(msg);
self.message_ids.push(id.clone());
// 同步到冷存储
let item = MemoryItem {
id: id.clone(),
content: serde_json::to_string(self.messages.last().unwrap())
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }),
created_at: now,
};
self.store.save(item).await?;
// 触发淘汰和压缩
self.maybe_evict_and_compact().await;
Ok(())
}
/// 获取完整消息历史。
pub fn get_history(&self) -> &[OpenaiChatMessage] {
&self.messages
}
/// 清空所有消息。
pub async fn clear(&mut self) -> Result<(), MemoryError> {
let to_delete = std::mem::take(&mut self.message_ids);
self.messages.clear();
self.compact_state = CompactState::new();
for id in to_delete {
self.store.delete(&id).await?;
}
Ok(())
}
/// 当前消息数量。
pub fn len(&self) -> usize {
self.messages.len()
}
/// 是否为空。
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
fn session_prefix(&self) -> String {
format!("conv:{self}:", self = self.session_id)
}
fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String {
format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos())
}
async fn maybe_evict_and_compact(&mut self) {
// 1. Sliding window 淘汰:删除最旧消息
if self.config.strategy == MemoryStrategy::SlidingWindow {
while self.messages.len() > self.config.max_turns {
if let Some(removed_id) = self.message_ids.first().cloned() {
let _ = self.store.delete(&removed_id).await;
}
self.messages.remove(0);
self.message_ids.remove(0);
}
}
// 2. 压缩(复用 llm::compact
if let Some(ref compact_config) = self.config.compact_config {
if should_compact(&self.messages, compact_config, &self.compact_state) {
let keep_recent = compact_config.keep_recent;
let freed = microcompact(&mut self.messages, keep_recent);
if freed > 0 {
self.compact_state.record_success();
} else {
// 没有 token 被释放(可能没找到可压缩的 tool result
let _ = self.compact_state.record_failure();
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::OpenaiChatMessage;
use crate::memory::InMemoryStore;
use crate::memory::MemoryStore;
fn user_text(s: &str) -> OpenaiChatMessage {
OpenaiChatMessage::user_text(s)
}
#[tokio::test]
async fn add_and_get_history() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default());
conv.add_message(user_text("hello")).await.unwrap();
conv.add_message(user_text("world")).await.unwrap();
assert_eq!(conv.len(), 2);
assert_eq!(conv.get_history().len(), 2);
}
#[tokio::test]
async fn sliding_window_evicts_oldest() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let config = ConversationMemoryConfig {
strategy: MemoryStrategy::SlidingWindow,
max_turns: 3,
compact_config: None,
};
let mut conv = ConversationMemory::new(store, "s1", config);
for i in 0..5 {
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
}
assert_eq!(conv.len(), 3);
}
#[tokio::test]
async fn full_strategy_no_evict() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let config = ConversationMemoryConfig {
strategy: MemoryStrategy::Full,
max_turns: 3,
compact_config: None,
};
let mut conv = ConversationMemory::new(store, "s1", config);
for i in 0..5 {
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
}
// Full 策略不删除消息
assert_eq!(conv.len(), 5);
}
#[tokio::test]
async fn clear_empties_messages() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default());
conv.add_message(user_text("hello")).await.unwrap();
assert!(!conv.is_empty());
conv.clear().await.unwrap();
assert!(conv.is_empty());
}
}
+29
View File
@@ -0,0 +1,29 @@
//! 记忆系统错误类型。
use thiserror::Error;
/// 记忆系统错误枚举。
#[derive(Debug, Error)]
pub enum MemoryError {
#[error("Item not found: {0}")]
NotFound(String),
#[error("Storage error: {0}")]
Storage(String),
#[error("Serialization error: {0}")]
Serialization(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Retrieval error: {0}")]
RetrievalError(String),
}
impl MemoryError {
/// 是否为可恢复错误(调用方可重试或调整参数)。
pub fn is_recoverable(&self) -> bool {
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
}
}
+247
View File
@@ -0,0 +1,247 @@
//! 知识库 —— KnowledgePage 存储与关键词检索。
use std::sync::Arc;
use time::OffsetDateTime;
use crate::memory::error::MemoryError;
use crate::memory::store::MemoryStore;
use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem};
pub use crate::memory::types::PageIndexEntry;
/// `MemoryItem.id` 中知识页面前缀。
pub const KNOWLEDGE_PREFIX: &str = "knowledge_";
/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。
///
/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON),
/// 同时维护一个 `Vec<PageIndexEntry>` 索引以加速列表遍历。
pub struct KnowledgeStore {
store: Arc<dyn MemoryStore>,
index: std::sync::Mutex<Vec<PageIndexEntry>>,
}
impl KnowledgeStore {
/// 创建一个新的 KnowledgeStore。
pub fn new(store: Arc<dyn MemoryStore>) -> Self {
Self {
store,
index: std::sync::Mutex::new(Vec::new()),
}
}
/// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
let items = self
.store
.list(&MemoryFilter {
prefix: Some(KNOWLEDGE_PREFIX.to_string()),
..Default::default()
})
.await?;
let mut index = self.index.lock().unwrap();
index.clear();
for item in items {
let page: KnowledgePage = serde_json::from_str(&item.content)
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
index.push(PageIndexEntry::from(&page));
}
Ok(())
}
/// 创建一个新的知识页面。
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
if page.id.is_empty() {
return Err(MemoryError::InvalidInput("page.id is empty".into()));
}
let now = OffsetDateTime::now_utc();
let id = format!("{KNOWLEDGE_PREFIX}{}", page.id);
let content = serde_json::to_string(&page)
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
let item = MemoryItem {
id,
content,
metadata: serde_json::json!({}),
created_at: now,
};
self.store.save(item).await?;
let mut index = self.index.lock().unwrap();
// 替换或追加
if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) {
*existing = PageIndexEntry::from(&page);
} else {
index.push(PageIndexEntry::from(&page));
}
Ok(())
}
/// 根据 page id 获取一个页面。
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> {
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
let item = self.store.get(&full_id).await?;
match item {
None => Ok(None),
Some(item) => {
let page: KnowledgePage = serde_json::from_str(&item.content)
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
Ok(Some(page))
}
}
}
/// 更新一个已存在的知识页面。
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
if page.id.is_empty() {
return Err(MemoryError::InvalidInput("page.id is empty".into()));
}
// 通过 get_page 检查存在性
if self.get_page(&page.id).await?.is_none() {
return Err(MemoryError::NotFound(page.id));
}
self.add_page(page).await
}
/// 删除一个知识页面。
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> {
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
self.store.delete(&full_id).await?;
let mut index = self.index.lock().unwrap();
index.retain(|e| e.id != id);
Ok(())
}
/// 根据关键词搜索知识页面。
///
/// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。
/// 全文 `content` 搜索走 `MemoryStore`。
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> {
if query.is_empty() {
return Ok(Vec::new());
}
let needle = query.to_lowercase();
let mut results = Vec::new();
let index = self.index.lock().unwrap();
for entry in index.iter() {
if entry.title.to_lowercase().contains(&needle)
|| entry.summary.to_lowercase().contains(&needle)
|| entry.tags.iter().any(|t| t.to_lowercase().contains(&needle))
{
if let Some(page) = self.get_page(&entry.id).await? {
results.push(page);
}
}
}
Ok(results)
}
/// 获取内容目录(所有页面的轻量级索引条目)。
pub fn get_index(&self) -> Vec<PageIndexEntry> {
self.index.lock().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::InMemoryStore;
use time::OffsetDateTime;
fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage {
let now = OffsetDateTime::now_utc();
KnowledgePage {
id: id.to_string(),
title: title.to_string(),
summary: format!("summary of {title}"),
content: format!("full content of {title}"),
tags: tags.iter().map(|s| s.to_string()).collect(),
references: Vec::new(),
created_at: now,
updated_at: now,
}
}
#[tokio::test]
async fn add_get_delete_page() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"]))
.await
.unwrap();
let got = ks.get_page("p1").await.unwrap();
assert!(got.is_some());
assert_eq!(got.unwrap().title, "LangGraph");
ks.delete_page("p1").await.unwrap();
assert!(ks.get_page("p1").await.unwrap().is_none());
}
#[tokio::test]
async fn add_page_rejects_empty_id() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
let result = ks.add_page(make_page("", "NoId", &[])).await;
assert!(matches!(result, Err(MemoryError::InvalidInput(_))));
}
#[tokio::test]
async fn update_page_requires_existing() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
let result = ks.update_page(make_page("nope", "Ghost", &[])).await;
assert!(matches!(result, Err(MemoryError::NotFound(_))));
}
#[tokio::test]
async fn search_finds_by_title_summary_tag() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"]))
.await
.unwrap();
ks.add_page(make_page("p2", "Other", &["knowledge-graph"]))
.await
.unwrap();
ks.add_page(make_page("p3", "Third", &["unrelated"]))
.await
.unwrap();
let results = ks.search("stategraph").await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "p1");
let results = ks.search("knowledge-graph").await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "p2");
let results = ks.search("nonexistent").await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn get_index_returns_all_pages() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
let index = ks.get_index();
assert_eq!(index.len(), 2);
}
#[tokio::test]
async fn rebuild_index_recovers_from_drift() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
// 添加页面
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
assert_eq!(ks.get_index().len(), 2);
// 模拟 index 漂移:清空后重建
ks.index.lock().unwrap().clear();
assert_eq!(ks.get_index().len(), 0);
ks.rebuild_index().await.unwrap();
assert_eq!(ks.get_index().len(), 2);
}
}
+297
View File
@@ -0,0 +1,297 @@
//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。
use std::collections::HashSet;
use crate::memory::error::MemoryError;
use crate::memory::knowledge::KnowledgeStore;
use crate::memory::types::KnowledgePage;
/// 检索器配置。
#[derive(Debug, Clone)]
pub struct RetrieverConfig {
/// 最大返回条数(默认 20)。
pub max_results: usize,
/// 最低分数阈值 [0.0, 1.0](默认 0.1)。
pub min_score: f32,
}
impl Default for RetrieverConfig {
fn default() -> Self {
Self {
max_results: 20,
min_score: 0.1,
}
}
}
/// 单条带评分的检索结果。
#[derive(Debug, Clone)]
pub struct ScoredItem {
pub page: KnowledgePage,
/// TextOverlap 评分 [0.0, 1.0]
pub score: f32,
}
/// 检索结果。
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub items: Vec<ScoredItem>,
pub query: String,
}
/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。
pub struct MemoryRetriever {
knowledge_store: KnowledgeStore,
config: RetrieverConfig,
/// 停用词表(用于关键词提取)。
stop_words: HashSet<String>,
}
impl MemoryRetriever {
/// 创建一个新的 MemoryRetriever。
pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self {
Self {
knowledge_store,
config,
stop_words: default_stop_words(),
}
}
/// 替换停用词表。
pub fn with_stop_words(mut self, stop_words: HashSet<String>) -> Self {
self.stop_words = stop_words;
self
}
/// 检索相关知识页面。
pub async fn retrieve(&self, query: &str) -> Result<RetrievalResult, MemoryError> {
if query.is_empty() {
return Ok(RetrievalResult {
items: Vec::new(),
query: query.to_string(),
});
}
// 1. 关键词提取
let keywords = extract_keywords(query, &self.stop_words);
// 2. 用关键词在 KnowledgeStore 中搜索
let mut pages = Vec::new();
for keyword in &keywords {
let found = self.knowledge_store.search(keyword).await?;
for page in found {
if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) {
pages.push(page);
}
}
}
// 3. TextOverlap 评分
let mut items: Vec<ScoredItem> = pages
.into_iter()
.map(|page| {
let score = text_overlap_score(query, &page);
ScoredItem { page, score }
})
.collect();
// 4. 过滤 → 排序 → 截取
items.retain(|i| i.score >= self.config.min_score);
items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
items.truncate(self.config.max_results);
Ok(RetrievalResult {
items,
query: query.to_string(),
})
}
}
/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。
fn extract_keywords(query: &str, stop_words: &HashSet<String>) -> Vec<String> {
query
.split(|c: char| !c.is_alphanumeric())
.filter_map(|s| {
let lower = s.to_lowercase();
if lower.is_empty() || lower.chars().count() < 2 {
None
} else if stop_words.contains(&lower) {
None
} else {
Some(lower)
}
})
.collect()
}
/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。
///
/// 字段权重:title 0.5 + summary 0.3 + content 0.2
/// 中文场景按字符级 bigram 处理,不依赖分词器。
pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 {
let title = text_overlap_dice(query, &page.title);
let summary = text_overlap_dice(query, &page.summary);
let content = text_overlap_dice(query, &page.content);
title * 0.5 + summary * 0.3 + content * 0.2
}
/// Dice 系数(基于字符 bigram)。
fn text_overlap_dice(query: &str, text: &str) -> f32 {
let q_bigrams = char_bigrams(query);
let t_bigrams = char_bigrams(text);
if q_bigrams.is_empty() || t_bigrams.is_empty() {
return 0.0;
}
let q_set: HashSet<&String> = q_bigrams.iter().collect();
let t_set: HashSet<&String> = t_bigrams.iter().collect();
let intersect = q_set.intersection(&t_set).count();
let denom = q_set.len() + t_set.len();
if denom == 0 {
0.0
} else {
(2.0 * intersect as f32) / denom as f32
}
}
/// 提取字符 bigrams(用于字符级 Dice 系数)。
fn char_bigrams(s: &str) -> Vec<String> {
let chars: Vec<char> = s.chars().collect();
chars.windows(2).map(|w| w.iter().collect()).collect()
}
fn default_stop_words() -> HashSet<String> {
[
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall",
"can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what",
"which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor",
"so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at",
"by", "of", "as", "into", "through", "during", "before", "after", "above", "below",
]
.iter()
.map(|s| s.to_string())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::knowledge::KnowledgeStore;
use crate::memory::{InMemoryStore, MemoryStore};
use std::sync::Arc;
use time::OffsetDateTime;
fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage {
let now = OffsetDateTime::now_utc();
KnowledgePage {
id: id.to_string(),
title: title.to_string(),
summary: summary.to_string(),
content: content.to_string(),
tags: Vec::new(),
references: Vec::new(),
created_at: now,
updated_at: now,
}
}
#[tokio::test]
async fn retrieve_empty_query() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
let result = retriever.retrieve("").await.unwrap();
assert!(result.items.is_empty());
}
#[tokio::test]
async fn retrieve_finds_relevant_page() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
ks.add_page(make_page(
"p1",
"LangGraph StateGraph",
"state management",
"LangGraph uses StateGraph for state machines",
))
.await
.unwrap();
ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching"))
.await
.unwrap();
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
let result = retriever.retrieve("LangGraph state").await.unwrap();
assert!(!result.items.is_empty());
assert_eq!(result.items[0].page.id, "p1");
assert!(result.items[0].score > 0.0);
}
#[tokio::test]
async fn retrieve_respects_min_score() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap();
let config = RetrieverConfig {
max_results: 10,
min_score: 0.99,
};
let retriever = MemoryRetriever::new(ks, config);
let result = retriever.retrieve("totally unrelated content").await.unwrap();
assert!(result.items.is_empty());
}
#[tokio::test]
async fn retrieve_respects_max_results() {
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
let ks = KnowledgeStore::new(store);
for i in 0..5 {
ks.add_page(make_page(
&format!("p{i}"),
"LangGraph",
"framework",
"agent runtime",
))
.await
.unwrap();
}
let config = RetrieverConfig {
max_results: 2,
min_score: 0.0,
};
let retriever = MemoryRetriever::new(ks, config);
let result = retriever.retrieve("LangGraph").await.unwrap();
assert_eq!(result.items.len(), 2);
}
#[test]
fn text_overlap_dice_zero_on_empty() {
assert_eq!(text_overlap_dice("hello", ""), 0.0);
assert_eq!(text_overlap_dice("", "hello"), 0.0);
}
#[test]
fn text_overlap_dice_identical() {
let s = "hello world";
assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001);
}
#[test]
fn extract_keywords_filters_stop_words() {
let stop = default_stop_words();
let kws = extract_keywords("the quick brown fox is fast", &stop);
assert!(!kws.contains(&"the".to_string()));
assert!(!kws.contains(&"is".to_string()));
assert!(kws.contains(&"quick".to_string()));
assert!(kws.contains(&"brown".to_string()));
}
#[test]
fn extract_keywords_filters_single_chars() {
let stop = default_stop_words();
let kws = extract_keywords("a b c dog", &stop);
assert!(!kws.contains(&"a".to_string()));
assert!(!kws.contains(&"b".to_string()));
}
}
+314
View File
@@ -0,0 +1,314 @@
//! MemoryStore 抽象接口与默认实现。
use std::collections::HashMap;
use std::sync::Mutex;
use async_trait::async_trait;
use time::OffsetDateTime;
use crate::memory::error::MemoryError;
use crate::memory::types::{MemoryFilter, MemoryItem};
/// 底层记忆存储抽象接口。
///
/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。
/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。
#[async_trait]
pub trait MemoryStore: Send + Sync {
/// 保存/覆盖一个 MemoryItemupsert 语义)。
/// - 如果 id 不存在,则插入新条目
/// - 如果 id 已存在,则覆盖旧条目
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
/// 根据 id 获取一个 MemoryItem。
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
/// 根据 id 删除一个 MemoryItem。
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
/// 根据 filter 列出 MemoryItem。
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
}
/// 淘汰策略。
#[derive(Debug, Clone)]
pub enum EvictionPolicy {
/// 不淘汰(默认)。
None,
/// 超过存活时间(秒)淘汰。
Ttl { ttl_secs: u64 },
/// 超过容量上限淘汰最旧(基于 created_at)。
Capacity { max_items: usize },
}
/// 淘汰配置。
#[derive(Debug, Clone)]
pub struct EvictionConfig {
pub policy: EvictionPolicy,
/// 每写入 N 条后检查一次淘汰条件。
pub check_interval: usize,
}
impl Default for EvictionConfig {
fn default() -> Self {
Self {
policy: EvictionPolicy::None,
check_interval: 64,
}
}
}
/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。
pub struct InMemoryStore {
items: Mutex<HashMap<String, MemoryItem>>,
eviction: EvictionConfig,
/// 自上次淘汰检查以来的写入次数。
writes_since_check: Mutex<usize>,
}
impl InMemoryStore {
/// 创建一个无淘汰策略的 InMemoryStore。
pub fn new() -> Self {
Self {
items: Mutex::new(HashMap::new()),
eviction: EvictionConfig::default(),
writes_since_check: Mutex::new(0),
}
}
/// 创建一个带淘汰配置的 InMemoryStore。
pub fn with_eviction(eviction: EvictionConfig) -> Self {
Self {
items: Mutex::new(HashMap::new()),
eviction,
writes_since_check: Mutex::new(0),
}
}
fn maybe_evict(&self) {
// 不使用 .lock().await 跨点,先取计数判断是否需要淘汰
let should_check = {
let mut counter = self.writes_since_check.lock().unwrap();
*counter += 1;
if *counter >= self.eviction.check_interval {
*counter = 0;
true
} else {
false
}
};
if !should_check {
return;
}
let policy = self.eviction.policy.clone();
match policy {
EvictionPolicy::None => {}
EvictionPolicy::Ttl { ttl_secs } => {
let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64);
let mut items = self.items.lock().unwrap();
items.retain(|_, v| v.created_at > cutoff);
}
EvictionPolicy::Capacity { max_items } => {
let mut items = self.items.lock().unwrap();
if items.len() > max_items {
let mut vec: Vec<_> = items.drain().collect();
// O(n) 部分排序:保留 created_at 最大的 max_items 个
vec.select_nth_unstable_by(max_items, |a, b| {
b.1.created_at.cmp(&a.1.created_at)
});
vec.truncate(max_items);
*items = vec.into_iter().collect();
}
}
}
}
}
impl Default for InMemoryStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl MemoryStore for InMemoryStore {
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
{
let mut items = self.items.lock().unwrap();
items.insert(item.id.clone(), item);
}
self.maybe_evict();
Ok(())
}
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError> {
let items = self.items.lock().unwrap();
Ok(items.get(id).cloned())
}
async fn delete(&self, id: &str) -> Result<(), MemoryError> {
let mut items = self.items.lock().unwrap();
items.remove(id);
Ok(())
}
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError> {
let items = self.items.lock().unwrap();
let mut result: Vec<MemoryItem> = items
.values()
.filter(|v| match &filter.prefix {
Some(p) => v.id.starts_with(p),
None => true,
})
.filter(|v| match filter.since {
Some(t) => v.created_at > t,
None => true,
})
.cloned()
.collect();
// 按 created_at 升序排列(最旧在前)
result.sort_by_key(|v| v.created_at);
// 应用 offset
if let Some(offset) = filter.offset {
if offset < result.len() {
result.drain(..offset);
} else {
result.clear();
}
}
// 应用 limit
if let Some(limit) = filter.limit {
result.truncate(limit);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use time::OffsetDateTime;
fn make_item(id: &str) -> MemoryItem {
MemoryItem {
id: id.to_string(),
content: format!("content-{id}"),
metadata: serde_json::json!({}),
created_at: OffsetDateTime::now_utc(),
}
}
#[tokio::test]
async fn save_get_delete_list() {
let store = InMemoryStore::new();
store.save(make_item("a")).await.unwrap();
store.save(make_item("b")).await.unwrap();
let got = store.get("a").await.unwrap();
assert!(got.is_some());
assert_eq!(got.unwrap().id, "a");
let list = store.list(&MemoryFilter::default()).await.unwrap();
assert_eq!(list.len(), 2);
store.delete("a").await.unwrap();
assert!(store.get("a").await.unwrap().is_none());
}
#[tokio::test]
async fn save_is_upsert() {
let store = InMemoryStore::new();
store.save(make_item("a")).await.unwrap();
let mut item = make_item("a");
item.content = "updated".to_string();
store.save(item).await.unwrap();
let got = store.get("a").await.unwrap().unwrap();
assert_eq!(got.content, "updated");
let list = store.list(&MemoryFilter::default()).await.unwrap();
assert_eq!(list.len(), 1);
}
#[tokio::test]
async fn list_with_prefix_and_limit() {
let store = InMemoryStore::new();
store.save(make_item("foo_a")).await.unwrap();
store.save(make_item("foo_b")).await.unwrap();
store.save(make_item("bar_a")).await.unwrap();
let filter = MemoryFilter {
prefix: Some("foo_".to_string()),
..Default::default()
};
let list = store.list(&filter).await.unwrap();
assert_eq!(list.len(), 2);
let filter = MemoryFilter {
prefix: Some("foo_".to_string()),
limit: Some(1),
..Default::default()
};
let list = store.list(&filter).await.unwrap();
assert_eq!(list.len(), 1);
}
#[tokio::test]
async fn capacity_eviction() {
// 强制每次写入都检查
let eviction = EvictionConfig {
policy: EvictionPolicy::Capacity { max_items: 2 },
check_interval: 1,
};
let store = InMemoryStore::with_eviction(eviction);
// 第一条和第二条共存
store.save(make_item("a")).await.unwrap();
store.save(make_item("b")).await.unwrap();
// 第三条写入触发淘汰:a 或 b 之一被淘汰
store.save(make_item("c")).await.unwrap();
let list = store.list(&MemoryFilter::default()).await.unwrap();
assert_eq!(list.len(), 2);
// 留下的应该是 b 和 c(最新的两个)
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
assert!(ids.contains(&"b"));
assert!(ids.contains(&"c"));
}
#[tokio::test]
async fn ttl_eviction() {
// TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。
// 改用小 TTL + 睡眠:先 save asleepsave b 时 a 已过期被淘汰。
let eviction = EvictionConfig {
policy: EvictionPolicy::Ttl { ttl_secs: 1 },
check_interval: 1,
};
let store = InMemoryStore::with_eviction(eviction);
store.save(make_item("a")).await.unwrap();
// 等待超过 1 秒
std::thread::sleep(std::time::Duration::from_millis(1100));
// 触发淘汰:a 已超过 ttl_secs=1,应被淘汰
store.save(make_item("b")).await.unwrap();
let list = store.list(&MemoryFilter::default()).await.unwrap();
// 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。
// 我们只断言 list 不包含 "a" 即可。
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
assert!(
!ids.contains(&"a"),
"expected 'a' to be evicted, but found in {ids:?}"
);
}
#[tokio::test]
async fn none_policy_no_eviction() {
let eviction = EvictionConfig {
policy: EvictionPolicy::None,
check_interval: 1,
};
let store = InMemoryStore::with_eviction(eviction);
for i in 0..100 {
store.save(make_item(&format!("item_{i}"))).await.unwrap();
}
let list = store.list(&MemoryFilter::default()).await.unwrap();
assert_eq!(list.len(), 100);
}
}
+73
View File
@@ -0,0 +1,73 @@
//! 记忆系统核心数据类型。
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
/// 记忆条目 —— MemoryStore 存储的基本单元。
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryItem {
/// 唯一标识。
pub id: String,
/// 内容(通常为 JSON 序列化的具体记忆数据)。
pub content: String,
/// 任意附加元数据。
pub metadata: serde_json::Value,
/// 创建时间。
pub created_at: OffsetDateTime,
}
/// MemoryStore 列表查询条件。
#[derive(Debug, Clone, Default)]
pub struct MemoryFilter {
/// 按 id 前缀过滤。
pub prefix: Option<String>,
/// 仅返回该时间之后创建的条目。
pub since: Option<OffsetDateTime>,
/// 跳过前 N 条。
pub offset: Option<usize>,
/// 最多返回 N 条。
pub limit: Option<usize>,
}
/// 知识页面 —— 描述一段结构化知识。
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgePage {
/// 唯一标识。
pub id: String,
/// 标题。
pub title: String,
/// 一句话摘要。
pub summary: String,
/// 完整内容。
pub content: String,
/// 检索标签。
pub tags: Vec<String>,
/// 交叉引用的其他页面 ID。
pub references: Vec<String>,
/// 创建时间。
pub created_at: OffsetDateTime,
/// 最后更新时间。
pub updated_at: OffsetDateTime,
}
/// 知识页面索引条目 —— 用于轻量遍历和内容目录。
#[derive(Debug, Clone)]
pub struct PageIndexEntry {
pub id: String,
pub title: String,
pub summary: String,
pub tags: Vec<String>,
pub updated_at: OffsetDateTime,
}
impl From<&KnowledgePage> for PageIndexEntry {
fn from(p: &KnowledgePage) -> Self {
Self {
id: p.id.clone(),
title: p.title.clone(),
summary: p.summary.clone(),
tags: p.tags.clone(),
updated_at: p.updated_at,
}
}
}