""" AI 模型配置加载器 ================ 从 YAML 文件加载模型配置,支持热重载。 API Key 从 Settings 读取(符合配置规范)。 """ import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any logger = logging.getLogger(__name__) # 尝试导入 YAML 库 try: import yaml YAML_AVAILABLE = True except ImportError: YAML_AVAILABLE = False logger.warning("PyYAML 未安装,使用 JSON 备选方案。安装: pip install pyyaml") @dataclass class PlatformConfig: """平台配置""" id: str name: str provider: str priority: int = 100 base_url: str = "" # 从 YAML 读取,可选 @dataclass class ModelConfig: """模型配置""" id: str platform_id: str model_name: str display_name: str capabilities: list[str] = field(default_factory=list) default_params: dict[str, Any] = field(default_factory=dict) is_enabled: bool = True cost_per_1k_input: float = 0.0 cost_per_1k_output: float = 0.0 max_tokens_limit: int = 4096 class AIModelConfigLoader: """AI 模型配置加载器 从 YAML 加载模型配置(支持热重载)。 API Key 从 Settings 读取(通过 get_settings()),符合配置规范。 """ DEFAULT_CONFIG_PATH = ( Path(__file__).parent.parent.parent / "config" / "ai_models.yaml" ) def __init__(self, config_path: str | None = None): self.config_path = ( Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH ) self._platforms: dict[str, PlatformConfig] = {} self._models: dict[str, ModelConfig] = {} self._task_defaults: dict[str, str] = {} self._last_modified = 0 self._load() def _load(self): """加载配置文件""" if not self.config_path.exists(): logger.warning(f"配置文件不存在: {self.config_path},使用默认配置") self._load_defaults() return try: with open(self.config_path, encoding="utf-8") as f: if YAML_AVAILABLE: config = yaml.safe_load(f) else: # 备选:使用 JSON import json config = json.load(f) self._parse_config(config) self._last_modified = self.config_path.stat().st_mtime logger.info( f"已加载模型配置: {len(self._platforms)} 平台, {len(self._models)} 模型" ) except Exception as e: logger.error(f"加载配置文件失败: {e},使用默认配置") self._load_defaults() def _parse_config(self, config: dict): """解析配置(仅解析模型配置,API Key 从 Settings 读取)""" # 解析平台 platforms_data = config.get("platforms", {}) for pid, pdata in platforms_data.items(): self._platforms[pid] = PlatformConfig( id=pid, name=pdata.get("name", pid), provider=pdata.get("provider", pid), priority=pdata.get("priority", 100), base_url=pdata.get("base_url", ""), ) # 解析模型 models_data = config.get("models", {}) for mid, mdata in models_data.items(): self._models[mid] = ModelConfig( id=mid, platform_id=mdata.get("platform_id", ""), model_name=mdata.get("model_name", mid), display_name=mdata.get("display_name", mid), capabilities=mdata.get("capabilities", []), default_params=mdata.get("default_params", {}), is_enabled=mdata.get("is_enabled", True), cost_per_1k_input=mdata.get("cost_per_1k_input", 0.0), cost_per_1k_output=mdata.get("cost_per_1k_output", 0.0), max_tokens_limit=mdata.get("max_tokens_limit", 4096), ) # 解析任务默认映射 self._task_defaults = config.get("task_defaults", {}) def _load_defaults(self): """加载默认配置""" self._platforms = { "mock": PlatformConfig( id="mock", name="Mock 测试平台", provider="mock", priority=999, ) } self._models = { "mock-model": ModelConfig( id="mock-model", platform_id="mock", model_name="mock-model", display_name="Mock 测试模型", capabilities=["script", "polish", "chat"], ) } self._task_defaults = { "script": "mock-model", "polish": "mock-model", "chat": "mock-model", } def reload(self): """重新加载配置(如果文件有更新)""" if self.config_path.exists(): current_mtime = self.config_path.stat().st_mtime if current_mtime > self._last_modified: logger.info("配置文件已更新,重新加载") self._load() return True return False # ============== 查询方法 ============== def get_platform(self, platform_id: str) -> PlatformConfig | None: """获取平台配置""" return self._platforms.get(platform_id) def get_all_platforms(self) -> list[PlatformConfig]: """获取所有平台(按优先级排序)""" return sorted(self._platforms.values(), key=lambda p: p.priority) def get_model(self, model_id: str) -> ModelConfig | None: """获取模型配置""" return self._models.get(model_id) def get_all_models(self) -> list[ModelConfig]: """获取所有模型""" return list(self._models.values()) def get_enabled_models(self) -> list[ModelConfig]: """获取启用的模型""" return [m for m in self._models.values() if m.is_enabled] def get_models_by_capability(self, capability: str) -> list[ModelConfig]: """根据能力获取模型""" return [ m for m in self._models.values() if m.is_enabled and capability in m.capabilities ] def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]: """根据平台获取模型""" return [ m for m in self._models.values() if m.platform_id == platform_id and m.is_enabled ] def get_default_model_for_task(self, task_type: str) -> str | None: """获取任务类型的默认模型 ID""" return self._task_defaults.get(task_type) def set_default_model_for_task(self, task_type: str, model_id: str): """设置任务类型的默认模型(内存中,不保存到文件)""" if model_id in self._models: self._task_defaults[task_type] = model_id # 全局配置加载器实例 _config_loader: AIModelConfigLoader | None = None def get_config_loader() -> AIModelConfigLoader: """获取全局配置加载器""" global _config_loader if _config_loader is None: _config_loader = AIModelConfigLoader() return _config_loader def reload_config() -> bool: """重新加载配置""" loader = get_config_loader() return loader.reload()