Files
meijiaka-zy/python-api/app/ai/adapters/vidu_adapter.py
T
小鱼开发 e58159fc42 refactor: 第三方平台架构改造(Adapter Protocol + Gateway)
Phase 1: 异常体系统一
- 新增 PlatformError / PlatformErrorType 标准定义
- 改造所有 Provider 异常抛出为 PlatformError
- 注册全局 PlatformError exception handler

Phase 2: Adapter Protocol
- 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable)
- 新增 app/ai/adapters/constants.py(Method 常量)
- 新增 PlatformConfigLoader(config/platform-config.yaml)

Phase 3: HTTP Client 统一
- ViduProvider 从 aiohttp 迁移到 httpx(注入方式)
- VolcengineCaptionService 改为注入 http_client
- lifespan 统一管理所有 Client 创建和关闭

Phase 4: Gateway 骨架 + Adapter 实现
- 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter
- 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook)
- 新增 LLMGateway(带 Fallback 降级链)
- lifespan 注册所有 Adapter 和 Gateway

Phase 6: 清理与验证
- 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL
- Provider 改为从 PlatformConfigLoader 读取 base_url
- 清理 volcengine_caption_service 全局单例
- config_loader 默认路径改为 platform-config.yaml
- Scheduler 注入共享 HTTP client
- vidu.py 回调路由使用 Adapter 验签和解析
- ruff 全量通过,应用启动测试通过
2026-05-04 16:07:16 +08:00

266 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu Adapter
============
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
直接接入 ViduProvider,提供标准 Protocol 接口。
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import logging
from app.ai.adapters.base import (
AdapterResponse,
CallbackCapable,
PlatformAdapter,
SyncCapable,
TaskCapable,
TaskStatus,
)
from app.ai.adapters.constants import Method
from app.ai.providers.vidu_provider import ViduProvider
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
"""Vidu 平台标准 Adapter"""
platform_id = "vidu"
def __init__(self, provider: ViduProvider):
self.provider = provider
# ── PlatformAdapter ──
async def health(self) -> AdapterResponse:
try:
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
# 实际上会 404,但只要网络通就说明服务可用
await self.provider.query_task("health-check")
return AdapterResponse(success=True)
except PlatformError as e:
if e.error_type == PlatformErrorType.NOT_FOUND:
return AdapterResponse(success=True)
return AdapterResponse(
success=False,
error_message=str(e),
retryable=e.retryable,
)
except Exception as e:
return AdapterResponse(
success=False,
error_message=str(e),
retryable=False,
)
async def close(self) -> None:
await self.provider.close()
# ── SyncCapable ──
async def call(self, method: str, payload: dict) -> AdapterResponse:
try:
if method == Method.TTS:
result = await self.provider.tts_sync(
text=payload["text"],
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
pitch=payload.get("pitch", 0),
emotion=payload.get("emotion"),
)
return AdapterResponse(
success=True,
data={"audio_url": result.get("file_url")},
)
elif method == Method.CLONE_VOICE:
result = await self.provider.clone_voice(
audio_url=payload["audio_url"],
voice_id=payload["voice_id"],
text=payload.get("text"),
)
return AdapterResponse(
success=True,
data={
"voice_id": result.get("voice_id"),
"demo_audio": result.get("demo_audio"),
},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的方法: {method}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {method} 调用失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── TaskCapable ──
async def submit(
self,
task_type: str,
payload: dict,
callback_url: str | None = None,
) -> AdapterResponse:
try:
if task_type == Method.LIP_SYNC:
result = await self.provider.lip_sync(
video_url=payload["video_url"],
audio_url=payload.get("audio_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
speed=payload.get("speed", 1.0),
volume=payload.get("volume", 0),
ref_photo_url=payload.get("ref_photo_url"),
callback_url=callback_url,
)
return AdapterResponse(
success=True,
data={"task_id": result.get("task_id")},
)
else:
return AdapterResponse(
success=False,
error_message=f"不支持的任务类型: {task_type}",
retryable=False,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu {task_type} 提交失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
async def query(self, platform_job_id: str) -> TaskStatus:
try:
result = await self.provider.query_task(platform_job_id)
state = result.get("state", "unknown")
# Vidu 状态映射到标准状态
state_mapping = {
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
creations = result.get("creations", [])
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=state_mapping.get(state, "failed"),
result={"video_url": video_url, "creations": creations} if video_url else None,
error_message=result.get("message") if state == "failed" else None,
)
except PlatformError:
raise
except Exception as e:
raise PlatformError(
f"Vidu 任务查询失败: {e}",
platform="vidu",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
) from e
# ── CallbackCapable ──
async def verify_signature(
self,
headers: dict[str, str],
body: bytes,
secret: str,
callback_url: str | None = None,
) -> bool:
"""验证 Vidu 回调 HMAC-SHA256 签名"""
signature = headers.get("X-HMAC-SIGNATURE")
algorithm = headers.get("X-HMAC-ALGORITHM")
access_key = headers.get("X-HMAC-ACCESS-KEY")
signed_headers_str = headers.get("X-HMAC-SIGNED-HEADERS")
date = headers.get("Date")
if not all([signature, algorithm, access_key, signed_headers_str, date]):
return False
if algorithm != "hmac-sha256":
return False
if access_key != "vidu":
return False
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
header_values: dict[str, str] = {}
for name in header_names:
value = headers.get(name)
if value is None:
return False
header_values[name] = value
# 构建 signingString(使用 callback_url 动态解析 path/query
parsed = urlparse(callback_url or "")
http_uri = parsed.path or "/"
canonical_query_string = parsed.query or ""
signing_string = (
f"POST\n"
f"{http_uri}\n"
f"{canonical_query_string}\n"
f"vidu\n"
f"{date}\n"
)
for name in header_names:
signing_string += f"{name}:{header_values[name]}\n"
expected = base64.b64encode(
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
).decode("utf-8")
return hmac.compare_digest(signature, expected)
async def parse_callback(self, body: bytes) -> TaskStatus:
"""解析 Vidu 回调体"""
data = json.loads(body)
task_id = data.get("id") or data.get("task_id")
state = data.get("state")
creations = data.get("creations", [])
state_mapping = {
"pending": "processing",
"processing": "processing",
"success": "completed",
"failed": "failed",
}
video_url = None
if state == "success" and creations:
video_url = creations[0].get("url")
return TaskStatus(
state=state_mapping.get(state, "failed"),
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
error_message=data.get("message") if state == "failed" else None,
)