232 lines
7.3 KiB
Python
232 lines
7.3 KiB
Python
"""
|
|
AI 模型配置加载器
|
|
================
|
|
|
|
从 YAML 文件加载模型配置,支持热重载。
|
|
API Key 从 Settings 读取(符合配置规范)。
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 尝试导入 YAML 库
|
|
try:
|
|
import yaml
|
|
|
|
YAML_AVAILABLE = True
|
|
except ImportError:
|
|
YAML_AVAILABLE = False
|
|
logger.warning("PyYAML 未安装,使用 JSON 备选方案。安装: pip install pyyaml")
|
|
|
|
|
|
@dataclass
|
|
class PlatformConfig:
|
|
"""平台配置"""
|
|
|
|
id: str
|
|
name: str
|
|
provider: str
|
|
priority: int = 100
|
|
base_url: str = "" # 从 YAML 读取,可选
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""模型配置"""
|
|
|
|
id: str
|
|
platform_id: str
|
|
model_name: str
|
|
display_name: str
|
|
capabilities: list[str] = field(default_factory=list)
|
|
default_params: dict[str, Any] = field(default_factory=dict)
|
|
is_enabled: bool = True
|
|
cost_per_1k_input: float = 0.0
|
|
cost_per_1k_output: float = 0.0
|
|
max_tokens_limit: int = 4096
|
|
|
|
|
|
class AIModelConfigLoader:
|
|
"""AI 模型配置加载器
|
|
|
|
从 YAML 加载模型配置(支持热重载)。
|
|
API Key 从 Settings 读取(通过 get_settings()),符合配置规范。
|
|
"""
|
|
|
|
DEFAULT_CONFIG_PATH = (
|
|
Path(__file__).parent.parent.parent / "config" / "ai_models.yaml"
|
|
)
|
|
|
|
def __init__(self, config_path: str | None = None):
|
|
self.config_path = (
|
|
Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
|
|
)
|
|
self._platforms: dict[str, PlatformConfig] = {}
|
|
self._models: dict[str, ModelConfig] = {}
|
|
self._task_defaults: dict[str, str] = {}
|
|
self._last_modified = 0
|
|
self._load()
|
|
|
|
def _load(self):
|
|
"""加载配置文件"""
|
|
if not self.config_path.exists():
|
|
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
|
|
self._load_defaults()
|
|
return
|
|
|
|
try:
|
|
with open(self.config_path, encoding="utf-8") as f:
|
|
if YAML_AVAILABLE:
|
|
config = yaml.safe_load(f)
|
|
else:
|
|
# 备选:使用 JSON
|
|
import json
|
|
|
|
config = json.load(f)
|
|
|
|
self._parse_config(config)
|
|
self._last_modified = self.config_path.stat().st_mtime
|
|
logger.info(
|
|
f"已加载模型配置: {len(self._platforms)} 平台, {len(self._models)} 模型"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"加载配置文件失败: {e},使用默认配置")
|
|
self._load_defaults()
|
|
|
|
def _parse_config(self, config: dict):
|
|
"""解析配置(仅解析模型配置,API Key 从 Settings 读取)"""
|
|
# 解析平台
|
|
platforms_data = config.get("platforms", {})
|
|
for pid, pdata in platforms_data.items():
|
|
self._platforms[pid] = PlatformConfig(
|
|
id=pid,
|
|
name=pdata.get("name", pid),
|
|
provider=pdata.get("provider", pid),
|
|
priority=pdata.get("priority", 100),
|
|
base_url=pdata.get("base_url", ""),
|
|
)
|
|
|
|
# 解析模型
|
|
models_data = config.get("models", {})
|
|
for mid, mdata in models_data.items():
|
|
self._models[mid] = ModelConfig(
|
|
id=mid,
|
|
platform_id=mdata.get("platform_id", ""),
|
|
model_name=mdata.get("model_name", mid),
|
|
display_name=mdata.get("display_name", mid),
|
|
capabilities=mdata.get("capabilities", []),
|
|
default_params=mdata.get("default_params", {}),
|
|
is_enabled=mdata.get("is_enabled", True),
|
|
cost_per_1k_input=mdata.get("cost_per_1k_input", 0.0),
|
|
cost_per_1k_output=mdata.get("cost_per_1k_output", 0.0),
|
|
max_tokens_limit=mdata.get("max_tokens_limit", 4096),
|
|
)
|
|
|
|
# 解析任务默认映射
|
|
self._task_defaults = config.get("task_defaults", {})
|
|
|
|
def _load_defaults(self):
|
|
"""加载默认配置"""
|
|
self._platforms = {
|
|
"mock": PlatformConfig(
|
|
id="mock",
|
|
name="Mock 测试平台",
|
|
provider="mock",
|
|
priority=999,
|
|
)
|
|
}
|
|
self._models = {
|
|
"mock-model": ModelConfig(
|
|
id="mock-model",
|
|
platform_id="mock",
|
|
model_name="mock-model",
|
|
display_name="Mock 测试模型",
|
|
capabilities=["script", "polish", "chat"],
|
|
)
|
|
}
|
|
self._task_defaults = {
|
|
"script": "mock-model",
|
|
"polish": "mock-model",
|
|
"chat": "mock-model",
|
|
}
|
|
|
|
def reload(self):
|
|
"""重新加载配置(如果文件有更新)"""
|
|
if self.config_path.exists():
|
|
current_mtime = self.config_path.stat().st_mtime
|
|
if current_mtime > self._last_modified:
|
|
logger.info("配置文件已更新,重新加载")
|
|
self._load()
|
|
return True
|
|
return False
|
|
|
|
# ============== 查询方法 ==============
|
|
|
|
def get_platform(self, platform_id: str) -> PlatformConfig | None:
|
|
"""获取平台配置"""
|
|
return self._platforms.get(platform_id)
|
|
|
|
def get_all_platforms(self) -> list[PlatformConfig]:
|
|
"""获取所有平台(按优先级排序)"""
|
|
return sorted(self._platforms.values(), key=lambda p: p.priority)
|
|
|
|
def get_model(self, model_id: str) -> ModelConfig | None:
|
|
"""获取模型配置"""
|
|
return self._models.get(model_id)
|
|
|
|
def get_all_models(self) -> list[ModelConfig]:
|
|
"""获取所有模型"""
|
|
return list(self._models.values())
|
|
|
|
def get_enabled_models(self) -> list[ModelConfig]:
|
|
"""获取启用的模型"""
|
|
return [m for m in self._models.values() if m.is_enabled]
|
|
|
|
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
|
|
"""根据能力获取模型"""
|
|
return [
|
|
m
|
|
for m in self._models.values()
|
|
if m.is_enabled and capability in m.capabilities
|
|
]
|
|
|
|
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
|
|
"""根据平台获取模型"""
|
|
return [
|
|
m
|
|
for m in self._models.values()
|
|
if m.platform_id == platform_id and m.is_enabled
|
|
]
|
|
|
|
def get_default_model_for_task(self, task_type: str) -> str | None:
|
|
"""获取任务类型的默认模型 ID"""
|
|
return self._task_defaults.get(task_type)
|
|
|
|
def set_default_model_for_task(self, task_type: str, model_id: str):
|
|
"""设置任务类型的默认模型(内存中,不保存到文件)"""
|
|
if model_id in self._models:
|
|
self._task_defaults[task_type] = model_id
|
|
|
|
|
|
# 全局配置加载器实例
|
|
_config_loader: AIModelConfigLoader | None = None
|
|
|
|
|
|
def get_config_loader() -> AIModelConfigLoader:
|
|
"""获取全局配置加载器"""
|
|
global _config_loader
|
|
if _config_loader is None:
|
|
_config_loader = AIModelConfigLoader()
|
|
return _config_loader
|
|
|
|
|
|
def reload_config() -> bool:
|
|
"""重新加载配置"""
|
|
loader = get_config_loader()
|
|
return loader.reload()
|