diff --git a/specs/llm-types-design.md b/specs/llm-types-design.md new file mode 100644 index 0000000..2a36ac5 --- /dev/null +++ b/specs/llm-types-design.md @@ -0,0 +1,243 @@ +# 方案:重构 `types.rs` 为完整的 OpenAI 兼容 API 类型系统 + +## 1. 现状分析 + +### 当前问题 + +| 问题 | 详细 | +|------|------| +| **无 serde** | 所有类型只有 `Debug + Clone`,无 `Serialize/Deserialize`,迫使 `OpenaiProvider` 手动构建 JSON(354 行中约 200 行是序列化代码) | +| **请求参数不全** | `ChatRequest` 只支持 `model, messages, system_prompt, tools, max_tokens, temperature, extra_body`,缺失 streaming、response_format、tool_choice、stop、reasoning_effort 等 30+ 参数 | +| **响应类型太薄** | `ChatResponse` 只返回 `message + usage + stop_reason`,缺失 `id, created, model, choices` 数组、`logprobs`、`system_fingerprint` 等 | +| **无流式支持** | 无 `ChatCompletionChunk` 类型,无法处理 SSE 流式响应 | +| **反向依赖** | `types.rs` 引用 `cycle::usage::Usage`,造成模块间反向依赖 | +| **手动解析易出错** | `parse_response()` 从 `Value` 中逐字段解析,逻辑脆弱,不支持复杂嵌套类型 | + +### OpenAI API 参考文档覆盖范围 + +已完整阅读文档(2177 行),涵盖了完整的请求参数(35+ 个顶层参数)和响应结构。 + +## 2. 新类型系统设计 + +### 架构 + +将 `types.rs` 重构为 Rust 新风格模块目录(符合项目已有惯例),按功能领域拆分: + +``` +src/llm/ +├── types/ +│ ├── mod.rs # 模块根:re-exports + 基础枚举/共用类型 +│ ├── request.rs # 请求参数(ChatCompletionRequest 等) +│ ├── response.rs # 响应类型(ChatCompletionResponse + ChatCompletionChunk) +│ ├── message.rs # 消息类型(6 种角色消息 + content parts) +│ ├── tool.rs # 工具定义 + 工具调用 +│ ├── usage.rs # Token 用量(从 cycle/usage.rs 移入,消除反向依赖) +│ └── shared.rs # 共用枚举(ReasoningEffort, ServiceTier, ResponseFormat 等) +``` + +同时,将 `cycle/usage.rs` 中的 `Usage` 和 `CostTracker` **移到** `types/usage.rs`,`cycle/usage.rs` 保留 `pub use` 兼容 re-export。 + +### 核心决策 + +| 决策 | 选择 | 理由 | +|------|------|------| +| **序列化方式** | 全部类型 derive `Serialize, Deserialize` | 消除手动 JSON 构建,让 provider 直接 `.json(&req)` / `.json::()` | +| **类型风格** | 直接映射 OpenAI API JSON 形状 | 一目了然,与 API 文档 1:1 对应,调试方便 | +| **命名策略** | 添加 `OpenAI` 前缀(如 `OpenaiChatRequest`) | 明确标注为 OpenAI 兼容类型 | +| **字段命名** | `#[serde(rename_all = "snake_case")]` | OpenAI API 使用 snake_case | +| **可选字段** | `#[serde(skip_serializing_if = "Option::is_none")]` | 不序列化 None 字段,保持请求体干净 | +| **默认值** | `#[serde(default)]` | 反序列化时缺失字段用默认值 | +| **后向兼容** | 通过类型别名保持 `ChatRequest`/`ChatResponse` 等名称可用 | LlmProvider/LlmCycle 接口不变 | +| **泛化策略** | Anthropic 是独立体系,暂不纳入当前设计 | 保持当前类型系统专注 OpenAI,Provider 层做转换 | + +### 关键类型设计原则 + +- **`OpenaiChatRequest`**:统一结构体(不拆分 NonStreaming/Streaming),包含 `stream: Option` 字段,所有字段均为 `Option`,build 时 `skip_serializing_if` +- **`OpenaiChatResponse`**:直接对应 `ChatCompletion`(完整响应),保留完整 choices 数组等所有字段 +- **`OpenaiChatChunk`**:对应流式 chunk,`object = "chat.completion.chunk"` +- **消息系统**:用单个 `OpenaiChatMessage` enum 覆盖 6 种角色消息类型(Developer/System/User/Assistant/Tool/Function),每种内部使用对应 struct +- **Content parts**:`OpenaiContentPart` enum 覆盖 text/image_url/input_audio/file/refusal + +## 3. 完整类型清单 + +### `types/mod.rs` — 共用类型 +``` +Role → enum { Developer, System, User, Assistant, Tool, Function } +FinishReason → enum { Stop, Length, ToolCalls, ContentFilter, FunctionCall } +ServiceTier → enum { Auto, Default, Flex, Scale, Priority } +Modality → enum { Text, Audio } +ImageDetail → enum { Auto, Low, High } +AudioFormat → enum { Wav, Mp3, Aac, Flac, Opus, Pcm16 } +Voice → struct { id: String } 或预定义枚举 +SearchContextSize → enum { Low, Medium, High } +StopSequence → enum { Single(String), Multiple(Vec) } +Verbosity → enum { Low, Medium, High } +``` + +### `types/request.rs` — 请求参数 +``` +OpenaiChatRequest → struct (35+ 字段,所有 OpenAI 参数) +ResponseFormat → enum { Text, JsonObject { .. }, JsonSchema { .. } } +ToolChoice → enum { None, Auto, Required, Named { .. }, AllowedTools { .. } } +StreamOptions → struct { include_usage, include_obfuscation } +AudioParam → struct { format, voice } +PredictionContent → struct { type, content } +WebSearchOptions → struct { search_context_size, user_location } +UserLocation → struct { type, approximate: Approximate } +Approximate → struct { city, country, region, timezone } +FunctionCallOption → struct { name } // deprecated +FunctionDefinition → struct { name, description, parameters, strict } +OpenaiTool → enum { Function { .. }, Custom { .. } } +``` + +### `types/response.rs` — 响应类型 +``` +OpenaiChatResponse → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier } +Choice → struct { index, message, finish_reason, logprobs } +OpenaiChatMessage → struct { content, refusal, role, tool_calls, function_call, audio, annotations } +OpenaiChatChunk → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier } +ChunkChoice → struct { index, delta, logprobs, finish_reason } +Delta → struct { role, content, tool_calls, function_call } +Logprobs → struct { content, refusal } +TokenLogprob → struct { token, bytes, logprob, top_logprobs } +TopLogprob → struct { token, bytes, logprob } +Annotation → struct { type, url_citation } +URLCitation → struct { end_index, start_index, title, url } +OpenaiAudio → struct { id, data, expires_at, transcript } +FunctionCall → struct { name, arguments } +OpenaiToolCall → enum { Function { id, function, type }, Custom { id, custom, type } } +``` + +### `types/message.rs` — 消息类型 +``` +OpenaiChatMessage → enum (覆盖 6 种角色消息) + DeveloperMessage → struct { content, role, name } + SystemMessage → struct { content, role, name } + UserMessage → struct { content, role, name } + AssistantMessage → struct { content, refusal, role, name, tool_calls, function_call, audio } + ToolMessage → struct { content, role, tool_call_id } + FunctionMessage → struct { content, role, name } +OpenaiContentPart → enum + OpenaiContentPartText → struct { type, text } + OpenaiContentPartImage → struct { type, image_url: ImageURL } + OpenaiContentPartInputAudio → struct { type, input_audio: InputAudio } + OpenaiContentPartFile → struct { type, file: FileData } + OpenaiContentPartRefusal → struct { type, refusal } +ImageURL → struct { url, detail } +InputAudio → struct { data, format } +FileData → struct { file_data, file_id, filename } +``` + +### `types/tool.rs` — 工具类型 +``` +OpenaiToolDefinition → struct { name, description, parameters, strict } + (保留 ToolDefinition 别名保持后向兼容,重定义为包含所有字段) +OpenaiToolCall (在请求中使用) → 见 response.rs 中的定义 +``` + +### `types/usage.rs` — Token 用量 +``` +Usage → struct { prompt_tokens, completion_tokens, total_tokens, + completion_tokens_details, prompt_tokens_details } +CompletionTokensDetails → struct { reasoning_tokens, audio_tokens, + accepted_prediction_tokens, rejected_prediction_tokens } +PromptTokensDetails → struct { audio_tokens, cached_tokens } +CostTracker → 从 cycle/usage.rs 移入(累计追踪器) +``` + +### 删除的旧类型 +- `ContentBlock` → 被 `OpenaiContentPart` 替代(更准确的 OpenAI API 命名) +- `StopReason` → 被 `FinishReason` 替代(与 API 命名一致) +- `Message` → 被 `OpenaiChatMessage` 替代 + +### 类型别名(后向兼容) +``` +ChatRequest = OpenaiChatRequest +ChatResponse = OpenaiChatResponse +Message = OpenaiChatMessage +ContentBlock = OpenaiContentPart +ToolDefinition = OpenaiToolDefinition +Role = Role(保持不变,但扩展变体) +StopReason = FinishReason +``` + +## 4. 对其他模块的影响 + +### `provider/openai.rs` +- **大幅简化**:`build_request_body()` → 直接 `serde_json::to_value(&request)` +- `parse_response()` 中 100+ 行手动解析 → 直接 `serde_json::from_value::()` +- `serialize_messages()`, `serialize_message()`, `serialize_content_block()`, `serialize_tool()` → **全部删除** +- 新增 `chat_stream()` 方法返回 `OpenaiChatChunk` 流 +- 需要适配新类型的字段名变更(如 `Usage` 中 `input_tokens` → `prompt_tokens`) + +### `provider.rs` (trait) +- 接口保持不变,继续使用 `ChatRequest`/`ChatResponse` 类型别名 +- 调整 `Usage` 类型引用路径 + +### `cycle.rs` +- `CycleConfig` 扩展支持更多请求参数(至少增加 `tools, tool_choice, response_format, stop, reasoning_effort, seed` 等) +- `LlmCycle::submit()` 构建 `ChatRequest` 时使用新类型 +- `response.usage` 字段类型变更(新 `Usage` 含更多字段) +- 此时不添加流式支持 + +### `cycle/usage.rs` +- `Usage` 结构体**被移走**到 `types/usage.rs` +- `cycle/usage.rs` 保留 `pub use crate::llm::types::usage::{Usage, CostTracker};` 作为兼容性 re-export +- `CostTracker` 逻辑不变 + +### `error.rs` +- 无明显变更,错误类型和映射逻辑不变 + +## 5. 实施步骤 + +### Phase 1: 基础设施 +``` +1. [准备] 在 Cargo.toml 中确认 serde 依赖(已有 serde = "1",features = ["derive"]) +2. [创建] 新建 src/llm/types/ 目录 +``` + +### Phase 2: 类型定义(按依赖顺序) +``` +3. [usage.rs] 从 cycle/usage.rs 迁移 Usage + CostTracker +4. [shared.rs] 定义 Role, FinishReason, ServiceTier, Modality, ImageDetail, StopSequence, ResponseFormat +5. [message.rs] 定义 OpenaiChatMessage(6种角色)+ OpenaiContentPart + ImageURL + InputAudio +6. [tool.rs] 定义 OpenaiToolDefinition + OpenaiToolCall + FunctionCall +7. [request.rs] 定义 OpenaiChatRequest(35+ 字段)+ ToolChoice + StreamOptions +8. [response.rs] 定义 OpenaiChatResponse + OpenaiChatChunk + Choice + Delta + Logprobs +``` + +### Phase 3: 模块组装 +``` +9. [mod.rs] 创建模块根,re-export 所有类型 + 别名(ChatRequest = OpenaiChatRequest 等) +10. [usage.rs] 更新 cycle/usage.rs 为 pub use re-export +11. [删除] 删除旧 src/llm/types.rs +``` + +### Phase 4: Provider 适配 +``` +12. [provider/openai.rs] 重写为 serde 序列化(删除 ~200 行手动代码) +13. [cycle.rs] 适配新类型字段(prompt_tokens vs input_tokens) +``` + +### Phase 5: 验证 +``` +14. [编译] cargo check 确保编译通过 +15. [检查] cargo clippy 确保无警告 +16. [测试] cargo test 确保测试通过 +``` + +## 6. 验证方式 + +- `cargo check` — 编译通过 +- `cargo clippy` — 无警告 +- `cargo test` — 所有测试通过(如果有集成测试,可能需要调整) +- 检查 `OpenaiProvider` 代码量减少(预期从 354 行降至 ~150 行) +- 手动验证序列化输出是否符合 OpenAI API 格式 + +## 7. 注意事项 + +1. **Break change**: 某些类型名称变化(如 `StopReason` → `FinishReason`),项目处于早期阶段,可接受 +2. **后向兼容**: 通过类型别名保持旧名称可用,接口层无需修改 +3. **Anthropic 处理**: Anthropic 是独立体系,不在当前设计中泛化,单独实现 Provider +4. **异步流**: `chat_stream()` 的签名需要仔细设计(`Pin>>>` 或自定义类型) +5. **CostTracker 不变**: 虽然 Usage 变复杂了,但 CostTracker 只累计 input/output token 数,逻辑不变 diff --git a/src/llm/cycle.rs b/src/llm/cycle.rs index f710818..250fc73 100644 --- a/src/llm/cycle.rs +++ b/src/llm/cycle.rs @@ -7,19 +7,15 @@ pub use usage::{CostTracker, Usage}; use crate::llm::cycle::retry::should_retry; use crate::llm::error::LlmError; use crate::llm::provider::LlmProvider; -use crate::llm::types::{ChatRequest, ChatResponse, ContentBlock, Message, Role, ToolDefinition}; +use crate::llm::types::{ + ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition, +}; -/// LLM 生命周期引擎的配置。 pub struct CycleConfig { - /// 使用的模型名称。 pub model: String, - /// 最大输出 token 数。 pub max_tokens: Option, - /// 采样温度。 pub temperature: Option, - /// 最大对话轮数(预留,暂未使用)。 pub max_turns: Option, - /// 重试策略配置。 pub retry: RetryConfig, } @@ -35,22 +31,15 @@ impl Default for CycleConfig { } } -/// LLM 调用生命周期引擎。 -/// -/// 管理一次多轮交互的完整生命周期,包括: -/// - 消息历史维护 -/// - Token 用量追踪 -/// - 自动重试 pub struct LlmCycle { provider: Box, config: CycleConfig, usage: CostTracker, - messages: Vec, + messages: Vec, system_prompt: Option, } impl LlmCycle { - /// 创建新的生命周期引擎。 pub fn new(provider: Box, config: CycleConfig) -> Self { Self { provider, @@ -61,50 +50,33 @@ impl LlmCycle { } } - /// 设置系统提示词(Builder 模式)。 pub fn with_system_prompt(mut self, prompt: String) -> Self { self.system_prompt = Some(prompt); self } - /// 获取 Token 用量追踪器引用。 pub fn usage(&self) -> &CostTracker { &self.usage } - /// 获取当前消息历史。 - pub fn messages(&self) -> &[Message] { + pub fn messages(&self) -> &[OpenaiChatMessage] { &self.messages } - /// 清空消息历史。 pub fn clear_messages(&mut self) { self.messages.clear(); } - /// 重置 Token 用量统计。 pub fn reset_usage(&mut self) { self.usage.reset(); } - /// 提交一条用户消息并获取模型响应。 - /// - /// 流程: - /// 1. 将用户消息追加到消息历史 - /// 2. 构建 ChatRequest - /// 3. 使用重试循环调用 provider.chat() - /// 4. 将助手回复追加到消息历史 - /// 5. 累计 token 用量 - /// 6. 返回 ChatResponse pub async fn submit( &mut self, prompt: String, tools: Vec, ) -> Result { - self.messages.push(Message { - role: Role::User, - content: vec![ContentBlock::Text { text: prompt }], - }); + self.messages.push(OpenaiChatMessage::user_text(prompt)); let mut attempts = 0; @@ -113,10 +85,7 @@ impl LlmCycle { match self.provider.chat(request).await { Ok(response) => { - self.messages.push(Message { - role: Role::Assistant, - content: response.message.content.clone(), - }); + self.messages.push(response.message.clone()); self.usage.add(&response.usage); @@ -134,16 +103,35 @@ impl LlmCycle { } } - /// 根据当前状态构建 ChatRequest。 fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest { + let mut messages = self.messages.clone(); + + if let Some(sys_prompt) = &self.system_prompt + && !messages.iter().any(|m| matches!(m, OpenaiChatMessage::System { .. })) + { + messages.insert(0, OpenaiChatMessage::system_text(sys_prompt)); + } + + let openai_tools: Option> = if tools.is_empty() { + None + } else { + Some( + tools.iter() + .map(|t| OpenaiTool::Function { + function: t.clone(), + }) + .collect(), + ) + }; + ChatRequest { model: self.config.model.clone(), - messages: self.messages.clone(), - system_prompt: self.system_prompt.clone(), - tools: tools.to_vec(), + messages, max_tokens: self.config.max_tokens, temperature: self.config.temperature, - extra_body: None, + tools: openai_tools, + tool_choice: Some(ToolChoice::Auto), + ..Default::default() } } -} +} \ No newline at end of file diff --git a/src/llm/cycle/usage.rs b/src/llm/cycle/usage.rs index 32da25e..52974c8 100644 --- a/src/llm/cycle/usage.rs +++ b/src/llm/cycle/usage.rs @@ -1,42 +1 @@ -/// 单次请求的 Token 用量。 -#[derive(Debug, Clone, Default)] -pub struct Usage { - /// 输入(提示词)消耗的 token 数。 - pub input_tokens: u32, - /// 输出(生成内容)消耗的 token 数。 - pub output_tokens: u32, -} - -/// Token 用量累计追踪器。 -/// -/// 在多轮对话中累计所有请求的 token 消耗。 -#[derive(Debug, Default)] -pub struct CostTracker { - accumulated: Usage, -} - -impl CostTracker { - /// 累加一次请求的用量。 - /// - /// 使用 saturating_add 防止溢出。 - pub fn add(&mut self, usage: &Usage) { - self.accumulated.input_tokens = self - .accumulated - .input_tokens - .saturating_add(usage.input_tokens); - self.accumulated.output_tokens = self - .accumulated - .output_tokens - .saturating_add(usage.output_tokens); - } - - /// 获取累计用量。 - pub fn total(&self) -> &Usage { - &self.accumulated - } - - /// 重置累计用量。 - pub fn reset(&mut self) { - self.accumulated = Usage::default(); - } -} +pub use crate::llm::types::usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage, Usage as LlmUsage}; diff --git a/src/llm/provider/openai.rs b/src/llm/provider/openai.rs index e0b72b1..fb9a605 100644 --- a/src/llm/provider/openai.rs +++ b/src/llm/provider/openai.rs @@ -2,33 +2,19 @@ use std::time::Duration; use async_trait::async_trait; use reqwest::Client; -use serde_json::{json, Value}; -use crate::llm::cycle::usage::Usage; use crate::llm::error::LlmError; -use crate::llm::types::{ - ChatRequest, ChatResponse, ContentBlock, Message, Role, StopReason, ToolDefinition, -}; - +use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse}; use super::LlmProvider; -/// OpenAI 兼容 API 的 Provider 实现。 -/// -/// 支持任意实现了 `POST /v1/chat/completions` 标准的 API -/// (包括 OpenAI、Azure OpenAI、DashScope、vLLM 等)。 pub struct OpenaiProvider { http_client: Client, base_url: String, api_key: String, - #[allow(dead_code)] - model: String, } impl OpenaiProvider { - /// 创建新的 OpenAI Provider。 - /// - /// 默认使用 120 秒超时的 HTTP 客户端。 - pub fn new(base_url: String, api_key: String, model: String) -> Self { + pub fn new(base_url: String, api_key: String, _model: String) -> Self { let http_client = Client::builder() .timeout(Duration::from_secs(120)) .build() @@ -38,284 +24,45 @@ impl OpenaiProvider { http_client, base_url, api_key, - model, } } - /// 替换为自定义的 HTTP 客户端(用于测试或自定义配置)。 pub fn with_client(mut self, client: Client) -> Self { self.http_client = client; self } - /// 将 ChatRequest 构建为 OpenAI API 请求体 JSON。 - fn build_request_body(&self, request: &ChatRequest) -> Value { - let mut body = json!({ - "model": request.model, - "messages": Self::serialize_messages(request), - }); - - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = json!(max_tokens); - } - if let Some(temperature) = request.temperature { - body["temperature"] = json!(temperature); - } - if !request.tools.is_empty() { - body["tools"] = json!( - request - .tools - .iter() - .map(Self::serialize_tool) - .collect::>() - ); - } - - // 合并 extra_body 中的扩展参数到请求体顶层 - if let Some(ref extra) = request.extra_body - && let Some(obj) = extra.as_object() - { - for (k, v) in obj { - body[k] = v.clone(); + fn map_reqwest_error(e: reqwest::Error) -> LlmError { + if e.is_timeout() { + LlmError::Timeout { + duration: Duration::from_secs(120), } - } - - body - } - - /// 将请求中的消息列表序列化为 API 消息数组。 - fn serialize_messages(request: &ChatRequest) -> Vec { - let mut messages: Vec = Vec::new(); - - // system_prompt 作为独立的 system 角色消息放在最前面 - if let Some(ref system_prompt) = request.system_prompt { - messages.push(json!({ - "role": "system", - "content": system_prompt - })); - } - - for msg in &request.messages { - messages.push(Self::serialize_message(msg)); - } - - messages - } - - /// 将单条消息序列化为 API 格式。 - /// - /// 处理逻辑: - /// - 多个 content block 或包含图片 → 使用数组格式 - /// - ToolResult → 使用 tool 角色格式 - /// - 其他 → 使用纯文本格式 - fn serialize_message(msg: &Message) -> Value { - let role_str = match msg.role { - Role::User => "user", - Role::Assistant => "assistant", - Role::System => "system", - Role::Tool => "tool", - }; - - let has_mixed = msg.content.len() > 1 - || msg - .content - .iter() - .any(|b| matches!(b, ContentBlock::ImageUrl { .. })); - - if has_mixed { - let content: Vec = msg - .content - .iter() - .map(Self::serialize_content_block) - .collect(); - json!({ "role": role_str, "content": content }) - } else if let Some(ContentBlock::ToolResult { - tool_use_id, - content, - }) = msg.content.first() - { - json!({ - "role": "tool", - "tool_call_id": tool_use_id, - "content": content - }) + } else if e.is_connect() { + LlmError::Other(format!("连接失败: {}", e)) } else { - let text = msg - .content - .first() - .map(|b| match b { - ContentBlock::Text { text } => text.clone(), - _ => String::new(), - }) - .unwrap_or_default(); - json!({ "role": role_str, "content": text }) + LlmError::Other(format!("请求失败: {}", e)) } } - - /// 将 ContentBlock 序列化为 API content parts 数组元素。 - fn serialize_content_block(block: &ContentBlock) -> Value { - match block { - ContentBlock::Text { text } => { - json!({ "type": "text", "text": text }) - } - ContentBlock::ImageUrl { url } => { - json!({ "type": "image_url", "image_url": { "url": url } }) - } - ContentBlock::ToolUse { id, name, input } => { - json!({ "type": "tool_use", "id": id, "name": name, "input": input }) - } - ContentBlock::ToolResult { .. } => { - json!({ "type": "tool_result", "content": "" }) - } - } - } - - /// 将 ToolDefinition 序列化为 OpenAI tools 数组元素。 - fn serialize_tool(tool: &ToolDefinition) -> Value { - json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema - } - }) - } - - /// 将 OpenAI API 响应 JSON 解析为 ChatResponse。 - fn parse_response(response: Value) -> Result { - let choice = response["choices"][0] - .as_object() - .ok_or_else(|| LlmError::Other("响应中缺少 choices[0]".into()))?; - - let msg = choice["message"] - .as_object() - .ok_or_else(|| LlmError::Other("响应中缺少 message".into()))?; - - let role = match msg["role"].as_str() { - Some("assistant") => Role::Assistant, - Some(_) => Role::Assistant, - None => Role::Assistant, - }; - - let mut content_blocks: Vec = Vec::new(); - - // 从 content 字段提取文本和 tool_use - if let Some(content_val) = msg.get("content") { - match content_val { - Value::String(s) if !s.is_empty() => { - content_blocks.push(ContentBlock::Text { text: s.clone() }); - } - Value::Array(arr) => { - for item in arr { - if let Some(item_type) = item["type"].as_str() { - match item_type { - "text" => { - if let Some(text) = item["text"].as_str() { - content_blocks - .push(ContentBlock::Text { text: text.into() }); - } - } - "tool_use" | "function" => { - let id = item["id"].as_str().unwrap_or("").to_string(); - let name = item["name"].as_str().unwrap_or("").to_string(); - let input = item.get("input").cloned().unwrap_or(Value::Null); - content_blocks - .push(ContentBlock::ToolUse { id, name, input }); - } - _ => {} - } - } - } - } - _ => {} - } - } - - // 从 tool_calls 字段提取工具调用(OpenAI 特有格式) - if let Some(tool_calls) = msg.get("tool_calls").and_then(|v| v.as_array()) { - for tc in tool_calls { - let id = tc["id"].as_str().unwrap_or("").to_string(); - let name = tc["function"]["name"].as_str().unwrap_or("").to_string(); - let input = tc["function"]["arguments"] - .as_str() - .and_then(|s| serde_json::from_str(s).ok()) - .unwrap_or(Value::Null); - content_blocks.push(ContentBlock::ToolUse { id, name, input }); - } - } - - if content_blocks.is_empty() { - content_blocks.push(ContentBlock::Text { - text: String::new(), - }); - } - - // 解析停止原因 - let stop_reason = choice["finish_reason"].as_str().map(|s| match s { - "stop" => StopReason::Stop, - "tool_calls" => StopReason::ToolUse, - "max_tokens" => StopReason::MaxTokens, - "length" => StopReason::Length, - "content_filter" => StopReason::ContentFilter, - other => StopReason::Other(other.into()), - }); - - // 解析 token 用量 - let usage = response["usage"] - .as_object() - .map(|u| Usage { - input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32, - output_tokens: u - .get("completion_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as u32, - }) - .unwrap_or_default(); - - Ok(ChatResponse { - message: Message { - role, - content: content_blocks, - }, - usage, - stop_reason, - }) - } } #[async_trait] impl LlmProvider for OpenaiProvider { async fn chat(&self, request: ChatRequest) -> Result { let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); - let body = self.build_request_body(&request); let response = self .http_client .post(&url) .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) + .json(&request) .send() .await - .map_err(|e| { - if e.is_timeout() { - LlmError::Timeout { - duration: Duration::from_secs(120), - } - } else if e.is_connect() { - LlmError::Other(format!("连接失败: {}", e)) - } else { - LlmError::Other(format!("请求失败: {}", e)) - } - })?; + .map_err(Self::map_reqwest_error)?; let status = response.status(); let status_code: u16 = status.as_u16(); - // 处理非 2xx 响应,将 HTTP 状态码映射为对应的 LlmError 变体 if !status.is_success() { - // 在消费 response body 之前先读取 retry-after 头部 let retry_after = response .headers() .get("retry-after") @@ -344,11 +91,11 @@ impl LlmProvider for OpenaiProvider { }; } - let json_body: Value = response + let chat_response: OpenaiChatResponse = response .json() .await .map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?; - Self::parse_response(json_body) + Ok(ChatResponse::from(chat_response)) } -} +} \ No newline at end of file diff --git a/src/llm/types.rs b/src/llm/types.rs deleted file mode 100644 index 950ac44..0000000 --- a/src/llm/types.rs +++ /dev/null @@ -1,100 +0,0 @@ -use crate::llm::cycle::usage::Usage; -use serde_json::Value; - -/// 对话消息的角色。 -#[derive(Debug, Clone)] -pub enum Role { - User, - Assistant, - System, - Tool, -} - -/// 消息内容块,支持多模态及工具调用。 -#[derive(Debug, Clone)] -pub enum ContentBlock { - /// 纯文本内容。 - Text { - text: String, - }, - /// 图片 URL(多模态输入预留)。 - ImageUrl { - url: String, - }, - /// 模型发起的工具调用(预留,暂不实现自动执行)。 - ToolUse { - id: String, - name: String, - input: Value, - }, - /// 工具执行结果的回传(预留,暂不实现自动执行)。 - ToolResult { - tool_use_id: String, - content: String, - }, -} - -/// 一条对话消息,由角色和内容块列表组成。 -#[derive(Debug, Clone)] -pub struct Message { - pub role: Role, - pub content: Vec, -} - -/// 可供模型调用的工具定义。 -#[derive(Debug, Clone)] -pub struct ToolDefinition { - /// 工具名称。 - pub name: String, - /// 工具描述,用于模型理解何时调用。 - pub description: String, - /// JSON Schema 格式的输入参数定义。 - pub input_schema: Value, -} - -/// 对 /v1/chat/completions 的完整请求参数。 -#[derive(Debug, Clone)] -pub struct ChatRequest { - /// 模型标识(如 "gpt-4o")。 - pub model: String, - /// 对话历史 + 新消息。 - pub messages: Vec, - /// 独立的系统提示词,将在序列化时转为 system 角色消息。 - pub system_prompt: Option, - /// 可用的工具定义列表。 - pub tools: Vec, - /// 最大输出 token 数。 - pub max_tokens: Option, - /// 采样温度。 - pub temperature: Option, - /// 扩展参数(如 enable_thinking),会合并到请求体顶层。 - pub extra_body: Option, -} - -/// 模型返回的完整响应。 -#[derive(Debug, Clone)] -pub struct ChatResponse { - /// 助手的回复消息。 - pub message: Message, - /// 本次请求的 token 用量。 - pub usage: Usage, - /// 停止原因。 - pub stop_reason: Option, -} - -/// 模型停止生成的原因。 -#[derive(Debug, Clone)] -pub enum StopReason { - /// 正常结束。 - Stop, - /// 模型请求调用工具(预留)。 - ToolUse, - /// 达到 max_tokens 上限。 - MaxTokens, - /// 内容被安全过滤。 - ContentFilter, - /// 长度限制(兼容某些 API 的 finish_reason)。 - Length, - /// 其他未分类的原因。 - Other(String), -} diff --git a/src/llm/types/message.rs b/src/llm/types/message.rs new file mode 100644 index 0000000..d81a7f3 --- /dev/null +++ b/src/llm/types/message.rs @@ -0,0 +1,125 @@ +use crate::llm::types::shared::{AudioFormat, ImageDetail}; +use crate::llm::types::tool::OpenaiToolCall; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageURL { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InputAudio { + pub data: String, + pub format: AudioFormat, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileData { + pub file_data: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum OpenaiContentPart { + Text { + text: String, + }, + Image { + image_url: ImageURL, + #[serde(skip_serializing_if = "Option::is_none")] + detail: Option, + }, + InputAudio { + input_audio: InputAudio, + }, + File { + file: FileData, + }, + Refusal { + refusal: String, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "role")] +pub enum OpenaiChatMessage { + Developer { + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + System { + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + User { + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + Assistant { + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + refusal: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + }, + Tool { + content: Vec, + tool_call_id: String, + }, + Function { + content: Vec, + name: String, + }, +} + +impl OpenaiChatMessage { + pub fn user_text>(text: S) -> Self { + OpenaiChatMessage::User { + content: vec![OpenaiContentPart::Text { text: text.into() }], + name: None, + } + } + + pub fn assistant_text>(text: S) -> Self { + OpenaiChatMessage::Assistant { + content: vec![OpenaiContentPart::Text { text: text.into() }], + refusal: None, + name: None, + tool_calls: None, + } + } + + pub fn system_text>(text: S) -> Self { + OpenaiChatMessage::System { + content: vec![OpenaiContentPart::Text { text: text.into() }], + name: None, + } + } + + pub fn developer_text>(text: S) -> Self { + OpenaiChatMessage::Developer { + content: vec![OpenaiContentPart::Text { text: text.into() }], + name: None, + } + } + + pub fn tool_result>(tool_call_id: String, content: S) -> Self { + OpenaiChatMessage::Tool { + content: vec![OpenaiContentPart::Text { + text: content.into(), + }], + tool_call_id, + } + } +} diff --git a/src/llm/types/mod.rs b/src/llm/types/mod.rs new file mode 100644 index 0000000..b3747a9 --- /dev/null +++ b/src/llm/types/mod.rs @@ -0,0 +1,53 @@ +pub mod message; +pub mod request; +pub mod response; +pub mod shared; +pub mod tool; +pub mod usage; + +pub use message::{ + FileData, ImageURL, InputAudio, OpenaiChatMessage, OpenaiContentPart, +}; +pub use request::{OpenaiChatRequest, OpenaiTool, StreamOptions, ToolChoice}; +pub use response::{ + Annotation, Choice, ChunkChoice, Delta, Logprobs, OpenaiAudio, + OpenaiChatChunk, OpenaiChatResponse, TokenLogprob, TopLogprob, URLCitation, +}; +pub use shared::{ + AudioFormat, FinishReason, ImageDetail, Modality, ResponseFormat, Role, + ServiceTier, StopSequence, +}; +pub use tool::{FunctionCall, OpenaiToolCall, OpenaiToolDefinition}; +pub use usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage}; + +#[derive(Debug, Clone)] +pub struct ChatResponse { + pub message: OpenaiChatMessage, + pub usage: Usage, + pub stop_reason: Option, +} + +impl From for ChatResponse { + fn from(response: OpenaiChatResponse) -> Self { + let message = response + .choices + .first() + .map(|c| c.message.clone()) + .unwrap_or_else(|| OpenaiChatMessage::assistant_text("")); + let stop_reason = response + .choices + .first() + .and_then(|c| c.finish_reason); + ChatResponse { + message, + usage: response.usage, + stop_reason, + } + } +} + +pub type ChatRequest = OpenaiChatRequest; +pub type Message = OpenaiChatMessage; +pub type ContentBlock = OpenaiContentPart; +pub type ToolDefinition = OpenaiToolDefinition; +pub type StopReason = FinishReason; \ No newline at end of file diff --git a/src/llm/types/request.rs b/src/llm/types/request.rs new file mode 100644 index 0000000..5530b9f --- /dev/null +++ b/src/llm/types/request.rs @@ -0,0 +1,117 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use crate::llm::types::shared::{ResponseFormat, ServiceTier, StopSequence}; +use crate::llm::types::tool::OpenaiToolDefinition; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_obfuscation: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum ToolChoice { + None, + Auto, + Required, + Named { + name: String, + }, + AllowedTools { + #[serde(rename = "tools")] + tool_names: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum OpenaiTool { + Function { + function: OpenaiToolDefinition, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AudioParam { + pub format: String, + pub voice: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictionContent { + #[serde(rename = "type")] + pub pred_type: String, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserLocation { + #[serde(rename = "type")] + pub loc_type: String, + pub approximate: Approximate, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Approximate { + pub city: String, + pub country: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub region: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timezone: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WebSearchOptions { + pub search_context_size: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub user_location: Option, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct OpenaiChatRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub extra_headers: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub extra_body: Option, +} \ No newline at end of file diff --git a/src/llm/types/response.rs b/src/llm/types/response.rs new file mode 100644 index 0000000..90011ce --- /dev/null +++ b/src/llm/types/response.rs @@ -0,0 +1,117 @@ +use serde::{Deserialize, Serialize}; +use crate::llm::types::shared::{FinishReason, ServiceTier}; +use crate::llm::types::message::OpenaiChatMessage; +use crate::llm::types::tool::OpenaiToolCall; +use crate::llm::types::usage::Usage; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenLogprob { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, + pub logprob: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TopLogprob { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, + pub logprob: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Logprobs { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct URLCitation { + pub end_index: u32, + pub start_index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub url: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Annotation { + #[serde(rename = "type")] + pub ann_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub url_citation: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenaiAudio { + pub id: String, + pub data: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub transcript: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Choice { + pub index: u32, + pub message: OpenaiChatMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenaiChatResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Delta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChunkChoice { + pub index: u32, + pub delta: Delta, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenaiChatChunk { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} \ No newline at end of file diff --git a/src/llm/types/shared.rs b/src/llm/types/shared.rs new file mode 100644 index 0000000..210e0fc --- /dev/null +++ b/src/llm/types/shared.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Role { + Developer, + System, + User, + Assistant, + Tool, + Function, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ToolCalls, + ContentFilter, + FunctionCall, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + Default, + #[serde(other)] + Other, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Modality { + Text, + Audio, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ImageDetail { + Auto, + Low, + High, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AudioFormat { + Wav, + Mp3, + Aac, + Flac, + Opus, + Pcm16, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StopSequence { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum ResponseFormat { + Text, + JsonObject, + JsonSchema { + schema: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + strict: Option, + }, +} \ No newline at end of file diff --git a/src/llm/types/tool.rs b/src/llm/types/tool.rs new file mode 100644 index 0000000..1506361 --- /dev/null +++ b/src/llm/types/tool.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenaiToolDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum OpenaiToolCall { + Function { + id: String, + function: FunctionCall, + }, +} \ No newline at end of file diff --git a/src/llm/types/usage.rs b/src/llm/types/usage.rs new file mode 100644 index 0000000..e7121f6 --- /dev/null +++ b/src/llm/types/usage.rs @@ -0,0 +1,75 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct CompletionTokensDetails { + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub audio_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub accepted_prediction_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rejected_prediction_tokens: Option, +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct PromptTokensDetails { + #[serde(skip_serializing_if = "Option::is_none")] + pub audio_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cached_tokens: Option, +} + +#[derive(Debug, Default)] +pub struct CostTracker { + accumulated: Usage, +} + +impl CostTracker { + pub fn add(&mut self, usage: &Usage) { + self.accumulated.prompt_tokens = self + .accumulated + .prompt_tokens + .saturating_add(usage.prompt_tokens); + self.accumulated.completion_tokens = self + .accumulated + .completion_tokens + .saturating_add(usage.completion_tokens); + self.accumulated.total_tokens = self + .accumulated + .total_tokens + .saturating_add(usage.total_tokens); + } + + pub fn total(&self) -> &Usage { + &self.accumulated + } + + pub fn reset(&mut self) { + self.accumulated = Usage::default(); + } +} + +impl Usage { + pub fn from_input_output(input: u32, output: u32) -> Self { + let total = input.saturating_add(output); + Usage { + prompt_tokens: input, + completion_tokens: output, + total_tokens: total, + completion_tokens_details: None, + prompt_tokens_details: None, + } + } +} \ No newline at end of file