Files
meijiaka-zy/python-api/app/ai/model_router.py
T

359 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 模型路由 V2 - 基于文件配置
=================================
从 YAML 配置文件加载平台/模型配置。
配置在启动时加载,运行时只读,不支持热重载。
"""
import asyncio
import logging
from app.ai.adapters.constants import Method
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
from app.ai.providers.volcengine_provider import VolcengineProvider
from app.core.config_loader import AIModelConfigLoader, get_config_loader
from app.core.exceptions import AppException, PlatformError
from app.platform_gateway import PlatformGateway
logger = logging.getLogger(__name__)
class _PlatformInstance:
"""平台实例包装器(保留兼容,内部转发到 PlatformGateway"""
def __init__(self, config: dict, gateway: PlatformGateway | None = None):
self.config = config
self.gateway = gateway
self.provider_id = config.get("id", "")
async def generate(self, model_name: str, prompt: str, **kwargs) -> GenerationResult:
"""调用生成(通过 PlatformGateway"""
if self.gateway:
result = await self.gateway.call_sync(
platform=self.provider_id,
method=Method.CHAT,
payload={
"prompt": prompt,
"model": model_name,
**kwargs,
},
)
if not result.success:
raise ProviderError(result.error_message or f"{self.provider_id} 调用失败")
data = result.data or {}
return GenerationResult(
content=data.get("content", ""),
usage=data.get("usage"),
model=data.get("model", model_name),
)
# fallback: 直接通过 Provider(兼容旧初始化方式)
raise ProviderError("PlatformGateway 未初始化")
async def health_check(self, model_name: str | None = None) -> ModelHealth:
"""健康检查(通过 PlatformGateway"""
if self.gateway:
try:
result = await self.gateway.health_check_all()
adapter_result = result.get(self.provider_id)
if adapter_result:
return ModelHealth(
id=model_name or self.provider_id,
name=self.provider_id,
is_available=adapter_result.success,
response_time=(
adapter_result.data.get("response_time_ms", 0)
if adapter_result.data
else 0
),
last_error=adapter_result.error_message,
)
except Exception as e:
logger.warning(f"平台 {self.provider_id} 健康检查失败: {e}")
return ModelHealth(
id=model_name or self.provider_id,
name=self.provider_id,
is_available=False,
response_time=0,
last_error="PlatformGateway 未初始化",
)
class ModelRouter:
"""
模型路由 V2 - 基于文件配置
支持:
- 从 YAML 文件加载配置
- 多平台配置
- 每平台多模型
- 模型自动选择
"""
def __init__(self):
self.platforms: dict[str, _PlatformInstance] = {}
self._config_loader: AIModelConfigLoader | None = None
self._initialized = False
self._gateway: PlatformGateway | None = None
async def initialize(self, db_session=None, gateway: PlatformGateway | None = None):
"""初始化路由
Args:
db_session: 保留兼容性
gateway: PlatformGateway 实例,用于统一调用第三方平台
"""
if self._initialized:
return
self._gateway = gateway
# 从文件配置加载
self._config_loader = get_config_loader()
self._load_from_config()
self._initialized = True
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
def _load_from_config(self):
"""从配置文件加载平台和模型"""
self.platforms = {}
# 加载平台
for platform in self._config_loader.get_all_platforms():
try:
self.platforms[platform.id] = _PlatformInstance(
{
"id": platform.id,
"name": platform.name,
"provider": platform.provider,
"base_url": platform.base_url,
},
gateway=self._gateway,
)
logger.info(f"平台 {platform.id} 初始化成功")
except Exception as e:
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
# 加载模型到 Provider(用于模型名称映射)
volcengine_models = []
for model in self._config_loader.get_enabled_models():
if model.platform_id == "volcengine":
volcengine_models.append(
{
"id": model.id,
"model_name": model.model_name,
}
)
if volcengine_models:
VolcengineProvider.load_models_from_config(volcengine_models)
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
def get_model_config(self, model_id: str) -> dict | None:
"""获取模型配置"""
if self._config_loader:
model = self._config_loader.get_model(model_id)
if model:
return {
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
return None
def list_models(
self, capability: str | None = None, platform_id: str | None = None
) -> list[dict]:
"""列出可用模型"""
models = []
if self._config_loader:
if capability:
config_models = self._config_loader.get_models_by_capability(capability)
elif platform_id:
config_models = self._config_loader.get_models_by_platform(platform_id)
else:
config_models = self._config_loader.get_enabled_models()
for model in config_models:
models.append(
{
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
)
return models
def list_platforms(self) -> list[dict]:
"""列出所有平台"""
if self._config_loader:
return [
{
"id": p.id,
"name": p.name,
"provider": p.provider,
}
for p in self._config_loader.get_all_platforms()
]
return []
def select_model_for_task(self, task_type: str) -> str | None:
"""根据任务类型选择最佳模型"""
# 先检查任务默认配置
if self._config_loader:
default_model = self._config_loader.get_default_model_for_task(task_type)
if default_model:
model = self._config_loader.get_model(default_model)
if model and model.is_enabled:
return default_model
# 按能力匹配
candidates = self._config_loader.get_models_by_capability(task_type)
if candidates:
return candidates[0].id
return None
async def generate(
self,
prompt: str,
model_id: str | None = None,
task_type: str | None = None,
**kwargs,
) -> GenerationResult:
"""
生成文本
Args:
prompt: 提示词
model_id: 指定模型 IDNone 则自动选择
task_type: 任务类型(用于自动选模型)
"""
# 确定主模型
if model_id is None:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
models = self._config_loader.get_enabled_models() if self._config_loader else []
if models:
model_id = models[0].id
else:
raise ProviderError("没有可用的模型")
model = self._config_loader.get_model(model_id) if self._config_loader else None
if not model:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(model.platform_id)
if not platform:
raise ProviderError(f"平台不存在: {model.platform_id}")
params = {**model.default_params, **kwargs}
try:
return await platform.generate(prompt=prompt, model_name=model.model_name, **params)
except (PlatformError, AppException):
raise
except Exception as e:
raise ProviderError(f"模型 {model_id} 生成失败: {e}") from e
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
"""检查模型健康状态"""
# 优先通过 PlatformGateway 统一健康检查
if self._gateway:
gateway_results = await self._gateway.health_check_all()
results = {}
if self._config_loader:
for model in self._config_loader.get_enabled_models():
adapter_result = gateway_results.get(model.platform_id)
if adapter_result:
results[model.id] = ModelHealth(
id=model.id,
name=model.display_name,
is_available=adapter_result.success,
response_time=(
adapter_result.data.get("response_time_ms", 0)
if adapter_result.data
else 0
),
last_error=adapter_result.error_message,
)
else:
results[model.id] = ModelHealth(
id=model.id,
name=model.display_name,
is_available=False,
response_time=0,
last_error="平台未注册到 Gateway",
)
return results
# fallback: 直接通过 PlatformInstance
results = {}
if model_id:
target_model = self._config_loader.get_model(model_id) if self._config_loader else None
if target_model is None:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(target_model.platform_id)
if platform:
results[model_id] = await platform.health_check(target_model.model_name)
else:
if self._config_loader:
for model in self._config_loader.get_enabled_models():
platform = self.platforms.get(model.platform_id)
if platform:
try:
results[model.id] = await platform.health_check(model.model_name)
except Exception as e:
results[model.id] = ModelHealth(
id=model.id,
name=model.display_name,
is_available=False,
response_time=0,
last_error=str(e),
)
return results
# 全局单例
_model_router: ModelRouter | None = None
_init_lock = asyncio.Lock()
async def get_model_router(db_session=None, gateway: PlatformGateway | None = None) -> ModelRouter:
"""获取 ModelRouter 单例(线程安全)
使用双重检查锁定模式确保并发安全。
若之前初始化失败(_initialized=False),下次调用会自动重试。
"""
global _model_router
if _model_router is None or not getattr(_model_router, "_initialized", False):
async with _init_lock:
if _model_router is None:
_model_router = ModelRouter()
if not _model_router._initialized:
logger.info("Initializing ModelRouter singleton...")
await _model_router.initialize(db_session, gateway=gateway)
logger.info("ModelRouter singleton initialized")
elif gateway is not None and _model_router._gateway is None:
# 延迟绑定 Gateway(如果之前初始化时未传入)
_model_router._gateway = gateway
for platform_instance in _model_router.platforms.values():
platform_instance.gateway = gateway
return _model_router