From 45562822f0e718a9519763606551ee9206400957 Mon Sep 17 00:00:00 2001 From: cyq <15000851237@163.com> Date: Tue, 2 Jun 2026 00:53:47 +0800 Subject: [PATCH] feat(agent): classify model families --- crates/agent/src/lib.rs | 118 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index 36f6e0bf..1d7640f3 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -3,6 +3,22 @@ use std::collections::HashMap; use codewhale_config::ProviderKind; use serde::{Deserialize, Serialize}; +/// High-level model family used for shared identity affordances across clients. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelFamily { + DeepSeek, + Anthropic, + OpenAI, + Google, + Meta, + Mistral, + Qwen, + Grok, + Cohere, + GptOss, + Inferencer, +} + /// Metadata for a single model entry in the registry. /// /// Each model has a canonical `id` used by the provider, a list of `aliases` @@ -703,6 +719,58 @@ fn normalize(value: &str) -> String { value.trim().to_ascii_lowercase() } +#[must_use] +/// Classify a model identifier by its underlying model family. +pub fn model_family(model_id: &str) -> ModelFamily { + let normalized = normalize(model_id); + if normalized.is_empty() { + return ModelFamily::Inferencer; + } + + if normalized.contains("deepseek") { + return ModelFamily::DeepSeek; + } + if normalized.contains("claude") || normalized.contains("anthropic") { + return ModelFamily::Anthropic; + } + if normalized.contains("gpt-oss") || normalized.contains("gpt_oss") { + return ModelFamily::GptOss; + } + if normalized.starts_with("gpt-") + || normalized.contains("/gpt-") + || normalized.contains("openai/") + { + return ModelFamily::OpenAI; + } + if normalized.contains("gemini") + || normalized.contains("gemma") + || normalized.contains("google/") + { + return ModelFamily::Google; + } + if normalized.contains("llama") || normalized.contains("meta-") || normalized.contains("meta/") + { + return ModelFamily::Meta; + } + if normalized.contains("mistral") + || normalized.contains("mixtral") + || normalized.contains("codestral") + { + return ModelFamily::Mistral; + } + if normalized.contains("qwen") { + return ModelFamily::Qwen; + } + if normalized.contains("grok") { + return ModelFamily::Grok; + } + if normalized.contains("cohere") || normalized.contains("command-r") { + return ModelFamily::Cohere; + } + + ModelFamily::Inferencer +} + fn model_matches(model: &ModelInfo, requested: &str) -> bool { let requested = normalize(requested); normalize(&model.id) == requested @@ -1171,4 +1239,54 @@ mod tests { assert_eq!(resolved.resolved.provider, ProviderKind::Deepseek); assert_eq!(resolved.resolved.id, "deepseek-v4-flash"); } + + #[test] + fn model_family_classifies_known_model_ids() { + assert_eq!(model_family("deepseek-v4-pro"), ModelFamily::DeepSeek); + assert_eq!(model_family("openai/gpt-5.4"), ModelFamily::OpenAI); + assert_eq!( + model_family("anthropic/claude-opus-4-7"), + ModelFamily::Anthropic + ); + assert_eq!( + model_family("meta-llama/llama-3.3-70b-instruct"), + ModelFamily::Meta + ); + assert_eq!(model_family("Qwen/Qwen3-Coder"), ModelFamily::Qwen); + } + + #[test] + fn model_family_uses_underlying_model_for_router_ids() { + assert_eq!( + model_family("groq/llama-3.3-70b-versatile"), + ModelFamily::Meta + ); + assert_eq!( + model_family("openrouter/openai/gpt-5.4"), + ModelFamily::OpenAI + ); + assert_eq!( + model_family("fireworks/accounts/fireworks/models/deepseek-v4-pro"), + ModelFamily::DeepSeek + ); + } + + #[test] + fn model_family_covers_prominent_google_and_mistral_model_names() { + assert_eq!(model_family("google/gemma-3-27b-it"), ModelFamily::Google); + assert_eq!( + model_family("mistralai/mixtral-8x22b"), + ModelFamily::Mistral + ); + assert_eq!(model_family("codestral-latest"), ModelFamily::Mistral); + } + + #[test] + fn model_family_falls_back_to_inferencer_for_unknown_models() { + assert_eq!( + model_family("custom-gateway/my-private-model"), + ModelFamily::Inferencer + ); + assert_eq!(model_family(""), ModelFamily::Inferencer); + } }