""" Vidu Adapter ============ 实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。 直接接入 ViduProvider,提供标准 Protocol 接口。 """ from __future__ import annotations import base64 import hashlib import hmac import json import logging from typing import Any from urllib.parse import urlparse from app.ai.adapters.base import ( AdapterResponse, CallbackCapable, PlatformAdapter, SyncCapable, TaskCapable, TaskStatus, ) from app.ai.adapters.constants import Method from app.ai.providers.vidu_provider import ViduProvider from app.core.exceptions import PlatformError, PlatformErrorType logger = logging.getLogger(__name__) class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable): """Vidu 平台标准 Adapter""" platform_id = "vidu" # Vidu 原生状态 ↔ 标准状态 映射表 _VIDU_TO_STANDARD = { "created": "processing", "queueing": "processing", "pending": "processing", "processing": "processing", "success": "completed", "failed": "failed", } _STANDARD_TO_VIDU = { "completed": "success", "processing": "processing", "pending": "pending", "failed": "failed", } def __init__(self, provider: ViduProvider): self.provider = provider @classmethod def normalize_state(cls, vidu_state: str) -> str: """Vidu 原生状态 → 标准状态(processing / completed / failed)""" return cls._VIDU_TO_STANDARD.get(vidu_state, "failed") @classmethod def denormalize_state(cls, standard_state: str) -> str: """标准状态 → Vidu 原生状态(success / processing / pending / failed)""" return cls._STANDARD_TO_VIDU.get(standard_state, standard_state) # ── PlatformAdapter ── async def health(self) -> AdapterResponse: try: # Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性 # 实际上会 404,但只要网络通就说明服务可用 await self.provider.query_task("health-check") return AdapterResponse(success=True) except PlatformError as e: if e.error_type == PlatformErrorType.NOT_FOUND: return AdapterResponse(success=True) return AdapterResponse( success=False, error_message=str(e), retryable=e.retryable, ) except Exception as e: return AdapterResponse( success=False, error_message=str(e), retryable=False, ) async def close(self) -> None: await self.provider.close() # ── SyncCapable ── async def call(self, method: str, payload: dict) -> AdapterResponse: try: if method == Method.TTS: result = await self.provider.tts_sync( text=payload["text"], voice_id=payload.get("voice_id", "tianxin_xiaoling"), speed=payload.get("speed", 1.0), volume=payload.get("volume", 0), pitch=payload.get("pitch", 0), emotion=payload.get("emotion"), ) return AdapterResponse( success=True, data={"audio_url": result.get("file_url")}, ) elif method == Method.CLONE_VOICE: result = await self.provider.clone_voice( audio_url=payload["audio_url"], voice_id=payload["voice_id"], text=payload.get("text") or "", ) return AdapterResponse( success=True, data={ "voice_id": result.get("voice_id"), "demo_audio": result.get("demo_audio"), }, ) else: return AdapterResponse( success=False, error_message=f"不支持的方法: {method}", retryable=False, ) except PlatformError: raise except Exception as e: raise PlatformError( f"Vidu {method} 调用失败: {e}", platform="vidu", retryable=False, error_type=PlatformErrorType.UNKNOWN, ) from e # ── TaskCapable ── async def submit( self, task_type: str, payload: dict, callback_url: str | None = None, ) -> AdapterResponse: try: if task_type == Method.LIP_SYNC: result = await self.provider.lip_sync( video_url=payload["video_url"], audio_url=payload.get("audio_url"), text=payload.get("text"), voice_id=payload.get("voice_id"), speed=payload.get("speed", 1.0), volume=payload.get("volume", 0), ref_photo_url=payload.get("ref_photo_url"), callback_url=callback_url, ) return AdapterResponse( success=True, data={"task_id": result.get("task_id")}, ) else: return AdapterResponse( success=False, error_message=f"不支持的任务类型: {task_type}", retryable=False, ) except PlatformError: raise except Exception as e: raise PlatformError( f"Vidu {task_type} 提交失败: {e}", platform="vidu", retryable=False, error_type=PlatformErrorType.UNKNOWN, ) from e async def query(self, platform_task_id: str) -> TaskStatus: try: result = await self.provider.query_task(platform_task_id) state = result.get("state", "unknown") creations = result.get("creations", []) video_url = None if state == "success" and creations: video_url = creations[0].get("url") return TaskStatus( state=self.normalize_state(state), result={"video_url": video_url, "creations": creations} if video_url else None, error_message=result.get("message") if state == "failed" else None, ) except PlatformError: raise except Exception as e: raise PlatformError( f"Vidu 任务查询失败: {e}", platform="vidu", retryable=False, error_type=PlatformErrorType.UNKNOWN, ) from e # ── CallbackCapable ── async def verify_signature( self, headers: dict[str, str], body: bytes, secret: str, callback_url: str | None = None, ) -> bool: """验证 Vidu 回调 HMAC-SHA256 签名""" import logging logger = logging.getLogger(__name__) # HTTP 头大小写不敏感:建立小写 key 的查找表 headers_lower = {k.lower(): v for k, v in headers.items()} signature = headers_lower.get("x-hmac-signature") algorithm = headers_lower.get("x-hmac-algorithm") access_key = headers_lower.get("x-hmac-access-key") signed_headers_str = headers_lower.get("x-hmac-signed-headers") date = headers_lower.get("date") if not all([signature, algorithm, access_key, signed_headers_str, date]): logger.warning(f"[Vidu] 签名验证失败: 缺少必要头, headers={list(headers.keys())}") return False assert signature is not None assert signed_headers_str is not None assert date is not None if algorithm != "hmac-sha256": logger.warning(f"[Vidu] 签名验证失败: 不支持的算法 {algorithm}") return False if access_key != "vidu": logger.warning(f"[Vidu] 签名验证失败: access_key 不匹配 {access_key}") return False header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()] header_values: dict[str, str] = {} for name in header_names: # 签名头名也可能大小写不一致,统一用小写查找 value = headers_lower.get(name.lower()) if value is None: logger.warning(f"[Vidu] 签名验证失败: 缺少签名头 {name}") return False header_values[name] = value # 构建 signingString(使用 callback_url 动态解析 path/query) parsed = urlparse(callback_url or "") http_uri = parsed.path or "/" canonical_query_string = parsed.query or "" signing_string = ( f"POST\n" f"{http_uri}\n" f"{canonical_query_string}\n" f"vidu\n" f"{date}\n" ) for name in header_names: signing_string += f"{name}:{header_values[name]}\n" expected = base64.b64encode( hmac.new( secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256 ).digest() ).decode("utf-8") if not hmac.compare_digest(signature, expected): logger.warning( f"[Vidu] 签名验证失败: callback_url={callback_url}, " f"signing_string={repr(signing_string)}, " f"expected={expected[:20]}..., received={signature[:20]}..." ) return False return True async def verify_nonce( self, headers: dict[str, str], redis: Any, ) -> bool: """验证 Vidu 回调 nonce 防重放""" nonce = headers.get("x-request-nonce") if not nonce: return False nonce_key = f"vidu:callback_nonce:{nonce}" if await redis.exists(nonce_key): return False await redis.setex(nonce_key, 300, "1") return True async def parse_callback(self, body: bytes) -> TaskStatus: """解析 Vidu 回调体""" data = json.loads(body) task_id = data.get("id") or data.get("task_id") state = data.get("state") creations = data.get("creations", []) video_url = None if state == "success" and creations: video_url = creations[0].get("url") return TaskStatus( state=self.normalize_state(state), result=( {"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id} ), error_message=( (data.get("err_code") or data.get("message")) if state == "failed" else None ), )