From 9edd2008c447e9eae91ab6f7c9b5d271eb4783e2 Mon Sep 17 00:00:00 2001 From: Nightt <87569709+nightt5879@users.noreply.github.com> Date: Wed, 27 May 2026 18:34:11 +0800 Subject: [PATCH] docs: add provider registry drift check --- .github/workflows/ci.yml | 2 + README.md | 1 + docs/PROVIDERS.md | 23 +++- scripts/check-provider-registry.py | 195 +++++++++++++++++++++++++++++ 4 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 scripts/check-provider-registry.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4203a17c..f02b1e3a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,8 @@ jobs: components: rustfmt - name: Check formatting run: cargo fmt --all -- --check + - name: Check provider registry drift + run: python scripts/check-provider-registry.py - name: Linux clippy location run: echo "Linux clippy/test gates run on CNB for mirrored fix/*, rebrand/*, work/v*, and main branches." diff --git a/README.md b/README.md index 3213ee15..fbce91ab 100644 --- a/README.md +++ b/README.md @@ -555,6 +555,7 @@ without recreating skills the user deliberately deleted. |---|---| | [ARCHITECTURE.md](docs/ARCHITECTURE.md) | Codebase internals | | [CONFIGURATION.md](docs/CONFIGURATION.md) | Full config reference | +| [PROVIDERS.md](docs/PROVIDERS.md) | Provider IDs, auth, model defaults, and capability metadata | | [MODES.md](docs/MODES.md) | Plan / Agent / YOLO modes | | [MCP.md](docs/MCP.md) | Model Context Protocol integration | | [RUNTIME_API.md](docs/RUNTIME_API.md) | HTTP/SSE API server | diff --git a/docs/PROVIDERS.md b/docs/PROVIDERS.md index 38969415..d268779e 100644 --- a/docs/PROVIDERS.md +++ b/docs/PROVIDERS.md @@ -21,6 +21,8 @@ Sources to keep in sync: `codewhale model list` and `codewhale model resolve`. - `config.example.toml` and `docs/CONFIGURATION.md` - user-facing config examples and environment variable reference. +- `scripts/check-provider-registry.py` - drift check for canonical provider + IDs, TOML table names, static registry rows, and documented defaults. ## Provider Selection @@ -133,6 +135,24 @@ DeepSeek compatibility aliases `deepseek-chat` and `deepseek-reasoner` map to `deepseek-v4-flash` capability metadata and are scheduled to retire on 2026-07-24 at 2026-07-24T15:59:00Z. +## Drift Check + +Run this before changing provider IDs, provider TOML tables, static model +registry rows, or provider default strings: + +```bash +python scripts/check-provider-registry.py +``` + +The check fails when: + +- `docs/PROVIDERS.md` omits a canonical `ProviderKind::as_str()` ID. +- The shipped-provider table omits or adds a `[providers.*]` TOML table. +- The static model registry table drifts from providers used by + `crates/agent/src/lib.rs`. +- A provider default model or base URL constant in `crates/tui/src/config.rs` + is no longer mentioned here. + ## Planned, Not Shipped Yet These items belong to the v0.8.47 provider-abstraction milestone or related @@ -149,9 +169,6 @@ provider docs work, but they are not native shipped behavior in this checkout: - Hugging Face model passport metadata in the picker, including license, base model, context length, chat template, tool-call support, reasoning support, and gated/private status. -- A generated drift-check script that fails when this file diverges from the - provider registry. Until that exists, update this file with a source read of - the files listed at the top. Until native Hugging Face support lands, users can only reach an explicitly configured Hugging Face-compatible OpenAI route through the generic `openai` diff --git a/scripts/check-provider-registry.py b/scripts/check-provider-registry.py new file mode 100644 index 00000000..0a6aab11 --- /dev/null +++ b/scripts/check-provider-registry.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +"""Check that docs/PROVIDERS.md tracks the shipped provider registry. + +This is intentionally lightweight. It does not try to generate prose; it checks +the stable identifiers and default strings that are easy for docs to drift from: + +- canonical ProviderKind IDs +- provider TOML tables +- shipped-provider table rows +- static ModelRegistry provider rows +- default provider model/base URL constants +""" + +from __future__ import annotations + +import re +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +CONFIG_RS = ROOT / "crates" / "config" / "src" / "lib.rs" +TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs" +AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs" +PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md" + +PROVIDER_VARIANT_TO_TABLE = { + "Deepseek": "deepseek", + "NvidiaNim": "nvidia_nim", + "Openai": "openai", + "Atlascloud": "atlascloud", + "WanjieArk": "wanjie_ark", + "Openrouter": "openrouter", + "Novita": "novita", + "Fireworks": "fireworks", + "Moonshot": "moonshot", + "Sglang": "sglang", + "Vllm": "vllm", + "Ollama": "ollama", +} + + +def read(path: Path) -> str: + return path.read_text(encoding="utf-8") + + +def extract_match_block(source: str, signature: str) -> str: + start = source.index(signature) + match_start = source.index("match", start) + brace_start = source.index("{", match_start) + depth = 0 + for index in range(brace_start, len(source)): + char = source[index] + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + return source[brace_start + 1 : index] + raise ValueError(f"could not parse match block after {signature!r}") + + +def provider_kind_ids(config_rs: str) -> dict[str, str]: + block = extract_match_block(config_rs, "pub fn as_str(self) -> &'static str") + pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block) + if not pairs: + raise ValueError("ProviderKind::as_str returned no providers") + return {variant: provider_id for variant, provider_id in pairs} + + +def provider_tables(config_rs: str) -> set[str]: + struct_start = config_rs.index("pub struct ProvidersToml") + struct_end = config_rs.index("\n}", struct_start) + fields = re.findall( + r"pub\s+([a-z0-9_]+)\s*:\s*ProviderConfigToml", + config_rs[struct_start:struct_end], + ) + if not fields: + raise ValueError("ProvidersToml returned no provider tables") + return set(fields) + + +def shipped_provider_rows(providers_md: str) -> set[str]: + heading = providers_md.index("## Shipped Providers") + next_heading = providers_md.index("\n## ", heading + 1) + table = providers_md[heading:next_heading] + return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE)) + + +def shipped_provider_tables(providers_md: str) -> set[str]: + heading = providers_md.index("## Shipped Providers") + next_heading = providers_md.index("\n## ", heading + 1) + table = providers_md[heading:next_heading] + return set(re.findall(r"\|\s*`\[providers\.([a-z0-9_]+)\]`\s*\|", table)) + + +def static_registry_provider_rows(providers_md: str) -> set[str]: + heading = providers_md.index("## Static Model Registry") + next_heading = providers_md.index("\n## ", heading + 1) + table = providers_md[heading:next_heading] + return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE)) + + +def model_registry_providers(agent_rs: str, variant_to_id: dict[str, str]) -> set[str]: + variants = set(re.findall(r"provider:\s*ProviderKind::(\w+)", agent_rs)) + missing = variants - set(variant_to_id) + if missing: + raise ValueError(f"ModelRegistry uses unknown provider variants: {sorted(missing)}") + return {variant_to_id[variant] for variant in variants} + + +def default_strings(tui_config_rs: str) -> set[str]: + defaults = set() + for name, value in re.findall( + r'const\s+(DEFAULT_[A-Z0-9_]+(?:MODEL|BASE_URL)):\s*&str\s*=\s*"([^"]+)"', + tui_config_rs, + ): + if name == "DEFAULT_DEEPSEEKCN_BASE_URL": + continue + defaults.add(value) + if not defaults: + raise ValueError("no default provider model/base URL constants found") + return defaults + + +def missing_default_strings(providers_md: str, defaults: set[str]) -> list[str]: + return sorted(value for value in defaults if value not in providers_md) + + +def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]: + errors = [] + missing = sorted(expected - actual) + extra = sorted(actual - expected) + if missing: + errors.append(f"{label} missing: {', '.join(missing)}") + if extra: + errors.append(f"{label} extra: {', '.join(extra)}") + return errors + + +def main() -> int: + config_rs = read(CONFIG_RS) + tui_config_rs = read(TUI_CONFIG_RS) + agent_rs = read(AGENT_RS) + providers_md = read(PROVIDERS_MD) + + variant_to_id = provider_kind_ids(config_rs) + canonical_ids = set(variant_to_id.values()) + missing_table_mappings = sorted(set(variant_to_id) - set(PROVIDER_VARIANT_TO_TABLE)) + if missing_table_mappings: + raise ValueError( + "PROVIDER_VARIANT_TO_TABLE is missing variants: " + + ", ".join(missing_table_mappings) + ) + expected_tables = { + PROVIDER_VARIANT_TO_TABLE[variant] for variant in variant_to_id + } + + errors: list[str] = [] + errors += report_set( + "shipped provider rows", + canonical_ids, + shipped_provider_rows(providers_md), + ) + errors += report_set("provider TOML tables", expected_tables, provider_tables(config_rs)) + errors += report_set( + "documented provider TOML tables", + expected_tables, + shipped_provider_tables(providers_md), + ) + errors += report_set( + "static ModelRegistry rows", + model_registry_providers(agent_rs, variant_to_id), + static_registry_provider_rows(providers_md), + ) + + missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs)) + if missing_defaults: + errors.append( + "docs/PROVIDERS.md does not mention default strings: " + + ", ".join(missing_defaults) + ) + + if errors: + print("Provider registry drift check failed:", file=sys.stderr) + for error in errors: + print(f"- {error}", file=sys.stderr) + return 1 + + print("Provider registry drift check passed.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())