e58159fc42
Phase 1: 异常体系统一 - 新增 PlatformError / PlatformErrorType 标准定义 - 改造所有 Provider 异常抛出为 PlatformError - 注册全局 PlatformError exception handler Phase 2: Adapter Protocol - 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable) - 新增 app/ai/adapters/constants.py(Method 常量) - 新增 PlatformConfigLoader(config/platform-config.yaml) Phase 3: HTTP Client 统一 - ViduProvider 从 aiohttp 迁移到 httpx(注入方式) - VolcengineCaptionService 改为注入 http_client - lifespan 统一管理所有 Client 创建和关闭 Phase 4: Gateway 骨架 + Adapter 实现 - 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter - 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook) - 新增 LLMGateway(带 Fallback 降级链) - lifespan 注册所有 Adapter 和 Gateway Phase 6: 清理与验证 - 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL - Provider 改为从 PlatformConfigLoader 读取 base_url - 清理 volcengine_caption_service 全局单例 - config_loader 默认路径改为 platform-config.yaml - Scheduler 注入共享 HTTP client - vidu.py 回调路由使用 Adapter 验签和解析 - ruff 全量通过,应用启动测试通过
163 lines
5.3 KiB
Python
163 lines
5.3 KiB
Python
"""
|
|
第三方平台统一调用网关
|
|
========================
|
|
|
|
所有第三方平台调用的唯一入口。
|
|
- 同步调用:call_sync()
|
|
- 异步任务提交:submit_task()
|
|
- 任务状态查询:query_task()
|
|
- 回调处理:handle_webhook()
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from app.ai.adapters.base import (
|
|
AdapterResponse,
|
|
CallbackCapable,
|
|
PlatformAdapter,
|
|
SyncCapable,
|
|
TaskCapable,
|
|
TaskStatus,
|
|
)
|
|
from app.core.exceptions import PlatformError, PlatformErrorType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PlatformGateway:
|
|
"""第三方平台统一调用网关"""
|
|
|
|
def __init__(self, adapters: dict[str, PlatformAdapter] | None = None):
|
|
self.adapters: dict[str, PlatformAdapter] = adapters or {}
|
|
|
|
def register(self, platform_id: str, adapter: PlatformAdapter) -> None:
|
|
"""注册平台 Adapter"""
|
|
self.adapters[platform_id] = adapter
|
|
logger.info(f"PlatformGateway 注册平台: {platform_id}")
|
|
|
|
def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable:
|
|
"""获取支持同步调用的 Adapter"""
|
|
adapter = self.adapters.get(platform)
|
|
if adapter is None:
|
|
raise ValueError(f"未注册的平台: {platform}")
|
|
if not isinstance(adapter, SyncCapable):
|
|
raise ValueError(f"平台 {platform} 不支持同步调用")
|
|
return adapter
|
|
|
|
def _get_task_adapter(self, platform: str, task_type: str) -> TaskCapable:
|
|
"""获取支持异步任务的 Adapter"""
|
|
adapter = self.adapters.get(platform)
|
|
if adapter is None:
|
|
raise ValueError(f"未注册的平台: {platform}")
|
|
if not isinstance(adapter, TaskCapable):
|
|
raise ValueError(f"平台 {platform} 不支持异步任务")
|
|
return adapter
|
|
|
|
def _get_callback_adapter(self, platform: str) -> CallbackCapable:
|
|
"""获取支持回调的 Adapter"""
|
|
adapter = self.adapters.get(platform)
|
|
if adapter is None:
|
|
raise ValueError(f"未注册的平台: {platform}")
|
|
if not isinstance(adapter, CallbackCapable):
|
|
raise ValueError(f"平台 {platform} 不支持回调")
|
|
return adapter
|
|
|
|
# ── 同步调用 ──
|
|
|
|
async def call_sync(
|
|
self,
|
|
platform: str,
|
|
method: str,
|
|
payload: dict[str, Any],
|
|
) -> AdapterResponse:
|
|
"""同步调用统一入口"""
|
|
adapter = self._get_sync_adapter(platform, method)
|
|
return await adapter.call(method, payload)
|
|
|
|
# ── 异步任务 ──
|
|
|
|
async def submit_task(
|
|
self,
|
|
platform: str,
|
|
task_type: str,
|
|
payload: dict[str, Any],
|
|
callback_url: str | None = None,
|
|
idempotency_key: str | None = None,
|
|
) -> str:
|
|
"""异步任务提交统一入口,返回 internal_job_id
|
|
|
|
TODO: 接入 Async Engine 后,生成 internal_job_id 并写入 JobRegistry
|
|
"""
|
|
adapter = self._get_task_adapter(platform, task_type)
|
|
result = await adapter.submit(task_type, payload, callback_url)
|
|
|
|
if not result.success:
|
|
raise PlatformError(
|
|
result.error_message or "任务提交失败",
|
|
platform=platform,
|
|
retryable=result.retryable,
|
|
error_type=PlatformErrorType.UNKNOWN,
|
|
)
|
|
|
|
# 当前直接返回 platform_task_id,后续接入 Async Engine 后返回 internal_job_id
|
|
return result.data.get("task_id", "")
|
|
|
|
async def query_task(self, platform: str, platform_job_id: str) -> TaskStatus:
|
|
"""任务状态查询统一入口"""
|
|
adapter = self._get_task_adapter(platform, "")
|
|
return await adapter.query(platform_job_id)
|
|
|
|
# ── 回调处理 ──
|
|
|
|
async def handle_webhook(
|
|
self,
|
|
platform: str,
|
|
headers: dict[str, str],
|
|
body: bytes,
|
|
secret: str | None = None,
|
|
callback_url: str | None = None,
|
|
) -> TaskStatus:
|
|
"""统一回调处理入口"""
|
|
adapter = self._get_callback_adapter(platform)
|
|
|
|
if secret and not await adapter.verify_signature(
|
|
headers, body, secret, callback_url=callback_url
|
|
):
|
|
raise PlatformError(
|
|
"回调签名验证失败",
|
|
platform=platform,
|
|
retryable=False,
|
|
error_type=PlatformErrorType.AUTH_FAILED,
|
|
)
|
|
|
|
return await adapter.parse_callback(body)
|
|
|
|
# ── 生命周期 ──
|
|
|
|
async def close_all(self) -> None:
|
|
"""关闭所有 Adapter"""
|
|
for platform_id, adapter in self.adapters.items():
|
|
try:
|
|
await adapter.close()
|
|
logger.info(f"Adapter 关闭: {platform_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Adapter 关闭失败: {platform_id}: {e}")
|
|
|
|
# ── 健康检查 ──
|
|
|
|
async def health_check_all(self) -> dict[str, AdapterResponse]:
|
|
"""检查所有平台健康状态"""
|
|
results = {}
|
|
for platform_id, adapter in self.adapters.items():
|
|
try:
|
|
results[platform_id] = await adapter.health()
|
|
except Exception as e:
|
|
results[platform_id] = AdapterResponse(
|
|
success=False,
|
|
error_message=str(e),
|
|
)
|
|
return results
|