Files
meijiaka-zy/python-api/app/ai/model_router.py
T

418 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 模型路由 V2 - 基于文件配置
=================================
从 YAML 配置文件加载平台/模型配置,支持热重载。
"""
import asyncio
import logging
from collections.abc import AsyncIterator
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
from app.ai.providers.generic_llm_provider import MockProvider
from app.ai.providers.klingai_provider import KlingAIProvider
from app.ai.providers.volcengine_provider import VolcengineProvider
from app.config import get_settings
from app.core.config_loader import AIModelConfigLoader, get_config_loader
logger = logging.getLogger(__name__)
class PlatformInstance:
"""平台实例包装器"""
def __init__(self, config: dict):
self.config = config
self.provider = self._create_provider()
def _create_provider(self):
"""根据平台类型创建 Provider
API Key 从 Settings 读取(符合配置规范)
"""
provider_type = self.config.get("provider", "mock")
settings = get_settings()
if provider_type == "volcengine":
# 从 Settings 读取 API Key
api_key = settings.VOLCENGINE_API_KEY
if not api_key:
raise ProviderError(
"Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY"
)
return VolcengineProvider(
api_key=api_key,
base_url=self.config.get("base_url") or settings.VOLCENGINE_BASE_URL,
)
elif provider_type == "klingai":
# 从 Settings 读取 AK/SK
access_key = settings.KLINGAI_ACCESS_KEY
secret_key = settings.KLINGAI_SECRET_KEY
if not access_key or not secret_key:
raise ProviderError(
"KlingAI Access/Secret Key 未配置,请在 .env 中设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY"
)
return KlingAIProvider(
config={
"access_key": access_key,
"secret_key": secret_key,
"base_url": self.config.get("base_url"),
}
)
elif provider_type == "mock":
return MockProvider()
else:
raise ProviderError(f"不支持的 Provider 类型: {provider_type}")
async def generate(
self, model_name: str, prompt: str, **kwargs
) -> GenerationResult:
"""调用生成"""
return await self.provider.generate(prompt=prompt, model=model_name, **kwargs)
async def generate_stream(
self, model_name: str, prompt: str, **kwargs
) -> AsyncIterator[str]:
"""流式生成"""
async for chunk in self.provider.generate_stream(
prompt=prompt, model=model_name, **kwargs
):
yield chunk
async def health_check(self, model_name: str | None = None) -> ModelHealth:
"""健康检查"""
return await self.provider.health_check(model_name)
class ModelRouter:
"""
模型路由 V2 - 基于文件配置
支持:
- 从 YAML 文件加载配置
- 多平台配置
- 每平台多模型
- 模型自动选择
- 故障降级
- 配置热重载
"""
def __init__(self):
self.platforms: dict[str, PlatformInstance] = {}
self._config_loader: AIModelConfigLoader | None = None
self._initialized = False
async def initialize(self, db_session=None):
"""初始化路由(db_session 参数保留兼容性,实际不使用)"""
if self._initialized:
return
# 从文件配置加载
self._config_loader = get_config_loader()
self._load_from_config()
self._initialized = True
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
def _load_from_config(self):
"""从配置文件加载平台和模型"""
self.platforms = {}
# 加载平台
for platform in self._config_loader.get_all_platforms():
try:
# PlatformInstance 自动从 Settings 读取 API Key
self.platforms[platform.id] = PlatformInstance(
{
"id": platform.id,
"name": platform.name,
"provider": platform.provider,
"base_url": platform.base_url,
}
)
logger.info(f"平台 {platform.id} 初始化成功")
except Exception as e:
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
# 加载模型到 Provider(用于模型名称映射)
volcengine_models = []
for model in self._config_loader.get_enabled_models():
if model.platform_id == "volcengine":
volcengine_models.append(
{
"id": model.id,
"model_name": model.model_name,
}
)
if volcengine_models:
VolcengineProvider.load_models_from_config(volcengine_models)
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
def reload_config(self) -> bool:
"""重新加载配置"""
if self._config_loader and self._config_loader.reload():
self._load_from_config()
return True
return False
def get_model_config(self, model_id: str) -> dict | None:
"""获取模型配置"""
if self._config_loader:
model = self._config_loader.get_model(model_id)
if model:
return {
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
return None
def list_models(
self, capability: str | None = None, platform_id: str | None = None
) -> list[dict]:
"""列出可用模型"""
models = []
if self._config_loader:
if capability:
config_models = self._config_loader.get_models_by_capability(capability)
elif platform_id:
config_models = self._config_loader.get_models_by_platform(platform_id)
else:
config_models = self._config_loader.get_enabled_models()
for model in config_models:
models.append(
{
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
)
return models
def list_platforms(self) -> list[dict]:
"""列出所有平台"""
if self._config_loader:
return [
{
"id": p.id,
"name": p.name,
"provider": p.provider,
}
for p in self._config_loader.get_all_platforms()
]
return []
def select_model_for_task(self, task_type: str) -> str | None:
"""根据任务类型选择最佳模型"""
# 先检查任务默认配置
if self._config_loader:
default_model = self._config_loader.get_default_model_for_task(task_type)
if default_model:
model = self._config_loader.get_model(default_model)
if model and model.is_enabled:
return default_model
# 按能力匹配
candidates = self._config_loader.get_models_by_capability(task_type)
if candidates:
return candidates[0].id
return None
async def generate(
self,
prompt: str,
model_id: str | None = None,
task_type: str | None = None,
**kwargs,
) -> GenerationResult:
"""
生成文本
Args:
prompt: 提示词
model_id: 指定模型 IDNone 则自动选择
task_type: 任务类型(用于自动选模型)
"""
# 确定模型
if model_id is None:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
# 使用第一个可用模型
models = (
self._config_loader.get_enabled_models()
if self._config_loader
else []
)
if models:
model_id = models[0].id
else:
raise ProviderError("没有可用的模型")
if self._config_loader:
model = self._config_loader.get_model(model_id)
if not model:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(model.platform_id)
if not platform:
raise ProviderError(f"平台不存在: {model.platform_id}")
# 合并默认参数
params = {**model.default_params, **kwargs}
# 调用生成
try:
result = await platform.generate(
prompt=prompt, model_name=model.model_name, **params
)
return result
except Exception as e:
logger.error(f"模型 {model_id} 生成失败: {e}")
raise
async def generate_stream_with_progress(
self,
prompt: str,
model_id: str | None = None,
task_type: str | None = None,
**kwargs,
):
"""
流式生成文本,带进度信息
Args:
prompt: 提示词
model_id: 指定模型 ID
task_type: 任务类型
**kwargs: 其他参数
Yields:
dict: 包含 type, content, total_chars 等字段
"""
# 确定模型
if model_id is None:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
models = (
self._config_loader.get_enabled_models()
if self._config_loader
else []
)
if models:
model_id = models[0].id
else:
raise ProviderError("没有可用的模型")
model = self._config_loader.get_model(model_id) if self._config_loader else None
if not model:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(model.platform_id)
if not platform:
raise ProviderError(f"平台不存在: {model.platform_id}")
# 合并默认参数
params = {**model.default_params, **kwargs}
# 检查 provider 是否有 generate_stream_with_progress 方法
provider = platform.provider
if hasattr(provider, "generate_stream_with_progress"):
async for chunk in provider.generate_stream_with_progress(
prompt=prompt, model=model.model_name, **params
):
yield chunk
else:
# 降级到普通流式生成
full_content = ""
async for content in provider.generate_stream(
prompt=prompt, model=model.model_name, **params
):
full_content += content
yield {
"type": "chunk",
"content": content,
"total_chars": len(full_content),
}
yield {
"type": "usage",
"prompt_tokens": 0,
"completion_tokens": 0,
}
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
"""检查模型健康状态"""
results = {}
if model_id:
model = (
self._config_loader.get_model(model_id) if self._config_loader else None
)
if model:
platform = self.platforms.get(model.platform_id)
if platform:
results[model_id] = await platform.health_check(model.model_name)
else:
# 检查所有模型
if self._config_loader:
for model in self._config_loader.get_enabled_models():
platform = self.platforms.get(model.platform_id)
if platform:
try:
results[model.id] = await platform.health_check(
model.model_name
)
except Exception as e:
results[model.id] = ModelHealth(
id=model.id,
name=model.display_name,
is_available=False,
response_time=0,
last_error=str(e),
)
return results
# 全局单例
_model_router: ModelRouter | None = None
_init_lock = asyncio.Lock()
async def get_model_router(db_session=None) -> ModelRouter:
"""获取 ModelRouter 单例(线程安全)
使用双重检查锁定模式确保并发安全。
"""
global _model_router
if _model_router is None:
async with _init_lock:
# 双重检查,防止在获取锁期间其他协程已初始化
if _model_router is None:
logger.info("Initializing ModelRouter singleton...")
_model_router = ModelRouter()
await _model_router.initialize(db_session)
logger.info("ModelRouter singleton initialized")
return _model_router