382 lines
13 KiB
Python
382 lines
13 KiB
Python
"""
|
||
Vidu API Provider
|
||
=================
|
||
|
||
封装 Vidu 语音/视频相关 HTTP API:
|
||
- 同步 TTS(/ent/v2/audio-tts)
|
||
- 声音复刻(/ent/v2/audio-clone)
|
||
- 视频生成(/ent/v2/lip-sync)
|
||
- 查询任务(/ent/v2/tasks/{id}/creations)
|
||
|
||
统一使用 httpx.AsyncClient,由 lifespan 统一管理生命周期。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# Vidu 错误码分类
|
||
_VIDU_AUDIT_ERROR_CODES = {
|
||
"TaskPromptPolicyViolation",
|
||
"AuditSubmitIllegal",
|
||
"CreationPolicyViolation",
|
||
"PhotoAuditNotPass",
|
||
"AuditFailed",
|
||
"ImageCheckBodyJointsFailed",
|
||
"ImageCheckFaceFailed",
|
||
"ImageObjectsUndetected",
|
||
"FaceDetectFailure",
|
||
"FaceDetectNotPass",
|
||
"NoFaceDetected",
|
||
"MultiFaceDetected",
|
||
}
|
||
|
||
_VIDU_RETRYABLE_ERROR_CODES = {
|
||
"InternalServiceFailure",
|
||
"ModelUnavailable",
|
||
"Unknown",
|
||
}
|
||
|
||
_VIDU_RATE_LIMIT_ERROR_CODES = {
|
||
"QuotaExceeded",
|
||
"TooManyRequests",
|
||
"SystemThrottling",
|
||
"OperationInProcess",
|
||
}
|
||
|
||
|
||
def _extract_vidu_error_code(message: str | None) -> str | None:
|
||
"""从 Vidu 错误信息中提取错误码"""
|
||
if not message:
|
||
return None
|
||
# Vidu 错误码格式:"ErrorCode: 中文描述"
|
||
return message.split(":")[0].strip() or None
|
||
|
||
|
||
def _map_vidu_error(
|
||
status: int,
|
||
message: str,
|
||
*,
|
||
err_code: str | None = None,
|
||
) -> PlatformError:
|
||
"""把 Vidu HTTP 错误映射为标准 PlatformError
|
||
|
||
优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
|
||
"""
|
||
raw_code = err_code or _extract_vidu_error_code(message)
|
||
|
||
# 1. 内容安全/审核类:不可重试
|
||
if raw_code in _VIDU_AUDIT_ERROR_CODES:
|
||
return PlatformError(
|
||
message=message,
|
||
platform="vidu",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.CONTENT_VIOLATION,
|
||
status_code=status,
|
||
raw_code=raw_code,
|
||
)
|
||
|
||
# 2. 平台内部/模型不可用:可重试
|
||
if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
|
||
return PlatformError(
|
||
message=message,
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.SERVER_ERROR,
|
||
status_code=status,
|
||
raw_code=raw_code,
|
||
)
|
||
|
||
# 3. 限流类:可重试
|
||
if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
|
||
return PlatformError(
|
||
message=message,
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.RATE_LIMIT,
|
||
status_code=status,
|
||
raw_code=raw_code,
|
||
)
|
||
|
||
# 4. HTTP status 兜底
|
||
mapping = {
|
||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||
403: (PlatformErrorType.AUTH_FAILED, False),
|
||
400: (PlatformErrorType.BAD_REQUEST, False),
|
||
404: (PlatformErrorType.NOT_FOUND, False),
|
||
500: (PlatformErrorType.SERVER_ERROR, True),
|
||
502: (PlatformErrorType.SERVER_ERROR, True),
|
||
503: (PlatformErrorType.SERVER_ERROR, True),
|
||
}
|
||
error_type, retryable = mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||
return PlatformError(
|
||
message=message,
|
||
platform="vidu",
|
||
retryable=retryable,
|
||
error_type=error_type,
|
||
status_code=status,
|
||
raw_code=raw_code,
|
||
)
|
||
|
||
|
||
class ViduProvider:
|
||
"""Vidu API 客户端封装
|
||
|
||
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
base_url: str | None = None,
|
||
client: httpx.AsyncClient | None = None,
|
||
):
|
||
settings = get_settings()
|
||
self.api_key = api_key or settings.VIDU_API_KEY
|
||
if base_url:
|
||
self.base_url = base_url.rstrip("/")
|
||
else:
|
||
from app.core.platform_config import get_platform_config_loader
|
||
|
||
platform_config = get_platform_config_loader().get_platform("vidu")
|
||
self.base_url = (
|
||
platform_config.base_url if platform_config else "https://api.vidu.cn"
|
||
).rstrip("/")
|
||
|
||
if not self.api_key:
|
||
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
self._owns_client = False
|
||
# 外部传入的 client 也要补认证头(main.py / scheduler 共用 client 场景)
|
||
self.client.headers["Authorization"] = f"Token {self.api_key}"
|
||
self.client.headers["Content-Type"] = "application/json"
|
||
else:
|
||
self.client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
|
||
headers={
|
||
"Authorization": f"Token {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
self._owns_client = True
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP Client,释放连接池。仅在自己创建 Client 时关闭。"""
|
||
if self._owns_client and not self.client.is_closed:
|
||
await self.client.aclose()
|
||
|
||
# ==================== TTS 语音合成 ====================
|
||
|
||
async def tts_sync(
|
||
self,
|
||
text: str,
|
||
voice_id: str,
|
||
speed: float = 1.0,
|
||
volume: int = 0,
|
||
pitch: int = 0,
|
||
emotion: str | None = None,
|
||
pronunciation_dict_tone: list[str] | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""同步语音合成
|
||
|
||
POST /ent/v2/audio-tts
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/audio-tts"
|
||
|
||
body: dict[str, Any] = {
|
||
"text": text,
|
||
"voice_setting_voice_id": voice_id,
|
||
"voice_setting_speed": speed,
|
||
"voice_setting_volume": volume,
|
||
"voice_setting_pitch": pitch,
|
||
}
|
||
if emotion:
|
||
body["voice_setting_emotion"] = emotion
|
||
if pronunciation_dict_tone:
|
||
body["pronunciation_dict_tone"] = pronunciation_dict_tone
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
logger.info(f"[Vidu TTS] 请求参数: text_length={len(text)}")
|
||
|
||
logger.info(f"[Vidu TTS] 提交请求: url={url}, body={body}")
|
||
|
||
try:
|
||
# 文本较长时同步合成可能耗时较久,超时时间放宽到 120 秒
|
||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(
|
||
f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||
)
|
||
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu TTS] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu TTS 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 声音复刻 ====================
|
||
|
||
async def clone_voice(
|
||
self,
|
||
audio_url: str,
|
||
voice_id: str,
|
||
text: str,
|
||
prompt_audio_url: str | None = None,
|
||
prompt_text: str | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""声音复刻(同步接口)
|
||
|
||
POST /ent/v2/audio-clone
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/audio-clone"
|
||
|
||
body: dict[str, Any] = {
|
||
"audio_url": audio_url,
|
||
"voice_id": voice_id,
|
||
"text": text,
|
||
}
|
||
if prompt_audio_url:
|
||
body["prompt_audio_url"] = prompt_audio_url
|
||
if prompt_text:
|
||
body["prompt_text"] = prompt_text
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
try:
|
||
# 声音复刻处理音频可能耗时较久,超时时间放宽到 120 秒
|
||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(
|
||
f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||
)
|
||
raise _map_vidu_error(
|
||
resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
|
||
)
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu Clone] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu Clone 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 视频生成 ====================
|
||
|
||
async def lip_sync(
|
||
self,
|
||
video_url: str,
|
||
audio_url: str | None = None,
|
||
text: str | None = None,
|
||
voice_id: str | None = None,
|
||
speed: float = 1.0,
|
||
volume: int = 0,
|
||
ref_photo_url: str | None = None,
|
||
callback_url: str | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""视频生成(异步接口)
|
||
|
||
POST /ent/v2/lip-sync
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/lip-sync"
|
||
|
||
body: dict[str, Any] = {"video_url": video_url}
|
||
|
||
if audio_url:
|
||
body["audio_url"] = audio_url
|
||
if text:
|
||
body["text"] = text
|
||
if voice_id:
|
||
body["voice_id"] = voice_id
|
||
if speed != 1.0:
|
||
body["speed"] = speed
|
||
if volume != 0:
|
||
body["volume"] = volume
|
||
if ref_photo_url:
|
||
body["ref_photo_url"] = ref_photo_url
|
||
if callback_url:
|
||
body["callback_url"] = callback_url
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
try:
|
||
resp = await self.client.post(url, json=body)
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(
|
||
f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||
)
|
||
raise _map_vidu_error(
|
||
resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
|
||
)
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu LipSync] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu LipSync 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 查询任务 ====================
|
||
|
||
async def query_task(self, task_id: str) -> dict[str, Any]:
|
||
"""查询任务状态及生成物
|
||
|
||
GET /ent/v2/tasks/{task_id}/creations
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/tasks/{task_id}/creations"
|
||
|
||
try:
|
||
resp = await self.client.get(url)
|
||
data = resp.json()
|
||
if resp.status_code != 200:
|
||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(
|
||
f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||
)
|
||
raise _map_vidu_error(
|
||
resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
|
||
)
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu Query] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu Query 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|