use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart}; use crate::llm::types::request::OpenaiChatRequest; use crate::prompt::error::PromptError; use crate::prompt::template::{PromptTemplate, TemplateContext}; /// 提示词组合器——构建多角色消息序列。 #[derive(Default)] pub struct PromptComposer { messages: Vec, } impl PromptComposer { /// 创建一个空的组合器。 pub fn new() -> Self { Self::default() } /// 从已有的消息列表初始化。 pub fn from_messages(messages: Vec) -> Self { Self { messages } } // ===== 纯文本消息 ===== /// 添加一条纯文本 system 消息。 pub fn system(mut self, text: impl Into) -> Self { self.push_message(OpenaiChatMessage::system_text(text.into())); self } /// 添加一条纯文本 user 消息。 pub fn user(mut self, text: impl Into) -> Self { self.push_message(OpenaiChatMessage::user_text(text.into())); self } /// 添加一条纯文本 assistant 消息。 pub fn assistant(mut self, text: impl Into) -> Self { self.push_message(OpenaiChatMessage::assistant_text(text.into())); self } /// 添加一条纯文本 developer 消息(o1 系列模型使用)。 pub fn developer(mut self, text: impl Into) -> Self { self.push_message(OpenaiChatMessage::developer_text(text.into())); self } /// 添加一条 Tool 消息(工具执行结果回传)。 pub fn tool(mut self, tool_call_id: impl Into, content: impl Into) -> Self { self.push_message(OpenaiChatMessage::tool_result( tool_call_id.into(), content.into(), )); self } // ===== 模板消息 ===== /// 使用模板和上下文渲染后添加为 user 消息。 pub fn user_template( mut self, template: &PromptTemplate, ctx: &TemplateContext, ) -> Result { let text = template.render(ctx)?; self.push_message(OpenaiChatMessage::user_text(text)); Ok(self) } /// 使用模板和上下文渲染后添加为 system 消息。 pub fn system_template( mut self, template: &PromptTemplate, ctx: &TemplateContext, ) -> Result { let text = template.render(ctx)?; self.push_message(OpenaiChatMessage::system_text(text)); Ok(self) } /// 使用模板和上下文渲染后添加为 assistant 消息。 pub fn assistant_template( mut self, template: &PromptTemplate, ctx: &TemplateContext, ) -> Result { let text = template.render(ctx)?; self.push_message(OpenaiChatMessage::assistant_text(text)); Ok(self) } /// 使用模板和上下文渲染后添加为 developer 消息。 pub fn developer_template( mut self, template: &PromptTemplate, ctx: &TemplateContext, ) -> Result { let text = template.render(ctx)?; self.push_message(OpenaiChatMessage::developer_text(text)); Ok(self) } // ===== 多模态 ContentPart ===== /// 添加一条含指定 ContentPart 的 system 消息。 pub fn system_content(mut self, part: OpenaiContentPart) -> Self { self.push_message(OpenaiChatMessage::System { content: ContentField::Array(vec![part]), name: None, }); self } /// 添加一条含指定 ContentPart 的 user 消息。 pub fn user_content(mut self, part: OpenaiContentPart) -> Self { self.push_message(OpenaiChatMessage::User { content: ContentField::Array(vec![part]), name: None, }); self } /// 添加一条含指定 ContentPart 的 assistant 消息。 pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self { self.push_message(OpenaiChatMessage::Assistant { content: ContentField::Array(vec![part]), refusal: None, name: None, tool_calls: None, }); self } /// 添加一条含指定 ContentPart 的 developer 消息。 pub fn developer_content(mut self, part: OpenaiContentPart) -> Self { self.push_message(OpenaiChatMessage::Developer { content: ContentField::Array(vec![part]), name: None, }); self } /// 添加一条含指定 ContentPart 的 Tool 消息。 pub fn tool_content( mut self, tool_call_id: impl Into, part: OpenaiContentPart, ) -> Self { self.push_message(OpenaiChatMessage::Tool { content: ContentField::Array(vec![part]), tool_call_id: tool_call_id.into(), }); self } /// 批量添加 ContentPart 作为 user 消息。 pub fn user_contents(mut self, parts: Vec) -> Self { self.push_message(OpenaiChatMessage::User { content: ContentField::Array(parts), name: None, }); self } /// 批量添加 ContentPart 作为 system 消息。 pub fn system_contents(mut self, parts: Vec) -> Self { self.push_message(OpenaiChatMessage::System { content: ContentField::Array(parts), name: None, }); self } /// 批量添加 ContentPart 作为 assistant 消息。 pub fn assistant_contents(mut self, parts: Vec) -> Self { self.push_message(OpenaiChatMessage::Assistant { content: ContentField::Array(parts), refusal: None, name: None, tool_calls: None, }); self } /// 批量添加 ContentPart 作为 developer 消息。 pub fn developer_contents(mut self, parts: Vec) -> Self { self.push_message(OpenaiChatMessage::Developer { content: ContentField::Array(parts), name: None, }); self } // ===== 角色标识 ===== /// 为上一条添加的消息设置 `name` 字段。 pub fn with_name(mut self, name: impl Into) -> Self { let name = name.into(); if let Some(msg) = self.messages.last_mut() { set_message_name(msg, name); } self } // ===== 构建 ===== /// 构建最终的消息列表。 pub fn build(self) -> Vec { self.messages } /// 构建并直接创建 ChatRequest(需搭配 model 参数)。 /// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`, /// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。 pub fn build_request(self, model: impl Into) -> OpenaiChatRequest { OpenaiChatRequest { model: model.into(), messages: self.messages, ..Default::default() } } // ===== 内部方法 ===== fn push_message(&mut self, msg: OpenaiChatMessage) { self.messages.push(msg); } } fn set_message_name(msg: &mut OpenaiChatMessage, name: String) { match msg { OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name), OpenaiChatMessage::System { name: n, .. } => *n = Some(name), OpenaiChatMessage::User { name: n, .. } => *n = Some(name), OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name), OpenaiChatMessage::Tool { .. } => {} OpenaiChatMessage::Function { .. } => {} } } /// 验证消息序列是否符合 OpenAI API 要求。 pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> { if messages.is_empty() { return Err(PromptError::InvalidSequence( "消息列表不能为空".to_string(), )); } let mut last_tool_call_ids: Vec = Vec::new(); for (i, msg) in messages.iter().enumerate() { match msg { OpenaiChatMessage::Tool { tool_call_id, .. } => { if last_tool_call_ids.is_empty() { return Err(PromptError::InvalidSequence(format!( "消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls" ))); } if !last_tool_call_ids.iter().any(|id| id == tool_call_id) { return Err(PromptError::InvalidSequence(format!( "消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls", tool_call_id ))); } } OpenaiChatMessage::Assistant { tool_calls: Some(calls), .. } => { last_tool_call_ids.clear(); for call in calls { let crate::llm::types::OpenaiToolCall::Function { id, .. } = call; last_tool_call_ids.push(id.clone()); } } OpenaiChatMessage::Assistant { tool_calls: None, .. } => { last_tool_call_ids.clear(); } _ => { last_tool_call_ids.clear(); } } } Ok(()) } #[cfg(test)] mod tests { use super::*; use crate::prompt::TemplateValue; #[test] fn test_composer_basic() { let msgs = PromptComposer::new() .system("You are helpful") .user("Hello") .assistant("Hi there!") .build(); assert_eq!(msgs.len(), 3); } #[test] fn test_composer_tool() { let msgs = PromptComposer::new() .system("You are helpful") .user("What's the weather?") .assistant("Let me check") .tool("call_123", "Sunny, 25°C") .build(); assert_eq!(msgs.len(), 4); match &msgs[3] { OpenaiChatMessage::Tool { tool_call_id, content, .. } => { assert_eq!(tool_call_id, "call_123"); match content { ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"), _ => {} } } _ => panic!("Expected Tool message"), } } #[test] fn test_validate_messages_ok() { let msgs = PromptComposer::new() .system("You are helpful") .user("Hello") .build(); assert!(validate_messages(&msgs).is_ok()); } #[test] fn test_validate_messages_empty() { let msgs: Vec = vec![]; assert!(validate_messages(&msgs).is_err()); } #[test] fn test_template_render() { let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap(); let mut ctx = TemplateContext::new(); ctx.insert("name", "Alice"); ctx.insert("count", "5"); let result = tpl.render(&ctx).unwrap(); assert_eq!(result, "Hello Alice, you have 5 messages"); } #[test] fn test_template_if() { let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap(); let mut ctx = TemplateContext::new(); ctx.insert("name", "Bob"); let with_name = tpl.render(&ctx).unwrap(); assert_eq!(with_name, "Hello Bob"); let without_name = tpl.render(&TemplateContext::new()).unwrap(); assert_eq!(without_name, "Hello Guest"); } #[test] fn test_template_each() { let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap(); let mut ctx = TemplateContext::new(); ctx.insert("items", TemplateValue::Array(vec![ TemplateValue::String("a".to_string()), TemplateValue::String("b".to_string()), TemplateValue::String("c".to_string()), ])); let result = tpl.render(&ctx).unwrap(); assert_eq!(result, "Items: a, b, c, "); } #[test] fn test_template_display() { let tpl = PromptTemplate::compile("Hello {{name}}").unwrap(); assert_eq!(format!("{}", tpl), "Hello {{name}}"); } #[test] fn test_context_from_json() { let json: serde_json::Value = serde_json::json!({ "name": "Alice", "active": true }); let ctx = TemplateContext::from_json(&json).unwrap(); assert!(ctx.get("name").is_some()); assert!(ctx.get("active").is_some()); } }