Files
meijiaka-zy/python-api/app/ai/providers/vidu_provider.py
T

382 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu API Provider
=================
封装 Vidu 语音/视频相关 HTTP API
- 同步 TTS/ent/v2/audio-tts
- 声音复刻(/ent/v2/audio-clone
- 视频生成(/ent/v2/lip-sync
- 查询任务(/ent/v2/tasks/{id}/creations
统一使用 httpx.AsyncClient,由 lifespan 统一管理生命周期。
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from app.config import get_settings
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
# Vidu 错误码分类
_VIDU_AUDIT_ERROR_CODES = {
"TaskPromptPolicyViolation",
"AuditSubmitIllegal",
"CreationPolicyViolation",
"PhotoAuditNotPass",
"AuditFailed",
"ImageCheckBodyJointsFailed",
"ImageCheckFaceFailed",
"ImageObjectsUndetected",
"FaceDetectFailure",
"FaceDetectNotPass",
"NoFaceDetected",
"MultiFaceDetected",
}
_VIDU_RETRYABLE_ERROR_CODES = {
"InternalServiceFailure",
"ModelUnavailable",
"Unknown",
}
_VIDU_RATE_LIMIT_ERROR_CODES = {
"QuotaExceeded",
"TooManyRequests",
"SystemThrottling",
"OperationInProcess",
}
def _extract_vidu_error_code(message: str | None) -> str | None:
"""从 Vidu 错误信息中提取错误码"""
if not message:
return None
# Vidu 错误码格式:"ErrorCode: 中文描述"
return message.split(":")[0].strip() or None
def _map_vidu_error(
status: int,
message: str,
*,
err_code: str | None = None,
) -> PlatformError:
"""把 Vidu HTTP 错误映射为标准 PlatformError
优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
"""
raw_code = err_code or _extract_vidu_error_code(message)
# 1. 内容安全/审核类:不可重试
if raw_code in _VIDU_AUDIT_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
status_code=status,
raw_code=raw_code,
)
# 2. 平台内部/模型不可用:可重试
if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=True,
error_type=PlatformErrorType.SERVER_ERROR,
status_code=status,
raw_code=raw_code,
)
# 3. 限流类:可重试
if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=True,
error_type=PlatformErrorType.RATE_LIMIT,
status_code=status,
raw_code=raw_code,
)
# 4. HTTP status 兜底
mapping = {
429: (PlatformErrorType.RATE_LIMIT, True),
401: (PlatformErrorType.AUTH_FAILED, False),
403: (PlatformErrorType.AUTH_FAILED, False),
400: (PlatformErrorType.BAD_REQUEST, False),
404: (PlatformErrorType.NOT_FOUND, False),
500: (PlatformErrorType.SERVER_ERROR, True),
502: (PlatformErrorType.SERVER_ERROR, True),
503: (PlatformErrorType.SERVER_ERROR, True),
}
error_type, retryable = mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
message=message,
platform="vidu",
retryable=retryable,
error_type=error_type,
status_code=status,
raw_code=raw_code,
)
class ViduProvider:
"""Vidu API 客户端封装
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
"""
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
client: httpx.AsyncClient | None = None,
):
settings = get_settings()
self.api_key = api_key or settings.VIDU_API_KEY
if base_url:
self.base_url = base_url.rstrip("/")
else:
from app.core.platform_config import get_platform_config_loader
platform_config = get_platform_config_loader().get_platform("vidu")
self.base_url = (
platform_config.base_url if platform_config else "https://api.vidu.cn"
).rstrip("/")
if not self.api_key:
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
if client is not None:
self.client = client
self._owns_client = False
# 外部传入的 client 也要补认证头(main.py / scheduler 共用 client 场景)
self.client.headers["Authorization"] = f"Token {self.api_key}"
self.client.headers["Content-Type"] = "application/json"
else:
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
headers={
"Authorization": f"Token {self.api_key}",
"Content-Type": "application/json",
},
)
self._owns_client = True
async def close(self) -> None:
"""关闭 HTTP Client,释放连接池。仅在自己创建 Client 时关闭。"""
if self._owns_client and not self.client.is_closed:
await self.client.aclose()
# ==================== TTS 语音合成 ====================
async def tts_sync(
self,
text: str,
voice_id: str,
speed: float = 1.0,
volume: int = 0,
pitch: int = 0,
emotion: str | None = None,
pronunciation_dict_tone: list[str] | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""同步语音合成
POST /ent/v2/audio-tts
"""
url = f"{self.base_url}/ent/v2/audio-tts"
body: dict[str, Any] = {
"text": text,
"voice_setting_voice_id": voice_id,
"voice_setting_speed": speed,
"voice_setting_volume": volume,
"voice_setting_pitch": pitch,
}
if emotion:
body["voice_setting_emotion"] = emotion
if pronunciation_dict_tone:
body["pronunciation_dict_tone"] = pronunciation_dict_tone
if payload:
body["payload"] = payload
logger.info(f"[Vidu TTS] 请求参数: text_length={len(text)}")
logger.info(f"[Vidu TTS] 提交请求: url={url}, body={body}")
try:
# 文本较长时同步合成可能耗时较久,超时时间放宽到 120 秒
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu TTS] 网络错误: {e}")
raise PlatformError(
f"Vidu TTS 网络错误: {e}",
platform="vidu",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
# ==================== 声音复刻 ====================
async def clone_voice(
self,
audio_url: str,
voice_id: str,
text: str,
prompt_audio_url: str | None = None,
prompt_text: str | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""声音复刻(同步接口)
POST /ent/v2/audio-clone
"""
url = f"{self.base_url}/ent/v2/audio-clone"
body: dict[str, Any] = {
"audio_url": audio_url,
"voice_id": voice_id,
"text": text,
}
if prompt_audio_url:
body["prompt_audio_url"] = prompt_audio_url
if prompt_text:
body["prompt_text"] = prompt_text
if payload:
body["payload"] = payload
try:
# 声音复刻处理音频可能耗时较久,超时时间放宽到 120 秒
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Clone] 网络错误: {e}")
raise PlatformError(
f"Vidu Clone 网络错误: {e}",
platform="vidu",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
# ==================== 视频生成 ====================
async def lip_sync(
self,
video_url: str,
audio_url: str | None = None,
text: str | None = None,
voice_id: str | None = None,
speed: float = 1.0,
volume: int = 0,
ref_photo_url: str | None = None,
callback_url: str | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""视频生成(异步接口)
POST /ent/v2/lip-sync
"""
url = f"{self.base_url}/ent/v2/lip-sync"
body: dict[str, Any] = {"video_url": video_url}
if audio_url:
body["audio_url"] = audio_url
if text:
body["text"] = text
if voice_id:
body["voice_id"] = voice_id
if speed != 1.0:
body["speed"] = speed
if volume != 0:
body["volume"] = volume
if ref_photo_url:
body["ref_photo_url"] = ref_photo_url
if callback_url:
body["callback_url"] = callback_url
if payload:
body["payload"] = payload
try:
resp = await self.client.post(url, json=body)
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu LipSync] 网络错误: {e}")
raise PlatformError(
f"Vidu LipSync 网络错误: {e}",
platform="vidu",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
# ==================== 查询任务 ====================
async def query_task(self, task_id: str) -> dict[str, Any]:
"""查询任务状态及生成物
GET /ent/v2/tasks/{task_id}/creations
"""
url = f"{self.base_url}/ent/v2/tasks/{task_id}/creations"
try:
resp = await self.client.get(url)
data = resp.json()
if resp.status_code != 200:
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Query] 网络错误: {e}")
raise PlatformError(
f"Vidu Query 网络错误: {e}",
platform="vidu",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e