feat(llm): 重构 types.rs 为完整的 OpenAI 兼容类型系统
将 `types.rs` 拆分为模块化目录,所有类型派生 `Serialize/Deserialize`, 并新增 `OpenaiChatChunk`、`Role` 扩展等 30+ 缺失类型 消除对 `cycle/usage.rs` 的反向依赖,`Usage`/`CostTracker` 移至 `types/usage.rs`
This commit is contained in:
@@ -0,0 +1,243 @@
|
|||||||
|
# 方案:重构 `types.rs` 为完整的 OpenAI 兼容 API 类型系统
|
||||||
|
|
||||||
|
## 1. 现状分析
|
||||||
|
|
||||||
|
### 当前问题
|
||||||
|
|
||||||
|
| 问题 | 详细 |
|
||||||
|
|------|------|
|
||||||
|
| **无 serde** | 所有类型只有 `Debug + Clone`,无 `Serialize/Deserialize`,迫使 `OpenaiProvider` 手动构建 JSON(354 行中约 200 行是序列化代码) |
|
||||||
|
| **请求参数不全** | `ChatRequest` 只支持 `model, messages, system_prompt, tools, max_tokens, temperature, extra_body`,缺失 streaming、response_format、tool_choice、stop、reasoning_effort 等 30+ 参数 |
|
||||||
|
| **响应类型太薄** | `ChatResponse` 只返回 `message + usage + stop_reason`,缺失 `id, created, model, choices` 数组、`logprobs`、`system_fingerprint` 等 |
|
||||||
|
| **无流式支持** | 无 `ChatCompletionChunk` 类型,无法处理 SSE 流式响应 |
|
||||||
|
| **反向依赖** | `types.rs` 引用 `cycle::usage::Usage`,造成模块间反向依赖 |
|
||||||
|
| **手动解析易出错** | `parse_response()` 从 `Value` 中逐字段解析,逻辑脆弱,不支持复杂嵌套类型 |
|
||||||
|
|
||||||
|
### OpenAI API 参考文档覆盖范围
|
||||||
|
|
||||||
|
已完整阅读文档(2177 行),涵盖了完整的请求参数(35+ 个顶层参数)和响应结构。
|
||||||
|
|
||||||
|
## 2. 新类型系统设计
|
||||||
|
|
||||||
|
### 架构
|
||||||
|
|
||||||
|
将 `types.rs` 重构为 Rust 新风格模块目录(符合项目已有惯例),按功能领域拆分:
|
||||||
|
|
||||||
|
```
|
||||||
|
src/llm/
|
||||||
|
├── types/
|
||||||
|
│ ├── mod.rs # 模块根:re-exports + 基础枚举/共用类型
|
||||||
|
│ ├── request.rs # 请求参数(ChatCompletionRequest 等)
|
||||||
|
│ ├── response.rs # 响应类型(ChatCompletionResponse + ChatCompletionChunk)
|
||||||
|
│ ├── message.rs # 消息类型(6 种角色消息 + content parts)
|
||||||
|
│ ├── tool.rs # 工具定义 + 工具调用
|
||||||
|
│ ├── usage.rs # Token 用量(从 cycle/usage.rs 移入,消除反向依赖)
|
||||||
|
│ └── shared.rs # 共用枚举(ReasoningEffort, ServiceTier, ResponseFormat 等)
|
||||||
|
```
|
||||||
|
|
||||||
|
同时,将 `cycle/usage.rs` 中的 `Usage` 和 `CostTracker` **移到** `types/usage.rs`,`cycle/usage.rs` 保留 `pub use` 兼容 re-export。
|
||||||
|
|
||||||
|
### 核心决策
|
||||||
|
|
||||||
|
| 决策 | 选择 | 理由 |
|
||||||
|
|------|------|------|
|
||||||
|
| **序列化方式** | 全部类型 derive `Serialize, Deserialize` | 消除手动 JSON 构建,让 provider 直接 `.json(&req)` / `.json::<Res>()` |
|
||||||
|
| **类型风格** | 直接映射 OpenAI API JSON 形状 | 一目了然,与 API 文档 1:1 对应,调试方便 |
|
||||||
|
| **命名策略** | 添加 `OpenAI` 前缀(如 `OpenaiChatRequest`) | 明确标注为 OpenAI 兼容类型 |
|
||||||
|
| **字段命名** | `#[serde(rename_all = "snake_case")]` | OpenAI API 使用 snake_case |
|
||||||
|
| **可选字段** | `#[serde(skip_serializing_if = "Option::is_none")]` | 不序列化 None 字段,保持请求体干净 |
|
||||||
|
| **默认值** | `#[serde(default)]` | 反序列化时缺失字段用默认值 |
|
||||||
|
| **后向兼容** | 通过类型别名保持 `ChatRequest`/`ChatResponse` 等名称可用 | LlmProvider/LlmCycle 接口不变 |
|
||||||
|
| **泛化策略** | Anthropic 是独立体系,暂不纳入当前设计 | 保持当前类型系统专注 OpenAI,Provider 层做转换 |
|
||||||
|
|
||||||
|
### 关键类型设计原则
|
||||||
|
|
||||||
|
- **`OpenaiChatRequest`**:统一结构体(不拆分 NonStreaming/Streaming),包含 `stream: Option<bool>` 字段,所有字段均为 `Option`,build 时 `skip_serializing_if`
|
||||||
|
- **`OpenaiChatResponse`**:直接对应 `ChatCompletion`(完整响应),保留完整 choices 数组等所有字段
|
||||||
|
- **`OpenaiChatChunk`**:对应流式 chunk,`object = "chat.completion.chunk"`
|
||||||
|
- **消息系统**:用单个 `OpenaiChatMessage` enum 覆盖 6 种角色消息类型(Developer/System/User/Assistant/Tool/Function),每种内部使用对应 struct
|
||||||
|
- **Content parts**:`OpenaiContentPart` enum 覆盖 text/image_url/input_audio/file/refusal
|
||||||
|
|
||||||
|
## 3. 完整类型清单
|
||||||
|
|
||||||
|
### `types/mod.rs` — 共用类型
|
||||||
|
```
|
||||||
|
Role → enum { Developer, System, User, Assistant, Tool, Function }
|
||||||
|
FinishReason → enum { Stop, Length, ToolCalls, ContentFilter, FunctionCall }
|
||||||
|
ServiceTier → enum { Auto, Default, Flex, Scale, Priority }
|
||||||
|
Modality → enum { Text, Audio }
|
||||||
|
ImageDetail → enum { Auto, Low, High }
|
||||||
|
AudioFormat → enum { Wav, Mp3, Aac, Flac, Opus, Pcm16 }
|
||||||
|
Voice → struct { id: String } 或预定义枚举
|
||||||
|
SearchContextSize → enum { Low, Medium, High }
|
||||||
|
StopSequence → enum { Single(String), Multiple(Vec<String>) }
|
||||||
|
Verbosity → enum { Low, Medium, High }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/request.rs` — 请求参数
|
||||||
|
```
|
||||||
|
OpenaiChatRequest → struct (35+ 字段,所有 OpenAI 参数)
|
||||||
|
ResponseFormat → enum { Text, JsonObject { .. }, JsonSchema { .. } }
|
||||||
|
ToolChoice → enum { None, Auto, Required, Named { .. }, AllowedTools { .. } }
|
||||||
|
StreamOptions → struct { include_usage, include_obfuscation }
|
||||||
|
AudioParam → struct { format, voice }
|
||||||
|
PredictionContent → struct { type, content }
|
||||||
|
WebSearchOptions → struct { search_context_size, user_location }
|
||||||
|
UserLocation → struct { type, approximate: Approximate }
|
||||||
|
Approximate → struct { city, country, region, timezone }
|
||||||
|
FunctionCallOption → struct { name } // deprecated
|
||||||
|
FunctionDefinition → struct { name, description, parameters, strict }
|
||||||
|
OpenaiTool → enum { Function { .. }, Custom { .. } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/response.rs` — 响应类型
|
||||||
|
```
|
||||||
|
OpenaiChatResponse → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||||
|
Choice → struct { index, message, finish_reason, logprobs }
|
||||||
|
OpenaiChatMessage → struct { content, refusal, role, tool_calls, function_call, audio, annotations }
|
||||||
|
OpenaiChatChunk → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||||
|
ChunkChoice → struct { index, delta, logprobs, finish_reason }
|
||||||
|
Delta → struct { role, content, tool_calls, function_call }
|
||||||
|
Logprobs → struct { content, refusal }
|
||||||
|
TokenLogprob → struct { token, bytes, logprob, top_logprobs }
|
||||||
|
TopLogprob → struct { token, bytes, logprob }
|
||||||
|
Annotation → struct { type, url_citation }
|
||||||
|
URLCitation → struct { end_index, start_index, title, url }
|
||||||
|
OpenaiAudio → struct { id, data, expires_at, transcript }
|
||||||
|
FunctionCall → struct { name, arguments }
|
||||||
|
OpenaiToolCall → enum { Function { id, function, type }, Custom { id, custom, type } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/message.rs` — 消息类型
|
||||||
|
```
|
||||||
|
OpenaiChatMessage → enum (覆盖 6 种角色消息)
|
||||||
|
DeveloperMessage → struct { content, role, name }
|
||||||
|
SystemMessage → struct { content, role, name }
|
||||||
|
UserMessage → struct { content, role, name }
|
||||||
|
AssistantMessage → struct { content, refusal, role, name, tool_calls, function_call, audio }
|
||||||
|
ToolMessage → struct { content, role, tool_call_id }
|
||||||
|
FunctionMessage → struct { content, role, name }
|
||||||
|
OpenaiContentPart → enum
|
||||||
|
OpenaiContentPartText → struct { type, text }
|
||||||
|
OpenaiContentPartImage → struct { type, image_url: ImageURL }
|
||||||
|
OpenaiContentPartInputAudio → struct { type, input_audio: InputAudio }
|
||||||
|
OpenaiContentPartFile → struct { type, file: FileData }
|
||||||
|
OpenaiContentPartRefusal → struct { type, refusal }
|
||||||
|
ImageURL → struct { url, detail }
|
||||||
|
InputAudio → struct { data, format }
|
||||||
|
FileData → struct { file_data, file_id, filename }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/tool.rs` — 工具类型
|
||||||
|
```
|
||||||
|
OpenaiToolDefinition → struct { name, description, parameters, strict }
|
||||||
|
(保留 ToolDefinition 别名保持后向兼容,重定义为包含所有字段)
|
||||||
|
OpenaiToolCall (在请求中使用) → 见 response.rs 中的定义
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/usage.rs` — Token 用量
|
||||||
|
```
|
||||||
|
Usage → struct { prompt_tokens, completion_tokens, total_tokens,
|
||||||
|
completion_tokens_details, prompt_tokens_details }
|
||||||
|
CompletionTokensDetails → struct { reasoning_tokens, audio_tokens,
|
||||||
|
accepted_prediction_tokens, rejected_prediction_tokens }
|
||||||
|
PromptTokensDetails → struct { audio_tokens, cached_tokens }
|
||||||
|
CostTracker → 从 cycle/usage.rs 移入(累计追踪器)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 删除的旧类型
|
||||||
|
- `ContentBlock` → 被 `OpenaiContentPart` 替代(更准确的 OpenAI API 命名)
|
||||||
|
- `StopReason` → 被 `FinishReason` 替代(与 API 命名一致)
|
||||||
|
- `Message` → 被 `OpenaiChatMessage` 替代
|
||||||
|
|
||||||
|
### 类型别名(后向兼容)
|
||||||
|
```
|
||||||
|
ChatRequest = OpenaiChatRequest
|
||||||
|
ChatResponse = OpenaiChatResponse
|
||||||
|
Message = OpenaiChatMessage
|
||||||
|
ContentBlock = OpenaiContentPart
|
||||||
|
ToolDefinition = OpenaiToolDefinition
|
||||||
|
Role = Role(保持不变,但扩展变体)
|
||||||
|
StopReason = FinishReason
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 对其他模块的影响
|
||||||
|
|
||||||
|
### `provider/openai.rs`
|
||||||
|
- **大幅简化**:`build_request_body()` → 直接 `serde_json::to_value(&request)`
|
||||||
|
- `parse_response()` 中 100+ 行手动解析 → 直接 `serde_json::from_value::<OpenaiChatResponse>()`
|
||||||
|
- `serialize_messages()`, `serialize_message()`, `serialize_content_block()`, `serialize_tool()` → **全部删除**
|
||||||
|
- 新增 `chat_stream()` 方法返回 `OpenaiChatChunk` 流
|
||||||
|
- 需要适配新类型的字段名变更(如 `Usage` 中 `input_tokens` → `prompt_tokens`)
|
||||||
|
|
||||||
|
### `provider.rs` (trait)
|
||||||
|
- 接口保持不变,继续使用 `ChatRequest`/`ChatResponse` 类型别名
|
||||||
|
- 调整 `Usage` 类型引用路径
|
||||||
|
|
||||||
|
### `cycle.rs`
|
||||||
|
- `CycleConfig` 扩展支持更多请求参数(至少增加 `tools, tool_choice, response_format, stop, reasoning_effort, seed` 等)
|
||||||
|
- `LlmCycle::submit()` 构建 `ChatRequest` 时使用新类型
|
||||||
|
- `response.usage` 字段类型变更(新 `Usage` 含更多字段)
|
||||||
|
- 此时不添加流式支持
|
||||||
|
|
||||||
|
### `cycle/usage.rs`
|
||||||
|
- `Usage` 结构体**被移走**到 `types/usage.rs`
|
||||||
|
- `cycle/usage.rs` 保留 `pub use crate::llm::types::usage::{Usage, CostTracker};` 作为兼容性 re-export
|
||||||
|
- `CostTracker` 逻辑不变
|
||||||
|
|
||||||
|
### `error.rs`
|
||||||
|
- 无明显变更,错误类型和映射逻辑不变
|
||||||
|
|
||||||
|
## 5. 实施步骤
|
||||||
|
|
||||||
|
### Phase 1: 基础设施
|
||||||
|
```
|
||||||
|
1. [准备] 在 Cargo.toml 中确认 serde 依赖(已有 serde = "1",features = ["derive"])
|
||||||
|
2. [创建] 新建 src/llm/types/ 目录
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 2: 类型定义(按依赖顺序)
|
||||||
|
```
|
||||||
|
3. [usage.rs] 从 cycle/usage.rs 迁移 Usage + CostTracker
|
||||||
|
4. [shared.rs] 定义 Role, FinishReason, ServiceTier, Modality, ImageDetail, StopSequence, ResponseFormat
|
||||||
|
5. [message.rs] 定义 OpenaiChatMessage(6种角色)+ OpenaiContentPart + ImageURL + InputAudio
|
||||||
|
6. [tool.rs] 定义 OpenaiToolDefinition + OpenaiToolCall + FunctionCall
|
||||||
|
7. [request.rs] 定义 OpenaiChatRequest(35+ 字段)+ ToolChoice + StreamOptions
|
||||||
|
8. [response.rs] 定义 OpenaiChatResponse + OpenaiChatChunk + Choice + Delta + Logprobs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 3: 模块组装
|
||||||
|
```
|
||||||
|
9. [mod.rs] 创建模块根,re-export 所有类型 + 别名(ChatRequest = OpenaiChatRequest 等)
|
||||||
|
10. [usage.rs] 更新 cycle/usage.rs 为 pub use re-export
|
||||||
|
11. [删除] 删除旧 src/llm/types.rs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 4: Provider 适配
|
||||||
|
```
|
||||||
|
12. [provider/openai.rs] 重写为 serde 序列化(删除 ~200 行手动代码)
|
||||||
|
13. [cycle.rs] 适配新类型字段(prompt_tokens vs input_tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 5: 验证
|
||||||
|
```
|
||||||
|
14. [编译] cargo check 确保编译通过
|
||||||
|
15. [检查] cargo clippy 确保无警告
|
||||||
|
16. [测试] cargo test 确保测试通过
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 验证方式
|
||||||
|
|
||||||
|
- `cargo check` — 编译通过
|
||||||
|
- `cargo clippy` — 无警告
|
||||||
|
- `cargo test` — 所有测试通过(如果有集成测试,可能需要调整)
|
||||||
|
- 检查 `OpenaiProvider` 代码量减少(预期从 354 行降至 ~150 行)
|
||||||
|
- 手动验证序列化输出是否符合 OpenAI API 格式
|
||||||
|
|
||||||
|
## 7. 注意事项
|
||||||
|
|
||||||
|
1. **Break change**: 某些类型名称变化(如 `StopReason` → `FinishReason`),项目处于早期阶段,可接受
|
||||||
|
2. **后向兼容**: 通过类型别名保持旧名称可用,接口层无需修改
|
||||||
|
3. **Anthropic 处理**: Anthropic 是独立体系,不在当前设计中泛化,单独实现 Provider
|
||||||
|
4. **异步流**: `chat_stream()` 的签名需要仔细设计(`Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>>>>` 或自定义类型)
|
||||||
|
5. **CostTracker 不变**: 虽然 Usage 变复杂了,但 CostTracker 只累计 input/output token 数,逻辑不变
|
||||||
+32
-44
@@ -7,19 +7,15 @@ pub use usage::{CostTracker, Usage};
|
|||||||
use crate::llm::cycle::retry::should_retry;
|
use crate::llm::cycle::retry::should_retry;
|
||||||
use crate::llm::error::LlmError;
|
use crate::llm::error::LlmError;
|
||||||
use crate::llm::provider::LlmProvider;
|
use crate::llm::provider::LlmProvider;
|
||||||
use crate::llm::types::{ChatRequest, ChatResponse, ContentBlock, Message, Role, ToolDefinition};
|
use crate::llm::types::{
|
||||||
|
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
||||||
|
};
|
||||||
|
|
||||||
/// LLM 生命周期引擎的配置。
|
|
||||||
pub struct CycleConfig {
|
pub struct CycleConfig {
|
||||||
/// 使用的模型名称。
|
|
||||||
pub model: String,
|
pub model: String,
|
||||||
/// 最大输出 token 数。
|
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
/// 采样温度。
|
|
||||||
pub temperature: Option<f32>,
|
pub temperature: Option<f32>,
|
||||||
/// 最大对话轮数(预留,暂未使用)。
|
|
||||||
pub max_turns: Option<u32>,
|
pub max_turns: Option<u32>,
|
||||||
/// 重试策略配置。
|
|
||||||
pub retry: RetryConfig,
|
pub retry: RetryConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,22 +31,15 @@ impl Default for CycleConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// LLM 调用生命周期引擎。
|
|
||||||
///
|
|
||||||
/// 管理一次多轮交互的完整生命周期,包括:
|
|
||||||
/// - 消息历史维护
|
|
||||||
/// - Token 用量追踪
|
|
||||||
/// - 自动重试
|
|
||||||
pub struct LlmCycle {
|
pub struct LlmCycle {
|
||||||
provider: Box<dyn LlmProvider>,
|
provider: Box<dyn LlmProvider>,
|
||||||
config: CycleConfig,
|
config: CycleConfig,
|
||||||
usage: CostTracker,
|
usage: CostTracker,
|
||||||
messages: Vec<Message>,
|
messages: Vec<OpenaiChatMessage>,
|
||||||
system_prompt: Option<String>,
|
system_prompt: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmCycle {
|
impl LlmCycle {
|
||||||
/// 创建新的生命周期引擎。
|
|
||||||
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||||
Self {
|
Self {
|
||||||
provider,
|
provider,
|
||||||
@@ -61,50 +50,33 @@ impl LlmCycle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 设置系统提示词(Builder 模式)。
|
|
||||||
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||||
self.system_prompt = Some(prompt);
|
self.system_prompt = Some(prompt);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取 Token 用量追踪器引用。
|
|
||||||
pub fn usage(&self) -> &CostTracker {
|
pub fn usage(&self) -> &CostTracker {
|
||||||
&self.usage
|
&self.usage
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取当前消息历史。
|
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||||
pub fn messages(&self) -> &[Message] {
|
|
||||||
&self.messages
|
&self.messages
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 清空消息历史。
|
|
||||||
pub fn clear_messages(&mut self) {
|
pub fn clear_messages(&mut self) {
|
||||||
self.messages.clear();
|
self.messages.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 重置 Token 用量统计。
|
|
||||||
pub fn reset_usage(&mut self) {
|
pub fn reset_usage(&mut self) {
|
||||||
self.usage.reset();
|
self.usage.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 提交一条用户消息并获取模型响应。
|
|
||||||
///
|
|
||||||
/// 流程:
|
|
||||||
/// 1. 将用户消息追加到消息历史
|
|
||||||
/// 2. 构建 ChatRequest
|
|
||||||
/// 3. 使用重试循环调用 provider.chat()
|
|
||||||
/// 4. 将助手回复追加到消息历史
|
|
||||||
/// 5. 累计 token 用量
|
|
||||||
/// 6. 返回 ChatResponse
|
|
||||||
pub async fn submit(
|
pub async fn submit(
|
||||||
&mut self,
|
&mut self,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
tools: Vec<ToolDefinition>,
|
tools: Vec<ToolDefinition>,
|
||||||
) -> Result<ChatResponse, LlmError> {
|
) -> Result<ChatResponse, LlmError> {
|
||||||
self.messages.push(Message {
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
role: Role::User,
|
|
||||||
content: vec![ContentBlock::Text { text: prompt }],
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
|
|
||||||
@@ -113,10 +85,7 @@ impl LlmCycle {
|
|||||||
|
|
||||||
match self.provider.chat(request).await {
|
match self.provider.chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
self.messages.push(Message {
|
self.messages.push(response.message.clone());
|
||||||
role: Role::Assistant,
|
|
||||||
content: response.message.content.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
self.usage.add(&response.usage);
|
self.usage.add(&response.usage);
|
||||||
|
|
||||||
@@ -134,16 +103,35 @@ impl LlmCycle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 根据当前状态构建 ChatRequest。
|
|
||||||
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||||
|
let mut messages = self.messages.clone();
|
||||||
|
|
||||||
|
if let Some(sys_prompt) = &self.system_prompt
|
||||||
|
&& !messages.iter().any(|m| matches!(m, OpenaiChatMessage::System { .. }))
|
||||||
|
{
|
||||||
|
messages.insert(0, OpenaiChatMessage::system_text(sys_prompt));
|
||||||
|
}
|
||||||
|
|
||||||
|
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools.iter()
|
||||||
|
.map(|t| OpenaiTool::Function {
|
||||||
|
function: t.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
ChatRequest {
|
ChatRequest {
|
||||||
model: self.config.model.clone(),
|
model: self.config.model.clone(),
|
||||||
messages: self.messages.clone(),
|
messages,
|
||||||
system_prompt: self.system_prompt.clone(),
|
|
||||||
tools: tools.to_vec(),
|
|
||||||
max_tokens: self.config.max_tokens,
|
max_tokens: self.config.max_tokens,
|
||||||
temperature: self.config.temperature,
|
temperature: self.config.temperature,
|
||||||
extra_body: None,
|
tools: openai_tools,
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
+1
-42
@@ -1,42 +1 @@
|
|||||||
/// 单次请求的 Token 用量。
|
pub use crate::llm::types::usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage, Usage as LlmUsage};
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct Usage {
|
|
||||||
/// 输入(提示词)消耗的 token 数。
|
|
||||||
pub input_tokens: u32,
|
|
||||||
/// 输出(生成内容)消耗的 token 数。
|
|
||||||
pub output_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Token 用量累计追踪器。
|
|
||||||
///
|
|
||||||
/// 在多轮对话中累计所有请求的 token 消耗。
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
pub struct CostTracker {
|
|
||||||
accumulated: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CostTracker {
|
|
||||||
/// 累加一次请求的用量。
|
|
||||||
///
|
|
||||||
/// 使用 saturating_add 防止溢出。
|
|
||||||
pub fn add(&mut self, usage: &Usage) {
|
|
||||||
self.accumulated.input_tokens = self
|
|
||||||
.accumulated
|
|
||||||
.input_tokens
|
|
||||||
.saturating_add(usage.input_tokens);
|
|
||||||
self.accumulated.output_tokens = self
|
|
||||||
.accumulated
|
|
||||||
.output_tokens
|
|
||||||
.saturating_add(usage.output_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 获取累计用量。
|
|
||||||
pub fn total(&self) -> &Usage {
|
|
||||||
&self.accumulated
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 重置累计用量。
|
|
||||||
pub fn reset(&mut self) {
|
|
||||||
self.accumulated = Usage::default();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+14
-267
@@ -2,33 +2,19 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::{json, Value};
|
|
||||||
|
|
||||||
use crate::llm::cycle::usage::Usage;
|
|
||||||
use crate::llm::error::LlmError;
|
use crate::llm::error::LlmError;
|
||||||
use crate::llm::types::{
|
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
|
||||||
ChatRequest, ChatResponse, ContentBlock, Message, Role, StopReason, ToolDefinition,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::LlmProvider;
|
use super::LlmProvider;
|
||||||
|
|
||||||
/// OpenAI 兼容 API 的 Provider 实现。
|
|
||||||
///
|
|
||||||
/// 支持任意实现了 `POST /v1/chat/completions` 标准的 API
|
|
||||||
/// (包括 OpenAI、Azure OpenAI、DashScope、vLLM 等)。
|
|
||||||
pub struct OpenaiProvider {
|
pub struct OpenaiProvider {
|
||||||
http_client: Client,
|
http_client: Client,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
#[allow(dead_code)]
|
|
||||||
model: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenaiProvider {
|
impl OpenaiProvider {
|
||||||
/// 创建新的 OpenAI Provider。
|
pub fn new(base_url: String, api_key: String, _model: String) -> Self {
|
||||||
///
|
|
||||||
/// 默认使用 120 秒超时的 HTTP 客户端。
|
|
||||||
pub fn new(base_url: String, api_key: String, model: String) -> Self {
|
|
||||||
let http_client = Client::builder()
|
let http_client = Client::builder()
|
||||||
.timeout(Duration::from_secs(120))
|
.timeout(Duration::from_secs(120))
|
||||||
.build()
|
.build()
|
||||||
@@ -38,284 +24,45 @@ impl OpenaiProvider {
|
|||||||
http_client,
|
http_client,
|
||||||
base_url,
|
base_url,
|
||||||
api_key,
|
api_key,
|
||||||
model,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 替换为自定义的 HTTP 客户端(用于测试或自定义配置)。
|
|
||||||
pub fn with_client(mut self, client: Client) -> Self {
|
pub fn with_client(mut self, client: Client) -> Self {
|
||||||
self.http_client = client;
|
self.http_client = client;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将 ChatRequest 构建为 OpenAI API 请求体 JSON。
|
fn map_reqwest_error(e: reqwest::Error) -> LlmError {
|
||||||
fn build_request_body(&self, request: &ChatRequest) -> Value {
|
if e.is_timeout() {
|
||||||
let mut body = json!({
|
LlmError::Timeout {
|
||||||
"model": request.model,
|
duration: Duration::from_secs(120),
|
||||||
"messages": Self::serialize_messages(request),
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = json!(max_tokens);
|
|
||||||
}
|
|
||||||
if let Some(temperature) = request.temperature {
|
|
||||||
body["temperature"] = json!(temperature);
|
|
||||||
}
|
|
||||||
if !request.tools.is_empty() {
|
|
||||||
body["tools"] = json!(
|
|
||||||
request
|
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.map(Self::serialize_tool)
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 合并 extra_body 中的扩展参数到请求体顶层
|
|
||||||
if let Some(ref extra) = request.extra_body
|
|
||||||
&& let Some(obj) = extra.as_object()
|
|
||||||
{
|
|
||||||
for (k, v) in obj {
|
|
||||||
body[k] = v.clone();
|
|
||||||
}
|
}
|
||||||
}
|
} else if e.is_connect() {
|
||||||
|
LlmError::Other(format!("连接失败: {}", e))
|
||||||
body
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 将请求中的消息列表序列化为 API 消息数组。
|
|
||||||
fn serialize_messages(request: &ChatRequest) -> Vec<Value> {
|
|
||||||
let mut messages: Vec<Value> = Vec::new();
|
|
||||||
|
|
||||||
// system_prompt 作为独立的 system 角色消息放在最前面
|
|
||||||
if let Some(ref system_prompt) = request.system_prompt {
|
|
||||||
messages.push(json!({
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
for msg in &request.messages {
|
|
||||||
messages.push(Self::serialize_message(msg));
|
|
||||||
}
|
|
||||||
|
|
||||||
messages
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 将单条消息序列化为 API 格式。
|
|
||||||
///
|
|
||||||
/// 处理逻辑:
|
|
||||||
/// - 多个 content block 或包含图片 → 使用数组格式
|
|
||||||
/// - ToolResult → 使用 tool 角色格式
|
|
||||||
/// - 其他 → 使用纯文本格式
|
|
||||||
fn serialize_message(msg: &Message) -> Value {
|
|
||||||
let role_str = match msg.role {
|
|
||||||
Role::User => "user",
|
|
||||||
Role::Assistant => "assistant",
|
|
||||||
Role::System => "system",
|
|
||||||
Role::Tool => "tool",
|
|
||||||
};
|
|
||||||
|
|
||||||
let has_mixed = msg.content.len() > 1
|
|
||||||
|| msg
|
|
||||||
.content
|
|
||||||
.iter()
|
|
||||||
.any(|b| matches!(b, ContentBlock::ImageUrl { .. }));
|
|
||||||
|
|
||||||
if has_mixed {
|
|
||||||
let content: Vec<Value> = msg
|
|
||||||
.content
|
|
||||||
.iter()
|
|
||||||
.map(Self::serialize_content_block)
|
|
||||||
.collect();
|
|
||||||
json!({ "role": role_str, "content": content })
|
|
||||||
} else if let Some(ContentBlock::ToolResult {
|
|
||||||
tool_use_id,
|
|
||||||
content,
|
|
||||||
}) = msg.content.first()
|
|
||||||
{
|
|
||||||
json!({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_use_id,
|
|
||||||
"content": content
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
let text = msg
|
LlmError::Other(format!("请求失败: {}", e))
|
||||||
.content
|
|
||||||
.first()
|
|
||||||
.map(|b| match b {
|
|
||||||
ContentBlock::Text { text } => text.clone(),
|
|
||||||
_ => String::new(),
|
|
||||||
})
|
|
||||||
.unwrap_or_default();
|
|
||||||
json!({ "role": role_str, "content": text })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将 ContentBlock 序列化为 API content parts 数组元素。
|
|
||||||
fn serialize_content_block(block: &ContentBlock) -> Value {
|
|
||||||
match block {
|
|
||||||
ContentBlock::Text { text } => {
|
|
||||||
json!({ "type": "text", "text": text })
|
|
||||||
}
|
|
||||||
ContentBlock::ImageUrl { url } => {
|
|
||||||
json!({ "type": "image_url", "image_url": { "url": url } })
|
|
||||||
}
|
|
||||||
ContentBlock::ToolUse { id, name, input } => {
|
|
||||||
json!({ "type": "tool_use", "id": id, "name": name, "input": input })
|
|
||||||
}
|
|
||||||
ContentBlock::ToolResult { .. } => {
|
|
||||||
json!({ "type": "tool_result", "content": "" })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 将 ToolDefinition 序列化为 OpenAI tools 数组元素。
|
|
||||||
fn serialize_tool(tool: &ToolDefinition) -> Value {
|
|
||||||
json!({
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": tool.input_schema
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 将 OpenAI API 响应 JSON 解析为 ChatResponse。
|
|
||||||
fn parse_response(response: Value) -> Result<ChatResponse, LlmError> {
|
|
||||||
let choice = response["choices"][0]
|
|
||||||
.as_object()
|
|
||||||
.ok_or_else(|| LlmError::Other("响应中缺少 choices[0]".into()))?;
|
|
||||||
|
|
||||||
let msg = choice["message"]
|
|
||||||
.as_object()
|
|
||||||
.ok_or_else(|| LlmError::Other("响应中缺少 message".into()))?;
|
|
||||||
|
|
||||||
let role = match msg["role"].as_str() {
|
|
||||||
Some("assistant") => Role::Assistant,
|
|
||||||
Some(_) => Role::Assistant,
|
|
||||||
None => Role::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut content_blocks: Vec<ContentBlock> = Vec::new();
|
|
||||||
|
|
||||||
// 从 content 字段提取文本和 tool_use
|
|
||||||
if let Some(content_val) = msg.get("content") {
|
|
||||||
match content_val {
|
|
||||||
Value::String(s) if !s.is_empty() => {
|
|
||||||
content_blocks.push(ContentBlock::Text { text: s.clone() });
|
|
||||||
}
|
|
||||||
Value::Array(arr) => {
|
|
||||||
for item in arr {
|
|
||||||
if let Some(item_type) = item["type"].as_str() {
|
|
||||||
match item_type {
|
|
||||||
"text" => {
|
|
||||||
if let Some(text) = item["text"].as_str() {
|
|
||||||
content_blocks
|
|
||||||
.push(ContentBlock::Text { text: text.into() });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"tool_use" | "function" => {
|
|
||||||
let id = item["id"].as_str().unwrap_or("").to_string();
|
|
||||||
let name = item["name"].as_str().unwrap_or("").to_string();
|
|
||||||
let input = item.get("input").cloned().unwrap_or(Value::Null);
|
|
||||||
content_blocks
|
|
||||||
.push(ContentBlock::ToolUse { id, name, input });
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从 tool_calls 字段提取工具调用(OpenAI 特有格式)
|
|
||||||
if let Some(tool_calls) = msg.get("tool_calls").and_then(|v| v.as_array()) {
|
|
||||||
for tc in tool_calls {
|
|
||||||
let id = tc["id"].as_str().unwrap_or("").to_string();
|
|
||||||
let name = tc["function"]["name"].as_str().unwrap_or("").to_string();
|
|
||||||
let input = tc["function"]["arguments"]
|
|
||||||
.as_str()
|
|
||||||
.and_then(|s| serde_json::from_str(s).ok())
|
|
||||||
.unwrap_or(Value::Null);
|
|
||||||
content_blocks.push(ContentBlock::ToolUse { id, name, input });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if content_blocks.is_empty() {
|
|
||||||
content_blocks.push(ContentBlock::Text {
|
|
||||||
text: String::new(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析停止原因
|
|
||||||
let stop_reason = choice["finish_reason"].as_str().map(|s| match s {
|
|
||||||
"stop" => StopReason::Stop,
|
|
||||||
"tool_calls" => StopReason::ToolUse,
|
|
||||||
"max_tokens" => StopReason::MaxTokens,
|
|
||||||
"length" => StopReason::Length,
|
|
||||||
"content_filter" => StopReason::ContentFilter,
|
|
||||||
other => StopReason::Other(other.into()),
|
|
||||||
});
|
|
||||||
|
|
||||||
// 解析 token 用量
|
|
||||||
let usage = response["usage"]
|
|
||||||
.as_object()
|
|
||||||
.map(|u| Usage {
|
|
||||||
input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
|
|
||||||
output_tokens: u
|
|
||||||
.get("completion_tokens")
|
|
||||||
.and_then(|v| v.as_u64())
|
|
||||||
.unwrap_or(0) as u32,
|
|
||||||
})
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(ChatResponse {
|
|
||||||
message: Message {
|
|
||||||
role,
|
|
||||||
content: content_blocks,
|
|
||||||
},
|
|
||||||
usage,
|
|
||||||
stop_reason,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LlmProvider for OpenaiProvider {
|
impl LlmProvider for OpenaiProvider {
|
||||||
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||||
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||||
let body = self.build_request_body(&request);
|
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.http_client
|
.http_client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.header("Content-Type", "application/json")
|
.json(&request)
|
||||||
.json(&body)
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(Self::map_reqwest_error)?;
|
||||||
if e.is_timeout() {
|
|
||||||
LlmError::Timeout {
|
|
||||||
duration: Duration::from_secs(120),
|
|
||||||
}
|
|
||||||
} else if e.is_connect() {
|
|
||||||
LlmError::Other(format!("连接失败: {}", e))
|
|
||||||
} else {
|
|
||||||
LlmError::Other(format!("请求失败: {}", e))
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let status_code: u16 = status.as_u16();
|
let status_code: u16 = status.as_u16();
|
||||||
|
|
||||||
// 处理非 2xx 响应,将 HTTP 状态码映射为对应的 LlmError 变体
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
// 在消费 response body 之前先读取 retry-after 头部
|
|
||||||
let retry_after = response
|
let retry_after = response
|
||||||
.headers()
|
.headers()
|
||||||
.get("retry-after")
|
.get("retry-after")
|
||||||
@@ -344,11 +91,11 @@ impl LlmProvider for OpenaiProvider {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let json_body: Value = response
|
let chat_response: OpenaiChatResponse = response
|
||||||
.json()
|
.json()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?;
|
.map_err(|e| LlmError::Other(format!("响应解析失败: {}", e)))?;
|
||||||
|
|
||||||
Self::parse_response(json_body)
|
Ok(ChatResponse::from(chat_response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
use crate::llm::cycle::usage::Usage;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
/// 对话消息的角色。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
Tool,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 消息内容块,支持多模态及工具调用。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ContentBlock {
|
|
||||||
/// 纯文本内容。
|
|
||||||
Text {
|
|
||||||
text: String,
|
|
||||||
},
|
|
||||||
/// 图片 URL(多模态输入预留)。
|
|
||||||
ImageUrl {
|
|
||||||
url: String,
|
|
||||||
},
|
|
||||||
/// 模型发起的工具调用(预留,暂不实现自动执行)。
|
|
||||||
ToolUse {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
input: Value,
|
|
||||||
},
|
|
||||||
/// 工具执行结果的回传(预留,暂不实现自动执行)。
|
|
||||||
ToolResult {
|
|
||||||
tool_use_id: String,
|
|
||||||
content: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 一条对话消息,由角色和内容块列表组成。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Message {
|
|
||||||
pub role: Role,
|
|
||||||
pub content: Vec<ContentBlock>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 可供模型调用的工具定义。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ToolDefinition {
|
|
||||||
/// 工具名称。
|
|
||||||
pub name: String,
|
|
||||||
/// 工具描述,用于模型理解何时调用。
|
|
||||||
pub description: String,
|
|
||||||
/// JSON Schema 格式的输入参数定义。
|
|
||||||
pub input_schema: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 对 /v1/chat/completions 的完整请求参数。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ChatRequest {
|
|
||||||
/// 模型标识(如 "gpt-4o")。
|
|
||||||
pub model: String,
|
|
||||||
/// 对话历史 + 新消息。
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
/// 独立的系统提示词,将在序列化时转为 system 角色消息。
|
|
||||||
pub system_prompt: Option<String>,
|
|
||||||
/// 可用的工具定义列表。
|
|
||||||
pub tools: Vec<ToolDefinition>,
|
|
||||||
/// 最大输出 token 数。
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
/// 采样温度。
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
/// 扩展参数(如 enable_thinking),会合并到请求体顶层。
|
|
||||||
pub extra_body: Option<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 模型返回的完整响应。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ChatResponse {
|
|
||||||
/// 助手的回复消息。
|
|
||||||
pub message: Message,
|
|
||||||
/// 本次请求的 token 用量。
|
|
||||||
pub usage: Usage,
|
|
||||||
/// 停止原因。
|
|
||||||
pub stop_reason: Option<StopReason>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 模型停止生成的原因。
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum StopReason {
|
|
||||||
/// 正常结束。
|
|
||||||
Stop,
|
|
||||||
/// 模型请求调用工具(预留)。
|
|
||||||
ToolUse,
|
|
||||||
/// 达到 max_tokens 上限。
|
|
||||||
MaxTokens,
|
|
||||||
/// 内容被安全过滤。
|
|
||||||
ContentFilter,
|
|
||||||
/// 长度限制(兼容某些 API 的 finish_reason)。
|
|
||||||
Length,
|
|
||||||
/// 其他未分类的原因。
|
|
||||||
Other(String),
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
use crate::llm::types::shared::{AudioFormat, ImageDetail};
|
||||||
|
use crate::llm::types::tool::OpenaiToolCall;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ImageURL {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<ImageDetail>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct InputAudio {
|
||||||
|
pub data: String,
|
||||||
|
pub format: AudioFormat,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct FileData {
|
||||||
|
pub file_data: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub filename: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiContentPart {
|
||||||
|
Text {
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
Image {
|
||||||
|
image_url: ImageURL,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
detail: Option<ImageDetail>,
|
||||||
|
},
|
||||||
|
InputAudio {
|
||||||
|
input_audio: InputAudio,
|
||||||
|
},
|
||||||
|
File {
|
||||||
|
file: FileData,
|
||||||
|
},
|
||||||
|
Refusal {
|
||||||
|
refusal: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "role")]
|
||||||
|
pub enum OpenaiChatMessage {
|
||||||
|
Developer {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
System {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
Assistant {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
refusal: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
Function {
|
||||||
|
content: Vec<OpenaiContentPart>,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenaiChatMessage {
|
||||||
|
pub fn user_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::User {
|
||||||
|
content: vec![OpenaiContentPart::Text { text: text.into() }],
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
content: vec![OpenaiContentPart::Text { text: text.into() }],
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn system_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::System {
|
||||||
|
content: vec![OpenaiContentPart::Text { text: text.into() }],
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn developer_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::Developer {
|
||||||
|
content: vec![OpenaiContentPart::Text { text: text.into() }],
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_result<S: Into<String>>(tool_call_id: String, content: S) -> Self {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
content: vec![OpenaiContentPart::Text {
|
||||||
|
text: content.into(),
|
||||||
|
}],
|
||||||
|
tool_call_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
pub mod message;
|
||||||
|
pub mod request;
|
||||||
|
pub mod response;
|
||||||
|
pub mod shared;
|
||||||
|
pub mod tool;
|
||||||
|
pub mod usage;
|
||||||
|
|
||||||
|
pub use message::{
|
||||||
|
FileData, ImageURL, InputAudio, OpenaiChatMessage, OpenaiContentPart,
|
||||||
|
};
|
||||||
|
pub use request::{OpenaiChatRequest, OpenaiTool, StreamOptions, ToolChoice};
|
||||||
|
pub use response::{
|
||||||
|
Annotation, Choice, ChunkChoice, Delta, Logprobs, OpenaiAudio,
|
||||||
|
OpenaiChatChunk, OpenaiChatResponse, TokenLogprob, TopLogprob, URLCitation,
|
||||||
|
};
|
||||||
|
pub use shared::{
|
||||||
|
AudioFormat, FinishReason, ImageDetail, Modality, ResponseFormat, Role,
|
||||||
|
ServiceTier, StopSequence,
|
||||||
|
};
|
||||||
|
pub use tool::{FunctionCall, OpenaiToolCall, OpenaiToolDefinition};
|
||||||
|
pub use usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
pub message: OpenaiChatMessage,
|
||||||
|
pub usage: Usage,
|
||||||
|
pub stop_reason: Option<FinishReason>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenaiChatResponse> for ChatResponse {
|
||||||
|
fn from(response: OpenaiChatResponse) -> Self {
|
||||||
|
let message = response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.map(|c| c.message.clone())
|
||||||
|
.unwrap_or_else(|| OpenaiChatMessage::assistant_text(""));
|
||||||
|
let stop_reason = response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|c| c.finish_reason);
|
||||||
|
ChatResponse {
|
||||||
|
message,
|
||||||
|
usage: response.usage,
|
||||||
|
stop_reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ChatRequest = OpenaiChatRequest;
|
||||||
|
pub type Message = OpenaiChatMessage;
|
||||||
|
pub type ContentBlock = OpenaiContentPart;
|
||||||
|
pub type ToolDefinition = OpenaiToolDefinition;
|
||||||
|
pub type StopReason = FinishReason;
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use crate::llm::types::shared::{ResponseFormat, ServiceTier, StopSequence};
|
||||||
|
use crate::llm::types::tool::OpenaiToolDefinition;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StreamOptions {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_usage: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_obfuscation: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Required,
|
||||||
|
Named {
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
AllowedTools {
|
||||||
|
#[serde(rename = "tools")]
|
||||||
|
tool_names: Vec<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiTool {
|
||||||
|
Function {
|
||||||
|
function: OpenaiToolDefinition,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AudioParam {
|
||||||
|
pub format: String,
|
||||||
|
pub voice: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PredictionContent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub pred_type: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct UserLocation {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub loc_type: String,
|
||||||
|
pub approximate: Approximate,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Approximate {
|
||||||
|
pub city: String,
|
||||||
|
pub country: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub region: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub timezone: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct WebSearchOptions {
|
||||||
|
pub search_context_size: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user_location: Option<UserLocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct OpenaiChatRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<crate::llm::types::message::OpenaiChatMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub service_tier: Option<ServiceTier>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StopSequence>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<OpenaiTool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub extra_headers: Option<Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub extra_body: Option<Value>,
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::llm::types::shared::{FinishReason, ServiceTier};
|
||||||
|
use crate::llm::types::message::OpenaiChatMessage;
|
||||||
|
use crate::llm::types::tool::OpenaiToolCall;
|
||||||
|
use crate::llm::types::usage::Usage;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TokenLogprob {
|
||||||
|
pub token: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bytes: Option<Vec<u32>>,
|
||||||
|
pub logprob: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_logprobs: Option<Vec<TopLogprob>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TopLogprob {
|
||||||
|
pub token: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bytes: Option<Vec<u32>>,
|
||||||
|
pub logprob: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Logprobs {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<Vec<TokenLogprob>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refusal: Option<Vec<TokenLogprob>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct URLCitation {
|
||||||
|
pub end_index: u32,
|
||||||
|
pub start_index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub title: Option<String>,
|
||||||
|
pub url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Annotation {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub ann_type: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub url_citation: Option<URLCitation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiAudio {
|
||||||
|
pub id: String,
|
||||||
|
pub data: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub expires_at: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub transcript: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: OpenaiChatMessage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub finish_reason: Option<FinishReason>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<Logprobs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiChatResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<Choice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub service_tier: Option<ServiceTier>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Delta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refusal: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChunkChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: Delta,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<Logprobs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub finish_reason: Option<FinishReason>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiChatChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Role {
|
||||||
|
Developer,
|
||||||
|
System,
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
Tool,
|
||||||
|
Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Stop,
|
||||||
|
Length,
|
||||||
|
ToolCalls,
|
||||||
|
ContentFilter,
|
||||||
|
FunctionCall,
|
||||||
|
#[serde(other)]
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ServiceTier {
|
||||||
|
Auto,
|
||||||
|
Default,
|
||||||
|
#[serde(other)]
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Modality {
|
||||||
|
Text,
|
||||||
|
Audio,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ImageDetail {
|
||||||
|
Auto,
|
||||||
|
Low,
|
||||||
|
High,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum AudioFormat {
|
||||||
|
Wav,
|
||||||
|
Mp3,
|
||||||
|
Aac,
|
||||||
|
Flac,
|
||||||
|
Opus,
|
||||||
|
Pcm16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum StopSequence {
|
||||||
|
Single(String),
|
||||||
|
Multiple(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum ResponseFormat {
|
||||||
|
Text,
|
||||||
|
JsonObject,
|
||||||
|
JsonSchema {
|
||||||
|
schema: serde_json::Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
strict: Option<bool>,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiToolDefinition {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub strict: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionCall {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiToolCall {
|
||||||
|
Function {
|
||||||
|
id: String,
|
||||||
|
function: FunctionCall,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_tokens_details: Option<PromptTokensDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CompletionTokensDetails {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub accepted_prediction_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rejected_prediction_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct PromptTokensDetails {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub cached_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct CostTracker {
|
||||||
|
accumulated: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CostTracker {
|
||||||
|
pub fn add(&mut self, usage: &Usage) {
|
||||||
|
self.accumulated.prompt_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.prompt_tokens
|
||||||
|
.saturating_add(usage.prompt_tokens);
|
||||||
|
self.accumulated.completion_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.completion_tokens
|
||||||
|
.saturating_add(usage.completion_tokens);
|
||||||
|
self.accumulated.total_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.total_tokens
|
||||||
|
.saturating_add(usage.total_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total(&self) -> &Usage {
|
||||||
|
&self.accumulated
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.accumulated = Usage::default();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Usage {
|
||||||
|
pub fn from_input_output(input: u32, output: u32) -> Self {
|
||||||
|
let total = input.saturating_add(output);
|
||||||
|
Usage {
|
||||||
|
prompt_tokens: input,
|
||||||
|
completion_tokens: output,
|
||||||
|
total_tokens: total,
|
||||||
|
completion_tokens_details: None,
|
||||||
|
prompt_tokens_details: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user