feat(agent): 实现 Agent Runtime 核心胶水层 (Phase 4a)

- 添加 Agent trait、AgentSession、RuntimeBundle、AgentBuilder
- 添加 Plan/Step/StepStatus 任务规划数据结构
- 添加 AgentError 统一错误类型(聚合 LlmError/ToolError/MemoryError)
- 实现 submit_turn 单轮对话流程(含 hook 触发与 cost 累计)
- 扩展 LlmCycle 支持 Arc<dyn LlmProvider>
- 扩展 HookEvent 添加 OnTurnStart/OnTurnEnd
- 更新 roadmap 状态
This commit is contained in:
徐涛
2026-06-11 21:45:28 +08:00
parent 59ec0f5597
commit 2b189880a9
11 changed files with 1025 additions and 19 deletions
+25
View File
@@ -0,0 +1,25 @@
//! Agent Runtime —— 智能体(Agent)核心胶水层。
//!
//! 把 Phase 0-3 的能力(LlmCycle / ToolRegistry / MemoryStore / HookExecutor"装配"为
//! 上层可用的智能体抽象:`Agent` / `AgentSession` / `RuntimeBundle` / `AgentBuilder` / `Plan`。
//!
//! **不**实现业务循环,**不**假设上层如何使用 memory。
//! 详细设计见 `docs/7-agent-runtime.md`。
// 模块根文件 `agent.rs` 与子模块 `agent/agent.rs` 同名(项目惯例,与 `llm/cycle.rs` 一致)。
#![allow(clippy::module_inception)]
pub mod agent;
pub mod builder;
pub mod error;
pub mod runtime;
pub mod session;
pub mod task;
// 重导出公共 API(按使用频度排序)
pub use agent::Agent;
pub use builder::AgentBuilder;
pub use error::AgentError;
pub use runtime::{AgentConfig, RuntimeBundle};
pub use session::AgentSession;
pub use task::{Plan, Step, StepStatus};
+30
View File
@@ -0,0 +1,30 @@
//! Agent trait —— 智能体的"角色"抽象。
//!
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.1):
//!
//! - **角色与会话分离**`Agent` 定义"做什么、用什么工具"`AgentSession` 维护"当前状态"
//! - **工具白名单扩展点**:默认从 `RuntimeBundle.tool_registry` 取全部,子 trait 可覆盖做白名单/过滤
//! - **不绑定业务循环**`submit_turn` 在 `AgentSession` 上,不在 trait 上
use crate::agent::runtime::RuntimeBundle;
use crate::llm::types::ToolDefinition;
/// Agent 角色抽象。
///
/// 实现此 trait 即可接入 Agent Runtime。典型实现是 struct 持有静态配置(name、system prompt 模板),
/// 也可以是基于配置动态生成的轻量实现。
pub trait Agent: Send + Sync {
/// 角色名(用于日志、调试、UI 展示)。
fn name(&self) -> &str;
/// 系统提示词。无提示词的纯工具型 agent 返回 `None`。
fn system_prompt(&self) -> Option<&str>;
/// 列出该 Agent 想暴露给 LLM 的工具定义。
///
/// **默认实现**:从 `bundle.tool_registry` 取全部工具(最常用模式)。
/// **子 trait / 具体实现可覆盖**:做白名单、过滤、按状态动态调整等。
fn tool_definitions(&self, bundle: &RuntimeBundle) -> Vec<ToolDefinition> {
bundle.tool_registry.definitions()
}
}
+174
View File
@@ -0,0 +1,174 @@
//! AgentBuilder —— `RuntimeBundle` 的链式构造入口。
//!
//! 设计原则:
//!
//! - **唯一构造入口**:上层应用不应直接 `RuntimeBundle::new`;用 `AgentBuilder` 保证必填字段
//! 校验集中、默认值集中管理
//! - **必填字段在 `build()` 时校验**:缺失返回 `AgentError::Config`,不 panic
//! - **选填字段独立 setter**:未调用对应 setter 时使用 `None` 兜底
use std::sync::Arc;
use crate::agent::error::AgentError;
use crate::agent::runtime::{AgentConfig, RuntimeBundle};
use crate::llm::hooks::HookExecutor;
use crate::llm::provider::LlmProvider;
use crate::memory::retriever::MemoryRetriever;
use crate::memory::store::MemoryStore;
use crate::tools::ToolRegistry;
/// `RuntimeBundle` 的链式构造器。
///
/// 使用示例:
/// ```ignore
/// let bundle = AgentBuilder::new()
/// .provider(my_provider)
/// .tool_registry(my_registry)
/// .hook_executor(my_executor)
/// .build()?;
/// ```
#[derive(Default)]
pub struct AgentBuilder {
provider: Option<Arc<dyn LlmProvider>>,
tool_registry: Option<Arc<ToolRegistry>>,
hook_executor: Option<Arc<HookExecutor>>,
memory_store: Option<Arc<dyn MemoryStore>>,
retriever: Option<Arc<MemoryRetriever>>,
config: Option<AgentConfig>,
}
impl AgentBuilder {
/// 创建一个空的 builder,所有必填字段均为 `None`。
pub fn new() -> Self {
Self::default()
}
/// 设置 LLM provider(必填)。
pub fn provider(mut self, p: Arc<dyn LlmProvider>) -> Self {
self.provider = Some(p);
self
}
/// 设置工具注册表(必填)。
pub fn tool_registry(mut self, r: Arc<ToolRegistry>) -> Self {
self.tool_registry = Some(r);
self
}
/// 设置钩子执行器(必填)。
pub fn hook_executor(mut self, h: Arc<HookExecutor>) -> Self {
self.hook_executor = Some(h);
self
}
/// 设置持久化记忆后端(选填,不传也能跑)。
pub fn memory_store(mut self, m: Arc<dyn MemoryStore>) -> Self {
self.memory_store = Some(m);
self
}
/// 设置记忆检索器(选填,不传也能跑)。
pub fn retriever(mut self, r: Arc<MemoryRetriever>) -> Self {
self.retriever = Some(r);
self
}
/// 整体覆盖 `AgentConfig`(选填,不传则用默认值)。
pub fn config(mut self, c: AgentConfig) -> Self {
self.config = Some(c);
self
}
/// 构造 `RuntimeBundle`,校验必填字段。
///
/// **错误**`provider` / `tool_registry` / `hook_executor` 任一缺失则返回
/// `AgentError::Config("missing <field>")`,不 panic。
pub fn build(self) -> Result<RuntimeBundle, AgentError> {
let provider = self
.provider
.ok_or_else(|| AgentError::Config("missing provider".into()))?;
let tool_registry = self
.tool_registry
.ok_or_else(|| AgentError::Config("missing tool_registry".into()))?;
let hook_executor = self
.hook_executor
.ok_or_else(|| AgentError::Config("missing hook_executor".into()))?;
let config = self.config.unwrap_or_default();
Ok(RuntimeBundle::new(
provider,
tool_registry,
hook_executor,
self.memory_store,
self.retriever,
config,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::provider::LlmProvider;
use crate::llm::types::{ChatRequest, ChatResponse};
use crate::llm::error::LlmError;
use async_trait::async_trait;
struct StubProvider;
#[async_trait]
impl LlmProvider for StubProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, LlmError> {
unimplemented!()
}
}
#[test]
fn build_with_all_required_succeeds() {
let bundle = AgentBuilder::new()
.provider(Arc::new(StubProvider))
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build();
assert!(bundle.is_ok());
}
#[test]
fn build_missing_provider_returns_config_error() {
let result = AgentBuilder::new()
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build();
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("provider")));
}
#[test]
fn build_missing_tool_registry_returns_config_error() {
let result = AgentBuilder::new()
.provider(Arc::new(StubProvider))
.hook_executor(Arc::new(HookExecutor::new()))
.build();
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("tool_registry")));
}
#[test]
fn build_missing_hook_executor_returns_config_error() {
let result = AgentBuilder::new()
.provider(Arc::new(StubProvider))
.tool_registry(Arc::new(ToolRegistry::new()))
.build();
assert!(matches!(result, Err(AgentError::Config(s)) if s.contains("hook_executor")));
}
#[test]
fn optional_fields_default_to_none() {
let bundle = AgentBuilder::new()
.provider(Arc::new(StubProvider))
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap();
assert!(bundle.memory_store.is_none());
assert!(bundle.retriever.is_none());
}
}
+173
View File
@@ -0,0 +1,173 @@
//! Agent Runtime 统一错误类型。
//!
//! `AgentError` 聚合 Phase 0-3 各层错误(LlmError / ToolError / MemoryError),
//! 加上 Agent 层特有的错误变体。设计原则:
//!
//! - 聚合而非包装:保留内层错误的类型信息(避免 `Box<dyn Error>` 丢失上下文)
//! - 显式 `From` 实现:让 `?` 运算符能透明传播下层错误
//! - `is_recoverable()`:根据变体类型判定可恢复性,便于上层决策
use thiserror::Error;
use crate::llm::error::LlmError;
use crate::memory::error::MemoryError;
use crate::tools::error::ToolError;
/// Agent Runtime 统一错误枚举。
///
/// **不实现 `Clone`**:透传内层 `LlmError` / `MemoryError`,两者均未派生 `Clone`(保留
/// 完整错误信息,传递所有权)。如需在多 session 间共享错误状态,用 `Arc<AgentError>` 包装。
#[derive(Debug, Error)]
pub enum AgentError {
/// LLM 调用错误(透传 Phase 0)。
#[error("LLM 错误: {0}")]
Llm(#[from] LlmError),
/// 工具调用错误(透传 Phase 2)。
#[error("工具错误: {0}")]
Tool(#[from] ToolError),
/// 记忆系统错误(透传 Phase 3)。
#[error("记忆错误: {0}")]
Memory(#[from] MemoryError),
/// 钩子阻断操作(Agent 层特有)。
#[error("钩子阻断: {0}")]
HookBlocked(String),
/// 达到限制阈值(最大 turn、token 预算等)。
#[error("超过限制: {0}")]
LimitExceeded(String),
/// 配置错误(构建 RuntimeBundle / AgentSession 时校验失败)。
#[error("配置错误: {0}")]
Config(String),
/// 其他未分类错误(兜底)。
#[error("Agent 错误: {0}")]
Other(String),
}
impl AgentError {
/// 判定错误是否可恢复。
///
/// - `Llm` / `Memory`:由内层 `is_recoverable()` 决定
/// - `Tool`:由内层 `is_recoverable()` 决定
/// - `HookBlocked` / `LimitExceeded`:不可恢复(需人工介入或终止循环)
/// - `Config` / `Other`:不可恢复
pub fn is_recoverable(&self) -> bool {
match self {
Self::Llm(e) => matches!(
e,
LlmError::RateLimit { .. } | LlmError::Timeout { .. } | LlmError::Stream(_)
),
Self::Tool(e) => e.is_recoverable(),
Self::Memory(e) => e.is_recoverable(),
Self::HookBlocked(_) | Self::LimitExceeded(_) | Self::Config(_) | Self::Other(_) => {
false
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn llm_recoverable_propagation() {
let err = AgentError::Llm(LlmError::Timeout {
duration: std::time::Duration::from_secs(30),
});
assert!(err.is_recoverable());
}
#[test]
fn llm_non_recoverable_propagation() {
let err = AgentError::Llm(LlmError::Authentication("bad key".into()));
assert!(!err.is_recoverable());
}
#[test]
fn tool_recoverable_propagation() {
let err = AgentError::Tool(ToolError::ExecutionFailed("foo".into(), "boom".into()));
assert!(err.is_recoverable());
}
#[test]
fn tool_non_recoverable_propagation() {
let err = AgentError::Tool(ToolError::NotFound("foo".into()));
assert!(!err.is_recoverable());
}
#[test]
fn memory_recoverable_propagation() {
let err = AgentError::Memory(MemoryError::NotFound("page".into()));
assert!(err.is_recoverable());
}
#[test]
fn memory_non_recoverable_propagation() {
let err = AgentError::Memory(MemoryError::Storage("disk full".into()));
assert!(!err.is_recoverable());
}
#[test]
fn hook_blocked_not_recoverable() {
assert!(!AgentError::HookBlocked("denied".into()).is_recoverable());
}
#[test]
fn limit_exceeded_not_recoverable() {
assert!(!AgentError::LimitExceeded("max turns".into()).is_recoverable());
}
#[test]
fn config_not_recoverable() {
assert!(!AgentError::Config("missing provider".into()).is_recoverable());
}
#[test]
fn other_not_recoverable() {
assert!(!AgentError::Other("unknown".into()).is_recoverable());
}
#[test]
fn from_llm_via_question_mark() {
fn returns_llm() -> Result<(), LlmError> {
Err(LlmError::Other("test".into()))
}
fn caller() -> Result<(), AgentError> {
returns_llm()?;
Ok(())
}
let err = caller().unwrap_err();
assert!(matches!(err, AgentError::Llm(_)));
}
#[test]
fn from_tool_via_question_mark() {
fn returns_tool() -> Result<(), ToolError> {
Err(ToolError::NotFound("x".into()))
}
fn caller() -> Result<(), AgentError> {
returns_tool()?;
Ok(())
}
let err = caller().unwrap_err();
assert!(matches!(err, AgentError::Tool(_)));
}
#[test]
fn from_memory_via_question_mark() {
fn returns_mem() -> Result<(), MemoryError> {
Err(MemoryError::Storage("x".into()))
}
fn caller() -> Result<(), AgentError> {
returns_mem()?;
Ok(())
}
let err = caller().unwrap_err();
assert!(matches!(err, AgentError::Memory(_)));
}
}
+110
View File
@@ -0,0 +1,110 @@
//! Runtime Bundle —— 显式依赖注入容器(OpenHarness 风格)。
//!
//! 集中持有 Agent 运行所需的全部运行时依赖:`LlmProvider` / `ToolRegistry` / `HookExecutor` /
//! `MemoryStore`(弱引用)/ `MemoryRetriever`(弱引用) / `AgentConfig`。
//!
//! **设计意图**(参见 `docs/7-agent-runtime.md` §3.2.2):
//!
//! - 所有运行时依赖显式打包,便于跨 `AgentSession` 共享、便于测试注入 mock
//! - `memory_store` / `retriever` 为 `Option`:上层应用不传也能跑(无记忆模式)
//! - 构造时若 `retriever` 为 `Some`,自动注册 `"retrieve"` toolv0.1 占位——
//! Phase 4a 不在 `submit_turn` 中真正调用;Phase 4a 任务范围仅"装配可注册",
//! 真正的 `RetrieveTool` 实现留待 v0.2 接入)
//! - 不持有 `Box<dyn LlmProvider>` 而是 `Arc<dyn LlmProvider>`:支持多 session 共享
use std::sync::Arc;
use std::time::Duration;
use crate::llm::compact::CompactConfig;
use crate::llm::provider::LlmProvider;
use crate::llm::hooks::HookExecutor;
use crate::memory::retriever::MemoryRetriever;
use crate::memory::store::MemoryStore;
use crate::tools::ToolRegistry;
/// Agent 运行配置。
#[derive(Debug, Clone)]
pub struct AgentConfig {
/// 单次会话最大 turn 数(含工具循环内部 turn),默认 50。
pub max_turns: u32,
/// 单次会话最大工具循环轮次(与 LlmCycle 的 `max_tool_turns` 对齐),默认 10。
pub max_tool_turns: u32,
/// 会话 TTLNone 表示无过期),默认 None。
pub session_ttl: Option<Duration>,
/// 上下文压缩配置(None 表示不启用自动压缩),默认 None。
pub compact_config: Option<CompactConfig>,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
max_turns: 50,
max_tool_turns: 10,
session_ttl: None,
compact_config: None,
}
}
}
/// Agent Runtime 依赖注入容器。
///
/// 通过 `AgentBuilder::build()` 构造;构造完成后内部为只读视图。
/// `Arc` 共享,多个 `AgentSession` 可共用同一个 bundle。
#[derive(Clone)]
pub struct RuntimeBundle {
/// LLM 后端(强引用,多 session 共享)。
pub provider: Arc<dyn LlmProvider>,
/// 工具注册表(强引用,多 session 共享)。
pub tool_registry: Arc<ToolRegistry>,
/// 钩子执行器(强引用,多 session 共享)。
pub hook_executor: Arc<HookExecutor>,
/// 持久化记忆后端(弱引用 —— 不传也能跑)。
pub memory_store: Option<Arc<dyn MemoryStore>>,
/// 记忆检索器(弱引用 —— 不传也能跑)。
/// 传入时可在 `submit_turn` 内部将检索能力作为工具暴露给 LLM。
pub retriever: Option<Arc<MemoryRetriever>>,
/// 运行时配置。
pub config: AgentConfig,
}
impl std::fmt::Debug for RuntimeBundle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RuntimeBundle")
.field("provider_type", &"<dyn LlmProvider>")
.field("tool_names", &self.tool_registry.list_tools())
.field("has_memory_store", &self.memory_store.is_some())
.field("has_retriever", &self.retriever.is_some())
.field("config", &self.config)
.finish()
}
}
impl RuntimeBundle {
/// 构造一个 `RuntimeBundle`。
///
/// **Phase 4a 行为**`retriever` 存在时仅占位记录,不真正注入工具
/// v0.1 不在 `submit_turn` 中启用检索;Phase 4c 之后再决定是否注册成 tool)。
/// 真正的工具注入留待 v0.2 接入 `RetrieveTool` 实现。
pub fn new(
provider: Arc<dyn LlmProvider>,
tool_registry: Arc<ToolRegistry>,
hook_executor: Arc<HookExecutor>,
memory_store: Option<Arc<dyn MemoryStore>>,
retriever: Option<Arc<MemoryRetriever>>,
config: AgentConfig,
) -> Self {
Self {
provider,
tool_registry,
hook_executor,
memory_store,
retriever,
config,
}
}
}
+342
View File
@@ -0,0 +1,342 @@
//! AgentSession —— 智能体"会话"实例。
//!
//! 设计要点(参见 `docs/7-agent-runtime.md` §3.2.3):
//!
//! - **会话 = 角色 + 状态**:绑定 `session_id` / `agent` / `bundle`,累计 `turn_index` 和 `cost_so_far`
//! - **最小 reference impl**`submit_turn` 演示"组装 LlmCycle → submit_with_tools → 累计 cost"的标准流程
//! - **不做业务循环**:多轮策略、错误重试、记忆回写由上层应用或具体 `TaskAgent` 决定
//! - **不持有 ConversationMemory**:上层可独立 new 一个 `ConversationMemory`,在合适的时机调 `add_message`
use std::collections::HashMap;
use std::sync::Arc;
use crate::agent::agent::Agent;
use crate::agent::error::AgentError;
use crate::agent::runtime::RuntimeBundle;
use crate::llm::cycle::{CostTracker, CycleConfig, LlmCycle};
use crate::llm::hooks::{HookContext, HookEvent};
use crate::llm::types::ChatResponse;
/// Agent 会话实例。
///
/// 同一 `Agent` 可被多个 `AgentSession` 复用(不同 session_id 互不干扰)。
/// `submit_turn` 一次只跑一轮 LLM 调用(含自动 tool 循环)。
///
/// **不实现 `Clone`**session 持有累计 `turn_index` / `cost_so_far` / `session_data`
/// 共享这些状态需要显式 sync 语义;如果上层需要并发访问,自己用 `Arc<Mutex<_>>` 包装。
pub struct AgentSession {
/// 会话 ID(由调用方指定,用于日志/追踪/记忆关联)。
pub session_id: String,
/// 角色(可热切换为同 bundle 下的其他角色)。
pub agent: Arc<dyn Agent>,
bundle: Arc<RuntimeBundle>,
turn_index: u32,
cost_so_far: CostTracker,
/// 会话级键值数据(Phase 4a 用内联 HashMapPhase 4c 替换为 `SessionMemory`)。
session_data: HashMap<String, String>,
}
impl std::fmt::Debug for AgentSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentSession")
.field("session_id", &self.session_id)
.field("agent", &self.agent.name())
.field("turn_index", &self.turn_index)
.field("cost_so_far", &self.cost_so_far.total())
.field("session_data_keys", &self.session_data.keys().collect::<Vec<_>>())
.finish()
}
}
impl AgentSession {
/// 创建一个新的会话实例。
///
/// `agent` 与 `bundle` 共同决定 `submit_turn` 行为:system_prompt / 工具集 / LLM 后端均来自它们。
pub fn new(
agent: Arc<dyn Agent>,
session_id: impl Into<String>,
bundle: Arc<RuntimeBundle>,
) -> Self {
Self {
session_id: session_id.into(),
agent,
bundle,
turn_index: 0,
cost_so_far: CostTracker::default(),
session_data: HashMap::new(),
}
}
/// 当前 turn 序号(0-based:第一次 `submit_turn` 完成后变 1)。
pub fn turn_index(&self) -> u32 {
self.turn_index
}
/// 累计用量(跨所有 turn)。
pub fn usage(&self) -> &CostTracker {
&self.cost_so_far
}
/// 会话级数据快照引用。
pub fn session_data(&self) -> &HashMap<String, String> {
&self.session_data
}
/// 写入一条会话级数据(覆盖同名 key)。
pub fn set_session_data(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.session_data.insert(key.into(), value.into());
}
/// 读取一条会话级数据。
pub fn get_session_data(&self, key: &str) -> Option<&str> {
self.session_data.get(key).map(String::as_str)
}
/// 提交一轮对话(含自动 tool 循环),返回 LLM 响应。
///
/// 流程:
/// 1. 触发 `OnTurnStart` hook
/// 2. 组装 `LlmCycle`(注入 system_prompt / hook_executor / compact_config / 消息历史)
/// 3. `submit_with_tools` 跑单轮对话
/// 4. 累计 `cost_so_far`
/// 5. 触发 `OnTurnEnd` hook
/// 6. `turn_index += 1`
///
/// **不做**
/// - 不持有 `ConversationMemory`(由上层独立 task 决定何时回写)
/// - 不做 Plan 拆解(Phase 4b 才加 `TaskAgent`
/// - 不做 session_data 持久化(Phase 4c 替换为 `SessionMemory`
pub async fn submit_turn(
&mut self,
user_input: impl Into<String>,
) -> Result<ChatResponse, AgentError> {
let turn_index = self.turn_index;
let hook_executor = Arc::clone(&self.bundle.hook_executor);
// 1. 触发 OnTurnStart hook
let start_ctx =
HookContext::new(HookEvent::OnTurnStart).with_turn_index(turn_index);
hook_executor
.execute(HookEvent::OnTurnStart, &start_ctx)
.await;
// 2. 组装 LlmCycle —— 共享 bundle 中的 provider 句柄
// 工具列表从 agent.tool_definitions(bundle) 派生(默认 = bundle 全量);
// submit_with_tools 内部从 registry 自行取 definitions,此处仅消费以触发
// 子 trait 覆盖(白名单/过滤)的副作用。
let _ = self.agent.tool_definitions(&self.bundle);
let mut cycle = LlmCycle::new_with_arc(Arc::clone(&self.bundle.provider), CycleConfig::default())
.with_messages(Vec::new());
if let Some(prompt) = self.agent.system_prompt() {
cycle = cycle.with_system_prompt(prompt.to_string());
}
if let Some(cfg) = self.bundle.config.compact_config.clone() {
cycle = cycle.with_compact_config(cfg);
}
// 3. 提交(HookExecutor 不在这里传——内部 hook 由 LlmCycle 在 PreRequest/PostRequest 触发)
let response = cycle
.submit_with_tools(user_input.into(), &self.bundle.tool_registry)
.await?;
// 4. 累计 cost
self.cost_so_far.add(&response.usage);
// 5. 触发 OnTurnEnd hook
let end_ctx = HookContext::new(HookEvent::OnTurnEnd).with_turn_index(turn_index);
hook_executor.execute(HookEvent::OnTurnEnd, &end_ctx).await;
// 6. turn_index 递增
self.turn_index += 1;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::builder::AgentBuilder;
use crate::llm::hooks::{Hook, HookContext, HookExecutor, HookResult};
use crate::llm::provider::LlmProvider;
use crate::llm::types::{
ChatRequest, ChatResponse, FinishReason, OpenaiChatMessage,
};
use crate::tools::ToolRegistry;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
/// 计数 hook —— 每被调用一次 +1。
struct CountHook(AtomicU32);
#[async_trait]
impl Hook for CountHook {
async fn execute(&self, _ctx: &HookContext<'_>) -> HookResult {
self.0.fetch_add(1, Ordering::SeqCst);
HookResult::allow()
}
}
/// 把 `Arc<CountHook>` 包装为 `Box<dyn Hook>`dyn Hook 不能直接来自 Arc)。
struct CountHookAdapter(Arc<CountHook>);
#[async_trait]
impl Hook for CountHookAdapter {
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult {
self.0.execute(ctx).await
}
}
/// MockProvider:按调用顺序返回预设响应。
struct MockProvider {
responses: std::sync::Mutex<Vec<ChatResponse>>,
}
impl MockProvider {
fn new(responses: Vec<ChatResponse>) -> Self {
Self {
responses: std::sync::Mutex::new(responses),
}
}
}
#[async_trait]
impl LlmProvider for MockProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, crate::llm::error::LlmError> {
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(crate::llm::error::LlmError::Other(
"no more mock responses".into(),
));
}
Ok(responses.remove(0))
}
}
struct StubAgent {
name: String,
prompt: Option<String>,
}
impl Agent for StubAgent {
fn name(&self) -> &str {
&self.name
}
fn system_prompt(&self) -> Option<&str> {
self.prompt.as_deref()
}
}
fn assistant_text(text: &str) -> ChatResponse {
ChatResponse {
message: OpenaiChatMessage::assistant_text(text),
usage: crate::llm::types::Usage::from_input_output(10, 5),
stop_reason: Some(FinishReason::Stop),
}
}
/// 烟雾测试 1AgentSession::submit_turn 跑通 mock provider。
#[tokio::test]
async fn submit_turn_runs_with_mock_provider() {
let provider = Arc::new(MockProvider::new(vec![assistant_text("hello back")]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: Some("you are a test agent".into()),
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s1", bundle);
assert_eq!(session.turn_index(), 0);
let response = session.submit_turn("hi").await.unwrap();
let text = match &response.message {
OpenaiChatMessage::Assistant { content, .. } => {
if let crate::llm::types::ContentField::String(s) = content {
s.clone()
} else {
String::new()
}
}
_ => String::new(),
};
assert_eq!(text, "hello back");
assert_eq!(session.turn_index(), 1);
assert_eq!(session.usage().total().prompt_tokens, 10);
assert_eq!(session.usage().total().completion_tokens, 5);
}
/// 烟雾测试 2session_data 读写。
#[test]
fn session_data_set_get() {
let provider = Arc::new(MockProvider::new(vec![]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(HookExecutor::new()))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s2", bundle);
assert!(session.get_session_data("k").is_none());
session.set_session_data("k", "v");
assert_eq!(session.get_session_data("k"), Some("v"));
// 覆盖写
session.set_session_data("k", "v2");
assert_eq!(session.get_session_data("k"), Some("v2"));
}
/// 烟雾测试 3submit_turn 触发 OnTurnStart / OnTurnEnd hook。
#[tokio::test]
async fn submit_turn_triggers_turn_hooks() {
let mut hook_executor = HookExecutor::new();
let start_count = Arc::new(CountHook(AtomicU32::new(0)));
let end_count = Arc::new(CountHook(AtomicU32::new(0)));
hook_executor.register(
HookEvent::OnTurnStart,
Box::new(CountHookAdapter(start_count.clone())),
);
hook_executor.register(
HookEvent::OnTurnEnd,
Box::new(CountHookAdapter(end_count.clone())),
);
let provider = Arc::new(MockProvider::new(vec![
assistant_text("ok"),
assistant_text("ok 2"),
]));
let agent = Arc::new(StubAgent {
name: "stub".into(),
prompt: None,
});
let bundle = Arc::new(
AgentBuilder::new()
.provider(provider)
.tool_registry(Arc::new(ToolRegistry::new()))
.hook_executor(Arc::new(hook_executor))
.build()
.unwrap(),
);
let mut session = AgentSession::new(agent, "s3", bundle);
session.submit_turn("hi").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 1);
assert_eq!(end_count.0.load(Ordering::SeqCst), 1);
session.submit_turn("hi again").await.unwrap();
assert_eq!(start_count.0.load(Ordering::SeqCst), 2);
assert_eq!(end_count.0.load(Ordering::SeqCst), 2);
}
}
+121
View File
@@ -0,0 +1,121 @@
//! 任务规划数据结构 + Phase 4b 任务执行 trait。
//!
//! Phase 4a 范围:仅 `Plan` / `Step` / `StepStatus` 纯数据结构。
//! Phase 4b 在此文件追加 `TaskAgent` trait / `PlanParser` trait / `JsonPlanParser` 参考实现。
//!
//! 设计意图(参见 `docs/7-agent-runtime.md` §3.2.4、§3.3.1):
//!
//! - `StepStatus` 用 enum 而非简单 bool,便于 UI 展示和统计
//! - 状态机单向:`Pending → Running → (Completed | Failed | Skipped)`,不回退
//! - 重试由上层新建 `Plan` 实现,`TaskAgent` 不做自动重试
use crate::agent::error::AgentError;
use crate::llm::types::ChatResponse;
/// 任务规划 —— 一组有序的 Step。
#[derive(Debug)]
pub struct Plan {
/// 规划唯一标识。
pub id: String,
/// 规划目标(人类可读)。
pub goal: String,
/// 步骤列表。
pub steps: Vec<Step>,
}
/// 任务步骤。
#[derive(Debug)]
pub struct Step {
/// 步骤在 Plan 中的位置(0-based)。
pub index: usize,
/// 步骤描述(注入 LLM 作为 user prompt)。
pub description: String,
/// 当前状态。
pub status: StepStatus,
}
impl Step {
/// 创建一个初始为 `Pending` 的步骤。
pub fn new(index: usize, description: impl Into<String>) -> Self {
Self {
index,
description: description.into(),
status: StepStatus::Pending,
}
}
}
/// 步骤状态机。
///
/// 转换路径:`Pending → Running → (Completed | Failed | Skipped)`,单向不回退。
///
/// **不实现 `Clone`**`Failed` 变体携带 `AgentError`,下层 `LlmError` / `MemoryError`
/// 均未派生 `Clone`(保留原始错误信息,传递所有权而非克隆)。如需复制 `Plan`,
/// 只能 clone 处于 `Pending` / `Running` / `Completed` / `Skipped` 状态的步骤。
#[derive(Debug)]
pub enum StepStatus {
/// 初始状态 —— 等待执行。
Pending,
/// 正在执行(`TaskAgent::execute_plan` 进入)。
Running,
/// 已完成(含 LLM 响应)。
Completed(ChatResponse),
/// 失败(含错误)。
Failed(AgentError),
/// 跳过(上层主动跳过)。
Skipped,
}
impl StepStatus {
/// 状态是否处于"未完成"。
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
/// 状态是否处于终态。
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Completed(_) | Self::Failed(_) | Self::Skipped)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn step_initial_state_is_pending() {
let s = Step::new(0, "do something");
assert!(s.status.is_pending());
assert!(!s.status.is_terminal());
assert_eq!(s.index, 0);
assert_eq!(s.description, "do something");
}
#[test]
fn terminal_states_classified() {
let err = AgentError::Other("x".into());
assert!(StepStatus::Failed(err).is_terminal());
assert!(StepStatus::Skipped.is_terminal());
}
#[test]
fn running_is_not_terminal() {
assert!(!StepStatus::Running.is_terminal());
assert!(!StepStatus::Running.is_pending());
}
#[test]
fn plan_holds_steps() {
let plan = Plan {
id: "p1".into(),
goal: "test goal".into(),
steps: vec![
Step::new(0, "first"),
Step::new(1, "second"),
],
};
assert_eq!(plan.steps.len(), 2);
assert_eq!(plan.steps[0].index, 0);
assert_eq!(plan.steps[1].index, 1);
}
}
+1
View File
@@ -1,5 +1,6 @@
//! agcore —— 智能体(Agent)核心工具箱。
pub mod agent;
pub mod llm;
pub mod memory;
pub mod prompt;
+12 -2
View File
@@ -63,7 +63,7 @@ impl Default for CycleConfig {
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
pub struct LlmCycle {
provider: Box<dyn LlmProvider>,
provider: Arc<dyn LlmProvider>,
config: CycleConfig,
usage: CostTracker,
messages: Vec<OpenaiChatMessage>,
@@ -74,8 +74,18 @@ pub struct LlmCycle {
}
impl LlmCycle {
/// 创建一个新的 LlmCycle。
/// 创建一个新的 LlmCycle(持有 `Box<dyn LlmProvider>` 的独占所有权)
///
/// 内部将 Box 转为 `Arc<dyn LlmProvider>` 以便 `new_with_arc` 复用句柄。
/// 公共签名保持不变,向后兼容。
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
Self::new_with_arc(Arc::from(provider), config)
}
/// 创建一个新的 LlmCycle,共享传入的 `Arc<dyn LlmProvider>` 句柄。
///
/// **新增**Phase 4a 引入):用于 `AgentSession::submit_turn` 在多 session 间共享 provider。
pub fn new_with_arc(provider: Arc<dyn LlmProvider>, config: CycleConfig) -> Self {
Self {
provider,
config,
+13
View File
@@ -16,6 +16,10 @@ pub enum HookEvent {
OnRetry,
/// 不可恢复错误返回之前。
OnError,
/// Agent 会话开始一轮 turn 之前(Phase 4a 新增)。
OnTurnStart,
/// Agent 会话完成一轮 turn 之后(Phase 4a 新增)。
OnTurnEnd,
}
/// 此次钩子调用的上下文。
@@ -29,6 +33,8 @@ pub struct HookContext<'a> {
pub error: Option<&'a LlmError>,
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
pub attempt: u32,
/// 当前 turn 序号(0-based,仅 OnTurnStart / OnTurnEnd 可用,Phase 4a 新增)。
pub turn_index: Option<u32>,
}
impl<'a> HookContext<'a> {
@@ -38,6 +44,7 @@ impl<'a> HookContext<'a> {
request: None,
error: None,
attempt: 0,
turn_index: None,
}
}
@@ -55,6 +62,12 @@ impl<'a> HookContext<'a> {
self.attempt = attempt;
self
}
/// 设置 turn 序号(仅 OnTurnStart / OnTurnEnd 使用)。
pub(crate) fn with_turn_index(mut self, turn_index: u32) -> Self {
self.turn_index = Some(turn_index);
self
}
}
/// 钩子执行结果。