90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
"""
|
||
LLM 调用网关
|
||
============
|
||
|
||
职责:
|
||
1. 按 task_type 选择模型
|
||
2. Fallback 降级链
|
||
3. 调用各平台 Adapter
|
||
4. 流式/非流式统一封装
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from app.ai.adapters.base import SyncCapable
|
||
from app.ai.adapters.constants import Method
|
||
from app.core.exceptions import PlatformError
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMGateway:
|
||
"""LLM 调用网关"""
|
||
|
||
def __init__(
|
||
self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None
|
||
):
|
||
self.adapters = adapters
|
||
self.fallback_chains = fallback_chains or {}
|
||
|
||
def _get_adapter(self, platform: str) -> SyncCapable:
|
||
adapter = self.adapters.get(platform)
|
||
if adapter is None:
|
||
raise ValueError(f"未注册的 LLM 平台: {platform}")
|
||
return adapter
|
||
|
||
async def chat(
|
||
self,
|
||
model_id: str,
|
||
prompt: str,
|
||
platform: str = "volcengine_ark",
|
||
**kwargs,
|
||
) -> dict[str, Any]:
|
||
"""同步聊天,带 Fallback
|
||
|
||
Args:
|
||
model_id: 模型别名(如 doubao-seed-2-0-pro)
|
||
prompt: 用户提示词
|
||
platform: 平台 ID
|
||
**kwargs: temperature, max_tokens, system_prompt 等
|
||
"""
|
||
models_to_try = [model_id] + self.fallback_chains.get(model_id, [])
|
||
|
||
last_error = None
|
||
for mid in models_to_try:
|
||
adapter = self._get_adapter(platform)
|
||
try:
|
||
result = await adapter.call(
|
||
Method.CHAT,
|
||
{
|
||
"prompt": prompt,
|
||
"model": mid,
|
||
**kwargs,
|
||
},
|
||
)
|
||
if result.success:
|
||
if mid != model_id:
|
||
logger.warning(f"[LLMGateway] 模型降级成功: {model_id} → {mid}")
|
||
return result.data or {}
|
||
else:
|
||
last_error = PlatformError(
|
||
result.error_message or f"模型 {mid} 调用失败",
|
||
platform=platform,
|
||
retryable=result.retryable,
|
||
)
|
||
except PlatformError as e:
|
||
last_error = e
|
||
if not e.retryable:
|
||
raise # 不可重试的错误直接抛,不再 Fallback
|
||
logger.warning(f"[LLMGateway] 模型 {mid} 失败,尝试下一个: {e}")
|
||
continue
|
||
|
||
raise last_error or PlatformError(
|
||
f"所有模型均失败: {model_id}",
|
||
platform=platform,
|
||
retryable=False,
|
||
)
|