Files
meijiaka-zy/python-api/app/ai/adapters/vidu_adapter.py
T

321 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu Adapter
============
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
直接接入 ViduProvider,提供标准 Protocol 接口。
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import logging
from typing import Any
from urllib.parse import urlparse
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.ai.adapters.constants import Method
from app.ai.providers.vidu_provider import ViduProvider
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
"""Vidu 平台标准 Adapter"""
platform_id = "vidu"
# Vidu 原生状态 ↔ 标准状态 映射表
_VIDU_TO_STANDARD = {
"created": "processing",
"queueing": "processing",
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
_STANDARD_TO_VIDU = {
"completed": "success",
"processing": "processing",
"pending": "pending",
"failed": "failed",
}
def __init__(self, provider: ViduProvider):
self.provider = provider
@classmethod
def normalize_state(cls, vidu_state: str) -> str:
"""Vidu 原生状态 → 标准状态(processing / completed / failed"""
return cls._VIDU_TO_STANDARD.get(vidu_state, "failed")
@classmethod
def denormalize_state(cls, standard_state: str) -> str:
"""标准状态 → Vidu 原生状态(success / processing / pending / failed"""
return cls._STANDARD_TO_VIDU.get(standard_state, standard_state)
# ── PlatformAdapter ──
async def health(self) -> AdapterResponse:
try:
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
# 实际上会 404,但只要网络通就说明服务可用
await self.provider.query_task("health-check")
return AdapterResponse(success=True)
except PlatformError as e:
if e.error_type == PlatformErrorType.NOT_FOUND:
return AdapterResponse(success=True)
return AdapterResponse(
success=False,
error_message=str(e),
retryable=e.retryable,
)
except Exception as e:
return AdapterResponse(
success=False,
error_message=str(e),
retryable=False,
)
async def close(self) -> None:
await self.provider.close()
# ── SyncCapable ──
async def call(self, method: str, payload: dict) -> AdapterResponse:
try:
if method == Method.TTS:
result = await self.provider.tts_sync(
text=payload["text"],
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
pitch=payload.get("pitch", 0),
emotion=payload.get("emotion"),
)
return AdapterResponse(
success=True,
data={"audio_url": result.get("file_url")},
)
elif method == Method.CLONE_VOICE:
result = await self.provider.clone_voice(
audio_url=payload["audio_url"],
voice_id=payload["voice_id"],
text=payload.get("text") or "",
)
return AdapterResponse(
success=True,
data={
"voice_id": result.get("voice_id"),
"demo_audio": result.get("demo_audio"),
},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的方法: {method}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {method} 调用失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── TaskCapable ──
async def submit(
self,
task_type: str,
payload: dict,
callback_url: str | None = None,
) -> AdapterResponse:
try:
if task_type == Method.LIP_SYNC:
result = await self.provider.lip_sync(
video_url=payload["video_url"],
audio_url=payload.get("audio_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
ref_photo_url=payload.get("ref_photo_url"),
callback_url=callback_url,
)
return AdapterResponse(
success=True,
data={"task_id": result.get("task_id")},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的任务类型: {task_type}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {task_type} 提交失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
async def query(self, platform_task_id: str) -> TaskStatus:
try:
result = await self.provider.query_task(platform_task_id)
state = result.get("state", "unknown")
creations = result.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=self.normalize_state(state),
result={"video_url": video_url, "creations": creations} if video_url else None,
error_message=result.get("message") if state == "failed" else None,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu 任务查询失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── CallbackCapable ──
async def verify_signature(
self,
headers: dict[str, str],
body: bytes,
secret: str,
callback_url: str | None = None,
) -> bool:
"""验证 Vidu 回调 HMAC-SHA256 签名"""
import logging
logger = logging.getLogger(__name__)
# HTTP 头大小写不敏感:建立小写 key 的查找表
headers_lower = {k.lower(): v for k, v in headers.items()}
signature = headers_lower.get("x-hmac-signature")
algorithm = headers_lower.get("x-hmac-algorithm")
access_key = headers_lower.get("x-hmac-access-key")
signed_headers_str = headers_lower.get("x-hmac-signed-headers")
date = headers_lower.get("date")
if not all([signature, algorithm, access_key, signed_headers_str, date]):
logger.warning(f"[Vidu] 签名验证失败: 缺少必要头, headers={list(headers.keys())}")
return False
assert signature is not None
assert signed_headers_str is not None
assert date is not None
if algorithm != "hmac-sha256":
logger.warning(f"[Vidu] 签名验证失败: 不支持的算法 {algorithm}")
return False
if access_key != "vidu":
logger.warning(f"[Vidu] 签名验证失败: access_key 不匹配 {access_key}")
return False
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
header_values: dict[str, str] = {}
for name in header_names:
# 签名头名也可能大小写不一致,统一用小写查找
value = headers_lower.get(name.lower())
if value is None:
logger.warning(f"[Vidu] 签名验证失败: 缺少签名头 {name}")
return False
header_values[name] = value
# 构建 signingString(使用 callback_url 动态解析 path/query
parsed = urlparse(callback_url or "")
http_uri = parsed.path or "/"
canonical_query_string = parsed.query or ""
signing_string = (
f"POST\n" f"{http_uri}\n" f"{canonical_query_string}\n" f"vidu\n" f"{date}\n"
)
for name in header_names:
signing_string += f"{name}:{header_values[name]}\n"
expected = base64.b64encode(
hmac.new(
secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256
).digest()
).decode("utf-8")
if not hmac.compare_digest(signature, expected):
logger.warning(
f"[Vidu] 签名验证失败: callback_url={callback_url}, "
f"signing_string={repr(signing_string)}, "
f"expected={expected[:20]}..., received={signature[:20]}..."
)
return False
return True
async def verify_nonce(
self,
headers: dict[str, str],
redis: Any,
) -> bool:
"""验证 Vidu 回调 nonce 防重放"""
nonce = headers.get("x-request-nonce")
if not nonce:
return False
nonce_key = f"vidu:callback_nonce:{nonce}"
if await redis.exists(nonce_key):
return False
await redis.setex(nonce_key, 300, "1")
return True
async def parse_callback(self, body: bytes) -> TaskStatus:
"""解析 Vidu 回调体"""
data = json.loads(body)
task_id = data.get("id") or data.get("task_id")
state = data.get("state")
creations = data.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=self.normalize_state(state),
result=(
{"video_url": video_url, "creations": creations, "task_id": task_id}
if video_url
else {"task_id": task_id}
),
error_message=(
(data.get("err_code") or data.get("message")) if state == "failed" else None
),
)