fix: harden provider registry drift check

This commit is contained in:
Hunter Bown
2026-05-27 06:19:32 -05:00
parent b0e7b67386
commit e2099dd691
3 changed files with 61 additions and 71 deletions
+59 -69
View File
@@ -24,30 +24,28 @@ TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs"
AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs"
PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md"
PROVIDER_VARIANT_TO_TABLE = {
"Deepseek": "deepseek",
"NvidiaNim": "nvidia_nim",
"Openai": "openai",
"Atlascloud": "atlascloud",
"WanjieArk": "wanjie_ark",
"Openrouter": "openrouter",
"Novita": "novita",
"Fireworks": "fireworks",
"Moonshot": "moonshot",
"Sglang": "sglang",
"Vllm": "vllm",
"Ollama": "ollama",
}
def read(path: Path) -> str:
return path.read_text(encoding="utf-8")
def require_index(source: str, needle: str, context: str, start: int = 0) -> int:
try:
return source.index(needle, start)
except ValueError:
raise ValueError(f"{context}: missing {needle!r}") from None
def markdown_section(source: str, heading: str) -> str:
start = require_index(source, heading, "docs/PROVIDERS.md")
next_heading = source.find("\n## ", start + len(heading))
end = len(source) if next_heading == -1 else next_heading
return source[start:end]
def extract_match_block(source: str, signature: str) -> str:
start = source.index(signature)
match_start = source.index("match", start)
brace_start = source.index("{", match_start)
start = require_index(source, signature, "crates/config/src/lib.rs")
match_start = require_index(source, "match", f"match block after {signature!r}", start)
brace_start = require_index(source, "{", f"match block after {signature!r}", match_start)
depth = 0
for index in range(brace_start, len(source)):
char = source[index]
@@ -69,8 +67,10 @@ def provider_kind_ids(config_rs: str) -> dict[str, str]:
def provider_tables(config_rs: str) -> set[str]:
struct_start = config_rs.index("pub struct ProvidersToml")
struct_end = config_rs.index("\n}", struct_start)
struct_start = require_index(
config_rs, "pub struct ProvidersToml", "crates/config/src/lib.rs"
)
struct_end = require_index(config_rs, "\n}", "ProvidersToml struct", struct_start)
fields = re.findall(
r"pub\s+([a-z0-9_]+)\s*:\s*ProviderConfigToml",
config_rs[struct_start:struct_end],
@@ -81,23 +81,17 @@ def provider_tables(config_rs: str) -> set[str]:
def shipped_provider_rows(providers_md: str) -> set[str]:
heading = providers_md.index("## Shipped Providers")
next_heading = providers_md.index("\n## ", heading + 1)
table = providers_md[heading:next_heading]
table = markdown_section(providers_md, "## Shipped Providers")
return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE))
def shipped_provider_tables(providers_md: str) -> set[str]:
heading = providers_md.index("## Shipped Providers")
next_heading = providers_md.index("\n## ", heading + 1)
table = providers_md[heading:next_heading]
table = markdown_section(providers_md, "## Shipped Providers")
return set(re.findall(r"\|\s*`\[providers\.([a-z0-9_]+)\]`\s*\|", table))
def static_registry_provider_rows(providers_md: str) -> set[str]:
heading = providers_md.index("## Static Model Registry")
next_heading = providers_md.index("\n## ", heading + 1)
table = providers_md[heading:next_heading]
table = markdown_section(providers_md, "## Static Model Registry")
return set(re.findall(r"^\|\s*`([^`]+)`\s*\|", table, flags=re.MULTILINE))
@@ -124,7 +118,8 @@ def default_strings(tui_config_rs: str) -> set[str]:
def missing_default_strings(providers_md: str, defaults: set[str]) -> list[str]:
return sorted(value for value in defaults if value not in providers_md)
code_spans = set(re.findall(r"`([^`]+)`", providers_md))
return sorted(defaults - code_spans)
def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]:
@@ -139,47 +134,42 @@ def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]:
def main() -> int:
config_rs = read(CONFIG_RS)
tui_config_rs = read(TUI_CONFIG_RS)
agent_rs = read(AGENT_RS)
providers_md = read(PROVIDERS_MD)
try:
config_rs = read(CONFIG_RS)
tui_config_rs = read(TUI_CONFIG_RS)
agent_rs = read(AGENT_RS)
providers_md = read(PROVIDERS_MD)
variant_to_id = provider_kind_ids(config_rs)
canonical_ids = set(variant_to_id.values())
missing_table_mappings = sorted(set(variant_to_id) - set(PROVIDER_VARIANT_TO_TABLE))
if missing_table_mappings:
raise ValueError(
"PROVIDER_VARIANT_TO_TABLE is missing variants: "
+ ", ".join(missing_table_mappings)
variant_to_id = provider_kind_ids(config_rs)
canonical_ids = set(variant_to_id.values())
expected_tables = {provider_id.replace("-", "_") for provider_id in canonical_ids}
errors: list[str] = []
errors += report_set(
"shipped provider rows",
canonical_ids,
shipped_provider_rows(providers_md),
)
expected_tables = {
PROVIDER_VARIANT_TO_TABLE[variant] for variant in variant_to_id
}
errors: list[str] = []
errors += report_set(
"shipped provider rows",
canonical_ids,
shipped_provider_rows(providers_md),
)
errors += report_set("provider TOML tables", expected_tables, provider_tables(config_rs))
errors += report_set(
"documented provider TOML tables",
expected_tables,
shipped_provider_tables(providers_md),
)
errors += report_set(
"static ModelRegistry rows",
model_registry_providers(agent_rs, variant_to_id),
static_registry_provider_rows(providers_md),
)
missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs))
if missing_defaults:
errors.append(
"docs/PROVIDERS.md does not mention default strings: "
+ ", ".join(missing_defaults)
errors += report_set("provider TOML tables", expected_tables, provider_tables(config_rs))
errors += report_set(
"documented provider TOML tables",
expected_tables,
shipped_provider_tables(providers_md),
)
errors += report_set(
"static ModelRegistry rows",
model_registry_providers(agent_rs, variant_to_id),
static_registry_provider_rows(providers_md),
)
missing_defaults = missing_default_strings(providers_md, default_strings(tui_config_rs))
if missing_defaults:
errors.append(
"docs/PROVIDERS.md does not mention default strings as Markdown code spans: "
+ ", ".join(missing_defaults)
)
except ValueError as err:
errors = [str(err)]
if errors:
print("Provider registry drift check failed:", file=sys.stderr)