feat(prompt): 添加提示词工程模块并扩展LLM周期接口

新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
This commit is contained in:
徐涛
2026-06-03 06:18:16 +08:00
parent 7f5513adf3
commit 993ae0eb4b
6 changed files with 1073 additions and 0 deletions
+406
View File
@@ -0,0 +1,406 @@
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
use crate::llm::types::request::OpenaiChatRequest;
use crate::prompt::error::PromptError;
use crate::prompt::template::{PromptTemplate, TemplateContext};
/// 提示词组合器——构建多角色消息序列。
#[derive(Default)]
pub struct PromptComposer {
messages: Vec<OpenaiChatMessage>,
}
impl PromptComposer {
/// 创建一个空的组合器。
pub fn new() -> Self {
Self::default()
}
/// 从已有的消息列表初始化。
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
Self { messages }
}
// ===== 纯文本消息 =====
/// 添加一条纯文本 system 消息。
pub fn system(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::system_text(text.into()));
self
}
/// 添加一条纯文本 user 消息。
pub fn user(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::user_text(text.into()));
self
}
/// 添加一条纯文本 assistant 消息。
pub fn assistant(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
self
}
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
pub fn developer(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::developer_text(text.into()));
self
}
/// 添加一条 Tool 消息(工具执行结果回传)。
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::tool_result(
tool_call_id.into(),
content.into(),
));
self
}
// ===== 模板消息 =====
/// 使用模板和上下文渲染后添加为 user 消息。
pub fn user_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::user_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 system 消息。
pub fn system_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::system_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 assistant 消息。
pub fn assistant_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::assistant_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 developer 消息。
pub fn developer_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::developer_text(text));
Ok(self)
}
// ===== 多模态 ContentPart =====
/// 添加一条含指定 ContentPart 的 system 消息。
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 user 消息。
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 assistant 消息。
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(vec![part]),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 添加一条含指定 ContentPart 的 developer 消息。
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 Tool 消息。
pub fn tool_content(
mut self,
tool_call_id: impl Into<String>,
part: OpenaiContentPart,
) -> Self {
self.push_message(OpenaiChatMessage::Tool {
content: ContentField::Array(vec![part]),
tool_call_id: tool_call_id.into(),
});
self
}
/// 批量添加 ContentPart 作为 user 消息。
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 system 消息。
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 assistant 消息。
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(parts),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 批量添加 ContentPart 作为 developer 消息。
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(parts),
name: None,
});
self
}
// ===== 角色标识 =====
/// 为上一条添加的消息设置 `name` 字段。
pub fn with_name(mut self, name: impl Into<String>) -> Self {
let name = name.into();
if let Some(msg) = self.messages.last_mut() {
set_message_name(msg, name);
}
self
}
// ===== 构建 =====
/// 构建最终的消息列表。
pub fn build(self) -> Vec<OpenaiChatMessage> {
self.messages
}
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
OpenaiChatRequest {
model: model.into(),
messages: self.messages,
..Default::default()
}
}
// ===== 内部方法 =====
fn push_message(&mut self, msg: OpenaiChatMessage) {
self.messages.push(msg);
}
}
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
match msg {
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Tool { .. } => {}
OpenaiChatMessage::Function { .. } => {}
}
}
/// 验证消息序列是否符合 OpenAI API 要求。
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
if messages.is_empty() {
return Err(PromptError::InvalidSequence(
"消息列表不能为空".to_string(),
));
}
let mut last_tool_call_ids: Vec<String> = Vec::new();
for (i, msg) in messages.iter().enumerate() {
match msg {
OpenaiChatMessage::Tool {
tool_call_id,
..
} => {
if last_tool_call_ids.is_empty() {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
)));
}
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
tool_call_id
)));
}
}
OpenaiChatMessage::Assistant {
tool_calls: Some(calls),
..
} => {
last_tool_call_ids.clear();
for call in calls {
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
last_tool_call_ids.push(id.clone());
}
}
OpenaiChatMessage::Assistant {
tool_calls: None, ..
} => {
last_tool_call_ids.clear();
}
_ => {
last_tool_call_ids.clear();
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompt::TemplateValue;
#[test]
fn test_composer_basic() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.assistant("Hi there!")
.build();
assert_eq!(msgs.len(), 3);
}
#[test]
fn test_composer_tool() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("What's the weather?")
.assistant("Let me check")
.tool("call_123", "Sunny, 25°C")
.build();
assert_eq!(msgs.len(), 4);
match &msgs[3] {
OpenaiChatMessage::Tool {
tool_call_id,
content,
..
} => {
assert_eq!(tool_call_id, "call_123");
match content {
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
_ => {}
}
}
_ => panic!("Expected Tool message"),
}
}
#[test]
fn test_validate_messages_ok() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.build();
assert!(validate_messages(&msgs).is_ok());
}
#[test]
fn test_validate_messages_empty() {
let msgs: Vec<OpenaiChatMessage> = vec![];
assert!(validate_messages(&msgs).is_err());
}
#[test]
fn test_template_render() {
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Alice");
ctx.insert("count", "5");
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Hello Alice, you have 5 messages");
}
#[test]
fn test_template_if() {
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Bob");
let with_name = tpl.render(&ctx).unwrap();
assert_eq!(with_name, "Hello Bob");
let without_name = tpl.render(&TemplateContext::new()).unwrap();
assert_eq!(without_name, "Hello Guest");
}
#[test]
fn test_template_each() {
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("items", TemplateValue::Array(vec![
TemplateValue::String("a".to_string()),
TemplateValue::String("b".to_string()),
TemplateValue::String("c".to_string()),
]));
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Items: a, b, c, ");
}
#[test]
fn test_template_display() {
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
assert_eq!(format!("{}", tpl), "Hello {{name}}");
}
#[test]
fn test_context_from_json() {
let json: serde_json::Value = serde_json::json!({
"name": "Alice",
"active": true
});
let ctx = TemplateContext::from_json(&json).unwrap();
assert!(ctx.get("name").is_some());
assert!(ctx.get("active").is_some());
}
}