Files
meijiaka-zy/python-api/app/ai/providers/vidu_provider.py
T
小鱼开发 ab9962d333 refactor(vidu): reusable session, semaphore, retry, lifespan management
- vidu_provider: single ClientSession with TCP connector pool and explicit timeouts
- vidu_service: Semaphore(10) concurrency limit + tenacity retry (3 attempts, exponential backoff)
- voice/vidu routes: use FastAPI Depends injection instead of new Service() per request
- main.py: initialize Vidu Provider & Service in lifespan, close on shutdown
- add tenacity to dependencies
- remove vidu_tts_service.py
2026-05-02 21:55:20 +08:00

212 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Vidu API Provider
=================
封装 Vidu 语音/视频相关 HTTP API
- 同步 TTS/ent/v2/audio-tts
- 声音复刻(/ent/v2/audio-clone
- 对口型(/ent/v2/lip-sync
- 查询任务(/ent/v2/tasks/{id}/creations
2024 工程实践:
- 单一 ClientSession 实例,全局复用
- 连接池大小对齐第三方并发限制
- 显式分层超时配置
"""
from __future__ import annotations
import logging
from typing import Any
import aiohttp
from app.config import get_settings
logger = logging.getLogger(__name__)
class ViduProvider:
"""Vidu API 客户端封装
单一 ClientSession 实例,应用生命周期内复用。
由 FastAPI lifespan 负责创建和关闭。
"""
def __init__(self, api_key: str | None = None, base_url: str | None = None):
settings = get_settings()
self.api_key = api_key or settings.VIDU_API_KEY
self.base_url = (base_url or settings.VIDU_BASE_URL).rstrip("/")
if not self.api_key:
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
connector = aiohttp.TCPConnector(
limit=20,
limit_per_host=20,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=30,
connect=5,
sock_read=10,
)
self.session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={
"Authorization": f"Token {self.api_key}",
"Content-Type": "application/json",
},
)
async def close(self) -> None:
"""关闭 HTTP Session,释放连接池。"""
await self.session.close()
# ==================== TTS 语音合成 ====================
async def tts_sync(
self,
text: str,
voice_id: str,
speed: float = 1.0,
volume: int = 0,
pitch: int = 0,
emotion: str | None = None,
pronunciation_dict_tone: list[str] | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""同步语音合成
POST /ent/v2/audio-tts
"""
url = f"{self.base_url}/ent/v2/audio-tts"
body: dict[str, Any] = {
"text": text,
"voice_setting_voice_id": voice_id,
"voice_setting_speed": speed,
"voice_setting_volume": volume,
"voice_setting_pitch": pitch,
}
if emotion:
body["voice_setting_emotion"] = emotion
if pronunciation_dict_tone:
body["pronunciation_dict_tone"] = pronunciation_dict_tone
if payload:
body["payload"] = payload
logger.info(f"[Vidu TTS] 请求参数: text_length={len(text)}")
async with self.session.post(url, json=body) as resp:
data = await resp.json()
if resp.status != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status}"
logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status}, response={data}")
raise Exception(f"Vidu TTS error: {msg}")
return data
# ==================== 声音复刻 ====================
async def clone_voice(
self,
audio_url: str,
voice_id: str,
text: str,
prompt_audio_url: str | None = None,
prompt_text: str | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""声音复刻(同步接口)
POST /ent/v2/audio-clone
"""
url = f"{self.base_url}/ent/v2/audio-clone"
body: dict[str, Any] = {
"audio_url": audio_url,
"voice_id": voice_id,
"text": text,
}
if prompt_audio_url:
body["prompt_audio_url"] = prompt_audio_url
if prompt_text:
body["prompt_text"] = prompt_text
if payload:
body["payload"] = payload
async with self.session.post(url, json=body) as resp:
data = await resp.json()
if resp.status != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status}"
logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status}, response={data}")
raise Exception(f"Vidu clone error: {msg}")
return data
# ==================== 对口型 ====================
async def lip_sync(
self,
video_url: str,
audio_url: str | None = None,
text: str | None = None,
voice_id: str | None = None,
speed: float = 1.0,
volume: int = 0,
ref_photo_url: str | None = None,
callback_url: str | None = None,
payload: str | None = None,
) -> dict[str, Any]:
"""对口型(异步接口)
POST /ent/v2/lip-sync
"""
url = f"{self.base_url}/ent/v2/lip-sync"
body: dict[str, Any] = {"video_url": video_url}
if audio_url:
body["audio_url"] = audio_url
if text:
body["text"] = text
if voice_id:
body["voice_id"] = voice_id
if speed != 1.0:
body["speed"] = speed
if volume != 0:
body["volume"] = volume
if ref_photo_url:
body["ref_photo_url"] = ref_photo_url
if callback_url:
body["callback_url"] = callback_url
if payload:
body["payload"] = payload
async with self.session.post(url, json=body) as resp:
data = await resp.json()
if resp.status != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status}"
logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status}, response={data}")
raise Exception(f"Vidu lip-sync error: {msg}")
return data
# ==================== 查询任务 ====================
async def query_task(self, task_id: str) -> dict[str, Any]:
"""查询任务状态及生成物
GET /ent/v2/tasks/{task_id}/creations
"""
url = f"{self.base_url}/ent/v2/tasks/{task_id}/creations"
async with self.session.get(url) as resp:
data = await resp.json()
if resp.status != 200:
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status}"
logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status}, response={data}")
raise Exception(f"Vidu query task error: {msg}")
return data