Merge PR #2879: Hugging Face provider docs and tests
Harvested from PR #2879 by @mvanhorn Co-authored-by: mvanhorn <455140+mvanhorn@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,10 @@ API_PROVIDER_ONLY_IDS = {"deepseek-cn"}
|
||||
SHARED_PROVIDER_TABLES = {
|
||||
"siliconflow-CN": "siliconflow_cn",
|
||||
}
|
||||
HUGGINGFACE_ALIASES = {"huggingface", "hugging-face", "hugging_face", "hf"}
|
||||
HUGGINGFACE_API_KEY_ENV_ORDER = ["HUGGINGFACE_API_KEY", "HF_TOKEN"]
|
||||
HUGGINGFACE_BASE_URL_ENV_ORDER = ["HUGGINGFACE_BASE_URL", "HF_BASE_URL"]
|
||||
HUGGINGFACE_MODEL_ENV_ORDER = ["HUGGINGFACE_MODEL", "HF_MODEL"]
|
||||
|
||||
|
||||
def read(path: Path) -> str:
|
||||
@@ -69,6 +73,35 @@ def extract_match_block(
|
||||
raise ValueError(f"could not parse match block after {signature!r}")
|
||||
|
||||
|
||||
def parse_aliases_for_variant(source: str, enum_name: str, variant: str, context: str) -> set[str]:
|
||||
impl_start = require_index(source, f"impl {enum_name}", context)
|
||||
block = extract_match_block(
|
||||
source,
|
||||
"pub fn parse(value: &str) -> Option<Self>",
|
||||
context,
|
||||
impl_start,
|
||||
)
|
||||
match_arm = re.search(
|
||||
rf'((?:"[^"]+"\s*\|\s*)*"[^"]+")\s*=>\s*Some\(Self::{variant}\)',
|
||||
block,
|
||||
)
|
||||
if match_arm:
|
||||
return set(re.findall(r'"([^"]+)"', match_arm.group(1)))
|
||||
if enum_name in {"ProviderKind", "ApiProvider"}:
|
||||
provider_rs = read(PROVIDER_RS)
|
||||
provider_macro = re.search(
|
||||
rf'provider!\(\s*\n\s*\w+,\s*\n\s*{variant},\s*\n\s*"([^"]+)".*?'
|
||||
r"aliases:\s*\[(.*?)\]\s*\);",
|
||||
provider_rs,
|
||||
re.DOTALL,
|
||||
)
|
||||
if provider_macro:
|
||||
return {provider_macro.group(1)} | set(
|
||||
re.findall(r'"([^"]+)"', provider_macro.group(2))
|
||||
)
|
||||
raise ValueError(f"{context}: missing parse arm for {variant}")
|
||||
|
||||
|
||||
def provider_kind_ids(config_rs: str) -> dict[str, str]:
|
||||
provider_rs = read(PROVIDER_RS)
|
||||
pairs = re.findall(
|
||||
@@ -201,6 +234,76 @@ def report_provider_enum_drift(
|
||||
return errors
|
||||
|
||||
|
||||
def report_huggingface_coverage(
|
||||
config_rs: str, tui_config_rs: str, providers_md: str
|
||||
) -> list[str]:
|
||||
errors = []
|
||||
|
||||
config_aliases = parse_aliases_for_variant(
|
||||
config_rs, "ProviderKind", "Huggingface", "crates/config/src/lib.rs"
|
||||
)
|
||||
tui_aliases = parse_aliases_for_variant(
|
||||
tui_config_rs, "ApiProvider", "Huggingface", "crates/tui/src/config.rs"
|
||||
)
|
||||
errors += report_set(
|
||||
"ProviderKind Hugging Face aliases",
|
||||
HUGGINGFACE_ALIASES,
|
||||
config_aliases & HUGGINGFACE_ALIASES,
|
||||
)
|
||||
errors += report_set(
|
||||
"ApiProvider Hugging Face aliases",
|
||||
HUGGINGFACE_ALIASES,
|
||||
tui_aliases & HUGGINGFACE_ALIASES,
|
||||
)
|
||||
|
||||
inline_source = re.sub(r"```.*?```", "", providers_md, flags=re.DOTALL)
|
||||
code_spans = set(re.findall(r"`([^`]+)`", inline_source))
|
||||
errors += report_set(
|
||||
"documented Hugging Face aliases",
|
||||
HUGGINGFACE_ALIASES,
|
||||
code_spans & HUGGINGFACE_ALIASES,
|
||||
)
|
||||
|
||||
for label, env_order in [
|
||||
("Hugging Face API key env precedence", HUGGINGFACE_API_KEY_ENV_ORDER),
|
||||
("Hugging Face base URL env precedence", HUGGINGFACE_BASE_URL_ENV_ORDER),
|
||||
("Hugging Face model env precedence", HUGGINGFACE_MODEL_ENV_ORDER),
|
||||
]:
|
||||
errors += report_env_lookup_order(
|
||||
label, config_rs, env_order, "crates/config/src/lib.rs"
|
||||
)
|
||||
errors += report_env_lookup_order(
|
||||
label, tui_config_rs, env_order, "crates/tui/src/config.rs"
|
||||
)
|
||||
errors += report_string_order(label, providers_md, env_order, "docs/PROVIDERS.md")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def report_env_lookup_order(
|
||||
label: str, source: str, expected_order: list[str], context: str
|
||||
) -> list[str]:
|
||||
lookup_needles = [f'std::env::var("{name}")' for name in expected_order]
|
||||
return report_string_order(label, source, lookup_needles, context)
|
||||
|
||||
|
||||
def report_string_order(
|
||||
label: str, source: str, expected_order: list[str], context: str
|
||||
) -> list[str]:
|
||||
positions = []
|
||||
for needle in expected_order:
|
||||
index = source.find(needle)
|
||||
if index == -1:
|
||||
return [f"{label} missing {needle!r} in {context}"]
|
||||
positions.append(index)
|
||||
if positions != sorted(positions):
|
||||
return [
|
||||
f"{label} has wrong order in {context}: expected "
|
||||
+ " before ".join(expected_order)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def provider_table_name(provider_id: str) -> str:
|
||||
return SHARED_PROVIDER_TABLES.get(provider_id, provider_id.replace("-", "_"))
|
||||
|
||||
@@ -219,6 +322,7 @@ def main() -> int:
|
||||
|
||||
errors: list[str] = []
|
||||
errors += report_provider_enum_drift(canonical_ids, live_api_provider_ids)
|
||||
errors += report_huggingface_coverage(config_rs, tui_config_rs, providers_md)
|
||||
errors += report_set(
|
||||
"shipped provider rows",
|
||||
canonical_ids,
|
||||
@@ -256,4 +360,4 @@ def main() -> int:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
raise SystemExit(main())
|
||||
|
||||
Reference in New Issue
Block a user