From e2099dd6913880e6035a0981a6784beef75ea9a9 Mon Sep 17 00:00:00 2001 From: Hunter Bown Date: Wed, 27 May 2026 06:19:32 -0500 Subject: [PATCH] fix: harden provider registry drift check --- .github/workflows/ci.yml | 2 +- docs/PROVIDERS.md | 2 +- scripts/check-provider-registry.py | 128 +++++++++++++---------------- 3 files changed, 61 insertions(+), 71 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f02b1e3a..45c212bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: - name: Check formatting run: cargo fmt --all -- --check - name: Check provider registry drift - run: python scripts/check-provider-registry.py + run: python3 scripts/check-provider-registry.py - name: Linux clippy location run: echo "Linux clippy/test gates run on CNB for mirrored fix/*, rebrand/*, work/v*, and main branches." diff --git a/docs/PROVIDERS.md b/docs/PROVIDERS.md index d268779e..ec9657ab 100644 --- a/docs/PROVIDERS.md +++ b/docs/PROVIDERS.md @@ -141,7 +141,7 @@ Run this before changing provider IDs, provider TOML tables, static model registry rows, or provider default strings: ```bash -python scripts/check-provider-registry.py +python3 scripts/check-provider-registry.py ``` The check fails when: diff --git a/scripts/check-provider-registry.py b/scripts/check-provider-registry.py index 0a6aab11..ebc1d556 100644 --- a/scripts/check-provider-registry.py +++ b/scripts/check-provider-registry.py @@ -24,30 +24,28 @@ TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs" AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs" PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md" -PROVIDER_VARIANT_TO_TABLE = { - "Deepseek": "deepseek", - "NvidiaNim": "nvidia_nim", - "Openai": "openai", - "Atlascloud": "atlascloud", - "WanjieArk": "wanjie_ark", - "Openrouter": "openrouter", - "Novita": "novita", - "Fireworks": "fireworks", - "Moonshot": "moonshot", - "Sglang": "sglang", - "Vllm": "vllm", - "Ollama": "ollama", -} - - def read(path: Path) -> str: return path.read_text(encoding="utf-8") +def require_index(source: str, needle: str, context: str, start: int = 0) -> int: + try: + return source.index(needle, start) + except ValueError: + raise ValueError(f"{context}: missing {needle!r}") from None + + +def markdown_section(source: str, heading: str) -> str: + start = require_index(source, heading, "docs/PROVIDERS.md") + next_heading = source.find("\n## ", start + len(heading)) + end = len(source) if next_heading == -1 else next_heading + return source[start:end] + + def extract_match_block(source: str, signature: str) -> str: - start = source.index(signature) - match_start = source.index("match", start) - brace_start = source.index("{", match_start) + start = require_index(source, signature, "crates/config/src/lib.rs") + match_start = require_index(source, "match", f"match block after {signature!r}", start) + brace_start = require_index(source, "{", f"match block after {signature!r}", match_start) depth = 0 for index in range(brace_start, len(source)): char = source[index] @@ -69,8 +67,10 @@ def provider_kind_ids(config_rs: str) -> dict[str, str]: def provider_tables(config_rs: str) -> set[str]: - struct_start = config_rs.index("pub struct ProvidersToml") - struct_end = config_rs.index("\n}", struct_start) + struct_start = require_index( + config_rs, "pub struct ProvidersToml", "crates/config/src/lib.rs" + ) + struct_end = require_index(config_rs, "\n}", "ProvidersToml struct", struct_start) fields = re.findall( r"pub\s+([a-z0-9_]+)\s*:\s*ProviderConfigToml", config_rs[struct_start:struct_end], @@ -81,23 +81,17 @@ def provider_tables(config_rs: str) -> set[str]: def shipped_provider_rows(providers_md: str) -> set[str]: - heading = providers_md.index("## Shipped Providers") - next_heading = providers_md.index("\n## ", heading + 1) - table = providers_md[heading:next_heading] + table = markdown_section(providers_md, "## Shipped Providers") return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE)) def shipped_provider_tables(providers_md: str) -> set[str]: - heading = providers_md.index("## Shipped Providers") - next_heading = providers_md.index("\n## ", heading + 1) - table = providers_md[heading:next_heading] + table = markdown_section(providers_md, "## Shipped Providers") return set(re.findall(r"\|\s*`\[providers\.([a-z0-9_]+)\]`\s*\|", table)) def static_registry_provider_rows(providers_md: str) -> set[str]: - heading = providers_md.index("## Static Model Registry") - next_heading = providers_md.index("\n## ", heading + 1) - table = providers_md[heading:next_heading] + table = markdown_section(providers_md, "## Static Model Registry") return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE)) @@ -124,7 +118,8 @@ def default_strings(tui_config_rs: str) -> set[str]: def missing_default_strings(providers_md: str, defaults: set[str]) -> list[str]: - return sorted(value for value in defaults if value not in providers_md) + code_spans = set(re.findall(r"`([^`]+)`", providers_md)) + return sorted(defaults - code_spans) def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]: @@ -139,47 +134,42 @@ def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]: def main() -> int: - config_rs = read(CONFIG_RS) - tui_config_rs = read(TUI_CONFIG_RS) - agent_rs = read(AGENT_RS) - providers_md = read(PROVIDERS_MD) + try: + config_rs = read(CONFIG_RS) + tui_config_rs = read(TUI_CONFIG_RS) + agent_rs = read(AGENT_RS) + providers_md = read(PROVIDERS_MD) - variant_to_id = provider_kind_ids(config_rs) - canonical_ids = set(variant_to_id.values()) - missing_table_mappings = sorted(set(variant_to_id) - set(PROVIDER_VARIANT_TO_TABLE)) - if missing_table_mappings: - raise ValueError( - "PROVIDER_VARIANT_TO_TABLE is missing variants: " - + ", ".join(missing_table_mappings) + variant_to_id = provider_kind_ids(config_rs) + canonical_ids = set(variant_to_id.values()) + expected_tables = {provider_id.replace("-", "_") for provider_id in canonical_ids} + + errors: list[str] = [] + errors += report_set( + "shipped provider rows", + canonical_ids, + shipped_provider_rows(providers_md), ) - expected_tables = { - PROVIDER_VARIANT_TO_TABLE[variant] for variant in variant_to_id - } - - errors: list[str] = [] - errors += report_set( - "shipped provider rows", - canonical_ids, - shipped_provider_rows(providers_md), - ) - errors += report_set("provider TOML tables", expected_tables, provider_tables(config_rs)) - errors += report_set( - "documented provider TOML tables", - expected_tables, - shipped_provider_tables(providers_md), - ) - errors += report_set( - "static ModelRegistry rows", - model_registry_providers(agent_rs, variant_to_id), - static_registry_provider_rows(providers_md), - ) - - missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs)) - if missing_defaults: - errors.append( - "docs/PROVIDERS.md does not mention default strings: " - + ", ".join(missing_defaults) + errors += report_set("provider TOML tables", expected_tables, provider_tables(config_rs)) + errors += report_set( + "documented provider TOML tables", + expected_tables, + shipped_provider_tables(providers_md), ) + errors += report_set( + "static ModelRegistry rows", + model_registry_providers(agent_rs, variant_to_id), + static_registry_provider_rows(providers_md), + ) + + missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs)) + if missing_defaults: + errors.append( + "docs/PROVIDERS.md does not mention default strings as Markdown code spans: " + + ", ".join(missing_defaults) + ) + except ValueError as err: + errors = [str(err)] if errors: print("Provider registry drift check failed:", file=sys.stderr)