//! AgentSession —— 智能体"会话"实例。 //! //! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3): //! //! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far` //! - **最小 reference impl**:`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程 //! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定 //! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message` use std::collections::HashMap; use std::sync::Arc; use crate::agent::agent::Agent; use crate::agent::error::AgentError; use crate::agent::runtime::RuntimeBundle; use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle}; use crate::llm::hooks::{HookContext, HookEvent}; use crate::llm::types::ChatResponse; /// Agent 会话实例。 /// /// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。 /// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。 /// /// **不实现 `Clone`**:session 持有累计 `turn_index` / `cost_so_far` / `session_data`, /// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc>` 包装。 pub struct AgentSession { /// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。 pub session_id: String, /// 角色(可热切换为同 bundle 下的其他角色)。 pub agent: Arc, bundle: Arc, turn_index: u32, cost_so_far: CostTracker, /// 会话级键值数据(Phase 4a 用内联 HashMap;Phase 4c 替换为 `SessionMemory`)。 session_data: HashMap, } impl std::fmt::Debug for AgentSession { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AgentSession") .field("session_id", &self.session_id) .field("agent", &self.agent.name()) .field("turn_index", &self.turn_index) .field("cost_so_far", &self.cost_so_far.total()) .field("session_data_keys", &self.session_data.keys().collect::>()) .finish() } } impl AgentSession { /// 创建一个新的会话实例。 /// /// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。 pub fn new( agent: Arc, session_id: impl Into, bundle: Arc, ) -> Self { Self { session_id: session_id.into(), agent, bundle, turn_index: 0, cost_so_far: CostTracker::default(), session_data: HashMap::new(), } } /// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。 pub fn turn_index(&self) -> u32 { self.turn_index } /// 累计用量(跨所有 turn)。 pub fn usage(&self) -> &CostTracker { &self.cost_so_far } /// 会话级数据快照引用。 pub fn session_data(&self) -> &HashMap { &self.session_data } /// 写入一条会话级数据(覆盖同名 key)。 pub fn set_session_data(&mut self, key: impl Into, value: impl Into) { self.session_data.insert(key.into(), value.into()); } /// 读取一条会话级数据。 pub fn get_session_data(&self, key: &str) -> Option<&str> { self.session_data.get(key).map(String::as_str) } /// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。 /// /// 流程: /// 1. 触发 `OnTurnStart` hook /// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史) /// 3. `submit_with_tools` 跑单轮对话 /// 4. 累计 `cost_so_far` /// 5. 触发 `OnTurnEnd` hook /// 6. `turn_index += 1` /// /// **不做**: /// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写) /// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`) /// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`) pub async fn submit_turn( &mut self, user_input: impl Into, ) -> Result { let turn_index = self.turn_index; let hook_executor = Arc::clone(&self.bundle.hook_executor); // 1. 触发 OnTurnStart hook let start_ctx = HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index); hook_executor .execute(HookEvent::OnTurnStart, &start_ctx) .await; // 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄 // 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量); // submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发 // 子 trait 覆盖(白名单/过滤)的副作用。 let _ = self.agent.tool_definitions(&self.bundle); let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default()) .with_messages(Vec::new()); if let Some(prompt) = self.agent.system_prompt() { cycle = cycle.with_system_prompt(prompt.to_string()); } if let Some(cfg) = self.bundle.config.compact_config.clone() { cycle = cycle.with_compact_config(cfg); } // 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发) let response = cycle .submit_with_tools(user_input.into(), &self.bundle.tool_registry) .await?; // 4. 累计 cost self.cost_so_far.add(&response.usage); // 5. 触发 OnTurnEnd hook let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index); hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await; // 6. turn_index 递增 self.turn_index += 1; Ok(response) } } #[cfg(test)] mod tests { use super::*; use crate::agent::builder::AgentBuilder; use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult}; use crate::llm::provider::LlmProvider; use crate::llm::types::{ ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage, }; use crate::tools::ToolRegistry; use async_trait::async_trait; use std::sync::atomic::{AtomicU32, Ordering}; /// 计数 hook —— 每被调用一次 +1。 struct CountHook(AtomicU32); #[async_trait] impl Hook for CountHook { async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult { self.0.fetch_add(1, Ordering::SeqCst); HookResult::allow() } } /// 把 `Arc` 包装为 `Box`(dyn Hook 不能直接来自 Arc)。 struct CountHookAdapter(Arc); #[async_trait] impl Hook for CountHookAdapter { async fn execute(&self, ctx: &HookContext<'_>) -> HookResult { self.0.execute(ctx).await } } /// MockProvider:按调用顺序返回预设响应。 struct MockProvider { responses: std::sync::Mutex>, } impl MockProvider { fn new(responses: Vec) -> Self { Self { responses: std::sync::Mutex::new(responses), } } } #[async_trait] impl LlmProvider for MockProvider { async fn chat(&self, _request: ChatRequest) -> Result { let mut responses = self.responses.lock().unwrap(); if responses.is_empty() { return Err(crate::llm::error::LlmError::Other( "no more mock responses".into(), )); } Ok(responses.remove(0)) } } struct StubAgent { name: String, prompt: Option, } impl Agent for StubAgent { fn name(&self) -> &str { &self.name } fn system_prompt(&self) -> Option<&str> { self.prompt.as_deref() } } fn assistant_text(text: &str) -> ChatResponse { ChatResponse { message: OpenaiChatMessage::assistant_text(text), usage: crate::llm::types::Usage::from_input_output(10, 5), stop_reason: Some(FinishReason::Stop), } } /// 烟雾测试 1:AgentSession::submit_turn 跑通 mock provider。 #[tokio::test] async fn submit_turn_runs_with_mock_provider() { let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")])); let agent = Arc::new(StubAgent { name: "stub".into(), prompt: Some("you are a test agent".into()), }); let bundle = Arc::new( AgentBuilder::new() .provider(provider) .tool_registry(Arc::new(ToolRegistry::new())) .hook_executor(Arc::new(HookExecutor::new())) .build() .unwrap(), ); let mut session = AgentSession::new(agent, "s1", bundle); assert_eq!(session.turn_index(), 0); let response = session.submit_turn("hi").await.unwrap(); let text = match &response.message { OpenaiChatMessage::Assistant { content, .. } => { if let crate::llm::types::ContentField::String(s) = content { s.clone() } else { String::new() } } _ => String::new(), }; assert_eq!(text, "hello back"); assert_eq!(session.turn_index(), 1); assert_eq!(session.usage().total().prompt_tokens, 10); assert_eq!(session.usage().total().completion_tokens, 5); } /// 烟雾测试 2:session_data 读写。 #[test] fn session_data_set_get() { let provider = Arc::new(MockProvider::new(vec![])); let agent = Arc::new(StubAgent { name: "stub".into(), prompt: None, }); let bundle = Arc::new( AgentBuilder::new() .provider(provider) .tool_registry(Arc::new(ToolRegistry::new())) .hook_executor(Arc::new(HookExecutor::new())) .build() .unwrap(), ); let mut session = AgentSession::new(agent, "s2", bundle); assert!(session.get_session_data("k").is_none()); session.set_session_data("k", "v"); assert_eq!(session.get_session_data("k"), Some("v")); // 覆盖写 session.set_session_data("k", "v2"); assert_eq!(session.get_session_data("k"), Some("v2")); } /// 烟雾测试 3:submit_turn 触发 OnTurnStart / OnTurnEnd hook。 #[tokio::test] async fn submit_turn_triggers_turn_hooks() { let mut hook_executor = HookExecutor::new(); let start_count = Arc::new(CountHook(AtomicU32::new(0))); let end_count = Arc::new(CountHook(AtomicU32::new(0))); hook_executor.register( HookEvent::OnTurnStart, Box::new(CountHookAdapter(start_count.clone())), ); hook_executor.register( HookEvent::OnTurnEnd, Box::new(CountHookAdapter(end_count.clone())), ); let provider = Arc::new(MockProvider::new(vec![ assistant_text("ok"), assistant_text("ok 2"), ])); let agent = Arc::new(StubAgent { name: "stub".into(), prompt: None, }); let bundle = Arc::new( AgentBuilder::new() .provider(provider) .tool_registry(Arc::new(ToolRegistry::new())) .hook_executor(Arc::new(hook_executor)) .build() .unwrap(), ); let mut session = AgentSession::new(agent, "s3", bundle); session.submit_turn("hi").await.unwrap(); assert_eq!(start_count.0.load(Ordering::SeqCst), 1); assert_eq!(end_count.0.load(Ordering::SeqCst), 1); session.submit_turn("hi again").await.unwrap(); assert_eq!(start_count.0.load(Ordering::SeqCst), 2); assert_eq!(end_count.0.load(Ordering::SeqCst), 2); } }