Files
meijiaka-zy/python-api/app/platform_gateway.py
T
小鱼开发 30536276ba refactor(scheduler): 统一异步任务调度架构
核心变更:
- 统一第三方接口架构:所有服务走 PlatformGateway(call_sync/submit_task/query_task/handle_webhook)
- 视频生成(Vidu 对口型)纳入 Async Engine,与 script/subtitle/tts 统一为 POST /tasks/{task_type} 模式
- 新增 VideoHandler、TTSHandler,完善 ScriptHandler/SubtitleHandler
- PlatformGateway 生成 internal_task_id,建立 Redis 双向映射,callback 场景传入 Async Engine task_id 保证映射一致
- SlotManager 新增 acquire_ctx 上下文管理器,所有 Handler 统一使用
- ViduAdapter 状态映射归一化(normalize_state/denormalize_state)
- 移除 ViduService Semaphore 和 tenacity 重试,并发控制完全交予 SlotManager
- nonce 防重放下沉到 CallbackCapable 协议
- Service 层错误统一为 PlatformError,路由层错误信息脱敏
- 废弃 /voice/lip-sync,清理 vidu.py 遗留路由

Bug 修复:
- VideoHandler 轮询阶段后添加 continue,防止已提交任务重复创建
- voice.py synthesize_to_file 变量名冲突(request vs request_body)
- PlatformGateway.submit_task 空 data 防护
- ScriptHandler 动态导入 asyncio 改为模块级导入
- SubtitleHandler 完成时补充 progress=100

文档:
- 更新 AGENTS.md 核心功能、运行时架构、异步调度描述
2026-05-05 20:53:18 +08:00

