431c54c258
- 前端:ScriptCreation SSE 流式改为 createTask + pollTask 轮询 - 后端:LLM 仅保留 doubao-seed-2-0-pro,删除降级链及相关模型 - 后端:删除所有图片生成代码(ImageParams/ImageTaskParams/generate_image) - 更新 platform-config.yaml、model_router、volcengine_provider、tasks 等
131 lines
4.0 KiB
Python
131 lines
4.0 KiB
Python
"""
|
||
LLM 调用网关
|
||
============
|
||
|
||
职责:
|
||
1. 按 task_type 选择模型
|
||
2. Fallback 降级链
|
||
3. 调用各平台 Adapter
|
||
4. 流式/非流式统一封装
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from collections.abc import AsyncIterator
|
||
from typing import Any
|
||
|
||
from app.ai.adapters.base import SyncCapable
|
||
from app.ai.adapters.constants import Method
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMGateway:
|
||
"""LLM 调用网关"""
|
||
|
||
def __init__(self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None):
|
||
self.adapters = adapters
|
||
self.fallback_chains = fallback_chains or {}
|
||
|
||
def _get_adapter(self, platform: str) -> SyncCapable:
|
||
adapter = self.adapters.get(platform)
|
||
if adapter is None:
|
||
raise ValueError(f"未注册的 LLM 平台: {platform}")
|
||
return adapter
|
||
|
||
async def chat(
|
||
self,
|
||
model_id: str,
|
||
prompt: str,
|
||
platform: str = "volcengine_ark",
|
||
**kwargs,
|
||
) -> dict[str, Any]:
|
||
"""同步聊天,带 Fallback
|
||
|
||
Args:
|
||
model_id: 模型别名(如 doubao-seed-2-0-pro)
|
||
prompt: 用户提示词
|
||
platform: 平台 ID
|
||
**kwargs: temperature, max_tokens, system_prompt 等
|
||
"""
|
||
models_to_try = [model_id] + self.fallback_chains.get(model_id, [])
|
||
|
||
last_error = None
|
||
for mid in models_to_try:
|
||
adapter = self._get_adapter(platform)
|
||
try:
|
||
result = await adapter.call(Method.CHAT, {
|
||
"prompt": prompt,
|
||
"model": mid,
|
||
**kwargs,
|
||
})
|
||
if result.success:
|
||
if mid != model_id:
|
||
logger.warning(f"[LLMGateway] 模型降级成功: {model_id} → {mid}")
|
||
return result.data
|
||
else:
|
||
last_error = PlatformError(
|
||
result.error_message or f"模型 {mid} 调用失败",
|
||
platform=platform,
|
||
retryable=result.retryable,
|
||
)
|
||
except PlatformError as e:
|
||
last_error = e
|
||
if not e.retryable:
|
||
raise # 不可重试的错误直接抛,不再 Fallback
|
||
logger.warning(f"[LLMGateway] 模型 {mid} 失败,尝试下一个: {e}")
|
||
continue
|
||
|
||
raise last_error or PlatformError(
|
||
f"所有模型均失败: {model_id}",
|
||
platform=platform,
|
||
retryable=False,
|
||
)
|
||
|
||
async def chat_stream(
|
||
self,
|
||
model_id: str,
|
||
prompt: str,
|
||
platform: str = "volcengine_ark",
|
||
**kwargs,
|
||
) -> AsyncIterator[dict[str, Any]]:
|
||
"""流式聊天
|
||
|
||
流式不支持 Fallback(中途切换模型会导致内容混乱)。
|
||
"""
|
||
adapter = self._get_adapter(platform)
|
||
|
||
# 检查 Adapter 是否支持流式
|
||
if not hasattr(adapter, "call_stream"):
|
||
raise PlatformError(
|
||
"平台不支持流式输出",
|
||
platform=platform,
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
|
||
yielded_any = False
|
||
try:
|
||
async for chunk in adapter.call_stream(Method.CHAT_STREAM, {
|
||
"prompt": prompt,
|
||
"model": model_id,
|
||
**kwargs,
|
||
}):
|
||
yielded_any = True
|
||
yield chunk
|
||
except Exception as e:
|
||
if yielded_any:
|
||
# 已经输出内容,不再降级,直接抛
|
||
logger.error(f"[LLMGateway] 流式生成中途失败: {e}")
|
||
raise
|
||
raise PlatformError(
|
||
f"流式生成失败: {e}",
|
||
platform=platform,
|
||
retryable=True,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
|