feat(agent): classify model families

This commit is contained in:
cyq
2026-06-02 00:53:47 +08:00
committed by Hunter Bown
parent 195dd6b9ab
commit 45562822f0
+118
View File
@@ -3,6 +3,22 @@ use std::collections::HashMap;
use codewhale_config::ProviderKind;
use serde::{Deserialize, Serialize};
/// High-level model family used for shared identity affordances across clients.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelFamily {
DeepSeek,
Anthropic,
OpenAI,
Google,
Meta,
Mistral,
Qwen,
Grok,
Cohere,
GptOss,
Inferencer,
}
/// Metadata for a single model entry in the registry.
///
/// Each model has a canonical `id` used by the provider, a list of `aliases`
@@ -703,6 +719,58 @@ fn normalize(value: &str) -> String {
value.trim().to_ascii_lowercase()
}
#[must_use]
/// Classify a model identifier by its underlying model family.
pub fn model_family(model_id: &str) -> ModelFamily {
let normalized = normalize(model_id);
if normalized.is_empty() {
return ModelFamily::Inferencer;
}
if normalized.contains("deepseek") {
return ModelFamily::DeepSeek;
}
if normalized.contains("claude") || normalized.contains("anthropic") {
return ModelFamily::Anthropic;
}
if normalized.contains("gpt-oss") || normalized.contains("gpt_oss") {
return ModelFamily::GptOss;
}
if normalized.starts_with("gpt-")
|| normalized.contains("/gpt-")
|| normalized.contains("openai/")
{
return ModelFamily::OpenAI;
}
if normalized.contains("gemini")
|| normalized.contains("gemma")
|| normalized.contains("google/")
{
return ModelFamily::Google;
}
if normalized.contains("llama") || normalized.contains("meta-") || normalized.contains("meta/")
{
return ModelFamily::Meta;
}
if normalized.contains("mistral")
|| normalized.contains("mixtral")
|| normalized.contains("codestral")
{
return ModelFamily::Mistral;
}
if normalized.contains("qwen") {
return ModelFamily::Qwen;
}
if normalized.contains("grok") {
return ModelFamily::Grok;
}
if normalized.contains("cohere") || normalized.contains("command-r") {
return ModelFamily::Cohere;
}
ModelFamily::Inferencer
}
fn model_matches(model: &ModelInfo, requested: &str) -> bool {
let requested = normalize(requested);
normalize(&model.id) == requested
@@ -1171,4 +1239,54 @@ mod tests {
assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek);
assert_eq!(resolved.resolved.id, "deepseek-v4-flash");
}
#[test]
fn model_family_classifies_known_model_ids() {
assert_eq!(model_family("deepseek-v4-pro"), ModelFamily::DeepSeek);
assert_eq!(model_family("openai/gpt-5.4"), ModelFamily::OpenAI);
assert_eq!(
model_family("anthropic/claude-opus-4-7"),
ModelFamily::Anthropic
);
assert_eq!(
model_family("meta-llama/llama-3.3-70b-instruct"),
ModelFamily::Meta
);
assert_eq!(model_family("Qwen/Qwen3-Coder"), ModelFamily::Qwen);
}
#[test]
fn model_family_uses_underlying_model_for_router_ids() {
assert_eq!(
model_family("groq/llama-3.3-70b-versatile"),
ModelFamily::Meta
);
assert_eq!(
model_family("openrouter/openai/gpt-5.4"),
ModelFamily::OpenAI
);
assert_eq!(
model_family("fireworks/accounts/fireworks/models/deepseek-v4-pro"),
ModelFamily::DeepSeek
);
}
#[test]
fn model_family_covers_prominent_google_and_mistral_model_names() {
assert_eq!(model_family("google/gemma-3-27b-it"), ModelFamily::Google);
assert_eq!(
model_family("mistralai/mixtral-8x22b"),
ModelFamily::Mistral
);
assert_eq!(model_family("codestral-latest"), ModelFamily::Mistral);
}
#[test]
fn model_family_falls_back_to_inferencer_for_unknown_models() {
assert_eq!(
model_family("custom-gateway/my-private-model"),
ModelFamily::Inferencer
);
assert_eq!(model_family(""), ModelFamily::Inferencer);
}
}