feat(prompt): 添加提示词工程模块并扩展LLM周期接口
新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
This commit is contained in:
@@ -0,0 +1,406 @@
|
||||
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||
use crate::llm::types::request::OpenaiChatRequest;
|
||||
use crate::prompt::error::PromptError;
|
||||
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||
|
||||
/// 提示词组合器——构建多角色消息序列。
|
||||
#[derive(Default)]
|
||||
pub struct PromptComposer {
|
||||
messages: Vec<OpenaiChatMessage>,
|
||||
}
|
||||
|
||||
impl PromptComposer {
|
||||
/// 创建一个空的组合器。
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// 从已有的消息列表初始化。
|
||||
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
|
||||
Self { messages }
|
||||
}
|
||||
|
||||
// ===== 纯文本消息 =====
|
||||
|
||||
/// 添加一条纯文本 system 消息。
|
||||
pub fn system(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::system_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 user 消息。
|
||||
pub fn user(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::user_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 assistant 消息。
|
||||
pub fn assistant(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||
pub fn developer(mut self, text: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::developer_text(text.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::tool_result(
|
||||
tool_call_id.into(),
|
||||
content.into(),
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 模板消息 =====
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||
pub fn user_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::user_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||
pub fn system_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::system_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||
pub fn assistant_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::assistant_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||
pub fn developer_template(
|
||||
mut self,
|
||||
template: &PromptTemplate,
|
||||
ctx: &TemplateContext,
|
||||
) -> Result<Self, PromptError> {
|
||||
let text = template.render(ctx)?;
|
||||
self.push_message(OpenaiChatMessage::developer_text(text));
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
// ===== 多模态 ContentPart =====
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::System {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::User {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(vec![part]),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Developer {
|
||||
content: ContentField::Array(vec![part]),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||
pub fn tool_content(
|
||||
mut self,
|
||||
tool_call_id: impl Into<String>,
|
||||
part: OpenaiContentPart,
|
||||
) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Tool {
|
||||
content: ContentField::Array(vec![part]),
|
||||
tool_call_id: tool_call_id.into(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 user 消息。
|
||||
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::User {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 system 消息。
|
||||
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::System {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Assistant {
|
||||
content: ContentField::Array(parts),
|
||||
refusal: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// 批量添加 ContentPart 作为 developer 消息。
|
||||
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||
self.push_message(OpenaiChatMessage::Developer {
|
||||
content: ContentField::Array(parts),
|
||||
name: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 角色标识 =====
|
||||
|
||||
/// 为上一条添加的消息设置 `name` 字段。
|
||||
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||
let name = name.into();
|
||||
if let Some(msg) = self.messages.last_mut() {
|
||||
set_message_name(msg, name);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
// ===== 构建 =====
|
||||
|
||||
/// 构建最终的消息列表。
|
||||
pub fn build(self) -> Vec<OpenaiChatMessage> {
|
||||
self.messages
|
||||
}
|
||||
|
||||
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
|
||||
OpenaiChatRequest {
|
||||
model: model.into(),
|
||||
messages: self.messages,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// ===== 内部方法 =====
|
||||
|
||||
fn push_message(&mut self, msg: OpenaiChatMessage) {
|
||||
self.messages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
|
||||
match msg {
|
||||
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
|
||||
OpenaiChatMessage::Tool { .. } => {}
|
||||
OpenaiChatMessage::Function { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
|
||||
if messages.is_empty() {
|
||||
return Err(PromptError::InvalidSequence(
|
||||
"消息列表不能为空".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut last_tool_call_ids: Vec<String> = Vec::new();
|
||||
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
match msg {
|
||||
OpenaiChatMessage::Tool {
|
||||
tool_call_id,
|
||||
..
|
||||
} => {
|
||||
if last_tool_call_ids.is_empty() {
|
||||
return Err(PromptError::InvalidSequence(format!(
|
||||
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
|
||||
)));
|
||||
}
|
||||
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
|
||||
return Err(PromptError::InvalidSequence(format!(
|
||||
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
|
||||
tool_call_id
|
||||
)));
|
||||
}
|
||||
}
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
} => {
|
||||
last_tool_call_ids.clear();
|
||||
for call in calls {
|
||||
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
|
||||
last_tool_call_ids.push(id.clone());
|
||||
}
|
||||
}
|
||||
OpenaiChatMessage::Assistant {
|
||||
tool_calls: None, ..
|
||||
} => {
|
||||
last_tool_call_ids.clear();
|
||||
}
|
||||
_ => {
|
||||
last_tool_call_ids.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prompt::TemplateValue;
|
||||
|
||||
#[test]
|
||||
fn test_composer_basic() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("Hello")
|
||||
.assistant("Hi there!")
|
||||
.build();
|
||||
|
||||
assert_eq!(msgs.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_composer_tool() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("What's the weather?")
|
||||
.assistant("Let me check")
|
||||
.tool("call_123", "Sunny, 25°C")
|
||||
.build();
|
||||
|
||||
assert_eq!(msgs.len(), 4);
|
||||
match &msgs[3] {
|
||||
OpenaiChatMessage::Tool {
|
||||
tool_call_id,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_call_id, "call_123");
|
||||
match content {
|
||||
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Tool message"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_messages_ok() {
|
||||
let msgs = PromptComposer::new()
|
||||
.system("You are helpful")
|
||||
.user("Hello")
|
||||
.build();
|
||||
|
||||
assert!(validate_messages(&msgs).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_messages_empty() {
|
||||
let msgs: Vec<OpenaiChatMessage> = vec![];
|
||||
assert!(validate_messages(&msgs).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_render() {
|
||||
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("name", "Alice");
|
||||
ctx.insert("count", "5");
|
||||
|
||||
let result = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(result, "Hello Alice, you have 5 messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_if() {
|
||||
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("name", "Bob");
|
||||
|
||||
let with_name = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(with_name, "Hello Bob");
|
||||
|
||||
let without_name = tpl.render(&TemplateContext::new()).unwrap();
|
||||
assert_eq!(without_name, "Hello Guest");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_each() {
|
||||
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
|
||||
let mut ctx = TemplateContext::new();
|
||||
ctx.insert("items", TemplateValue::Array(vec![
|
||||
TemplateValue::String("a".to_string()),
|
||||
TemplateValue::String("b".to_string()),
|
||||
TemplateValue::String("c".to_string()),
|
||||
]));
|
||||
|
||||
let result = tpl.render(&ctx).unwrap();
|
||||
assert_eq!(result, "Items: a, b, c, ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_display() {
|
||||
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
|
||||
assert_eq!(format!("{}", tpl), "Hello {{name}}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_from_json() {
|
||||
let json: serde_json::Value = serde_json::json!({
|
||||
"name": "Alice",
|
||||
"active": true
|
||||
});
|
||||
let ctx = TemplateContext::from_json(&json).unwrap();
|
||||
assert!(ctx.get("name").is_some());
|
||||
assert!(ctx.get("active").is_some());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user