2b189880a9
- 添加 Agent trait、AgentSession、RuntimeBundle、AgentBuilder - 添加 Plan/Step/StepStatus 任务规划数据结构 - 添加 AgentError 统一错误类型(聚合 LlmError/ToolError/MemoryError) - 实现 submit_turn 单轮对话流程(含 hook 触发与 cost 累计) - 扩展 LlmCycle 支持 Arc<dyn LlmProvider> - 扩展 HookEvent 添加 OnTurnStart/OnTurnEnd - 更新 roadmap 状态
343 lines
12 KiB
Rust
343 lines
12 KiB
Rust
//! AgentSession —— 智能体"会话"实例。
|
||
//!
|
||
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3):
|
||
//!
|
||
//! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far`
|
||
//! - **最小 reference impl**:`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程
|
||
//! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定
|
||
//! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message`
|
||
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
|
||
use crate::agent::agent::Agent;
|
||
use crate::agent::error::AgentError;
|
||
use crate::agent::runtime::RuntimeBundle;
|
||
use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle};
|
||
use crate::llm::hooks::{HookContext, HookEvent};
|
||
use crate::llm::types::ChatResponse;
|
||
|
||
/// Agent 会话实例。
|
||
///
|
||
/// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。
|
||
/// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。
|
||
///
|
||
/// **不实现 `Clone`**:session 持有累计 `turn_index` / `cost_so_far` / `session_data`,
|
||
/// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc<Mutex<_>>` 包装。
|
||
pub struct AgentSession {
|
||
/// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。
|
||
pub session_id: String,
|
||
/// 角色(可热切换为同 bundle 下的其他角色)。
|
||
pub agent: Arc<dyn Agent>,
|
||
bundle: Arc<RuntimeBundle>,
|
||
turn_index: u32,
|
||
cost_so_far: CostTracker,
|
||
/// 会话级键值数据(Phase 4a 用内联 HashMap;Phase 4c 替换为 `SessionMemory`)。
|
||
session_data: HashMap<String, String>,
|
||
}
|
||
|
||
impl std::fmt::Debug for AgentSession {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
f.debug_struct("AgentSession")
|
||
.field("session_id", &self.session_id)
|
||
.field("agent", &self.agent.name())
|
||
.field("turn_index", &self.turn_index)
|
||
.field("cost_so_far", &self.cost_so_far.total())
|
||
.field("session_data_keys", &self.session_data.keys().collect::<Vec<_>>())
|
||
.finish()
|
||
}
|
||
}
|
||
|
||
impl AgentSession {
|
||
/// 创建一个新的会话实例。
|
||
///
|
||
/// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。
|
||
pub fn new(
|
||
agent: Arc<dyn Agent>,
|
||
session_id: impl Into<String>,
|
||
bundle: Arc<RuntimeBundle>,
|
||
) -> Self {
|
||
Self {
|
||
session_id: session_id.into(),
|
||
agent,
|
||
bundle,
|
||
turn_index: 0,
|
||
cost_so_far: CostTracker::default(),
|
||
session_data: HashMap::new(),
|
||
}
|
||
}
|
||
|
||
/// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。
|
||
pub fn turn_index(&self) -> u32 {
|
||
self.turn_index
|
||
}
|
||
|
||
/// 累计用量(跨所有 turn)。
|
||
pub fn usage(&self) -> &CostTracker {
|
||
&self.cost_so_far
|
||
}
|
||
|
||
/// 会话级数据快照引用。
|
||
pub fn session_data(&self) -> &HashMap<String, String> {
|
||
&self.session_data
|
||
}
|
||
|
||
/// 写入一条会话级数据(覆盖同名 key)。
|
||
pub fn set_session_data(&mut self, key: impl Into<String>, value: impl Into<String>) {
|
||
self.session_data.insert(key.into(), value.into());
|
||
}
|
||
|
||
/// 读取一条会话级数据。
|
||
pub fn get_session_data(&self, key: &str) -> Option<&str> {
|
||
self.session_data.get(key).map(String::as_str)
|
||
}
|
||
|
||
/// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。
|
||
///
|
||
/// 流程:
|
||
/// 1. 触发 `OnTurnStart` hook
|
||
/// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史)
|
||
/// 3. `submit_with_tools` 跑单轮对话
|
||
/// 4. 累计 `cost_so_far`
|
||
/// 5. 触发 `OnTurnEnd` hook
|
||
/// 6. `turn_index += 1`
|
||
///
|
||
/// **不做**:
|
||
/// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写)
|
||
/// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`)
|
||
/// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`)
|
||
pub async fn submit_turn(
|
||
&mut self,
|
||
user_input: impl Into<String>,
|
||
) -> Result<ChatResponse, AgentError> {
|
||
let turn_index = self.turn_index;
|
||
let hook_executor = Arc::clone(&self.bundle.hook_executor);
|
||
|
||
// 1. 触发 OnTurnStart hook
|
||
let start_ctx =
|
||
HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index);
|
||
hook_executor
|
||
.execute(HookEvent::OnTurnStart, &start_ctx)
|
||
.await;
|
||
|
||
// 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄
|
||
// 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量);
|
||
// submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发
|
||
// 子 trait 覆盖(白名单/过滤)的副作用。
|
||
let _ = self.agent.tool_definitions(&self.bundle);
|
||
let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default())
|
||
.with_messages(Vec::new());
|
||
if let Some(prompt) = self.agent.system_prompt() {
|
||
cycle = cycle.with_system_prompt(prompt.to_string());
|
||
}
|
||
if let Some(cfg) = self.bundle.config.compact_config.clone() {
|
||
cycle = cycle.with_compact_config(cfg);
|
||
}
|
||
|
||
// 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发)
|
||
let response = cycle
|
||
.submit_with_tools(user_input.into(), &self.bundle.tool_registry)
|
||
.await?;
|
||
|
||
// 4. 累计 cost
|
||
self.cost_so_far.add(&response.usage);
|
||
|
||
// 5. 触发 OnTurnEnd hook
|
||
let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index);
|
||
hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await;
|
||
|
||
// 6. turn_index 递增
|
||
self.turn_index += 1;
|
||
|
||
Ok(response)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::agent::builder::AgentBuilder;
|
||
use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult};
|
||
use crate::llm::provider::LlmProvider;
|
||
use crate::llm::types::{
|
||
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage,
|
||
};
|
||
use crate::tools::ToolRegistry;
|
||
use async_trait::async_trait;
|
||
use std::sync::atomic::{AtomicU32, Ordering};
|
||
|
||
/// 计数 hook —— 每被调用一次 +1。
|
||
struct CountHook(AtomicU32);
|
||
|
||
#[async_trait]
|
||
impl Hook for CountHook {
|
||
async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult {
|
||
self.0.fetch_add(1, Ordering::SeqCst);
|
||
HookResult::allow()
|
||
}
|
||
}
|
||
|
||
/// 把 `Arc<CountHook>` 包装为 `Box<dyn Hook>`(dyn Hook 不能直接来自 Arc)。
|
||
struct CountHookAdapter(Arc<CountHook>);
|
||
|
||
#[async_trait]
|
||
impl Hook for CountHookAdapter {
|
||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult {
|
||
self.0.execute(ctx).await
|
||
}
|
||
}
|
||
|
||
/// MockProvider:按调用顺序返回预设响应。
|
||
struct MockProvider {
|
||
responses: std::sync::Mutex<Vec<ChatResponse>>,
|
||
}
|
||
|
||
impl MockProvider {
|
||
fn new(responses: Vec<ChatResponse>) -> Self {
|
||
Self {
|
||
responses: std::sync::Mutex::new(responses),
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LlmProvider for MockProvider {
|
||
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, crate::llm::error::LlmError> {
|
||
let mut responses = self.responses.lock().unwrap();
|
||
if responses.is_empty() {
|
||
return Err(crate::llm::error::LlmError::Other(
|
||
"no more mock responses".into(),
|
||
));
|
||
}
|
||
Ok(responses.remove(0))
|
||
}
|
||
}
|
||
|
||
struct StubAgent {
|
||
name: String,
|
||
prompt: Option<String>,
|
||
}
|
||
|
||
impl Agent for StubAgent {
|
||
fn name(&self) -> &str {
|
||
&self.name
|
||
}
|
||
fn system_prompt(&self) -> Option<&str> {
|
||
self.prompt.as_deref()
|
||
}
|
||
}
|
||
|
||
fn assistant_text(text: &str) -> ChatResponse {
|
||
ChatResponse {
|
||
message: OpenaiChatMessage::assistant_text(text),
|
||
usage: crate::llm::types::Usage::from_input_output(10, 5),
|
||
stop_reason: Some(FinishReason::Stop),
|
||
}
|
||
}
|
||
|
||
/// 烟雾测试 1:AgentSession::submit_turn 跑通 mock provider。
|
||
#[tokio::test]
|
||
async fn submit_turn_runs_with_mock_provider() {
|
||
let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")]));
|
||
let agent = Arc::new(StubAgent {
|
||
name: "stub".into(),
|
||
prompt: Some("you are a test agent".into()),
|
||
});
|
||
let bundle = Arc::new(
|
||
AgentBuilder::new()
|
||
.provider(provider)
|
||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||
.hook_executor(Arc::new(HookExecutor::new()))
|
||
.build()
|
||
.unwrap(),
|
||
);
|
||
|
||
let mut session = AgentSession::new(agent, "s1", bundle);
|
||
assert_eq!(session.turn_index(), 0);
|
||
|
||
let response = session.submit_turn("hi").await.unwrap();
|
||
let text = match &response.message {
|
||
OpenaiChatMessage::Assistant { content, .. } => {
|
||
if let crate::llm::types::ContentField::String(s) = content {
|
||
s.clone()
|
||
} else {
|
||
String::new()
|
||
}
|
||
}
|
||
_ => String::new(),
|
||
};
|
||
assert_eq!(text, "hello back");
|
||
assert_eq!(session.turn_index(), 1);
|
||
assert_eq!(session.usage().total().prompt_tokens, 10);
|
||
assert_eq!(session.usage().total().completion_tokens, 5);
|
||
}
|
||
|
||
/// 烟雾测试 2:session_data 读写。
|
||
#[test]
|
||
fn session_data_set_get() {
|
||
let provider = Arc::new(MockProvider::new(vec![]));
|
||
let agent = Arc::new(StubAgent {
|
||
name: "stub".into(),
|
||
prompt: None,
|
||
});
|
||
let bundle = Arc::new(
|
||
AgentBuilder::new()
|
||
.provider(provider)
|
||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||
.hook_executor(Arc::new(HookExecutor::new()))
|
||
.build()
|
||
.unwrap(),
|
||
);
|
||
let mut session = AgentSession::new(agent, "s2", bundle);
|
||
|
||
assert!(session.get_session_data("k").is_none());
|
||
session.set_session_data("k", "v");
|
||
assert_eq!(session.get_session_data("k"), Some("v"));
|
||
// 覆盖写
|
||
session.set_session_data("k", "v2");
|
||
assert_eq!(session.get_session_data("k"), Some("v2"));
|
||
}
|
||
|
||
/// 烟雾测试 3:submit_turn 触发 OnTurnStart / OnTurnEnd hook。
|
||
#[tokio::test]
|
||
async fn submit_turn_triggers_turn_hooks() {
|
||
let mut hook_executor = HookExecutor::new();
|
||
let start_count = Arc::new(CountHook(AtomicU32::new(0)));
|
||
let end_count = Arc::new(CountHook(AtomicU32::new(0)));
|
||
hook_executor.register(
|
||
HookEvent::OnTurnStart,
|
||
Box::new(CountHookAdapter(start_count.clone())),
|
||
);
|
||
hook_executor.register(
|
||
HookEvent::OnTurnEnd,
|
||
Box::new(CountHookAdapter(end_count.clone())),
|
||
);
|
||
|
||
let provider = Arc::new(MockProvider::new(vec![
|
||
assistant_text("ok"),
|
||
assistant_text("ok 2"),
|
||
]));
|
||
let agent = Arc::new(StubAgent {
|
||
name: "stub".into(),
|
||
prompt: None,
|
||
});
|
||
let bundle = Arc::new(
|
||
AgentBuilder::new()
|
||
.provider(provider)
|
||
.tool_registry(Arc::new(ToolRegistry::new()))
|
||
.hook_executor(Arc::new(hook_executor))
|
||
.build()
|
||
.unwrap(),
|
||
);
|
||
let mut session = AgentSession::new(agent, "s3", bundle);
|
||
|
||
session.submit_turn("hi").await.unwrap();
|
||
assert_eq!(start_count.0.load(Ordering::SeqCst), 1);
|
||
assert_eq!(end_count.0.load(Ordering::SeqCst), 1);
|
||
|
||
session.submit_turn("hi again").await.unwrap();
|
||
assert_eq!(start_count.0.load(Ordering::SeqCst), 2);
|
||
assert_eq!(end_count.0.load(Ordering::SeqCst), 2);
|
||
}
|
||
}
|