Files
meijiaka-zy/python-api/app/ai/adapters/vidu_adapter.py
T
小鱼开发 9ddcb2347d ci: 构建流程优化 - test环境固定/平台选择/版本号自动更新/缓存
- VITE_API_BASE_URL 固定为 dev.tapi.meijiaka.cn(test环境)
- 添加 platform 选择(all/macos/windows),支持单独构建
- 添加版本号自动更新(tauri.conf.json + Cargo.toml)
- 添加 Rust + Node 构建缓存,节省CI额度
- 修复 ViduAdapter parse_callback 运算符优先级bug
- 修复 ViduProvider tts_sync 日志前缀误写
- VoiceSynthesis 空状态UI优化
2026-05-19 15:17:36 +08:00

294 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu Adapter
============
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
直接接入 ViduProvider,提供标准 Protocol 接口。
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import logging
from typing import Any
from urllib.parse import urlparse
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.ai.adapters.constants import Method
from app.ai.providers.vidu_provider import ViduProvider
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
"""Vidu 平台标准 Adapter"""
platform_id = "vidu"
# Vidu 原生状态 ↔ 标准状态 映射表
_VIDU_TO_STANDARD = {
"created": "processing",
"queueing": "processing",
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
_STANDARD_TO_VIDU = {
"completed": "success",
"processing": "processing",
"pending": "pending",
"failed": "failed",
}
def __init__(self, provider: ViduProvider):
self.provider = provider
@classmethod
def normalize_state(cls, vidu_state: str) -> str:
"""Vidu 原生状态 → 标准状态(processing / completed / failed"""
return cls._VIDU_TO_STANDARD.get(vidu_state, "failed")
@classmethod
def denormalize_state(cls, standard_state: str) -> str:
"""标准状态 → Vidu 原生状态(success / processing / pending / failed"""
return cls._STANDARD_TO_VIDU.get(standard_state, standard_state)
# ── PlatformAdapter ──
async def health(self) -> AdapterResponse:
try:
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
# 实际上会 404,但只要网络通就说明服务可用
await self.provider.query_task("health-check")
return AdapterResponse(success=True)
except PlatformError as e:
if e.error_type == PlatformErrorType.NOT_FOUND:
return AdapterResponse(success=True)
return AdapterResponse(
success=False,
error_message=str(e),
retryable=e.retryable,
)
except Exception as e:
return AdapterResponse(
success=False,
error_message=str(e),
retryable=False,
)
async def close(self) -> None:
await self.provider.close()
# ── SyncCapable ──
async def call(self, method: str, payload: dict) -> AdapterResponse:
try:
if method == Method.TTS:
result = await self.provider.tts_sync(
text=payload["text"],
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
pitch=payload.get("pitch", 0),
emotion=payload.get("emotion"),
)
return AdapterResponse(
success=True,
data={"audio_url": result.get("file_url")},
)
elif method == Method.CLONE_VOICE:
result = await self.provider.clone_voice(
audio_url=payload["audio_url"],
voice_id=payload["voice_id"],
text=payload.get("text"),
)
return AdapterResponse(
success=True,
data={
"voice_id": result.get("voice_id"),
"demo_audio": result.get("demo_audio"),
},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的方法: {method}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {method} 调用失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── TaskCapable ──
async def submit(
self,
task_type: str,
payload: dict,
callback_url: str | None = None,
) -> AdapterResponse:
try:
if task_type == Method.LIP_SYNC:
result = await self.provider.lip_sync(
video_url=payload["video_url"],
audio_url=payload.get("audio_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
ref_photo_url=payload.get("ref_photo_url"),
callback_url=callback_url,
)
return AdapterResponse(
success=True,
data={"task_id": result.get("task_id")},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的任务类型: {task_type}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {task_type} 提交失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
async def query(self, platform_task_id: str) -> TaskStatus:
try:
result = await self.provider.query_task(platform_task_id)
state = result.get("state", "unknown")
creations = result.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=self.normalize_state(state),
result={"video_url": video_url, "creations": creations} if video_url else None,
error_message=result.get("message") if state == "failed" else None,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu 任务查询失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── CallbackCapable ──
async def verify_signature(
self,
headers: dict[str, str],
body: bytes,
secret: str,
callback_url: str | None = None,
) -> bool:
"""验证 Vidu 回调 HMAC-SHA256 签名"""
signature = headers.get("X-HMAC-SIGNATURE")
algorithm = headers.get("X-HMAC-ALGORITHM")
access_key = headers.get("X-HMAC-ACCESS-KEY")
signed_headers_str = headers.get("X-HMAC-SIGNED-HEADERS")
date = headers.get("Date")
if not all([signature, algorithm, access_key, signed_headers_str, date]):
return False
if algorithm != "hmac-sha256":
return False
if access_key != "vidu":
return False
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
header_values: dict[str, str] = {}
for name in header_names:
value = headers.get(name)
if value is None:
return False
header_values[name] = value
# 构建 signingString(使用 callback_url 动态解析 path/query
parsed = urlparse(callback_url or "")
http_uri = parsed.path or "/"
canonical_query_string = parsed.query or ""
signing_string = (
f"POST\n"
f"{http_uri}\n"
f"{canonical_query_string}\n"
f"vidu\n"
f"{date}\n"
)
for name in header_names:
signing_string += f"{name}:{header_values[name]}\n"
expected = base64.b64encode(
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
).decode("utf-8")
return hmac.compare_digest(signature, expected)
async def verify_nonce(
self,
headers: dict[str, str],
redis: Any,
) -> bool:
"""验证 Vidu 回调 nonce 防重放"""
nonce = headers.get("x-request-nonce")
if not nonce:
return False
nonce_key = f"vidu:callback_nonce:{nonce}"
if await redis.exists(nonce_key):
return False
await redis.setex(nonce_key, 300, "1")
return True
async def parse_callback(self, body: bytes) -> TaskStatus:
"""解析 Vidu 回调体"""
data = json.loads(body)
task_id = data.get("id") or data.get("task_id")
state = data.get("state")
creations = data.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=self.normalize_state(state),
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
error_message=(data.get("err_code") or data.get("message")) if state == "failed" else None,
)