diff --git a/docs/PROVIDERS.md b/docs/PROVIDERS.md index ec9657ab..2899b980 100644 --- a/docs/PROVIDERS.md +++ b/docs/PROVIDERS.md @@ -22,7 +22,8 @@ Sources to keep in sync: - `config.example.toml` and `docs/CONFIGURATION.md` - user-facing config examples and environment variable reference. - `scripts/check-provider-registry.py` - drift check for canonical provider - IDs, TOML table names, static registry rows, and documented defaults. + IDs, live TUI provider IDs, TOML table names, static registry rows, and + documented defaults. ## Provider Selection @@ -147,6 +148,8 @@ python3 scripts/check-provider-registry.py The check fails when: - `docs/PROVIDERS.md` omits a canonical `ProviderKind::as_str()` ID. +- `crates/tui/src/config.rs` `ApiProvider::as_str()` diverges from + `ProviderKind::as_str()` except for the explicit `deepseek-cn` legacy alias. - The shipped-provider table omits or adds a `[providers.*]` TOML table. - The static model registry table drifts from providers used by `crates/agent/src/lib.rs`. diff --git a/scripts/check-provider-registry.py b/scripts/check-provider-registry.py index ebc1d556..ed6b28b8 100644 --- a/scripts/check-provider-registry.py +++ b/scripts/check-provider-registry.py @@ -6,6 +6,7 @@ the stable identifiers and default strings that are easy for docs to drift from: - canonical ProviderKind IDs - provider TOML tables +- live TUI ApiProvider IDs - shipped-provider table rows - static ModelRegistry provider rows - default provider model/base URL constants @@ -24,6 +25,10 @@ TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs" AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs" PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md" + +API_PROVIDER_ONLY_IDS = {"deepseek-cn"} + + def read(path: Path) -> str: return path.read_text(encoding="utf-8") @@ -42,8 +47,10 @@ def markdown_section(source: str, heading: str) -> str: return source[start:end] -def extract_match_block(source: str, signature: str) -> str: - start = require_index(source, signature, "crates/config/src/lib.rs") +def extract_match_block( + source: str, signature: str, context: str, start: int = 0 +) -> str: + start = require_index(source, signature, context, start) match_start = require_index(source, "match", f"match block after {signature!r}", start) brace_start = require_index(source, "{", f"match block after {signature!r}", match_start) depth = 0 @@ -59,13 +66,37 @@ def extract_match_block(source: str, signature: str) -> str: def provider_kind_ids(config_rs: str) -> dict[str, str]: - block = extract_match_block(config_rs, "pub fn as_str(self) -> &'static str") + impl_start = require_index( + config_rs, "impl ProviderKind", "crates/config/src/lib.rs" + ) + block = extract_match_block( + config_rs, + "pub fn as_str(self) -> &'static str", + "crates/config/src/lib.rs", + impl_start, + ) pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block) if not pairs: raise ValueError("ProviderKind::as_str returned no providers") return {variant: provider_id for variant, provider_id in pairs} +def api_provider_ids(tui_config_rs: str) -> dict[str, str]: + impl_start = require_index( + tui_config_rs, "impl ApiProvider", "crates/tui/src/config.rs" + ) + block = extract_match_block( + tui_config_rs, + "pub fn as_str(self) -> &'static str", + "crates/tui/src/config.rs", + impl_start, + ) + pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block) + if not pairs: + raise ValueError("ApiProvider::as_str returned no providers") + return {variant: provider_id for variant, provider_id in pairs} + + def provider_tables(config_rs: str) -> set[str]: struct_start = require_index( config_rs, "pub struct ProvidersToml", "crates/config/src/lib.rs" @@ -133,6 +164,34 @@ def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]: return errors +def report_provider_enum_drift( + provider_kind_ids: set[str], api_provider_ids: set[str] +) -> list[str]: + errors = [] + missing_from_api_provider = sorted(provider_kind_ids - api_provider_ids) + unexpected_api_provider_ids = sorted( + api_provider_ids - provider_kind_ids - API_PROVIDER_ONLY_IDS + ) + missing_allowlisted_ids = sorted(API_PROVIDER_ONLY_IDS - api_provider_ids) + + if missing_from_api_provider: + errors.append( + "ApiProvider missing ProviderKind IDs: " + + ", ".join(missing_from_api_provider) + ) + if unexpected_api_provider_ids: + errors.append( + "ApiProvider has non-whitelisted IDs absent from ProviderKind: " + + ", ".join(unexpected_api_provider_ids) + ) + if missing_allowlisted_ids: + errors.append( + "ApiProvider-only whitelist entries are absent from ApiProvider: " + + ", ".join(missing_allowlisted_ids) + ) + return errors + + def main() -> int: try: config_rs = read(CONFIG_RS) @@ -142,9 +201,11 @@ def main() -> int: variant_to_id = provider_kind_ids(config_rs) canonical_ids = set(variant_to_id.values()) + live_api_provider_ids = set(api_provider_ids(tui_config_rs).values()) expected_tables = {provider_id.replace("-", "_") for provider_id in canonical_ids} errors: list[str] = [] + errors += report_provider_enum_drift(canonical_ids, live_api_provider_ids) errors += report_set( "shipped provider rows", canonical_ids,