feat(prompt): 添加提示词工程模块并扩展LLM周期接口
新增 `prompt` 模块,包含模板引擎、组合器和错误类型,同时在`LlmCycle`中增加直接操作消息历史的方法和`submit_messages`接口
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
//! agcore —— 智能体(Agent)核心工具箱。
|
//! agcore —— 智能体(Agent)核心工具箱。
|
||||||
|
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
|
pub mod prompt;
|
||||||
|
|
||||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,94 @@ impl LlmCycle {
|
|||||||
self.usage.reset();
|
self.usage.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。
|
||||||
|
pub fn with_messages(mut self, messages: Vec<OpenaiChatMessage>) -> Self {
|
||||||
|
self.messages = messages;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 追加消息到历史尾部。
|
||||||
|
pub fn extend_messages(&mut self, messages: Vec<OpenaiChatMessage>) {
|
||||||
|
self.messages.extend(messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用预构建消息提交(跳过自动 push user prompt)。
|
||||||
|
///
|
||||||
|
/// 与 `submit()` 不同,不自动添加 `user_text(prompt)`,也不自动插入 system prompt。
|
||||||
|
/// 调用方完全控制消息序列内容。
|
||||||
|
pub async fn submit_messages(
|
||||||
|
&mut self,
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
) -> Result<ChatResponse, LlmError> {
|
||||||
|
let openai_tools: Option<Vec<OpenaiTool>> = if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|t| OpenaiTool::Function {
|
||||||
|
function: t.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages,
|
||||||
|
max_tokens: self.config.max_tokens,
|
||||||
|
temperature: self.config.temperature,
|
||||||
|
tools: openai_tools,
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest)
|
||||||
|
.with_request(&request);
|
||||||
|
let results = executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PreRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
if results.iter().any(|r| r.should_block) {
|
||||||
|
let reason = results
|
||||||
|
.iter()
|
||||||
|
.find(|r| r.should_block)
|
||||||
|
.and_then(|r| r.reason.clone())
|
||||||
|
.unwrap_or_else(|| "Blocked by pre-request hook".to_string());
|
||||||
|
return Err(LlmError::Other(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.provider.chat(request).await {
|
||||||
|
Ok(response) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let post_request = ChatRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
messages: vec![],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest)
|
||||||
|
.with_request(&post_request);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::PostRequest, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
self.usage.add(&response.usage);
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if let Some(ref executor) = self.hook_executor {
|
||||||
|
let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e);
|
||||||
|
executor
|
||||||
|
.execute(crate::llm::hooks::HookEvent::OnError, &ctx)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// 提交用户消息并获取 LLM 响应。
|
/// 提交用户消息并获取 LLM 响应。
|
||||||
pub async fn submit(
|
pub async fn submit(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
pub mod error;
|
||||||
|
pub mod template;
|
||||||
|
pub mod composer;
|
||||||
|
|
||||||
|
pub use error::PromptError;
|
||||||
|
pub use template::{PromptTemplate, PromptTemplateRegistry, TemplateContext, TemplateValue};
|
||||||
|
pub use composer::PromptComposer;
|
||||||
@@ -0,0 +1,406 @@
|
|||||||
|
use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart};
|
||||||
|
use crate::llm::types::request::OpenaiChatRequest;
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
use crate::prompt::template::{PromptTemplate, TemplateContext};
|
||||||
|
|
||||||
|
/// 提示词组合器——构建多角色消息序列。
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct PromptComposer {
|
||||||
|
messages: Vec<OpenaiChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptComposer {
|
||||||
|
/// 创建一个空的组合器。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从已有的消息列表初始化。
|
||||||
|
pub fn from_messages(messages: Vec<OpenaiChatMessage>) -> Self {
|
||||||
|
Self { messages }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 纯文本消息 =====
|
||||||
|
|
||||||
|
/// 添加一条纯文本 system 消息。
|
||||||
|
pub fn system(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::system_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 user 消息。
|
||||||
|
pub fn user(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::user_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 assistant 消息。
|
||||||
|
pub fn assistant(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::assistant_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条纯文本 developer 消息(o1 系列模型使用)。
|
||||||
|
pub fn developer(mut self, text: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::developer_text(text.into()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条 Tool 消息(工具执行结果回传)。
|
||||||
|
pub fn tool(mut self, tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::tool_result(
|
||||||
|
tool_call_id.into(),
|
||||||
|
content.into(),
|
||||||
|
));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 模板消息 =====
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 user 消息。
|
||||||
|
pub fn user_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::user_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 system 消息。
|
||||||
|
pub fn system_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::system_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 assistant 消息。
|
||||||
|
pub fn assistant_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::assistant_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用模板和上下文渲染后添加为 developer 消息。
|
||||||
|
pub fn developer_template(
|
||||||
|
mut self,
|
||||||
|
template: &PromptTemplate,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
) -> Result<Self, PromptError> {
|
||||||
|
let text = template.render(ctx)?;
|
||||||
|
self.push_message(OpenaiChatMessage::developer_text(text));
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 多模态 ContentPart =====
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 system 消息。
|
||||||
|
pub fn system_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::System {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 user 消息。
|
||||||
|
pub fn user_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::User {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 assistant 消息。
|
||||||
|
pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 developer 消息。
|
||||||
|
pub fn developer_content(mut self, part: OpenaiContentPart) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Developer {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加一条含指定 ContentPart 的 Tool 消息。
|
||||||
|
pub fn tool_content(
|
||||||
|
mut self,
|
||||||
|
tool_call_id: impl Into<String>,
|
||||||
|
part: OpenaiContentPart,
|
||||||
|
) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Tool {
|
||||||
|
content: ContentField::Array(vec![part]),
|
||||||
|
tool_call_id: tool_call_id.into(),
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 user 消息。
|
||||||
|
pub fn user_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::User {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 system 消息。
|
||||||
|
pub fn system_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::System {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 assistant 消息。
|
||||||
|
pub fn assistant_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Assistant {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量添加 ContentPart 作为 developer 消息。
|
||||||
|
pub fn developer_contents(mut self, parts: Vec<OpenaiContentPart>) -> Self {
|
||||||
|
self.push_message(OpenaiChatMessage::Developer {
|
||||||
|
content: ContentField::Array(parts),
|
||||||
|
name: None,
|
||||||
|
});
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 角色标识 =====
|
||||||
|
|
||||||
|
/// 为上一条添加的消息设置 `name` 字段。
|
||||||
|
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||||
|
let name = name.into();
|
||||||
|
if let Some(msg) = self.messages.last_mut() {
|
||||||
|
set_message_name(msg, name);
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 构建 =====
|
||||||
|
|
||||||
|
/// 构建最终的消息列表。
|
||||||
|
pub fn build(self) -> Vec<OpenaiChatMessage> {
|
||||||
|
self.messages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 构建并直接创建 ChatRequest(需搭配 model 参数)。
|
||||||
|
/// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`,
|
||||||
|
/// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。
|
||||||
|
pub fn build_request(self, model: impl Into<String>) -> OpenaiChatRequest {
|
||||||
|
OpenaiChatRequest {
|
||||||
|
model: model.into(),
|
||||||
|
messages: self.messages,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== 内部方法 =====
|
||||||
|
|
||||||
|
fn push_message(&mut self, msg: OpenaiChatMessage) {
|
||||||
|
self.messages.push(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_message_name(msg: &mut OpenaiChatMessage, name: String) {
|
||||||
|
match msg {
|
||||||
|
OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::System { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::User { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name),
|
||||||
|
OpenaiChatMessage::Tool { .. } => {}
|
||||||
|
OpenaiChatMessage::Function { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 验证消息序列是否符合 OpenAI API 要求。
|
||||||
|
pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> {
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(PromptError::InvalidSequence(
|
||||||
|
"消息列表不能为空".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_tool_call_ids: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
for (i, msg) in messages.iter().enumerate() {
|
||||||
|
match msg {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
tool_call_id,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
if last_tool_call_ids.is_empty() {
|
||||||
|
return Err(PromptError::InvalidSequence(format!(
|
||||||
|
"消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if !last_tool_call_ids.iter().any(|id| id == tool_call_id) {
|
||||||
|
return Err(PromptError::InvalidSequence(format!(
|
||||||
|
"消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls",
|
||||||
|
tool_call_id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
for call in calls {
|
||||||
|
let crate::llm::types::OpenaiToolCall::Function { id, .. } = call;
|
||||||
|
last_tool_call_ids.push(id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OpenaiChatMessage::Assistant {
|
||||||
|
tool_calls: None, ..
|
||||||
|
} => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
last_tool_call_ids.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::prompt::TemplateValue;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_composer_basic() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("Hello")
|
||||||
|
.assistant("Hi there!")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert_eq!(msgs.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_composer_tool() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("What's the weather?")
|
||||||
|
.assistant("Let me check")
|
||||||
|
.tool("call_123", "Sunny, 25°C")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert_eq!(msgs.len(), 4);
|
||||||
|
match &msgs[3] {
|
||||||
|
OpenaiChatMessage::Tool {
|
||||||
|
tool_call_id,
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
assert_eq!(tool_call_id, "call_123");
|
||||||
|
match content {
|
||||||
|
ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected Tool message"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_messages_ok() {
|
||||||
|
let msgs = PromptComposer::new()
|
||||||
|
.system("You are helpful")
|
||||||
|
.user("Hello")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert!(validate_messages(&msgs).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_messages_empty() {
|
||||||
|
let msgs: Vec<OpenaiChatMessage> = vec![];
|
||||||
|
assert!(validate_messages(&msgs).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_render() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("name", "Alice");
|
||||||
|
ctx.insert("count", "5");
|
||||||
|
|
||||||
|
let result = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(result, "Hello Alice, you have 5 messages");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_if() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("name", "Bob");
|
||||||
|
|
||||||
|
let with_name = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(with_name, "Hello Bob");
|
||||||
|
|
||||||
|
let without_name = tpl.render(&TemplateContext::new()).unwrap();
|
||||||
|
assert_eq!(without_name, "Hello Guest");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_each() {
|
||||||
|
let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap();
|
||||||
|
let mut ctx = TemplateContext::new();
|
||||||
|
ctx.insert("items", TemplateValue::Array(vec![
|
||||||
|
TemplateValue::String("a".to_string()),
|
||||||
|
TemplateValue::String("b".to_string()),
|
||||||
|
TemplateValue::String("c".to_string()),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let result = tpl.render(&ctx).unwrap();
|
||||||
|
assert_eq!(result, "Items: a, b, c, ");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_template_display() {
|
||||||
|
let tpl = PromptTemplate::compile("Hello {{name}}").unwrap();
|
||||||
|
assert_eq!(format!("{}", tpl), "Hello {{name}}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_context_from_json() {
|
||||||
|
let json: serde_json::Value = serde_json::json!({
|
||||||
|
"name": "Alice",
|
||||||
|
"active": true
|
||||||
|
});
|
||||||
|
let ctx = TemplateContext::from_json(&json).unwrap();
|
||||||
|
assert!(ctx.get("name").is_some());
|
||||||
|
assert!(ctx.get("active").is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum PromptError {
|
||||||
|
#[error("模板解析错误: {0}")]
|
||||||
|
Parse(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 变量 '{0}' 未找到")]
|
||||||
|
VariableNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: 引用的子模板 '{0}' 未注册")]
|
||||||
|
PartialNotFound(String),
|
||||||
|
|
||||||
|
#[error("渲染错误: '{0}' 不是数组,无法遍历")]
|
||||||
|
NotAnArray(String),
|
||||||
|
|
||||||
|
#[error("渲染递归超过最大深度限制 ({0})")]
|
||||||
|
MaxDepthReached(u8),
|
||||||
|
|
||||||
|
#[error("渲染错误: {0}")]
|
||||||
|
Render(String),
|
||||||
|
|
||||||
|
#[error("消息序列校验失败: {0}")]
|
||||||
|
InvalidSequence(String),
|
||||||
|
|
||||||
|
#[error("文件读取错误: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
@@ -0,0 +1,543 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fmt;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::prompt::error::PromptError;
|
||||||
|
|
||||||
|
const MAX_RENDER_DEPTH: u8 = 16;
|
||||||
|
|
||||||
|
// ===== TemplateValue =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum TemplateValue {
|
||||||
|
String(String),
|
||||||
|
Bool(bool),
|
||||||
|
Array(Vec<TemplateValue>),
|
||||||
|
Object(HashMap<String, TemplateValue>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for TemplateValue {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
TemplateValue::String(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for TemplateValue {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
TemplateValue::String(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<bool> for TemplateValue {
|
||||||
|
fn from(b: bool) -> Self {
|
||||||
|
TemplateValue::Bool(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for TemplateValue {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
TemplateValue::String(s) => write!(f, "{}", s),
|
||||||
|
TemplateValue::Bool(b) => write!(f, "{}", b),
|
||||||
|
TemplateValue::Array(arr) => {
|
||||||
|
let strs: Vec<String> = arr.iter().map(|v| format!("{}", v)).collect();
|
||||||
|
write!(f, "[{}]", strs.join(", "))
|
||||||
|
}
|
||||||
|
TemplateValue::Object(map) => {
|
||||||
|
let strs: Vec<String> = map
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("\"{}\": {}", k, v))
|
||||||
|
.collect();
|
||||||
|
write!(f, "{{{}}}", strs.join(", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemplateValue {
|
||||||
|
fn is_truthy(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
TemplateValue::String(s) => !s.is_empty(),
|
||||||
|
TemplateValue::Bool(b) => *b,
|
||||||
|
TemplateValue::Array(arr) => !arr.is_empty(),
|
||||||
|
TemplateValue::Object(map) => !map.is_empty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_array(&self) -> Option<&Vec<TemplateValue>> {
|
||||||
|
match self {
|
||||||
|
TemplateValue::Array(arr) => Some(arr),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== TemplateContext =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct TemplateContext {
|
||||||
|
vars: HashMap<String, TemplateValue>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TemplateContext {
|
||||||
|
/// 创建一个空的模板上下文。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。
|
||||||
|
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<TemplateValue>) {
|
||||||
|
self.vars.insert(key.into(), value.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称获取变量值。
|
||||||
|
pub fn get(&self, key: &str) -> Option<&TemplateValue> {
|
||||||
|
self.vars.get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。
|
||||||
|
pub fn from_json(value: &Value) -> Result<Self, PromptError> {
|
||||||
|
let map = value
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| PromptError::Render("JSON 根值必须是对象".to_string()))?;
|
||||||
|
|
||||||
|
let mut ctx = Self::new();
|
||||||
|
for (k, v) in map {
|
||||||
|
ctx.vars.insert(k.clone(), json_to_template_value(v)?);
|
||||||
|
}
|
||||||
|
Ok(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 `HashMap` 构造(适用于配置加载场景)。
|
||||||
|
pub fn from_map(map: HashMap<String, TemplateValue>) -> Self {
|
||||||
|
Self { vars: map }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn json_to_template_value(v: &Value) -> Result<TemplateValue, PromptError> {
|
||||||
|
match v {
|
||||||
|
Value::Null => Ok(TemplateValue::String(String::new())),
|
||||||
|
Value::Bool(b) => Ok(TemplateValue::Bool(*b)),
|
||||||
|
Value::Number(n) => Ok(TemplateValue::String(n.to_string())),
|
||||||
|
Value::String(s) => Ok(TemplateValue::String(s.clone())),
|
||||||
|
Value::Array(arr) => {
|
||||||
|
let items: Result<Vec<TemplateValue>, _> =
|
||||||
|
arr.iter().map(json_to_template_value).collect();
|
||||||
|
Ok(TemplateValue::Array(items?))
|
||||||
|
}
|
||||||
|
Value::Object(obj) => {
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
for (k, v) in obj {
|
||||||
|
map.insert(k.clone(), json_to_template_value(v)?);
|
||||||
|
}
|
||||||
|
Ok(TemplateValue::Object(map))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Fragment (AST) =====
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Fragment {
|
||||||
|
Literal(String),
|
||||||
|
Variable { name: String },
|
||||||
|
If {
|
||||||
|
condition: String,
|
||||||
|
body: Vec<Fragment>,
|
||||||
|
else_body: Vec<Fragment>,
|
||||||
|
},
|
||||||
|
Each {
|
||||||
|
variable: String,
|
||||||
|
body: Vec<Fragment>,
|
||||||
|
},
|
||||||
|
Raw(String),
|
||||||
|
Include(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== PromptTemplate =====
|
||||||
|
|
||||||
|
pub struct PromptTemplate {
|
||||||
|
raw: String,
|
||||||
|
fragments: Vec<Fragment>,
|
||||||
|
partials: HashMap<String, PromptTemplate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplate {
|
||||||
|
/// 从模板字符串编译。
|
||||||
|
pub fn compile(template: &str) -> Result<Self, PromptError> {
|
||||||
|
let fragments = compile_fragments(template)?;
|
||||||
|
Ok(Self {
|
||||||
|
raw: template.to_string(),
|
||||||
|
fragments,
|
||||||
|
partials: HashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用上下文渲染。
|
||||||
|
pub fn render(&self, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||||
|
let mut output = String::new();
|
||||||
|
render_fragments(&self.fragments, ctx, &self.partials, &mut output, 0)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 使用上下文和外部 partials 渲染。
|
||||||
|
pub fn render_with_partials(
|
||||||
|
&self,
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
partials: &HashMap<String, PromptTemplate>,
|
||||||
|
) -> Result<String, PromptError> {
|
||||||
|
let mut output = String::new();
|
||||||
|
render_fragments(&self.fragments, ctx, partials, &mut output, 0)?;
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 注册可引用的子模板。
|
||||||
|
pub fn register_partial(&mut self, name: &str, template: PromptTemplate) {
|
||||||
|
self.partials.insert(name.to_string(), template);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for PromptTemplate {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Compiler =====
|
||||||
|
|
||||||
|
fn compile_fragments(template: &str) -> Result<Vec<Fragment>, PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut fragments = Vec::new();
|
||||||
|
let mut i = 0;
|
||||||
|
let mut literal = String::new();
|
||||||
|
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
if !literal.is_empty() {
|
||||||
|
fragments.push(Fragment::Literal(literal.clone()));
|
||||||
|
literal.clear();
|
||||||
|
}
|
||||||
|
let (tag_content, end) = parse_tag(bytes, i)?;
|
||||||
|
i = end;
|
||||||
|
|
||||||
|
let tag = tag_content.trim();
|
||||||
|
if let Some(rest) = tag.strip_prefix("#if ") {
|
||||||
|
let (body, else_body, new_i) =
|
||||||
|
parse_block(template, i, "if")?;
|
||||||
|
let condition = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::If {
|
||||||
|
condition,
|
||||||
|
body,
|
||||||
|
else_body,
|
||||||
|
});
|
||||||
|
i = new_i;
|
||||||
|
} else if let Some(rest) = tag.strip_prefix("#each ") {
|
||||||
|
let (body, new_i) = parse_each_block(template, i)?;
|
||||||
|
let variable = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::Each { variable, body });
|
||||||
|
i = new_i;
|
||||||
|
} else if tag == "#raw" {
|
||||||
|
let (raw, new_i) = parse_raw_block(template, i)?;
|
||||||
|
fragments.push(Fragment::Raw(raw));
|
||||||
|
i = new_i;
|
||||||
|
} else if let Some(rest) = tag.strip_prefix("> ") {
|
||||||
|
let name = rest.trim().to_string();
|
||||||
|
fragments.push(Fragment::Include(name));
|
||||||
|
} else if tag.starts_with("/") {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
let name = tag.to_string();
|
||||||
|
fragments.push(Fragment::Variable { name });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
literal.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !literal.is_empty() {
|
||||||
|
fragments.push(Fragment::Literal(literal));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(fragments)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tag(bytes: &[u8], start: usize) -> Result<(String, usize), PromptError> {
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut i = start + 2;
|
||||||
|
let mut content = String::new();
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'}' && i + 1 < len && bytes[i + 1] == b'}' {
|
||||||
|
return Ok((content, i + 2));
|
||||||
|
}
|
||||||
|
content.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
Err(PromptError::Parse("未闭合的 {{ 标签".to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_block(
|
||||||
|
template: &str,
|
||||||
|
start: usize,
|
||||||
|
kind: &str,
|
||||||
|
) -> Result<(Vec<Fragment>, Vec<Fragment>, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut depth = 1u32;
|
||||||
|
let mut i = start;
|
||||||
|
let mut body = String::new();
|
||||||
|
let mut else_body = String::new();
|
||||||
|
let mut is_else = false;
|
||||||
|
|
||||||
|
while i < len && depth > 0 {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == format!("/{kind}") {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let (if_fragments, else_fragments) = if is_else {
|
||||||
|
(compile_fragments(&body)?, compile_fragments(&else_body)?)
|
||||||
|
} else {
|
||||||
|
(compile_fragments(&body)?, compile_fragments("")?)
|
||||||
|
};
|
||||||
|
return Ok((if_fragments, else_fragments, end));
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
}
|
||||||
|
i = end;
|
||||||
|
} else if tag == "#if " || tag.starts_with("#if ") {
|
||||||
|
depth += 1;
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
} else if tag == "else" && depth == 1 {
|
||||||
|
is_else = true;
|
||||||
|
i = end;
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if is_else {
|
||||||
|
else_body.push(bytes[i] as char);
|
||||||
|
} else {
|
||||||
|
body.push(bytes[i] as char);
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(format!("未闭合的 {{#{}}} 块", kind)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_each_block(
|
||||||
|
template: &str,
|
||||||
|
start: usize,
|
||||||
|
) -> Result<(Vec<Fragment>, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut depth = 1u32;
|
||||||
|
let mut i = start;
|
||||||
|
let mut body = String::new();
|
||||||
|
|
||||||
|
while i < len && depth > 0 {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == "/each" {
|
||||||
|
depth -= 1;
|
||||||
|
if depth == 0 {
|
||||||
|
let fragments = compile_fragments(&body)?;
|
||||||
|
return Ok((fragments, end));
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
}
|
||||||
|
i = end;
|
||||||
|
} else if tag == "#each " || tag.starts_with("#each ") {
|
||||||
|
depth += 1;
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
} else {
|
||||||
|
body.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(
|
||||||
|
"未闭合的 {{#each}} 块".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_raw_block(template: &str, start: usize) -> Result<(String, usize), PromptError> {
|
||||||
|
let bytes = template.as_bytes();
|
||||||
|
let len = bytes.len();
|
||||||
|
let mut i = start;
|
||||||
|
let mut content = String::new();
|
||||||
|
|
||||||
|
while i < len {
|
||||||
|
if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' {
|
||||||
|
let (tag, end) = parse_tag(bytes, i)?;
|
||||||
|
let tag = tag.trim().to_string();
|
||||||
|
if tag == "/raw" {
|
||||||
|
return Ok((content, end));
|
||||||
|
} else {
|
||||||
|
content.push_str(&template[i..end]);
|
||||||
|
i = end;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content.push(bytes[i] as char);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(PromptError::Parse(
|
||||||
|
"未闭合的 {{#raw}} 块".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Renderer =====
|
||||||
|
|
||||||
|
fn render_fragments(
|
||||||
|
fragments: &[Fragment],
|
||||||
|
ctx: &TemplateContext,
|
||||||
|
partials: &HashMap<String, PromptTemplate>,
|
||||||
|
output: &mut String,
|
||||||
|
depth: u8,
|
||||||
|
) -> Result<(), PromptError> {
|
||||||
|
if depth > MAX_RENDER_DEPTH {
|
||||||
|
return Err(PromptError::MaxDepthReached(MAX_RENDER_DEPTH));
|
||||||
|
}
|
||||||
|
|
||||||
|
for frag in fragments {
|
||||||
|
match frag {
|
||||||
|
Fragment::Literal(text) => {
|
||||||
|
output.push_str(text);
|
||||||
|
}
|
||||||
|
Fragment::Variable { name } => {
|
||||||
|
match ctx.get(name) {
|
||||||
|
Some(val) => {
|
||||||
|
output.push_str(&format!("{}", val));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(PromptError::VariableNotFound(name.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Fragment::If {
|
||||||
|
condition,
|
||||||
|
body,
|
||||||
|
else_body,
|
||||||
|
} => {
|
||||||
|
let truthy = ctx
|
||||||
|
.get(condition)
|
||||||
|
.map(|v| v.is_truthy())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let target = if truthy { body } else { else_body };
|
||||||
|
render_fragments(target, ctx, partials, output, depth + 1)?;
|
||||||
|
}
|
||||||
|
Fragment::Each { variable, body } => {
|
||||||
|
let arr = match ctx.get(variable) {
|
||||||
|
Some(val) => val.as_array().ok_or_else(|| {
|
||||||
|
PromptError::NotAnArray(variable.clone())
|
||||||
|
})?,
|
||||||
|
None => {
|
||||||
|
return Err(PromptError::VariableNotFound(variable.clone()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for item in arr {
|
||||||
|
let mut child_ctx = ctx.clone();
|
||||||
|
child_ctx.vars.insert("item".to_string(), item.clone());
|
||||||
|
render_fragments(body, &child_ctx, partials, output, depth + 1)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Fragment::Raw(text) => {
|
||||||
|
output.push_str(text);
|
||||||
|
}
|
||||||
|
Fragment::Include(name) => {
|
||||||
|
if let Some(partial) = partials.get(name) {
|
||||||
|
partial.render_with_partials(ctx, partials)?;
|
||||||
|
} else {
|
||||||
|
return Err(PromptError::PartialNotFound(name.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== PromptTemplateRegistry =====
|
||||||
|
|
||||||
|
/// 内部存储的模板(支持延迟编译)。
|
||||||
|
enum StoredTemplate {
|
||||||
|
Compiled(PromptTemplate),
|
||||||
|
Raw(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 模板注册表——管理多模板实例。
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct PromptTemplateRegistry {
|
||||||
|
templates: HashMap<String, StoredTemplate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptTemplateRegistry {
|
||||||
|
/// 创建一个空的模板注册表。
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
templates: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从模板字符串编译并注册(立即编译)。
|
||||||
|
pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError> {
|
||||||
|
let compiled = PromptTemplate::compile(template)?;
|
||||||
|
self.templates
|
||||||
|
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 延迟编译注册:只存储原始字符串,首次渲染时编译。
|
||||||
|
pub fn register_lazy(&mut self, name: &str, template: &str) {
|
||||||
|
self.templates.insert(
|
||||||
|
name.to_string(),
|
||||||
|
StoredTemplate::Raw(template.to_string()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从文件读取并编译注册。
|
||||||
|
pub fn register_file(&mut self, name: &str, path: &std::path::Path) -> Result<(), PromptError> {
|
||||||
|
let content = std::fs::read_to_string(path)?;
|
||||||
|
let compiled = PromptTemplate::compile(&content)?;
|
||||||
|
self.templates
|
||||||
|
.insert(name.to_string(), StoredTemplate::Compiled(compiled));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取已注册的模板(延迟编译的模板在此首次编译)。
|
||||||
|
pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError> {
|
||||||
|
if let Some(stored) = self.templates.get_mut(name) {
|
||||||
|
if let StoredTemplate::Raw(raw) = stored {
|
||||||
|
let compiled = PromptTemplate::compile(raw)?;
|
||||||
|
*stored = StoredTemplate::Compiled(compiled);
|
||||||
|
}
|
||||||
|
match stored {
|
||||||
|
StoredTemplate::Compiled(tpl) => Ok(tpl),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(PromptError::PartialNotFound(name.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称渲染。
|
||||||
|
pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result<String, PromptError> {
|
||||||
|
let tpl = self.get(name)?;
|
||||||
|
tpl.render(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user