30536276ba
核心变更:
- 统一第三方接口架构:所有服务走 PlatformGateway(call_sync/submit_task/query_task/handle_webhook)
- 视频生成(Vidu 对口型)纳入 Async Engine,与 script/subtitle/tts 统一为 POST /tasks/{task_type} 模式
- 新增 VideoHandler、TTSHandler,完善 ScriptHandler/SubtitleHandler
- PlatformGateway 生成 internal_task_id,建立 Redis 双向映射,callback 场景传入 Async Engine task_id 保证映射一致
- SlotManager 新增 acquire_ctx 上下文管理器,所有 Handler 统一使用
- ViduAdapter 状态映射归一化(normalize_state/denormalize_state)
- 移除 ViduService Semaphore 和 tenacity 重试,并发控制完全交予 SlotManager
- nonce 防重放下沉到 CallbackCapable 协议
- Service 层错误统一为 PlatformError,路由层错误信息脱敏
- 废弃 /voice/lip-sync,清理 vidu.py 遗留路由
Bug 修复:
- VideoHandler 轮询阶段后添加 continue,防止已提交任务重复创建
- voice.py synthesize_to_file 变量名冲突(request vs request_body)
- PlatformGateway.submit_task 空 data 防护
- ScriptHandler 动态导入 asyncio 改为模块级导入
- SubtitleHandler 完成时补充 progress=100
文档:
- 更新 AGENTS.md 核心功能、运行时架构、异步调度描述
279 lines
9.9 KiB
Python
279 lines
9.9 KiB
Python
"""
|
||
第三方平台统一调用网关
|
||
========================
|
||
|
||
所有第三方平台调用的唯一入口。
|
||
- 同步调用:call_sync()
|
||
- 异步任务提交:submit_task()
|
||
- 任务状态查询:query_task()
|
||
- 回调处理:handle_webhook()
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import uuid
|
||
from typing import Any
|
||
|
||
from app.ai.adapters.base import (
|
||
AdapterResponse,
|
||
CallbackCapable,
|
||
PlatformAdapter,
|
||
SyncCapable,
|
||
TaskCapable,
|
||
TaskStatus,
|
||
)
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis key 前缀:内部 task_id → platform_task_id 映射
|
||
_TASK_MAPPING_PREFIX = "platform_gateway:task_mapping"
|
||
_TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天
|
||
|
||
|
||
class PlatformGateway:
|
||
"""第三方平台统一调用网关"""
|
||
|
||
def __init__(
|
||
self,
|
||
adapters: dict[str, PlatformAdapter] | None = None,
|
||
redis=None,
|
||
):
|
||
self.adapters: dict[str, PlatformAdapter] = adapters or {}
|
||
self._redis = redis
|
||
|
||
def _get_redis(self):
|
||
"""懒加载 Redis 客户端"""
|
||
if self._redis is None:
|
||
from app.core.redis_client import get_redis_client
|
||
|
||
self._redis = get_redis_client()
|
||
return self._redis
|
||
|
||
def _task_mapping_key(self, internal_task_id: str) -> str:
|
||
return f"{_TASK_MAPPING_PREFIX}:{internal_task_id}"
|
||
|
||
async def _store_task_mapping(
|
||
self, internal_task_id: str, platform: str, platform_task_id: str
|
||
) -> None:
|
||
"""存储内部 task_id 与平台 task_id 的双向映射关系"""
|
||
redis = self._get_redis()
|
||
# 正向映射:internal → platform
|
||
key = self._task_mapping_key(internal_task_id)
|
||
await redis.hset(key, mapping={
|
||
"platform": platform,
|
||
"platform_task_id": platform_task_id,
|
||
})
|
||
await redis.expire(key, _TASK_MAPPING_TTL)
|
||
# 反向映射:platform → internal(供回调查找)
|
||
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
|
||
await redis.setex(reverse_key, _TASK_MAPPING_TTL, internal_task_id)
|
||
|
||
async def _get_task_mapping(self, internal_task_id: str) -> dict[str, str] | None:
|
||
"""查询内部 task_id 对应的平台映射"""
|
||
redis = self._get_redis()
|
||
key = self._task_mapping_key(internal_task_id)
|
||
data = await redis.hgetall(key)
|
||
if not data:
|
||
return None
|
||
return {
|
||
"platform": data.get("platform", ""),
|
||
"platform_task_id": data.get("platform_task_id", ""),
|
||
}
|
||
|
||
async def get_internal_task_id_by_platform_task_id(
|
||
self, platform: str, platform_task_id: str
|
||
) -> str | None:
|
||
"""通过平台 task_id 反查内部 task_id(供回调使用)"""
|
||
redis = self._get_redis()
|
||
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
|
||
return await redis.get(reverse_key)
|
||
|
||
def register(self, platform_id: str, adapter: PlatformAdapter) -> None:
|
||
"""注册平台 Adapter"""
|
||
self.adapters[platform_id] = adapter
|
||
logger.info(f"PlatformGateway 注册平台: {platform_id}")
|
||
|
||
def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable:
|
||
"""获取支持同步调用的 Adapter"""
|
||
adapter = self.adapters.get(platform)
|
||
if adapter is None:
|
||
raise ValueError(f"未注册的平台: {platform}")
|
||
if not isinstance(adapter, SyncCapable):
|
||
raise ValueError(f"平台 {platform} 不支持同步调用")
|
||
return adapter
|
||
|
||
def _get_task_adapter(self, platform: str, task_type: str | None = None) -> TaskCapable:
|
||
"""获取支持异步任务的 Adapter
|
||
|
||
Args:
|
||
platform: 平台 ID
|
||
task_type: 任务类型(仅在提交时需要校验,查询时可不传)
|
||
"""
|
||
adapter = self.adapters.get(platform)
|
||
if adapter is None:
|
||
raise ValueError(f"未注册的平台: {platform}")
|
||
if not isinstance(adapter, TaskCapable):
|
||
raise ValueError(f"平台 {platform} 不支持异步任务")
|
||
return adapter
|
||
|
||
def _get_callback_adapter(self, platform: str) -> CallbackCapable:
|
||
"""获取支持回调的 Adapter"""
|
||
adapter = self.adapters.get(platform)
|
||
if adapter is None:
|
||
raise ValueError(f"未注册的平台: {platform}")
|
||
if not isinstance(adapter, CallbackCapable):
|
||
raise ValueError(f"平台 {platform} 不支持回调")
|
||
return adapter
|
||
|
||
# ── 同步调用 ──
|
||
|
||
async def call_sync(
|
||
self,
|
||
platform: str,
|
||
method: str,
|
||
payload: dict[str, Any],
|
||
) -> AdapterResponse:
|
||
"""同步调用统一入口"""
|
||
adapter = self._get_sync_adapter(platform, method)
|
||
return await adapter.call(method, payload)
|
||
|
||
# ── 异步任务 ──
|
||
|
||
async def submit_task(
|
||
self,
|
||
platform: str,
|
||
task_type: str,
|
||
payload: dict[str, Any],
|
||
callback_url: str | None = None,
|
||
internal_task_id: str | None = None,
|
||
) -> str:
|
||
"""异步任务提交统一入口,返回 internal_task_id
|
||
|
||
Args:
|
||
internal_task_id: 调用方(如 Async Engine)传入的内部任务 ID。
|
||
若提供,则直接使用该 ID 建立映射;否则自动生成。
|
||
callback 场景必须传入,确保回调能反查到正确的 Registry 记录。
|
||
"""
|
||
adapter = self._get_task_adapter(platform, task_type)
|
||
result = await adapter.submit(task_type, payload, callback_url)
|
||
|
||
if not result.success:
|
||
raise PlatformError(
|
||
result.error_message or "任务提交失败",
|
||
platform=platform,
|
||
retryable=result.retryable,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
)
|
||
|
||
platform_task_id = (result.data or {}).get("task_id", "")
|
||
if not platform_task_id:
|
||
raise PlatformError(
|
||
"任务提交成功但未返回平台任务ID",
|
||
platform=platform,
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
)
|
||
internal_task_id = internal_task_id or uuid.uuid4().hex
|
||
await self._store_task_mapping(internal_task_id, platform, platform_task_id)
|
||
logger.info(
|
||
f"Task submitted: internal={internal_task_id}, "
|
||
f"platform={platform}, platform_task_id={platform_task_id}"
|
||
)
|
||
return internal_task_id
|
||
|
||
async def query_task(self, platform: str, platform_task_id: str) -> TaskStatus:
|
||
"""任务状态查询统一入口(传入 platform_task_id)"""
|
||
adapter = self._get_task_adapter(platform)
|
||
return await adapter.query(platform_task_id)
|
||
|
||
async def query_task_by_internal_id(
|
||
self, internal_task_id: str, task_type: str | None = None
|
||
) -> TaskStatus:
|
||
"""通过内部 task_id 查询任务状态
|
||
|
||
Args:
|
||
internal_task_id: 内部任务 ID
|
||
task_type: 可选的任务类型,用于路由到 Adapter 的特定查询方法
|
||
"""
|
||
mapping = await self._get_task_mapping(internal_task_id)
|
||
if not mapping:
|
||
raise PlatformError(
|
||
"任务不存在或已过期",
|
||
platform="",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.NOT_FOUND,
|
||
)
|
||
platform = mapping["platform"]
|
||
platform_task_id = mapping["platform_task_id"]
|
||
adapter = self._get_task_adapter(platform)
|
||
|
||
# 根据 task_type 路由到 Adapter 的特定查询方法
|
||
if task_type == "auto_align" and hasattr(adapter, "query_auto_align"):
|
||
return await adapter.query_auto_align(platform_task_id)
|
||
return await adapter.query(platform_task_id)
|
||
|
||
# ── 回调处理 ──
|
||
|
||
async def handle_webhook(
|
||
self,
|
||
platform: str,
|
||
headers: dict[str, str],
|
||
body: bytes,
|
||
secret: str | None = None,
|
||
callback_url: str | None = None,
|
||
) -> TaskStatus:
|
||
"""统一回调处理入口(含签名验证 + nonce 防重放)"""
|
||
adapter = self._get_callback_adapter(platform)
|
||
|
||
# 1. 签名验证
|
||
if secret and not await adapter.verify_signature(
|
||
headers, body, secret, callback_url=callback_url
|
||
):
|
||
raise PlatformError(
|
||
"回调签名验证失败",
|
||
platform=platform,
|
||
retryable=False,
|
||
error_type=PlatformErrorType.AUTH_FAILED,
|
||
)
|
||
|
||
# 2. nonce 防重放(可选,仅 Adapter 实现了 verify_nonce 时)
|
||
if hasattr(adapter, "verify_nonce"):
|
||
redis = self._get_redis()
|
||
if not await adapter.verify_nonce(headers, redis):
|
||
raise PlatformError(
|
||
"回调 nonce 已使用,可能为重放攻击",
|
||
platform=platform,
|
||
retryable=False,
|
||
error_type=PlatformErrorType.AUTH_FAILED,
|
||
)
|
||
|
||
return await adapter.parse_callback(body)
|
||
|
||
# ── 生命周期 ──
|
||
|
||
async def close_all(self) -> None:
|
||
"""关闭所有 Adapter"""
|
||
for platform_id, adapter in self.adapters.items():
|
||
try:
|
||
await adapter.close()
|
||
logger.info(f"Adapter 关闭: {platform_id}")
|
||
except Exception as e:
|
||
logger.warning(f"Adapter 关闭失败: {platform_id}: {e}")
|
||
|
||
# ── 健康检查 ──
|
||
|
||
async def health_check_all(self) -> dict[str, AdapterResponse]:
|
||
"""检查所有平台健康状态"""
|
||
results = {}
|
||
for platform_id, adapter in self.adapters.items():
|
||
try:
|
||
results[platform_id] = await adapter.health()
|
||
except Exception as e:
|
||
results[platform_id] = AdapterResponse(
|
||
success=False,
|
||
error_message=str(e),
|
||
)
|
||
return results
|