diff --git a/src/lib.rs b/src/lib.rs index b7061ac..b873931 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ //! agcore —— 智能体(Agent)核心工具箱。 pub mod llm; +pub mod prompt; use tracing_subscriber::{EnvFilter, fmt, prelude::*}; diff --git a/src/llm/cycle.rs b/src/llm/cycle.rs index 529a2fd..06e67cf 100644 --- a/src/llm/cycle.rs +++ b/src/llm/cycle.rs @@ -113,6 +113,94 @@ impl LlmCycle { self.usage.reset(); } + /// 直接设置消息历史(覆盖已有消息),支持 Builder 链式调用。 + pub fn with_messages(mut self, messages: Vec) -> Self { + self.messages = messages; + self + } + + /// 追加消息到历史尾部。 + pub fn extend_messages(&mut self, messages: Vec) { + self.messages.extend(messages); + } + + /// 使用预构建消息提交(跳过自动 push user prompt)。 + /// + /// 与 `submit()` 不同,不自动添加 `user_text(prompt)`,也不自动插入 system prompt。 + /// 调用方完全控制消息序列内容。 + pub async fn submit_messages( + &mut self, + messages: Vec, + tools: Vec, + ) -> Result { + let openai_tools: Option> = if tools.is_empty() { + None + } else { + Some( + tools + .iter() + .map(|t| OpenaiTool::Function { + function: t.clone(), + }) + .collect(), + ) + }; + + let request = ChatRequest { + model: self.config.model.clone(), + messages, + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, + tools: openai_tools, + tool_choice: Some(ToolChoice::Auto), + ..Default::default() + }; + + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PreRequest) + .with_request(&request); + let results = executor + .execute(crate::llm::hooks::HookEvent::PreRequest, &ctx) + .await; + if results.iter().any(|r| r.should_block) { + let reason = results + .iter() + .find(|r| r.should_block) + .and_then(|r| r.reason.clone()) + .unwrap_or_else(|| "Blocked by pre-request hook".to_string()); + return Err(LlmError::Other(reason)); + } + } + + match self.provider.chat(request).await { + Ok(response) => { + if let Some(ref executor) = self.hook_executor { + let post_request = ChatRequest { + model: self.config.model.clone(), + messages: vec![], + ..Default::default() + }; + let ctx = HookContext::new(crate::llm::hooks::HookEvent::PostRequest) + .with_request(&post_request); + executor + .execute(crate::llm::hooks::HookEvent::PostRequest, &ctx) + .await; + } + self.usage.add(&response.usage); + Ok(response) + } + Err(e) => { + if let Some(ref executor) = self.hook_executor { + let ctx = HookContext::new(crate::llm::hooks::HookEvent::OnError).with_error(&e); + executor + .execute(crate::llm::hooks::HookEvent::OnError, &ctx) + .await; + } + Err(e) + } + } + } + /// 提交用户消息并获取 LLM 响应。 pub async fn submit( &mut self, diff --git a/src/prompt.rs b/src/prompt.rs new file mode 100644 index 0000000..b6086c1 --- /dev/null +++ b/src/prompt.rs @@ -0,0 +1,7 @@ +pub mod error; +pub mod template; +pub mod composer; + +pub use error::PromptError; +pub use template::{PromptTemplate, PromptTemplateRegistry, TemplateContext, TemplateValue}; +pub use composer::PromptComposer; diff --git a/src/prompt/composer.rs b/src/prompt/composer.rs new file mode 100644 index 0000000..4999a64 --- /dev/null +++ b/src/prompt/composer.rs @@ -0,0 +1,406 @@ +use crate::llm::types::message::{ContentField, OpenaiChatMessage, OpenaiContentPart}; +use crate::llm::types::request::OpenaiChatRequest; +use crate::prompt::error::PromptError; +use crate::prompt::template::{PromptTemplate, TemplateContext}; + +/// 提示词组合器——构建多角色消息序列。 +#[derive(Default)] +pub struct PromptComposer { + messages: Vec, +} + +impl PromptComposer { + /// 创建一个空的组合器。 + pub fn new() -> Self { + Self::default() + } + + /// 从已有的消息列表初始化。 + pub fn from_messages(messages: Vec) -> Self { + Self { messages } + } + + // ===== 纯文本消息 ===== + + /// 添加一条纯文本 system 消息。 + pub fn system(mut self, text: impl Into) -> Self { + self.push_message(OpenaiChatMessage::system_text(text.into())); + self + } + + /// 添加一条纯文本 user 消息。 + pub fn user(mut self, text: impl Into) -> Self { + self.push_message(OpenaiChatMessage::user_text(text.into())); + self + } + + /// 添加一条纯文本 assistant 消息。 + pub fn assistant(mut self, text: impl Into) -> Self { + self.push_message(OpenaiChatMessage::assistant_text(text.into())); + self + } + + /// 添加一条纯文本 developer 消息(o1 系列模型使用)。 + pub fn developer(mut self, text: impl Into) -> Self { + self.push_message(OpenaiChatMessage::developer_text(text.into())); + self + } + + /// 添加一条 Tool 消息(工具执行结果回传)。 + pub fn tool(mut self, tool_call_id: impl Into, content: impl Into) -> Self { + self.push_message(OpenaiChatMessage::tool_result( + tool_call_id.into(), + content.into(), + )); + self + } + + // ===== 模板消息 ===== + + /// 使用模板和上下文渲染后添加为 user 消息。 + pub fn user_template( + mut self, + template: &PromptTemplate, + ctx: &TemplateContext, + ) -> Result { + let text = template.render(ctx)?; + self.push_message(OpenaiChatMessage::user_text(text)); + Ok(self) + } + + /// 使用模板和上下文渲染后添加为 system 消息。 + pub fn system_template( + mut self, + template: &PromptTemplate, + ctx: &TemplateContext, + ) -> Result { + let text = template.render(ctx)?; + self.push_message(OpenaiChatMessage::system_text(text)); + Ok(self) + } + + /// 使用模板和上下文渲染后添加为 assistant 消息。 + pub fn assistant_template( + mut self, + template: &PromptTemplate, + ctx: &TemplateContext, + ) -> Result { + let text = template.render(ctx)?; + self.push_message(OpenaiChatMessage::assistant_text(text)); + Ok(self) + } + + /// 使用模板和上下文渲染后添加为 developer 消息。 + pub fn developer_template( + mut self, + template: &PromptTemplate, + ctx: &TemplateContext, + ) -> Result { + let text = template.render(ctx)?; + self.push_message(OpenaiChatMessage::developer_text(text)); + Ok(self) + } + + // ===== 多模态 ContentPart ===== + + /// 添加一条含指定 ContentPart 的 system 消息。 + pub fn system_content(mut self, part: OpenaiContentPart) -> Self { + self.push_message(OpenaiChatMessage::System { + content: ContentField::Array(vec![part]), + name: None, + }); + self + } + + /// 添加一条含指定 ContentPart 的 user 消息。 + pub fn user_content(mut self, part: OpenaiContentPart) -> Self { + self.push_message(OpenaiChatMessage::User { + content: ContentField::Array(vec![part]), + name: None, + }); + self + } + + /// 添加一条含指定 ContentPart 的 assistant 消息。 + pub fn assistant_content(mut self, part: OpenaiContentPart) -> Self { + self.push_message(OpenaiChatMessage::Assistant { + content: ContentField::Array(vec![part]), + refusal: None, + name: None, + tool_calls: None, + }); + self + } + + /// 添加一条含指定 ContentPart 的 developer 消息。 + pub fn developer_content(mut self, part: OpenaiContentPart) -> Self { + self.push_message(OpenaiChatMessage::Developer { + content: ContentField::Array(vec![part]), + name: None, + }); + self + } + + /// 添加一条含指定 ContentPart 的 Tool 消息。 + pub fn tool_content( + mut self, + tool_call_id: impl Into, + part: OpenaiContentPart, + ) -> Self { + self.push_message(OpenaiChatMessage::Tool { + content: ContentField::Array(vec![part]), + tool_call_id: tool_call_id.into(), + }); + self + } + + /// 批量添加 ContentPart 作为 user 消息。 + pub fn user_contents(mut self, parts: Vec) -> Self { + self.push_message(OpenaiChatMessage::User { + content: ContentField::Array(parts), + name: None, + }); + self + } + + /// 批量添加 ContentPart 作为 system 消息。 + pub fn system_contents(mut self, parts: Vec) -> Self { + self.push_message(OpenaiChatMessage::System { + content: ContentField::Array(parts), + name: None, + }); + self + } + + /// 批量添加 ContentPart 作为 assistant 消息。 + pub fn assistant_contents(mut self, parts: Vec) -> Self { + self.push_message(OpenaiChatMessage::Assistant { + content: ContentField::Array(parts), + refusal: None, + name: None, + tool_calls: None, + }); + self + } + + /// 批量添加 ContentPart 作为 developer 消息。 + pub fn developer_contents(mut self, parts: Vec) -> Self { + self.push_message(OpenaiChatMessage::Developer { + content: ContentField::Array(parts), + name: None, + }); + self + } + + // ===== 角色标识 ===== + + /// 为上一条添加的消息设置 `name` 字段。 + pub fn with_name(mut self, name: impl Into) -> Self { + let name = name.into(); + if let Some(msg) = self.messages.last_mut() { + set_message_name(msg, name); + } + self + } + + // ===== 构建 ===== + + /// 构建最终的消息列表。 + pub fn build(self) -> Vec { + self.messages + } + + /// 构建并直接创建 ChatRequest(需搭配 model 参数)。 + /// 返回的 `OpenaiChatRequest` 中 `tools`、`temperature`、`max_tokens` 等字段均为 `None`, + /// 可通过结构体更新语法补全:`OpenaiChatRequest { tools: Some(...), ..req }`。 + pub fn build_request(self, model: impl Into) -> OpenaiChatRequest { + OpenaiChatRequest { + model: model.into(), + messages: self.messages, + ..Default::default() + } + } + + // ===== 内部方法 ===== + + fn push_message(&mut self, msg: OpenaiChatMessage) { + self.messages.push(msg); + } +} + +fn set_message_name(msg: &mut OpenaiChatMessage, name: String) { + match msg { + OpenaiChatMessage::Developer { name: n, .. } => *n = Some(name), + OpenaiChatMessage::System { name: n, .. } => *n = Some(name), + OpenaiChatMessage::User { name: n, .. } => *n = Some(name), + OpenaiChatMessage::Assistant { name: n, .. } => *n = Some(name), + OpenaiChatMessage::Tool { .. } => {} + OpenaiChatMessage::Function { .. } => {} + } +} + +/// 验证消息序列是否符合 OpenAI API 要求。 +pub fn validate_messages(messages: &[OpenaiChatMessage]) -> Result<(), PromptError> { + if messages.is_empty() { + return Err(PromptError::InvalidSequence( + "消息列表不能为空".to_string(), + )); + } + + let mut last_tool_call_ids: Vec = Vec::new(); + + for (i, msg) in messages.iter().enumerate() { + match msg { + OpenaiChatMessage::Tool { + tool_call_id, + .. + } => { + if last_tool_call_ids.is_empty() { + return Err(PromptError::InvalidSequence(format!( + "消息[{i}] Tool 消息前必须有 Assistant 消息且含 tool_calls" + ))); + } + if !last_tool_call_ids.iter().any(|id| id == tool_call_id) { + return Err(PromptError::InvalidSequence(format!( + "消息[{i}] Tool 消息的 tool_call_id '{}' 未匹配任何 assistant tool_calls", + tool_call_id + ))); + } + } + OpenaiChatMessage::Assistant { + tool_calls: Some(calls), + .. + } => { + last_tool_call_ids.clear(); + for call in calls { + let crate::llm::types::OpenaiToolCall::Function { id, .. } = call; + last_tool_call_ids.push(id.clone()); + } + } + OpenaiChatMessage::Assistant { + tool_calls: None, .. + } => { + last_tool_call_ids.clear(); + } + _ => { + last_tool_call_ids.clear(); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prompt::TemplateValue; + + #[test] + fn test_composer_basic() { + let msgs = PromptComposer::new() + .system("You are helpful") + .user("Hello") + .assistant("Hi there!") + .build(); + + assert_eq!(msgs.len(), 3); + } + + #[test] + fn test_composer_tool() { + let msgs = PromptComposer::new() + .system("You are helpful") + .user("What's the weather?") + .assistant("Let me check") + .tool("call_123", "Sunny, 25°C") + .build(); + + assert_eq!(msgs.len(), 4); + match &msgs[3] { + OpenaiChatMessage::Tool { + tool_call_id, + content, + .. + } => { + assert_eq!(tool_call_id, "call_123"); + match content { + ContentField::String(s) => assert_eq!(s, "Sunny, 25°C"), + _ => {} + } + } + _ => panic!("Expected Tool message"), + } + } + + #[test] + fn test_validate_messages_ok() { + let msgs = PromptComposer::new() + .system("You are helpful") + .user("Hello") + .build(); + + assert!(validate_messages(&msgs).is_ok()); + } + + #[test] + fn test_validate_messages_empty() { + let msgs: Vec = vec![]; + assert!(validate_messages(&msgs).is_err()); + } + + #[test] + fn test_template_render() { + let tpl = PromptTemplate::compile("Hello {{name}}, you have {{count}} messages").unwrap(); + let mut ctx = TemplateContext::new(); + ctx.insert("name", "Alice"); + ctx.insert("count", "5"); + + let result = tpl.render(&ctx).unwrap(); + assert_eq!(result, "Hello Alice, you have 5 messages"); + } + + #[test] + fn test_template_if() { + let tpl = PromptTemplate::compile("Hello {{#if name}}{{name}}{{else}}Guest{{/if}}").unwrap(); + let mut ctx = TemplateContext::new(); + ctx.insert("name", "Bob"); + + let with_name = tpl.render(&ctx).unwrap(); + assert_eq!(with_name, "Hello Bob"); + + let without_name = tpl.render(&TemplateContext::new()).unwrap(); + assert_eq!(without_name, "Hello Guest"); + } + + #[test] + fn test_template_each() { + let tpl = PromptTemplate::compile("Items: {{#each items}}{{item}}, {{/each}}").unwrap(); + let mut ctx = TemplateContext::new(); + ctx.insert("items", TemplateValue::Array(vec![ + TemplateValue::String("a".to_string()), + TemplateValue::String("b".to_string()), + TemplateValue::String("c".to_string()), + ])); + + let result = tpl.render(&ctx).unwrap(); + assert_eq!(result, "Items: a, b, c, "); + } + + #[test] + fn test_template_display() { + let tpl = PromptTemplate::compile("Hello {{name}}").unwrap(); + assert_eq!(format!("{}", tpl), "Hello {{name}}"); + } + + #[test] + fn test_context_from_json() { + let json: serde_json::Value = serde_json::json!({ + "name": "Alice", + "active": true + }); + let ctx = TemplateContext::from_json(&json).unwrap(); + assert!(ctx.get("name").is_some()); + assert!(ctx.get("active").is_some()); + } +} diff --git a/src/prompt/error.rs b/src/prompt/error.rs new file mode 100644 index 0000000..3afa1cb --- /dev/null +++ b/src/prompt/error.rs @@ -0,0 +1,28 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum PromptError { + #[error("模板解析错误: {0}")] + Parse(String), + + #[error("渲染错误: 变量 '{0}' 未找到")] + VariableNotFound(String), + + #[error("渲染错误: 引用的子模板 '{0}' 未注册")] + PartialNotFound(String), + + #[error("渲染错误: '{0}' 不是数组,无法遍历")] + NotAnArray(String), + + #[error("渲染递归超过最大深度限制 ({0})")] + MaxDepthReached(u8), + + #[error("渲染错误: {0}")] + Render(String), + + #[error("消息序列校验失败: {0}")] + InvalidSequence(String), + + #[error("文件读取错误: {0}")] + Io(#[from] std::io::Error), +} diff --git a/src/prompt/template.rs b/src/prompt/template.rs new file mode 100644 index 0000000..8a5738f --- /dev/null +++ b/src/prompt/template.rs @@ -0,0 +1,543 @@ +use std::collections::HashMap; +use std::fmt; +use serde_json::Value; + +use crate::prompt::error::PromptError; + +const MAX_RENDER_DEPTH: u8 = 16; + +// ===== TemplateValue ===== + +#[derive(Debug, Clone)] +pub enum TemplateValue { + String(String), + Bool(bool), + Array(Vec), + Object(HashMap), +} + +impl From for TemplateValue { + fn from(s: String) -> Self { + TemplateValue::String(s) + } +} + +impl From<&str> for TemplateValue { + fn from(s: &str) -> Self { + TemplateValue::String(s.to_string()) + } +} + +impl From for TemplateValue { + fn from(b: bool) -> Self { + TemplateValue::Bool(b) + } +} + +impl fmt::Display for TemplateValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TemplateValue::String(s) => write!(f, "{}", s), + TemplateValue::Bool(b) => write!(f, "{}", b), + TemplateValue::Array(arr) => { + let strs: Vec = arr.iter().map(|v| format!("{}", v)).collect(); + write!(f, "[{}]", strs.join(", ")) + } + TemplateValue::Object(map) => { + let strs: Vec = map + .iter() + .map(|(k, v)| format!("\"{}\": {}", k, v)) + .collect(); + write!(f, "{{{}}}", strs.join(", ")) + } + } + } +} + +impl TemplateValue { + fn is_truthy(&self) -> bool { + match self { + TemplateValue::String(s) => !s.is_empty(), + TemplateValue::Bool(b) => *b, + TemplateValue::Array(arr) => !arr.is_empty(), + TemplateValue::Object(map) => !map.is_empty(), + } + } + + fn as_array(&self) -> Option<&Vec> { + match self { + TemplateValue::Array(arr) => Some(arr), + _ => None, + } + } +} + +// ===== TemplateContext ===== + +#[derive(Debug, Clone, Default)] +pub struct TemplateContext { + vars: HashMap, +} + +impl TemplateContext { + /// 创建一个空的模板上下文。 + pub fn new() -> Self { + Self::default() + } + + /// 插入变量(支持 `&str` / `String` / `bool` 自动转换)。 + pub fn insert(&mut self, key: impl Into, value: impl Into) { + self.vars.insert(key.into(), value.into()); + } + + /// 按名称获取变量值。 + pub fn get(&self, key: &str) -> Option<&TemplateValue> { + self.vars.get(key) + } + + /// 从 `serde_json::Value` 递归构造(支持嵌套 Object/Array)。 + pub fn from_json(value: &Value) -> Result { + let map = value + .as_object() + .ok_or_else(|| PromptError::Render("JSON 根值必须是对象".to_string()))?; + + let mut ctx = Self::new(); + for (k, v) in map { + ctx.vars.insert(k.clone(), json_to_template_value(v)?); + } + Ok(ctx) + } + + /// 从 `HashMap` 构造(适用于配置加载场景)。 + pub fn from_map(map: HashMap) -> Self { + Self { vars: map } + } +} + +fn json_to_template_value(v: &Value) -> Result { + match v { + Value::Null => Ok(TemplateValue::String(String::new())), + Value::Bool(b) => Ok(TemplateValue::Bool(*b)), + Value::Number(n) => Ok(TemplateValue::String(n.to_string())), + Value::String(s) => Ok(TemplateValue::String(s.clone())), + Value::Array(arr) => { + let items: Result, _> = + arr.iter().map(json_to_template_value).collect(); + Ok(TemplateValue::Array(items?)) + } + Value::Object(obj) => { + let mut map = HashMap::new(); + for (k, v) in obj { + map.insert(k.clone(), json_to_template_value(v)?); + } + Ok(TemplateValue::Object(map)) + } + } +} + +// ===== Fragment (AST) ===== + +#[derive(Debug, Clone)] +enum Fragment { + Literal(String), + Variable { name: String }, + If { + condition: String, + body: Vec, + else_body: Vec, + }, + Each { + variable: String, + body: Vec, + }, + Raw(String), + Include(String), +} + +// ===== PromptTemplate ===== + +pub struct PromptTemplate { + raw: String, + fragments: Vec, + partials: HashMap, +} + +impl PromptTemplate { + /// 从模板字符串编译。 + pub fn compile(template: &str) -> Result { + let fragments = compile_fragments(template)?; + Ok(Self { + raw: template.to_string(), + fragments, + partials: HashMap::new(), + }) + } + + /// 使用上下文渲染。 + pub fn render(&self, ctx: &TemplateContext) -> Result { + let mut output = String::new(); + render_fragments(&self.fragments, ctx, &self.partials, &mut output, 0)?; + Ok(output) + } + + /// 使用上下文和外部 partials 渲染。 + pub fn render_with_partials( + &self, + ctx: &TemplateContext, + partials: &HashMap, + ) -> Result { + let mut output = String::new(); + render_fragments(&self.fragments, ctx, partials, &mut output, 0)?; + Ok(output) + } + + /// 注册可引用的子模板。 + pub fn register_partial(&mut self, name: &str, template: PromptTemplate) { + self.partials.insert(name.to_string(), template); + } +} + +impl fmt::Display for PromptTemplate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.raw) + } +} + +// ===== Compiler ===== + +fn compile_fragments(template: &str) -> Result, PromptError> { + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut fragments = Vec::new(); + let mut i = 0; + let mut literal = String::new(); + + while i < len { + if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' { + if !literal.is_empty() { + fragments.push(Fragment::Literal(literal.clone())); + literal.clear(); + } + let (tag_content, end) = parse_tag(bytes, i)?; + i = end; + + let tag = tag_content.trim(); + if let Some(rest) = tag.strip_prefix("#if ") { + let (body, else_body, new_i) = + parse_block(template, i, "if")?; + let condition = rest.trim().to_string(); + fragments.push(Fragment::If { + condition, + body, + else_body, + }); + i = new_i; + } else if let Some(rest) = tag.strip_prefix("#each ") { + let (body, new_i) = parse_each_block(template, i)?; + let variable = rest.trim().to_string(); + fragments.push(Fragment::Each { variable, body }); + i = new_i; + } else if tag == "#raw" { + let (raw, new_i) = parse_raw_block(template, i)?; + fragments.push(Fragment::Raw(raw)); + i = new_i; + } else if let Some(rest) = tag.strip_prefix("> ") { + let name = rest.trim().to_string(); + fragments.push(Fragment::Include(name)); + } else if tag.starts_with("/") { + break; + } else { + let name = tag.to_string(); + fragments.push(Fragment::Variable { name }); + } + } else { + literal.push(bytes[i] as char); + i += 1; + } + } + + if !literal.is_empty() { + fragments.push(Fragment::Literal(literal)); + } + + Ok(fragments) +} + +fn parse_tag(bytes: &[u8], start: usize) -> Result<(String, usize), PromptError> { + let len = bytes.len(); + let mut i = start + 2; + let mut content = String::new(); + while i < len { + if bytes[i] == b'}' && i + 1 < len && bytes[i + 1] == b'}' { + return Ok((content, i + 2)); + } + content.push(bytes[i] as char); + i += 1; + } + Err(PromptError::Parse("未闭合的 {{ 标签".to_string())) +} + +fn parse_block( + template: &str, + start: usize, + kind: &str, +) -> Result<(Vec, Vec, usize), PromptError> { + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut depth = 1u32; + let mut i = start; + let mut body = String::new(); + let mut else_body = String::new(); + let mut is_else = false; + + while i < len && depth > 0 { + if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' { + let (tag, end) = parse_tag(bytes, i)?; + let tag = tag.trim().to_string(); + if tag == format!("/{kind}") { + depth -= 1; + if depth == 0 { + let (if_fragments, else_fragments) = if is_else { + (compile_fragments(&body)?, compile_fragments(&else_body)?) + } else { + (compile_fragments(&body)?, compile_fragments("")?) + }; + return Ok((if_fragments, else_fragments, end)); + } else { + body.push_str(&template[i..end]); + } + i = end; + } else if tag == "#if " || tag.starts_with("#if ") { + depth += 1; + body.push_str(&template[i..end]); + i = end; + } else if tag == "else" && depth == 1 { + is_else = true; + i = end; + } else { + body.push_str(&template[i..end]); + i = end; + } + } else { + if is_else { + else_body.push(bytes[i] as char); + } else { + body.push(bytes[i] as char); + } + i += 1; + } + } + + Err(PromptError::Parse(format!("未闭合的 {{#{}}} 块", kind))) +} + +fn parse_each_block( + template: &str, + start: usize, +) -> Result<(Vec, usize), PromptError> { + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut depth = 1u32; + let mut i = start; + let mut body = String::new(); + + while i < len && depth > 0 { + if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' { + let (tag, end) = parse_tag(bytes, i)?; + let tag = tag.trim().to_string(); + if tag == "/each" { + depth -= 1; + if depth == 0 { + let fragments = compile_fragments(&body)?; + return Ok((fragments, end)); + } else { + body.push_str(&template[i..end]); + } + i = end; + } else if tag == "#each " || tag.starts_with("#each ") { + depth += 1; + body.push_str(&template[i..end]); + i = end; + } else { + body.push_str(&template[i..end]); + i = end; + } + } else { + body.push(bytes[i] as char); + i += 1; + } + } + + Err(PromptError::Parse( + "未闭合的 {{#each}} 块".to_string(), + )) +} + +fn parse_raw_block(template: &str, start: usize) -> Result<(String, usize), PromptError> { + let bytes = template.as_bytes(); + let len = bytes.len(); + let mut i = start; + let mut content = String::new(); + + while i < len { + if bytes[i] == b'{' && i + 1 < len && bytes[i + 1] == b'{' { + let (tag, end) = parse_tag(bytes, i)?; + let tag = tag.trim().to_string(); + if tag == "/raw" { + return Ok((content, end)); + } else { + content.push_str(&template[i..end]); + i = end; + } + } else { + content.push(bytes[i] as char); + i += 1; + } + } + + Err(PromptError::Parse( + "未闭合的 {{#raw}} 块".to_string(), + )) +} + +// ===== Renderer ===== + +fn render_fragments( + fragments: &[Fragment], + ctx: &TemplateContext, + partials: &HashMap, + output: &mut String, + depth: u8, +) -> Result<(), PromptError> { + if depth > MAX_RENDER_DEPTH { + return Err(PromptError::MaxDepthReached(MAX_RENDER_DEPTH)); + } + + for frag in fragments { + match frag { + Fragment::Literal(text) => { + output.push_str(text); + } + Fragment::Variable { name } => { + match ctx.get(name) { + Some(val) => { + output.push_str(&format!("{}", val)); + } + None => { + return Err(PromptError::VariableNotFound(name.clone())); + } + } + } + Fragment::If { + condition, + body, + else_body, + } => { + let truthy = ctx + .get(condition) + .map(|v| v.is_truthy()) + .unwrap_or(false); + let target = if truthy { body } else { else_body }; + render_fragments(target, ctx, partials, output, depth + 1)?; + } + Fragment::Each { variable, body } => { + let arr = match ctx.get(variable) { + Some(val) => val.as_array().ok_or_else(|| { + PromptError::NotAnArray(variable.clone()) + })?, + None => { + return Err(PromptError::VariableNotFound(variable.clone())); + } + }; + + for item in arr { + let mut child_ctx = ctx.clone(); + child_ctx.vars.insert("item".to_string(), item.clone()); + render_fragments(body, &child_ctx, partials, output, depth + 1)?; + } + } + Fragment::Raw(text) => { + output.push_str(text); + } + Fragment::Include(name) => { + if let Some(partial) = partials.get(name) { + partial.render_with_partials(ctx, partials)?; + } else { + return Err(PromptError::PartialNotFound(name.clone())); + } + } + } + } + + Ok(()) +} + +// ===== PromptTemplateRegistry ===== + +/// 内部存储的模板(支持延迟编译)。 +enum StoredTemplate { + Compiled(PromptTemplate), + Raw(String), +} + +/// 模板注册表——管理多模板实例。 +#[derive(Default)] +pub struct PromptTemplateRegistry { + templates: HashMap, +} + +impl PromptTemplateRegistry { + /// 创建一个空的模板注册表。 + pub fn new() -> Self { + Self { + templates: HashMap::new(), + } + } + + /// 从模板字符串编译并注册(立即编译)。 + pub fn register(&mut self, name: &str, template: &str) -> Result<(), PromptError> { + let compiled = PromptTemplate::compile(template)?; + self.templates + .insert(name.to_string(), StoredTemplate::Compiled(compiled)); + Ok(()) + } + + /// 延迟编译注册:只存储原始字符串,首次渲染时编译。 + pub fn register_lazy(&mut self, name: &str, template: &str) { + self.templates.insert( + name.to_string(), + StoredTemplate::Raw(template.to_string()), + ); + } + + /// 从文件读取并编译注册。 + pub fn register_file(&mut self, name: &str, path: &std::path::Path) -> Result<(), PromptError> { + let content = std::fs::read_to_string(path)?; + let compiled = PromptTemplate::compile(&content)?; + self.templates + .insert(name.to_string(), StoredTemplate::Compiled(compiled)); + Ok(()) + } + + /// 获取已注册的模板(延迟编译的模板在此首次编译)。 + pub fn get(&mut self, name: &str) -> Result<&PromptTemplate, PromptError> { + if let Some(stored) = self.templates.get_mut(name) { + if let StoredTemplate::Raw(raw) = stored { + let compiled = PromptTemplate::compile(raw)?; + *stored = StoredTemplate::Compiled(compiled); + } + match stored { + StoredTemplate::Compiled(tpl) => Ok(tpl), + _ => unreachable!(), + } + } else { + Err(PromptError::PartialNotFound(name.to_string())) + } + } + + /// 按名称渲染。 + pub fn render(&mut self, name: &str, ctx: &TemplateContext) -> Result { + let tpl = self.get(name)?; + tpl.render(ctx) + } +}