diff --git a/src/llm/types/request.rs b/src/llm/types/request.rs index 5530b9f..e706917 100644 --- a/src/llm/types/request.rs +++ b/src/llm/types/request.rs @@ -11,8 +11,7 @@ pub struct StreamOptions { pub include_obfuscation: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case", tag = "type")] +#[derive(Debug, Clone)] pub enum ToolChoice { None, Auto, @@ -21,11 +20,71 @@ pub enum ToolChoice { name: String, }, AllowedTools { - #[serde(rename = "tools")] tool_names: Vec, }, } +impl Serialize for ToolChoice { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + ToolChoice::None => serializer.serialize_str("none"), + ToolChoice::Auto => serializer.serialize_str("auto"), + ToolChoice::Required => serializer.serialize_str("required"), + ToolChoice::Named { name } => { + let obj = serde_json::json!({ + "type": "function", + "function": { "name": name } + }); + obj.serialize(serializer) + } + ToolChoice::AllowedTools { tool_names } => { + let obj = serde_json::json!({ + "type": "function", + "function": { "name": tool_names.first().cloned().unwrap_or_default() } + }); + obj.serialize(serializer) + } + } + } +} + +impl<'de> Deserialize<'de> for ToolChoice { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => match s.as_str() { + "none" => Ok(ToolChoice::None), + "auto" => Ok(ToolChoice::Auto), + "required" => Ok(ToolChoice::Required), + _ => Err(serde::de::Error::custom(format!("unknown tool choice: {s}"))), + }, + Value::Object(obj) => { + let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| { + serde::de::Error::custom("missing 'type' field in tool_choice") + })?; + if typ == "function" { + let func = obj.get("function").and_then(|v| v.as_object()).ok_or_else(|| { + serde::de::Error::custom("missing 'function' field in tool_choice") + })?; + let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| { + serde::de::Error::custom("missing 'function.name' in tool_choice") + })?; + Ok(ToolChoice::Named { name: name.to_string() }) + } else { + Err(serde::de::Error::custom(format!("unknown tool_choice type: {typ}"))) + } + } + _ => Err(serde::de::Error::custom("tool_choice must be a string or object")), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] pub enum OpenaiTool {