""" LLM 调用网关 ============ 职责: 1. 按 task_type 选择模型 2. Fallback 降级链 3. 调用各平台 Adapter 4. 流式/非流式统一封装 """ from __future__ import annotations import logging from typing import Any from app.ai.adapters.base import SyncCapable from app.ai.adapters.constants import Method from app.core.exceptions import PlatformError logger = logging.getLogger(__name__) class LLMGateway: """LLM 调用网关""" def __init__( self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None ): self.adapters = adapters self.fallback_chains = fallback_chains or {} def _get_adapter(self, platform: str) -> SyncCapable: adapter = self.adapters.get(platform) if adapter is None: raise ValueError(f"未注册的 LLM 平台: {platform}") return adapter async def chat( self, model_id: str, prompt: str, platform: str = "volcengine_ark", **kwargs, ) -> dict[str, Any]: """同步聊天,带 Fallback Args: model_id: 模型别名(如 doubao-seed-2-0-pro) prompt: 用户提示词 platform: 平台 ID **kwargs: temperature, max_tokens, system_prompt 等 """ models_to_try = [model_id] + self.fallback_chains.get(model_id, []) last_error = None for mid in models_to_try: adapter = self._get_adapter(platform) try: result = await adapter.call( Method.CHAT, { "prompt": prompt, "model": mid, **kwargs, }, ) if result.success: if mid != model_id: logger.warning(f"[LLMGateway] 模型降级成功: {model_id} → {mid}") return result.data or {} else: last_error = PlatformError( result.error_message or f"模型 {mid} 调用失败", platform=platform, retryable=result.retryable, ) except PlatformError as e: last_error = e if not e.retryable: raise # 不可重试的错误直接抛,不再 Fallback logger.warning(f"[LLMGateway] 模型 {mid} 失败,尝试下一个: {e}") continue raise last_error or PlatformError( f"所有模型均失败: {model_id}", platform=platform, retryable=False, )