321 lines
11 KiB
Python
321 lines
11 KiB
Python
"""
|
||
Vidu Adapter
|
||
============
|
||
|
||
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
|
||
直接接入 ViduProvider,提供标准 Protocol 接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
from urllib.parse import urlparse
|
||
|
||
from app.ai.adapters.base import (
|
||
AdapterResponse,
|
||
CallbackCapable,
|
||
PlatformAdapter,
|
||
SyncCapable,
|
||
TaskCapable,
|
||
TaskStatus,
|
||
)
|
||
from app.ai.adapters.constants import Method
|
||
from app.ai.providers.vidu_provider import ViduProvider
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||
"""Vidu 平台标准 Adapter"""
|
||
|
||
platform_id = "vidu"
|
||
|
||
# Vidu 原生状态 ↔ 标准状态 映射表
|
||
_VIDU_TO_STANDARD = {
|
||
"created": "processing",
|
||
"queueing": "processing",
|
||
"pending": "processing",
|
||
"processing": "processing",
|
||
"success": "completed",
|
||
"failed": "failed",
|
||
}
|
||
_STANDARD_TO_VIDU = {
|
||
"completed": "success",
|
||
"processing": "processing",
|
||
"pending": "pending",
|
||
"failed": "failed",
|
||
}
|
||
|
||
def __init__(self, provider: ViduProvider):
|
||
self.provider = provider
|
||
|
||
@classmethod
|
||
def normalize_state(cls, vidu_state: str) -> str:
|
||
"""Vidu 原生状态 → 标准状态(processing / completed / failed)"""
|
||
return cls._VIDU_TO_STANDARD.get(vidu_state, "failed")
|
||
|
||
@classmethod
|
||
def denormalize_state(cls, standard_state: str) -> str:
|
||
"""标准状态 → Vidu 原生状态(success / processing / pending / failed)"""
|
||
return cls._STANDARD_TO_VIDU.get(standard_state, standard_state)
|
||
|
||
# ── PlatformAdapter ──
|
||
|
||
async def health(self) -> AdapterResponse:
|
||
try:
|
||
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
|
||
# 实际上会 404,但只要网络通就说明服务可用
|
||
await self.provider.query_task("health-check")
|
||
return AdapterResponse(success=True)
|
||
except PlatformError as e:
|
||
if e.error_type == PlatformErrorType.NOT_FOUND:
|
||
return AdapterResponse(success=True)
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=str(e),
|
||
retryable=e.retryable,
|
||
)
|
||
except Exception as e:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=str(e),
|
||
retryable=False,
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
await self.provider.close()
|
||
|
||
# ── SyncCapable ──
|
||
|
||
async def call(self, method: str, payload: dict) -> AdapterResponse:
|
||
try:
|
||
if method == Method.TTS:
|
||
result = await self.provider.tts_sync(
|
||
text=payload["text"],
|
||
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
|
||
speed=payload.get("speed", 1.0),
|
||
volume=payload.get("volume", 0),
|
||
pitch=payload.get("pitch", 0),
|
||
emotion=payload.get("emotion"),
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={"audio_url": result.get("file_url")},
|
||
)
|
||
|
||
elif method == Method.CLONE_VOICE:
|
||
result = await self.provider.clone_voice(
|
||
audio_url=payload["audio_url"],
|
||
voice_id=payload["voice_id"],
|
||
text=payload.get("text") or "",
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={
|
||
"voice_id": result.get("voice_id"),
|
||
"demo_audio": result.get("demo_audio"),
|
||
},
|
||
)
|
||
|
||
else:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=f"不支持的方法: {method}",
|
||
retryable=False,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu {method} 调用失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
# ── TaskCapable ──
|
||
|
||
async def submit(
|
||
self,
|
||
task_type: str,
|
||
payload: dict,
|
||
callback_url: str | None = None,
|
||
) -> AdapterResponse:
|
||
try:
|
||
if task_type == Method.LIP_SYNC:
|
||
result = await self.provider.lip_sync(
|
||
video_url=payload["video_url"],
|
||
audio_url=payload.get("audio_url"),
|
||
text=payload.get("text"),
|
||
voice_id=payload.get("voice_id"),
|
||
speed=payload.get("speed", 1.0),
|
||
volume=payload.get("volume", 0),
|
||
ref_photo_url=payload.get("ref_photo_url"),
|
||
callback_url=callback_url,
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={"task_id": result.get("task_id")},
|
||
)
|
||
|
||
else:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=f"不支持的任务类型: {task_type}",
|
||
retryable=False,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu {task_type} 提交失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
async def query(self, platform_task_id: str) -> TaskStatus:
|
||
try:
|
||
result = await self.provider.query_task(platform_task_id)
|
||
state = result.get("state", "unknown")
|
||
|
||
creations = result.get("creations", [])
|
||
video_url = None
|
||
if state == "success" and creations:
|
||
video_url = creations[0].get("url")
|
||
|
||
return TaskStatus(
|
||
state=self.normalize_state(state),
|
||
result={"video_url": video_url, "creations": creations} if video_url else None,
|
||
error_message=result.get("message") if state == "failed" else None,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu 任务查询失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
# ── CallbackCapable ──
|
||
|
||
async def verify_signature(
|
||
self,
|
||
headers: dict[str, str],
|
||
body: bytes,
|
||
secret: str,
|
||
callback_url: str | None = None,
|
||
) -> bool:
|
||
"""验证 Vidu 回调 HMAC-SHA256 签名"""
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# HTTP 头大小写不敏感:建立小写 key 的查找表
|
||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||
|
||
signature = headers_lower.get("x-hmac-signature")
|
||
algorithm = headers_lower.get("x-hmac-algorithm")
|
||
access_key = headers_lower.get("x-hmac-access-key")
|
||
signed_headers_str = headers_lower.get("x-hmac-signed-headers")
|
||
date = headers_lower.get("date")
|
||
|
||
if not all([signature, algorithm, access_key, signed_headers_str, date]):
|
||
logger.warning(f"[Vidu] 签名验证失败: 缺少必要头, headers={list(headers.keys())}")
|
||
return False
|
||
assert signature is not None
|
||
assert signed_headers_str is not None
|
||
assert date is not None
|
||
if algorithm != "hmac-sha256":
|
||
logger.warning(f"[Vidu] 签名验证失败: 不支持的算法 {algorithm}")
|
||
return False
|
||
if access_key != "vidu":
|
||
logger.warning(f"[Vidu] 签名验证失败: access_key 不匹配 {access_key}")
|
||
return False
|
||
|
||
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
|
||
header_values: dict[str, str] = {}
|
||
for name in header_names:
|
||
# 签名头名也可能大小写不一致,统一用小写查找
|
||
value = headers_lower.get(name.lower())
|
||
if value is None:
|
||
logger.warning(f"[Vidu] 签名验证失败: 缺少签名头 {name}")
|
||
return False
|
||
header_values[name] = value
|
||
|
||
# 构建 signingString(使用 callback_url 动态解析 path/query)
|
||
parsed = urlparse(callback_url or "")
|
||
http_uri = parsed.path or "/"
|
||
canonical_query_string = parsed.query or ""
|
||
|
||
signing_string = (
|
||
f"POST\n" f"{http_uri}\n" f"{canonical_query_string}\n" f"vidu\n" f"{date}\n"
|
||
)
|
||
for name in header_names:
|
||
signing_string += f"{name}:{header_values[name]}\n"
|
||
|
||
expected = base64.b64encode(
|
||
hmac.new(
|
||
secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256
|
||
).digest()
|
||
).decode("utf-8")
|
||
|
||
if not hmac.compare_digest(signature, expected):
|
||
logger.warning(
|
||
f"[Vidu] 签名验证失败: callback_url={callback_url}, "
|
||
f"signing_string={repr(signing_string)}, "
|
||
f"expected={expected[:20]}..., received={signature[:20]}..."
|
||
)
|
||
return False
|
||
|
||
return True
|
||
|
||
async def verify_nonce(
|
||
self,
|
||
headers: dict[str, str],
|
||
redis: Any,
|
||
) -> bool:
|
||
"""验证 Vidu 回调 nonce 防重放"""
|
||
nonce = headers.get("x-request-nonce")
|
||
if not nonce:
|
||
return False
|
||
nonce_key = f"vidu:callback_nonce:{nonce}"
|
||
if await redis.exists(nonce_key):
|
||
return False
|
||
await redis.setex(nonce_key, 300, "1")
|
||
return True
|
||
|
||
async def parse_callback(self, body: bytes) -> TaskStatus:
|
||
"""解析 Vidu 回调体"""
|
||
data = json.loads(body)
|
||
task_id = data.get("id") or data.get("task_id")
|
||
state = data.get("state")
|
||
creations = data.get("creations", [])
|
||
|
||
video_url = None
|
||
if state == "success" and creations:
|
||
video_url = creations[0].get("url")
|
||
|
||
return TaskStatus(
|
||
state=self.normalize_state(state),
|
||
result=(
|
||
{"video_url": video_url, "creations": creations, "task_id": task_id}
|
||
if video_url
|
||
else {"task_id": task_id}
|
||
),
|
||
error_message=(
|
||
(data.get("err_code") or data.get("message")) if state == "failed" else None
|
||
),
|
||
)
|