Files
meijiaka-zy/python-api/app/core/config_loader.py
T
小鱼开发 e58159fc42 refactor: 第三方平台架构改造(Adapter Protocol + Gateway)
Phase 1: 异常体系统一
- 新增 PlatformError / PlatformErrorType 标准定义
- 改造所有 Provider 异常抛出为 PlatformError
- 注册全局 PlatformError exception handler

Phase 2: Adapter Protocol
- 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable)
- 新增 app/ai/adapters/constants.py(Method 常量)
- 新增 PlatformConfigLoader(config/platform-config.yaml)

Phase 3: HTTP Client 统一
- ViduProvider 从 aiohttp 迁移到 httpx(注入方式)
- VolcengineCaptionService 改为注入 http_client
- lifespan 统一管理所有 Client 创建和关闭

Phase 4: Gateway 骨架 + Adapter 实现
- 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter
- 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook)
- 新增 LLMGateway(带 Fallback 降级链)
- lifespan 注册所有 Adapter 和 Gateway

Phase 6: 清理与验证
- 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL
- Provider 改为从 PlatformConfigLoader 读取 base_url
- 清理 volcengine_caption_service 全局单例
- config_loader 默认路径改为 platform-config.yaml
- Scheduler 注入共享 HTTP client
- vidu.py 回调路由使用 Adapter 验签和解析
- ruff 全量通过,应用启动测试通过
2026-05-04 16:07:16 +08:00

251 lines
8.3 KiB
Python

