""" 平台统一配置管理 ================ 从 config/platform-config.yaml 加载所有配置: - platforms: 平台信息、模型列表、方法配置、限流参数 - task_defaults: 任务默认模型映射 - runtime: 任务超时、TTL 策略 使用 Pydantic Schema 校验,启动时加载,全环境只读。 """ from __future__ import annotations import logging from pathlib import Path from typing import Any from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) try: import yaml YAML_AVAILABLE = True except ImportError: YAML_AVAILABLE = False logger.warning("PyYAML 未安装") # ═══════════════════════════════════════════════════════════════ # Pydantic Schema # ═══════════════════════════════════════════════════════════════ class RateLimitConfig(BaseModel): """限流配置""" qps: float = Field(default=10.0, ge=0) burst: int = Field(default=20, ge=1) class ModelConfig(BaseModel): """模型配置""" model_config = ConfigDict(protected_namespaces=()) id: str platform_id: str = "" # 加载时自动注入 model_name: str = "" display_name: str = "" capabilities: list[str] = [] default_params: dict[str, Any] = {} is_enabled: bool = True cost_per_1k_input: float = 0.0 cost_per_1k_output: float = 0.0 max_tokens_limit: int = 4096 class MethodConfig(BaseModel): """方法级配置""" timeout: int = Field(default=30, ge=1) max_connections: int = Field(default=20, ge=1) rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig) @property def rate_limit_qps(self) -> float: return self.rate_limit.qps @property def rate_limit_burst(self) -> int: return self.rate_limit.burst class PlatformConfigData(BaseModel): """平台配置数据(对应 yaml 中单个平台)""" name: str = "" provider: str = "" base_url: str = "" priority: int = 100 rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig) models: list[ModelConfig] = [] methods: dict[str, MethodConfig] = {} @property def rate_limit_qps(self) -> float: return self.rate_limit.qps @property def rate_limit_burst(self) -> int: return self.rate_limit.burst class RuntimeSection(BaseModel): """运行时策略""" task_timeouts: dict[str, int] = {} task_ttl: dict[str, int] = {} class PlatformConfigRoot(BaseModel): """配置文件根 Schema""" platforms: dict[str, PlatformConfigData] = {} task_defaults: dict[str, str] = {} runtime: RuntimeSection = Field(default_factory=RuntimeSection) version: str = "" # ═══════════════════════════════════════════════════════════════ # 兼容层 PlatformConfig(保持与旧代码接口一致) # ═══════════════════════════════════════════════════════════════ class PlatformConfig: """平台配置(兼容旧接口,供外部代码使用)""" def __init__(self, platform_id: str, data: PlatformConfigData): self.id = platform_id self.name = data.name self.provider = data.provider self.base_url = data.base_url self.priority = data.priority self.rate_limit_qps = data.rate_limit_qps self.rate_limit_burst = data.rate_limit_burst self.models = data.models self.methods = data.methods # ═══════════════════════════════════════════════════════════════ # 统一加载器 # ═══════════════════════════════════════════════════════════════ class PlatformConfigLoader: """平台统一配置加载器 从 platform-config.yaml 一次性加载并校验所有配置, 提供 platforms / models / methods / runtime 的统一访问。 启动时加载,全环境只读(不支持热重载)。 """ DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent / "config" / "platform-config.yaml" def __init__(self, config_path: str | None = None): self.config_path = Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH self._raw: PlatformConfigRoot | None = None self._platforms: dict[str, PlatformConfig] = {} self._models: dict[str, ModelConfig] = {} self._load() def _load(self) -> None: """加载并校验配置文件""" if not self.config_path.exists(): raise FileNotFoundError(f"平台配置文件不存在: {self.config_path}") try: with open(self.config_path, encoding="utf-8") as f: if YAML_AVAILABLE: raw_dict = yaml.safe_load(f) else: import json raw_dict = json.load(f) # Pydantic Schema 校验(失败直接抛异常,fail fast) self._raw = PlatformConfigRoot.model_validate(raw_dict) # 构建平台索引 self._platforms = {} for pid, pdata in self._raw.platforms.items(): self._platforms[pid] = PlatformConfig(pid, pdata) # 构建模型索引(注入 platform_id) self._models = {} for pid, pdata in self._raw.platforms.items(): for m in pdata.models: m.platform_id = pid self._models[m.id] = m logger.info( f"平台配置加载完成: {len(self._platforms)} 平台, " f"{len(self._models)} 模型, " f"version={self._raw.version or 'none'}" ) except Exception as e: logger.error(f"平台配置加载失败: {e}") raise RuntimeError(f"平台配置加载失败: {e}") from e # ── 平台查询 ────────────────────────────────────────────── def get_platform(self, platform_id: str) -> PlatformConfig | None: return self._platforms.get(platform_id) def get_all_platforms(self) -> list[PlatformConfig]: return sorted(self._platforms.values(), key=lambda p: p.priority) # ── 模型查询 ────────────────────────────────────────────── def get_model(self, model_id: str) -> ModelConfig | None: return self._models.get(model_id) def get_all_models(self) -> list[ModelConfig]: return list(self._models.values()) def get_enabled_models(self) -> list[ModelConfig]: return [m for m in self._models.values() if m.is_enabled] def get_models_by_capability(self, capability: str) -> list[ModelConfig]: return [m for m in self._models.values() if m.is_enabled and capability in m.capabilities] def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]: return [m for m in self._models.values() if m.platform_id == platform_id and m.is_enabled] def get_default_model_for_task(self, task_type: str) -> str | None: if self._raw is None: return None return self._raw.task_defaults.get(task_type) # ── 方法查询 ────────────────────────────────────────────── def get_method_config(self, platform_id: str, method: str) -> MethodConfig | None: platform = self._raw.platforms.get(platform_id) if self._raw else None if platform: return platform.methods.get(method) return None # ── 运行时策略 ──────────────────────────────────────────── def get_runtime_raw(self) -> dict[str, Any]: """获取运行时配置原始数据(用于 Admin API 展示)""" if self._raw is None: return {} return self._raw.runtime.model_dump() def get_task_timeout(self, task_type: str) -> int | None: if self._raw is None: return None return self._raw.runtime.task_timeouts.get(task_type) def get_task_ttl(self, task_type: str) -> int | None: if self._raw is None: return None return self._raw.runtime.task_ttl.get(task_type) @property def version(self) -> str: return self._raw.version if self._raw else "" # ═══════════════════════════════════════════════════════════════ # 全局单例 # ═══════════════════════════════════════════════════════════════ _platform_config_loader: PlatformConfigLoader | None = None def get_platform_config_loader() -> PlatformConfigLoader: global _platform_config_loader if _platform_config_loader is None: _platform_config_loader = PlatformConfigLoader() return _platform_config_loader