feat(llm): 实现 Phase 0 剩余四个模块
实现 ProviderRegistry、HookExecutor、StreamEvents 和 Auto-compaction 模块,并集成到 LlmCycle 中
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
//! agcore —— 智能体(Agent)核心工具箱。
|
||||
//!
|
||||
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
|
||||
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
|
||||
|
||||
pub mod llm;
|
||||
|
||||
|
||||
+3
-2
@@ -1,8 +1,9 @@
|
||||
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||
//!
|
||||
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
|
||||
|
||||
pub mod compact;
|
||||
pub mod cycle;
|
||||
pub mod error;
|
||||
pub mod hooks;
|
||||
pub mod provider;
|
||||
pub mod stream;
|
||||
pub mod types;
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
//! 上下文自动压缩 —— 当对话历史过长时自动压缩。
|
||||
|
||||
use crate::llm::types::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||
|
||||
const AUTOCOMPACT_BUFFER_TOKENS: u32 = 13_000;
|
||||
const RESERVED_OUTPUT_TOKENS: u32 = 20_000;
|
||||
const MAX_CONSECUTIVE_FAILURES: u32 = 3;
|
||||
const KEEP_RECENT: usize = 6;
|
||||
|
||||
/// 上下文压缩配置。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactConfig {
|
||||
/// 模型上下文窗口大小(token 数)。
|
||||
pub context_window: u32,
|
||||
/// 为输出预留的 token 数。
|
||||
pub reserved_tokens: u32,
|
||||
/// 微压缩保留的最近消息数。
|
||||
pub keep_recent: usize,
|
||||
}
|
||||
|
||||
impl Default for CompactConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context_window: 128_000,
|
||||
reserved_tokens: RESERVED_OUTPUT_TOKENS,
|
||||
keep_recent: KEEP_RECENT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactConfig {
|
||||
/// 计算自动压缩触发的阈值。
|
||||
pub fn threshold(&self) -> u32 {
|
||||
self.context_window
|
||||
.saturating_sub(self.reserved_tokens)
|
||||
.saturating_sub(AUTOCOMPACT_BUFFER_TOKENS)
|
||||
}
|
||||
}
|
||||
|
||||
/// 压缩状态 —— 跟踪连续失败次数(断路器模式)。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactState {
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
impl Default for CompactState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactState {
|
||||
/// 创建一个新的压缩状态。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
consecutive_failures: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// 记录一次成功的压缩。
|
||||
pub fn record_success(&mut self) {
|
||||
self.consecutive_failures = 0;
|
||||
}
|
||||
|
||||
/// 记录一次压缩失败。
|
||||
///
|
||||
/// 返回 `true` 表示已达断路器上限,不再尝试。
|
||||
pub fn record_failure(&mut self) -> bool {
|
||||
self.consecutive_failures += 1;
|
||||
self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES
|
||||
}
|
||||
}
|
||||
|
||||
/// 粗略估计消息列表的 token 数(基于字符数,4 字符 ≈ 1 token)。
|
||||
pub fn estimate_message_tokens(messages: &[OpenaiChatMessage]) -> u32 {
|
||||
messages
|
||||
.iter()
|
||||
.map(estimate_single_message_tokens)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn estimate_single_message_tokens(msg: &OpenaiChatMessage) -> u32 {
|
||||
let role_overhead: u32 = 4;
|
||||
let content_tokens = match msg {
|
||||
OpenaiChatMessage::Developer { content, .. }
|
||||
| OpenaiChatMessage::System { content, .. }
|
||||
| OpenaiChatMessage::User { content, .. }
|
||||
| OpenaiChatMessage::Assistant { content, .. }
|
||||
| OpenaiChatMessage::Function { content, .. } => estimate_content_tokens(content),
|
||||
OpenaiChatMessage::Tool { content, .. } => estimate_content_tokens(content),
|
||||
};
|
||||
role_overhead + content_tokens
|
||||
}
|
||||
|
||||
fn estimate_content_tokens(content: &ContentField) -> u32 {
|
||||
match content {
|
||||
ContentField::String(s) => estimate_text_tokens(s),
|
||||
ContentField::Array(parts) => parts.iter().map(estimate_part_tokens).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_part_tokens(part: &OpenaiContentPart) -> u32 {
|
||||
match part {
|
||||
OpenaiContentPart::Text { text } => estimate_text_tokens(text),
|
||||
_ => 50,
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_text_tokens(text: &str) -> u32 {
|
||||
if text.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
let len = text.len() as u32;
|
||||
(len * 4).div_ceil(3)
|
||||
}
|
||||
|
||||
/// 判断是否需要触发自动压缩。
|
||||
pub fn should_compact(
|
||||
messages: &[OpenaiChatMessage],
|
||||
config: &CompactConfig,
|
||||
state: &CompactState,
|
||||
) -> bool {
|
||||
if state.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
|
||||
return false;
|
||||
}
|
||||
let tokens = estimate_message_tokens(messages);
|
||||
tokens >= config.threshold()
|
||||
}
|
||||
|
||||
/// 执行微压缩 —— 用 `[pruned]` 替换旧的 tool result 内容。
|
||||
///
|
||||
/// 这是最便宜的压缩方式,不需要 LLM 调用。
|
||||
/// 保留最近的 `keep_recent` 条消息不变。
|
||||
///
|
||||
/// 返回释放的估算 token 数。
|
||||
pub fn microcompact(messages: &mut [OpenaiChatMessage], keep_recent: usize) -> u32 {
|
||||
if messages.len() <= keep_recent {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let prune_start = messages.len() - keep_recent;
|
||||
let mut freed_tokens: u32 = 0;
|
||||
|
||||
for msg in &messages[..prune_start] {
|
||||
if matches!(msg, OpenaiChatMessage::Tool { .. }) {
|
||||
freed_tokens += estimate_single_message_tokens(msg);
|
||||
}
|
||||
}
|
||||
|
||||
for msg in &mut messages[..prune_start] {
|
||||
if let OpenaiChatMessage::Tool { content, .. } = msg {
|
||||
*content = ContentField::Array(vec![OpenaiContentPart::Text {
|
||||
text: "[pruned]".to_string(),
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
freed_tokens
|
||||
}
|
||||
+216
-1
@@ -1,21 +1,38 @@
|
||||
//! LLM 调用周期控制模块。
|
||||
|
||||
mod retry;
|
||||
pub mod usage;
|
||||
|
||||
pub use retry::RetryConfig;
|
||||
pub use usage::{CostTracker, Usage};
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use async_stream::stream;
|
||||
|
||||
use crate::llm::compact::{should_compact, microcompact, CompactConfig, CompactState};
|
||||
use crate::llm::cycle::retry::should_retry;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::hooks::{HookContext, HookExecutor};
|
||||
use crate::llm::provider::LlmProvider;
|
||||
use crate::llm::stream::StreamEvent;
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
||||
};
|
||||
|
||||
/// LLM 调用周期配置。
|
||||
pub struct CycleConfig {
|
||||
/// 模型名称。
|
||||
pub model: String,
|
||||
/// 最大输出 token 数。
|
||||
pub max_tokens: Option<u32>,
|
||||
/// 采样温度。
|
||||
pub temperature: Option<f32>,
|
||||
/// 最大对话轮次。
|
||||
pub max_turns: Option<u32>,
|
||||
/// 重试策略配置。
|
||||
pub retry: RetryConfig,
|
||||
}
|
||||
|
||||
@@ -31,15 +48,20 @@ impl Default for CycleConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM 调用周期 —— 管理一次或多次 LLM 请求的生命周期。
|
||||
pub struct LlmCycle {
|
||||
provider: Box<dyn LlmProvider>,
|
||||
config: CycleConfig,
|
||||
usage: CostTracker,
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
system_prompt: Option<String>,
|
||||
hook_executor: Option<Arc<HookExecutor>>,
|
||||
compact_config: Option<CompactConfig>,
|
||||
compact_state: CompactState,
|
||||
}
|
||||
|
||||
impl LlmCycle {
|
||||
/// 创建一个新的 LlmCycle。
|
||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
@@ -47,30 +69,51 @@ impl LlmCycle {
|
||||
usage: CostTracker::default(),
|
||||
messages: Vec::new(),
|
||||
system_prompt: None,
|
||||
hook_executor: None,
|
||||
compact_config: None,
|
||||
compact_state: CompactState::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 设置系统提示词。
|
||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||
self.system_prompt = Some(prompt);
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置钩子执行器。
|
||||
pub fn with_hook_executor(mut self, executor: HookExecutor) -> Self {
|
||||
self.hook_executor = Some(Arc::new(executor));
|
||||
self
|
||||
}
|
||||
|
||||
/// 设置上下文压缩配置。
|
||||
pub fn with_compact_config(mut self, config: CompactConfig) -> Self {
|
||||
self.compact_config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
/// 获取用量追踪器引用。
|
||||
pub fn usage(&self) -> &CostTracker {
|
||||
&self.usage
|
||||
}
|
||||
|
||||
/// 获取消息历史引用。
|
||||
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// 清空消息历史。
|
||||
pub fn clear_messages(&mut self) {
|
||||
self.messages.clear();
|
||||
}
|
||||
|
||||
/// 重置用量统计。
|
||||
pub fn reset_usage(&mut self) {
|
||||
self.usage.reset();
|
||||
}
|
||||
|
||||
/// 提交用户消息并获取 LLM 响应。
|
||||
pub async fn submit(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
@@ -78,31 +121,203 @@ impl LlmCycle {
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let mut attempts = 0;
|
||||
|
||||
loop {
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
self.messages.push(response.message.clone());
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let post_request = self.build_request(&tools);
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
self.messages.push(response.message.clone());
|
||||
self.usage.add(&response.usage);
|
||||
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||
attempts += 1;
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnRetry)
|
||||
.with_error(&e)
|
||||
.with_attempt(attempts);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnRetry, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
let delay = self.config.retry.compute_delay(attempts);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx =
|
||||
HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 提交用户消息并返回语义事件流。
|
||||
///
|
||||
/// 与 `submit` 不同,该方法返回流式事件而非完整响应。
|
||||
/// 适用于需要实时处理 LLM 输出的场景。
|
||||
pub async fn submit_stream(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>, LlmError> {
|
||||
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||
|
||||
if let Some(ref config) = self.compact_config
|
||||
&& should_compact(&self.messages, config, &self.compact_state)
|
||||
{
|
||||
let freed = microcompact(&mut self.messages, config.keep_recent);
|
||||
if freed > 0 {
|
||||
self.compact_state.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
let request = self.build_request(&tools);
|
||||
|
||||
if let Some(ref executor) = self.hook_executor {
|
||||
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||
.with_request(&request);
|
||||
let results = executor
|
||||
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||
.await;
|
||||
if results.iter().any(|r| r.should_block) {
|
||||
let reason = results
|
||||
.iter()
|
||||
.find(|r| r.should_block)
|
||||
.and_then(|r| r.reason.clone())
|
||||
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||
return Err(LlmError::Other(reason));
|
||||
}
|
||||
}
|
||||
|
||||
let chunk_stream = self.provider.chat_stream(request).await?;
|
||||
let hook_executor = self.hook_executor.clone();
|
||||
let post_request = self.build_request(&tools);
|
||||
|
||||
Ok(Box::pin(stream! {
|
||||
use futures_util::StreamExt;
|
||||
let mut chunk_stream = chunk_stream;
|
||||
|
||||
while let Some(result) = chunk_stream.next().await {
|
||||
match result {
|
||||
Ok(chunk) => {
|
||||
let mut assistant_text = String::new();
|
||||
let mut tool_started: Option<(String, String, String)> = None;
|
||||
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
assistant_text.push_str(content);
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let crate::llm::types::OpenaiToolCall::Function { id, function } = tc;
|
||||
tool_started = Some((id.clone(), function.name.clone(), function.arguments.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if !assistant_text.is_empty() {
|
||||
yield StreamEvent::AssistantTextDelta { text: assistant_text };
|
||||
}
|
||||
|
||||
if let Some((tool_call_id, tool_name, arguments)) = tool_started {
|
||||
let args: serde_json::Value = serde_json::from_str(&arguments)
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
yield StreamEvent::ToolExecutionStarted {
|
||||
tool_name,
|
||||
input: args,
|
||||
tool_call_id,
|
||||
};
|
||||
}
|
||||
|
||||
for choice in &chunk.choices {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
yield StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage_info) = &chunk.usage {
|
||||
yield StreamEvent::CostUpdate { usage: *usage_info };
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::OnError,
|
||||
)
|
||||
.with_error(&e);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||
.await;
|
||||
}
|
||||
yield StreamEvent::Error { message: e.to_string() };
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref executor) = hook_executor {
|
||||
let ctx = crate::llm::hooks::HookContext::new(
|
||||
crate::llm::hooks::HookEvent::PostRequest,
|
||||
)
|
||||
.with_request(&post_request);
|
||||
executor
|
||||
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||
.await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||
let mut messages = self.messages.clone();
|
||||
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
//! 生命周期钩子 —— 在 LLM 调用周期的关键节点插入自定义逻辑。
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::ChatRequest;
|
||||
|
||||
/// 生命周期钩子事件点。
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HookEvent {
|
||||
/// LLM 请求发起之前(可阻断)。
|
||||
PreRequest,
|
||||
/// 成功响应之后。
|
||||
PostRequest,
|
||||
/// 重试之前(仅可重试错误时触发)。
|
||||
OnRetry,
|
||||
/// 不可恢复错误返回之前。
|
||||
OnError,
|
||||
}
|
||||
|
||||
/// 此次钩子调用的上下文。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookContext<'a> {
|
||||
/// 当前事件点。
|
||||
pub event: HookEvent,
|
||||
/// 当前的请求(部分事件点可用)。
|
||||
pub request: Option<&'a ChatRequest>,
|
||||
/// 当前错误(仅 OnError 和 OnRetry 可用)。
|
||||
pub error: Option<&'a LlmError>,
|
||||
/// 当前重试次数(从 1 开始,仅 OnRetry 可用)。
|
||||
pub attempt: u32,
|
||||
}
|
||||
|
||||
impl<'a> HookContext<'a> {
|
||||
pub(crate) fn new(event: HookEvent) -> Self {
|
||||
Self {
|
||||
event,
|
||||
request: None,
|
||||
error: None,
|
||||
attempt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_request(mut self, request: &'a ChatRequest) -> Self {
|
||||
self.request = Some(request);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_error(mut self, error: &'a LlmError) -> Self {
|
||||
self.error = Some(error);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_attempt(mut self, attempt: u32) -> Self {
|
||||
self.attempt = attempt;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// 钩子执行结果。
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookResult {
|
||||
/// 是否阻断后续操作(仅 PreRequest 有效)。
|
||||
pub should_block: bool,
|
||||
/// 阻断/备注原因。
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
impl HookResult {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn allow() -> Self {
|
||||
Self {
|
||||
should_block: false,
|
||||
reason: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn block(reason: impl Into<String>) -> Self {
|
||||
Self {
|
||||
should_block: true,
|
||||
reason: Some(reason.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 生命周期钩子 trait。
|
||||
#[async_trait]
|
||||
pub trait Hook: Send + Sync {
|
||||
/// 执行钩子逻辑。
|
||||
async fn execute(&self, ctx: &HookContext<'_>) -> HookResult;
|
||||
}
|
||||
|
||||
/// 钩子执行器 —— 管理注册与触发。
|
||||
pub struct HookExecutor {
|
||||
hooks: Vec<(HookEvent, Box<dyn Hook>)>,
|
||||
}
|
||||
|
||||
impl Default for HookExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl HookExecutor {
|
||||
/// 创建一个空的执行器。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hooks: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个钩子到指定事件点。
|
||||
pub fn register(&mut self, event: HookEvent, hook: Box<dyn Hook>) {
|
||||
self.hooks.push((event, hook));
|
||||
}
|
||||
|
||||
/// 执行指定事件点的所有钩子。
|
||||
pub async fn execute(&self, event: HookEvent, ctx: &HookContext<'_>) -> Vec<HookResult> {
|
||||
let mut results = Vec::new();
|
||||
for (e, hook) in &self.hooks {
|
||||
if *e == event {
|
||||
results.push(hook.execute(ctx).await);
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
}
|
||||
+21
-1
@@ -1,7 +1,12 @@
|
||||
pub mod openai;
|
||||
pub mod registry;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatChunk};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -57,4 +62,19 @@ pub fn create_provider(
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// 发送聊天请求并返回完整响应。
|
||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||
|
||||
/// 流式聊天请求 —— 返回原始 SSE chunk 流。
|
||||
///
|
||||
/// 默认实现回退到非流式调用(包装为单元素流)。
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let response = self.chat(request).await?;
|
||||
let chunk = OpenaiChatChunk::from(response);
|
||||
Ok(Box::pin(tokio_stream::once(Ok(chunk))))
|
||||
}
|
||||
}
|
||||
|
||||
+122
-1
@@ -1,12 +1,18 @@
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::LlmProvider;
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
|
||||
use crate::llm::types::{
|
||||
ChatRequest, ChatResponse, OpenaiChatChunk, OpenaiChatResponse, StreamOptions,
|
||||
};
|
||||
|
||||
pub struct OpenaiProvider {
|
||||
http_client: Client,
|
||||
@@ -111,4 +117,119 @@ impl LlmProvider for OpenaiProvider {
|
||||
|
||||
Ok(ChatResponse::from(chat_response))
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
mut request: ChatRequest,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
LlmError,
|
||||
> {
|
||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||
|
||||
request.stream = Some(true);
|
||||
request.stream_options = Some(StreamOptions {
|
||||
include_usage: Some(true),
|
||||
include_obfuscation: None,
|
||||
});
|
||||
|
||||
info!(model = %request.model, "发送 LLM 流式请求");
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "流式请求失败");
|
||||
Self::map_reqwest_error(e)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let status_code: u16 = status.as_u16();
|
||||
|
||||
if !status.is_success() {
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
error!(status = status_code, body = %body_text, "流式请求失败");
|
||||
return Err(LlmError::Request {
|
||||
status: status_code,
|
||||
body: body_text,
|
||||
});
|
||||
}
|
||||
|
||||
let byte_stream = response.bytes_stream().map(|r| {
|
||||
r.map_err(|e| LlmError::Other(format!("流式读取失败: {}", e)))
|
||||
});
|
||||
|
||||
Ok(Box::pin(SseChunkStream::new(byte_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
struct SseChunkStream<S> {
|
||||
inner: S,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> SseChunkStream<S> {
|
||||
fn new(stream: S) -> Self {
|
||||
Self {
|
||||
inner: stream,
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<Bytes, LlmError>> + Unpin> Stream for SseChunkStream<S> {
|
||||
type Item = Result<OpenaiChatChunk, LlmError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
if let Some(pos) = self.buffer.find("\n") {
|
||||
let line = self.buffer.drain(..pos + 1).collect::<String>();
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty()
|
||||
|| trimmed == "data: [DONE]"
|
||||
|| trimmed == "[DONE]"
|
||||
|| trimmed == "data:"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let data = if let Some(p) = trimmed.strip_prefix("data: ") {
|
||||
p
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
match serde_json::from_str::<OpenaiChatChunk>(data) {
|
||||
Ok(chunk) => return std::task::Poll::Ready(Some(Ok(chunk))),
|
||||
Err(e) => {
|
||||
return std::task::Poll::Ready(Some(Err(LlmError::Other(format!(
|
||||
"Chunk 解析失败: {} | raw: {}",
|
||||
e, data
|
||||
)))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(bytes))) => {
|
||||
if let Ok(s) = std::str::from_utf8(&bytes) {
|
||||
self.buffer.push_str(s);
|
||||
}
|
||||
}
|
||||
std::task::Poll::Ready(Some(Err(e))) => {
|
||||
return std::task::Poll::Ready(Some(Err(e)));
|
||||
}
|
||||
std::task::Poll::Ready(None) => {
|
||||
if self.buffer.is_empty() {
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
self.buffer.clear();
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
//! Provider 注册表 —— 多 Provider 实例的注册与发现。
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::provider::{create_provider, LlmProvider, ProviderConfig, ProviderType};
|
||||
|
||||
/// Provider 注册表 —— 管理多个 LLM Provider 实例。
|
||||
///
|
||||
/// 支持注册命名 Provider、按名称查找、设置默认 Provider。
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Box<dyn LlmProvider>>,
|
||||
default_name: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
/// 创建一个空的注册表。
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
default_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 注册一个已初始化的 Provider 实例。
|
||||
pub fn register(&mut self, name: impl Into<String>, provider: Box<dyn LlmProvider>) {
|
||||
self.providers.insert(name.into(), provider);
|
||||
}
|
||||
|
||||
/// 通过 ProviderType + ProviderConfig 创建并注册。
|
||||
pub fn register_with_config(
|
||||
&mut self,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
config: ProviderConfig,
|
||||
) -> Result<(), LlmError> {
|
||||
let provider = create_provider(provider_type, config)?;
|
||||
self.register(name, provider);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 设置默认 Provider。
|
||||
pub fn set_default(&mut self, name: &str) -> Result<(), LlmError> {
|
||||
if !self.providers.contains_key(name) {
|
||||
return Err(LlmError::Other(format!("Provider '{}' 不存在", name)));
|
||||
}
|
||||
self.default_name = Some(name.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 按名称查找 Provider。
|
||||
pub fn get(&self, name: &str) -> Option<&dyn LlmProvider> {
|
||||
self.providers.get(name).map(|p| p.as_ref())
|
||||
}
|
||||
|
||||
/// 获取默认 Provider。
|
||||
pub fn get_default(&self) -> Option<&dyn LlmProvider> {
|
||||
self.default_name
|
||||
.as_ref()
|
||||
.and_then(|name| self.get(name))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
//! 流式事件系统 —— 将 LLM 流式响应解析为语义化事件。
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use futures_core::stream::Stream;
|
||||
use futures_util::future::poll_fn;
|
||||
use futures_util::FutureExt;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::llm::error::LlmError;
|
||||
use crate::llm::types::{FinishReason, OpenaiChatChunk, OpenaiToolCall, Usage};
|
||||
|
||||
/// 流式事件 —— LLM 调用全生命周期的语义化事件。
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
/// 助手回复文本增量。
|
||||
AssistantTextDelta { text: String },
|
||||
/// 工具调用开始。
|
||||
ToolExecutionStarted {
|
||||
tool_name: String,
|
||||
input: Value,
|
||||
tool_call_id: String,
|
||||
},
|
||||
/// 工具调用完成。
|
||||
ToolExecutionCompleted {
|
||||
tool_name: String,
|
||||
output: Value,
|
||||
is_error: bool,
|
||||
},
|
||||
/// Token 用量更新。
|
||||
CostUpdate { usage: Usage },
|
||||
/// 一轮会话完成。
|
||||
TurnComplete { reason: FinishReason },
|
||||
/// 错误事件。
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
impl StreamEvent {
|
||||
fn error(message: impl Into<String>) -> Self {
|
||||
Self::Error {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 将原始 OpenaiChatChunk 流解析为 StreamEvent 流。
|
||||
pub fn parse_chunk_stream(
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
) -> Pin<Box<dyn futures_core::Stream<Item = StreamEvent> + Send>> {
|
||||
Box::pin(ChunkToEventStream { chunks })
|
||||
}
|
||||
|
||||
struct ChunkToEventStream {
|
||||
chunks: Pin<Box<dyn futures_core::Stream<Item = Result<OpenaiChatChunk, LlmError>> + Send>>,
|
||||
}
|
||||
|
||||
impl Stream for ChunkToEventStream {
|
||||
type Item = StreamEvent;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = &mut *self;
|
||||
poll_fn(|cx| match Pin::new(&mut this.chunks).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(chunk))) => {
|
||||
for choice in &chunk.choices {
|
||||
let delta = &choice.delta;
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
return Poll::Ready(Some(StreamEvent::AssistantTextDelta {
|
||||
text: content.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = &delta.tool_calls
|
||||
&& let Some(tc) = tool_calls.first()
|
||||
{
|
||||
let OpenaiToolCall::Function { id, function } = tc;
|
||||
let args: Value =
|
||||
serde_json::from_str(&function.arguments).unwrap_or(Value::Null);
|
||||
return Poll::Ready(Some(StreamEvent::ToolExecutionStarted {
|
||||
tool_name: function.name.clone(),
|
||||
input: args,
|
||||
tool_call_id: id.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
return Poll::Ready(Some(StreamEvent::TurnComplete {
|
||||
reason: *finish_reason,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = &chunk.usage {
|
||||
return Poll::Ready(Some(StreamEvent::CostUpdate {
|
||||
usage: *usage,
|
||||
}));
|
||||
}
|
||||
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(StreamEvent::error(e.to_string()))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
})
|
||||
.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,34 @@ impl From<OpenaiChatResponse> for ChatResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: ChatResponse) -> Self {
|
||||
let delta = Delta::from(response.message.clone());
|
||||
let chunk_choice = ChunkChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
logprobs: None,
|
||||
finish_reason: response.stop_reason,
|
||||
};
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: format!("chunk-{}", std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos())
|
||||
.unwrap_or(0)),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0),
|
||||
model: String::new(),
|
||||
choices: vec![chunk_choice],
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ChatRequest = OpenaiChatRequest;
|
||||
pub type Message = OpenaiChatMessage;
|
||||
pub type ContentBlock = OpenaiContentPart;
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::llm::types::shared::{FinishReason, ServiceTier};
|
||||
use crate::llm::types::tool::OpenaiToolCall;
|
||||
use crate::llm::types::usage::Usage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::llm::types::{ContentField, OpenaiContentPart};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenLogprob {
|
||||
@@ -115,3 +116,66 @@ pub struct OpenaiChatChunk {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
impl From<OpenaiChatMessage> for Delta {
|
||||
fn from(msg: OpenaiChatMessage) -> Self {
|
||||
match msg {
|
||||
OpenaiChatMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: match content {
|
||||
ContentField::String(s) => Some(s),
|
||||
ContentField::Array(parts) => {
|
||||
let mut text = String::new();
|
||||
for part in parts {
|
||||
if let OpenaiContentPart::Text { text: t } = part {
|
||||
text.push_str(&t);
|
||||
}
|
||||
}
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text)
|
||||
}
|
||||
}
|
||||
},
|
||||
refusal: None,
|
||||
tool_calls,
|
||||
},
|
||||
_ => Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenaiChatResponse> for OpenaiChatChunk {
|
||||
fn from(response: OpenaiChatResponse) -> Self {
|
||||
let choices = response
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|c| ChunkChoice {
|
||||
index: c.index,
|
||||
delta: Delta::from(c.message),
|
||||
logprobs: c.logprobs,
|
||||
finish_reason: c.finish_reason,
|
||||
})
|
||||
.collect();
|
||||
|
||||
OpenaiChatChunk {
|
||||
id: response.id,
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: response.created,
|
||||
model: response.model,
|
||||
choices,
|
||||
usage: Some(response.usage),
|
||||
system_fingerprint: response.system_fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user