Files
meijiaka-zy/python-api/app/ai/adapters/vidu_adapter.py
T
2026-05-04 19:18:22 +08:00

267 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu Adapter
============
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
直接接入 ViduProvider,提供标准 Protocol 接口。
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import logging
from urllib.parse import urlparse
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.ai.adapters.constants import Method
from app.ai.providers.vidu_provider import ViduProvider
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
"""Vidu 平台标准 Adapter"""
platform_id = "vidu"
def __init__(self, provider: ViduProvider):
self.provider = provider
# ── PlatformAdapter ──
async def health(self) -> AdapterResponse:
try:
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
# 实际上会 404,但只要网络通就说明服务可用
await self.provider.query_task("health-check")
return AdapterResponse(success=True)
except PlatformError as e:
if e.error_type == PlatformErrorType.NOT_FOUND:
return AdapterResponse(success=True)
return AdapterResponse(
success=False,
error_message=str(e),
retryable=e.retryable,
)
except Exception as e:
return AdapterResponse(
success=False,
error_message=str(e),
retryable=False,
)
async def close(self) -> None:
await self.provider.close()
# ── SyncCapable ──
async def call(self, method: str, payload: dict) -> AdapterResponse:
try:
if method == Method.TTS:
result = await self.provider.tts_sync(
text=payload["text"],
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
pitch=payload.get("pitch", 0),
emotion=payload.get("emotion"),
)
return AdapterResponse(
success=True,
data={"audio_url": result.get("file_url")},
)
elif method == Method.CLONE_VOICE:
result = await self.provider.clone_voice(
audio_url=payload["audio_url"],
voice_id=payload["voice_id"],
text=payload.get("text"),
)
return AdapterResponse(
success=True,
data={
"voice_id": result.get("voice_id"),
"demo_audio": result.get("demo_audio"),
},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的方法: {method}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {method} 调用失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── TaskCapable ──
async def submit(
self,
task_type: str,
payload: dict,
callback_url: str | None = None,
) -> AdapterResponse:
try:
if task_type == Method.LIP_SYNC:
result = await self.provider.lip_sync(
video_url=payload["video_url"],
audio_url=payload.get("audio_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
ref_photo_url=payload.get("ref_photo_url"),
callback_url=callback_url,
)
return AdapterResponse(
success=True,
data={"task_id": result.get("task_id")},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的任务类型: {task_type}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {task_type} 提交失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
async def query(self, platform_task_id: str) -> TaskStatus:
try:
result = await self.provider.query_task(platform_task_id)
state = result.get("state", "unknown")
# Vidu 状态映射到标准状态
state_mapping = {
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
creations = result.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=state_mapping.get(state, "failed"),
result={"video_url": video_url, "creations": creations} if video_url else None,
error_message=result.get("message") if state == "failed" else None,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu 任务查询失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── CallbackCapable ──
async def verify_signature(
self,
headers: dict[str, str],
body: bytes,
secret: str,
callback_url: str | None = None,
) -> bool:
"""验证 Vidu 回调 HMAC-SHA256 签名"""
signature = headers.get("X-HMAC-SIGNATURE")
algorithm = headers.get("X-HMAC-ALGORITHM")
access_key = headers.get("X-HMAC-ACCESS-KEY")
signed_headers_str = headers.get("X-HMAC-SIGNED-HEADERS")
date = headers.get("Date")
if not all([signature, algorithm, access_key, signed_headers_str, date]):
return False
if algorithm != "hmac-sha256":
return False
if access_key != "vidu":
return False
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
header_values: dict[str, str] = {}
for name in header_names:
value = headers.get(name)
if value is None:
return False
header_values[name] = value
# 构建 signingString(使用 callback_url 动态解析 path/query
parsed = urlparse(callback_url or "")
http_uri = parsed.path or "/"
canonical_query_string = parsed.query or ""
signing_string = (
f"POST\n"
f"{http_uri}\n"
f"{canonical_query_string}\n"
f"vidu\n"
f"{date}\n"
)
for name in header_names:
signing_string += f"{name}:{header_values[name]}\n"
expected = base64.b64encode(
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
).decode("utf-8")
return hmac.compare_digest(signature, expected)
async def parse_callback(self, body: bytes) -> TaskStatus:
"""解析 Vidu 回调体"""
data = json.loads(body)
task_id = data.get("id") or data.get("task_id")
state = data.get("state")
creations = data.get("creations", [])
state_mapping = {
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=state_mapping.get(state, "failed"),
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
error_message=data.get("message") if state == "failed" else None,
)