feat(llm): 添加 Provider 工厂方法和枚举类型
- 新增 `ProviderType` 枚举和 `FromStr` 解析,支持通过环境变量选择 Provider - 新增 `ProviderConfig` 结构体和 `create_provider` 工厂方法,统一 Provider 创建 - 更新示例代码使用新的工厂模式,移除直接实例化 OpenaiProvider 的方式 - 移除 Assistant 消息中未使用的 `reasoning_content` 字段
This commit is contained in:
@@ -3,7 +3,7 @@ use std::env;
|
|||||||
use agcore::init_tracing;
|
use agcore::init_tracing;
|
||||||
use agcore::llm::{
|
use agcore::llm::{
|
||||||
cycle::{CycleConfig, LlmCycle},
|
cycle::{CycleConfig, LlmCycle},
|
||||||
provider::openai::OpenaiProvider,
|
provider::{create_provider, ProviderConfig, ProviderType},
|
||||||
types::{ChatResponse, OpenaiContentPart},
|
types::{ChatResponse, OpenaiContentPart},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -31,17 +31,30 @@ async fn main() {
|
|||||||
|
|
||||||
let api_key = env::var("OPENAI_API_KEY").expect("未设置 OPENAI_API_KEY 环境变量");
|
let api_key = env::var("OPENAI_API_KEY").expect("未设置 OPENAI_API_KEY 环境变量");
|
||||||
let base_url = env::var("OPENAI_BASE_URL").expect("未设置 OPENAI_BASE_URL 环境变量");
|
let base_url = env::var("OPENAI_BASE_URL").expect("未设置 OPENAI_BASE_URL 环境变量");
|
||||||
let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量选择所要使用的模型");
|
let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量");
|
||||||
|
|
||||||
let provider = OpenaiProvider::new(base_url, api_key, model.clone());
|
let provider_type = env::var("PROVIDER")
|
||||||
let config = CycleConfig {
|
.unwrap_or_else(|_| "openai".into())
|
||||||
|
.parse::<ProviderType>()
|
||||||
|
.expect("无效的 PROVIDER 值");
|
||||||
|
|
||||||
|
let config = ProviderConfig {
|
||||||
|
base_url,
|
||||||
|
api_key,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = create_provider(provider_type, config)
|
||||||
|
.expect("创建 Provider 失败");
|
||||||
|
|
||||||
|
let cycle_config = CycleConfig {
|
||||||
model,
|
model,
|
||||||
max_tokens: Some(65536),
|
max_tokens: Some(65536),
|
||||||
temperature: Some(1.3),
|
temperature: Some(1.3),
|
||||||
..CycleConfig::default()
|
..CycleConfig::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut cycle = LlmCycle::new(Box::new(provider), config)
|
let mut cycle = LlmCycle::new(provider, cycle_config)
|
||||||
.with_system_prompt("你是一个简洁的助手,对于任何问题都是用一句话回答。".to_string());
|
.with_system_prompt("你是一个简洁的助手,对于任何问题都是用一句话回答。".to_string());
|
||||||
|
|
||||||
println!("发送请求...");
|
println!("发送请求...");
|
||||||
|
|||||||
@@ -4,6 +4,51 @@ use crate::llm::error::LlmError;
|
|||||||
use crate::llm::types::{ChatRequest, ChatResponse};
|
use crate::llm::types::{ChatRequest, ChatResponse};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ProviderType {
|
||||||
|
OpenAI,
|
||||||
|
DeepSeek,
|
||||||
|
Qwen,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for ProviderType {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"openai" => Ok(ProviderType::OpenAI),
|
||||||
|
"deepseek" => Ok(ProviderType::DeepSeek),
|
||||||
|
"qwen" | "dashscope" | "tongyi" => Ok(ProviderType::Qwen),
|
||||||
|
_ => Err(format!("未知的 Provider 类型: {}", s)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_provider(
|
||||||
|
provider_type: ProviderType,
|
||||||
|
config: ProviderConfig,
|
||||||
|
) -> Result<Box<dyn LlmProvider>, LlmError> {
|
||||||
|
match provider_type {
|
||||||
|
ProviderType::OpenAI => Ok(Box::new(openai::OpenaiProvider::new(
|
||||||
|
config.base_url,
|
||||||
|
config.api_key,
|
||||||
|
config.model,
|
||||||
|
))),
|
||||||
|
ProviderType::DeepSeek => {
|
||||||
|
unimplemented!("DeepSeek Provider 尚未实现")
|
||||||
|
}
|
||||||
|
ProviderType::Qwen => {
|
||||||
|
unimplemented!("Qwen Provider 尚未实现")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// LLM Provider 抽象接口。
|
/// LLM Provider 抽象接口。
|
||||||
///
|
///
|
||||||
/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等)
|
/// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等)
|
||||||
|
|||||||
@@ -107,8 +107,6 @@ pub enum OpenaiChatMessage {
|
|||||||
Assistant {
|
Assistant {
|
||||||
content: ContentField,
|
content: ContentField,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
reasoning_content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
refusal: Option<String>,
|
refusal: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
@@ -136,7 +134,6 @@ impl OpenaiChatMessage {
|
|||||||
pub fn assistant_text<S: Into<String>>(text: S) -> Self {
|
pub fn assistant_text<S: Into<String>>(text: S) -> Self {
|
||||||
OpenaiChatMessage::Assistant {
|
OpenaiChatMessage::Assistant {
|
||||||
content: ContentField::String(text.into()),
|
content: ContentField::String(text.into()),
|
||||||
reasoning_content: None,
|
|
||||||
refusal: None,
|
refusal: None,
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
|
|||||||
Reference in New Issue
Block a user