e58159fc42
Phase 1: 异常体系统一 - 新增 PlatformError / PlatformErrorType 标准定义 - 改造所有 Provider 异常抛出为 PlatformError - 注册全局 PlatformError exception handler Phase 2: Adapter Protocol - 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable) - 新增 app/ai/adapters/constants.py(Method 常量) - 新增 PlatformConfigLoader(config/platform-config.yaml) Phase 3: HTTP Client 统一 - ViduProvider 从 aiohttp 迁移到 httpx(注入方式) - VolcengineCaptionService 改为注入 http_client - lifespan 统一管理所有 Client 创建和关闭 Phase 4: Gateway 骨架 + Adapter 实现 - 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter - 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook) - 新增 LLMGateway(带 Fallback 降级链) - lifespan 注册所有 Adapter 和 Gateway Phase 6: 清理与验证 - 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL - Provider 改为从 PlatformConfigLoader 读取 base_url - 清理 volcengine_caption_service 全局单例 - config_loader 默认路径改为 platform-config.yaml - Scheduler 注入共享 HTTP client - vidu.py 回调路由使用 Adapter 验签和解析 - ruff 全量通过,应用启动测试通过
272 lines
9.4 KiB
Python
272 lines
9.4 KiB
Python
"""
|
||
Vidu API Provider
|
||
=================
|
||
|
||
封装 Vidu 语音/视频相关 HTTP API:
|
||
- 同步 TTS(/ent/v2/audio-tts)
|
||
- 声音复刻(/ent/v2/audio-clone)
|
||
- 对口型(/ent/v2/lip-sync)
|
||
- 查询任务(/ent/v2/tasks/{id}/creations)
|
||
|
||
统一使用 httpx.AsyncClient,由 lifespan 统一管理生命周期。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _map_vidu_error(status: int, message: str) -> PlatformError:
|
||
"""把 Vidu HTTP 错误映射为标准 PlatformError"""
|
||
mapping = {
|
||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||
403: (PlatformErrorType.AUTH_FAILED, False),
|
||
400: (PlatformErrorType.BAD_REQUEST, False),
|
||
404: (PlatformErrorType.NOT_FOUND, False),
|
||
500: (PlatformErrorType.SERVER_ERROR, True),
|
||
502: (PlatformErrorType.SERVER_ERROR, True),
|
||
503: (PlatformErrorType.SERVER_ERROR, True),
|
||
}
|
||
error_type, retryable = mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||
return PlatformError(
|
||
message=message,
|
||
platform="vidu",
|
||
retryable=retryable,
|
||
error_type=error_type,
|
||
status_code=status,
|
||
)
|
||
|
||
|
||
class ViduProvider:
|
||
"""Vidu API 客户端封装
|
||
|
||
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
base_url: str | None = None,
|
||
client: httpx.AsyncClient | None = None,
|
||
):
|
||
settings = get_settings()
|
||
self.api_key = api_key or settings.VIDU_API_KEY
|
||
if base_url:
|
||
self.base_url = base_url.rstrip("/")
|
||
else:
|
||
from app.core.platform_config import get_platform_config_loader
|
||
|
||
platform_config = get_platform_config_loader().get_platform("vidu")
|
||
self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
|
||
|
||
if not self.api_key:
|
||
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
self._owns_client = False
|
||
else:
|
||
self.client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
|
||
headers={
|
||
"Authorization": f"Token {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
self._owns_client = True
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP Client,释放连接池。仅在自己创建 Client 时关闭。"""
|
||
if self._owns_client and not self.client.is_closed:
|
||
await self.client.aclose()
|
||
|
||
# ==================== TTS 语音合成 ====================
|
||
|
||
async def tts_sync(
|
||
self,
|
||
text: str,
|
||
voice_id: str,
|
||
speed: float = 1.0,
|
||
volume: int = 0,
|
||
pitch: int = 0,
|
||
emotion: str | None = None,
|
||
pronunciation_dict_tone: list[str] | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""同步语音合成
|
||
|
||
POST /ent/v2/audio-tts
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/audio-tts"
|
||
|
||
body: dict[str, Any] = {
|
||
"text": text,
|
||
"voice_setting_voice_id": voice_id,
|
||
"voice_setting_speed": speed,
|
||
"voice_setting_volume": volume,
|
||
"voice_setting_pitch": pitch,
|
||
}
|
||
if emotion:
|
||
body["voice_setting_emotion"] = emotion
|
||
if pronunciation_dict_tone:
|
||
body["pronunciation_dict_tone"] = pronunciation_dict_tone
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
logger.info(f"[Vidu TTS] 请求参数: text_length={len(text)}")
|
||
|
||
try:
|
||
resp = await self.client.post(url, json=body)
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu TTS] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu TTS 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 声音复刻 ====================
|
||
|
||
async def clone_voice(
|
||
self,
|
||
audio_url: str,
|
||
voice_id: str,
|
||
text: str,
|
||
prompt_audio_url: str | None = None,
|
||
prompt_text: str | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""声音复刻(同步接口)
|
||
|
||
POST /ent/v2/audio-clone
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/audio-clone"
|
||
|
||
body: dict[str, Any] = {
|
||
"audio_url": audio_url,
|
||
"voice_id": voice_id,
|
||
"text": text,
|
||
}
|
||
if prompt_audio_url:
|
||
body["prompt_audio_url"] = prompt_audio_url
|
||
if prompt_text:
|
||
body["prompt_text"] = prompt_text
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
try:
|
||
resp = await self.client.post(url, json=body)
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||
raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu Clone] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu Clone 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 对口型 ====================
|
||
|
||
async def lip_sync(
|
||
self,
|
||
video_url: str,
|
||
audio_url: str | None = None,
|
||
text: str | None = None,
|
||
voice_id: str | None = None,
|
||
speed: float = 1.0,
|
||
volume: int = 0,
|
||
ref_photo_url: str | None = None,
|
||
callback_url: str | None = None,
|
||
payload: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""对口型(异步接口)
|
||
|
||
POST /ent/v2/lip-sync
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/lip-sync"
|
||
|
||
body: dict[str, Any] = {"video_url": video_url}
|
||
|
||
if audio_url:
|
||
body["audio_url"] = audio_url
|
||
if text:
|
||
body["text"] = text
|
||
if voice_id:
|
||
body["voice_id"] = voice_id
|
||
if speed != 1.0:
|
||
body["speed"] = speed
|
||
if volume != 0:
|
||
body["volume"] = volume
|
||
if ref_photo_url:
|
||
body["ref_photo_url"] = ref_photo_url
|
||
if callback_url:
|
||
body["callback_url"] = callback_url
|
||
if payload:
|
||
body["payload"] = payload
|
||
|
||
try:
|
||
resp = await self.client.post(url, json=body)
|
||
data = resp.json()
|
||
if resp.status_code != 200 or data.get("state") == "failed":
|
||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||
raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu LipSync] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu LipSync 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
|
||
# ==================== 查询任务 ====================
|
||
|
||
async def query_task(self, task_id: str) -> dict[str, Any]:
|
||
"""查询任务状态及生成物
|
||
|
||
GET /ent/v2/tasks/{task_id}/creations
|
||
"""
|
||
url = f"{self.base_url}/ent/v2/tasks/{task_id}/creations"
|
||
|
||
try:
|
||
resp = await self.client.get(url)
|
||
data = resp.json()
|
||
if resp.status_code != 200:
|
||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||
logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||
raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
|
||
return data
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
logger.error(f"[Vidu Query] 网络错误: {e}")
|
||
raise PlatformError(
|
||
f"Vidu Query 网络错误: {e}",
|
||
platform="vidu",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|