feat(llm): 替换 ToolChoice 的自动派生为手动序列化
处理自定义 JSON 序列化逻辑,支持字符串值(`none`/`auto`/`required`)和对象格式(`{"type":"function","function":{"name":"..."}}`)。反序列化时向前兼容两种格式。
This commit is contained in:
@@ -11,8 +11,7 @@ pub struct StreamOptions {
|
|||||||
pub include_obfuscation: Option<bool>,
|
pub include_obfuscation: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone)]
|
||||||
#[serde(rename_all = "snake_case", tag = "type")]
|
|
||||||
pub enum ToolChoice {
|
pub enum ToolChoice {
|
||||||
None,
|
None,
|
||||||
Auto,
|
Auto,
|
||||||
@@ -21,11 +20,71 @@ pub enum ToolChoice {
|
|||||||
name: String,
|
name: String,
|
||||||
},
|
},
|
||||||
AllowedTools {
|
AllowedTools {
|
||||||
#[serde(rename = "tools")]
|
|
||||||
tool_names: Vec<String>,
|
tool_names: Vec<String>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Serialize for ToolChoice {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
ToolChoice::None => serializer.serialize_str("none"),
|
||||||
|
ToolChoice::Auto => serializer.serialize_str("auto"),
|
||||||
|
ToolChoice::Required => serializer.serialize_str("required"),
|
||||||
|
ToolChoice::Named { name } => {
|
||||||
|
let obj = serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": { "name": name }
|
||||||
|
});
|
||||||
|
obj.serialize(serializer)
|
||||||
|
}
|
||||||
|
ToolChoice::AllowedTools { tool_names } => {
|
||||||
|
let obj = serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": { "name": tool_names.first().cloned().unwrap_or_default() }
|
||||||
|
});
|
||||||
|
obj.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ToolChoice {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => match s.as_str() {
|
||||||
|
"none" => Ok(ToolChoice::None),
|
||||||
|
"auto" => Ok(ToolChoice::Auto),
|
||||||
|
"required" => Ok(ToolChoice::Required),
|
||||||
|
_ => Err(serde::de::Error::custom(format!("unknown tool choice: {s}"))),
|
||||||
|
},
|
||||||
|
Value::Object(obj) => {
|
||||||
|
let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'type' field in tool_choice")
|
||||||
|
})?;
|
||||||
|
if typ == "function" {
|
||||||
|
let func = obj.get("function").and_then(|v| v.as_object()).ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'function' field in tool_choice")
|
||||||
|
})?;
|
||||||
|
let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||||
|
serde::de::Error::custom("missing 'function.name' in tool_choice")
|
||||||
|
})?;
|
||||||
|
Ok(ToolChoice::Named { name: name.to_string() })
|
||||||
|
} else {
|
||||||
|
Err(serde::de::Error::custom(format!("unknown tool_choice type: {typ}")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(serde::de::Error::custom("tool_choice must be a string or object")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case", tag = "type")]
|
#[serde(rename_all = "snake_case", tag = "type")]
|
||||||
pub enum OpenaiTool {
|
pub enum OpenaiTool {
|
||||||
|
|||||||
Reference in New Issue
Block a user