feat(prompt): 添加提示词工程模块并扩展LLM周期接口

新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
This commit is contained in:
徐涛
2026-06-03 06:18:16 +08:00
parent 7f5513adf3
commit 993ae0eb4b
6 changed files with 1073 additions and 0 deletions
+1
View File
@@ -1,6 +1,7 @@
//! agcore —— 智能体(Agent)核心工具箱。
pub mod llm;
pub mod prompt;
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
+88
View File
@@ -113,6 +113,94 @@ impl LlmCycle {
self.usage.reset();
}
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
pub fn with_messages(mut self, messages: Vec<OpenaiChatMessage>) -> Self {
self.messages = messages;
self
}
/// 追加消息到历史尾部。
pub fn extend_messages(&mut self, messages: Vec<OpenaiChatMessage>) {
self.messages.extend(messages);
}
/// 使用预构建消息提交(跳过自动 push user prompt)。
///
/// 与 `submit()` 不同,不自动添加 `user_text(prompt)`,也不自动插入 system prompt。
/// 调用方完全控制消息序列内容。
pub async fn submit_messages(
&mut self,
messages: Vec<OpenaiChatMessage>,
tools: Vec<ToolDefinition>,
) -> Result<ChatResponse, LlmError> {
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
None
} else {
Some(
tools
.iter()
.map(|t| OpenaiTool::Function {
function: t.clone(),
})
.collect(),
)
};
let request = ChatRequest {
model: self.config.model.clone(),
messages,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
tools: openai_tools,
tool_choice: Some(ToolChoice::Auto),
..Default::default()
};
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
.with_request(&request);
let results = executor
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
.await;
if results.iter().any(|r| r.should_block) {
let reason = results
.iter()
.find(|r| r.should_block)
.and_then(|r| r.reason.clone())
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
return Err(LlmError::Other(reason));
}
}
match self.provider.chat(request).await {
Ok(response) => {
if let Some(ref executor) = self.hook_executor {
let post_request = ChatRequest {
model: self.config.model.clone(),
messages: vec![],
..Default::default()
};
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
.with_request(&post_request);
executor
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
.await;
}
self.usage.add(&response.usage);
Ok(response)
}
Err(e) => {
if let Some(ref executor) = self.hook_executor {
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
executor
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
.await;
}
Err(e)
}
}
}
/// 提交用户消息并获取 LLM 响应。
pub async fn submit(
&mut self,
+7
View File
@@ -0,0 +1,7 @@
pub mod error;
pub mod template;
pub mod composer;
pub use error::PromptError;
pub use template::{PromptTemplate, PromptTemplateRegistry, TemplateContext, TemplateValue};
pub use composer::PromptComposer;
+406
View File
@@ -0,0 +1,406 @@
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
use crate::llm::types::request::OpenaiChatRequest;
use crate::prompt::error::PromptError;
use crate::prompt::template::{PromptTemplate, TemplateContext};
/// 提示词组合器——构建多角色消息序列。
#[derive(Default)]
pub struct PromptComposer {
messages: Vec<OpenaiChatMessage>,
}
impl PromptComposer {
/// 创建一个空的组合器。
pub fn new() -> Self {
Self::default()
}
/// 从已有的消息列表初始化。
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
Self { messages }
}
// ===== 纯文本消息 =====
/// 添加一条纯文本 system 消息。
pub fn system(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::system_text(text.into()));
self
}
/// 添加一条纯文本 user 消息。
pub fn user(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::user_text(text.into()));
self
}
/// 添加一条纯文本 assistant 消息。
pub fn assistant(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
self
}
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
pub fn developer(mut self, text: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::developer_text(text.into()));
self
}
/// 添加一条 Tool 消息(工具执行结果回传)。
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
self.push_message(OpenaiChatMessage::tool_result(
tool_call_id.into(),
content.into(),
));
self
}
// ===== 模板消息 =====
/// 使用模板和上下文渲染后添加为 user 消息。
pub fn user_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::user_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 system 消息。
pub fn system_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::system_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 assistant 消息。
pub fn assistant_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::assistant_text(text));
Ok(self)
}
/// 使用模板和上下文渲染后添加为 developer 消息。
pub fn developer_template(
mut self,
template: &PromptTemplate,
ctx: &TemplateContext,
) -> Result<Self, PromptError> {
let text = template.render(ctx)?;
self.push_message(OpenaiChatMessage::developer_text(text));
Ok(self)
}
// ===== 多模态 ContentPart =====
/// 添加一条含指定 ContentPart 的 system 消息。
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 user 消息。
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 assistant 消息。
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(vec![part]),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 添加一条含指定 ContentPart 的 developer 消息。
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(vec![part]),
name: None,
});
self
}
/// 添加一条含指定 ContentPart 的 Tool 消息。
pub fn tool_content(
mut self,
tool_call_id: impl Into<String>,
part: OpenaiContentPart,
) -> Self {
self.push_message(OpenaiChatMessage::Tool {
content: ContentField::Array(vec![part]),
tool_call_id: tool_call_id.into(),
});
self
}
/// 批量添加 ContentPart 作为 user 消息。
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::User {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 system 消息。
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::System {
content: ContentField::Array(parts),
name: None,
});
self
}
/// 批量添加 ContentPart 作为 assistant 消息。
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Assistant {
content: ContentField::Array(parts),
refusal: None,
name: None,
tool_calls: None,
});
self
}
/// 批量添加 ContentPart 作为 developer 消息。
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
self.push_message(OpenaiChatMessage::Developer {
content: ContentField::Array(parts),
name: None,
});
self
}
// ===== 角色标识 =====
/// 为上一条添加的消息设置 `name` 字段。
pub fn with_name(mut self, name: impl Into<String>) -> Self {
let name = name.into();
if let Some(msg) = self.messages.last_mut() {
set_message_name(msg, name);
}
self
}
// ===== 构建 =====
/// 构建最终的消息列表。
pub fn build(self) -> Vec<OpenaiChatMessage> {
self.messages
}
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
OpenaiChatRequest {
model: model.into(),
messages: self.messages,
..Default::default()
}
}
// ===== 内部方法 =====
fn push_message(&mut self, msg: OpenaiChatMessage) {
self.messages.push(msg);
}
}
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
match msg {
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
OpenaiChatMessage::Tool { .. } => {}
OpenaiChatMessage::Function { .. } => {}
}
}
/// 验证消息序列是否符合 OpenAI API 要求。
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
if messages.is_empty() {
return Err(PromptError::InvalidSequence(
"消息列表不能为空".to_string(),
));
}
let mut last_tool_call_ids: Vec<String> = Vec::new();
for (i, msg) in messages.iter().enumerate() {
match msg {
OpenaiChatMessage::Tool {
tool_call_id,
..
} => {
if last_tool_call_ids.is_empty() {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
)));
}
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
return Err(PromptError::InvalidSequence(format!(
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
tool_call_id
)));
}
}
OpenaiChatMessage::Assistant {
tool_calls: Some(calls),
..
} => {
last_tool_call_ids.clear();
for call in calls {
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
last_tool_call_ids.push(id.clone());
}
}
OpenaiChatMessage::Assistant {
tool_calls: None, ..
} => {
last_tool_call_ids.clear();
}
_ => {
last_tool_call_ids.clear();
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompt::TemplateValue;
#[test]
fn test_composer_basic() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.assistant("Hi there!")
.build();
assert_eq!(msgs.len(), 3);
}
#[test]
fn test_composer_tool() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("What's the weather?")
.assistant("Let me check")
.tool("call_123", "Sunny, 25°C")
.build();
assert_eq!(msgs.len(), 4);
match &msgs[3] {
OpenaiChatMessage::Tool {
tool_call_id,
content,
..
} => {
assert_eq!(tool_call_id, "call_123");
match content {
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
_ => {}
}
}
_ => panic!("Expected Tool message"),
}
}
#[test]
fn test_validate_messages_ok() {
let msgs = PromptComposer::new()
.system("You are helpful")
.user("Hello")
.build();
assert!(validate_messages(&msgs).is_ok());
}
#[test]
fn test_validate_messages_empty() {
let msgs: Vec<OpenaiChatMessage> = vec![];
assert!(validate_messages(&msgs).is_err());
}
#[test]
fn test_template_render() {
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Alice");
ctx.insert("count", "5");
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Hello Alice, you have 5 messages");
}
#[test]
fn test_template_if() {
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("name", "Bob");
let with_name = tpl.render(&ctx).unwrap();
assert_eq!(with_name, "Hello Bob");
let without_name = tpl.render(&TemplateContext::new()).unwrap();
assert_eq!(without_name, "Hello Guest");
}
#[test]
fn test_template_each() {
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
let mut ctx = TemplateContext::new();
ctx.insert("items", TemplateValue::Array(vec![
TemplateValue::String("a".to_string()),
TemplateValue::String("b".to_string()),
TemplateValue::String("c".to_string()),
]));
let result = tpl.render(&ctx).unwrap();
assert_eq!(result, "Items: a, b, c, ");
}
#[test]
fn test_template_display() {
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
assert_eq!(format!("{}", tpl), "Hello {{name}}");
}
#[test]
fn test_context_from_json() {
let json: serde_json::Value = serde_json::json!({
"name": "Alice",
"active": true
});
let ctx = TemplateContext::from_json(&json).unwrap();
assert!(ctx.get("name").is_some());
assert!(ctx.get("active").is_some());
}
}
+28
View File
@@ -0,0 +1,28 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum PromptError {
#[error("模板解析错误: {0}")]
Parse(String),
#[error("渲染错误: 变量 '{0}' 未找到")]
VariableNotFound(String),
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
PartialNotFound(String),
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
NotAnArray(String),
#[error("渲染递归超过最大深度限制 ({0})")]
MaxDepthReached(u8),
#[error("渲染错误: {0}")]
Render(String),
#[error("消息序列校验失败: {0}")]
InvalidSequence(String),
#[error("文件读取错误: {0}")]
Io(#[from] std::io::Error),
}
+543
View File
@@ -0,0 +1,543 @@
use std::collections::HashMap;
use std::fmt;
use serde_json::Value;
use crate::prompt::error::PromptError;
const MAX_RENDER_DEPTH: u8 = 16;
// ===== TemplateValue =====
#[derive(Debug, Clone)]
pub enum TemplateValue {
String(String),
Bool(bool),
Array(Vec<TemplateValue>),
Object(HashMap<String, TemplateValue>),
}
impl From<String> for TemplateValue {
fn from(s: String) -> Self {
TemplateValue::String(s)
}
}
impl From<&str> for TemplateValue {
fn from(s: &str) -> Self {
TemplateValue::String(s.to_string())
}
}
impl From<bool> for TemplateValue {
fn from(b: bool) -> Self {
TemplateValue::Bool(b)
}
}
impl fmt::Display for TemplateValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TemplateValue::String(s) => write!(f, "{}", s),
TemplateValue::Bool(b) => write!(f, "{}", b),
TemplateValue::Array(arr) => {
let strs: Vec<String> = arr.iter().map(|v| format!("{}", v)).collect();
write!(f, "[{}]", strs.join(", "))
}
TemplateValue::Object(map) => {
let strs: Vec<String> = map
.iter()
.map(|(k, v)| format!("\"{}\": {}", k, v))
.collect();
write!(f, "{{{}}}", strs.join(", "))
}
}
}
}
impl TemplateValue {
fn is_truthy(&self) -> bool {
match self {
TemplateValue::String(s) => !s.is_empty(),
TemplateValue::Bool(b) => *b,
TemplateValue::Array(arr) => !arr.is_empty(),
TemplateValue::Object(map) => !map.is_empty(),
}
}
fn as_array(&self) -> Option<&Vec<TemplateValue>> {
match self {
TemplateValue::Array(arr) => Some(arr),
_ => None,
}
}
}
// ===== TemplateContext =====
#[derive(Debug, Clone, Default)]
pub struct TemplateContext {
vars: HashMap<String, TemplateValue>,
}
impl TemplateContext {
/// 创建一个空的模板上下文。
pub fn new() -> Self {
Self::default()
}
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>) {
self.vars.insert(key.into(), value.into());
}
/// 按名称获取变量值。
pub fn get(&self, key: &str) -> Option<&TemplateValue> {
self.vars.get(key)
}
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
pub fn from_json(value: &Value) -> Result<Self, PromptError> {
let map = value
.as_object()
.ok_or_else(|| PromptError::Render("JSON 根值必须是对象".to_string()))?;
let mut ctx = Self::new();
for (k, v) in map {
ctx.vars.insert(k.clone(), json_to_template_value(v)?);
}
Ok(ctx)
}
/// 从 `HashMap` 构造(适用于配置加载场景)。
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self {
Self { vars: map }
}
}
fn json_to_template_value(v: &Value) -> Result<TemplateValue, PromptError> {
match v {
Value::Null => Ok(TemplateValue::String(String::new())),
Value::Bool(b) => Ok(TemplateValue::Bool(*b)),
Value::Number(n) => Ok(TemplateValue::String(n.to_string())),
Value::String(s) => Ok(TemplateValue::String(s.clone())),
Value::Array(arr) => {
let items: Result<Vec<TemplateValue>, _> =
arr.iter().map(json_to_template_value).collect();
Ok(TemplateValue::Array(items?))
}
Value::Object(obj) => {
let mut map = HashMap::new();
for (k, v) in obj {
map.insert(k.clone(), json_to_template_value(v)?);
}
Ok(TemplateValue::Object(map))
}
}
}
// ===== Fragment (AST) =====
#[derive(Debug, Clone)]
enum Fragment {
Literal(String),
Variable { name: String },
If {
condition: String,
body: Vec<Fragment>,
else_body: Vec<Fragment>,
},
Each {
variable: String,
body: Vec<Fragment>,
},
Raw(String),
Include(String),
}
// ===== PromptTemplate =====
pub struct PromptTemplate {
raw: String,
fragments: Vec<Fragment>,
partials: HashMap<String, PromptTemplate>,
}
impl PromptTemplate {
/// 从模板字符串编译。
pub fn compile(template: &str) -> Result<Self, PromptError> {
let fragments = compile_fragments(template)?;
Ok(Self {
raw: template.to_string(),
fragments,
partials: HashMap::new(),
})
}
/// 使用上下文渲染。
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError> {
let mut output = String::new();
render_fragments(&self.fragments, ctx, &self.partials, &mut output, 0)?;
Ok(output)
}
/// 使用上下文和外部 partials 渲染。
pub fn render_with_partials(
&self,
ctx: &TemplateContext,
partials: &HashMap<String, PromptTemplate>,
) -> Result<String, PromptError> {
let mut output = String::new();
render_fragments(&self.fragments, ctx, partials, &mut output, 0)?;
Ok(output)
}
/// 注册可引用的子模板。
pub fn register_partial(&mut self, name: &str, template: PromptTemplate) {
self.partials.insert(name.to_string(), template);
}
}
impl fmt::Display for PromptTemplate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.raw)
}
}
// ===== Compiler =====
fn compile_fragments(template: &str) -> Result<Vec<Fragment>, PromptError> {
let bytes = template.as_bytes();
let len = bytes.len();
let mut fragments = Vec::new();
let mut i = 0;
let mut literal = String::new();
while i < len {
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
if !literal.is_empty() {
fragments.push(Fragment::Literal(literal.clone()));
literal.clear();
}
let (tag_content, end) = parse_tag(bytes, i)?;
i = end;
let tag = tag_content.trim();
if let Some(rest) = tag.strip_prefix("#if ") {
let (body, else_body, new_i) =
parse_block(template, i, "if")?;
let condition = rest.trim().to_string();
fragments.push(Fragment::If {
condition,
body,
else_body,
});
i = new_i;
} else if let Some(rest) = tag.strip_prefix("#each ") {
let (body, new_i) = parse_each_block(template, i)?;
let variable = rest.trim().to_string();
fragments.push(Fragment::Each { variable, body });
i = new_i;
} else if tag == "#raw" {
let (raw, new_i) = parse_raw_block(template, i)?;
fragments.push(Fragment::Raw(raw));
i = new_i;
} else if let Some(rest) = tag.strip_prefix("> ") {
let name = rest.trim().to_string();
fragments.push(Fragment::Include(name));
} else if tag.starts_with("/") {
break;
} else {
let name = tag.to_string();
fragments.push(Fragment::Variable { name });
}
} else {
literal.push(bytes[i] as char);
i += 1;
}
}
if !literal.is_empty() {
fragments.push(Fragment::Literal(literal));
}
Ok(fragments)
}
fn parse_tag(bytes: &[u8], start: usize) -> Result<(String, usize), PromptError> {
let len = bytes.len();
let mut i = start + 2;
let mut content = String::new();
while i < len {
if bytes[i] == b'}' && i + 1 < len && bytes[i + 1] == b'}' {
return Ok((content, i + 2));
}
content.push(bytes[i] as char);
i += 1;
}
Err(PromptError::Parse("未闭合的 {{ 标签".to_string()))
}
fn parse_block(
template: &str,
start: usize,
kind: &str,
) -> Result<(Vec<Fragment>, Vec<Fragment>, usize), PromptError> {
let bytes = template.as_bytes();
let len = bytes.len();
let mut depth = 1u32;
let mut i = start;
let mut body = String::new();
let mut else_body = String::new();
let mut is_else = false;
while i < len && depth > 0 {
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
let (tag, end) = parse_tag(bytes, i)?;
let tag = tag.trim().to_string();
if tag == format!("/{kind}") {
depth -= 1;
if depth == 0 {
let (if_fragments, else_fragments) = if is_else {
(compile_fragments(&body)?, compile_fragments(&else_body)?)
} else {
(compile_fragments(&body)?, compile_fragments("")?)
};
return Ok((if_fragments, else_fragments, end));
} else {
body.push_str(&template[i..end]);
}
i = end;
} else if tag == "#if " || tag.starts_with("#if ") {
depth += 1;
body.push_str(&template[i..end]);
i = end;
} else if tag == "else" && depth == 1 {
is_else = true;
i = end;
} else {
body.push_str(&template[i..end]);
i = end;
}
} else {
if is_else {
else_body.push(bytes[i] as char);
} else {
body.push(bytes[i] as char);
}
i += 1;
}
}
Err(PromptError::Parse(format!("未闭合的 {{#{}}}", kind)))
}
fn parse_each_block(
template: &str,
start: usize,
) -> Result<(Vec<Fragment>, usize), PromptError> {
let bytes = template.as_bytes();
let len = bytes.len();
let mut depth = 1u32;
let mut i = start;
let mut body = String::new();
while i < len && depth > 0 {
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
let (tag, end) = parse_tag(bytes, i)?;
let tag = tag.trim().to_string();
if tag == "/each" {
depth -= 1;
if depth == 0 {
let fragments = compile_fragments(&body)?;
return Ok((fragments, end));
} else {
body.push_str(&template[i..end]);
}
i = end;
} else if tag == "#each " || tag.starts_with("#each ") {
depth += 1;
body.push_str(&template[i..end]);
i = end;
} else {
body.push_str(&template[i..end]);
i = end;
}
} else {
body.push(bytes[i] as char);
i += 1;
}
}
Err(PromptError::Parse(
"未闭合的 {{#each}} 块".to_string(),
))
}
fn parse_raw_block(template: &str, start: usize) -> Result<(String, usize), PromptError> {
let bytes = template.as_bytes();
let len = bytes.len();
let mut i = start;
let mut content = String::new();
while i < len {
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
let (tag, end) = parse_tag(bytes, i)?;
let tag = tag.trim().to_string();
if tag == "/raw" {
return Ok((content, end));
} else {
content.push_str(&template[i..end]);
i = end;
}
} else {
content.push(bytes[i] as char);
i += 1;
}
}
Err(PromptError::Parse(
"未闭合的 {{#raw}} 块".to_string(),
))
}
// ===== Renderer =====
fn render_fragments(
fragments: &[Fragment],
ctx: &TemplateContext,
partials: &HashMap<String, PromptTemplate>,
output: &mut String,
depth: u8,
) -> Result<(), PromptError> {
if depth > MAX_RENDER_DEPTH {
return Err(PromptError::MaxDepthReached(MAX_RENDER_DEPTH));
}
for frag in fragments {
match frag {
Fragment::Literal(text) => {
output.push_str(text);
}
Fragment::Variable { name } => {
match ctx.get(name) {
Some(val) => {
output.push_str(&format!("{}", val));
}
None => {
return Err(PromptError::VariableNotFound(name.clone()));
}
}
}
Fragment::If {
condition,
body,
else_body,
} => {
let truthy = ctx
.get(condition)
.map(|v| v.is_truthy())
.unwrap_or(false);
let target = if truthy { body } else { else_body };
render_fragments(target, ctx, partials, output, depth + 1)?;
}
Fragment::Each { variable, body } => {
let arr = match ctx.get(variable) {
Some(val) => val.as_array().ok_or_else(|| {
PromptError::NotAnArray(variable.clone())
})?,
None => {
return Err(PromptError::VariableNotFound(variable.clone()));
}
};
for item in arr {
let mut child_ctx = ctx.clone();
child_ctx.vars.insert("item".to_string(), item.clone());
render_fragments(body, &child_ctx, partials, output, depth + 1)?;
}
}
Fragment::Raw(text) => {
output.push_str(text);
}
Fragment::Include(name) => {
if let Some(partial) = partials.get(name) {
partial.render_with_partials(ctx, partials)?;
} else {
return Err(PromptError::PartialNotFound(name.clone()));
}
}
}
}
Ok(())
}
// ===== PromptTemplateRegistry =====
/// 内部存储的模板(支持延迟编译)。
enum StoredTemplate {
Compiled(PromptTemplate),
Raw(String),
}
/// 模板注册表——管理多模板实例。
#[derive(Default)]
pub struct PromptTemplateRegistry {
templates: HashMap<String, StoredTemplate>,
}
impl PromptTemplateRegistry {
/// 创建一个空的模板注册表。
pub fn new() -> Self {
Self {
templates: HashMap::new(),
}
}
/// 从模板字符串编译并注册(立即编译)。
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError> {
let compiled = PromptTemplate::compile(template)?;
self.templates
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
Ok(())
}
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
pub fn register_lazy(&mut self, name: &str, template: &str) {
self.templates.insert(
name.to_string(),
StoredTemplate::Raw(template.to_string()),
);
}
/// 从文件读取并编译注册。
pub fn register_file(&mut self, name: &str, path: &std::path::Path) -> Result<(), PromptError> {
let content = std::fs::read_to_string(path)?;
let compiled = PromptTemplate::compile(&content)?;
self.templates
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
Ok(())
}
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError> {
if let Some(stored) = self.templates.get_mut(name) {
if let StoredTemplate::Raw(raw) = stored {
let compiled = PromptTemplate::compile(raw)?;
*stored = StoredTemplate::Compiled(compiled);
}
match stored {
StoredTemplate::Compiled(tpl) => Ok(tpl),
_ => unreachable!(),
}
} else {
Err(PromptError::PartialNotFound(name.to_string()))
}
}
/// 按名称渲染。
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError> {
let tpl = self.get(name)?;
tpl.render(ctx)
}
}