431c54c258
- 前端:ScriptCreation SSE 流式改为 createTask + pollTask 轮询 - 后端:LLM 仅保留 doubao-seed-2-0-pro,删除降级链及相关模型 - 后端:删除所有图片生成代码(ImageParams/ImageTaskParams/generate_image) - 更新 platform-config.yaml、model_router、volcengine_provider、tasks 等
410 lines
14 KiB
Python
410 lines
14 KiB
Python
"""
|
||
AI 模型路由 V2 - 基于文件配置
|
||
=================================
|
||
|
||
从 YAML 配置文件加载平台/模型配置,支持热重载。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from collections.abc import AsyncIterator
|
||
|
||
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
|
||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||
from app.config import get_settings
|
||
from app.core.config_loader import AIModelConfigLoader, get_config_loader
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class PlatformInstance:
|
||
"""平台实例包装器"""
|
||
|
||
def __init__(self, config: dict):
|
||
self.config = config
|
||
self.provider = self._create_provider()
|
||
|
||
def _create_provider(self):
|
||
"""根据平台类型创建 Provider
|
||
|
||
API Key 从 Settings 读取(符合配置规范)
|
||
"""
|
||
provider_type = self.config.get("provider", "volcengine")
|
||
settings = get_settings()
|
||
|
||
if provider_type == "volcengine":
|
||
# 从 Settings 读取 API Key
|
||
api_key = settings.VOLCENGINE_API_KEY
|
||
if not api_key:
|
||
raise ProviderError(
|
||
"Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY"
|
||
)
|
||
base_url = self.config.get("base_url")
|
||
if not base_url:
|
||
from app.core.platform_config import get_platform_config_loader
|
||
|
||
platform_config = get_platform_config_loader().get_platform("volcengine_ark")
|
||
base_url = platform_config.base_url if platform_config else "https://ark.cn-beijing.volces.com/api/v3"
|
||
return VolcengineProvider(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
else:
|
||
raise ProviderError(f"不支持的 Provider 类型: {provider_type}")
|
||
|
||
async def generate(
|
||
self, model_name: str, prompt: str, **kwargs
|
||
) -> GenerationResult:
|
||
"""调用生成"""
|
||
return await self.provider.generate(prompt=prompt, model=model_name, **kwargs)
|
||
|
||
async def generate_stream(
|
||
self, model_name: str, prompt: str, **kwargs
|
||
) -> AsyncIterator[str]:
|
||
"""流式生成"""
|
||
async for chunk in self.provider.generate_stream(
|
||
prompt=prompt, model=model_name, **kwargs
|
||
):
|
||
yield chunk
|
||
|
||
async def health_check(self, model_name: str | None = None) -> ModelHealth:
|
||
"""健康检查"""
|
||
return await self.provider.health_check(model_name)
|
||
|
||
|
||
class ModelRouter:
|
||
"""
|
||
模型路由 V2 - 基于文件配置
|
||
|
||
支持:
|
||
- 从 YAML 文件加载配置
|
||
- 多平台配置
|
||
- 每平台多模型
|
||
- 模型自动选择
|
||
- 故障降级
|
||
- 配置热重载
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.platforms: dict[str, PlatformInstance] = {}
|
||
self._config_loader: AIModelConfigLoader | None = None
|
||
self._initialized = False
|
||
|
||
async def initialize(self, db_session=None):
|
||
"""初始化路由(db_session 参数保留兼容性,实际不使用)"""
|
||
if self._initialized:
|
||
return
|
||
|
||
# 从文件配置加载
|
||
self._config_loader = get_config_loader()
|
||
self._load_from_config()
|
||
|
||
self._initialized = True
|
||
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
|
||
|
||
def _load_from_config(self):
|
||
"""从配置文件加载平台和模型"""
|
||
self.platforms = {}
|
||
|
||
# 加载平台
|
||
for platform in self._config_loader.get_all_platforms():
|
||
try:
|
||
# PlatformInstance 自动从 Settings 读取 API Key
|
||
self.platforms[platform.id] = PlatformInstance(
|
||
{
|
||
"id": platform.id,
|
||
"name": platform.name,
|
||
"provider": platform.provider,
|
||
"base_url": platform.base_url,
|
||
}
|
||
)
|
||
logger.info(f"平台 {platform.id} 初始化成功")
|
||
except Exception as e:
|
||
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
|
||
|
||
# 加载模型到 Provider(用于模型名称映射)
|
||
volcengine_models = []
|
||
for model in self._config_loader.get_enabled_models():
|
||
if model.platform_id == "volcengine":
|
||
volcengine_models.append(
|
||
{
|
||
"id": model.id,
|
||
"model_name": model.model_name,
|
||
}
|
||
)
|
||
|
||
if volcengine_models:
|
||
VolcengineProvider.load_models_from_config(volcengine_models)
|
||
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
|
||
|
||
def reload_config(self) -> bool:
|
||
"""重新加载配置"""
|
||
if self._config_loader and self._config_loader.reload():
|
||
self._load_from_config()
|
||
return True
|
||
return False
|
||
|
||
def get_model_config(self, model_id: str) -> dict | None:
|
||
"""获取模型配置"""
|
||
if self._config_loader:
|
||
model = self._config_loader.get_model(model_id)
|
||
if model:
|
||
return {
|
||
"id": model.id,
|
||
"platform_id": model.platform_id,
|
||
"model_name": model.model_name,
|
||
"display_name": model.display_name,
|
||
"capabilities": model.capabilities,
|
||
"default_params": model.default_params,
|
||
"cost_per_1k_input": model.cost_per_1k_input,
|
||
"cost_per_1k_output": model.cost_per_1k_output,
|
||
"max_tokens_limit": model.max_tokens_limit,
|
||
}
|
||
return None
|
||
|
||
def list_models(
|
||
self, capability: str | None = None, platform_id: str | None = None
|
||
) -> list[dict]:
|
||
"""列出可用模型"""
|
||
models = []
|
||
|
||
if self._config_loader:
|
||
if capability:
|
||
config_models = self._config_loader.get_models_by_capability(capability)
|
||
elif platform_id:
|
||
config_models = self._config_loader.get_models_by_platform(platform_id)
|
||
else:
|
||
config_models = self._config_loader.get_enabled_models()
|
||
|
||
for model in config_models:
|
||
models.append(
|
||
{
|
||
"id": model.id,
|
||
"platform_id": model.platform_id,
|
||
"model_name": model.model_name,
|
||
"display_name": model.display_name,
|
||
"capabilities": model.capabilities,
|
||
"default_params": model.default_params,
|
||
"cost_per_1k_input": model.cost_per_1k_input,
|
||
"cost_per_1k_output": model.cost_per_1k_output,
|
||
"max_tokens_limit": model.max_tokens_limit,
|
||
}
|
||
)
|
||
|
||
return models
|
||
|
||
def list_platforms(self) -> list[dict]:
|
||
"""列出所有平台"""
|
||
if self._config_loader:
|
||
return [
|
||
{
|
||
"id": p.id,
|
||
"name": p.name,
|
||
"provider": p.provider,
|
||
}
|
||
for p in self._config_loader.get_all_platforms()
|
||
]
|
||
return []
|
||
|
||
def select_model_for_task(self, task_type: str) -> str | None:
|
||
"""根据任务类型选择最佳模型"""
|
||
# 先检查任务默认配置
|
||
if self._config_loader:
|
||
default_model = self._config_loader.get_default_model_for_task(task_type)
|
||
if default_model:
|
||
model = self._config_loader.get_model(default_model)
|
||
if model and model.is_enabled:
|
||
return default_model
|
||
|
||
# 按能力匹配
|
||
candidates = self._config_loader.get_models_by_capability(task_type)
|
||
if candidates:
|
||
return candidates[0].id
|
||
|
||
return None
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model_id: str | None = None,
|
||
task_type: str | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model_id: 指定模型 ID,None 则自动选择
|
||
task_type: 任务类型(用于自动选模型)
|
||
"""
|
||
# 确定主模型
|
||
if model_id is None:
|
||
if task_type:
|
||
model_id = self.select_model_for_task(task_type)
|
||
if model_id is None:
|
||
models = (
|
||
self._config_loader.get_enabled_models()
|
||
if self._config_loader
|
||
else []
|
||
)
|
||
if models:
|
||
model_id = models[0].id
|
||
else:
|
||
raise ProviderError("没有可用的模型")
|
||
|
||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||
if not model:
|
||
raise ProviderError(f"模型不存在: {model_id}")
|
||
|
||
platform = self.platforms.get(model.platform_id)
|
||
if not platform:
|
||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||
|
||
params = {**model.default_params, **kwargs}
|
||
|
||
try:
|
||
return await platform.generate(
|
||
prompt=prompt, model_name=model.model_name, **params
|
||
)
|
||
except Exception as e:
|
||
raise ProviderError(f"模型 {model_id} 生成失败: {e}") from e
|
||
|
||
async def _try_generate_stream(
|
||
self,
|
||
model_id: str,
|
||
prompt: str,
|
||
**kwargs,
|
||
):
|
||
"""尝试单个模型的流式生成"""
|
||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||
if not model:
|
||
raise ProviderError(f"模型不存在: {model_id}")
|
||
|
||
platform = self.platforms.get(model.platform_id)
|
||
if not platform:
|
||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||
|
||
params = {**model.default_params, **kwargs}
|
||
provider = platform.provider
|
||
|
||
if hasattr(provider, "generate_stream_with_progress"):
|
||
async for chunk in provider.generate_stream_with_progress(
|
||
prompt=prompt, model=model.model_name, **params
|
||
):
|
||
yield chunk
|
||
else:
|
||
full_content = ""
|
||
async for content in provider.generate_stream(
|
||
prompt=prompt, model=model.model_name, **params
|
||
):
|
||
full_content += content
|
||
yield {
|
||
"type": "chunk",
|
||
"content": content,
|
||
"total_chars": len(full_content),
|
||
}
|
||
yield {
|
||
"type": "usage",
|
||
"prompt_tokens": 0,
|
||
"completion_tokens": 0,
|
||
}
|
||
|
||
async def generate_stream_with_progress(
|
||
self,
|
||
prompt: str,
|
||
model_id: str | None = None,
|
||
task_type: str | None = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
流式生成文本,带进度信息,支持模型降级
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model_id: 指定模型 ID
|
||
task_type: 任务类型
|
||
**kwargs: 其他参数
|
||
|
||
Yields:
|
||
dict: 包含 type, content, total_chars 等字段
|
||
"""
|
||
# 确定主模型
|
||
if model_id is None:
|
||
if task_type:
|
||
model_id = self.select_model_for_task(task_type)
|
||
if model_id is None:
|
||
models = (
|
||
self._config_loader.get_enabled_models()
|
||
if self._config_loader
|
||
else []
|
||
)
|
||
if models:
|
||
model_id = models[0].id
|
||
else:
|
||
raise ProviderError("没有可用的模型")
|
||
|
||
yielded_any = False
|
||
try:
|
||
async for chunk in self._try_generate_stream(model_id, prompt, **kwargs):
|
||
yielded_any = True
|
||
yield chunk
|
||
except Exception as e:
|
||
if yielded_any:
|
||
logger.error(f"[ModelRouter] 模型 {model_id} 流式生成中途失败: {e}")
|
||
raise ProviderError(f"模型 {model_id} 流式生成失败: {e}") from e
|
||
|
||
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
|
||
"""检查模型健康状态"""
|
||
results = {}
|
||
|
||
if model_id:
|
||
model = (
|
||
self._config_loader.get_model(model_id) if self._config_loader else None
|
||
)
|
||
if model:
|
||
platform = self.platforms.get(model.platform_id)
|
||
if platform:
|
||
results[model_id] = await platform.health_check(model.model_name)
|
||
else:
|
||
# 检查所有模型
|
||
if self._config_loader:
|
||
for model in self._config_loader.get_enabled_models():
|
||
platform = self.platforms.get(model.platform_id)
|
||
if platform:
|
||
try:
|
||
results[model.id] = await platform.health_check(
|
||
model.model_name
|
||
)
|
||
except Exception as e:
|
||
results[model.id] = ModelHealth(
|
||
id=model.id,
|
||
name=model.display_name,
|
||
is_available=False,
|
||
response_time=0,
|
||
last_error=str(e),
|
||
)
|
||
|
||
return results
|
||
|
||
|
||
# 全局单例
|
||
_model_router: ModelRouter | None = None
|
||
_init_lock = asyncio.Lock()
|
||
|
||
|
||
async def get_model_router(db_session=None) -> ModelRouter:
|
||
"""获取 ModelRouter 单例(线程安全)
|
||
|
||
使用双重检查锁定模式确保并发安全。
|
||
"""
|
||
global _model_router
|
||
if _model_router is None:
|
||
async with _init_lock:
|
||
# 双重检查,防止在获取锁期间其他协程已初始化
|
||
if _model_router is None:
|
||
logger.info("Initializing ModelRouter singleton...")
|
||
_model_router = ModelRouter()
|
||
await _model_router.initialize(db_session)
|
||
logger.info("ModelRouter singleton initialized")
|
||
return _model_router
|