Files
agcore/src/prompt/composer.rs
T
徐涛 993ae0eb4b feat(prompt): 添加提示词工程模块并扩展LLM周期接口
新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
2026-06-03 06:18:16 +08:00

407 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
use crate::llm::types::request::OpenaiChatRequest;
use crate::prompt::error::PromptError;
use crate::prompt::template::{PromptTemplate, TemplateContext};
/// 提示词组合器——构建多角色消息序列。
#[derive(Default)]
pub struct PromptComposer {
messages: Vec<OpenaiChatMessage>,
}
impl PromptComposer {
/// 创建一个空的组合器。
pub fn new() -> Self {
Self::default()
}
/// 从已有的消息列表初始化。
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
Self { messages }
}
// ===== 纯文本消息 =====
/// 添加一条纯文本 system 消息。
pub fn system(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::system_text(text.into()));
self
}
/// 添加一条纯文本 user 消息。
pub fn user(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::user_text(text.into()));
self
}
/// 添加一条纯文本 assistant 消息。
pub fn assistant(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
self
}
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
pub fn developer(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::developer_text(text.into()));
self
}
/// 添加一条 Tool 消息(工具执行结果回传)。
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::tool_result(
tool_call_id.into(),
content.into(),
));
self
}
// ===== 模板消息 =====
/// 使用模板和上下文渲染后添加为 user 消息。
pub fn user_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::user_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 system 消息。
pub fn system_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::system_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 assistant 消息。
pub fn assistant_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::assistant_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 developer 消息。
pub fn developer_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::developer_text(text));
Ok(self)
}
// ===== 多模态 ContentPart =====
/// 添加一条含指定 ContentPart 的 system 消息。
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 user 消息。
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 assistant 消息。
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(vec![part]),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 添加一条含指定 ContentPart 的 developer 消息。
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 Tool 消息。
pub fn tool_content(
mut self,
tool_call_id: impl Into<String>,
part: OpenaiContentPart,
) -> Self {
self.push_message(OpenaiChatMessage::Tool {
content: ContentField::Array(vec![part]),
tool_call_id: tool_call_id.into(),
});
self
}
/// 批量添加 ContentPart 作为 user 消息。
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 system 消息。
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 assistant 消息。
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(parts),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 批量添加 ContentPart 作为 developer 消息。
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(parts),
name: None,
});
self
}
// ===== 角色标识 =====
/// 为上一条添加的消息设置 `name` 字段。
pub fn with_name(mut self, name: impl Into<String>) -> Self {
let name = name.into();
if let Some(msg) = self.messages.last_mut() {
set_message_name(msg, name);
}
self
}
// ===== 构建 =====
/// 构建最终的消息列表。
pub fn build(self) -> Vec<OpenaiChatMessage> {
self.messages
}
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
OpenaiChatRequest {
model: model.into(),
messages: self.messages,
..Default::default()
}
}
// ===== 内部方法 =====
fn push_message(&mut self, msg: OpenaiChatMessage) {
self.messages.push(msg);
}
}
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
match msg {
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Tool { .. } => {}
OpenaiChatMessage::Function { .. } => {}
}
}
/// 验证消息序列是否符合 OpenAI API 要求。
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
if messages.is_empty() {
return Err(PromptError::InvalidSequence(
"消息列表不能为空".to_string(),
));
}
let mut last_tool_call_ids: Vec<String> = Vec::new();
for (i, msg) in messages.iter().enumerate() {
match msg {
OpenaiChatMessage::Tool {
tool_call_id,
..
} => {
if last_tool_call_ids.is_empty() {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
)));
}
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
tool_call_id
)));
}
}
OpenaiChatMessage::Assistant {
tool_calls: Some(calls),
..
} => {
last_tool_call_ids.clear();
for call in calls {
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
last_tool_call_ids.push(id.clone());
}
}
OpenaiChatMessage::Assistant {
tool_calls: None, ..
} => {
last_tool_call_ids.clear();
}
_ => {
last_tool_call_ids.clear();
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompt::TemplateValue;
#[test]
fn test_composer_basic() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.assistant("Hi there!")
.build();
assert_eq!(msgs.len(), 3);
}
#[test]
fn test_composer_tool() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("What's the weather?")
.assistant("Let me check")
.tool("call_123", "Sunny, 25°C")
.build();
assert_eq!(msgs.len(), 4);
match &msgs[3] {
OpenaiChatMessage::Tool {
tool_call_id,
content,
..
} => {
assert_eq!(tool_call_id, "call_123");
match content {
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
_ => {}
}
}
_ => panic!("Expected Tool message"),
}
}
#[test]
fn test_validate_messages_ok() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.build();
assert!(validate_messages(&msgs).is_ok());
}
#[test]
fn test_validate_messages_empty() {
let msgs: Vec<OpenaiChatMessage> = vec![];
assert!(validate_messages(&msgs).is_err());
}
#[test]
fn test_template_render() {
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Alice");
ctx.insert("count", "5");
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Hello Alice, you have 5 messages");
}
#[test]
fn test_template_if() {
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Bob");
let with_name = tpl.render(&ctx).unwrap();
assert_eq!(with_name, "Hello Bob");
let without_name = tpl.render(&TemplateContext::new()).unwrap();
assert_eq!(without_name, "Hello Guest");
}
#[test]
fn test_template_each() {
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("items", TemplateValue::Array(vec![
TemplateValue::String("a".to_string()),
TemplateValue::String("b".to_string()),
TemplateValue::String("c".to_string()),
]));
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Items: a, b, c, ");
}
#[test]
fn test_template_display() {
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
assert_eq!(format!("{}", tpl), "Hello {{name}}");
}
#[test]
fn test_context_from_json() {
let json: serde_json::Value = serde_json::json!({
"name": "Alice",
"active": true
});
let ctx = TemplateContext::from_json(&json).unwrap();
assert!(ctx.get("name").is_some());
assert!(ctx.get("active").is_some());
}
}