e58159fc42
Phase 1: 异常体系统一 - 新增 PlatformError / PlatformErrorType 标准定义 - 改造所有 Provider 异常抛出为 PlatformError - 注册全局 PlatformError exception handler Phase 2: Adapter Protocol - 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable) - 新增 app/ai/adapters/constants.py(Method 常量) - 新增 PlatformConfigLoader(config/platform-config.yaml) Phase 3: HTTP Client 统一 - ViduProvider 从 aiohttp 迁移到 httpx(注入方式) - VolcengineCaptionService 改为注入 http_client - lifespan 统一管理所有 Client 创建和关闭 Phase 4: Gateway 骨架 + Adapter 实现 - 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter - 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook) - 新增 LLMGateway(带 Fallback 降级链) - lifespan 注册所有 Adapter 和 Gateway Phase 6: 清理与验证 - 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL - Provider 改为从 PlatformConfigLoader 读取 base_url - 清理 volcengine_caption_service 全局单例 - config_loader 默认路径改为 platform-config.yaml - Scheduler 注入共享 HTTP client - vidu.py 回调路由使用 Adapter 验签和解析 - ruff 全量通过,应用启动测试通过
266 lines
8.6 KiB
Python
266 lines
8.6 KiB
Python
"""
|
||
Vidu Adapter
|
||
============
|
||
|
||
实现 PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable。
|
||
直接接入 ViduProvider,提供标准 Protocol 接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import logging
|
||
|
||
from app.ai.adapters.base import (
|
||
AdapterResponse,
|
||
CallbackCapable,
|
||
PlatformAdapter,
|
||
SyncCapable,
|
||
TaskCapable,
|
||
TaskStatus,
|
||
)
|
||
from app.ai.adapters.constants import Method
|
||
from app.ai.providers.vidu_provider import ViduProvider
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||
"""Vidu 平台标准 Adapter"""
|
||
|
||
platform_id = "vidu"
|
||
|
||
def __init__(self, provider: ViduProvider):
|
||
self.provider = provider
|
||
|
||
# ── PlatformAdapter ──
|
||
|
||
async def health(self) -> AdapterResponse:
|
||
try:
|
||
# Vidu 没有专门的健康检查接口,用查询一个空任务测试连通性
|
||
# 实际上会 404,但只要网络通就说明服务可用
|
||
await self.provider.query_task("health-check")
|
||
return AdapterResponse(success=True)
|
||
except PlatformError as e:
|
||
if e.error_type == PlatformErrorType.NOT_FOUND:
|
||
return AdapterResponse(success=True)
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=str(e),
|
||
retryable=e.retryable,
|
||
)
|
||
except Exception as e:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=str(e),
|
||
retryable=False,
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
await self.provider.close()
|
||
|
||
# ── SyncCapable ──
|
||
|
||
async def call(self, method: str, payload: dict) -> AdapterResponse:
|
||
try:
|
||
if method == Method.TTS:
|
||
result = await self.provider.tts_sync(
|
||
text=payload["text"],
|
||
voice_id=payload.get("voice_id", "tianxin_xiaoling"),
|
||
speed=payload.get("speed", 1.0),
|
||
volume=payload.get("volume", 0),
|
||
pitch=payload.get("pitch", 0),
|
||
emotion=payload.get("emotion"),
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={"audio_url": result.get("file_url")},
|
||
)
|
||
|
||
elif method == Method.CLONE_VOICE:
|
||
result = await self.provider.clone_voice(
|
||
audio_url=payload["audio_url"],
|
||
voice_id=payload["voice_id"],
|
||
text=payload.get("text"),
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={
|
||
"voice_id": result.get("voice_id"),
|
||
"demo_audio": result.get("demo_audio"),
|
||
},
|
||
)
|
||
|
||
else:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=f"不支持的方法: {method}",
|
||
retryable=False,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu {method} 调用失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
# ── TaskCapable ──
|
||
|
||
async def submit(
|
||
self,
|
||
task_type: str,
|
||
payload: dict,
|
||
callback_url: str | None = None,
|
||
) -> AdapterResponse:
|
||
try:
|
||
if task_type == Method.LIP_SYNC:
|
||
result = await self.provider.lip_sync(
|
||
video_url=payload["video_url"],
|
||
audio_url=payload.get("audio_url"),
|
||
text=payload.get("text"),
|
||
voice_id=payload.get("voice_id"),
|
||
speed=payload.get("speed", 1.0),
|
||
volume=payload.get("volume", 0),
|
||
ref_photo_url=payload.get("ref_photo_url"),
|
||
callback_url=callback_url,
|
||
)
|
||
return AdapterResponse(
|
||
success=True,
|
||
data={"task_id": result.get("task_id")},
|
||
)
|
||
|
||
else:
|
||
return AdapterResponse(
|
||
success=False,
|
||
error_message=f"不支持的任务类型: {task_type}",
|
||
retryable=False,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu {task_type} 提交失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
async def query(self, platform_job_id: str) -> TaskStatus:
|
||
try:
|
||
result = await self.provider.query_task(platform_job_id)
|
||
state = result.get("state", "unknown")
|
||
|
||
# Vidu 状态映射到标准状态
|
||
state_mapping = {
|
||
"pending": "processing",
|
||
"processing": "processing",
|
||
"success": "completed",
|
||
"failed": "failed",
|
||
}
|
||
|
||
creations = result.get("creations", [])
|
||
video_url = None
|
||
if state == "success" and creations:
|
||
video_url = creations[0].get("url")
|
||
|
||
return TaskStatus(
|
||
state=state_mapping.get(state, "failed"),
|
||
result={"video_url": video_url, "creations": creations} if video_url else None,
|
||
error_message=result.get("message") if state == "failed" else None,
|
||
)
|
||
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise PlatformError(
|
||
f"Vidu 任务查询失败: {e}",
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
) from e
|
||
|
||
# ── CallbackCapable ──
|
||
|
||
async def verify_signature(
|
||
self,
|
||
headers: dict[str, str],
|
||
body: bytes,
|
||
secret: str,
|
||
callback_url: str | None = None,
|
||
) -> bool:
|
||
"""验证 Vidu 回调 HMAC-SHA256 签名"""
|
||
signature = headers.get("X-HMAC-SIGNATURE")
|
||
algorithm = headers.get("X-HMAC-ALGORITHM")
|
||
access_key = headers.get("X-HMAC-ACCESS-KEY")
|
||
signed_headers_str = headers.get("X-HMAC-SIGNED-HEADERS")
|
||
date = headers.get("Date")
|
||
|
||
if not all([signature, algorithm, access_key, signed_headers_str, date]):
|
||
return False
|
||
if algorithm != "hmac-sha256":
|
||
return False
|
||
if access_key != "vidu":
|
||
return False
|
||
|
||
header_names = [h.strip() for h in signed_headers_str.split(";") if h.strip()]
|
||
header_values: dict[str, str] = {}
|
||
for name in header_names:
|
||
value = headers.get(name)
|
||
if value is None:
|
||
return False
|
||
header_values[name] = value
|
||
|
||
# 构建 signingString(使用 callback_url 动态解析 path/query)
|
||
parsed = urlparse(callback_url or "")
|
||
http_uri = parsed.path or "/"
|
||
canonical_query_string = parsed.query or ""
|
||
|
||
signing_string = (
|
||
f"POST\n"
|
||
f"{http_uri}\n"
|
||
f"{canonical_query_string}\n"
|
||
f"vidu\n"
|
||
f"{date}\n"
|
||
)
|
||
for name in header_names:
|
||
signing_string += f"{name}:{header_values[name]}\n"
|
||
|
||
expected = base64.b64encode(
|
||
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
|
||
).decode("utf-8")
|
||
|
||
return hmac.compare_digest(signature, expected)
|
||
|
||
async def parse_callback(self, body: bytes) -> TaskStatus:
|
||
"""解析 Vidu 回调体"""
|
||
data = json.loads(body)
|
||
task_id = data.get("id") or data.get("task_id")
|
||
state = data.get("state")
|
||
creations = data.get("creations", [])
|
||
|
||
state_mapping = {
|
||
"pending": "processing",
|
||
"processing": "processing",
|
||
"success": "completed",
|
||
"failed": "failed",
|
||
}
|
||
|
||
video_url = None
|
||
if state == "success" and creations:
|
||
video_url = creations[0].get("url")
|
||
|
||
return TaskStatus(
|
||
state=state_mapping.get(state, "failed"),
|
||
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
|
||
error_message=data.get("message") if state == "failed" else None,
|
||
)
|