feat(memory): 添加记忆系统模块
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
//! agcore —— 智能体(Agent)核心工具箱。
|
||||
|
||||
pub mod llm;
|
||||
pub mod memory;
|
||||
pub mod prompt;
|
||||
pub mod tools;
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
//! 记忆系统 —— 对话消息管理、知识页面存储与关键词检索。
|
||||
|
||||
pub mod conversation;
|
||||
pub mod error;
|
||||
pub mod knowledge;
|
||||
pub mod retriever;
|
||||
pub mod store;
|
||||
pub mod types;
|
||||
|
||||
// 高频类型(大多数下游需要)
|
||||
pub use conversation::{ConversationMemory, ConversationMemoryConfig};
|
||||
pub use error::MemoryError;
|
||||
pub use knowledge::KnowledgeStore;
|
||||
pub use retriever::MemoryRetriever;
|
||||
pub use store::{InMemoryStore, MemoryStore};
|
||||
|
||||
// 低频类型(配置/高级使用)
|
||||
pub use conversation::MemoryStrategy;
|
||||
pub use knowledge::{PageIndexEntry, KNOWLEDGE_PREFIX};
|
||||
pub use retriever::{RetrieverConfig, RetrievalResult, ScoredItem};
|
||||
pub use store::{EvictionConfig, EvictionPolicy};
|
||||
pub use types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||
@@ -0,0 +1,260 @@
|
||||
//! 对话记忆 —— 多轮对话消息管理,复用 `llm::compact` 的压缩逻辑。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::llm::compact::{CompactConfig, CompactState, microcompact, should_compact};
|
||||
use crate::llm::types::OpenaiChatMessage;
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::memory::types::MemoryItem;
|
||||
|
||||
/// 对话消息管理策略。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MemoryStrategy {
|
||||
/// 滑动窗口:达到上限时删除最旧消息。
|
||||
SlidingWindow,
|
||||
/// 保留所有消息(仅压缩,不删除)。
|
||||
Full,
|
||||
}
|
||||
|
||||
/// 对话记忆配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConversationMemoryConfig {
|
||||
pub strategy: MemoryStrategy,
|
||||
pub max_turns: usize,
|
||||
pub compact_config: Option<CompactConfig>,
|
||||
}
|
||||
|
||||
impl Default for ConversationMemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
strategy: MemoryStrategy::SlidingWindow,
|
||||
max_turns: 50,
|
||||
compact_config: Some(CompactConfig::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 对话记忆 —— 按 session 管理多轮对话消息历史。
|
||||
///
|
||||
/// 内部维护 `Vec<OpenaiChatMessage>` 热缓存(供 `llm::compact` 直接操作),
|
||||
/// `MemoryStore` 用作冷持久化层。
|
||||
pub struct ConversationMemory {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
session_id: String,
|
||||
config: ConversationMemoryConfig,
|
||||
/// 热缓存:消息列表,供 `llm::compact` 直接操作。
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
/// 与 `messages` 一一对应的存储 ID(保持稳定以便淘汰时精准删除)。
|
||||
message_ids: Vec<String>,
|
||||
/// 压缩断路器状态。
|
||||
compact_state: CompactState,
|
||||
}
|
||||
|
||||
impl ConversationMemory {
|
||||
/// 创建一个新的 ConversationMemory。
|
||||
pub fn new(
|
||||
store: Arc<dyn MemoryStore>,
|
||||
session_id: impl Into<String>,
|
||||
config: ConversationMemoryConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
session_id: session_id.into(),
|
||||
config,
|
||||
messages: Vec::new(),
|
||||
message_ids: Vec::new(),
|
||||
compact_state: CompactState::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取 session id。
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
/// 获取配置。
|
||||
pub fn config(&self) -> &ConversationMemoryConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// 从 MemoryStore 加载历史消息到热缓存。
|
||||
pub async fn load(&mut self) -> Result<(), MemoryError> {
|
||||
let filter = crate::memory::types::MemoryFilter {
|
||||
prefix: Some(self.session_prefix()),
|
||||
..Default::default()
|
||||
};
|
||||
let items = self.store.list(&filter).await?;
|
||||
let mut pairs: Vec<(String, OpenaiChatMessage, OffsetDateTime)> = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
match serde_json::from_str::<OpenaiChatMessage>(&item.content) {
|
||||
Ok(msg) => pairs.push((item.id, msg, item.created_at)),
|
||||
Err(e) => {
|
||||
return Err(MemoryError::Serialization(format!(
|
||||
"load message {} failed: {e}",
|
||||
item.id
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// 按 created_at 升序排列
|
||||
pairs.sort_by_key(|p| p.2);
|
||||
self.message_ids = pairs.iter().map(|p| p.0.clone()).collect();
|
||||
self.messages = pairs.into_iter().map(|p| p.1).collect();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 添加一条消息。
|
||||
///
|
||||
/// 写入热缓存并通过 `MemoryStore` 持久化。如有需要,触发淘汰和压缩。
|
||||
pub async fn add_message(&mut self, msg: OpenaiChatMessage) -> Result<(), MemoryError> {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let index = self.messages.len();
|
||||
let id = self.make_message_id(index, &now);
|
||||
|
||||
// 写入热缓存
|
||||
self.messages.push(msg);
|
||||
self.message_ids.push(id.clone());
|
||||
|
||||
// 同步到冷存储
|
||||
let item = MemoryItem {
|
||||
id: id.clone(),
|
||||
content: serde_json::to_string(self.messages.last().unwrap())
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?,
|
||||
metadata: serde_json::json!({ "session_id": &self.session_id, "index": index }),
|
||||
created_at: now,
|
||||
};
|
||||
self.store.save(item).await?;
|
||||
|
||||
// 触发淘汰和压缩
|
||||
self.maybe_evict_and_compact().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取完整消息历史。
|
||||
pub fn get_history(&self) -> &[OpenaiChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空所有消息。
|
||||
pub async fn clear(&mut self) -> Result<(), MemoryError> {
|
||||
let to_delete = std::mem::take(&mut self.message_ids);
|
||||
self.messages.clear();
|
||||
self.compact_state = CompactState::new();
|
||||
for id in to_delete {
|
||||
self.store.delete(&id).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 当前消息数量。
|
||||
pub fn len(&self) -> usize {
|
||||
self.messages.len()
|
||||
}
|
||||
|
||||
/// 是否为空。
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.messages.is_empty()
|
||||
}
|
||||
|
||||
fn session_prefix(&self) -> String {
|
||||
format!("conv:{self}:", self = self.session_id)
|
||||
}
|
||||
|
||||
fn make_message_id(&self, index: usize, now: &OffsetDateTime) -> String {
|
||||
format!("{}{:010}_{}", self.session_prefix(), index, now.unix_timestamp_nanos())
|
||||
}
|
||||
|
||||
async fn maybe_evict_and_compact(&mut self) {
|
||||
// 1. Sliding window 淘汰:删除最旧消息
|
||||
if self.config.strategy == MemoryStrategy::SlidingWindow {
|
||||
while self.messages.len() > self.config.max_turns {
|
||||
if let Some(removed_id) = self.message_ids.first().cloned() {
|
||||
let _ = self.store.delete(&removed_id).await;
|
||||
}
|
||||
self.messages.remove(0);
|
||||
self.message_ids.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 压缩(复用 llm::compact)
|
||||
if let Some(ref compact_config) = self.config.compact_config {
|
||||
if should_compact(&self.messages, compact_config, &self.compact_state) {
|
||||
let keep_recent = compact_config.keep_recent;
|
||||
let freed = microcompact(&mut self.messages, keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
} else {
|
||||
// 没有 token 被释放(可能没找到可压缩的 tool result)
|
||||
let _ = self.compact_state.record_failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::types::OpenaiChatMessage;
|
||||
use crate::memory::InMemoryStore;
|
||||
use crate::memory::MemoryStore;
|
||||
|
||||
fn user_text(s: &str) -> OpenaiChatMessage {
|
||||
OpenaiChatMessage::user_text(s)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_and_get_history() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let mut conv = ConversationMemory::new(store, "session1", ConversationMemoryConfig::default());
|
||||
conv.add_message(user_text("hello")).await.unwrap();
|
||||
conv.add_message(user_text("world")).await.unwrap();
|
||||
assert_eq!(conv.len(), 2);
|
||||
assert_eq!(conv.get_history().len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sliding_window_evicts_oldest() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let config = ConversationMemoryConfig {
|
||||
strategy: MemoryStrategy::SlidingWindow,
|
||||
max_turns: 3,
|
||||
compact_config: None,
|
||||
};
|
||||
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||
for i in 0..5 {
|
||||
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||
}
|
||||
assert_eq!(conv.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn full_strategy_no_evict() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let config = ConversationMemoryConfig {
|
||||
strategy: MemoryStrategy::Full,
|
||||
max_turns: 3,
|
||||
compact_config: None,
|
||||
};
|
||||
let mut conv = ConversationMemory::new(store, "s1", config);
|
||||
for i in 0..5 {
|
||||
conv.add_message(user_text(&format!("msg-{i}"))).await.unwrap();
|
||||
}
|
||||
// Full 策略不删除消息
|
||||
assert_eq!(conv.len(), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn clear_empties_messages() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let mut conv = ConversationMemory::new(store.clone(), "s1", ConversationMemoryConfig::default());
|
||||
conv.add_message(user_text("hello")).await.unwrap();
|
||||
assert!(!conv.is_empty());
|
||||
conv.clear().await.unwrap();
|
||||
assert!(conv.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
//! 记忆系统错误类型。
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// 记忆系统错误枚举。
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MemoryError {
|
||||
#[error("Item not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Retrieval error: {0}")]
|
||||
RetrievalError(String),
|
||||
}
|
||||
|
||||
impl MemoryError {
|
||||
/// 是否为可恢复错误(调用方可重试或调整参数)。
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
matches!(self, Self::NotFound(_) | Self::RetrievalError(_))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
//! 知识库 —— KnowledgePage 存储与关键词检索。
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::store::MemoryStore;
|
||||
use crate::memory::types::{KnowledgePage, MemoryFilter, MemoryItem};
|
||||
|
||||
pub use crate::memory::types::PageIndexEntry;
|
||||
|
||||
/// `MemoryItem.id` 中知识页面前缀。
|
||||
pub const KNOWLEDGE_PREFIX: &str = "knowledge_";
|
||||
|
||||
/// 知识库 —— KnowledgePage CRUD + 关键词检索 + 内容索引。
|
||||
///
|
||||
/// 内部以 `MemoryStore` 为后端存储 KnowledgePage(序列化为 JSON),
|
||||
/// 同时维护一个 `Vec<PageIndexEntry>` 索引以加速列表遍历。
|
||||
pub struct KnowledgeStore {
|
||||
store: Arc<dyn MemoryStore>,
|
||||
index: std::sync::Mutex<Vec<PageIndexEntry>>,
|
||||
}
|
||||
|
||||
impl KnowledgeStore {
|
||||
/// 创建一个新的 KnowledgeStore。
|
||||
pub fn new(store: Arc<dyn MemoryStore>) -> Self {
|
||||
Self {
|
||||
store,
|
||||
index: std::sync::Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 MemoryStore 重建索引(修复 index 与 store 的不同步问题)。
|
||||
pub async fn rebuild_index(&self) -> Result<(), MemoryError> {
|
||||
let items = self
|
||||
.store
|
||||
.list(&MemoryFilter {
|
||||
prefix: Some(KNOWLEDGE_PREFIX.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
index.clear();
|
||||
for item in items {
|
||||
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
index.push(PageIndexEntry::from(&page));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 创建一个新的知识页面。
|
||||
pub async fn add_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||
if page.id.is_empty() {
|
||||
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||
}
|
||||
let now = OffsetDateTime::now_utc();
|
||||
let id = format!("{KNOWLEDGE_PREFIX}{}", page.id);
|
||||
let content = serde_json::to_string(&page)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
let item = MemoryItem {
|
||||
id,
|
||||
content,
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: now,
|
||||
};
|
||||
self.store.save(item).await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
// 替换或追加
|
||||
if let Some(existing) = index.iter_mut().find(|e| e.id == page.id) {
|
||||
*existing = PageIndexEntry::from(&page);
|
||||
} else {
|
||||
index.push(PageIndexEntry::from(&page));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 根据 page id 获取一个页面。
|
||||
pub async fn get_page(&self, id: &str) -> Result<Option<KnowledgePage>, MemoryError> {
|
||||
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||
let item = self.store.get(&full_id).await?;
|
||||
match item {
|
||||
None => Ok(None),
|
||||
Some(item) => {
|
||||
let page: KnowledgePage = serde_json::from_str(&item.content)
|
||||
.map_err(|e| MemoryError::Serialization(e.to_string()))?;
|
||||
Ok(Some(page))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 更新一个已存在的知识页面。
|
||||
pub async fn update_page(&self, page: KnowledgePage) -> Result<(), MemoryError> {
|
||||
if page.id.is_empty() {
|
||||
return Err(MemoryError::InvalidInput("page.id is empty".into()));
|
||||
}
|
||||
// 通过 get_page 检查存在性
|
||||
if self.get_page(&page.id).await?.is_none() {
|
||||
return Err(MemoryError::NotFound(page.id));
|
||||
}
|
||||
self.add_page(page).await
|
||||
}
|
||||
|
||||
/// 删除一个知识页面。
|
||||
pub async fn delete_page(&self, id: &str) -> Result<(), MemoryError> {
|
||||
let full_id = format!("{KNOWLEDGE_PREFIX}{id}");
|
||||
self.store.delete(&full_id).await?;
|
||||
let mut index = self.index.lock().unwrap();
|
||||
index.retain(|e| e.id != id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 根据关键词搜索知识页面。
|
||||
///
|
||||
/// 匹配规则:在 `title` / `summary` / `tags` 中查找子串(不区分大小写)。
|
||||
/// 全文 `content` 搜索走 `MemoryStore`。
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<KnowledgePage>, MemoryError> {
|
||||
if query.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let needle = query.to_lowercase();
|
||||
let mut results = Vec::new();
|
||||
let index = self.index.lock().unwrap();
|
||||
for entry in index.iter() {
|
||||
if entry.title.to_lowercase().contains(&needle)
|
||||
|| entry.summary.to_lowercase().contains(&needle)
|
||||
|| entry.tags.iter().any(|t| t.to_lowercase().contains(&needle))
|
||||
{
|
||||
if let Some(page) = self.get_page(&entry.id).await? {
|
||||
results.push(page);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// 获取内容目录(所有页面的轻量级索引条目)。
|
||||
pub fn get_index(&self) -> Vec<PageIndexEntry> {
|
||||
self.index.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::InMemoryStore;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_page(id: &str, title: &str, tags: &[&str]) -> KnowledgePage {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
KnowledgePage {
|
||||
id: id.to_string(),
|
||||
title: title.to_string(),
|
||||
summary: format!("summary of {title}"),
|
||||
content: format!("full content of {title}"),
|
||||
tags: tags.iter().map(|s| s.to_string()).collect(),
|
||||
references: Vec::new(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_get_delete_page() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "LangGraph", &["langgraph", "framework"]))
|
||||
.await
|
||||
.unwrap();
|
||||
let got = ks.get_page("p1").await.unwrap();
|
||||
assert!(got.is_some());
|
||||
assert_eq!(got.unwrap().title, "LangGraph");
|
||||
|
||||
ks.delete_page("p1").await.unwrap();
|
||||
assert!(ks.get_page("p1").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn add_page_rejects_empty_id() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let result = ks.add_page(make_page("", "NoId", &[])).await;
|
||||
assert!(matches!(result, Err(MemoryError::InvalidInput(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_page_requires_existing() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let result = ks.update_page(make_page("nope", "Ghost", &[])).await;
|
||||
assert!(matches!(result, Err(MemoryError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_finds_by_title_summary_tag() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "LangGraph StateGraph", &["llm"]))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p2", "Other", &["knowledge-graph"]))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p3", "Third", &["unrelated"]))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = ks.search("stategraph").await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "p1");
|
||||
|
||||
let results = ks.search("knowledge-graph").await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "p2");
|
||||
|
||||
let results = ks.search("nonexistent").await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn get_index_returns_all_pages() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||
let index = ks.get_index();
|
||||
assert_eq!(index.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rebuild_index_recovers_from_drift() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
// 添加页面
|
||||
ks.add_page(make_page("p1", "A", &[])).await.unwrap();
|
||||
ks.add_page(make_page("p2", "B", &[])).await.unwrap();
|
||||
assert_eq!(ks.get_index().len(), 2);
|
||||
|
||||
// 模拟 index 漂移:清空后重建
|
||||
ks.index.lock().unwrap().clear();
|
||||
assert_eq!(ks.get_index().len(), 0);
|
||||
|
||||
ks.rebuild_index().await.unwrap();
|
||||
assert_eq!(ks.get_index().len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
//! 记忆检索器 —— 基于 TextOverlap (Dice 系数) 的单通道关键词检索。
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::knowledge::KnowledgeStore;
|
||||
use crate::memory::types::KnowledgePage;
|
||||
|
||||
/// 检索器配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrieverConfig {
|
||||
/// 最大返回条数(默认 20)。
|
||||
pub max_results: usize,
|
||||
/// 最低分数阈值 [0.0, 1.0](默认 0.1)。
|
||||
pub min_score: f32,
|
||||
}
|
||||
|
||||
impl Default for RetrieverConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_results: 20,
|
||||
min_score: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 单条带评分的检索结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScoredItem {
|
||||
pub page: KnowledgePage,
|
||||
/// TextOverlap 评分 [0.0, 1.0]
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
/// 检索结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalResult {
|
||||
pub items: Vec<ScoredItem>,
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
/// 记忆检索器 —— 在 `KnowledgeStore` 中做关键词检索并按 TextOverlap 评分。
|
||||
pub struct MemoryRetriever {
|
||||
knowledge_store: KnowledgeStore,
|
||||
config: RetrieverConfig,
|
||||
/// 停用词表(用于关键词提取)。
|
||||
stop_words: HashSet<String>,
|
||||
}
|
||||
|
||||
impl MemoryRetriever {
|
||||
/// 创建一个新的 MemoryRetriever。
|
||||
pub fn new(knowledge_store: KnowledgeStore, config: RetrieverConfig) -> Self {
|
||||
Self {
|
||||
knowledge_store,
|
||||
config,
|
||||
stop_words: default_stop_words(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 替换停用词表。
|
||||
pub fn with_stop_words(mut self, stop_words: HashSet<String>) -> Self {
|
||||
self.stop_words = stop_words;
|
||||
self
|
||||
}
|
||||
|
||||
/// 检索相关知识页面。
|
||||
pub async fn retrieve(&self, query: &str) -> Result<RetrievalResult, MemoryError> {
|
||||
if query.is_empty() {
|
||||
return Ok(RetrievalResult {
|
||||
items: Vec::new(),
|
||||
query: query.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// 1. 关键词提取
|
||||
let keywords = extract_keywords(query, &self.stop_words);
|
||||
|
||||
// 2. 用关键词在 KnowledgeStore 中搜索
|
||||
let mut pages = Vec::new();
|
||||
for keyword in &keywords {
|
||||
let found = self.knowledge_store.search(keyword).await?;
|
||||
for page in found {
|
||||
if !pages.iter().any(|p: &KnowledgePage| p.id == page.id) {
|
||||
pages.push(page);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. TextOverlap 评分
|
||||
let mut items: Vec<ScoredItem> = pages
|
||||
.into_iter()
|
||||
.map(|page| {
|
||||
let score = text_overlap_score(query, &page);
|
||||
ScoredItem { page, score }
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 4. 过滤 → 排序 → 截取
|
||||
items.retain(|i| i.score >= self.config.min_score);
|
||||
items.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
|
||||
items.truncate(self.config.max_results);
|
||||
|
||||
Ok(RetrievalResult {
|
||||
items,
|
||||
query: query.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 query 中提取关键词:按非字母数字字符分割 → 转小写 → 过滤单字符和停用词。
|
||||
fn extract_keywords(query: &str, stop_words: &HashSet<String>) -> Vec<String> {
|
||||
query
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter_map(|s| {
|
||||
let lower = s.to_lowercase();
|
||||
if lower.is_empty() || lower.chars().count() < 2 {
|
||||
None
|
||||
} else if stop_words.contains(&lower) {
|
||||
None
|
||||
} else {
|
||||
Some(lower)
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// TextOverlap 评分(基于字符 bigram 的 Dice 系数 + 多字段加权)。
|
||||
///
|
||||
/// 字段权重:title 0.5 + summary 0.3 + content 0.2
|
||||
/// 中文场景按字符级 bigram 处理,不依赖分词器。
|
||||
pub fn text_overlap_score(query: &str, page: &KnowledgePage) -> f32 {
|
||||
let title = text_overlap_dice(query, &page.title);
|
||||
let summary = text_overlap_dice(query, &page.summary);
|
||||
let content = text_overlap_dice(query, &page.content);
|
||||
title * 0.5 + summary * 0.3 + content * 0.2
|
||||
}
|
||||
|
||||
/// Dice 系数(基于字符 bigram)。
|
||||
fn text_overlap_dice(query: &str, text: &str) -> f32 {
|
||||
let q_bigrams = char_bigrams(query);
|
||||
let t_bigrams = char_bigrams(text);
|
||||
if q_bigrams.is_empty() || t_bigrams.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let q_set: HashSet<&String> = q_bigrams.iter().collect();
|
||||
let t_set: HashSet<&String> = t_bigrams.iter().collect();
|
||||
let intersect = q_set.intersection(&t_set).count();
|
||||
let denom = q_set.len() + t_set.len();
|
||||
if denom == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(2.0 * intersect as f32) / denom as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// 提取字符 bigrams(用于字符级 Dice 系数)。
|
||||
fn char_bigrams(s: &str) -> Vec<String> {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
chars.windows(2).map(|w| w.iter().collect()).collect()
|
||||
}
|
||||
|
||||
fn default_stop_words() -> HashSet<String> {
|
||||
[
|
||||
"the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
|
||||
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "shall",
|
||||
"can", "this", "that", "these", "those", "it", "its", "they", "them", "their", "what",
|
||||
"which", "who", "whom", "how", "when", "where", "and", "or", "but", "not", "no", "nor",
|
||||
"so", "if", "then", "else", "with", "without", "for", "to", "from", "in", "on", "at",
|
||||
"by", "of", "as", "into", "through", "during", "before", "after", "above", "below",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::knowledge::KnowledgeStore;
|
||||
use crate::memory::{InMemoryStore, MemoryStore};
|
||||
use std::sync::Arc;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_page(id: &str, title: &str, summary: &str, content: &str) -> KnowledgePage {
|
||||
let now = OffsetDateTime::now_utc();
|
||||
KnowledgePage {
|
||||
id: id.to_string(),
|
||||
title: title.to_string(),
|
||||
summary: summary.to_string(),
|
||||
content: content.to_string(),
|
||||
tags: Vec::new(),
|
||||
references: Vec::new(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_empty_query() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||
let result = retriever.retrieve("").await.unwrap();
|
||||
assert!(result.items.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_finds_relevant_page() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page(
|
||||
"p1",
|
||||
"LangGraph StateGraph",
|
||||
"state management",
|
||||
"LangGraph uses StateGraph for state machines",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
ks.add_page(make_page("p2", "Other", "unrelated", "nothing matching"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let retriever = MemoryRetriever::new(ks, RetrieverConfig::default());
|
||||
let result = retriever.retrieve("LangGraph state").await.unwrap();
|
||||
assert!(!result.items.is_empty());
|
||||
assert_eq!(result.items[0].page.id, "p1");
|
||||
assert!(result.items[0].score > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_respects_min_score() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
ks.add_page(make_page("p1", "X", "Y", "Z")).await.unwrap();
|
||||
|
||||
let config = RetrieverConfig {
|
||||
max_results: 10,
|
||||
min_score: 0.99,
|
||||
};
|
||||
let retriever = MemoryRetriever::new(ks, config);
|
||||
let result = retriever.retrieve("totally unrelated content").await.unwrap();
|
||||
assert!(result.items.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retrieve_respects_max_results() {
|
||||
let store = Arc::new(InMemoryStore::new()) as Arc<dyn MemoryStore>;
|
||||
let ks = KnowledgeStore::new(store);
|
||||
for i in 0..5 {
|
||||
ks.add_page(make_page(
|
||||
&format!("p{i}"),
|
||||
"LangGraph",
|
||||
"framework",
|
||||
"agent runtime",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
let config = RetrieverConfig {
|
||||
max_results: 2,
|
||||
min_score: 0.0,
|
||||
};
|
||||
let retriever = MemoryRetriever::new(ks, config);
|
||||
let result = retriever.retrieve("LangGraph").await.unwrap();
|
||||
assert_eq!(result.items.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_overlap_dice_zero_on_empty() {
|
||||
assert_eq!(text_overlap_dice("hello", ""), 0.0);
|
||||
assert_eq!(text_overlap_dice("", "hello"), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_overlap_dice_identical() {
|
||||
let s = "hello world";
|
||||
assert!((text_overlap_dice(s, s) - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_keywords_filters_stop_words() {
|
||||
let stop = default_stop_words();
|
||||
let kws = extract_keywords("the quick brown fox is fast", &stop);
|
||||
assert!(!kws.contains(&"the".to_string()));
|
||||
assert!(!kws.contains(&"is".to_string()));
|
||||
assert!(kws.contains(&"quick".to_string()));
|
||||
assert!(kws.contains(&"brown".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_keywords_filters_single_chars() {
|
||||
let stop = default_stop_words();
|
||||
let kws = extract_keywords("a b c dog", &stop);
|
||||
assert!(!kws.contains(&"a".to_string()));
|
||||
assert!(!kws.contains(&"b".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
//! MemoryStore 抽象接口与默认实现。
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
use crate::memory::error::MemoryError;
|
||||
use crate::memory::types::{MemoryFilter, MemoryItem};
|
||||
|
||||
/// 底层记忆存储抽象接口。
|
||||
///
|
||||
/// 下游可实现此 trait 以对接持久化后端(JSON 文件、SQLite、Redis 等)。
|
||||
/// 默认实现 [`InMemoryStore`] 基于进程内 HashMap。
|
||||
#[async_trait]
|
||||
pub trait MemoryStore: Send + Sync {
|
||||
/// 保存/覆盖一个 MemoryItem(upsert 语义)。
|
||||
/// - 如果 id 不存在,则插入新条目
|
||||
/// - 如果 id 已存在,则覆盖旧条目
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError>;
|
||||
|
||||
/// 根据 id 获取一个 MemoryItem。
|
||||
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError>;
|
||||
|
||||
/// 根据 id 删除一个 MemoryItem。
|
||||
async fn delete(&self, id: &str) -> Result<(), MemoryError>;
|
||||
|
||||
/// 根据 filter 列出 MemoryItem。
|
||||
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError>;
|
||||
}
|
||||
|
||||
/// 淘汰策略。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EvictionPolicy {
|
||||
/// 不淘汰(默认)。
|
||||
None,
|
||||
/// 超过存活时间(秒)淘汰。
|
||||
Ttl { ttl_secs: u64 },
|
||||
/// 超过容量上限淘汰最旧(基于 created_at)。
|
||||
Capacity { max_items: usize },
|
||||
}
|
||||
|
||||
/// 淘汰配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EvictionConfig {
|
||||
pub policy: EvictionPolicy,
|
||||
/// 每写入 N 条后检查一次淘汰条件。
|
||||
pub check_interval: usize,
|
||||
}
|
||||
|
||||
impl Default for EvictionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
policy: EvictionPolicy::None,
|
||||
check_interval: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 进程内默认实现 —— 基于 HashMap + Mutex,纯内存。
|
||||
pub struct InMemoryStore {
|
||||
items: Mutex<HashMap<String, MemoryItem>>,
|
||||
eviction: EvictionConfig,
|
||||
/// 自上次淘汰检查以来的写入次数。
|
||||
writes_since_check: Mutex<usize>,
|
||||
}
|
||||
|
||||
impl InMemoryStore {
|
||||
/// 创建一个无淘汰策略的 InMemoryStore。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items: Mutex::new(HashMap::new()),
|
||||
eviction: EvictionConfig::default(),
|
||||
writes_since_check: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建一个带淘汰配置的 InMemoryStore。
|
||||
pub fn with_eviction(eviction: EvictionConfig) -> Self {
|
||||
Self {
|
||||
items: Mutex::new(HashMap::new()),
|
||||
eviction,
|
||||
writes_since_check: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_evict(&self) {
|
||||
// 不使用 .lock().await 跨点,先取计数判断是否需要淘汰
|
||||
let should_check = {
|
||||
let mut counter = self.writes_since_check.lock().unwrap();
|
||||
*counter += 1;
|
||||
if *counter >= self.eviction.check_interval {
|
||||
*counter = 0;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
if !should_check {
|
||||
return;
|
||||
}
|
||||
|
||||
let policy = self.eviction.policy.clone();
|
||||
match policy {
|
||||
EvictionPolicy::None => {}
|
||||
EvictionPolicy::Ttl { ttl_secs } => {
|
||||
let cutoff = OffsetDateTime::now_utc() - time::Duration::seconds(ttl_secs as i64);
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.retain(|_, v| v.created_at > cutoff);
|
||||
}
|
||||
EvictionPolicy::Capacity { max_items } => {
|
||||
let mut items = self.items.lock().unwrap();
|
||||
if items.len() > max_items {
|
||||
let mut vec: Vec<_> = items.drain().collect();
|
||||
// O(n) 部分排序:保留 created_at 最大的 max_items 个
|
||||
vec.select_nth_unstable_by(max_items, |a, b| {
|
||||
b.1.created_at.cmp(&a.1.created_at)
|
||||
});
|
||||
vec.truncate(max_items);
|
||||
*items = vec.into_iter().collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MemoryStore for InMemoryStore {
|
||||
async fn save(&self, item: MemoryItem) -> Result<(), MemoryError> {
|
||||
{
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.insert(item.id.clone(), item);
|
||||
}
|
||||
self.maybe_evict();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, id: &str) -> Result<Option<MemoryItem>, MemoryError> {
|
||||
let items = self.items.lock().unwrap();
|
||||
Ok(items.get(id).cloned())
|
||||
}
|
||||
|
||||
async fn delete(&self, id: &str) -> Result<(), MemoryError> {
|
||||
let mut items = self.items.lock().unwrap();
|
||||
items.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list(&self, filter: &MemoryFilter) -> Result<Vec<MemoryItem>, MemoryError> {
|
||||
let items = self.items.lock().unwrap();
|
||||
let mut result: Vec<MemoryItem> = items
|
||||
.values()
|
||||
.filter(|v| match &filter.prefix {
|
||||
Some(p) => v.id.starts_with(p),
|
||||
None => true,
|
||||
})
|
||||
.filter(|v| match filter.since {
|
||||
Some(t) => v.created_at > t,
|
||||
None => true,
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
// 按 created_at 升序排列(最旧在前)
|
||||
result.sort_by_key(|v| v.created_at);
|
||||
// 应用 offset
|
||||
if let Some(offset) = filter.offset {
|
||||
if offset < result.len() {
|
||||
result.drain(..offset);
|
||||
} else {
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
// 应用 limit
|
||||
if let Some(limit) = filter.limit {
|
||||
result.truncate(limit);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use time::OffsetDateTime;
|
||||
|
||||
fn make_item(id: &str) -> MemoryItem {
|
||||
MemoryItem {
|
||||
id: id.to_string(),
|
||||
content: format!("content-{id}"),
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: OffsetDateTime::now_utc(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_get_delete_list() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
|
||||
let got = store.get("a").await.unwrap();
|
||||
assert!(got.is_some());
|
||||
assert_eq!(got.unwrap().id, "a");
|
||||
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
|
||||
store.delete("a").await.unwrap();
|
||||
assert!(store.get("a").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_is_upsert() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
let mut item = make_item("a");
|
||||
item.content = "updated".to_string();
|
||||
store.save(item).await.unwrap();
|
||||
let got = store.get("a").await.unwrap().unwrap();
|
||||
assert_eq!(got.content, "updated");
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_with_prefix_and_limit() {
|
||||
let store = InMemoryStore::new();
|
||||
store.save(make_item("foo_a")).await.unwrap();
|
||||
store.save(make_item("foo_b")).await.unwrap();
|
||||
store.save(make_item("bar_a")).await.unwrap();
|
||||
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some("foo_".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let list = store.list(&filter).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
|
||||
let filter = MemoryFilter {
|
||||
prefix: Some("foo_".to_string()),
|
||||
limit: Some(1),
|
||||
..Default::default()
|
||||
};
|
||||
let list = store.list(&filter).await.unwrap();
|
||||
assert_eq!(list.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn capacity_eviction() {
|
||||
// 强制每次写入都检查
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::Capacity { max_items: 2 },
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
// 第一条和第二条共存
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
// 第三条写入触发淘汰:a 或 b 之一被淘汰
|
||||
store.save(make_item("c")).await.unwrap();
|
||||
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 2);
|
||||
// 留下的应该是 b 和 c(最新的两个)
|
||||
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||
assert!(ids.contains(&"b"));
|
||||
assert!(ids.contains(&"c"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ttl_eviction() {
|
||||
// TTL 设为 0 会立即过期,但我们想保留 "a" 等待 "b" 写入后被淘汰。
|
||||
// 改用小 TTL + 睡眠:先 save a,sleep,save b 时 a 已过期被淘汰。
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::Ttl { ttl_secs: 1 },
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
store.save(make_item("a")).await.unwrap();
|
||||
// 等待超过 1 秒
|
||||
std::thread::sleep(std::time::Duration::from_millis(1100));
|
||||
// 触发淘汰:a 已超过 ttl_secs=1,应被淘汰
|
||||
store.save(make_item("b")).await.unwrap();
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
// 由于 ttl_secs=1,且 b 刚写入,可能刚好处于临界值。
|
||||
// 我们只断言 list 不包含 "a" 即可。
|
||||
let ids: Vec<&str> = list.iter().map(|v| v.id.as_str()).collect();
|
||||
assert!(
|
||||
!ids.contains(&"a"),
|
||||
"expected 'a' to be evicted, but found in {ids:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn none_policy_no_eviction() {
|
||||
let eviction = EvictionConfig {
|
||||
policy: EvictionPolicy::None,
|
||||
check_interval: 1,
|
||||
};
|
||||
let store = InMemoryStore::with_eviction(eviction);
|
||||
for i in 0..100 {
|
||||
store.save(make_item(&format!("item_{i}"))).await.unwrap();
|
||||
}
|
||||
let list = store.list(&MemoryFilter::default()).await.unwrap();
|
||||
assert_eq!(list.len(), 100);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
//! 记忆系统核心数据类型。
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
/// 记忆条目 —— MemoryStore 存储的基本单元。
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryItem {
|
||||
/// 唯一标识。
|
||||
pub id: String,
|
||||
/// 内容(通常为 JSON 序列化的具体记忆数据)。
|
||||
pub content: String,
|
||||
/// 任意附加元数据。
|
||||
pub metadata: serde_json::Value,
|
||||
/// 创建时间。
|
||||
pub created_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
/// MemoryStore 列表查询条件。
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MemoryFilter {
|
||||
/// 按 id 前缀过滤。
|
||||
pub prefix: Option<String>,
|
||||
/// 仅返回该时间之后创建的条目。
|
||||
pub since: Option<OffsetDateTime>,
|
||||
/// 跳过前 N 条。
|
||||
pub offset: Option<usize>,
|
||||
/// 最多返回 N 条。
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// 知识页面 —— 描述一段结构化知识。
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KnowledgePage {
|
||||
/// 唯一标识。
|
||||
pub id: String,
|
||||
/// 标题。
|
||||
pub title: String,
|
||||
/// 一句话摘要。
|
||||
pub summary: String,
|
||||
/// 完整内容。
|
||||
pub content: String,
|
||||
/// 检索标签。
|
||||
pub tags: Vec<String>,
|
||||
/// 交叉引用的其他页面 ID。
|
||||
pub references: Vec<String>,
|
||||
/// 创建时间。
|
||||
pub created_at: OffsetDateTime,
|
||||
/// 最后更新时间。
|
||||
pub updated_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
/// 知识页面索引条目 —— 用于轻量遍历和内容目录。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PageIndexEntry {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub summary: String,
|
||||
pub tags: Vec<String>,
|
||||
pub updated_at: OffsetDateTime,
|
||||
}
|
||||
|
||||
impl From<&KnowledgePage> for PageIndexEntry {
|
||||
fn from(p: &KnowledgePage) -> Self {
|
||||
Self {
|
||||
id: p.id.clone(),
|
||||
title: p.title.clone(),
|
||||
summary: p.summary.clone(),
|
||||
tags: p.tags.clone(),
|
||||
updated_at: p.updated_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user