Files
meijiaka-zy/python-api/app/platform_gateway.py
T

363 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
第三方平台统一调用网关
========================
所有第三方平台调用的唯一入口。
- 同步调用:call_sync()
- 异步任务提交:submit_task()
- 任务状态查询:query_task()
- 回调处理:handle_webhook()
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.core.exceptions import PlatformError, PlatformErrorType
from app.utils.content_fingerprint import (
compute_content_fingerprint,
is_vidu_audit_error,
)
logger = logging.getLogger(__name__)
# Redis key 前缀:内部 task_id → platform_task_id 映射
_TASK_MAPPING_PREFIX = "platform_gateway:task_mapping"
_TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天
# Redis key 前缀:内容审核失败缓存
_AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection"
_AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时
class PlatformGateway:
"""第三方平台统一调用网关"""
def __init__(
self,
adapters: dict[str, PlatformAdapter] | None = None,
redis=None,
):
self.adapters: dict[str, PlatformAdapter] = adapters or {}
self._redis = redis
def _get_redis(self):
"""懒加载 Redis 客户端"""
if self._redis is None:
from app.core.redis_client import get_redis_client
self._redis = get_redis_client()
return self._redis
def _task_mapping_key(self, internal_task_id: str) -> str:
return f"{_TASK_MAPPING_PREFIX}:{internal_task_id}"
async def _store_task_mapping(
self, internal_task_id: str, platform: str, platform_task_id: str
) -> None:
"""存储内部 task_id 与平台 task_id 的双向映射关系"""
redis = self._get_redis()
# 正向映射:internal → platform
key = self._task_mapping_key(internal_task_id)
await redis.hset(
key,
mapping={
"platform": platform,
"platform_task_id": platform_task_id,
},
)
await redis.expire(key, _TASK_MAPPING_TTL)
# 反向映射:platform → internal(供回调查找)
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
await redis.setex(reverse_key, _TASK_MAPPING_TTL, internal_task_id)
async def _get_task_mapping(self, internal_task_id: str) -> dict[str, str] | None:
"""查询内部 task_id 对应的平台映射"""
redis = self._get_redis()
key = self._task_mapping_key(internal_task_id)
data = await redis.hgetall(key)
if not data:
return None
return {
"platform": data.get("platform", ""),
"platform_task_id": data.get("platform_task_id", ""),
}
def _audit_rejection_key(self, fingerprint: str) -> str:
return f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}"
async def _get_audit_rejection(self, fingerprint: str) -> str | None:
"""查询该内容指纹是否近期审核失败。
Returns:
失败错误码(如 "TaskPromptPolicyViolation"),未命中返回 None
"""
if not fingerprint:
return None
try:
redis = self._get_redis()
key = self._audit_rejection_key(fingerprint)
return await redis.get(key)
except Exception as e:
logger.warning(f"[PlatformGateway] 查询审核缓存失败: {e}")
return None
async def _set_audit_rejection(self, fingerprint: str, error_code: str) -> None:
"""缓存审核失败结果。"""
if not fingerprint or not error_code:
return
redis = self._get_redis()
key = self._audit_rejection_key(fingerprint)
await redis.setex(key, _AUDIT_REJECTION_TTL, error_code)
async def get_internal_task_id_by_platform_task_id(
self, platform: str, platform_task_id: str
) -> str | None:
"""通过平台 task_id 反查内部 task_id(供回调使用)"""
redis = self._get_redis()
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
return await redis.get(reverse_key)
def register(self, platform_id: str, adapter: PlatformAdapter) -> None:
"""注册平台 Adapter"""
self.adapters[platform_id] = adapter
logger.info(f"PlatformGateway 注册平台: {platform_id}")
def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable:
"""获取支持同步调用的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, SyncCapable):
raise ValueError(f"平台 {platform} 不支持同步调用")
return adapter
def _get_task_adapter(self, platform: str, task_type: str | None = None) -> TaskCapable:
"""获取支持异步任务的 Adapter
Args:
platform: 平台 ID
task_type: 任务类型(仅在提交时需要校验,查询时可不传)
"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, TaskCapable):
raise ValueError(f"平台 {platform} 不支持异步任务")
return adapter
def _get_callback_adapter(self, platform: str) -> CallbackCapable:
"""获取支持回调的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, CallbackCapable):
raise ValueError(f"平台 {platform} 不支持回调")
return adapter
# ── 同步调用 ──
async def call_sync(
self,
platform: str,
method: str,
payload: dict[str, Any],
) -> AdapterResponse:
"""同步调用统一入口"""
adapter = self._get_sync_adapter(platform, method)
return await adapter.call(method, payload)
# ── 异步任务 ──
async def submit_task(
self,
platform: str,
task_type: str,
payload: dict[str, Any],
callback_url: str | None = None,
internal_task_id: str | None = None,
) -> str:
"""异步任务提交统一入口,返回 internal_task_id
Args:
internal_task_id: 调用方(如 Async Engine)传入的内部任务 ID。
若提供,则直接使用该 ID 建立映射;否则自动生成。
callback 场景必须传入,确保回调能反查到正确的 Registry 记录。
"""
internal_task_id = internal_task_id or uuid.uuid4().hex
# 1. 同一 internal_task_id 已提交过,直接返回(幂等)
existing = await self._get_task_mapping(internal_task_id)
if existing:
return internal_task_id
# 2. Vidu 内容指纹防重:相同内容近期审核失败则直接拦截
fingerprint: str | None = None
if platform == "vidu":
fingerprint = compute_content_fingerprint(
task_type=task_type,
video_url=payload.get("video_url"),
audio_url=payload.get("audio_url"),
ref_photo_url=payload.get("ref_photo_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
)
rejected_code = await self._get_audit_rejection(fingerprint)
if rejected_code:
raise PlatformError(
"人物分镜台词未通过安全审核,请修改后重试",
platform=platform,
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
raw_code=rejected_code,
)
# 3. 调用平台 Adapter 提交任务
adapter = self._get_task_adapter(platform, task_type)
try:
result = await adapter.submit(task_type, payload, callback_url)
except PlatformError as e:
# Vidu 审核类错误:缓存内容指纹,防止重复调用
err_code = e.raw_code
if platform == "vidu" and fingerprint and err_code and is_vidu_audit_error(err_code):
await self._set_audit_rejection(fingerprint, err_code)
raise PlatformError(
"人物分镜台词未通过安全审核,请修改后重试",
platform=platform,
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
raw_code=err_code,
) from e
raise
if not result.success:
raw_code = result.error_code
if platform == "vidu" and fingerprint and raw_code and is_vidu_audit_error(raw_code):
await self._set_audit_rejection(fingerprint, raw_code)
raise PlatformError(
result.error_message or "任务提交失败",
platform=platform,
retryable=result.retryable,
error_type=PlatformErrorType.UNKNOWN,
raw_code=raw_code,
)
platform_task_id = (result.data or {}).get("task_id", "")
if not platform_task_id:
raise PlatformError(
"任务提交成功但未返回平台任务ID",
platform=platform,
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
await self._store_task_mapping(internal_task_id, platform, platform_task_id)
logger.info(
f"Task submitted: internal={internal_task_id}, "
f"platform={platform}, platform_task_id={platform_task_id}"
)
return internal_task_id
async def query_task(self, platform: str, platform_task_id: str) -> TaskStatus:
"""任务状态查询统一入口(传入 platform_task_id"""
adapter = self._get_task_adapter(platform)
return await adapter.query(platform_task_id)
async def query_task_by_internal_id(
self, internal_task_id: str, task_type: str | None = None
) -> TaskStatus:
"""通过内部 task_id 查询任务状态
Args:
internal_task_id: 内部任务 ID
task_type: 可选的任务类型,用于路由到 Adapter 的特定查询方法
"""
mapping = await self._get_task_mapping(internal_task_id)
if not mapping:
raise PlatformError(
"任务不存在或已过期",
platform="",
retryable=False,
error_type=PlatformErrorType.NOT_FOUND,
)
platform = mapping["platform"]
platform_task_id = mapping["platform_task_id"]
adapter = self._get_task_adapter(platform)
# 根据 task_type 路由到 Adapter 的特定查询方法
if task_type == "auto_align" and hasattr(adapter, "query_auto_align"):
return await adapter.query_auto_align(platform_task_id)
return await adapter.query(platform_task_id)
# ── 回调处理 ──
async def handle_webhook(
self,
platform: str,
headers: dict[str, str],
body: bytes,
secret: str | None = None,
callback_url: str | None = None,
) -> TaskStatus:
"""统一回调处理入口(含签名验证 + nonce 防重放)"""
adapter = self._get_callback_adapter(platform)
# 1. 签名验证
if secret and not await adapter.verify_signature(
headers, body, secret, callback_url=callback_url
):
raise PlatformError(
"回调签名验证失败",
platform=platform,
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
)
# 2. nonce 防重放(可选,仅 Adapter 实现了 verify_nonce 时)
if hasattr(adapter, "verify_nonce"):
redis = self._get_redis()
if not await adapter.verify_nonce(headers, redis):
raise PlatformError(
"回调 nonce 已使用,可能为重放攻击",
platform=platform,
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
)
return await adapter.parse_callback(body)
# ── 生命周期 ──
async def close_all(self) -> None:
"""关闭所有 Adapter"""
for platform_id, adapter in self.adapters.items():
try:
await adapter.close()
logger.info(f"Adapter 关闭: {platform_id}")
except Exception as e:
logger.warning(f"Adapter 关闭失败: {platform_id}: {e}")
# ── 健康检查 ──
async def health_check_all(self) -> dict[str, AdapterResponse]:
"""检查所有平台健康状态"""
results = {}
for platform_id, adapter in self.adapters.items():
try:
results[platform_id] = await adapter.health()
except Exception as e:
results[platform_id] = AdapterResponse(
success=False,
error_message=str(e),
)
return results