feat(llm): 替换 ToolChoice 的自动派生为手动序列化
处理自定义 JSON 序列化逻辑,支持字符串值(`none`/`auto`/`required`)和对象格式(`{"type":"function","function":{"name":"..."}}`)。反序列化时向前兼容两种格式。
This commit is contained in:
@@ -11,8 +11,7 @@ pub struct StreamOptions {
|
||||
pub include_obfuscation: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolChoice {
|
||||
None,
|
||||
Auto,
|
||||
@@ -21,11 +20,71 @@ pub enum ToolChoice {
|
||||
name: String,
|
||||
},
|
||||
AllowedTools {
|
||||
#[serde(rename = "tools")]
|
||||
tool_names: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Serialize for ToolChoice {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
ToolChoice::None => serializer.serialize_str("none"),
|
||||
ToolChoice::Auto => serializer.serialize_str("auto"),
|
||||
ToolChoice::Required => serializer.serialize_str("required"),
|
||||
ToolChoice::Named { name } => {
|
||||
let obj = serde_json::json!({
|
||||
"type": "function",
|
||||
"function": { "name": name }
|
||||
});
|
||||
obj.serialize(serializer)
|
||||
}
|
||||
ToolChoice::AllowedTools { tool_names } => {
|
||||
let obj = serde_json::json!({
|
||||
"type": "function",
|
||||
"function": { "name": tool_names.first().cloned().unwrap_or_default() }
|
||||
});
|
||||
obj.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ToolChoice {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(ToolChoice::None),
|
||||
"auto" => Ok(ToolChoice::Auto),
|
||||
"required" => Ok(ToolChoice::Required),
|
||||
_ => Err(serde::de::Error::custom(format!("unknown tool choice: {s}"))),
|
||||
},
|
||||
Value::Object(obj) => {
|
||||
let typ = obj.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'type' field in tool_choice")
|
||||
})?;
|
||||
if typ == "function" {
|
||||
let func = obj.get("function").and_then(|v| v.as_object()).ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'function' field in tool_choice")
|
||||
})?;
|
||||
let name = func.get("name").and_then(|v| v.as_str()).ok_or_else(|| {
|
||||
serde::de::Error::custom("missing 'function.name' in tool_choice")
|
||||
})?;
|
||||
Ok(ToolChoice::Named { name: name.to_string() })
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!("unknown tool_choice type: {typ}")))
|
||||
}
|
||||
}
|
||||
_ => Err(serde::de::Error::custom("tool_choice must be a string or object")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "type")]
|
||||
pub enum OpenaiTool {
|
||||
|
||||
Reference in New Issue
Block a user