diff --git a/examples/simple_visit.rs b/examples/simple_visit.rs index 578bbf2..4c5db16 100644 --- a/examples/simple_visit.rs +++ b/examples/simple_visit.rs @@ -3,7 +3,7 @@ use std::env; use agcore::init_tracing; use agcore::llm::{ cycle::{CycleConfig, LlmCycle}, - provider::openai::OpenaiProvider, + provider::{create_provider, ProviderConfig, ProviderType}, types::{ChatResponse, OpenaiContentPart}, }; @@ -31,17 +31,30 @@ async fn main() { let api_key = env::var("OPENAI_API_KEY").expect("未设置 OPENAI_API_KEY 环境变量"); let base_url = env::var("OPENAI_BASE_URL").expect("未设置 OPENAI_BASE_URL 环境变量"); - let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量选择所要使用的模型"); + let model = env::var("OPENAI_MODEL").expect("未设置 OPENAI_MODEL 环境变量"); - let provider = OpenaiProvider::new(base_url, api_key, model.clone()); - let config = CycleConfig { + let provider_type = env::var("PROVIDER") + .unwrap_or_else(|_| "openai".into()) + .parse::() + .expect("无效的 PROVIDER 值"); + + let config = ProviderConfig { + base_url, + api_key, + model: model.clone(), + }; + + let provider = create_provider(provider_type, config) + .expect("创建 Provider 失败"); + + let cycle_config = CycleConfig { model, max_tokens: Some(65536), temperature: Some(1.3), ..CycleConfig::default() }; - let mut cycle = LlmCycle::new(Box::new(provider), config) + let mut cycle = LlmCycle::new(provider, cycle_config) .with_system_prompt("你是一个简洁的助手,对于任何问题都是用一句话回答。".to_string()); println!("发送请求..."); diff --git a/src/llm/provider.rs b/src/llm/provider.rs index d7d0917..28445c8 100644 --- a/src/llm/provider.rs +++ b/src/llm/provider.rs @@ -4,6 +4,51 @@ use crate::llm::error::LlmError; use crate::llm::types::{ChatRequest, ChatResponse}; use async_trait::async_trait; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderType { + OpenAI, + DeepSeek, + Qwen, +} + +impl std::str::FromStr for ProviderType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "openai" => Ok(ProviderType::OpenAI), + "deepseek" => Ok(ProviderType::DeepSeek), + "qwen" | "dashscope" | "tongyi" => Ok(ProviderType::Qwen), + _ => Err(format!("未知的 Provider 类型: {}", s)), + } + } +} + +pub struct ProviderConfig { + pub base_url: String, + pub api_key: String, + pub model: String, +} + +pub fn create_provider( + provider_type: ProviderType, + config: ProviderConfig, +) -> Result, LlmError> { + match provider_type { + ProviderType::OpenAI => Ok(Box::new(openai::OpenaiProvider::new( + config.base_url, + config.api_key, + config.model, + ))), + ProviderType::DeepSeek => { + unimplemented!("DeepSeek Provider 尚未实现") + } + ProviderType::Qwen => { + unimplemented!("Qwen Provider 尚未实现") + } + } +} + /// LLM Provider 抽象接口。 /// /// 所有具体的 LLM 后端实现(OpenAI、Anthropic、Azure 等) diff --git a/src/llm/types/message.rs b/src/llm/types/message.rs index 172773f..25265f6 100644 --- a/src/llm/types/message.rs +++ b/src/llm/types/message.rs @@ -107,8 +107,6 @@ pub enum OpenaiChatMessage { Assistant { content: ContentField, #[serde(skip_serializing_if = "Option::is_none")] - reasoning_content: Option, - #[serde(skip_serializing_if = "Option::is_none")] refusal: Option, #[serde(skip_serializing_if = "Option::is_none")] name: Option, @@ -136,7 +134,6 @@ impl OpenaiChatMessage { pub fn assistant_text>(text: S) -> Self { OpenaiChatMessage::Assistant { content: ContentField::String(text.into()), - reasoning_content: None, refusal: None, name: None, tool_calls: None,