Files
agcore/src/agent/session.rs
T
徐涛 2b189880a9 feat(agent): 实现 Agent Runtime 核心胶水层 (Phase 4a)
- 添加 Agent trait、AgentSession、RuntimeBundle、AgentBuilder
- 添加 Plan/Step/StepStatus 任务规划数据结构
- 添加 AgentError 统一错误类型(聚合 LlmError/ToolError/MemoryError)
- 实现 submit_turn 单轮对话流程(含 hook 触发与 cost 累计)
- 扩展 LlmCycle 支持 Arc<dyn LlmProvider>
- 扩展 HookEvent 添加 OnTurnStart/OnTurnEnd
- 更新 roadmap 状态
2026-06-11 21:45:28 +08:00

343 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! AgentSession —— 智能体"会话"实例。
//!
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3):
//!
//! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far`
//! - **最小 reference impl**`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程
//! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定
//! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message`
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::agent::Agent;
use crate::agent::error::AgentError;
use crate::agent::runtime::RuntimeBundle;
use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle};
use crate::llm::hooks::{HookContext, HookEvent};
use crate::llm::types::ChatResponse;
/// Agent 会话实例。
///
/// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。
/// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。
///
/// **不实现 `Clone`**session 持有累计 `turn_index` / `cost_so_far` / `session_data`
/// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc<Mutex<_>>` 包装。
pub struct AgentSession {
/// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。
pub session_id: String,
/// 角色(可热切换为同 bundle 下的其他角色)。
pub agent: Arc<dyn Agent>,
bundle: Arc<RuntimeBundle>,
turn_index: u32,
cost_so_far: CostTracker,
/// 会话级键值数据(Phase 4a 用内联 HashMapPhase 4c 替换为 `SessionMemory`)。
session_data: HashMap<String, String>,
}
impl std::fmt::Debug for AgentSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentSession")
.field("session_id", &self.session_id)
.field("agent", &self.agent.name())
.field("turn_index", &self.turn_index)
.field("cost_so_far", &self.cost_so_far.total())
.field("session_data_keys", &self.session_data.keys().collect::<Vec<_>>())
.finish()
}
}
impl AgentSession {
/// 创建一个新的会话实例。
///
/// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。
pub fn new(
agent: Arc<dyn Agent>,
session_id: impl Into<String>,
bundle: Arc<RuntimeBundle>,
) -> Self {
Self {
session_id: session_id.into(),
agent,
bundle,
turn_index: 0,
cost_so_far: CostTracker::default(),
session_data: HashMap::new(),
}
}
/// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。
pub fn turn_index(&self) -> u32 {
self.turn_index
}
/// 累计用量(跨所有 turn)。
pub fn usage(&self) -> &CostTracker {
&self.cost_so_far
}
/// 会话级数据快照引用。
pub fn session_data(&self) -> &HashMap<String, String> {
&self.session_data
}
/// 写入一条会话级数据(覆盖同名 key)。
pub fn set_session_data(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.session_data.insert(key.into(), value.into());
}
/// 读取一条会话级数据。
pub fn get_session_data(&self, key: &str) -> Option<&str> {
self.session_data.get(key).map(String::as_str)
}
/// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。
///
/// 流程:
/// 1. 触发 `OnTurnStart` hook
/// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史)
/// 3. `submit_with_tools` 跑单轮对话
/// 4. 累计 `cost_so_far`
/// 5. 触发 `OnTurnEnd` hook
/// 6. `turn_index += 1`
///
/// **不做**
/// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写)
/// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`
/// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`
pub async fn submit_turn(
&mut self,
user_input: impl Into<String>,
) -> Result<ChatResponse, AgentError> {
let turn_index = self.turn_index;
let hook_executor = Arc::clone(&self.bundle.hook_executor);
// 1. 触发 OnTurnStart hook
let start_ctx =
HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index);
hook_executor
.execute(HookEvent::OnTurnStart, &start_ctx)
.await;
// 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄
// 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量);
// submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发
// 子 trait 覆盖(白名单/过滤)的副作用。
let _ = self.agent.tool_definitions(&self.bundle);
let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default())
.with_messages(Vec::new());
if let Some(prompt) = self.agent.system_prompt() {
cycle = cycle.with_system_prompt(prompt.to_string());
}
if let Some(cfg) = self.bundle.config.compact_config.clone() {
cycle = cycle.with_compact_config(cfg);
}
// 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发)
let response = cycle
.submit_with_tools(user_input.into(), &self.bundle.tool_registry)
.await?;
// 4. 累计 cost
self.cost_so_far.add(&response.usage);
// 5. 触发 OnTurnEnd hook
let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index);
hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await;
// 6. turn_index 递增
self.turn_index += 1;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::builder::AgentBuilder;
use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult};
use crate::llm::provider::LlmProvider;
use crate::llm::types::{
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage,
};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
/// 计数 hook —— 每被调用一次 +1。
struct CountHook(AtomicU32);
#[async_trait]
impl Hook for CountHook {
async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult {
self.0.fetch_add(1, Ordering::SeqCst);
HookResult::allow()
}
}
/// 把 `Arc<CountHook>` 包装为 `Box<dyn Hook>`dyn Hook 不能直接来自 Arc)。
struct CountHookAdapter(Arc<CountHook>);
#[async_trait]
impl Hook for CountHookAdapter {
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult {
self.0.execute(ctx).await
}
}
/// MockProvider:按调用顺序返回预设响应。
struct MockProvider {
responses: std::sync::Mutex<Vec<ChatResponse>>,
}
impl MockProvider {
fn new(responses: Vec<ChatResponse>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
}
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, crate::llm::error::LlmError> {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(crate::llm::error::LlmError::Other(
"no more mock responses".into(),
));
}
Ok(responses.remove(0))
}
}
struct StubAgent {
name: String,
prompt: Option<String>,
}
impl Agent for StubAgent {
fn name(&self) -> &str {
&self.name
}
fn system_prompt(&self) -> Option<&str> {
self.prompt.as_deref()
}
}
fn assistant_text(text: &str) -> ChatResponse {
ChatResponse {
message: OpenaiChatMessage::assistant_text(text),
usage: crate::llm::types::Usage::from_input_output(10, 5),
stop_reason: Some(FinishReason::Stop),
}
}
/// 烟雾测试 1AgentSession::submit_turn 跑通 mock provider。
#[tokio::test]
async fn submit_turn_runs_with_mock_provider() {
let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: Some("you are a test agent".into()),
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s1", bundle);
assert_eq!(session.turn_index(), 0);
let response = session.submit_turn("hi").await.unwrap();
let text = match &response.message {
OpenaiChatMessage::Assistant { content, .. } => {
if let crate::llm::types::ContentField::String(s) = content {
s.clone()
} else {
String::new()
}
}
_ => String::new(),
};
assert_eq!(text, "hello back");
assert_eq!(session.turn_index(), 1);
assert_eq!(session.usage().total().prompt_tokens, 10);
assert_eq!(session.usage().total().completion_tokens, 5);
}
/// 烟雾测试 2session_data 读写。
#[test]
fn session_data_set_get() {
let provider = Arc::new(MockProvider::new(vec![]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s2", bundle);
assert!(session.get_session_data("k").is_none());
session.set_session_data("k", "v");
assert_eq!(session.get_session_data("k"), Some("v"));
// 覆盖写
session.set_session_data("k", "v2");
assert_eq!(session.get_session_data("k"), Some("v2"));
}
/// 烟雾测试 3submit_turn 触发 OnTurnStart / OnTurnEnd hook。
#[tokio::test]
async fn submit_turn_triggers_turn_hooks() {
let mut hook_executor = HookExecutor::new();
let start_count = Arc::new(CountHook(AtomicU32::new(0)));
let end_count = Arc::new(CountHook(AtomicU32::new(0)));
hook_executor.register(
HookEvent::OnTurnStart,
Box::new(CountHookAdapter(start_count.clone())),
);
hook_executor.register(
HookEvent::OnTurnEnd,
Box::new(CountHookAdapter(end_count.clone())),
);
let provider = Arc::new(MockProvider::new(vec![
assistant_text("ok"),
assistant_text("ok 2"),
]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(hook_executor))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s3", bundle);
session.submit_turn("hi").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 1);
assert_eq!(end_count.0.load(Ordering::SeqCst), 1);
session.submit_turn("hi again").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 2);
assert_eq!(end_count.0.load(Ordering::SeqCst), 2);
}
}