test: check tui provider enum drift
This commit is contained in:
@@ -6,6 +6,7 @@ the stable identifiers and default strings that are easy for docs to drift from:
|
||||
|
||||
- canonical ProviderKind IDs
|
||||
- provider TOML tables
|
||||
- live TUI ApiProvider IDs
|
||||
- shipped-provider table rows
|
||||
- static ModelRegistry provider rows
|
||||
- default provider model/base URL constants
|
||||
@@ -24,6 +25,10 @@ TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs"
|
||||
AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs"
|
||||
PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md"
|
||||
|
||||
|
||||
API_PROVIDER_ONLY_IDS = {"deepseek-cn"}
|
||||
|
||||
|
||||
def read(path: Path) -> str:
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
@@ -42,8 +47,10 @@ def markdown_section(source: str, heading: str) -> str:
|
||||
return source[start:end]
|
||||
|
||||
|
||||
def extract_match_block(source: str, signature: str) -> str:
|
||||
start = require_index(source, signature, "crates/config/src/lib.rs")
|
||||
def extract_match_block(
|
||||
source: str, signature: str, context: str, start: int = 0
|
||||
) -> str:
|
||||
start = require_index(source, signature, context, start)
|
||||
match_start = require_index(source, "match", f"match block after {signature!r}", start)
|
||||
brace_start = require_index(source, "{", f"match block after {signature!r}", match_start)
|
||||
depth = 0
|
||||
@@ -59,13 +66,37 @@ def extract_match_block(source: str, signature: str) -> str:
|
||||
|
||||
|
||||
def provider_kind_ids(config_rs: str) -> dict[str, str]:
|
||||
block = extract_match_block(config_rs, "pub fn as_str(self) -> &'static str")
|
||||
impl_start = require_index(
|
||||
config_rs, "impl ProviderKind", "crates/config/src/lib.rs"
|
||||
)
|
||||
block = extract_match_block(
|
||||
config_rs,
|
||||
"pub fn as_str(self) -> &'static str",
|
||||
"crates/config/src/lib.rs",
|
||||
impl_start,
|
||||
)
|
||||
pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block)
|
||||
if not pairs:
|
||||
raise ValueError("ProviderKind::as_str returned no providers")
|
||||
return {variant: provider_id for variant, provider_id in pairs}
|
||||
|
||||
|
||||
def api_provider_ids(tui_config_rs: str) -> dict[str, str]:
|
||||
impl_start = require_index(
|
||||
tui_config_rs, "impl ApiProvider", "crates/tui/src/config.rs"
|
||||
)
|
||||
block = extract_match_block(
|
||||
tui_config_rs,
|
||||
"pub fn as_str(self) -> &'static str",
|
||||
"crates/tui/src/config.rs",
|
||||
impl_start,
|
||||
)
|
||||
pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block)
|
||||
if not pairs:
|
||||
raise ValueError("ApiProvider::as_str returned no providers")
|
||||
return {variant: provider_id for variant, provider_id in pairs}
|
||||
|
||||
|
||||
def provider_tables(config_rs: str) -> set[str]:
|
||||
struct_start = require_index(
|
||||
config_rs, "pub struct ProvidersToml", "crates/config/src/lib.rs"
|
||||
@@ -133,6 +164,34 @@ def report_set(label: str, expected: set[str], actual: set[str]) -> list[str]:
|
||||
return errors
|
||||
|
||||
|
||||
def report_provider_enum_drift(
|
||||
provider_kind_ids: set[str], api_provider_ids: set[str]
|
||||
) -> list[str]:
|
||||
errors = []
|
||||
missing_from_api_provider = sorted(provider_kind_ids - api_provider_ids)
|
||||
unexpected_api_provider_ids = sorted(
|
||||
api_provider_ids - provider_kind_ids - API_PROVIDER_ONLY_IDS
|
||||
)
|
||||
missing_allowlisted_ids = sorted(API_PROVIDER_ONLY_IDS - api_provider_ids)
|
||||
|
||||
if missing_from_api_provider:
|
||||
errors.append(
|
||||
"ApiProvider missing ProviderKind IDs: "
|
||||
+ ", ".join(missing_from_api_provider)
|
||||
)
|
||||
if unexpected_api_provider_ids:
|
||||
errors.append(
|
||||
"ApiProvider has non-whitelisted IDs absent from ProviderKind: "
|
||||
+ ", ".join(unexpected_api_provider_ids)
|
||||
)
|
||||
if missing_allowlisted_ids:
|
||||
errors.append(
|
||||
"ApiProvider-only whitelist entries are absent from ApiProvider: "
|
||||
+ ", ".join(missing_allowlisted_ids)
|
||||
)
|
||||
return errors
|
||||
|
||||
|
||||
def main() -> int:
|
||||
try:
|
||||
config_rs = read(CONFIG_RS)
|
||||
@@ -142,9 +201,11 @@ def main() -> int:
|
||||
|
||||
variant_to_id = provider_kind_ids(config_rs)
|
||||
canonical_ids = set(variant_to_id.values())
|
||||
live_api_provider_ids = set(api_provider_ids(tui_config_rs).values())
|
||||
expected_tables = {provider_id.replace("-", "_") for provider_id in canonical_ids}
|
||||
|
||||
errors: list[str] = []
|
||||
errors += report_provider_enum_drift(canonical_ids, live_api_provider_ids)
|
||||
errors += report_set(
|
||||
"shipped provider rows",
|
||||
canonical_ids,
|
||||
|
||||
Reference in New Issue
Block a user