95e55293c6
后端安全: - DEBUG 默认 True → False - 彻底移除 AUTH_BYPASS 认证绕过 - 验证码不再明文打印到日志 - 上传接口增加大小限制(500MB/20MB/100MB)与魔数校验 - python-jose → PyJWT, 更新 requirements.lock/uv.lock - Bandit 恢复关键规则(B104/B301/B305/B314/B324/B603/B607) - 修复 5 处 try_except_pass, 15 处加 nosec 注释 - 启用 Bandit pre-commit 钩子 前端安全: - 配置完整 CSP 策略 - 收紧 Capabilities(fs:allow-read-file → $RESOURCE/**) - 移除硬编码 devToken - 清理前端 TODO(美家卡智影命名统一) 部署修复: - docker-compose.prod 增加 alembic 迁移步骤 - api + scheduler 增加 Redis 心跳健康检查 - Nginx 添加安全响应头 - Nginx client_max_body_size 100M → 500M - .env.example 补充 UPLOAD_MAX_* 配置与安全注释 其他: - /voice/upload 合并到 /upload/audio - Rust 上传增加文件大小检查 - 清理 Rust 19 处 println! + 前端 21 处 console.info - 修复 VideoCompose.tsx toast 未导入(已有bug)
358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""
|
||
AI 模型路由 V2 - 基于文件配置
|
||
=================================
|
||
|
||
从 YAML 配置文件加载平台/模型配置。
|
||
配置在启动时加载,运行时只读,不支持热重载。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
|
||
from app.ai.adapters.constants import Method
|
||
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
|
||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||
from app.core.config_loader import AIModelConfigLoader, get_config_loader
|
||
from app.platform_gateway import PlatformGateway
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class _PlatformInstance:
|
||
"""平台实例包装器(保留兼容,内部转发到 PlatformGateway)"""
|
||
|
||
def __init__(self, config: dict, gateway: PlatformGateway | None = None):
|
||
self.config = config
|
||
self.gateway = gateway
|
||
self.provider_id = config.get("id", "")
|
||
|
||
async def generate(
|
||
self, model_name: str, prompt: str, **kwargs
|
||
) -> GenerationResult:
|
||
"""调用生成(通过 PlatformGateway)"""
|
||
if self.gateway:
|
||
result = await self.gateway.call_sync(
|
||
platform=self.provider_id,
|
||
method=Method.CHAT,
|
||
payload={
|
||
"prompt": prompt,
|
||
"model": model_name,
|
||
**kwargs,
|
||
},
|
||
)
|
||
if not result.success:
|
||
raise ProviderError(
|
||
result.error_message or f"{self.provider_id} 调用失败"
|
||
)
|
||
data = result.data or {}
|
||
return GenerationResult(
|
||
content=data.get("content", ""),
|
||
usage=data.get("usage"),
|
||
model=data.get("model", model_name),
|
||
)
|
||
# fallback: 直接通过 Provider(兼容旧初始化方式)
|
||
raise ProviderError("PlatformGateway 未初始化")
|
||
|
||
async def health_check(self, model_name: str | None = None) -> ModelHealth:
|
||
"""健康检查(通过 PlatformGateway)"""
|
||
if self.gateway:
|
||
try:
|
||
result = await self.gateway.health_check_all()
|
||
adapter_result = result.get(self.provider_id)
|
||
if adapter_result:
|
||
return ModelHealth(
|
||
id=model_name or self.provider_id,
|
||
name=self.provider_id,
|
||
is_available=adapter_result.success,
|
||
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
|
||
last_error=adapter_result.error_message,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"平台 {self.provider_id} 健康检查失败: {e}")
|
||
return ModelHealth(
|
||
id=model_name or self.provider_id,
|
||
name=self.provider_id,
|
||
is_available=False,
|
||
response_time=0,
|
||
last_error="PlatformGateway 未初始化",
|
||
)
|
||
|
||
|
||
class ModelRouter:
|
||
"""
|
||
模型路由 V2 - 基于文件配置
|
||
|
||
支持:
|
||
- 从 YAML 文件加载配置
|
||
- 多平台配置
|
||
- 每平台多模型
|
||
- 模型自动选择
|
||
"""
|
||
|
||
|
||
def __init__(self):
|
||
self.platforms: dict[str, _PlatformInstance] = {}
|
||
self._config_loader: AIModelConfigLoader | None = None
|
||
self._initialized = False
|
||
self._gateway: PlatformGateway | None = None
|
||
|
||
async def initialize(self, db_session=None, gateway: PlatformGateway | None = None):
|
||
"""初始化路由
|
||
|
||
Args:
|
||
db_session: 保留兼容性
|
||
gateway: PlatformGateway 实例,用于统一调用第三方平台
|
||
"""
|
||
if self._initialized:
|
||
return
|
||
|
||
self._gateway = gateway
|
||
|
||
# 从文件配置加载
|
||
self._config_loader = get_config_loader()
|
||
self._load_from_config()
|
||
|
||
self._initialized = True
|
||
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
|
||
|
||
def _load_from_config(self):
|
||
"""从配置文件加载平台和模型"""
|
||
self.platforms = {}
|
||
|
||
# 加载平台
|
||
for platform in self._config_loader.get_all_platforms():
|
||
try:
|
||
self.platforms[platform.id] = _PlatformInstance(
|
||
{
|
||
"id": platform.id,
|
||
"name": platform.name,
|
||
"provider": platform.provider,
|
||
"base_url": platform.base_url,
|
||
},
|
||
gateway=self._gateway,
|
||
)
|
||
logger.info(f"平台 {platform.id} 初始化成功")
|
||
except Exception as e:
|
||
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
|
||
|
||
# 加载模型到 Provider(用于模型名称映射)
|
||
volcengine_models = []
|
||
for model in self._config_loader.get_enabled_models():
|
||
if model.platform_id == "volcengine":
|
||
volcengine_models.append(
|
||
{
|
||
"id": model.id,
|
||
"model_name": model.model_name,
|
||
}
|
||
)
|
||
|
||
if volcengine_models:
|
||
VolcengineProvider.load_models_from_config(volcengine_models)
|
||
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
|
||
|
||
def get_model_config(self, model_id: str) -> dict | None:
|
||
"""获取模型配置"""
|
||
if self._config_loader:
|
||
model = self._config_loader.get_model(model_id)
|
||
if model:
|
||
return {
|
||
"id": model.id,
|
||
"platform_id": model.platform_id,
|
||
"model_name": model.model_name,
|
||
"display_name": model.display_name,
|
||
"capabilities": model.capabilities,
|
||
"default_params": model.default_params,
|
||
"cost_per_1k_input": model.cost_per_1k_input,
|
||
"cost_per_1k_output": model.cost_per_1k_output,
|
||
"max_tokens_limit": model.max_tokens_limit,
|
||
}
|
||
return None
|
||
|
||
def list_models(
|
||
self, capability: str | None = None, platform_id: str | None = None
|
||
) -> list[dict]:
|
||
"""列出可用模型"""
|
||
models = []
|
||
|
||
if self._config_loader:
|
||
if capability:
|
||
config_models = self._config_loader.get_models_by_capability(capability)
|
||
elif platform_id:
|
||
config_models = self._config_loader.get_models_by_platform(platform_id)
|
||
else:
|
||
config_models = self._config_loader.get_enabled_models()
|
||
|
||
for model in config_models:
|
||
models.append(
|
||
{
|
||
"id": model.id,
|
||
"platform_id": model.platform_id,
|
||
"model_name": model.model_name,
|
||
"display_name": model.display_name,
|
||
"capabilities": model.capabilities,
|
||
"default_params": model.default_params,
|
||
"cost_per_1k_input": model.cost_per_1k_input,
|
||
"cost_per_1k_output": model.cost_per_1k_output,
|
||
"max_tokens_limit": model.max_tokens_limit,
|
||
}
|
||
)
|
||
|
||
return models
|
||
|
||
def list_platforms(self) -> list[dict]:
|
||
"""列出所有平台"""
|
||
if self._config_loader:
|
||
return [
|
||
{
|
||
"id": p.id,
|
||
"name": p.name,
|
||
"provider": p.provider,
|
||
}
|
||
for p in self._config_loader.get_all_platforms()
|
||
]
|
||
return []
|
||
|
||
def select_model_for_task(self, task_type: str) -> str | None:
|
||
"""根据任务类型选择最佳模型"""
|
||
# 先检查任务默认配置
|
||
if self._config_loader:
|
||
default_model = self._config_loader.get_default_model_for_task(task_type)
|
||
if default_model:
|
||
model = self._config_loader.get_model(default_model)
|
||
if model and model.is_enabled:
|
||
return default_model
|
||
|
||
# 按能力匹配
|
||
candidates = self._config_loader.get_models_by_capability(task_type)
|
||
if candidates:
|
||
return candidates[0].id
|
||
|
||
return None
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model_id: str | None = None,
|
||
task_type: str | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""
|
||
生成文本
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model_id: 指定模型 ID,None 则自动选择
|
||
task_type: 任务类型(用于自动选模型)
|
||
"""
|
||
# 确定主模型
|
||
if model_id is None:
|
||
if task_type:
|
||
model_id = self.select_model_for_task(task_type)
|
||
if model_id is None:
|
||
models = (
|
||
self._config_loader.get_enabled_models()
|
||
if self._config_loader
|
||
else []
|
||
)
|
||
if models:
|
||
model_id = models[0].id
|
||
else:
|
||
raise ProviderError("没有可用的模型")
|
||
|
||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||
if not model:
|
||
raise ProviderError(f"模型不存在: {model_id}")
|
||
|
||
platform = self.platforms.get(model.platform_id)
|
||
if not platform:
|
||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||
|
||
params = {**model.default_params, **kwargs}
|
||
|
||
try:
|
||
return await platform.generate(
|
||
prompt=prompt, model_name=model.model_name, **params
|
||
)
|
||
except Exception as e:
|
||
raise ProviderError(f"模型 {model_id} 生成失败: {e}") from e
|
||
|
||
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
|
||
"""检查模型健康状态"""
|
||
# 优先通过 PlatformGateway 统一健康检查
|
||
if self._gateway:
|
||
gateway_results = await self._gateway.health_check_all()
|
||
results = {}
|
||
if self._config_loader:
|
||
for model in self._config_loader.get_enabled_models():
|
||
adapter_result = gateway_results.get(model.platform_id)
|
||
if adapter_result:
|
||
results[model.id] = ModelHealth(
|
||
id=model.id,
|
||
name=model.display_name,
|
||
is_available=adapter_result.success,
|
||
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
|
||
last_error=adapter_result.error_message,
|
||
)
|
||
else:
|
||
results[model.id] = ModelHealth(
|
||
id=model.id,
|
||
name=model.display_name,
|
||
is_available=False,
|
||
response_time=0,
|
||
last_error="平台未注册到 Gateway",
|
||
)
|
||
return results
|
||
|
||
# fallback: 直接通过 PlatformInstance
|
||
results = {}
|
||
if model_id:
|
||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||
if model:
|
||
platform = self.platforms.get(model.platform_id)
|
||
if platform:
|
||
results[model_id] = await platform.health_check(model.model_name)
|
||
else:
|
||
if self._config_loader:
|
||
for model in self._config_loader.get_enabled_models():
|
||
platform = self.platforms.get(model.platform_id)
|
||
if platform:
|
||
try:
|
||
results[model.id] = await platform.health_check(model.model_name)
|
||
except Exception as e:
|
||
results[model.id] = ModelHealth(
|
||
id=model.id,
|
||
name=model.display_name,
|
||
is_available=False,
|
||
response_time=0,
|
||
last_error=str(e),
|
||
)
|
||
return results
|
||
|
||
|
||
# 全局单例
|
||
_model_router: ModelRouter | None = None
|
||
_init_lock = asyncio.Lock()
|
||
|
||
|
||
async def get_model_router(db_session=None, gateway: PlatformGateway | None = None) -> ModelRouter:
|
||
"""获取 ModelRouter 单例(线程安全)
|
||
|
||
使用双重检查锁定模式确保并发安全。
|
||
若之前初始化失败(_initialized=False),下次调用会自动重试。
|
||
"""
|
||
global _model_router
|
||
if _model_router is None or not getattr(_model_router, "_initialized", False):
|
||
async with _init_lock:
|
||
if _model_router is None:
|
||
_model_router = ModelRouter()
|
||
if not _model_router._initialized:
|
||
logger.info("Initializing ModelRouter singleton...")
|
||
await _model_router.initialize(db_session, gateway=gateway)
|
||
logger.info("ModelRouter singleton initialized")
|
||
elif gateway is not None and _model_router._gateway is None:
|
||
# 延迟绑定 Gateway(如果之前初始化时未传入)
|
||
_model_router._gateway = gateway
|
||
for platform_instance in _model_router.platforms.values():
|
||
platform_instance.gateway = gateway
|
||
return _model_router
|