288 lines
9.6 KiB
Python
288 lines
9.6 KiB
Python
"""
|
|
Vidu API 代理路由
|
|
================
|
|
|
|
提供 Vidu 对口型(lip-sync)任务的提交、查询和回调接口。
|
|
前端通过此接口提交任务并轮询状态,无需直接访问 Vidu API。
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.config import get_settings
|
|
from app.core.redis_client import get_redis_client
|
|
from app.models.user import User
|
|
from app.schemas.common import ApiResponse, success_response
|
|
from app.services.vidu_tts_service import ViduTTSService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/vidu", tags=["Vidu"])
|
|
|
|
# ========== 请求/响应模型 ==========
|
|
|
|
|
|
class LipSyncRequest(BaseModel):
|
|
"""对口型请求"""
|
|
|
|
video_url: str = Field(..., min_length=1, description="原视频 URL")
|
|
audio_url: str | None = Field(None, description="音频 URL(与 text 二选一)")
|
|
text: str | None = Field(None, description="文本内容(与 audio_url 二选一)")
|
|
voice_id: str | None = Field(None, description="音色 ID(文字驱动时生效)")
|
|
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
|
|
volume: int = Field(default=0, ge=0, le=10, description="音量")
|
|
ref_photo_url: str | None = Field(None, description="人脸参考图 URL")
|
|
|
|
@staticmethod
|
|
def validate_at_least_one_audio_source(values: dict) -> dict:
|
|
"""验证至少提供 audio_url 或 text 之一"""
|
|
audio_url = values.get("audio_url")
|
|
text = values.get("text")
|
|
if not audio_url and not text:
|
|
raise ValueError("必须提供 audio_url 或 text 之一")
|
|
return values
|
|
|
|
|
|
class LipSyncResponse(BaseModel):
|
|
"""对口型任务提交响应"""
|
|
|
|
task_id: str = Field(..., description="Vidu 任务 ID")
|
|
message: str = Field(default="任务已提交", description="状态消息")
|
|
|
|
|
|
class LipSyncQueryResponse(BaseModel):
|
|
"""对口型任务查询响应"""
|
|
|
|
task_id: str = Field(..., description="任务 ID")
|
|
state: str = Field(..., description="任务状态: pending/processing/success/failed")
|
|
video_url: str | None = Field(None, description="生成后的视频 URL(成功时)")
|
|
message: str | None = Field(None, description="状态描述或错误信息")
|
|
creations: list[dict] | None = Field(None, description="Vidu 原始 creations 数据")
|
|
|
|
|
|
class LipSyncCallbackRequest(BaseModel):
|
|
"""Vidu 对口型任务回调请求"""
|
|
|
|
task_id: str = Field(..., description="任务 ID")
|
|
state: str = Field(..., description="任务状态")
|
|
creations: list[dict] | None = Field(None, description="生成物列表")
|
|
message: str | None = Field(None, description="错误信息")
|
|
|
|
|
|
class LipSyncStatusResponse(BaseModel):
|
|
"""对口型任务状态查询响应(供前端轮询)"""
|
|
|
|
task_id: str = Field(..., description="任务 ID")
|
|
state: str = Field(..., description="任务状态")
|
|
video_url: str | None = Field(None, description="生成后的视频 URL")
|
|
message: str | None = Field(None, description="错误信息")
|
|
updated_at: float = Field(..., description="状态更新时间戳")
|
|
|
|
|
|
# ========== API 路由 ==========
|
|
|
|
|
|
@router.post("/lip-sync", response_model=ApiResponse[LipSyncResponse])
|
|
async def create_lip_sync_task(
|
|
request: LipSyncRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
提交 Vidu 对口型任务
|
|
|
|
后端自动拼接 callback_url,Vidu 任务完成后会主动通知。
|
|
前端通过 /vidu/tasks/{task_id}/status 轮询状态。
|
|
"""
|
|
try:
|
|
# 验证至少提供一种音频来源
|
|
if not request.audio_url and not request.text:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="必须提供 audio_url 或 text 之一",
|
|
)
|
|
|
|
settings = get_settings()
|
|
callback_url = f"{settings.app_base_url}/api/v1/vidu/callback"
|
|
|
|
service = ViduTTSService()
|
|
task_id = await service.lip_sync_create(
|
|
video_url=request.video_url,
|
|
audio_url=request.audio_url,
|
|
text=request.text,
|
|
voice_id=request.voice_id,
|
|
speed=request.speed,
|
|
volume=request.volume,
|
|
ref_photo_url=request.ref_photo_url,
|
|
callback_url=callback_url,
|
|
)
|
|
|
|
# 初始化任务状态到 Redis(供前端轮询)
|
|
redis = get_redis_client()
|
|
await redis.setex(
|
|
f"vidu:lipsync:{task_id}",
|
|
3600,
|
|
json.dumps({
|
|
"state": "processing",
|
|
"video_url": None,
|
|
"message": None,
|
|
"updated_at": time.time(),
|
|
}),
|
|
)
|
|
|
|
logger.info(f"[Vidu] 对口型任务提交成功: task_id={task_id}, user={current_user.id}, callback={callback_url}")
|
|
|
|
return success_response(
|
|
data=LipSyncResponse(
|
|
task_id=task_id,
|
|
message="对口型任务已提交",
|
|
)
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[Vidu] 提交对口型任务失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"提交对口型任务失败: {e}")
|
|
|
|
|
|
@router.post("/callback")
|
|
async def vidu_callback(request: Request):
|
|
"""
|
|
Vidu 对口型任务完成回调
|
|
|
|
Vidu 任务完成后主动 POST 通知此接口。
|
|
无需登录校验(Vidu 外部调用)。
|
|
"""
|
|
try:
|
|
body = await request.json()
|
|
# Vidu 回调用 "id" 作为任务标识(查询接口也用 id),不是 "task_id"
|
|
task_id = body.get("id") or body.get("task_id")
|
|
state = body.get("state")
|
|
creations = body.get("creations", [])
|
|
|
|
# 提取视频 URL
|
|
video_url = None
|
|
if state == "success" and creations:
|
|
first = creations[0] if creations else {}
|
|
video_url = first.get("url")
|
|
|
|
# 更新 Redis 状态
|
|
redis = get_redis_client()
|
|
await redis.setex(
|
|
f"vidu:lipsync:{task_id}",
|
|
3600,
|
|
json.dumps({
|
|
"state": state,
|
|
"video_url": video_url,
|
|
"message": body.get("message"),
|
|
"updated_at": time.time(),
|
|
}),
|
|
)
|
|
|
|
logger.info(f"[Vidu] 回调接收: task_id={task_id}, state={state}")
|
|
return success_response(message="回调已接收")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Vidu] 回调处理失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"回调处理失败: {e}")
|
|
|
|
|
|
@router.get("/tasks/{task_id}/status", response_model=ApiResponse[LipSyncStatusResponse])
|
|
async def query_lip_sync_status(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
查询对口型任务状态(供前端轮询)
|
|
|
|
优先从 Redis 读取状态(由回调更新),
|
|
Redis 无数据时回退到直接查询 Vidu API。
|
|
"""
|
|
try:
|
|
redis = get_redis_client()
|
|
cached = await redis.get(f"vidu:lipsync:{task_id}")
|
|
|
|
if cached:
|
|
data = json.loads(cached)
|
|
return success_response(
|
|
data=LipSyncStatusResponse(
|
|
task_id=task_id,
|
|
state=data.get("state", "unknown"),
|
|
video_url=data.get("video_url"),
|
|
message=data.get("message"),
|
|
updated_at=data.get("updated_at", 0),
|
|
)
|
|
)
|
|
|
|
# Redis 无缓存,回退到直接查询 Vidu
|
|
service = ViduTTSService()
|
|
result = await service.lip_sync_query(task_id)
|
|
|
|
state = result.get("state", "unknown")
|
|
creations = result.get("creations", [])
|
|
video_url = None
|
|
if state == "success" and creations:
|
|
first = creations[0] if creations else {}
|
|
video_url = first.get("url")
|
|
|
|
return success_response(
|
|
data=LipSyncStatusResponse(
|
|
task_id=task_id,
|
|
state=state,
|
|
video_url=video_url,
|
|
message=result.get("message"),
|
|
updated_at=time.time(),
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Vidu] 查询任务状态失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {e}")
|
|
|
|
|
|
@router.get("/tasks/{task_id}/creations", response_model=ApiResponse[LipSyncQueryResponse])
|
|
async def query_lip_sync_task(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
直接查询 Vidu 对口型任务状态(保留兼容)
|
|
|
|
前端优先使用 /tasks/{task_id}/status(走 Redis 缓存)。
|
|
"""
|
|
try:
|
|
service = ViduTTSService()
|
|
result = await service.lip_sync_query(task_id)
|
|
|
|
state = result.get("state", "unknown")
|
|
creations = result.get("creations", [])
|
|
|
|
# 提取视频 URL(成功时)
|
|
video_url = None
|
|
if state == "success" and creations:
|
|
first_creation = creations[0] if creations else {}
|
|
video_url = first_creation.get("url")
|
|
|
|
logger.info(
|
|
f"[Vidu] 查询对口型任务: task_id={task_id}, state={state}, user={current_user.id}"
|
|
)
|
|
|
|
return success_response(
|
|
data=LipSyncQueryResponse(
|
|
task_id=task_id,
|
|
state=state,
|
|
video_url=video_url,
|
|
message=result.get("message") if state == "failed" else None,
|
|
creations=creations if creations else None,
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Vidu] 查询对口型任务失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"查询任务失败: {e}")
|