""" 第三方平台统一调用网关 ======================== 所有第三方平台调用的唯一入口。 - 同步调用:call_sync() - 异步任务提交:submit_task() - 任务状态查询:query_task() - 回调处理:handle_webhook() """ from __future__ import annotations import logging import uuid from typing import Any from app.ai.adapters.base import ( AdapterResponse, CallbackCapable, PlatformAdapter, SyncCapable, TaskCapable, TaskStatus, ) from app.core.exceptions import PlatformError, PlatformErrorType from app.utils.content_fingerprint import ( compute_content_fingerprint, is_vidu_audit_error, ) logger = logging.getLogger(__name__) # Redis key 前缀:内部 task_id → platform_task_id 映射 _TASK_MAPPING_PREFIX = "platform_gateway:task_mapping" _TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天 # Redis key 前缀:内容审核失败缓存 _AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection" _AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时 class PlatformGateway: """第三方平台统一调用网关""" def __init__( self, adapters: dict[str, PlatformAdapter] | None = None, redis=None, ): self.adapters: dict[str, PlatformAdapter] = adapters or {} self._redis = redis def _get_redis(self): """懒加载 Redis 客户端""" if self._redis is None: from app.core.redis_client import get_redis_client self._redis = get_redis_client() return self._redis def _task_mapping_key(self, internal_task_id: str) -> str: return f"{_TASK_MAPPING_PREFIX}:{internal_task_id}" async def _store_task_mapping( self, internal_task_id: str, platform: str, platform_task_id: str ) -> None: """存储内部 task_id 与平台 task_id 的双向映射关系""" redis = self._get_redis() # 正向映射:internal → platform key = self._task_mapping_key(internal_task_id) await redis.hset( key, mapping={ "platform": platform, "platform_task_id": platform_task_id, }, ) await redis.expire(key, _TASK_MAPPING_TTL) # 反向映射:platform → internal(供回调查找) reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}" await redis.setex(reverse_key, _TASK_MAPPING_TTL, internal_task_id) async def _get_task_mapping(self, internal_task_id: str) -> dict[str, str] | None: """查询内部 task_id 对应的平台映射""" redis = self._get_redis() key = self._task_mapping_key(internal_task_id) data = await redis.hgetall(key) if not data: return None return { "platform": data.get("platform", ""), "platform_task_id": data.get("platform_task_id", ""), } def _audit_rejection_key(self, fingerprint: str) -> str: return f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}" async def _get_audit_rejection(self, fingerprint: str) -> str | None: """查询该内容指纹是否近期审核失败。 Returns: 失败错误码(如 "TaskPromptPolicyViolation"),未命中返回 None """ if not fingerprint: return None try: redis = self._get_redis() key = self._audit_rejection_key(fingerprint) return await redis.get(key) except Exception as e: logger.warning(f"[PlatformGateway] 查询审核缓存失败: {e}") return None async def _set_audit_rejection(self, fingerprint: str, error_code: str) -> None: """缓存审核失败结果。""" if not fingerprint or not error_code: return redis = self._get_redis() key = self._audit_rejection_key(fingerprint) await redis.setex(key, _AUDIT_REJECTION_TTL, error_code) async def get_internal_task_id_by_platform_task_id( self, platform: str, platform_task_id: str ) -> str | None: """通过平台 task_id 反查内部 task_id(供回调使用)""" redis = self._get_redis() reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}" return await redis.get(reverse_key) def register(self, platform_id: str, adapter: PlatformAdapter) -> None: """注册平台 Adapter""" self.adapters[platform_id] = adapter logger.info(f"PlatformGateway 注册平台: {platform_id}") def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable: """获取支持同步调用的 Adapter""" adapter = self.adapters.get(platform) if adapter is None: raise ValueError(f"未注册的平台: {platform}") if not isinstance(adapter, SyncCapable): raise ValueError(f"平台 {platform} 不支持同步调用") return adapter def _get_task_adapter(self, platform: str, task_type: str | None = None) -> TaskCapable: """获取支持异步任务的 Adapter Args: platform: 平台 ID task_type: 任务类型(仅在提交时需要校验,查询时可不传) """ adapter = self.adapters.get(platform) if adapter is None: raise ValueError(f"未注册的平台: {platform}") if not isinstance(adapter, TaskCapable): raise ValueError(f"平台 {platform} 不支持异步任务") return adapter def _get_callback_adapter(self, platform: str) -> CallbackCapable: """获取支持回调的 Adapter""" adapter = self.adapters.get(platform) if adapter is None: raise ValueError(f"未注册的平台: {platform}") if not isinstance(adapter, CallbackCapable): raise ValueError(f"平台 {platform} 不支持回调") return adapter # ── 同步调用 ── async def call_sync( self, platform: str, method: str, payload: dict[str, Any], ) -> AdapterResponse: """同步调用统一入口""" adapter = self._get_sync_adapter(platform, method) return await adapter.call(method, payload) # ── 异步任务 ── async def submit_task( self, platform: str, task_type: str, payload: dict[str, Any], callback_url: str | None = None, internal_task_id: str | None = None, ) -> str: """异步任务提交统一入口,返回 internal_task_id Args: internal_task_id: 调用方(如 Async Engine)传入的内部任务 ID。 若提供,则直接使用该 ID 建立映射;否则自动生成。 callback 场景必须传入,确保回调能反查到正确的 Registry 记录。 """ internal_task_id = internal_task_id or uuid.uuid4().hex # 1. 同一 internal_task_id 已提交过,直接返回(幂等) existing = await self._get_task_mapping(internal_task_id) if existing: return internal_task_id # 2. Vidu 内容指纹防重:相同内容近期审核失败则直接拦截 fingerprint: str | None = None if platform == "vidu": fingerprint = compute_content_fingerprint( task_type=task_type, video_url=payload.get("video_url"), audio_url=payload.get("audio_url"), ref_photo_url=payload.get("ref_photo_url"), text=payload.get("text"), voice_id=payload.get("voice_id"), ) rejected_code = await self._get_audit_rejection(fingerprint) if rejected_code: raise PlatformError( "人物分镜台词未通过安全审核,请修改后重试", platform=platform, retryable=False, error_type=PlatformErrorType.CONTENT_VIOLATION, raw_code=rejected_code, ) # 3. 调用平台 Adapter 提交任务 adapter = self._get_task_adapter(platform, task_type) try: result = await adapter.submit(task_type, payload, callback_url) except PlatformError as e: # Vidu 审核类错误:缓存内容指纹,防止重复调用 err_code = e.raw_code if platform == "vidu" and fingerprint and err_code and is_vidu_audit_error(err_code): await self._set_audit_rejection(fingerprint, err_code) raise PlatformError( "人物分镜台词未通过安全审核,请修改后重试", platform=platform, retryable=False, error_type=PlatformErrorType.CONTENT_VIOLATION, raw_code=err_code, ) from e raise if not result.success: raw_code = result.error_code if platform == "vidu" and fingerprint and raw_code and is_vidu_audit_error(raw_code): await self._set_audit_rejection(fingerprint, raw_code) raise PlatformError( result.error_message or "任务提交失败", platform=platform, retryable=result.retryable, error_type=PlatformErrorType.UNKNOWN, raw_code=raw_code, ) platform_task_id = (result.data or {}).get("task_id", "") if not platform_task_id: raise PlatformError( "任务提交成功但未返回平台任务ID", platform=platform, retryable=False, error_type=PlatformErrorType.UNKNOWN, ) await self._store_task_mapping(internal_task_id, platform, platform_task_id) logger.info( f"Task submitted: internal={internal_task_id}, " f"platform={platform}, platform_task_id={platform_task_id}" ) return internal_task_id async def query_task(self, platform: str, platform_task_id: str) -> TaskStatus: """任务状态查询统一入口(传入 platform_task_id)""" adapter = self._get_task_adapter(platform) return await adapter.query(platform_task_id) async def query_task_by_internal_id( self, internal_task_id: str, task_type: str | None = None ) -> TaskStatus: """通过内部 task_id 查询任务状态 Args: internal_task_id: 内部任务 ID task_type: 可选的任务类型,用于路由到 Adapter 的特定查询方法 """ mapping = await self._get_task_mapping(internal_task_id) if not mapping: raise PlatformError( "任务不存在或已过期", platform="", retryable=False, error_type=PlatformErrorType.NOT_FOUND, ) platform = mapping["platform"] platform_task_id = mapping["platform_task_id"] adapter = self._get_task_adapter(platform) # 根据 task_type 路由到 Adapter 的特定查询方法 if task_type == "auto_align" and hasattr(adapter, "query_auto_align"): return await adapter.query_auto_align(platform_task_id) return await adapter.query(platform_task_id) # ── 回调处理 ── async def handle_webhook( self, platform: str, headers: dict[str, str], body: bytes, secret: str | None = None, callback_url: str | None = None, ) -> TaskStatus: """统一回调处理入口(含签名验证 + nonce 防重放)""" adapter = self._get_callback_adapter(platform) # 1. 签名验证 if secret and not await adapter.verify_signature( headers, body, secret, callback_url=callback_url ): raise PlatformError( "回调签名验证失败", platform=platform, retryable=False, error_type=PlatformErrorType.AUTH_FAILED, ) # 2. nonce 防重放(可选,仅 Adapter 实现了 verify_nonce 时) if hasattr(adapter, "verify_nonce"): redis = self._get_redis() if not await adapter.verify_nonce(headers, redis): raise PlatformError( "回调 nonce 已使用,可能为重放攻击", platform=platform, retryable=False, error_type=PlatformErrorType.AUTH_FAILED, ) return await adapter.parse_callback(body) # ── 生命周期 ── async def close_all(self) -> None: """关闭所有 Adapter""" for platform_id, adapter in self.adapters.items(): try: await adapter.close() logger.info(f"Adapter 关闭: {platform_id}") except Exception as e: logger.warning(f"Adapter 关闭失败: {platform_id}: {e}") # ── 健康检查 ── async def health_check_all(self) -> dict[str, AdapterResponse]: """检查所有平台健康状态""" results = {} for platform_id, adapter in self.adapters.items(): try: results[platform_id] = await adapter.health() except Exception as e: results[platform_id] = AdapterResponse( success=False, error_message=str(e), ) return results