993ae0eb4b
新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
407 lines
13 KiB
Rust
407 lines
13 KiB
Rust
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||
use crate::llm::types::request::OpenaiChatRequest;
|
||
use crate::prompt::error::PromptError;
|
||
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||
|
||
/// 提示词组合器——构建多角色消息序列。
|
||
#[derive(Default)]
|
||
pub struct PromptComposer {
|
||
messages: Vec<OpenaiChatMessage>,
|
||
}
|
||
|
||
impl PromptComposer {
|
||
/// 创建一个空的组合器。
|
||
pub fn new() -> Self {
|
||
Self::default()
|
||
}
|
||
|
||
/// 从已有的消息列表初始化。
|
||
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
|
||
Self { messages }
|
||
}
|
||
|
||
// ===== 纯文本消息 =====
|
||
|
||
/// 添加一条纯文本 system 消息。
|
||
pub fn system(mut self, text: impl Into<String>) -> Self {
|
||
self.push_message(OpenaiChatMessage::system_text(text.into()));
|
||
self
|
||
}
|
||
|
||
/// 添加一条纯文本 user 消息。
|
||
pub fn user(mut self, text: impl Into<String>) -> Self {
|
||
self.push_message(OpenaiChatMessage::user_text(text.into()));
|
||
self
|
||
}
|
||
|
||
/// 添加一条纯文本 assistant 消息。
|
||
pub fn assistant(mut self, text: impl Into<String>) -> Self {
|
||
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
|
||
self
|
||
}
|
||
|
||
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||
pub fn developer(mut self, text: impl Into<String>) -> Self {
|
||
self.push_message(OpenaiChatMessage::developer_text(text.into()));
|
||
self
|
||
}
|
||
|
||
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||
self.push_message(OpenaiChatMessage::tool_result(
|
||
tool_call_id.into(),
|
||
content.into(),
|
||
));
|
||
self
|
||
}
|
||
|
||
// ===== 模板消息 =====
|
||
|
||
/// 使用模板和上下文渲染后添加为 user 消息。
|
||
pub fn user_template(
|
||
mut self,
|
||
template: &PromptTemplate,
|
||
ctx: &TemplateContext,
|
||
) -> Result<Self, PromptError> {
|
||
let text = template.render(ctx)?;
|
||
self.push_message(OpenaiChatMessage::user_text(text));
|
||
Ok(self)
|
||
}
|
||
|
||
/// 使用模板和上下文渲染后添加为 system 消息。
|
||
pub fn system_template(
|
||
mut self,
|
||
template: &PromptTemplate,
|
||
ctx: &TemplateContext,
|
||
) -> Result<Self, PromptError> {
|
||
let text = template.render(ctx)?;
|
||
self.push_message(OpenaiChatMessage::system_text(text));
|
||
Ok(self)
|
||
}
|
||
|
||
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||
pub fn assistant_template(
|
||
mut self,
|
||
template: &PromptTemplate,
|
||
ctx: &TemplateContext,
|
||
) -> Result<Self, PromptError> {
|
||
let text = template.render(ctx)?;
|
||
self.push_message(OpenaiChatMessage::assistant_text(text));
|
||
Ok(self)
|
||
}
|
||
|
||
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||
pub fn developer_template(
|
||
mut self,
|
||
template: &PromptTemplate,
|
||
ctx: &TemplateContext,
|
||
) -> Result<Self, PromptError> {
|
||
let text = template.render(ctx)?;
|
||
self.push_message(OpenaiChatMessage::developer_text(text));
|
||
Ok(self)
|
||
}
|
||
|
||
// ===== 多模态 ContentPart =====
|
||
|
||
/// 添加一条含指定 ContentPart 的 system 消息。
|
||
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
|
||
self.push_message(OpenaiChatMessage::System {
|
||
content: ContentField::Array(vec![part]),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 添加一条含指定 ContentPart 的 user 消息。
|
||
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
|
||
self.push_message(OpenaiChatMessage::User {
|
||
content: ContentField::Array(vec![part]),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
|
||
self.push_message(OpenaiChatMessage::Assistant {
|
||
content: ContentField::Array(vec![part]),
|
||
refusal: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
|
||
self.push_message(OpenaiChatMessage::Developer {
|
||
content: ContentField::Array(vec![part]),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||
pub fn tool_content(
|
||
mut self,
|
||
tool_call_id: impl Into<String>,
|
||
part: OpenaiContentPart,
|
||
) -> Self {
|
||
self.push_message(OpenaiChatMessage::Tool {
|
||
content: ContentField::Array(vec![part]),
|
||
tool_call_id: tool_call_id.into(),
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 批量添加 ContentPart 作为 user 消息。
|
||
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||
self.push_message(OpenaiChatMessage::User {
|
||
content: ContentField::Array(parts),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 批量添加 ContentPart 作为 system 消息。
|
||
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||
self.push_message(OpenaiChatMessage::System {
|
||
content: ContentField::Array(parts),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 批量添加 ContentPart 作为 assistant 消息。
|
||
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||
self.push_message(OpenaiChatMessage::Assistant {
|
||
content: ContentField::Array(parts),
|
||
refusal: None,
|
||
name: None,
|
||
tool_calls: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
/// 批量添加 ContentPart 作为 developer 消息。
|
||
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||
self.push_message(OpenaiChatMessage::Developer {
|
||
content: ContentField::Array(parts),
|
||
name: None,
|
||
});
|
||
self
|
||
}
|
||
|
||
// ===== 角色标识 =====
|
||
|
||
/// 为上一条添加的消息设置 `name` 字段。
|
||
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||
let name = name.into();
|
||
if let Some(msg) = self.messages.last_mut() {
|
||
set_message_name(msg, name);
|
||
}
|
||
self
|
||
}
|
||
|
||
// ===== 构建 =====
|
||
|
||
/// 构建最终的消息列表。
|
||
pub fn build(self) -> Vec<OpenaiChatMessage> {
|
||
self.messages
|
||
}
|
||
|
||
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
|
||
OpenaiChatRequest {
|
||
model: model.into(),
|
||
messages: self.messages,
|
||
..Default::default()
|
||
}
|
||
}
|
||
|
||
// ===== 内部方法 =====
|
||
|
||
fn push_message(&mut self, msg: OpenaiChatMessage) {
|
||
self.messages.push(msg);
|
||
}
|
||
}
|
||
|
||
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
|
||
match msg {
|
||
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
|
||
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
|
||
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
|
||
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
|
||
OpenaiChatMessage::Tool { .. } => {}
|
||
OpenaiChatMessage::Function { .. } => {}
|
||
}
|
||
}
|
||
|
||
/// 验证消息序列是否符合 OpenAI API 要求。
|
||
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
|
||
if messages.is_empty() {
|
||
return Err(PromptError::InvalidSequence(
|
||
"消息列表不能为空".to_string(),
|
||
));
|
||
}
|
||
|
||
let mut last_tool_call_ids: Vec<String> = Vec::new();
|
||
|
||
for (i, msg) in messages.iter().enumerate() {
|
||
match msg {
|
||
OpenaiChatMessage::Tool {
|
||
tool_call_id,
|
||
..
|
||
} => {
|
||
if last_tool_call_ids.is_empty() {
|
||
return Err(PromptError::InvalidSequence(format!(
|
||
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
|
||
)));
|
||
}
|
||
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
|
||
return Err(PromptError::InvalidSequence(format!(
|
||
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
|
||
tool_call_id
|
||
)));
|
||
}
|
||
}
|
||
OpenaiChatMessage::Assistant {
|
||
tool_calls: Some(calls),
|
||
..
|
||
} => {
|
||
last_tool_call_ids.clear();
|
||
for call in calls {
|
||
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
|
||
last_tool_call_ids.push(id.clone());
|
||
}
|
||
}
|
||
OpenaiChatMessage::Assistant {
|
||
tool_calls: None, ..
|
||
} => {
|
||
last_tool_call_ids.clear();
|
||
}
|
||
_ => {
|
||
last_tool_call_ids.clear();
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::prompt::TemplateValue;
|
||
|
||
#[test]
|
||
fn test_composer_basic() {
|
||
let msgs = PromptComposer::new()
|
||
.system("You are helpful")
|
||
.user("Hello")
|
||
.assistant("Hi there!")
|
||
.build();
|
||
|
||
assert_eq!(msgs.len(), 3);
|
||
}
|
||
|
||
#[test]
|
||
fn test_composer_tool() {
|
||
let msgs = PromptComposer::new()
|
||
.system("You are helpful")
|
||
.user("What's the weather?")
|
||
.assistant("Let me check")
|
||
.tool("call_123", "Sunny, 25°C")
|
||
.build();
|
||
|
||
assert_eq!(msgs.len(), 4);
|
||
match &msgs[3] {
|
||
OpenaiChatMessage::Tool {
|
||
tool_call_id,
|
||
content,
|
||
..
|
||
} => {
|
||
assert_eq!(tool_call_id, "call_123");
|
||
match content {
|
||
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
|
||
_ => {}
|
||
}
|
||
}
|
||
_ => panic!("Expected Tool message"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_validate_messages_ok() {
|
||
let msgs = PromptComposer::new()
|
||
.system("You are helpful")
|
||
.user("Hello")
|
||
.build();
|
||
|
||
assert!(validate_messages(&msgs).is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_validate_messages_empty() {
|
||
let msgs: Vec<OpenaiChatMessage> = vec![];
|
||
assert!(validate_messages(&msgs).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn test_template_render() {
|
||
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
|
||
let mut ctx = TemplateContext::new();
|
||
ctx.insert("name", "Alice");
|
||
ctx.insert("count", "5");
|
||
|
||
let result = tpl.render(&ctx).unwrap();
|
||
assert_eq!(result, "Hello Alice, you have 5 messages");
|
||
}
|
||
|
||
#[test]
|
||
fn test_template_if() {
|
||
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
|
||
let mut ctx = TemplateContext::new();
|
||
ctx.insert("name", "Bob");
|
||
|
||
let with_name = tpl.render(&ctx).unwrap();
|
||
assert_eq!(with_name, "Hello Bob");
|
||
|
||
let without_name = tpl.render(&TemplateContext::new()).unwrap();
|
||
assert_eq!(without_name, "Hello Guest");
|
||
}
|
||
|
||
#[test]
|
||
fn test_template_each() {
|
||
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
|
||
let mut ctx = TemplateContext::new();
|
||
ctx.insert("items", TemplateValue::Array(vec![
|
||
TemplateValue::String("a".to_string()),
|
||
TemplateValue::String("b".to_string()),
|
||
TemplateValue::String("c".to_string()),
|
||
]));
|
||
|
||
let result = tpl.render(&ctx).unwrap();
|
||
assert_eq!(result, "Items: a, b, c, ");
|
||
}
|
||
|
||
#[test]
|
||
fn test_template_display() {
|
||
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
|
||
assert_eq!(format!("{}", tpl), "Hello {{name}}");
|
||
}
|
||
|
||
#[test]
|
||
fn test_context_from_json() {
|
||
let json: serde_json::Value = serde_json::json!({
|
||
"name": "Alice",
|
||
"active": true
|
||
});
|
||
let ctx = TemplateContext::from_json(&json).unwrap();
|
||
assert!(ctx.get("name").is_some());
|
||
assert!(ctx.get("active").is_some());
|
||
}
|
||
}
|