Files
agcore/src/agent/session_memory.rs
T
徐涛 ce1f1aaca0 feat(agent): 实现 Phase 4c 会话级记忆功能
- 新增 `SessionMemory` 结构体,基于 `MemoryStore` 按 namespace 隔离键值数据
- `AgentBuilder` 增加 `session_memory_backend` 配置入口
- `RuntimeBundle` 透传 `session_memory_backend` 字段
- `AgentSession` 将内联 `HashMap` 替换为完整的 `SessionMemory`,`set_session_data` 和 `get_session_data` 改为异步方法
- 新增 3 个内联测试,全量测试从 113 增至 116,clippy 0 警告
2026-06-11 22:14:15 +08:00

184 lines
5.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! SessionMemory —— 会话级记忆,用于 context 间的信息桥接。
//!
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.8):
//!
//! - **会话级**:单 session 内共享,跨 context 桥接信息(不是持久层,也不是对话历史)
//! - **复用 Phase 3 `MemoryStore`**:不引入新的存储后端机制
//! - **按 `namespace` 隔离**:每个 session 一个独立命名空间,防止跨 session 泄漏
//! - **`snapshot()` 格式化为标记文本**:专为注入 system prompt 设计
//! - **所有方法为 `async`**:因为后端可能是跨进程的(Redis / DB)
use std::sync::Arc;
use time::OffsetDateTime;
use crate::agent::error::AgentError;
use crate::memory::store::MemoryStore;
use crate::memory::types::{MemoryFilter, MemoryItem};
/// 会话级记忆实例。
///
/// 基于 [`MemoryStore`] 后端,按 `namespace` 隔离键值数据。
/// 适用于 session 内各 context 之间的信息桥接(如将关键结论传递给后续 context)。
pub struct SessionMemory {
store: Arc<dyn MemoryStore>,
namespace: String,
}
impl SessionMemory {
/// 创建新的 session 级记忆实例。
///
/// - `store`:后端存储(可跨进程共享的 `MemoryStore` 实现)。
/// - `namespace`:按 session_id 隔离,防止跨 session 泄漏。
/// 内部会自动添加 `"_session_"` 前缀。
pub fn new(store: Arc<dyn MemoryStore>, namespace: &str) -> Self {
Self {
store,
namespace: format!("_session_{namespace}"),
}
}
/// 内部 key 格式:`"{namespace}:{key}"`。
fn internal_key(&self, key: &str) -> String {
format!("{}:{}", self.namespace, key)
}
/// 写入一条 key-value 条目(覆盖同名 key)。
pub async fn set(&self, key: &str, value: &str) -> Result<(), AgentError> {
let item = MemoryItem {
id: self.internal_key(key),
content: value.to_string(),
metadata: serde_json::json!({}),
created_at: OffsetDateTime::now_utc(),
};
self.store.save(item).await.map_err(AgentError::Memory)
}
/// 读取指定 key 的值。
pub async fn get(&self, key: &str) -> Result<Option<String>, AgentError> {
let item = self
.store
.get(&self.internal_key(key))
.await
.map_err(AgentError::Memory)?;
Ok(item.map(|i| i.content))
}
/// 返回所有条目的格式化快照,适合注入 system prompt。
///
/// 格式:
/// ```text
/// <session-context>
/// key1: value1
/// key2: value2
/// </session-context>
/// ```
pub async fn snapshot(&self) -> Result<String, AgentError> {
let filter = MemoryFilter {
prefix: Some(format!("{}:", self.namespace)),
..Default::default()
};
let items = self
.store
.list(&filter)
.await
.map_err(AgentError::Memory)?;
let mut lines = Vec::with_capacity(items.len() + 2);
lines.push("<session-context>".to_string());
for item in items {
// 从 id 中提取原始 key(去掉 namespace 前缀)
let key = item
.id
.strip_prefix(&format!("{}:", self.namespace))
.unwrap_or(&item.id);
lines.push(format!("{}: {}", key, item.content));
}
lines.push("</session-context>".to_string());
Ok(lines.join("\n"))
}
/// 删除指定 key。
pub async fn remove(&self, key: &str) -> Result<(), AgentError> {
self.store
.delete(&self.internal_key(key))
.await
.map_err(AgentError::Memory)
}
/// 清空当前 namespace 下所有条目。
pub async fn clear(&self) -> Result<(), AgentError> {
let filter = MemoryFilter {
prefix: Some(format!("{}:", self.namespace)),
..Default::default()
};
let items = self
.store
.list(&filter)
.await
.map_err(AgentError::Memory)?;
for item in items {
self.store
.delete(&item.id)
.await
.map_err(AgentError::Memory)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::store::InMemoryStore;
fn make_store() -> Arc<dyn MemoryStore> {
Arc::new(InMemoryStore::new())
}
/// 烟雾测试 1set / get / remove 基本读写。
#[tokio::test]
async fn set_get_remove() {
let mem = SessionMemory::new(make_store(), "test-session");
assert!(mem.get("k").await.unwrap().is_none());
mem.set("k", "v").await.unwrap();
assert_eq!(mem.get("k").await.unwrap(), Some("v".into()));
mem.remove("k").await.unwrap();
assert!(mem.get("k").await.unwrap().is_none());
}
/// 烟雾测试 2snapshot 格式化输出。
#[tokio::test]
async fn snapshot_format() {
let mem = SessionMemory::new(make_store(), "s1");
mem.set("design", "PostgreSQL").await.unwrap();
mem.set("lang", "Rust").await.unwrap();
let snap = mem.snapshot().await.unwrap();
assert!(snap.contains("<session-context>"));
assert!(snap.contains("</session-context>"));
assert!(snap.contains("design: PostgreSQL"));
assert!(snap.contains("lang: Rust"));
}
/// 烟雾测试 3clear 清空当前 namespace。
#[tokio::test]
async fn clear_only_affects_own_namespace() {
let store = make_store();
let mem_a = SessionMemory::new(store.clone(), "a");
let mem_b = SessionMemory::new(store.clone(), "b");
mem_a.set("key", "val_a").await.unwrap();
mem_b.set("key", "val_b").await.unwrap();
mem_a.clear().await.unwrap();
assert!(mem_a.get("key").await.unwrap().is_none());
assert_eq!(mem_b.get("key").await.unwrap(), Some("val_b".into()));
}
}