feat(agent): 实现 Agent Runtime 核心胶水层 (Phase 4a)

- 添加 Agent trait、AgentSession、RuntimeBundle、AgentBuilder
- 添加 Plan/Step/StepStatus 任务规划数据结构
- 添加 AgentError 统一错误类型(聚合 LlmError/ToolError/MemoryError)
- 实现 submit_turn 单轮对话流程(含 hook 触发与 cost 累计)
- 扩展 LlmCycle 支持 Arc<dyn LlmProvider>
- 扩展 HookEvent 添加 OnTurnStart/OnTurnEnd
- 更新 roadmap 状态
This commit is contained in:
徐涛
2026-06-11 21:45:28 +08:00
parent 59ec0f5597
commit 2b189880a9
11 changed files with 1025 additions and 19 deletions
+342
View File
@@ -0,0 +1,342 @@
//! AgentSession —— 智能体"会话"实例。
//!
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3):
//!
//! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far`
//! - **最小 reference impl**`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程
//! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定
//! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message`
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::agent::Agent;
use crate::agent::error::AgentError;
use crate::agent::runtime::RuntimeBundle;
use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle};
use crate::llm::hooks::{HookContext, HookEvent};
use crate::llm::types::ChatResponse;
/// Agent 会话实例。
///
/// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。
/// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。
///
/// **不实现 `Clone`**session 持有累计 `turn_index` / `cost_so_far` / `session_data`
/// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc<Mutex<_>>` 包装。
pub struct AgentSession {
/// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。
pub session_id: String,
/// 角色(可热切换为同 bundle 下的其他角色)。
pub agent: Arc<dyn Agent>,
bundle: Arc<RuntimeBundle>,
turn_index: u32,
cost_so_far: CostTracker,
/// 会话级键值数据(Phase 4a 用内联 HashMapPhase 4c 替换为 `SessionMemory`)。
session_data: HashMap<String, String>,
}
impl std::fmt::Debug for AgentSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentSession")
.field("session_id", &self.session_id)
.field("agent", &self.agent.name())
.field("turn_index", &self.turn_index)
.field("cost_so_far", &self.cost_so_far.total())
.field("session_data_keys", &self.session_data.keys().collect::<Vec<_>>())
.finish()
}
}
impl AgentSession {
/// 创建一个新的会话实例。
///
/// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。
pub fn new(
agent: Arc<dyn Agent>,
session_id: impl Into<String>,
bundle: Arc<RuntimeBundle>,
) -> Self {
Self {
session_id: session_id.into(),
agent,
bundle,
turn_index: 0,
cost_so_far: CostTracker::default(),
session_data: HashMap::new(),
}
}
/// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。
pub fn turn_index(&self) -> u32 {
self.turn_index
}
/// 累计用量(跨所有 turn)。
pub fn usage(&self) -> &CostTracker {
&self.cost_so_far
}
/// 会话级数据快照引用。
pub fn session_data(&self) -> &HashMap<String, String> {
&self.session_data
}
/// 写入一条会话级数据(覆盖同名 key)。
pub fn set_session_data(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.session_data.insert(key.into(), value.into());
}
/// 读取一条会话级数据。
pub fn get_session_data(&self, key: &str) -> Option<&str> {
self.session_data.get(key).map(String::as_str)
}
/// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。
///
/// 流程:
/// 1. 触发 `OnTurnStart` hook
/// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史)
/// 3. `submit_with_tools` 跑单轮对话
/// 4. 累计 `cost_so_far`
/// 5. 触发 `OnTurnEnd` hook
/// 6. `turn_index += 1`
///
/// **不做**
/// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写)
/// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`
/// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`
pub async fn submit_turn(
&mut self,
user_input: impl Into<String>,
) -> Result<ChatResponse, AgentError> {
let turn_index = self.turn_index;
let hook_executor = Arc::clone(&self.bundle.hook_executor);
// 1. 触发 OnTurnStart hook
let start_ctx =
HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index);
hook_executor
.execute(HookEvent::OnTurnStart, &start_ctx)
.await;
// 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄
// 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量);
// submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发
// 子 trait 覆盖(白名单/过滤)的副作用。
let _ = self.agent.tool_definitions(&self.bundle);
let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default())
.with_messages(Vec::new());
if let Some(prompt) = self.agent.system_prompt() {
cycle = cycle.with_system_prompt(prompt.to_string());
}
if let Some(cfg) = self.bundle.config.compact_config.clone() {
cycle = cycle.with_compact_config(cfg);
}
// 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发)
let response = cycle
.submit_with_tools(user_input.into(), &self.bundle.tool_registry)
.await?;
// 4. 累计 cost
self.cost_so_far.add(&response.usage);
// 5. 触发 OnTurnEnd hook
let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index);
hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await;
// 6. turn_index 递增
self.turn_index += 1;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::builder::AgentBuilder;
use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult};
use crate::llm::provider::LlmProvider;
use crate::llm::types::{
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage,
};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
/// 计数 hook —— 每被调用一次 +1。
struct CountHook(AtomicU32);
#[async_trait]
impl Hook for CountHook {
async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult {
self.0.fetch_add(1, Ordering::SeqCst);
HookResult::allow()
}
}
/// 把 `Arc<CountHook>` 包装为 `Box<dyn Hook>`dyn Hook 不能直接来自 Arc)。
struct CountHookAdapter(Arc<CountHook>);
#[async_trait]
impl Hook for CountHookAdapter {
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult {
self.0.execute(ctx).await
}
}
/// MockProvider:按调用顺序返回预设响应。
struct MockProvider {
responses: std::sync::Mutex<Vec<ChatResponse>>,
}
impl MockProvider {
fn new(responses: Vec<ChatResponse>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
}
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, crate::llm::error::LlmError> {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(crate::llm::error::LlmError::Other(
"no more mock responses".into(),
));
}
Ok(responses.remove(0))
}
}
struct StubAgent {
name: String,
prompt: Option<String>,
}
impl Agent for StubAgent {
fn name(&self) -> &str {
&self.name
}
fn system_prompt(&self) -> Option<&str> {
self.prompt.as_deref()
}
}
fn assistant_text(text: &str) -> ChatResponse {
ChatResponse {
message: OpenaiChatMessage::assistant_text(text),
usage: crate::llm::types::Usage::from_input_output(10, 5),
stop_reason: Some(FinishReason::Stop),
}
}
/// 烟雾测试 1AgentSession::submit_turn 跑通 mock provider。
#[tokio::test]
async fn submit_turn_runs_with_mock_provider() {
let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: Some("you are a test agent".into()),
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s1", bundle);
assert_eq!(session.turn_index(), 0);
let response = session.submit_turn("hi").await.unwrap();
let text = match &response.message {
OpenaiChatMessage::Assistant { content, .. } => {
if let crate::llm::types::ContentField::String(s) = content {
s.clone()
} else {
String::new()
}
}
_ => String::new(),
};
assert_eq!(text, "hello back");
assert_eq!(session.turn_index(), 1);
assert_eq!(session.usage().total().prompt_tokens, 10);
assert_eq!(session.usage().total().completion_tokens, 5);
}
/// 烟雾测试 2session_data 读写。
#[test]
fn session_data_set_get() {
let provider = Arc::new(MockProvider::new(vec![]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s2", bundle);
assert!(session.get_session_data("k").is_none());
session.set_session_data("k", "v");
assert_eq!(session.get_session_data("k"), Some("v"));
// 覆盖写
session.set_session_data("k", "v2");
assert_eq!(session.get_session_data("k"), Some("v2"));
}
/// 烟雾测试 3submit_turn 触发 OnTurnStart / OnTurnEnd hook。
#[tokio::test]
async fn submit_turn_triggers_turn_hooks() {
let mut hook_executor = HookExecutor::new();
let start_count = Arc::new(CountHook(AtomicU32::new(0)));
let end_count = Arc::new(CountHook(AtomicU32::new(0)));
hook_executor.register(
HookEvent::OnTurnStart,
Box::new(CountHookAdapter(start_count.clone())),
);
hook_executor.register(
HookEvent::OnTurnEnd,
Box::new(CountHookAdapter(end_count.clone())),
);
let provider = Arc::new(MockProvider::new(vec![
assistant_text("ok"),
assistant_text("ok 2"),
]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(hook_executor))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s3", bundle);
session.submit_turn("hi").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 1);
assert_eq!(end_count.0.load(Ordering::SeqCst), 1);
session.submit_turn("hi again").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 2);
assert_eq!(end_count.0.load(Ordering::SeqCst), 2);
}
}