279 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
第三方平台统一调用网关
========================
所有第三方平台调用的唯一入口。
- 同步调用:call_sync()
- 异步任务提交:submit_task()
- 任务状态查询:query_task()
- 回调处理:handle_webhook()
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
# Redis key 前缀:内部 task_id → platform_task_id 映射
_TASK_MAPPING_PREFIX = "platform_gateway:task_mapping"
_TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天
class PlatformGateway:
"""第三方平台统一调用网关"""
def __init__(
self,
adapters: dict[str, PlatformAdapter] | None = None,
redis=None,
):
self.adapters: dict[str, PlatformAdapter] = adapters or {}
self._redis = redis
def _get_redis(self):
"""懒加载 Redis 客户端"""
if self._redis is None:
from app.core.redis_client import get_redis_client
self._redis = get_redis_client()
return self._redis
def _task_mapping_key(self, internal_task_id: str) -> str:
return f"{_TASK_MAPPING_PREFIX}:{internal_task_id}"
async def _store_task_mapping(
self, internal_task_id: str, platform: str, platform_task_id: str
) -> None:
"""存储内部 task_id 与平台 task_id 的双向映射关系"""
redis = self._get_redis()
# 正向映射:internal → platform
key = self._task_mapping_key(internal_task_id)
await redis.hset(key, mapping={
"platform": platform,
"platform_task_id": platform_task_id,
})
await redis.expire(key, _TASK_MAPPING_TTL)
# 反向映射:platform → internal(供回调查找)
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
await redis.setex(reverse_key, _TASK_MAPPING_TTL, internal_task_id)
async def _get_task_mapping(self, internal_task_id: str) -> dict[str, str] | None:
"""查询内部 task_id 对应的平台映射"""
redis = self._get_redis()
key = self._task_mapping_key(internal_task_id)
data = await redis.hgetall(key)
if not data:
return None
return {
"platform": data.get("platform", ""),
"platform_task_id": data.get("platform_task_id", ""),
}
async def get_internal_task_id_by_platform_task_id(
self, platform: str, platform_task_id: str
) -> str | None:
"""通过平台 task_id 反查内部 task_id(供回调使用)"""
redis = self._get_redis()
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
return await redis.get(reverse_key)
def register(self, platform_id: str, adapter: PlatformAdapter) -> None:
"""注册平台 Adapter"""
self.adapters[platform_id] = adapter
logger.info(f"PlatformGateway 注册平台: {platform_id}")
def _get_sync_adapter(self, platform: str, method: str) -> SyncCapable:
"""获取支持同步调用的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, SyncCapable):
raise ValueError(f"平台 {platform} 不支持同步调用")
return adapter
def _get_task_adapter(self, platform: str, task_type: str | None = None) -> TaskCapable:
"""获取支持异步任务的 Adapter
Args:
platform: 平台 ID
task_type: 任务类型(仅在提交时需要校验,查询时可不传)
"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, TaskCapable):
raise ValueError(f"平台 {platform} 不支持异步任务")
return adapter
def _get_callback_adapter(self, platform: str) -> CallbackCapable:
"""获取支持回调的 Adapter"""
adapter = self.adapters.get(platform)
if adapter is None:
raise ValueError(f"未注册的平台: {platform}")
if not isinstance(adapter, CallbackCapable):
raise ValueError(f"平台 {platform} 不支持回调")
return adapter
# ── 同步调用 ──
async def call_sync(
self,
platform: str,
method: str,
payload: dict[str, Any],
) -> AdapterResponse:
"""同步调用统一入口"""
adapter = self._get_sync_adapter(platform, method)
return await adapter.call(method, payload)
# ── 异步任务 ──
async def submit_task(
self,
platform: str,
task_type: str,
payload: dict[str, Any],
callback_url: str | None = None,
internal_task_id: str | None = None,
) -> str:
"""异步任务提交统一入口,返回 internal_task_id
Args:
internal_task_id: 调用方(如 Async Engine)传入的内部任务 ID。
若提供,则直接使用该 ID 建立映射;否则自动生成。
callback 场景必须传入,确保回调能反查到正确的 Registry 记录。
"""
adapter = self._get_task_adapter(platform, task_type)
result = await adapter.submit(task_type, payload, callback_url)
if not result.success:
raise PlatformError(
result.error_message or "任务提交失败",
platform=platform,
retryable=result.retryable,
error_type=PlatformErrorType.UNKNOWN,
)
platform_task_id = (result.data or {}).get("task_id", "")
if not platform_task_id:
raise PlatformError(
"任务提交成功但未返回平台任务ID",
platform=platform,
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
internal_task_id = internal_task_id or uuid.uuid4().hex
await self._store_task_mapping(internal_task_id, platform, platform_task_id)
logger.info(
f"Task submitted: internal={internal_task_id}, "
f"platform={platform}, platform_task_id={platform_task_id}"
)
return internal_task_id
async def query_task(self, platform: str, platform_task_id: str) -> TaskStatus:
"""任务状态查询统一入口(传入 platform_task_id"""
adapter = self._get_task_adapter(platform)
return await adapter.query(platform_task_id)
async def query_task_by_internal_id(
self, internal_task_id: str, task_type: str | None = None
) -> TaskStatus:
"""通过内部 task_id 查询任务状态
Args:
internal_task_id: 内部任务 ID
task_type: 可选的任务类型,用于路由到 Adapter 的特定查询方法
"""
mapping = await self._get_task_mapping(internal_task_id)
if not mapping:
raise PlatformError(
"任务不存在或已过期",
platform="",
retryable=False,
error_type=PlatformErrorType.NOT_FOUND,
)
platform = mapping["platform"]
platform_task_id = mapping["platform_task_id"]
adapter = self._get_task_adapter(platform)
# 根据 task_type 路由到 Adapter 的特定查询方法
if task_type == "auto_align" and hasattr(adapter, "query_auto_align"):
return await adapter.query_auto_align(platform_task_id)
return await adapter.query(platform_task_id)
# ── 回调处理 ──
async def handle_webhook(
self,
platform: str,
headers: dict[str, str],
body: bytes,
secret: str | None = None,
callback_url: str | None = None,
) -> TaskStatus:
"""统一回调处理入口(含签名验证 + nonce 防重放)"""
adapter = self._get_callback_adapter(platform)
# 1. 签名验证
if secret and not await adapter.verify_signature(
headers, body, secret, callback_url=callback_url
):
raise PlatformError(
"回调签名验证失败",
platform=platform,
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
)
# 2. nonce 防重放(可选,仅 Adapter 实现了 verify_nonce 时)
if hasattr(adapter, "verify_nonce"):
redis = self._get_redis()
if not await adapter.verify_nonce(headers, redis):
raise PlatformError(
"回调 nonce 已使用,可能为重放攻击",
platform=platform,
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
)
return await adapter.parse_callback(body)
# ── 生命周期 ──
async def close_all(self) -> None:
"""关闭所有 Adapter"""
for platform_id, adapter in self.adapters.items():
try:
await adapter.close()
logger.info(f"Adapter 关闭: {platform_id}")
except Exception as e:
logger.warning(f"Adapter 关闭失败: {platform_id}: {e}")
# ── 健康检查 ──
async def health_check_all(self) -> dict[str, AdapterResponse]:
"""检查所有平台健康状态"""
results = {}
for platform_id, adapter in self.adapters.items():
try:
results[platform_id] = await adapter.health()
except Exception as e:
results[platform_id] = AdapterResponse(
success=False,
error_message=str(e),
)
return results