Files
meijiaka-zy/python-api/app/ai/gateways/llm_gateway.py
T
小鱼开发 431c54c258 refactor: 前端脚本生成改为异步任务轮询,精简LLM模型,删除图片生成代码
- 前端:ScriptCreation SSE 流式改为 createTask + pollTask 轮询
- 后端:LLM 仅保留 doubao-seed-2-0-pro,删除降级链及相关模型
- 后端:删除所有图片生成代码(ImageParams/ImageTaskParams/generate_image)
- 更新 platform-config.yaml、model_router、volcengine_provider、tasks 等
2026-05-04 19:58:32 +08:00

131 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM 调用网关
============
职责:
1. 按 task_type 选择模型
2. Fallback 降级链
3. 调用各平台 Adapter
4. 流式/非流式统一封装
"""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator
from typing import Any
from app.ai.adapters.base import SyncCapable
from app.ai.adapters.constants import Method
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class LLMGateway:
"""LLM 调用网关"""
def __init__(self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None):
self.adapters = adapters
self.fallback_chains = fallback_chains or {}
def _get_adapter(self, platform: str) -> SyncCapable:
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的 LLM 平台: {platform}")
return adapter
async def chat(
self,
model_id: str,
prompt: str,
platform: str = "volcengine_ark",
**kwargs,
) -> dict[str, Any]:
"""同步聊天,带 Fallback
Args:
model_id: 模型别名(如 doubao-seed-2-0-pro
prompt: 用户提示词
platform: 平台 ID
**kwargs: temperature, max_tokens, system_prompt 等
"""
models_to_try = [model_id] + self.fallback_chains.get(model_id, [])
last_error = None
for mid in models_to_try:
adapter = self._get_adapter(platform)
try:
result = await adapter.call(Method.CHAT, {
"prompt": prompt,
"model": mid,
**kwargs,
})
if result.success:
if mid != model_id:
logger.warning(f"[LLMGateway] 模型降级成功: {model_id}{mid}")
return result.data
else:
last_error = PlatformError(
result.error_message or f"模型 {mid} 调用失败",
platform=platform,
retryable=result.retryable,
)
except PlatformError as e:
last_error = e
if not e.retryable:
raise # 不可重试的错误直接抛,不再 Fallback
logger.warning(f"[LLMGateway] 模型 {mid} 失败,尝试下一个: {e}")
continue
raise last_error or PlatformError(
f"所有模型均失败: {model_id}",
platform=platform,
retryable=False,
)
async def chat_stream(
self,
model_id: str,
prompt: str,
platform: str = "volcengine_ark",
**kwargs,
) -> AsyncIterator[dict[str, Any]]:
"""流式聊天
流式不支持 Fallback(中途切换模型会导致内容混乱)。
"""
adapter = self._get_adapter(platform)
# 检查 Adapter 是否支持流式
if not hasattr(adapter, "call_stream"):
raise PlatformError(
"平台不支持流式输出",
platform=platform,
retryable=False,
error_type=PlatformErrorType.BAD_REQUEST,
)
yielded_any = False
try:
async for chunk in adapter.call_stream(Method.CHAT_STREAM, {
"prompt": prompt,
"model": model_id,
**kwargs,
}):
yielded_any = True
yield chunk
except Exception as e:
if yielded_any:
# 已经输出内容,不再降级,直接抛
logger.error(f"[LLMGateway] 流式生成中途失败: {e}")
raise
raise PlatformError(
f"流式生成失败: {e}",
platform=platform,
retryable=True,
error_type=PlatformErrorType.UNKNOWN,
) from e