feat(memory): 添加记忆系统模块
This commit is contained in:
@@ -19,6 +19,7 @@ futures-core = "0.3"
|
|||||||
bytes = "1"
|
bytes = "1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
tokio-util = { version = "0.7", features = ["rt"] }
|
tokio-util = { version = "0.7", features = ["rt"] }
|
||||||
|
time = { version = "0.3", features = ["serde"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
//! agcore —— 智能体(Agent)核心工具箱。
|
//! agcore —— 智能体(Agent)核心工具箱。
|
||||||
|
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
|
pub mod memory;
|
||||||
pub mod prompt;
|
pub mod prompt;
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。
|
||||||
|
|
||||||
|
pub mod conversation;
|
||||||
|
pub mod error;
|
||||||
|
pub mod knowledge;
|
||||||
|
pub mod retriever;
|
||||||
|
pub mod store;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
// 高频类型(大多数下游需要)
|
||||||
|
pub use conversation::{ConversationMemory, ConversationMemoryConfig};
|
||||||
|
pub use error::MemoryError;
|
||||||
|
pub use knowledge::KnowledgeStore;
|
||||||
|
pub use retriever::MemoryRetriever;
|
||||||
|
pub use store::{InMemoryStore, MemoryStore};
|
||||||
|
|
||||||
|
// 低频类型(配置/高级使用)
|
||||||
|
pub use conversation::MemoryStrategy;
|
||||||
|
pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX};
|
||||||
|
pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem};
|
||||||
|
pub use store::{EvictionConfig, EvictionPolicy};
|
||||||
|
pub use types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||||
@@ -0,0 +1,260 @@
|
|||||||
|
//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact};
|
||||||
|
use crate::llm::types::OpenaiChatMessage;
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::store::MemoryStore;
|
||||||
|
use crate::memory::types::MemoryItem;
|
||||||
|
|
||||||
|
/// 对话消息管理策略。
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum MemoryStrategy {
|
||||||
|
/// 滑动窗口:达到上限时删除最旧消息。
|
||||||
|
SlidingWindow,
|
||||||
|
/// 保留所有消息(仅压缩,不删除)。
|
||||||
|
Full,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对话记忆配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConversationMemoryConfig {
|
||||||
|
pub strategy: MemoryStrategy,
|
||||||
|
pub max_turns: usize,
|
||||||
|
pub compact_config: Option<CompactConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConversationMemoryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
strategy: MemoryStrategy::SlidingWindow,
|
||||||
|
max_turns: 50,
|
||||||
|
compact_config: Some(CompactConfig::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 对话记忆 —— 按 session 管理多轮对话消息历史。
|
||||||
|
///
|
||||||
|
/// 内部维护 `Vec<OpenaiChatMessage>` 热缓存(供 `llm::compact` 直接操作),
|
||||||
|
/// `MemoryStore` 用作冷持久化层。
|
||||||
|
pub struct ConversationMemory {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
session_id: String,
|
||||||
|
config: ConversationMemoryConfig,
|
||||||
|
/// 热缓存:消息列表,供 `llm::compact` 直接操作。
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
/// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。
|
||||||
|
message_ids: Vec<String>,
|
||||||
|
/// 压缩断路器状态。
|
||||||
|
compact_state: CompactState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationMemory {
|
||||||
|
/// 创建一个新的 ConversationMemory。
|
||||||
|
pub fn new(
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
session_id: impl Into<String>,
|
||||||
|
config: ConversationMemoryConfig,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
session_id: session_id.into(),
|
||||||
|
config,
|
||||||
|
messages: Vec::new(),
|
||||||
|
message_ids: Vec::new(),
|
||||||
|
compact_state: CompactState::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取 session id。
|
||||||
|
pub fn session_id(&self) -> &str {
|
||||||
|
&self.session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取配置。
|
||||||
|
pub fn config(&self) -> &ConversationMemoryConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 MemoryStore 加载历史消息到热缓存。
|
||||||
|
pub async fn load(&mut self) -> Result<(), MemoryError> {
|
||||||
|
let filter = crate::memory::types::MemoryFilter {
|
||||||
|
prefix: Some(self.session_prefix()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let items = self.store.list(&filter).await?;
|
||||||
|
let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len());
|
||||||
|
for item in items {
|
||||||
|
match serde_json::from_str::<OpenaiChatMessage>(&item.content) {
|
||||||
|
Ok(msg) => pairs.push((item.id, msg, item.created_at)),
|
||||||
|
Err(e) => {
|
||||||
|
return Err(MemoryError::Serialization(format!(
|
||||||
|
"load message {} failed: {e}",
|
||||||
|
item.id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 按 created_at 升序排列
|
||||||
|
pairs.sort_by_key(|p| p.2);
|
||||||
|
self.message_ids = pairs.iter().map(|p| p.0.clone()).collect();
|
||||||
|
self.messages = pairs.into_iter().map(|p| p.1).collect();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条消息。
|
||||||
|
///
|
||||||
|
/// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。
|
||||||
|
pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
let index = self.messages.len();
|
||||||
|
let id = self.make_message_id(index, &now);
|
||||||
|
|
||||||
|
// 写入热缓存
|
||||||
|
self.messages.push(msg);
|
||||||
|
self.message_ids.push(id.clone());
|
||||||
|
|
||||||
|
// 同步到冷存储
|
||||||
|
let item = MemoryItem {
|
||||||
|
id: id.clone(),
|
||||||
|
content: serde_json::to_string(self.messages.last().unwrap())
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
|
||||||
|
metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }),
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
self.store.save(item).await?;
|
||||||
|
|
||||||
|
// 触发淘汰和压缩
|
||||||
|
self.maybe_evict_and_compact().await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取完整消息历史。
|
||||||
|
pub fn get_history(&self) -> &[OpenaiChatMessage] {
|
||||||
|
&self.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清空所有消息。
|
||||||
|
pub async fn clear(&mut self) -> Result<(), MemoryError> {
|
||||||
|
let to_delete = std::mem::take(&mut self.message_ids);
|
||||||
|
self.messages.clear();
|
||||||
|
self.compact_state = CompactState::new();
|
||||||
|
for id in to_delete {
|
||||||
|
self.store.delete(&id).await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 当前消息数量。
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.messages.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 是否为空。
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.messages.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn session_prefix(&self) -> String {
|
||||||
|
format!("conv:{self}:", self = self.session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String {
|
||||||
|
format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn maybe_evict_and_compact(&mut self) {
|
||||||
|
// 1. Sliding window 淘汰:删除最旧消息
|
||||||
|
if self.config.strategy == MemoryStrategy::SlidingWindow {
|
||||||
|
while self.messages.len() > self.config.max_turns {
|
||||||
|
if let Some(removed_id) = self.message_ids.first().cloned() {
|
||||||
|
let _ = self.store.delete(&removed_id).await;
|
||||||
|
}
|
||||||
|
self.messages.remove(0);
|
||||||
|
self.message_ids.remove(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 压缩(复用 llm::compact)
|
||||||
|
if let Some(ref compact_config) = self.config.compact_config {
|
||||||
|
if should_compact(&self.messages, compact_config, &self.compact_state) {
|
||||||
|
let keep_recent = compact_config.keep_recent;
|
||||||
|
let freed = microcompact(&mut self.messages, keep_recent);
|
||||||
|
if freed > 0 {
|
||||||
|
self.compact_state.record_success();
|
||||||
|
} else {
|
||||||
|
// 没有 token 被释放(可能没找到可压缩的 tool result)
|
||||||
|
let _ = self.compact_state.record_failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::llm::types::OpenaiChatMessage;
|
||||||
|
use crate::memory::InMemoryStore;
|
||||||
|
use crate::memory::MemoryStore;
|
||||||
|
|
||||||
|
fn user_text(s: &str) -> OpenaiChatMessage {
|
||||||
|
OpenaiChatMessage::user_text(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_and_get_history() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default());
|
||||||
|
conv.add_message(user_text("hello")).await.unwrap();
|
||||||
|
conv.add_message(user_text("world")).await.unwrap();
|
||||||
|
assert_eq!(conv.len(), 2);
|
||||||
|
assert_eq!(conv.get_history().len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sliding_window_evicts_oldest() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let config = ConversationMemoryConfig {
|
||||||
|
strategy: MemoryStrategy::SlidingWindow,
|
||||||
|
max_turns: 3,
|
||||||
|
compact_config: None,
|
||||||
|
};
|
||||||
|
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||||
|
for i in 0..5 {
|
||||||
|
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
assert_eq!(conv.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn full_strategy_no_evict() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let config = ConversationMemoryConfig {
|
||||||
|
strategy: MemoryStrategy::Full,
|
||||||
|
max_turns: 3,
|
||||||
|
compact_config: None,
|
||||||
|
};
|
||||||
|
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||||
|
for i in 0..5 {
|
||||||
|
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
// Full 策略不删除消息
|
||||||
|
assert_eq!(conv.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn clear_empties_messages() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default());
|
||||||
|
conv.add_message(user_text("hello")).await.unwrap();
|
||||||
|
assert!(!conv.is_empty());
|
||||||
|
conv.clear().await.unwrap();
|
||||||
|
assert!(conv.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
//! 记忆系统错误类型。
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// 记忆系统错误枚举。
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum MemoryError {
|
||||||
|
#[error("Item not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("Storage error: {0}")]
|
||||||
|
Storage(String),
|
||||||
|
|
||||||
|
#[error("Serialization error: {0}")]
|
||||||
|
Serialization(String),
|
||||||
|
|
||||||
|
#[error("Invalid input: {0}")]
|
||||||
|
InvalidInput(String),
|
||||||
|
|
||||||
|
#[error("Retrieval error: {0}")]
|
||||||
|
RetrievalError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryError {
|
||||||
|
/// 是否为可恢复错误(调用方可重试或调整参数)。
|
||||||
|
pub fn is_recoverable(&self) -> bool {
|
||||||
|
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,247 @@
|
|||||||
|
//! 知识库 —— KnowledgePage 存储与关键词检索。
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::store::MemoryStore;
|
||||||
|
use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||||
|
|
||||||
|
pub use crate::memory::types::PageIndexEntry;
|
||||||
|
|
||||||
|
/// `MemoryItem.id` 中知识页面前缀。
|
||||||
|
pub const KNOWLEDGE_PREFIX: &str = "knowledge_";
|
||||||
|
|
||||||
|
/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。
|
||||||
|
///
|
||||||
|
/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON),
|
||||||
|
/// 同时维护一个 `Vec<PageIndexEntry>` 索引以加速列表遍历。
|
||||||
|
pub struct KnowledgeStore {
|
||||||
|
store: Arc<dyn MemoryStore>,
|
||||||
|
index: std::sync::Mutex<Vec<PageIndexEntry>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KnowledgeStore {
|
||||||
|
/// 创建一个新的 KnowledgeStore。
|
||||||
|
pub fn new(store: Arc<dyn MemoryStore>) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
index: std::sync::Mutex::new(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。
|
||||||
|
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||||
|
let items = self
|
||||||
|
.store
|
||||||
|
.list(&MemoryFilter {
|
||||||
|
prefix: Some(KNOWLEDGE_PREFIX.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
index.clear();
|
||||||
|
for item in items {
|
||||||
|
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
index.push(PageIndexEntry::from(&page));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个新的知识页面。
|
||||||
|
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||||
|
if page.id.is_empty() {
|
||||||
|
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||||
|
}
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
let id = format!("{KNOWLEDGE_PREFIX}{}", page.id);
|
||||||
|
let content = serde_json::to_string(&page)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
let item = MemoryItem {
|
||||||
|
id,
|
||||||
|
content,
|
||||||
|
metadata: serde_json::json!({}),
|
||||||
|
created_at: now,
|
||||||
|
};
|
||||||
|
self.store.save(item).await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
// 替换或追加
|
||||||
|
if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) {
|
||||||
|
*existing = PageIndexEntry::from(&page);
|
||||||
|
} else {
|
||||||
|
index.push(PageIndexEntry::from(&page));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 根据 page id 获取一个页面。
|
||||||
|
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> {
|
||||||
|
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||||
|
let item = self.store.get(&full_id).await?;
|
||||||
|
match item {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(item) => {
|
||||||
|
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||||
|
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||||
|
Ok(Some(page))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 更新一个已存在的知识页面。
|
||||||
|
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||||
|
if page.id.is_empty() {
|
||||||
|
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||||
|
}
|
||||||
|
// 通过 get_page 检查存在性
|
||||||
|
if self.get_page(&page.id).await?.is_none() {
|
||||||
|
return Err(MemoryError::NotFound(page.id));
|
||||||
|
}
|
||||||
|
self.add_page(page).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 删除一个知识页面。
|
||||||
|
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> {
|
||||||
|
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||||
|
self.store.delete(&full_id).await?;
|
||||||
|
let mut index = self.index.lock().unwrap();
|
||||||
|
index.retain(|e| e.id != id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 根据关键词搜索知识页面。
|
||||||
|
///
|
||||||
|
/// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。
|
||||||
|
/// 全文 `content` 搜索走 `MemoryStore`。
|
||||||
|
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> {
|
||||||
|
if query.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
let needle = query.to_lowercase();
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let index = self.index.lock().unwrap();
|
||||||
|
for entry in index.iter() {
|
||||||
|
if entry.title.to_lowercase().contains(&needle)
|
||||||
|
|| entry.summary.to_lowercase().contains(&needle)
|
||||||
|
|| entry.tags.iter().any(|t| t.to_lowercase().contains(&needle))
|
||||||
|
{
|
||||||
|
if let Some(page) = self.get_page(&entry.id).await? {
|
||||||
|
results.push(page);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取内容目录(所有页面的轻量级索引条目)。
|
||||||
|
pub fn get_index(&self) -> Vec<PageIndexEntry> {
|
||||||
|
self.index.lock().unwrap().clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::InMemoryStore;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
KnowledgePage {
|
||||||
|
id: id.to_string(),
|
||||||
|
title: title.to_string(),
|
||||||
|
summary: format!("summary of {title}"),
|
||||||
|
content: format!("full content of {title}"),
|
||||||
|
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||||
|
references: Vec::new(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_get_delete_page() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let got = ks.get_page("p1").await.unwrap();
|
||||||
|
assert!(got.is_some());
|
||||||
|
assert_eq!(got.unwrap().title, "LangGraph");
|
||||||
|
|
||||||
|
ks.delete_page("p1").await.unwrap();
|
||||||
|
assert!(ks.get_page("p1").await.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn add_page_rejects_empty_id() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let result = ks.add_page(make_page("", "NoId", &[])).await;
|
||||||
|
assert!(matches!(result, Err(MemoryError::InvalidInput(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn update_page_requires_existing() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let result = ks.update_page(make_page("nope", "Ghost", &[])).await;
|
||||||
|
assert!(matches!(result, Err(MemoryError::NotFound(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn search_finds_by_title_summary_tag() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "Other", &["knowledge-graph"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p3", "Third", &["unrelated"]))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = ks.search("stategraph").await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].id, "p1");
|
||||||
|
|
||||||
|
let results = ks.search("knowledge-graph").await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].id, "p2");
|
||||||
|
|
||||||
|
let results = ks.search("nonexistent").await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_index_returns_all_pages() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||||
|
let index = ks.get_index();
|
||||||
|
assert_eq!(index.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn rebuild_index_recovers_from_drift() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
// 添加页面
|
||||||
|
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||||
|
assert_eq!(ks.get_index().len(), 2);
|
||||||
|
|
||||||
|
// 模拟 index 漂移:清空后重建
|
||||||
|
ks.index.lock().unwrap().clear();
|
||||||
|
assert_eq!(ks.get_index().len(), 0);
|
||||||
|
|
||||||
|
ks.rebuild_index().await.unwrap();
|
||||||
|
assert_eq!(ks.get_index().len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,297 @@
|
|||||||
|
//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::knowledge::KnowledgeStore;
|
||||||
|
use crate::memory::types::KnowledgePage;
|
||||||
|
|
||||||
|
/// 检索器配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetrieverConfig {
|
||||||
|
/// 最大返回条数(默认 20)。
|
||||||
|
pub max_results: usize,
|
||||||
|
/// 最低分数阈值 [0.0, 1.0](默认 0.1)。
|
||||||
|
pub min_score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetrieverConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_results: 20,
|
||||||
|
min_score: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 单条带评分的检索结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ScoredItem {
|
||||||
|
pub page: KnowledgePage,
|
||||||
|
/// TextOverlap 评分 [0.0, 1.0]
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检索结果。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetrievalResult {
|
||||||
|
pub items: Vec<ScoredItem>,
|
||||||
|
pub query: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。
|
||||||
|
pub struct MemoryRetriever {
|
||||||
|
knowledge_store: KnowledgeStore,
|
||||||
|
config: RetrieverConfig,
|
||||||
|
/// 停用词表(用于关键词提取)。
|
||||||
|
stop_words: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryRetriever {
|
||||||
|
/// 创建一个新的 MemoryRetriever。
|
||||||
|
pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
knowledge_store,
|
||||||
|
config,
|
||||||
|
stop_words: default_stop_words(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 替换停用词表。
|
||||||
|
pub fn with_stop_words(mut self, stop_words: HashSet<String>) -> Self {
|
||||||
|
self.stop_words = stop_words;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检索相关知识页面。
|
||||||
|
pub async fn retrieve(&self, query: &str) -> Result<RetrievalResult, MemoryError> {
|
||||||
|
if query.is_empty() {
|
||||||
|
return Ok(RetrievalResult {
|
||||||
|
items: Vec::new(),
|
||||||
|
query: query.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 关键词提取
|
||||||
|
let keywords = extract_keywords(query, &self.stop_words);
|
||||||
|
|
||||||
|
// 2. 用关键词在 KnowledgeStore 中搜索
|
||||||
|
let mut pages = Vec::new();
|
||||||
|
for keyword in &keywords {
|
||||||
|
let found = self.knowledge_store.search(keyword).await?;
|
||||||
|
for page in found {
|
||||||
|
if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) {
|
||||||
|
pages.push(page);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. TextOverlap 评分
|
||||||
|
let mut items: Vec<ScoredItem> = pages
|
||||||
|
.into_iter()
|
||||||
|
.map(|page| {
|
||||||
|
let score = text_overlap_score(query, &page);
|
||||||
|
ScoredItem { page, score }
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// 4. 过滤 → 排序 → 截取
|
||||||
|
items.retain(|i| i.score >= self.config.min_score);
|
||||||
|
items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
items.truncate(self.config.max_results);
|
||||||
|
|
||||||
|
Ok(RetrievalResult {
|
||||||
|
items,
|
||||||
|
query: query.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。
|
||||||
|
fn extract_keywords(query: &str, stop_words: &HashSet<String>) -> Vec<String> {
|
||||||
|
query
|
||||||
|
.split(|c: char| !c.is_alphanumeric())
|
||||||
|
.filter_map(|s| {
|
||||||
|
let lower = s.to_lowercase();
|
||||||
|
if lower.is_empty() || lower.chars().count() < 2 {
|
||||||
|
None
|
||||||
|
} else if stop_words.contains(&lower) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(lower)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。
|
||||||
|
///
|
||||||
|
/// 字段权重:title 0.5 + summary 0.3 + content 0.2
|
||||||
|
/// 中文场景按字符级 bigram 处理,不依赖分词器。
|
||||||
|
pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 {
|
||||||
|
let title = text_overlap_dice(query, &page.title);
|
||||||
|
let summary = text_overlap_dice(query, &page.summary);
|
||||||
|
let content = text_overlap_dice(query, &page.content);
|
||||||
|
title * 0.5 + summary * 0.3 + content * 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dice 系数(基于字符 bigram)。
|
||||||
|
fn text_overlap_dice(query: &str, text: &str) -> f32 {
|
||||||
|
let q_bigrams = char_bigrams(query);
|
||||||
|
let t_bigrams = char_bigrams(text);
|
||||||
|
if q_bigrams.is_empty() || t_bigrams.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let q_set: HashSet<&String> = q_bigrams.iter().collect();
|
||||||
|
let t_set: HashSet<&String> = t_bigrams.iter().collect();
|
||||||
|
let intersect = q_set.intersection(&t_set).count();
|
||||||
|
let denom = q_set.len() + t_set.len();
|
||||||
|
if denom == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
(2.0 * intersect as f32) / denom as f32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 提取字符 bigrams(用于字符级 Dice 系数)。
|
||||||
|
fn char_bigrams(s: &str) -> Vec<String> {
|
||||||
|
let chars: Vec<char> = s.chars().collect();
|
||||||
|
chars.windows(2).map(|w| w.iter().collect()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_stop_words() -> HashSet<String> {
|
||||||
|
[
|
||||||
|
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
|
||||||
|
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall",
|
||||||
|
"can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what",
|
||||||
|
"which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor",
|
||||||
|
"so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at",
|
||||||
|
"by", "of", "as", "into", "through", "during", "before", "after", "above", "below",
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::knowledge::KnowledgeStore;
|
||||||
|
use crate::memory::{InMemoryStore, MemoryStore};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage {
|
||||||
|
let now = OffsetDateTime::now_utc();
|
||||||
|
KnowledgePage {
|
||||||
|
id: id.to_string(),
|
||||||
|
title: title.to_string(),
|
||||||
|
summary: summary.to_string(),
|
||||||
|
content: content.to_string(),
|
||||||
|
tags: Vec::new(),
|
||||||
|
references: Vec::new(),
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_empty_query() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||||
|
let result = retriever.retrieve("").await.unwrap();
|
||||||
|
assert!(result.items.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_finds_relevant_page() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page(
|
||||||
|
"p1",
|
||||||
|
"LangGraph StateGraph",
|
||||||
|
"state management",
|
||||||
|
"LangGraph uses StateGraph for state machines",
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||||
|
let result = retriever.retrieve("LangGraph state").await.unwrap();
|
||||||
|
assert!(!result.items.is_empty());
|
||||||
|
assert_eq!(result.items[0].page.id, "p1");
|
||||||
|
assert!(result.items[0].score > 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_respects_min_score() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap();
|
||||||
|
|
||||||
|
let config = RetrieverConfig {
|
||||||
|
max_results: 10,
|
||||||
|
min_score: 0.99,
|
||||||
|
};
|
||||||
|
let retriever = MemoryRetriever::new(ks, config);
|
||||||
|
let result = retriever.retrieve("totally unrelated content").await.unwrap();
|
||||||
|
assert!(result.items.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retrieve_respects_max_results() {
|
||||||
|
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||||
|
let ks = KnowledgeStore::new(store);
|
||||||
|
for i in 0..5 {
|
||||||
|
ks.add_page(make_page(
|
||||||
|
&format!("p{i}"),
|
||||||
|
"LangGraph",
|
||||||
|
"framework",
|
||||||
|
"agent runtime",
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
let config = RetrieverConfig {
|
||||||
|
max_results: 2,
|
||||||
|
min_score: 0.0,
|
||||||
|
};
|
||||||
|
let retriever = MemoryRetriever::new(ks, config);
|
||||||
|
let result = retriever.retrieve("LangGraph").await.unwrap();
|
||||||
|
assert_eq!(result.items.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn text_overlap_dice_zero_on_empty() {
|
||||||
|
assert_eq!(text_overlap_dice("hello", ""), 0.0);
|
||||||
|
assert_eq!(text_overlap_dice("", "hello"), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn text_overlap_dice_identical() {
|
||||||
|
let s = "hello world";
|
||||||
|
assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_keywords_filters_stop_words() {
|
||||||
|
let stop = default_stop_words();
|
||||||
|
let kws = extract_keywords("the quick brown fox is fast", &stop);
|
||||||
|
assert!(!kws.contains(&"the".to_string()));
|
||||||
|
assert!(!kws.contains(&"is".to_string()));
|
||||||
|
assert!(kws.contains(&"quick".to_string()));
|
||||||
|
assert!(kws.contains(&"brown".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_keywords_filters_single_chars() {
|
||||||
|
let stop = default_stop_words();
|
||||||
|
let kws = extract_keywords("a b c dog", &stop);
|
||||||
|
assert!(!kws.contains(&"a".to_string()));
|
||||||
|
assert!(!kws.contains(&"b".to_string()));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,314 @@
|
|||||||
|
//! MemoryStore 抽象接口与默认实现。
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::memory::error::MemoryError;
|
||||||
|
use crate::memory::types::{MemoryFilter, MemoryItem};
|
||||||
|
|
||||||
|
/// 底层记忆存储抽象接口。
|
||||||
|
///
|
||||||
|
/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。
|
||||||
|
/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait MemoryStore: Send + Sync {
|
||||||
|
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||||
|
/// - 如果 id 不存在,则插入新条目
|
||||||
|
/// - 如果 id 已存在,则覆盖旧条目
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 id 获取一个 MemoryItem。
|
||||||
|
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 id 删除一个 MemoryItem。
|
||||||
|
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||||
|
|
||||||
|
/// 根据 filter 列出 MemoryItem。
|
||||||
|
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 淘汰策略。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum EvictionPolicy {
|
||||||
|
/// 不淘汰(默认)。
|
||||||
|
None,
|
||||||
|
/// 超过存活时间(秒)淘汰。
|
||||||
|
Ttl { ttl_secs: u64 },
|
||||||
|
/// 超过容量上限淘汰最旧(基于 created_at)。
|
||||||
|
Capacity { max_items: usize },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 淘汰配置。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EvictionConfig {
|
||||||
|
pub policy: EvictionPolicy,
|
||||||
|
/// 每写入 N 条后检查一次淘汰条件。
|
||||||
|
pub check_interval: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EvictionConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
policy: EvictionPolicy::None,
|
||||||
|
check_interval: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。
|
||||||
|
pub struct InMemoryStore {
|
||||||
|
items: Mutex<HashMap<String, MemoryItem>>,
|
||||||
|
eviction: EvictionConfig,
|
||||||
|
/// 自上次淘汰检查以来的写入次数。
|
||||||
|
writes_since_check: Mutex<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InMemoryStore {
|
||||||
|
/// 创建一个无淘汰策略的 InMemoryStore。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
items: Mutex::new(HashMap::new()),
|
||||||
|
eviction: EvictionConfig::default(),
|
||||||
|
writes_since_check: Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建一个带淘汰配置的 InMemoryStore。
|
||||||
|
pub fn with_eviction(eviction: EvictionConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
items: Mutex::new(HashMap::new()),
|
||||||
|
eviction,
|
||||||
|
writes_since_check: Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn maybe_evict(&self) {
|
||||||
|
// 不使用 .lock().await 跨点,先取计数判断是否需要淘汰
|
||||||
|
let should_check = {
|
||||||
|
let mut counter = self.writes_since_check.lock().unwrap();
|
||||||
|
*counter += 1;
|
||||||
|
if *counter >= self.eviction.check_interval {
|
||||||
|
*counter = 0;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if !should_check {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let policy = self.eviction.policy.clone();
|
||||||
|
match policy {
|
||||||
|
EvictionPolicy::None => {}
|
||||||
|
EvictionPolicy::Ttl { ttl_secs } => {
|
||||||
|
let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64);
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.retain(|_, v| v.created_at > cutoff);
|
||||||
|
}
|
||||||
|
EvictionPolicy::Capacity { max_items } => {
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
if items.len() > max_items {
|
||||||
|
let mut vec: Vec<_> = items.drain().collect();
|
||||||
|
// O(n) 部分排序:保留 created_at 最大的 max_items 个
|
||||||
|
vec.select_nth_unstable_by(max_items, |a, b| {
|
||||||
|
b.1.created_at.cmp(&a.1.created_at)
|
||||||
|
});
|
||||||
|
vec.truncate(max_items);
|
||||||
|
*items = vec.into_iter().collect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for InMemoryStore {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MemoryStore for InMemoryStore {
|
||||||
|
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||||
|
{
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.insert(item.id.clone(), item);
|
||||||
|
}
|
||||||
|
self.maybe_evict();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError> {
|
||||||
|
let items = self.items.lock().unwrap();
|
||||||
|
Ok(items.get(id).cloned())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete(&self, id: &str) -> Result<(), MemoryError> {
|
||||||
|
let mut items = self.items.lock().unwrap();
|
||||||
|
items.remove(id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError> {
|
||||||
|
let items = self.items.lock().unwrap();
|
||||||
|
let mut result: Vec<MemoryItem> = items
|
||||||
|
.values()
|
||||||
|
.filter(|v| match &filter.prefix {
|
||||||
|
Some(p) => v.id.starts_with(p),
|
||||||
|
None => true,
|
||||||
|
})
|
||||||
|
.filter(|v| match filter.since {
|
||||||
|
Some(t) => v.created_at > t,
|
||||||
|
None => true,
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
// 按 created_at 升序排列(最旧在前)
|
||||||
|
result.sort_by_key(|v| v.created_at);
|
||||||
|
// 应用 offset
|
||||||
|
if let Some(offset) = filter.offset {
|
||||||
|
if offset < result.len() {
|
||||||
|
result.drain(..offset);
|
||||||
|
} else {
|
||||||
|
result.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 应用 limit
|
||||||
|
if let Some(limit) = filter.limit {
|
||||||
|
result.truncate(limit);
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
fn make_item(id: &str) -> MemoryItem {
|
||||||
|
MemoryItem {
|
||||||
|
id: id.to_string(),
|
||||||
|
content: format!("content-{id}"),
|
||||||
|
metadata: serde_json::json!({}),
|
||||||
|
created_at: OffsetDateTime::now_utc(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_get_delete_list() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
|
||||||
|
let got = store.get("a").await.unwrap();
|
||||||
|
assert!(got.is_some());
|
||||||
|
assert_eq!(got.unwrap().id, "a");
|
||||||
|
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
|
||||||
|
store.delete("a").await.unwrap();
|
||||||
|
assert!(store.get("a").await.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn save_is_upsert() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
let mut item = make_item("a");
|
||||||
|
item.content = "updated".to_string();
|
||||||
|
store.save(item).await.unwrap();
|
||||||
|
let got = store.get("a").await.unwrap().unwrap();
|
||||||
|
assert_eq!(got.content, "updated");
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_with_prefix_and_limit() {
|
||||||
|
let store = InMemoryStore::new();
|
||||||
|
store.save(make_item("foo_a")).await.unwrap();
|
||||||
|
store.save(make_item("foo_b")).await.unwrap();
|
||||||
|
store.save(make_item("bar_a")).await.unwrap();
|
||||||
|
|
||||||
|
let filter = MemoryFilter {
|
||||||
|
prefix: Some("foo_".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let list = store.list(&filter).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
|
||||||
|
let filter = MemoryFilter {
|
||||||
|
prefix: Some("foo_".to_string()),
|
||||||
|
limit: Some(1),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let list = store.list(&filter).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn capacity_eviction() {
|
||||||
|
// 强制每次写入都检查
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::Capacity { max_items: 2 },
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
// 第一条和第二条共存
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
// 第三条写入触发淘汰:a 或 b 之一被淘汰
|
||||||
|
store.save(make_item("c")).await.unwrap();
|
||||||
|
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 2);
|
||||||
|
// 留下的应该是 b 和 c(最新的两个)
|
||||||
|
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||||
|
assert!(ids.contains(&"b"));
|
||||||
|
assert!(ids.contains(&"c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ttl_eviction() {
|
||||||
|
// TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。
|
||||||
|
// 改用小 TTL + 睡眠:先 save a,sleep,save b 时 a 已过期被淘汰。
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::Ttl { ttl_secs: 1 },
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
store.save(make_item("a")).await.unwrap();
|
||||||
|
// 等待超过 1 秒
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(1100));
|
||||||
|
// 触发淘汰:a 已超过 ttl_secs=1,应被淘汰
|
||||||
|
store.save(make_item("b")).await.unwrap();
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
// 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。
|
||||||
|
// 我们只断言 list 不包含 "a" 即可。
|
||||||
|
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||||
|
assert!(
|
||||||
|
!ids.contains(&"a"),
|
||||||
|
"expected 'a' to be evicted, but found in {ids:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn none_policy_no_eviction() {
|
||||||
|
let eviction = EvictionConfig {
|
||||||
|
policy: EvictionPolicy::None,
|
||||||
|
check_interval: 1,
|
||||||
|
};
|
||||||
|
let store = InMemoryStore::with_eviction(eviction);
|
||||||
|
for i in 0..100 {
|
||||||
|
store.save(make_item(&format!("item_{i}"))).await.unwrap();
|
||||||
|
}
|
||||||
|
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||||
|
assert_eq!(list.len(), 100);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
//! 记忆系统核心数据类型。
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
/// 记忆条目 —— MemoryStore 存储的基本单元。
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MemoryItem {
|
||||||
|
/// 唯一标识。
|
||||||
|
pub id: String,
|
||||||
|
/// 内容(通常为 JSON 序列化的具体记忆数据)。
|
||||||
|
pub content: String,
|
||||||
|
/// 任意附加元数据。
|
||||||
|
pub metadata: serde_json::Value,
|
||||||
|
/// 创建时间。
|
||||||
|
pub created_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MemoryStore 列表查询条件。
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct MemoryFilter {
|
||||||
|
/// 按 id 前缀过滤。
|
||||||
|
pub prefix: Option<String>,
|
||||||
|
/// 仅返回该时间之后创建的条目。
|
||||||
|
pub since: Option<OffsetDateTime>,
|
||||||
|
/// 跳过前 N 条。
|
||||||
|
pub offset: Option<usize>,
|
||||||
|
/// 最多返回 N 条。
|
||||||
|
pub limit: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 知识页面 —— 描述一段结构化知识。
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct KnowledgePage {
|
||||||
|
/// 唯一标识。
|
||||||
|
pub id: String,
|
||||||
|
/// 标题。
|
||||||
|
pub title: String,
|
||||||
|
/// 一句话摘要。
|
||||||
|
pub summary: String,
|
||||||
|
/// 完整内容。
|
||||||
|
pub content: String,
|
||||||
|
/// 检索标签。
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
/// 交叉引用的其他页面 ID。
|
||||||
|
pub references: Vec<String>,
|
||||||
|
/// 创建时间。
|
||||||
|
pub created_at: OffsetDateTime,
|
||||||
|
/// 最后更新时间。
|
||||||
|
pub updated_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 知识页面索引条目 —— 用于轻量遍历和内容目录。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PageIndexEntry {
|
||||||
|
pub id: String,
|
||||||
|
pub title: String,
|
||||||
|
pub summary: String,
|
||||||
|
pub tags: Vec<String>,
|
||||||
|
pub updated_at: OffsetDateTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&KnowledgePage> for PageIndexEntry {
|
||||||
|
fn from(p: &KnowledgePage) -> Self {
|
||||||
|
Self {
|
||||||
|
id: p.id.clone(),
|
||||||
|
title: p.title.clone(),
|
||||||
|
summary: p.summary.clone(),
|
||||||
|
tags: p.tags.clone(),
|
||||||
|
updated_at: p.updated_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user