Compare commits
10 Commits
16c3e14cb6
...
af5a580b5e
| Author | SHA1 | Date | |
|---|---|---|---|
| af5a580b5e | |||
| f7e73dd561 | |||
| 28635e28d5 | |||
| e22c176643 | |||
| a4b7b3b9f9 | |||
| 54387a5f64 | |||
| 91d32a6a82 | |||
| b21e163be0 | |||
| 2f58c803bf | |||
| 93de86e02d |
@@ -0,0 +1,74 @@
|
|||||||
|
# http://editorconfig.org
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
max_line_length = 100
|
||||||
|
|
||||||
|
[*.java]
|
||||||
|
indent_size = 4
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.{yml, yaml}]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.json]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.sql]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.go]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 4
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.py]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 4
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.sh]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.rs]
|
||||||
|
indent_size = 4
|
||||||
|
indent_style = space
|
||||||
|
end_of_line = lf
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
+879
-77
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
|||||||
|
**/*.md
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
+12
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"bracketSpacing": true,
|
||||||
|
"bracketSameLine": true,
|
||||||
|
"tabWidth": 2,
|
||||||
|
"useTabs": false,
|
||||||
|
"semi": true,
|
||||||
|
"singleQuote": true,
|
||||||
|
"jsxSingleQuote": false,
|
||||||
|
"trailingComma": "all",
|
||||||
|
"arrowParens": "always",
|
||||||
|
"printWidth": 100
|
||||||
|
}
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
# AGENTS.md
|
# AGENTS.md
|
||||||
|
|
||||||
Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目特定指令合并使用。
|
---
|
||||||
|
|
||||||
**权衡:** 这些指南倾向于谨慎而非速度。对于简单任务,可自行判断使用。
|
## 交互要求
|
||||||
|
|
||||||
|
- 思考过程全程必须使用中文,包括需求分析、逻辑拆解、方案选择等所有推理环节
|
||||||
|
- 最终输出内容必须全部使用中文,除代码语法本身和英文关键词以外
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -77,6 +80,7 @@ Karpathy 启发式行为指南,适用于 OpenCode。可根据需要与项目
|
|||||||
|
|
||||||
**Git Commit 规范**
|
**Git Commit 规范**
|
||||||
- 使用 Conventional Commits 格式:`<type>(<scope>): <description>`
|
- 使用 Conventional Commits 格式:`<type>(<scope>): <description>`
|
||||||
|
- **描述使用中文**
|
||||||
- 类型:
|
- 类型:
|
||||||
- `feat` - 新功能
|
- `feat` - 新功能
|
||||||
- `fix` - Bug 修复
|
- `fix` - Bug 修复
|
||||||
@@ -143,4 +147,4 @@ agcore 是一个智能体(Agent)核心工具箱,提供:
|
|||||||
- diff 中不必要的改动更少
|
- diff 中不必要的改动更少
|
||||||
- 因过度复杂而导致的重写更少
|
- 因过度复杂而导致的重写更少
|
||||||
- 澄清问题在实现之前提出
|
- 澄清问题在实现之前提出
|
||||||
- 干净、精简的 PR
|
- 干净、精简的 PR
|
||||||
|
|||||||
+11
@@ -4,3 +4,14 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
reqwest = { version = "0.12", features = ["json"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
thiserror = "2"
|
||||||
|
async-trait = "0.1"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
dotenvy = "0.15.7"
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
use std::env;
|
||||||
|
|
||||||
|
use agcore::init_tracing;
|
||||||
|
use agcore::llm::{
|
||||||
|
cycle::{CycleConfig, LlmCycle},
|
||||||
|
provider::{create_provider, ProviderConfig, ProviderType},
|
||||||
|
types::{ChatResponse, OpenaiContentPart},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn extract_response_text(response: &ChatResponse) -> &str {
|
||||||
|
match &response.message {
|
||||||
|
agcore::llm::types::OpenaiChatMessage::Assistant { content, .. } => match content {
|
||||||
|
agcore::llm::types::ContentField::String(s) => s,
|
||||||
|
agcore::llm::types::ContentField::Array(parts) => {
|
||||||
|
for part in parts {
|
||||||
|
if let OpenaiContentPart::Text { text } = part {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"[无文本内容]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => "[非 assistant 消息]",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
|
init_tracing();
|
||||||
|
|
||||||
|
let api_key = env::var("OPENAI_API_KEY").expect("未设置 OPENAI_API_KEY 环境变量");
|
||||||
|
let base_url = env::var("OPENAI_BASE_URL").expect("未设置 OPENAI_BASE_URL 环境变量");
|
||||||
|
let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量");
|
||||||
|
|
||||||
|
let provider_type = env::var("PROVIDER")
|
||||||
|
.unwrap_or_else(|_| "openai".into())
|
||||||
|
.parse::<ProviderType>()
|
||||||
|
.expect("无效的 PROVIDER 值");
|
||||||
|
|
||||||
|
let config = ProviderConfig {
|
||||||
|
base_url,
|
||||||
|
api_key,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = create_provider(provider_type, config)
|
||||||
|
.expect("创建 Provider 失败");
|
||||||
|
|
||||||
|
let cycle_config = CycleConfig {
|
||||||
|
model,
|
||||||
|
max_tokens: Some(65536),
|
||||||
|
temperature: Some(1.3),
|
||||||
|
..CycleConfig::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cycle = LlmCycle::new(provider, cycle_config)
|
||||||
|
.with_system_prompt("你是一个简洁的助手,对于任何问题都是用一句话回答。".to_string());
|
||||||
|
|
||||||
|
println!("发送请求...");
|
||||||
|
|
||||||
|
match cycle.submit("介绍一下你自己吧。".to_string(), vec![]).await {
|
||||||
|
Ok(response) => {
|
||||||
|
println!("LLM 回复:{}", extract_response_text(&response));
|
||||||
|
println!(
|
||||||
|
"Token 用量:{} 输入, {} 输出",
|
||||||
|
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("请求失败:{e}");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,243 @@
|
|||||||
|
# 方案:重构 `types.rs` 为完整的 OpenAI 兼容 API 类型系统
|
||||||
|
|
||||||
|
## 1. 现状分析
|
||||||
|
|
||||||
|
### 当前问题
|
||||||
|
|
||||||
|
| 问题 | 详细 |
|
||||||
|
|------|------|
|
||||||
|
| **无 serde** | 所有类型只有 `Debug + Clone`,无 `Serialize/Deserialize`,迫使 `OpenaiProvider` 手动构建 JSON(354 行中约 200 行是序列化代码) |
|
||||||
|
| **请求参数不全** | `ChatRequest` 只支持 `model, messages, system_prompt, tools, max_tokens, temperature, extra_body`,缺失 streaming、response_format、tool_choice、stop、reasoning_effort 等 30+ 参数 |
|
||||||
|
| **响应类型太薄** | `ChatResponse` 只返回 `message + usage + stop_reason`,缺失 `id, created, model, choices` 数组、`logprobs`、`system_fingerprint` 等 |
|
||||||
|
| **无流式支持** | 无 `ChatCompletionChunk` 类型,无法处理 SSE 流式响应 |
|
||||||
|
| **反向依赖** | `types.rs` 引用 `cycle::usage::Usage`,造成模块间反向依赖 |
|
||||||
|
| **手动解析易出错** | `parse_response()` 从 `Value` 中逐字段解析,逻辑脆弱,不支持复杂嵌套类型 |
|
||||||
|
|
||||||
|
### OpenAI API 参考文档覆盖范围
|
||||||
|
|
||||||
|
已完整阅读文档(2177 行),涵盖了完整的请求参数(35+ 个顶层参数)和响应结构。
|
||||||
|
|
||||||
|
## 2. 新类型系统设计
|
||||||
|
|
||||||
|
### 架构
|
||||||
|
|
||||||
|
将 `types.rs` 重构为 Rust 新风格模块目录(符合项目已有惯例),按功能领域拆分:
|
||||||
|
|
||||||
|
```
|
||||||
|
src/llm/
|
||||||
|
├── types/
|
||||||
|
│ ├── mod.rs # 模块根:re-exports + 基础枚举/共用类型
|
||||||
|
│ ├── request.rs # 请求参数(ChatCompletionRequest 等)
|
||||||
|
│ ├── response.rs # 响应类型(ChatCompletionResponse + ChatCompletionChunk)
|
||||||
|
│ ├── message.rs # 消息类型(6 种角色消息 + content parts)
|
||||||
|
│ ├── tool.rs # 工具定义 + 工具调用
|
||||||
|
│ ├── usage.rs # Token 用量(从 cycle/usage.rs 移入,消除反向依赖)
|
||||||
|
│ └── shared.rs # 共用枚举(ReasoningEffort, ServiceTier, ResponseFormat 等)
|
||||||
|
```
|
||||||
|
|
||||||
|
同时,将 `cycle/usage.rs` 中的 `Usage` 和 `CostTracker` **移到** `types/usage.rs`,`cycle/usage.rs` 保留 `pub use` 兼容 re-export。
|
||||||
|
|
||||||
|
### 核心决策
|
||||||
|
|
||||||
|
| 决策 | 选择 | 理由 |
|
||||||
|
|------|------|------|
|
||||||
|
| **序列化方式** | 全部类型 derive `Serialize, Deserialize` | 消除手动 JSON 构建,让 provider 直接 `.json(&req)` / `.json::<Res>()` |
|
||||||
|
| **类型风格** | 直接映射 OpenAI API JSON 形状 | 一目了然,与 API 文档 1:1 对应,调试方便 |
|
||||||
|
| **命名策略** | 添加 `OpenAI` 前缀(如 `OpenaiChatRequest`) | 明确标注为 OpenAI 兼容类型 |
|
||||||
|
| **字段命名** | `#[serde(rename_all = "snake_case")]` | OpenAI API 使用 snake_case |
|
||||||
|
| **可选字段** | `#[serde(skip_serializing_if = "Option::is_none")]` | 不序列化 None 字段,保持请求体干净 |
|
||||||
|
| **默认值** | `#[serde(default)]` | 反序列化时缺失字段用默认值 |
|
||||||
|
| **后向兼容** | 通过类型别名保持 `ChatRequest`/`ChatResponse` 等名称可用 | LlmProvider/LlmCycle 接口不变 |
|
||||||
|
| **泛化策略** | Anthropic 是独立体系,暂不纳入当前设计 | 保持当前类型系统专注 OpenAI,Provider 层做转换 |
|
||||||
|
|
||||||
|
### 关键类型设计原则
|
||||||
|
|
||||||
|
- **`OpenaiChatRequest`**:统一结构体(不拆分 NonStreaming/Streaming),包含 `stream: Option<bool>` 字段,所有字段均为 `Option`,build 时 `skip_serializing_if`
|
||||||
|
- **`OpenaiChatResponse`**:直接对应 `ChatCompletion`(完整响应),保留完整 choices 数组等所有字段
|
||||||
|
- **`OpenaiChatChunk`**:对应流式 chunk,`object = "chat.completion.chunk"`
|
||||||
|
- **消息系统**:用单个 `OpenaiChatMessage` enum 覆盖 6 种角色消息类型(Developer/System/User/Assistant/Tool/Function),每种内部使用对应 struct
|
||||||
|
- **Content parts**:`OpenaiContentPart` enum 覆盖 text/image_url/input_audio/file/refusal
|
||||||
|
|
||||||
|
## 3. 完整类型清单
|
||||||
|
|
||||||
|
### `types/mod.rs` — 共用类型
|
||||||
|
```
|
||||||
|
Role → enum { Developer, System, User, Assistant, Tool, Function }
|
||||||
|
FinishReason → enum { Stop, Length, ToolCalls, ContentFilter, FunctionCall }
|
||||||
|
ServiceTier → enum { Auto, Default, Flex, Scale, Priority }
|
||||||
|
Modality → enum { Text, Audio }
|
||||||
|
ImageDetail → enum { Auto, Low, High }
|
||||||
|
AudioFormat → enum { Wav, Mp3, Aac, Flac, Opus, Pcm16 }
|
||||||
|
Voice → struct { id: String } 或预定义枚举
|
||||||
|
SearchContextSize → enum { Low, Medium, High }
|
||||||
|
StopSequence → enum { Single(String), Multiple(Vec<String>) }
|
||||||
|
Verbosity → enum { Low, Medium, High }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/request.rs` — 请求参数
|
||||||
|
```
|
||||||
|
OpenaiChatRequest → struct (35+ 字段,所有 OpenAI 参数)
|
||||||
|
ResponseFormat → enum { Text, JsonObject { .. }, JsonSchema { .. } }
|
||||||
|
ToolChoice → enum { None, Auto, Required, Named { .. }, AllowedTools { .. } }
|
||||||
|
StreamOptions → struct { include_usage, include_obfuscation }
|
||||||
|
AudioParam → struct { format, voice }
|
||||||
|
PredictionContent → struct { type, content }
|
||||||
|
WebSearchOptions → struct { search_context_size, user_location }
|
||||||
|
UserLocation → struct { type, approximate: Approximate }
|
||||||
|
Approximate → struct { city, country, region, timezone }
|
||||||
|
FunctionCallOption → struct { name } // deprecated
|
||||||
|
FunctionDefinition → struct { name, description, parameters, strict }
|
||||||
|
OpenaiTool → enum { Function { .. }, Custom { .. } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/response.rs` — 响应类型
|
||||||
|
```
|
||||||
|
OpenaiChatResponse → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||||
|
Choice → struct { index, message, finish_reason, logprobs }
|
||||||
|
OpenaiChatMessage → struct { content, refusal, role, tool_calls, function_call, audio, annotations }
|
||||||
|
OpenaiChatChunk → struct { id, object, created, model, choices, usage, system_fingerprint, service_tier }
|
||||||
|
ChunkChoice → struct { index, delta, logprobs, finish_reason }
|
||||||
|
Delta → struct { role, content, tool_calls, function_call }
|
||||||
|
Logprobs → struct { content, refusal }
|
||||||
|
TokenLogprob → struct { token, bytes, logprob, top_logprobs }
|
||||||
|
TopLogprob → struct { token, bytes, logprob }
|
||||||
|
Annotation → struct { type, url_citation }
|
||||||
|
URLCitation → struct { end_index, start_index, title, url }
|
||||||
|
OpenaiAudio → struct { id, data, expires_at, transcript }
|
||||||
|
FunctionCall → struct { name, arguments }
|
||||||
|
OpenaiToolCall → enum { Function { id, function, type }, Custom { id, custom, type } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/message.rs` — 消息类型
|
||||||
|
```
|
||||||
|
OpenaiChatMessage → enum (覆盖 6 种角色消息)
|
||||||
|
DeveloperMessage → struct { content, role, name }
|
||||||
|
SystemMessage → struct { content, role, name }
|
||||||
|
UserMessage → struct { content, role, name }
|
||||||
|
AssistantMessage → struct { content, refusal, role, name, tool_calls, function_call, audio }
|
||||||
|
ToolMessage → struct { content, role, tool_call_id }
|
||||||
|
FunctionMessage → struct { content, role, name }
|
||||||
|
OpenaiContentPart → enum
|
||||||
|
OpenaiContentPartText → struct { type, text }
|
||||||
|
OpenaiContentPartImage → struct { type, image_url: ImageURL }
|
||||||
|
OpenaiContentPartInputAudio → struct { type, input_audio: InputAudio }
|
||||||
|
OpenaiContentPartFile → struct { type, file: FileData }
|
||||||
|
OpenaiContentPartRefusal → struct { type, refusal }
|
||||||
|
ImageURL → struct { url, detail }
|
||||||
|
InputAudio → struct { data, format }
|
||||||
|
FileData → struct { file_data, file_id, filename }
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/tool.rs` — 工具类型
|
||||||
|
```
|
||||||
|
OpenaiToolDefinition → struct { name, description, parameters, strict }
|
||||||
|
(保留 ToolDefinition 别名保持后向兼容,重定义为包含所有字段)
|
||||||
|
OpenaiToolCall (在请求中使用) → 见 response.rs 中的定义
|
||||||
|
```
|
||||||
|
|
||||||
|
### `types/usage.rs` — Token 用量
|
||||||
|
```
|
||||||
|
Usage → struct { prompt_tokens, completion_tokens, total_tokens,
|
||||||
|
completion_tokens_details, prompt_tokens_details }
|
||||||
|
CompletionTokensDetails → struct { reasoning_tokens, audio_tokens,
|
||||||
|
accepted_prediction_tokens, rejected_prediction_tokens }
|
||||||
|
PromptTokensDetails → struct { audio_tokens, cached_tokens }
|
||||||
|
CostTracker → 从 cycle/usage.rs 移入(累计追踪器)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 删除的旧类型
|
||||||
|
- `ContentBlock` → 被 `OpenaiContentPart` 替代(更准确的 OpenAI API 命名)
|
||||||
|
- `StopReason` → 被 `FinishReason` 替代(与 API 命名一致)
|
||||||
|
- `Message` → 被 `OpenaiChatMessage` 替代
|
||||||
|
|
||||||
|
### 类型别名(后向兼容)
|
||||||
|
```
|
||||||
|
ChatRequest = OpenaiChatRequest
|
||||||
|
ChatResponse = OpenaiChatResponse
|
||||||
|
Message = OpenaiChatMessage
|
||||||
|
ContentBlock = OpenaiContentPart
|
||||||
|
ToolDefinition = OpenaiToolDefinition
|
||||||
|
Role = Role(保持不变,但扩展变体)
|
||||||
|
StopReason = FinishReason
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 对其他模块的影响
|
||||||
|
|
||||||
|
### `provider/openai.rs`
|
||||||
|
- **大幅简化**:`build_request_body()` → 直接 `serde_json::to_value(&request)`
|
||||||
|
- `parse_response()` 中 100+ 行手动解析 → 直接 `serde_json::from_value::<OpenaiChatResponse>()`
|
||||||
|
- `serialize_messages()`, `serialize_message()`, `serialize_content_block()`, `serialize_tool()` → **全部删除**
|
||||||
|
- 新增 `chat_stream()` 方法返回 `OpenaiChatChunk` 流
|
||||||
|
- 需要适配新类型的字段名变更(如 `Usage` 中 `input_tokens` → `prompt_tokens`)
|
||||||
|
|
||||||
|
### `provider.rs` (trait)
|
||||||
|
- 接口保持不变,继续使用 `ChatRequest`/`ChatResponse` 类型别名
|
||||||
|
- 调整 `Usage` 类型引用路径
|
||||||
|
|
||||||
|
### `cycle.rs`
|
||||||
|
- `CycleConfig` 扩展支持更多请求参数(至少增加 `tools, tool_choice, response_format, stop, reasoning_effort, seed` 等)
|
||||||
|
- `LlmCycle::submit()` 构建 `ChatRequest` 时使用新类型
|
||||||
|
- `response.usage` 字段类型变更(新 `Usage` 含更多字段)
|
||||||
|
- 此时不添加流式支持
|
||||||
|
|
||||||
|
### `cycle/usage.rs`
|
||||||
|
- `Usage` 结构体**被移走**到 `types/usage.rs`
|
||||||
|
- `cycle/usage.rs` 保留 `pub use crate::llm::types::usage::{Usage, CostTracker};` 作为兼容性 re-export
|
||||||
|
- `CostTracker` 逻辑不变
|
||||||
|
|
||||||
|
### `error.rs`
|
||||||
|
- 无明显变更,错误类型和映射逻辑不变
|
||||||
|
|
||||||
|
## 5. 实施步骤
|
||||||
|
|
||||||
|
### Phase 1: 基础设施
|
||||||
|
```
|
||||||
|
1. [准备] 在 Cargo.toml 中确认 serde 依赖(已有 serde = "1",features = ["derive"])
|
||||||
|
2. [创建] 新建 src/llm/types/ 目录
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 2: 类型定义(按依赖顺序)
|
||||||
|
```
|
||||||
|
3. [usage.rs] 从 cycle/usage.rs 迁移 Usage + CostTracker
|
||||||
|
4. [shared.rs] 定义 Role, FinishReason, ServiceTier, Modality, ImageDetail, StopSequence, ResponseFormat
|
||||||
|
5. [message.rs] 定义 OpenaiChatMessage(6种角色)+ OpenaiContentPart + ImageURL + InputAudio
|
||||||
|
6. [tool.rs] 定义 OpenaiToolDefinition + OpenaiToolCall + FunctionCall
|
||||||
|
7. [request.rs] 定义 OpenaiChatRequest(35+ 字段)+ ToolChoice + StreamOptions
|
||||||
|
8. [response.rs] 定义 OpenaiChatResponse + OpenaiChatChunk + Choice + Delta + Logprobs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 3: 模块组装
|
||||||
|
```
|
||||||
|
9. [mod.rs] 创建模块根,re-export 所有类型 + 别名(ChatRequest = OpenaiChatRequest 等)
|
||||||
|
10. [usage.rs] 更新 cycle/usage.rs 为 pub use re-export
|
||||||
|
11. [删除] 删除旧 src/llm/types.rs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 4: Provider 适配
|
||||||
|
```
|
||||||
|
12. [provider/openai.rs] 重写为 serde 序列化(删除 ~200 行手动代码)
|
||||||
|
13. [cycle.rs] 适配新类型字段(prompt_tokens vs input_tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 5: 验证
|
||||||
|
```
|
||||||
|
14. [编译] cargo check 确保编译通过
|
||||||
|
15. [检查] cargo clippy 确保无警告
|
||||||
|
16. [测试] cargo test 确保测试通过
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 验证方式
|
||||||
|
|
||||||
|
- `cargo check` — 编译通过
|
||||||
|
- `cargo clippy` — 无警告
|
||||||
|
- `cargo test` — 所有测试通过(如果有集成测试,可能需要调整)
|
||||||
|
- 检查 `OpenaiProvider` 代码量减少(预期从 354 行降至 ~150 行)
|
||||||
|
- 手动验证序列化输出是否符合 OpenAI API 格式
|
||||||
|
|
||||||
|
## 7. 注意事项
|
||||||
|
|
||||||
|
1. **Break change**: 某些类型名称变化(如 `StopReason` → `FinishReason`),项目处于早期阶段,可接受
|
||||||
|
2. **后向兼容**: 通过类型别名保持旧名称可用,接口层无需修改
|
||||||
|
3. **Anthropic 处理**: Anthropic 是独立体系,不在当前设计中泛化,单独实现 Provider
|
||||||
|
4. **异步流**: `chat_stream()` 的签名需要仔细设计(`Pin<Box<dyn Stream<Item = Result<OpenaiChatChunk, LlmError>>>>` 或自定义类型)
|
||||||
|
5. **CostTracker 不变**: 虽然 Usage 变复杂了,但 CostTracker 只累计 input/output token 数,逻辑不变
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
# AG Core Roadmap
|
||||||
|
|
||||||
|
> 定稿日期:2026-05-11
|
||||||
|
|
||||||
|
## 愿景
|
||||||
|
|
||||||
|
AG Core 定位为构建 AI 智能体的底层工具箱,通过模块化、可插拔的架构,提供大模型调用、提示词工程、工具系统、记忆检索四大核心能力,支持快速组合出符合业务需求的智能体应用。
|
||||||
|
|
||||||
|
**当前状态**:代码为空壳,specs 目录有 1 份方案(LLM 调用周期)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 模块完整性评估
|
||||||
|
|
||||||
|
| 功能领域 | 方案状态 | 文档位置 | 实现优先级 |
|
||||||
|
|---------|---------|---------|-----------|
|
||||||
|
| LLM 调用周期 | ✅ 完整 | `specs/llm-call-lifecycle.md` | P0 |
|
||||||
|
| 提示词工程 | ❌ 缺失 | — | P1 |
|
||||||
|
| 工具系统 + 权限 | ❌ 缺失 | — | P1 |
|
||||||
|
| 记忆检索 | ❌ 缺失 | — | P2 |
|
||||||
|
| Agent 运行时 | ❌ 缺失 | — | P2 |
|
||||||
|
| 生命周期钩子 | ❌ 缺失 | — | P0(LLM Cycle 扩展) |
|
||||||
|
| Provider 注册发现 | ❌ 缺失 | — | P0(Provider 接口扩展) |
|
||||||
|
| 流式事件系统 | ❌ 缺失 | — | P0(流式接口前置) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 分阶段 Roadmap
|
||||||
|
|
||||||
|
### Phase 0 — Foundation(基础设施)
|
||||||
|
|
||||||
|
**目标**:实现 LLM 调用周期的核心功能,作为所有上层模块的基础。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `llm/types.rs` — 核心数据类型(Message, ContentBlock, ChatRequest/Response, ToolDefinition, StopReason)
|
||||||
|
2. `llm/error.rs` — 错误体系(LlmError 枚举,可重试/不可重试判断)
|
||||||
|
3. `llm/provider.rs` + `llm/provider/openai.rs` — Provider 接口 + OpenAI 兼容实现
|
||||||
|
4. `llm/provider/registry.rs` — ProviderRegistry(多 Provider 注册发现)
|
||||||
|
5. `llm/cycle.rs` + `llm/cycle/{retry,usage}.rs` — 生命周期引擎(重试策略 + 用量追踪)
|
||||||
|
6. `llm/hooks.rs` — HookExecutor 接口(生命周期钩子)
|
||||||
|
7. `llm/stream.rs` — StreamEvents 流式事件系统(AssistantTextDelta, ToolExecutionStarted 等)
|
||||||
|
8. `llm/compact.rs` — Auto-compaction(上下文自动压缩)
|
||||||
|
9. `Cargo.toml` — 添加依赖(tokio, reqwest, serde, thiserror, async-trait, tracing)
|
||||||
|
|
||||||
|
**依赖**:无
|
||||||
|
|
||||||
|
**优先级**:Must Have
|
||||||
|
|
||||||
|
**预估规模**:约 1000 行核心代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 1 — Prompt Engineering(提示词工程)
|
||||||
|
|
||||||
|
**目标**:提供提示词的组合、模板化与优化能力。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `prompt.rs` + `prompt/` 模块
|
||||||
|
2. `PromptTemplate` — 模板引擎(支持变量插值、条件渲染)
|
||||||
|
3. `PromptComposer` — 提示词组合器(拼接 system/user/assistant 消息)
|
||||||
|
4. `specs/prompt-design.md` — 方案文档
|
||||||
|
|
||||||
|
**依赖**:无(可与 Phase 0 并行)
|
||||||
|
|
||||||
|
**优先级**:Should Have
|
||||||
|
|
||||||
|
**预估规模**:约 400 行代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 2 — Tool System(工具系统)
|
||||||
|
|
||||||
|
**目标**:实现 MCP 协议集成与自定义工具注册、调用、权限控制。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `tools.rs` + `tools/` 模块
|
||||||
|
2. `ToolRegistry` — 工具注册表(注册、发现、调用)
|
||||||
|
3. `BaseTool` trait — 工具抽象接口
|
||||||
|
4. `McpClient` — MCP 协议客户端
|
||||||
|
5. `PermissionChecker` — 工具执行权限检查(读/写/删除/网络等)
|
||||||
|
6. `specs/tool-call-loop.md` — Tool 自动执行循环设计
|
||||||
|
7. 扩展 `llm/cycle.rs` 支持自动 tool 循环(参考 OpenHarness `run_query()`)
|
||||||
|
|
||||||
|
**依赖**:Phase 0(LlmProvider 接口传递 tool definitions)、Phase 1(提示词可能需要注入工具描述)
|
||||||
|
|
||||||
|
**优先级**:Should Have
|
||||||
|
|
||||||
|
**预估规模**:约 900 行代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 3 — Memory System(记忆系统)
|
||||||
|
|
||||||
|
**目标**:提供对话记忆的存储、检索与管理能力。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `memory.rs` + `memory/` 模块
|
||||||
|
2. `MemoryStore` trait — 记忆存储抽象(可插拔后端)
|
||||||
|
3. `VectorStore` — 向量存储实现(支持 embedding 检索)
|
||||||
|
4. `ConversationMemory` — 对话记忆管理(sliding window / 全量)
|
||||||
|
5. `MemoryRetriever` — 记忆检索器(similarity search)
|
||||||
|
6. `specs/memory-system.md` — 方案文档
|
||||||
|
|
||||||
|
**依赖**:Phase 0(LLM 调用可能用于 embedding 生成)
|
||||||
|
|
||||||
|
**优先级**:Could Have
|
||||||
|
|
||||||
|
**预估规模**:约 700 行代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 4 — Agent Runtime(智能体运行时)
|
||||||
|
|
||||||
|
**目标**:实现多轮对话编排与任务规划。
|
||||||
|
|
||||||
|
**交付物**:
|
||||||
|
1. `agent.rs` + `agent/` 模块
|
||||||
|
2. `Agent` trait — 智能体接口定义
|
||||||
|
3. `ConversationAgent` — 对话型智能体实现
|
||||||
|
4. `TaskAgent` — 任务型智能体(规划 → 执行 → 反馈)
|
||||||
|
5. `specs/agent-runtime.md` — 方案文档
|
||||||
|
|
||||||
|
**依赖**:Phase 0, 1, 2, 3(整合所有模块)
|
||||||
|
|
||||||
|
**优先级**:Could Have
|
||||||
|
|
||||||
|
**预估规模**:约 600 行代码
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 依赖关系图
|
||||||
|
|
||||||
|
```
|
||||||
|
Phase 4: Agent Runtime
|
||||||
|
│
|
||||||
|
┌─────────────────┼─────────────────┐
|
||||||
|
▼ ▼ ▼
|
||||||
|
Phase 1 Phase 2 Phase 3
|
||||||
|
Prompt Tool System Memory
|
||||||
|
Engineering + Permission System
|
||||||
|
+ HookExecutor
|
||||||
|
│ │ │
|
||||||
|
└────────┬────────┴────────┬────────┘
|
||||||
|
▼ ▼
|
||||||
|
Phase 0 ─────────────────┘
|
||||||
|
LLM Cycle
|
||||||
|
+ ProviderRegistry
|
||||||
|
+ HookExecutor
|
||||||
|
+ StreamEvents
|
||||||
|
+ Auto-compaction
|
||||||
|
(Foundation)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 扩展计划(v0.2+)
|
||||||
|
|
||||||
|
> 以下功能已在 Phase 0 中实现,流式接口为后续增量优化。
|
||||||
|
|
||||||
|
| 扩展项 | 所在模块 | 说明 | 优先级 |
|
||||||
|
|-------|---------|------|--------|
|
||||||
|
| Prompt Optimizer | `prompt` | 提示词自动优化 | P3 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险与建议
|
||||||
|
|
||||||
|
1. **Phase 0 尚未实现**:项目代码是空壳,建议优先完成 LLM 调用周期,避免后续模块依赖不存在的底层
|
||||||
|
2. **并行可能性**:Phase 0 和 Phase 1 可并行开展(无相互依赖),可加速早期交付
|
||||||
|
3. **MCP 协议复杂性**:MCP 涉及协议握手、session 管理、长期连接,建议预留充足时间调研协议细节
|
||||||
|
4. **Scope 蔓延风险**:当前 specs 只有 1 份文档,建议每个模块上线前都产出对应 spec,避免边实现边设计
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 下一步行动
|
||||||
|
|
||||||
|
1. **Phase 0 方案评审**:对齐 LLM 模块设计(`specs/llm-call-lifecycle.md` 已在 2026-05-11 更新)
|
||||||
|
2. **Phase 1 方案启动**:启动 `specs/prompt-design.md` 设计
|
||||||
|
3. **Phase 2 方案启动**:启动 `specs/tool-call-loop.md` 设计(含 PermissionChecker)
|
||||||
+18
-11
@@ -1,14 +1,21 @@
|
|||||||
pub fn add(left: u64, right: u64) -> u64 {
|
//! agcore —— 智能体(Agent)核心工具箱。
|
||||||
left + right
|
//!
|
||||||
}
|
//! 当前提供 LLM 调用周期控制作为核心底层能力,后续将扩展至
|
||||||
|
//! 提示词工程、记忆系统、工具调用、Agent 运行时等领域。
|
||||||
|
|
||||||
#[cfg(test)]
|
pub mod llm;
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||||
fn it_works() {
|
|
||||||
let result = add(2, 2);
|
static INIT: std::sync::Once = std::sync::Once::new();
|
||||||
assert_eq!(result, 4);
|
|
||||||
}
|
pub fn init_tracing() {
|
||||||
|
INIT.call_once(|| {
|
||||||
|
let filter =
|
||||||
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("agcore=info"));
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(fmt::layer())
|
||||||
|
.with(filter)
|
||||||
|
.init();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
//! LLM 调用周期 —— 大模型基础调用周期控制。
|
||||||
|
//!
|
||||||
|
//! 包含核心数据类型、Provider 抽象、OpenAI 兼容实现以及生命周期引擎。
|
||||||
|
|
||||||
|
pub mod cycle;
|
||||||
|
pub mod error;
|
||||||
|
pub mod provider;
|
||||||
|
pub mod types;
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
mod retry;
|
||||||
|
pub mod usage;
|
||||||
|
|
||||||
|
pub use retry::RetryConfig;
|
||||||
|
pub use usage::{CostTracker, Usage};
|
||||||
|
|
||||||
|
use crate::llm::cycle::retry::should_retry;
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::provider::LlmProvider;
|
||||||
|
use crate::llm::types::{
|
||||||
|
ChatRequest, ChatResponse, OpenaiChatMessage, OpenaiTool, ToolChoice, ToolDefinition,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct CycleConfig {
|
||||||
|
pub model: String,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub max_turns: Option<u32>,
|
||||||
|
pub retry: RetryConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CycleConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
model: String::from("gpt-4o"),
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
max_turns: None,
|
||||||
|
retry: RetryConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LlmCycle {
|
||||||
|
provider: Box<dyn LlmProvider>,
|
||||||
|
config: CycleConfig,
|
||||||
|
usage: CostTracker,
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
system_prompt: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmCycle {
|
||||||
|
pub fn new(provider: Box<dyn LlmProvider>, config: CycleConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
config,
|
||||||
|
usage: CostTracker::default(),
|
||||||
|
messages: Vec::new(),
|
||||||
|
system_prompt: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_system_prompt(mut self, prompt: String) -> Self {
|
||||||
|
self.system_prompt = Some(prompt);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn usage(&self) -> &CostTracker {
|
||||||
|
&self.usage
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn messages(&self) -> &[OpenaiChatMessage] {
|
||||||
|
&self.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_messages(&mut self) {
|
||||||
|
self.messages.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_usage(&mut self) {
|
||||||
|
self.usage.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn submit(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
self.messages.push(OpenaiChatMessage::user_text(prompt));
|
||||||
|
|
||||||
|
let mut attempts = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let request = self.build_request(&tools);
|
||||||
|
|
||||||
|
match self.provider.chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
self.messages.push(response.message.clone());
|
||||||
|
|
||||||
|
self.usage.add(&response.usage);
|
||||||
|
|
||||||
|
return Ok(response);
|
||||||
|
}
|
||||||
|
Err(e) if should_retry(&e) && attempts < self.config.retry.max_retries => {
|
||||||
|
attempts += 1;
|
||||||
|
let delay = self.config.retry.compute_delay(attempts);
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_request(&self, tools: &[ToolDefinition]) -> ChatRequest {
|
||||||
|
let mut messages = self.messages.clone();
|
||||||
|
|
||||||
|
if let Some(sys_prompt) = &self.system_prompt
|
||||||
|
&& !messages
|
||||||
|
.iter()
|
||||||
|
.any(|m| matches!(m, OpenaiChatMessage::System { .. }))
|
||||||
|
{
|
||||||
|
messages.insert(0, OpenaiChatMessage::system_text(sys_prompt));
|
||||||
|
}
|
||||||
|
|
||||||
|
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|t| OpenaiTool::Function {
|
||||||
|
function: t.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages,
|
||||||
|
max_tokens: self.config.max_tokens,
|
||||||
|
temperature: self.config.temperature,
|
||||||
|
tools: openai_tools,
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
|
||||||
|
/// 重试策略配置。
|
||||||
|
///
|
||||||
|
/// 使用指数退避 + jitter 算法计算每次重试的等待时间。
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RetryConfig {
|
||||||
|
/// 最大重试次数(默认 3)。
|
||||||
|
pub max_retries: u32,
|
||||||
|
/// 初始延迟(默认 1 秒)。
|
||||||
|
pub base_delay: Duration,
|
||||||
|
/// 最大延迟上限(默认 30 秒)。
|
||||||
|
pub max_delay: Duration,
|
||||||
|
/// Jitter 比例因子(默认 0.25)。
|
||||||
|
pub jitter_factor: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_retries: 3,
|
||||||
|
base_delay: Duration::from_secs(1),
|
||||||
|
max_delay: Duration::from_secs(30),
|
||||||
|
jitter_factor: 0.25,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetryConfig {
|
||||||
|
/// 根据当前重试次数计算等待时间。
|
||||||
|
///
|
||||||
|
/// 算法: `delay = min(base * 2^(attempt-1), max_delay) + random(0, delay * jitter_factor)`
|
||||||
|
pub fn compute_delay(&self, attempt: u32) -> Duration {
|
||||||
|
let base = self.base_delay.as_secs_f64();
|
||||||
|
let exponential = base * (2u64.pow(attempt.saturating_sub(1))) as f64;
|
||||||
|
let capped = exponential.min(self.max_delay.as_secs_f64());
|
||||||
|
let jitter = rand_jitter(capped * self.jitter_factor);
|
||||||
|
|
||||||
|
Duration::from_secs_f64(capped + jitter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 判断错误是否可重试。
|
||||||
|
///
|
||||||
|
/// 可重试条件:
|
||||||
|
/// - RateLimit(429)
|
||||||
|
/// - Timeout
|
||||||
|
/// - Request 且状态码 >= 500 或 == 429
|
||||||
|
pub fn should_retry(err: &LlmError) -> bool {
|
||||||
|
match err {
|
||||||
|
LlmError::RateLimit { .. } => true,
|
||||||
|
LlmError::Timeout { .. } => true,
|
||||||
|
LlmError::Request { status, .. } => *status >= 500 || *status == 429,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 基于纳秒时间戳的简单伪随机数,范围 [0, max)。
|
||||||
|
fn rand_jitter(max: f64) -> f64 {
|
||||||
|
if max <= 0.0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let t = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_nanos();
|
||||||
|
let r = (t % 1000) as f64 / 1000.0;
|
||||||
|
r * max
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub use crate::llm::types::usage::{
|
||||||
|
CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage, Usage as LlmUsage,
|
||||||
|
};
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// LLM 调用过程中可能发生的所有错误。
|
||||||
|
///
|
||||||
|
/// 错误按可重试性分为两类:
|
||||||
|
/// - **可重试**:`RateLimit`、`Timeout`、状态码 >= 500
|
||||||
|
/// - **不可重试**:`Authentication`、`ContextLength`、状态码 4xx(除 429)
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum LlmError {
|
||||||
|
/// API 认证失败(如 API key 无效)。
|
||||||
|
#[error("认证失败: {0}")]
|
||||||
|
Authentication(String),
|
||||||
|
|
||||||
|
/// 请求被限流,可选地附带重试等待时间。
|
||||||
|
#[error("限流(retry_after={retry_after:?})")]
|
||||||
|
RateLimit { retry_after: Option<Duration> },
|
||||||
|
|
||||||
|
/// HTTP 请求失败,包含状态码和响应体。
|
||||||
|
#[error("请求失败(status={status}): {body}")]
|
||||||
|
Request { status: u16, body: String },
|
||||||
|
|
||||||
|
/// 请求超时。
|
||||||
|
#[error("请求超时(duration={duration:?})")]
|
||||||
|
Timeout { duration: Duration },
|
||||||
|
|
||||||
|
/// 流式响应处理错误(预留)。
|
||||||
|
#[error("流式响应错误: {0}")]
|
||||||
|
Stream(String),
|
||||||
|
|
||||||
|
/// 上下文长度超限。
|
||||||
|
#[error("上下文超限(actual={actual}, limit={limit})")]
|
||||||
|
ContextLength { actual: u32, limit: u32 },
|
||||||
|
|
||||||
|
/// 其他未分类的 LLM 调用失败。
|
||||||
|
#[error("LLM 调用失败: {0}")]
|
||||||
|
Other(String),
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
pub mod openai;
|
||||||
|
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ProviderType {
|
||||||
|
OpenAI,
|
||||||
|
DeepSeek,
|
||||||
|
Qwen,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for ProviderType {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"openai" => Ok(ProviderType::OpenAI),
|
||||||
|
"deepseek" => Ok(ProviderType::DeepSeek),
|
||||||
|
"qwen" | "dashscope" | "tongyi" => Ok(ProviderType::Qwen),
|
||||||
|
_ => Err(format!("未知的 Provider 类型: {}", s)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_provider(
|
||||||
|
provider_type: ProviderType,
|
||||||
|
config: ProviderConfig,
|
||||||
|
) -> Result<Box<dyn LlmProvider>, LlmError> {
|
||||||
|
match provider_type {
|
||||||
|
ProviderType::OpenAI => Ok(Box::new(openai::OpenaiProvider::new(
|
||||||
|
config.base_url,
|
||||||
|
config.api_key,
|
||||||
|
config.model,
|
||||||
|
))),
|
||||||
|
ProviderType::DeepSeek => {
|
||||||
|
unimplemented!("DeepSeek Provider 尚未实现")
|
||||||
|
}
|
||||||
|
ProviderType::Qwen => {
|
||||||
|
unimplemented!("Qwen Provider 尚未实现")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLM Provider 抽象接口。
|
||||||
|
///
|
||||||
|
/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等)
|
||||||
|
/// 均需实现此 trait,以实现可插拔替换。
|
||||||
|
#[async_trait]
|
||||||
|
pub trait LlmProvider: Send + Sync {
|
||||||
|
/// 发送聊天请求并返回完整响应。
|
||||||
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError>;
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
use super::LlmProvider;
|
||||||
|
use crate::llm::error::LlmError;
|
||||||
|
use crate::llm::types::{ChatRequest, ChatResponse, OpenaiChatResponse};
|
||||||
|
|
||||||
|
pub struct OpenaiProvider {
|
||||||
|
http_client: Client,
|
||||||
|
base_url: String,
|
||||||
|
api_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenaiProvider {
|
||||||
|
pub fn new(base_url: String, api_key: String, _model: String) -> Self {
|
||||||
|
let http_client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(120))
|
||||||
|
.build()
|
||||||
|
.expect("创建 HTTP 客户端失败");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
http_client,
|
||||||
|
base_url,
|
||||||
|
api_key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_client(mut self, client: Client) -> Self {
|
||||||
|
self.http_client = client;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_reqwest_error(e: reqwest::Error) -> LlmError {
|
||||||
|
if e.is_timeout() {
|
||||||
|
LlmError::Timeout {
|
||||||
|
duration: Duration::from_secs(120),
|
||||||
|
}
|
||||||
|
} else if e.is_connect() {
|
||||||
|
LlmError::Other(format!("连接失败: {}", e))
|
||||||
|
} else {
|
||||||
|
LlmError::Other(format!("请求失败: {}", e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LlmProvider for OpenaiProvider {
|
||||||
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
|
||||||
|
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
|
||||||
|
|
||||||
|
info!(model = %request.model, max_tokens = request.max_tokens, temperature = request.temperature, "发送 LLM 请求");
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.http_client
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!(error = %e, "请求失败");
|
||||||
|
Self::map_reqwest_error(e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let status_code: u16 = status.as_u16();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let retry_after = response
|
||||||
|
.headers()
|
||||||
|
.get("retry-after")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse::<u64>().ok())
|
||||||
|
.map(Duration::from_secs);
|
||||||
|
let body_text = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
error!(status = status_code, body = %body_text, "请求失败");
|
||||||
|
|
||||||
|
return match status_code {
|
||||||
|
401 => Err(LlmError::Authentication(body_text)),
|
||||||
|
429 => Err(LlmError::RateLimit { retry_after }),
|
||||||
|
_ if status_code >= 500 => Err(LlmError::Request {
|
||||||
|
status: status_code,
|
||||||
|
body: body_text,
|
||||||
|
}),
|
||||||
|
_ if status_code == 400 && body_text.contains("context_length_exceeded") => {
|
||||||
|
Err(LlmError::ContextLength {
|
||||||
|
actual: 0,
|
||||||
|
limit: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => Err(LlmError::Request {
|
||||||
|
status: status_code,
|
||||||
|
body: body_text,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let body_text = response.text().await.unwrap_or_default();
|
||||||
|
debug!(body = %body_text, "收到响应体");
|
||||||
|
|
||||||
|
let chat_response: OpenaiChatResponse = serde_json::from_str(&body_text).map_err(|e| {
|
||||||
|
error!(error = %e, body = %body_text, "响应解析失败");
|
||||||
|
LlmError::Other(format!("响应解析失败: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
debug!(response = ?chat_response, "收到 LLM 响应");
|
||||||
|
|
||||||
|
Ok(ChatResponse::from(chat_response))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
use crate::llm::types::shared::{AudioFormat, ImageDetail};
|
||||||
|
use crate::llm::types::tool::OpenaiToolCall;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ContentField {
|
||||||
|
String(String),
|
||||||
|
Array(Vec<OpenaiContentPart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ContentField {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(ContentField::String(s)),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let parts: Result<Vec<OpenaiContentPart>, _> =
|
||||||
|
serde_json::from_value(Value::Array(arr));
|
||||||
|
match parts {
|
||||||
|
Ok(parts) => Ok(ContentField::Array(parts)),
|
||||||
|
Err(e) => Err(serde::de::Error::custom(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(serde::de::Error::custom("content must be string or array")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for ContentField {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
ContentField::String(s) => s.serialize(serializer),
|
||||||
|
ContentField::Array(arr) => arr.serialize(serializer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ImageURL {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<ImageDetail>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct InputAudio {
|
||||||
|
pub data: String,
|
||||||
|
pub format: AudioFormat,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct FileData {
|
||||||
|
pub file_data: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub filename: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiContentPart {
|
||||||
|
Text {
|
||||||
|
text: String,
|
||||||
|
},
|
||||||
|
Image {
|
||||||
|
image_url: ImageURL,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
detail: Option<ImageDetail>,
|
||||||
|
},
|
||||||
|
InputAudio {
|
||||||
|
input_audio: InputAudio,
|
||||||
|
},
|
||||||
|
File {
|
||||||
|
file: FileData,
|
||||||
|
},
|
||||||
|
Refusal {
|
||||||
|
refusal: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "role")]
|
||||||
|
pub enum OpenaiChatMessage {
|
||||||
|
Developer {
|
||||||
|
content: ContentField,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
System {
|
||||||
|
content: ContentField,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
content: ContentField,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
Assistant {
|
||||||
|
content: ContentField,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
refusal: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
content: ContentField,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
Function {
|
||||||
|
content: ContentField,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenaiChatMessage {
|
||||||
|
pub fn user_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::User {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::String(text.into()),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn system_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::System {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn developer_text<S: Into<String>>(text: S) -> Self {
|
||||||
|
OpenaiChatMessage::Developer {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text { text: text.into() }]),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_result<S: Into<String>>(tool_call_id: String, content: S) -> Self {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
content: ContentField::Array(vec![OpenaiContentPart::Text {
|
||||||
|
text: content.into(),
|
||||||
|
}]),
|
||||||
|
tool_call_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
pub mod message;
|
||||||
|
pub mod request;
|
||||||
|
pub mod response;
|
||||||
|
pub mod shared;
|
||||||
|
pub mod tool;
|
||||||
|
pub mod usage;
|
||||||
|
|
||||||
|
pub use message::{
|
||||||
|
ContentField, FileData, ImageURL, InputAudio, OpenaiChatMessage, OpenaiContentPart,
|
||||||
|
};
|
||||||
|
pub use request::{OpenaiChatRequest, OpenaiTool, StreamOptions, ToolChoice};
|
||||||
|
pub use response::{
|
||||||
|
Annotation, Choice, ChunkChoice, Delta, Logprobs, OpenaiAudio, OpenaiChatChunk,
|
||||||
|
OpenaiChatResponse, TokenLogprob, TopLogprob, URLCitation,
|
||||||
|
};
|
||||||
|
pub use shared::{
|
||||||
|
AudioFormat, FinishReason, ImageDetail, Modality, ResponseFormat, Role, ServiceTier,
|
||||||
|
StopSequence,
|
||||||
|
};
|
||||||
|
pub use tool::{FunctionCall, OpenaiToolCall, OpenaiToolDefinition};
|
||||||
|
pub use usage::{CompletionTokensDetails, CostTracker, PromptTokensDetails, Usage};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
pub message: OpenaiChatMessage,
|
||||||
|
pub usage: Usage,
|
||||||
|
pub stop_reason: Option<FinishReason>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenaiChatResponse> for ChatResponse {
|
||||||
|
fn from(response: OpenaiChatResponse) -> Self {
|
||||||
|
let message = response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.map(|c| c.message.clone())
|
||||||
|
.unwrap_or_else(|| OpenaiChatMessage::assistant_text(""));
|
||||||
|
let stop_reason = response.choices.first().and_then(|c| c.finish_reason);
|
||||||
|
ChatResponse {
|
||||||
|
message,
|
||||||
|
usage: response.usage,
|
||||||
|
stop_reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ChatRequest = OpenaiChatRequest;
|
||||||
|
pub type Message = OpenaiChatMessage;
|
||||||
|
pub type ContentBlock = OpenaiContentPart;
|
||||||
|
pub type ToolDefinition = OpenaiToolDefinition;
|
||||||
|
pub type StopReason = FinishReason;
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
use crate::llm::types::shared::{ResponseFormat, ServiceTier, StopSequence};
|
||||||
|
use crate::llm::types::tool::OpenaiToolDefinition;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StreamOptions {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_usage: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_obfuscation: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Required,
|
||||||
|
Named { name: String },
|
||||||
|
AllowedTools { tool_names: Vec<String> },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for ToolChoice {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
ToolChoice::None => serializer.serialize_str("none"),
|
||||||
|
ToolChoice::Auto => serializer.serialize_str("auto"),
|
||||||
|
ToolChoice::Required => serializer.serialize_str("required"),
|
||||||
|
ToolChoice::Named { name } => {
|
||||||
|
let obj = serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": { "name": name }
|
||||||
|
});
|
||||||
|
obj.serialize(serializer)
|
||||||
|
}
|
||||||
|
ToolChoice::AllowedTools { tool_names } => {
|
||||||
|
let obj = serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": { "name": tool_names.first().cloned().unwrap_or_default() }
|
||||||
|
});
|
||||||
|
obj.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ToolChoice {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => match s.as_str() {
|
||||||
|
"none" => Ok(ToolChoice::None),
|
||||||
|
"auto" => Ok(ToolChoice::Auto),
|
||||||
|
"required" => Ok(ToolChoice::Required),
|
||||||
|
_ => Err(serde::de::Error::custom(format!(
|
||||||
|
"unknown tool choice: {s}"
|
||||||
|
))),
|
||||||
|
},
|
||||||
|
Value::Object(obj) => {
|
||||||
|
let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'type' field in tool_choice")
|
||||||
|
})?;
|
||||||
|
if typ == "function" {
|
||||||
|
let func =
|
||||||
|
obj.get("function")
|
||||||
|
.and_then(|v| v.as_object())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'function' field in tool_choice")
|
||||||
|
})?;
|
||||||
|
let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'function.name' in tool_choice")
|
||||||
|
})?;
|
||||||
|
Ok(ToolChoice::Named {
|
||||||
|
name: name.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(serde::de::Error::custom(format!(
|
||||||
|
"unknown tool_choice type: {typ}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(serde::de::Error::custom(
|
||||||
|
"tool_choice must be a string or object",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiTool {
|
||||||
|
Function { function: OpenaiToolDefinition },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AudioParam {
|
||||||
|
pub format: String,
|
||||||
|
pub voice: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PredictionContent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub pred_type: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct UserLocation {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub loc_type: String,
|
||||||
|
pub approximate: Approximate,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Approximate {
|
||||||
|
pub city: String,
|
||||||
|
pub country: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub region: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub timezone: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct WebSearchOptions {
|
||||||
|
pub search_context_size: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user_location: Option<UserLocation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct OpenaiChatRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<crate::llm::types::message::OpenaiChatMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub service_tier: Option<ServiceTier>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StopSequence>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<OpenaiTool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub extra_headers: Option<Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub extra_body: Option<Value>,
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
use crate::llm::types::message::OpenaiChatMessage;
|
||||||
|
use crate::llm::types::shared::{FinishReason, ServiceTier};
|
||||||
|
use crate::llm::types::tool::OpenaiToolCall;
|
||||||
|
use crate::llm::types::usage::Usage;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TokenLogprob {
|
||||||
|
pub token: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bytes: Option<Vec<u32>>,
|
||||||
|
pub logprob: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_logprobs: Option<Vec<TopLogprob>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TopLogprob {
|
||||||
|
pub token: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bytes: Option<Vec<u32>>,
|
||||||
|
pub logprob: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Logprobs {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<Vec<TokenLogprob>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refusal: Option<Vec<TokenLogprob>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct URLCitation {
|
||||||
|
pub end_index: u32,
|
||||||
|
pub start_index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub title: Option<String>,
|
||||||
|
pub url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Annotation {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub ann_type: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub url_citation: Option<URLCitation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiAudio {
|
||||||
|
pub id: String,
|
||||||
|
pub data: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub expires_at: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub transcript: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: OpenaiChatMessage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub finish_reason: Option<FinishReason>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<Logprobs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiChatResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<Choice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub service_tier: Option<ServiceTier>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Delta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub refusal: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<OpenaiToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChunkChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: Delta,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<Logprobs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub finish_reason: Option<FinishReason>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiChatChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Role {
|
||||||
|
Developer,
|
||||||
|
System,
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
Tool,
|
||||||
|
Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Stop,
|
||||||
|
Length,
|
||||||
|
ToolCalls,
|
||||||
|
ContentFilter,
|
||||||
|
FunctionCall,
|
||||||
|
#[serde(other)]
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ServiceTier {
|
||||||
|
Auto,
|
||||||
|
Default,
|
||||||
|
#[serde(other)]
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Modality {
|
||||||
|
Text,
|
||||||
|
Audio,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ImageDetail {
|
||||||
|
Auto,
|
||||||
|
Low,
|
||||||
|
High,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum AudioFormat {
|
||||||
|
Wav,
|
||||||
|
Mp3,
|
||||||
|
Aac,
|
||||||
|
Flac,
|
||||||
|
Opus,
|
||||||
|
Pcm16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum StopSequence {
|
||||||
|
Single(String),
|
||||||
|
Multiple(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum ResponseFormat {
|
||||||
|
Text,
|
||||||
|
JsonObject,
|
||||||
|
JsonSchema {
|
||||||
|
schema: serde_json::Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
strict: Option<bool>,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenaiToolDefinition {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub strict: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionCall {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
|
pub enum OpenaiToolCall {
|
||||||
|
Function { id: String, function: FunctionCall },
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_tokens_details: Option<PromptTokensDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CompletionTokensDetails {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub accepted_prediction_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rejected_prediction_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
||||||
|
pub struct PromptTokensDetails {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub cached_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct CostTracker {
|
||||||
|
accumulated: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CostTracker {
|
||||||
|
pub fn add(&mut self, usage: &Usage) {
|
||||||
|
self.accumulated.prompt_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.prompt_tokens
|
||||||
|
.saturating_add(usage.prompt_tokens);
|
||||||
|
self.accumulated.completion_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.completion_tokens
|
||||||
|
.saturating_add(usage.completion_tokens);
|
||||||
|
self.accumulated.total_tokens = self
|
||||||
|
.accumulated
|
||||||
|
.total_tokens
|
||||||
|
.saturating_add(usage.total_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total(&self) -> &Usage {
|
||||||
|
&self.accumulated
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.accumulated = Usage::default();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Usage {
|
||||||
|
pub fn from_input_output(input: u32, output: u32) -> Self {
|
||||||
|
let total = input.saturating_add(output);
|
||||||
|
Usage {
|
||||||
|
prompt_tokens: input,
|
||||||
|
completion_tokens: output,
|
||||||
|
total_tokens: total,
|
||||||
|
completion_tokens_details: None,
|
||||||
|
prompt_tokens_details: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user