Files
meijiaka-zy/python-api/app/platform_gateway.py
T
2026-05-04 19:18:22 +08:00

163 lines
5.3 KiB
Python

"""
第三方平台统一调用网关
========================
所有第三方平台调用的唯一入口。
- 同步调用:call_sync()
- 异步任务提交:submit_task()
- 任务状态查询:query_task()
- 回调处理:handle_webhook()
"""
from __future__ import annotations
import logging
from typing import Any
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class PlatformGateway:
"""第三方平台统一调用网关"""
def __init__(self, adapters: dict[str, PlatformAdapter] | None = None):
self.adapters: dict[str, PlatformAdapter] = adapters or {}
def register(self, platform_id: str, adapter: PlatformAdapter) -> None:
"""注册平台 Adapter"""
self.adapters[platform_id] = adapter
logger.info(f"PlatformGateway 注册平台: {platform_id}")
def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable:
"""获取支持同步调用的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, SyncCapable):
raise ValueError(f"平台 {platform} 不支持同步调用")
return adapter
def _get_task_adapter(self, platform: str, task_type: str) -> TaskCapable:
"""获取支持异步任务的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, TaskCapable):
raise ValueError(f"平台 {platform} 不支持异步任务")
return adapter
def _get_callback_adapter(self, platform: str) -> CallbackCapable:
"""获取支持回调的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, CallbackCapable):
raise ValueError(f"平台 {platform} 不支持回调")
return adapter
# ── 同步调用 ──
async def call_sync(
self,
platform: str,
method: str,
payload: dict[str, Any],
) -> AdapterResponse:
"""同步调用统一入口"""
adapter = self._get_sync_adapter(platform, method)
return await adapter.call(method, payload)
# ── 异步任务 ──
async def submit_task(
self,
platform: str,
task_type: str,
payload: dict[str, Any],
callback_url: str | None = None,
idempotency_key: str | None = None,
) -> str:
"""异步任务提交统一入口,返回 internal_job_id
TODO: 接入 Async Engine 后,生成 internal_job_id 并写入 TaskRegistry
"""
adapter = self._get_task_adapter(platform, task_type)
result = await adapter.submit(task_type, payload, callback_url)
if not result.success:
raise PlatformError(
result.error_message or "任务提交失败",
platform=platform,
retryable=result.retryable,
error_type=PlatformErrorType.UNKNOWN,
)
# 当前直接返回 platform_task_id,后续接入 Async Engine 后返回 internal_job_id
return result.data.get("task_id", "")
async def query_task(self, platform: str, platform_task_id: str) -> TaskStatus:
"""任务状态查询统一入口"""
adapter = self._get_task_adapter(platform, "")
return await adapter.query(platform_task_id)
# ── 回调处理 ──
async def handle_webhook(
self,
platform: str,
headers: dict[str, str],
body: bytes,
secret: str | None = None,
callback_url: str | None = None,
) -> TaskStatus:
"""统一回调处理入口"""
adapter = self._get_callback_adapter(platform)
if secret and not await adapter.verify_signature(
headers, body, secret, callback_url=callback_url
):
raise PlatformError(
"回调签名验证失败",
platform=platform,
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
)
return await adapter.parse_callback(body)
# ── 生命周期 ──
async def close_all(self) -> None:
"""关闭所有 Adapter"""
for platform_id, adapter in self.adapters.items():
try:
await adapter.close()
logger.info(f"Adapter 关闭: {platform_id}")
except Exception as e:
logger.warning(f"Adapter 关闭失败: {platform_id}: {e}")
# ── 健康检查 ──
async def health_check_all(self) -> dict[str, AdapterResponse]:
"""检查所有平台健康状态"""
results = {}
for platform_id, adapter in self.adapters.items():
try:
results[platform_id] = await adapter.health()
except Exception as e:
results[platform_id] = AdapterResponse(
success=False,
error_message=str(e),
)
return results