"""
AI 模型配置加载器
================
从 YAML 文件加载模型配置,支持热重载。
API Key 从 Settings 读取(符合配置规范)。
"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# 尝试导入 YAML 库
try:
import yaml
YAML_AVAILABLE = True
except ImportError:
YAML_AVAILABLE = False
logger.warning("PyYAML 未安装,使用 JSON 备选方案。安装: pip install pyyaml")
@dataclass
class PlatformConfig:
"""平台配置"""
id: str
name: str
provider: str
priority: int = 100
base_url: str = "" # 从 YAML 读取,可选
@dataclass
class ModelConfig:
"""模型配置"""
id: str
platform_id: str
model_name: str
display_name: str
capabilities: list[str] = field(default_factory=list)
default_params: dict[str, Any] = field(default_factory=dict)
is_enabled: bool = True
cost_per_1k_input: float = 0.0
cost_per_1k_output: float = 0.0
max_tokens_limit: int = 4096
class AIModelConfigLoader:
"""AI 模型配置加载器
从 YAML 加载模型配置(支持热重载)。
API Key 从 Settings 读取(通过 get_settings()),符合配置规范。
"""
DEFAULT_CONFIG_PATH = (
Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
)
def __init__(self, config_path: str | None = None):
self.config_path = (
Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
)
self._platforms: dict[str, PlatformConfig] = {}
self._models: dict[str, ModelConfig] = {}
self._task_defaults: dict[str, str] = {}
self._last_modified = 0
self._load()
def _load(self):
"""加载配置文件"""
if not self.config_path.exists():
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
self._load_defaults()
return
try:
with open(self.config_path, encoding="utf-8") as f:
if YAML_AVAILABLE:
config = yaml.safe_load(f)
else:
# 备选:使用 JSON
import json
config = json.load(f)
self._parse_config(config)
self._last_modified = self.config_path.stat().st_mtime
logger.info(
f"已加载模型配置: {len(self._platforms)} 平台, {len(self._models)} 模型"
)
except Exception as e:
logger.error(f"加载配置文件失败: {e},使用默认配置")
self._load_defaults()
def _parse_config(self, config: dict):
"""解析配置(仅解析模型配置,API Key 从 Settings 读取)"""
# 解析平台
platforms_data = config.get("platforms", {})
for pid, pdata in platforms_data.items():
self._platforms[pid] = PlatformConfig(
id=pid,
name=pdata.get("name", pid),
provider=pdata.get("provider", pid),
priority=pdata.get("priority", 100),
base_url=pdata.get("base_url", ""),
)
# 从平台内嵌的 models 列表中提取模型(platform-config.yaml 格式)
for mdata in pdata.get("models", []):
mid = mdata.get("id")
if not mid:
continue
self._models[mid] = ModelConfig(
id=mid,
platform_id=pid,
model_name=mdata.get("model_name", mid),
display_name=mdata.get("display_name", mid),
capabilities=mdata.get("capabilities", []),
default_params=mdata.get("default_params", {}),
is_enabled=mdata.get("is_enabled", True),
cost_per_1k_input=mdata.get("cost_per_1k_input", 0.0),
cost_per_1k_output=mdata.get("cost_per_1k_output", 0.0),
max_tokens_limit=mdata.get("max_tokens_limit", 4096),
)
# 兼容旧格式:顶层的 models 字典
models_data = config.get("models", {})
for mid, mdata in models_data.items():
if mid not in self._models:
self._models[mid] = ModelConfig(
id=mid,
platform_id=mdata.get("platform_id", ""),
model_name=mdata.get("model_name", mid),
display_name=mdata.get("display_name", mid),
capabilities=mdata.get("capabilities", []),
default_params=mdata.get("default_params", {}),
is_enabled=mdata.get("is_enabled", True),
cost_per_1k_input=mdata.get("cost_per_1k_input", 0.0),
cost_per_1k_output=mdata.get("cost_per_1k_output", 0.0),
max_tokens_limit=mdata.get("max_tokens_limit", 4096),
)
# 解析任务默认映射
self._task_defaults = config.get("task_defaults", {})
def _load_defaults(self):
"""加载默认配置"""
self._platforms = {
"mock": PlatformConfig(
id="mock",
name="Mock 测试平台",
provider="mock",
priority=999,
)
}
self._models = {
"mock-model": ModelConfig(
id="mock-model",
platform_id="mock",
model_name="mock-model",
display_name="Mock 测试模型",
capabilities=["script", "polish", "chat"],
)
}
self._task_defaults = {
"script": "mock-model",
"polish": "mock-model",
"chat": "mock-model",
}
def reload(self):
"""重新加载配置(如果文件有更新)"""
if self.config_path.exists():
current_mtime = self.config_path.stat().st_mtime
if current_mtime > self._last_modified:
logger.info("配置文件已更新,重新加载")
self._load()
return True
return False
# ============== 查询方法 ==============
def get_platform(self, platform_id: str) -> PlatformConfig | None:
"""获取平台配置"""
return self._platforms.get(platform_id)
def get_all_platforms(self) -> list[PlatformConfig]:
"""获取所有平台(按优先级排序)"""
return sorted(self._platforms.values(), key=lambda p: p.priority)
def get_model(self, model_id: str) -> ModelConfig | None:
"""获取模型配置"""
return self._models.get(model_id)
def get_all_models(self) -> list[ModelConfig]:
"""获取所有模型"""
return list(self._models.values())
def get_enabled_models(self) -> list[ModelConfig]:
"""获取启用的模型"""
return [m for m in self._models.values() if m.is_enabled]
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
"""根据能力获取模型"""
return [
m
for m in self._models.values()
if m.is_enabled and capability in m.capabilities
]
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
"""根据平台获取模型"""
return [
m
for m in self._models.values()
if m.platform_id == platform_id and m.is_enabled
]
def get_default_model_for_task(self, task_type: str) -> str | None:
"""获取任务类型的默认模型 ID"""
return self._task_defaults.get(task_type)
def set_default_model_for_task(self, task_type: str, model_id: str):
"""设置任务类型的默认模型(内存中,不保存到文件)"""
if model_id in self._models:
self._task_defaults[task_type] = model_id
# 全局配置加载器实例
_config_loader: AIModelConfigLoader | None = None
def get_config_loader() -> AIModelConfigLoader:
"""获取全局配置加载器"""
global _config_loader
if _config_loader is None:
_config_loader = AIModelConfigLoader()
return _config_loader
def reload_config() -> bool:
"""重新加载配置"""
loader = get_config_loader()
return loader.reload()