Merge PR #3005: provider metadata registry
Harvested from PR #3005 by @sximelon Co-authored-by: sximelon <15710511+sximelon@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
CONFIG_RS = ROOT / "crates" / "config" / "src" / "lib.rs"
|
||||
PROVIDER_RS = ROOT / "crates" / "config" / "src" / "provider.rs"
|
||||
TUI_CONFIG_RS = ROOT / "crates" / "tui" / "src" / "config.rs"
|
||||
AGENT_RS = ROOT / "crates" / "agent" / "src" / "lib.rs"
|
||||
PROVIDERS_MD = ROOT / "docs" / "PROVIDERS.md"
|
||||
@@ -69,35 +70,37 @@ def extract_match_block(
|
||||
|
||||
|
||||
def provider_kind_ids(config_rs: str) -> dict[str, str]:
|
||||
impl_start = require_index(
|
||||
config_rs, "impl ProviderKind", "crates/config/src/lib.rs"
|
||||
provider_rs = read(PROVIDER_RS)
|
||||
pairs = re.findall(
|
||||
r"provider!\(\s*\n\s*\w+,\s*\n\s*(\w+),\s*\n\s*\"([^\"]+)\"",
|
||||
provider_rs,
|
||||
)
|
||||
block = extract_match_block(
|
||||
config_rs,
|
||||
"pub fn as_str(self) -> &'static str",
|
||||
"crates/config/src/lib.rs",
|
||||
impl_start,
|
||||
)
|
||||
pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block)
|
||||
if not pairs:
|
||||
raise ValueError("ProviderKind::as_str returned no providers")
|
||||
return {variant: provider_id for variant, provider_id in pairs}
|
||||
ids: dict[str, str] = {variant: provider_id for variant, provider_id in pairs}
|
||||
# OpenaiCodex and Anthropic use manual impls rather than the provider!() macro
|
||||
for variant_name, id_literal in [
|
||||
("OpenaiCodex", "openai-codex"),
|
||||
("Anthropic", "anthropic"),
|
||||
]:
|
||||
match = re.search(
|
||||
rf'impl\s+Provider\s+for\s+{variant_name}.*?fn\s+id.*?\"({id_literal})\"',
|
||||
provider_rs, re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
ids[variant_name] = match.group(1)
|
||||
if not ids:
|
||||
raise ValueError("provider!() invocations returned no providers")
|
||||
return ids
|
||||
|
||||
|
||||
def api_provider_ids(tui_config_rs: str) -> dict[str, str]:
|
||||
impl_start = require_index(
|
||||
tui_config_rs, "impl ApiProvider", "crates/tui/src/config.rs"
|
||||
)
|
||||
block = extract_match_block(
|
||||
tui_config_rs,
|
||||
"pub fn as_str(self) -> &'static str",
|
||||
"crates/tui/src/config.rs",
|
||||
impl_start,
|
||||
)
|
||||
pairs = re.findall(r"Self::(\w+)\s*=>\s*\"([^\"]+)\"", block)
|
||||
if not pairs:
|
||||
raise ValueError("ApiProvider::as_str returned no providers")
|
||||
return {variant: provider_id for variant, provider_id in pairs}
|
||||
# ApiProvider ids derive from ProviderKind ids (via delegation to .kind().as_str())
|
||||
# plus the legacy "deepseek-cn" variant that exists only in ApiProvider.
|
||||
variant_to_id = provider_kind_ids("")
|
||||
# ApiProvider::SiliconflowCn maps to ProviderKind::SiliconflowCN
|
||||
if "SiliconflowCN" in variant_to_id:
|
||||
variant_to_id["SiliconflowCn"] = variant_to_id["SiliconflowCN"]
|
||||
variant_to_id["DeepseekCN"] = "deepseek-cn"
|
||||
return variant_to_id
|
||||
|
||||
|
||||
def provider_tables(config_rs: str) -> set[str]:
|
||||
@@ -253,4 +256,4 @@ def main() -> int:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user