Files
meijiaka-zy/python-api/app/core/platform_config.py
T

267 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
平台统一配置管理
================
从 config/platform-config.yaml 加载所有配置:
- platforms: 平台信息、模型列表、方法配置、限流参数
- task_defaults: 任务默认模型映射
- runtime: 任务超时、TTL 策略
使用 Pydantic Schema 校验,启动时加载,全环境只读。
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
logger = logging.getLogger(__name__)
try:
import yaml
YAML_AVAILABLE = True
except ImportError:
YAML_AVAILABLE = False
logger.warning("PyYAML 未安装")
# ═══════════════════════════════════════════════════════════════
# Pydantic Schema
# ═══════════════════════════════════════════════════════════════
class RateLimitConfig(BaseModel):
"""限流配置"""
qps: float = Field(default=10.0, ge=0)
burst: int = Field(default=20, ge=1)
class ModelConfig(BaseModel):
"""模型配置"""
model_config = ConfigDict(protected_namespaces=())
id: str
platform_id: str = "" # 加载时自动注入
model_name: str = ""
display_name: str = ""
capabilities: list[str] = []
default_params: dict[str, Any] = {}
is_enabled: bool = True
cost_per_1k_input: float = 0.0
cost_per_1k_output: float = 0.0
max_tokens_limit: int = 4096
class MethodConfig(BaseModel):
"""方法级配置"""
timeout: int = Field(default=30, ge=1)
max_connections: int = Field(default=20, ge=1)
rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig)
@property
def rate_limit_qps(self) -> float:
return self.rate_limit.qps
@property
def rate_limit_burst(self) -> int:
return self.rate_limit.burst
class PlatformConfigData(BaseModel):
"""平台配置数据(对应 yaml 中单个平台)"""
name: str = ""
provider: str = ""
base_url: str = ""
priority: int = 100
rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig)
models: list[ModelConfig] = []
methods: dict[str, MethodConfig] = {}
@property
def rate_limit_qps(self) -> float:
return self.rate_limit.qps
@property
def rate_limit_burst(self) -> int:
return self.rate_limit.burst
class RuntimeSection(BaseModel):
"""运行时策略"""
task_timeouts: dict[str, int] = {}
task_ttl: dict[str, int] = {}
class PlatformConfigRoot(BaseModel):
"""配置文件根 Schema"""
platforms: dict[str, PlatformConfigData] = {}
task_defaults: dict[str, str] = {}
runtime: RuntimeSection = Field(default_factory=RuntimeSection)
version: str = ""
# ═══════════════════════════════════════════════════════════════
# 兼容层 PlatformConfig(保持与旧代码接口一致)
# ═══════════════════════════════════════════════════════════════
class PlatformConfig:
"""平台配置(兼容旧接口,供外部代码使用)"""
def __init__(self, platform_id: str, data: PlatformConfigData):
self.id = platform_id
self.name = data.name
self.provider = data.provider
self.base_url = data.base_url
self.priority = data.priority
self.rate_limit_qps = data.rate_limit_qps
self.rate_limit_burst = data.rate_limit_burst
self.models = data.models
self.methods = data.methods
# ═══════════════════════════════════════════════════════════════
# 统一加载器
# ═══════════════════════════════════════════════════════════════
class PlatformConfigLoader:
"""平台统一配置加载器
从 platform-config.yaml 一次性加载并校验所有配置,
提供 platforms / models / methods / runtime 的统一访问。
启动时加载,全环境只读(不支持热重载)。
"""
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
def __init__(self, config_path: str | None = None):
self.config_path = Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
self._raw: PlatformConfigRoot | None = None
self._platforms: dict[str, PlatformConfig] = {}
self._models: dict[str, ModelConfig] = {}
self._load()
def _load(self) -> None:
"""加载并校验配置文件"""
if not self.config_path.exists():
raise FileNotFoundError(f"平台配置文件不存在: {self.config_path}")
try:
with open(self.config_path, encoding="utf-8") as f:
if YAML_AVAILABLE:
raw_dict = yaml.safe_load(f)
else:
import json
raw_dict = json.load(f)
# Pydantic Schema 校验(失败直接抛异常,fail fast)
self._raw = PlatformConfigRoot.model_validate(raw_dict)
# 构建平台索引
self._platforms = {}
for pid, pdata in self._raw.platforms.items():
self._platforms[pid] = PlatformConfig(pid, pdata)
# 构建模型索引(注入 platform_id
self._models = {}
for pid, pdata in self._raw.platforms.items():
for m in pdata.models:
m.platform_id = pid
self._models[m.id] = m
logger.info(
f"平台配置加载完成: {len(self._platforms)} 平台, "
f"{len(self._models)} 模型, "
f"version={self._raw.version or 'none'}"
)
except Exception as e:
logger.error(f"平台配置加载失败: {e}")
raise RuntimeError(f"平台配置加载失败: {e}") from e
# ── 平台查询 ──────────────────────────────────────────────
def get_platform(self, platform_id: str) -> PlatformConfig | None:
return self._platforms.get(platform_id)
def get_all_platforms(self) -> list[PlatformConfig]:
return sorted(self._platforms.values(), key=lambda p: p.priority)
# ── 模型查询 ──────────────────────────────────────────────
def get_model(self, model_id: str) -> ModelConfig | None:
return self._models.get(model_id)
def get_all_models(self) -> list[ModelConfig]:
return list(self._models.values())
def get_enabled_models(self) -> list[ModelConfig]:
return [m for m in self._models.values() if m.is_enabled]
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
return [m for m in self._models.values() if m.is_enabled and capability in m.capabilities]
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
return [m for m in self._models.values() if m.platform_id == platform_id and m.is_enabled]
def get_default_model_for_task(self, task_type: str) -> str | None:
if self._raw is None:
return None
return self._raw.task_defaults.get(task_type)
# ── 方法查询 ──────────────────────────────────────────────
def get_method_config(self, platform_id: str, method: str) -> MethodConfig | None:
platform = self._raw.platforms.get(platform_id) if self._raw else None
if platform:
return platform.methods.get(method)
return None
# ── 运行时策略 ────────────────────────────────────────────
def get_runtime_raw(self) -> dict[str, Any]:
"""获取运行时配置原始数据(用于 Admin API 展示)"""
if self._raw is None:
return {}
return self._raw.runtime.model_dump()
def get_task_timeout(self, task_type: str) -> int | None:
if self._raw is None:
return None
return self._raw.runtime.task_timeouts.get(task_type)
def get_task_ttl(self, task_type: str) -> int | None:
if self._raw is None:
return None
return self._raw.runtime.task_ttl.get(task_type)
@property
def version(self) -> str:
return self._raw.version if self._raw else ""
# ═══════════════════════════════════════════════════════════════
# 全局单例
# ═══════════════════════════════════════════════════════════════
_platform_config_loader: PlatformConfigLoader | None = None
def get_platform_config_loader() -> PlatformConfigLoader:
global _platform_config_loader
if _platform_config_loader is None:
_platform_config_loader = PlatformConfigLoader()
return _platform_config_loader