feat(llm): 替换 ToolChoice 的自动派生为手动序列化

处理自定义 JSON 序列化逻辑,支持字符串值(`none`/`auto`/`required`)和对象格式(`{"type":"function","function":{"name":"..."}}`)。反序列化时向前兼容两种格式。
This commit is contained in:
徐涛
2026-05-14 08:36:08 +08:00
parent a4b7b3b9f9
commit e22c176643
+62 -3
View File
@@ -11,8 +11,7 @@ pub struct StreamOptions {
pub include_obfuscation: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
#[derive(Debug, Clone)]
pub enum ToolChoice {
None,
Auto,
@@ -21,11 +20,71 @@ pub enum ToolChoice {
name: String,
},
AllowedTools {
#[serde(rename = "tools")]
tool_names: Vec<String>,
},
}
impl Serialize for ToolChoice {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
ToolChoice::None => serializer.serialize_str("none"),
ToolChoice::Auto => serializer.serialize_str("auto"),
ToolChoice::Required => serializer.serialize_str("required"),
ToolChoice::Named { name } => {
let obj = serde_json::json!({
"type": "function",
"function": { "name": name }
});
obj.serialize(serializer)
}
ToolChoice::AllowedTools { tool_names } => {
let obj = serde_json::json!({
"type": "function",
"function": { "name": tool_names.first().cloned().unwrap_or_default() }
});
obj.serialize(serializer)
}
}
}
}
impl<'de> Deserialize<'de> for ToolChoice {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => match s.as_str() {
"none" => Ok(ToolChoice::None),
"auto" => Ok(ToolChoice::Auto),
"required" => Ok(ToolChoice::Required),
_ => Err(serde::de::Error::custom(format!("unknown tool choice: {s}"))),
},
Value::Object(obj) => {
let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
serde::de::Error::custom("missing 'type' field in tool_choice")
})?;
if typ == "function" {
let func = obj.get("function").and_then(|v| v.as_object()).ok_or_else(|| {
serde::de::Error::custom("missing 'function' field in tool_choice")
})?;
let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
serde::de::Error::custom("missing 'function.name' in tool_choice")
})?;
Ok(ToolChoice::Named { name: name.to_string() })
} else {
Err(serde::de::Error::custom(format!("unknown tool_choice type: {typ}")))
}
}
_ => Err(serde::de::Error::custom("tool_choice must be a string or object")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum OpenaiTool {