fix: harden provider registry drift check
This commit is contained in:
@@ -39,7 +39,7 @@ jobs:
|
||||
- name: Check formatting
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Check provider registry drift
|
||||
run: python scripts/check-provider-registry.py
|
||||
run: python3 scripts/check-provider-registry.py
|
||||
- name: Linux clippy location
|
||||
run: echo "Linux clippy/test gates run on CNB for mirrored fix/*, rebrand/*, work/v*, and main branches."
|
||||
|
||||
|
||||
+1
-1
@@ -141,7 +141,7 @@ Run this before changing provider IDs, provider TOML tables, static model
|
||||
registry rows, or provider default strings:
|
||||
|
||||
```bash
|
||||
python scripts/check-provider-registry.py
|
||||
python3 scripts/check-provider-registry.py
|
||||
```
|
||||
|
||||
The check fails when:
|
||||
|
||||
@@ -24,30 +24,28 @@ TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs"
|
||||
AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs"
|
||||
PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md"
|
||||
|
||||
PROVIDER_VARIANT_TO_TABLE = {
|
||||
"Deepseek": "deepseek",
|
||||
"NvidiaNim": "nvidia_nim",
|
||||
"Openai": "openai",
|
||||
"Atlascloud": "atlascloud",
|
||||
"WanjieArk": "wanjie_ark",
|
||||
"Openrouter": "openrouter",
|
||||
"Novita": "novita",
|
||||
"Fireworks": "fireworks",
|
||||
"Moonshot": "moonshot",
|
||||
"Sglang": "sglang",
|
||||
"Vllm": "vllm",
|
||||
"Ollama": "ollama",
|
||||
}
|
||||
|
||||
|
||||
def read(path: Path) -> str:
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def require_index(source: str, needle: str, context: str, start: int = 0) -> int:
|
||||
try:
|
||||
return source.index(needle, start)
|
||||
except ValueError:
|
||||
raise ValueError(f"{context}: missing {needle!r}") from None
|
||||
|
||||
|
||||
def markdown_section(source: str, heading: str) -> str:
|
||||
start = require_index(source, heading, "docs/PROVIDERS.md")
|
||||
next_heading = source.find("\n## ", start + len(heading))
|
||||
end = len(source) if next_heading == -1 else next_heading
|
||||
return source[start:end]
|
||||
|
||||
|
||||
def extract_match_block(source: str, signature: str) -> str:
|
||||
start = source.index(signature)
|
||||
match_start = source.index("match", start)
|
||||
brace_start = source.index("{", match_start)
|
||||
start = require_index(source, signature, "crates/config/src/lib.rs")
|
||||
match_start = require_index(source, "match", f"match block after {signature!r}", start)
|
||||
brace_start = require_index(source, "{", f"match block after {signature!r}", match_start)
|
||||
depth = 0
|
||||
for index in range(brace_start, len(source)):
|
||||
char = source[index]
|
||||
@@ -69,8 +67,10 @@ def provider_kind_ids(config_rs: str) -> dict[str, str]:
|
||||
|
||||
|
||||
def provider_tables(config_rs: str) -> set[str]:
|
||||
struct_start = config_rs.index("pub struct ProvidersToml")
|
||||
struct_end = config_rs.index("\n}", struct_start)
|
||||
struct_start = require_index(
|
||||
config_rs, "pub struct ProvidersToml", "crates/config/src/lib.rs"
|
||||
)
|
||||
struct_end = require_index(config_rs, "\n}", "ProvidersToml struct", struct_start)
|
||||
fields = re.findall(
|
||||
r"pub\s+([a-z0-9_]+)\s*:\s*ProviderConfigToml",
|
||||
config_rs[struct_start:struct_end],
|
||||
@@ -81,23 +81,17 @@ def provider_tables(config_rs: str) -> set[str]:
|
||||
|
||||
|
||||
def shipped_provider_rows(providers_md: str) -> set[str]:
|
||||
heading = providers_md.index("## Shipped Providers")
|
||||
next_heading = providers_md.index("\n## ", heading + 1)
|
||||
table = providers_md[heading:next_heading]
|
||||
table = markdown_section(providers_md, "## Shipped Providers")
|
||||
return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE))
|
||||
|
||||
|
||||
def shipped_provider_tables(providers_md: str) -> set[str]:
|
||||
heading = providers_md.index("## Shipped Providers")
|
||||
next_heading = providers_md.index("\n## ", heading + 1)
|
||||
table = providers_md[heading:next_heading]
|
||||
table = markdown_section(providers_md, "## Shipped Providers")
|
||||
return set(re.findall(r"\|\s*`\[providers\.([a-z0-9_]+)\]`\s*\|", table))
|
||||
|
||||
|
||||
def static_registry_provider_rows(providers_md: str) -> set[str]:
|
||||
heading = providers_md.index("## Static Model Registry")
|
||||
next_heading = providers_md.index("\n## ", heading + 1)
|
||||
table = providers_md[heading:next_heading]
|
||||
table = markdown_section(providers_md, "## Static Model Registry")
|
||||
return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE))
|
||||
|
||||
|
||||
@@ -124,7 +118,8 @@ def default_strings(tui_config_rs: str) -> set[str]:
|
||||
|
||||
|
||||
def missing_default_strings(providers_md: str, defaults: set[str]) -> list[str]:
|
||||
return sorted(value for value in defaults if value not in providers_md)
|
||||
code_spans = set(re.findall(r"`([^`]+)`", providers_md))
|
||||
return sorted(defaults - code_spans)
|
||||
|
||||
|
||||
def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]:
|
||||
@@ -139,6 +134,7 @@ def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]:
|
||||
|
||||
|
||||
def main() -> int:
|
||||
try:
|
||||
config_rs = read(CONFIG_RS)
|
||||
tui_config_rs = read(TUI_CONFIG_RS)
|
||||
agent_rs = read(AGENT_RS)
|
||||
@@ -146,15 +142,7 @@ def main() -> int:
|
||||
|
||||
variant_to_id = provider_kind_ids(config_rs)
|
||||
canonical_ids = set(variant_to_id.values())
|
||||
missing_table_mappings = sorted(set(variant_to_id) - set(PROVIDER_VARIANT_TO_TABLE))
|
||||
if missing_table_mappings:
|
||||
raise ValueError(
|
||||
"PROVIDER_VARIANT_TO_TABLE is missing variants: "
|
||||
+ ", ".join(missing_table_mappings)
|
||||
)
|
||||
expected_tables = {
|
||||
PROVIDER_VARIANT_TO_TABLE[variant] for variant in variant_to_id
|
||||
}
|
||||
expected_tables = {provider_id.replace("-", "_") for provider_id in canonical_ids}
|
||||
|
||||
errors: list[str] = []
|
||||
errors += report_set(
|
||||
@@ -177,9 +165,11 @@ def main() -> int:
|
||||
missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs))
|
||||
if missing_defaults:
|
||||
errors.append(
|
||||
"docs/PROVIDERS.md does not mention default strings: "
|
||||
"docs/PROVIDERS.md does not mention default strings as Markdown code spans: "
|
||||
+ ", ".join(missing_defaults)
|
||||
)
|
||||
except ValueError as err:
|
||||
errors = [str(err)]
|
||||
|
||||
if errors:
|
||||
print("Provider registry drift check failed:", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user