561 lines
19 KiB
Python
561 lines
19 KiB
Python
"""
|
||
Avatar 形象克隆模块
|
||
==================
|
||
|
||
串行流程:
|
||
1. 使用上传的视频创建 KlingAI 自定义音色 (custom-voices)
|
||
2. 轮询等待音色生成完成,获取 voice_id
|
||
3. 使用同一视频 + voice_id 创建 KlingAI 主体 (advanced-custom-elements)
|
||
4. 轮询等待主体生成完成,获取 provider_element_id
|
||
5. 返回统一的 AvatarItem
|
||
|
||
异步架构:
|
||
- POST /avatar/clone 只负责注册到 Async Engine(纯 Redis,无 DB),立即返回 task_id
|
||
- 真正的轮询由 Async Engine Scheduler 在后台执行
|
||
- 前端通过 SSE 或轮询 GET /avatar/tasks/{task_id} 查询进度
|
||
|
||
数据策略:
|
||
- 形象克隆数据只保存在前端本地,后端不持久化到数据库
|
||
- 任务运行时的中间状态全部存储在 Redis 中(TTL 24h)
|
||
|
||
错误提示策略:
|
||
- custom-voice 失败:提示"有声的人物视频"相关原因
|
||
- element 失败:提示视频内容/质量不符合主体创建要求
|
||
- 超时:标记为 timeout,支持重试
|
||
"""
|
||
|
||
import asyncio
|
||
import contextlib
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from datetime import UTC, datetime
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel, ConfigDict, Field
|
||
|
||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||
from app.api.deps import get_current_user
|
||
from app.config import get_settings
|
||
from app.core.redis_client import get_redis_client
|
||
from app.scheduler.registry import JobRegistry
|
||
from app.schemas.common import ApiResponse, success_response
|
||
from app.schemas.enums import AvatarCloneStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
|
||
def _get_kling_provider() -> KlingAIProvider:
|
||
settings = get_settings()
|
||
return KlingAIProvider(
|
||
config={
|
||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||
}
|
||
)
|
||
|
||
|
||
async def _get_avatar_state(redis, job_id: str) -> dict | None:
|
||
"""从 Redis 读取 avatar 任务完整状态"""
|
||
data = await redis.hgetall(f"job:{job_id}")
|
||
if not data:
|
||
return None
|
||
|
||
# 解析 JSON 字段
|
||
for key in ("result", "params"):
|
||
if key in data and data[key]:
|
||
with contextlib.suppress(json.JSONDecodeError):
|
||
data[key] = json.loads(data[key])
|
||
return data
|
||
|
||
|
||
class CloneAvatarRequest(BaseModel):
|
||
"""创建形象克隆请求"""
|
||
|
||
name: str = Field(..., min_length=1, max_length=20, description="形象名称")
|
||
video_url: str = Field(description="人物视频 URL")
|
||
|
||
|
||
class CloneAvatarResponse(BaseModel):
|
||
"""创建形象克隆响应"""
|
||
|
||
task_id: str = Field(..., description="任务 ID(用于 SSE/轮询跟踪进度)")
|
||
status: str = Field("pending", description="初始状态")
|
||
|
||
|
||
class AvatarTaskStatusResponse(BaseModel):
|
||
"""任务状态查询响应"""
|
||
|
||
task_id: str
|
||
status: str = Field(..., description="当前状态")
|
||
fail_reason: str | None = Field(None, description="失败原因")
|
||
voice_id: str | None = Field(None, description="已生成的音色 ID")
|
||
human_id: int | None = Field(None, description="已生成的主体 ID")
|
||
trial_url: str | None = Field(None, description="试听 URL")
|
||
video_url: str = Field(..., description="原始视频 URL")
|
||
name: str = Field(..., description="形象名称")
|
||
created_at: datetime = Field(..., description="创建时间")
|
||
updated_at: datetime = Field(..., description="更新时间")
|
||
|
||
|
||
class AvatarItem(BaseModel):
|
||
"""形象库列表项"""
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
id: str = Field(..., description="形象唯一标识")
|
||
name: str = Field(..., description="展示名称")
|
||
voice_id: str = Field(..., description="Kling 自定义音色 ID")
|
||
human_id: int = Field(..., description="数字人主体 ID")
|
||
video_url: str = Field(description="原始人物视频 URL")
|
||
trial_url: str | None = Field(None, description="音色试听 URL")
|
||
record_time: str = Field(description="创建时间 ISO 字符串")
|
||
|
||
|
||
class UpdateAvatarNameRequest(BaseModel):
|
||
"""更新形象名称请求"""
|
||
|
||
name: str = Field(..., min_length=1, max_length=20, description="新形象名称")
|
||
|
||
|
||
# ============================================================
|
||
# API 路由
|
||
# ============================================================
|
||
|
||
|
||
@router.post("/avatar/clone", response_model=ApiResponse[CloneAvatarResponse])
|
||
async def clone_avatar(
|
||
data: CloneAvatarRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
提交形象克隆任务
|
||
|
||
立即返回 task_id,前端通过 SSE 或轮询跟踪进度。
|
||
实际串行流程由 Async Engine Scheduler 异步执行。
|
||
任务状态纯 Redis 存储,不写入数据库。
|
||
"""
|
||
user_id = str(current_user.id)
|
||
name = data.name.strip()
|
||
video_url = data.video_url.strip()
|
||
|
||
# 生成 task_id
|
||
task_id = f"avt_{uuid.uuid4().hex[:16]}"
|
||
now = datetime.now(UTC)
|
||
|
||
# 写入 Redis,供 Async Engine 调度(同时存储 avatar 初始状态)
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
await registry.create(task_id, "avatar_clone", user_id)
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
progress=5,
|
||
message="开始形象克隆...",
|
||
completed=0,
|
||
total=1,
|
||
params={
|
||
"avatar_id": task_id,
|
||
"name": name,
|
||
"video_url": video_url,
|
||
"user_id": user_id,
|
||
},
|
||
# 存储 avatar 状态字段(供 API 查询)
|
||
avatar_status=AvatarCloneStatus.PENDING.value,
|
||
avatar_name=name,
|
||
avatar_video_url=video_url,
|
||
voice_id="",
|
||
provider_element_id="",
|
||
provider_voice_job_id="",
|
||
provider_element_job_id="",
|
||
trial_url="",
|
||
fail_reason="",
|
||
created_at=now.isoformat(),
|
||
updated_at=now.isoformat(),
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
return success_response(data=CloneAvatarResponse(task_id=task_id, status="pending"))
|
||
|
||
|
||
@router.get("/avatar/tasks/{task_id}", response_model=ApiResponse[AvatarTaskStatusResponse])
|
||
async def get_avatar_task_status(
|
||
task_id: str,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""查询形象克隆任务状态(从 Redis 读取)"""
|
||
redis = get_redis_client()
|
||
state = await _get_avatar_state(redis, task_id)
|
||
if not state:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
# 权限检查
|
||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||
if params.get("user_id") != str(current_user.id):
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
def _dt(key: str) -> datetime:
|
||
raw = state.get(key, "")
|
||
if raw:
|
||
try:
|
||
return datetime.fromisoformat(raw)
|
||
except ValueError:
|
||
pass
|
||
return datetime.now(UTC)
|
||
|
||
def _int(key: str) -> int | None:
|
||
raw = state.get(key, "")
|
||
if raw:
|
||
try:
|
||
return int(raw)
|
||
except ValueError:
|
||
pass
|
||
return None
|
||
|
||
return success_response(
|
||
data=AvatarTaskStatusResponse(
|
||
task_id=task_id,
|
||
status=state.get("avatar_status", state.get("status", "unknown")),
|
||
fail_reason=state.get("fail_reason") or None,
|
||
voice_id=state.get("voice_id") or None,
|
||
human_id=_int("provider_element_id"),
|
||
trial_url=state.get("trial_url") or None,
|
||
video_url=params.get("video_url", ""),
|
||
name=params.get("name", ""),
|
||
created_at=_dt("created_at"),
|
||
updated_at=_dt("updated_at"),
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/avatar/clone/stream")
|
||
async def sse_avatar_clone(
|
||
task_id: str = Query(..., alias="task_id", description="任务 ID"),
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
SSE 流:实时推送形象克隆任务状态
|
||
|
||
前端连接后,每 3 秒推送一次状态,直到任务结束(succeed / failed / timeout)。
|
||
"""
|
||
user_id = str(current_user.id)
|
||
|
||
async def event_stream():
|
||
for _ in range(400): # 最多 20 分钟(400 * 3s)
|
||
redis = get_redis_client()
|
||
state = await _get_avatar_state(redis, task_id)
|
||
|
||
if not state:
|
||
payload = json.dumps(
|
||
{"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False
|
||
)
|
||
yield f"event: error\ndata: {payload}\n\n"
|
||
break
|
||
|
||
# 权限检查
|
||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||
if params.get("user_id") != user_id:
|
||
payload = json.dumps(
|
||
{"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False
|
||
)
|
||
yield f"event: error\ndata: {payload}\n\n"
|
||
break
|
||
|
||
avatar_status = state.get("avatar_status", state.get("status", "unknown"))
|
||
|
||
payload = json.dumps(
|
||
{
|
||
"task_id": task_id,
|
||
"status": avatar_status,
|
||
"fail_reason": state.get("fail_reason") or None,
|
||
"voice_id": state.get("voice_id") or None,
|
||
"provider_element_id": state.get("provider_element_id") or None,
|
||
"trial_url": state.get("trial_url") or None,
|
||
"video_url": params.get("video_url", ""),
|
||
"name": params.get("name", ""),
|
||
"created_at": state.get("created_at", ""),
|
||
"updated_at": state.get("updated_at", ""),
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
yield f"data: {payload}\n\n"
|
||
|
||
if avatar_status in (
|
||
AvatarCloneStatus.SUCCEED,
|
||
AvatarCloneStatus.VOICE_FAILED,
|
||
AvatarCloneStatus.ELEMENT_FAILED,
|
||
AvatarCloneStatus.TIMEOUT,
|
||
):
|
||
break
|
||
|
||
await asyncio.sleep(3)
|
||
else:
|
||
# 达到最大轮询次数,推送超时事件
|
||
payload = json.dumps(
|
||
{"status": "timeout", "fail_reason": "连接超时,请通过轮询接口继续跟踪"},
|
||
ensure_ascii=False,
|
||
)
|
||
yield f"event: timeout\ndata: {payload}\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
},
|
||
)
|
||
|
||
|
||
@router.post("/avatar/tasks/{task_id}/retry", response_model=ApiResponse[dict])
|
||
async def retry_avatar_task(
|
||
task_id: str,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
重试失败或超时的形象克隆任务
|
||
|
||
仅允许对 voice_failed / element_failed / timeout 状态的任务重试。
|
||
重试时会重置状态为 pending 并重新注册到 Async Engine。
|
||
"""
|
||
redis = get_redis_client()
|
||
state = await _get_avatar_state(redis, task_id)
|
||
if not state:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||
if params.get("user_id") != str(current_user.id):
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
avatar_status = state.get("avatar_status", state.get("status", ""))
|
||
if avatar_status not in (
|
||
AvatarCloneStatus.VOICE_FAILED.value,
|
||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
AvatarCloneStatus.TIMEOUT.value,
|
||
):
|
||
raise HTTPException(status_code=400, detail=f"当前状态 {avatar_status} 不支持重试")
|
||
|
||
# 重置状态
|
||
registry = JobRegistry(redis)
|
||
now = datetime.now(UTC).isoformat()
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
avatar_status=AvatarCloneStatus.PENDING,
|
||
fail_reason="",
|
||
voice_id="",
|
||
provider_element_id="",
|
||
provider_voice_job_id="",
|
||
provider_element_job_id="",
|
||
trial_url="",
|
||
updated_at=now,
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
return success_response(data={"task_id": task_id, "status": "pending"})
|
||
|
||
|
||
@router.delete("/avatar/{avatar_id}", response_model=ApiResponse[dict])
|
||
async def delete_avatar(
|
||
avatar_id: str,
|
||
voice_id: str | None = None,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
删除形象:清理 Kling 资源 + 删除 Redis 任务记录
|
||
|
||
不操作数据库,形象数据由前端本地管理。
|
||
"""
|
||
redis = get_redis_client()
|
||
state = await _get_avatar_state(redis, avatar_id)
|
||
|
||
# 获取 Kling 资源 ID(优先用传入的,否则从 Redis 读)
|
||
actual_voice_id = voice_id
|
||
actual_element_id = None
|
||
if state:
|
||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||
if params.get("user_id") == str(current_user.id):
|
||
actual_element_id = state.get("provider_element_id")
|
||
if not actual_voice_id:
|
||
actual_voice_id = state.get("voice_id")
|
||
|
||
# 异步清理 Kling 资源(不阻塞前端)
|
||
provider = _get_kling_provider()
|
||
if actual_element_id:
|
||
try:
|
||
await provider.delete_element(str(actual_element_id))
|
||
except Exception as e:
|
||
logger.warning(f"delete_element failed: {e}")
|
||
|
||
if actual_voice_id:
|
||
try:
|
||
await provider.delete_custom_voice(actual_voice_id)
|
||
except Exception as e:
|
||
logger.warning(f"delete_custom_voice failed: {e}")
|
||
|
||
# 删除 Redis 任务记录
|
||
registry = JobRegistry(redis)
|
||
await registry.delete(avatar_id)
|
||
|
||
return success_response(data={"success": True, "message": "形象已删除"})
|
||
|
||
|
||
@router.get("/avatar/library", response_model=ApiResponse[list[AvatarItem]])
|
||
async def get_avatar_library(
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
获取当前用户的克隆形象库
|
||
|
||
形象数据只保存在前端本地,后端不持久化。
|
||
此接口始终返回空列表,由前端从 localStorage/文件系统读取真实数据。
|
||
"""
|
||
return success_response(data=[])
|
||
|
||
|
||
@router.patch("/avatar/{avatar_id}", response_model=ApiResponse[dict])
|
||
async def update_avatar_name(
|
||
avatar_id: str,
|
||
data: UpdateAvatarNameRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
更新形象名称
|
||
|
||
形象数据由前端本地管理,后端仅返回成功。
|
||
"""
|
||
new_name = data.name.strip()
|
||
if not new_name:
|
||
raise HTTPException(status_code=400, detail="名称不能为空")
|
||
|
||
return success_response(data={"success": True, "name": new_name})
|
||
|
||
|
||
# =============================================================================
|
||
# 管理和监控接口(用于排查问题和手动恢复)
|
||
# =============================================================================
|
||
|
||
|
||
class AvatarHealthResponse(BaseModel):
|
||
"""形象克隆服务健康状态"""
|
||
|
||
total_processing: int = Field(..., description="处理中的任务总数")
|
||
pending: int = Field(..., description="待处理任务数")
|
||
voice_processing: int = Field(..., description="音色生成中任务数")
|
||
element_processing: int = Field(..., description="主体生成中任务数")
|
||
stuck_tasks: int = Field(..., description="卡住任务数(超过30分钟)")
|
||
recent_failures: int = Field(..., description="最近1小时失败数")
|
||
|
||
|
||
@router.get("/avatar/health", response_model=ApiResponse[AvatarHealthResponse])
|
||
async def get_avatar_health(
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
获取形象克隆服务健康状态
|
||
|
||
基于 Redis 运行中任务统计,不查询数据库。
|
||
"""
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
job_ids = await registry.get_running_job_ids()
|
||
|
||
total_processing = 0
|
||
pending = 0
|
||
voice_processing = 0
|
||
element_processing = 0
|
||
stuck_tasks = 0
|
||
recent_failures = 0
|
||
|
||
now = datetime.now(UTC)
|
||
stuck_threshold = now.timestamp() - 30 * 60 # 30 分钟前
|
||
recent_threshold = now.timestamp() - 60 * 60 # 1 小时前
|
||
|
||
for job_id in job_ids:
|
||
state = await _get_avatar_state(redis, job_id)
|
||
if not state:
|
||
continue
|
||
|
||
# 只统计当前用户的任务(非管理员)
|
||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||
if params.get("user_id") != str(current_user.id):
|
||
continue
|
||
|
||
job_type = state.get("type", "")
|
||
if job_type != "avatar_clone":
|
||
continue
|
||
|
||
avatar_status = state.get("avatar_status", state.get("status", ""))
|
||
total_processing += 1
|
||
|
||
if avatar_status == AvatarCloneStatus.PENDING.value:
|
||
pending += 1
|
||
elif avatar_status == AvatarCloneStatus.VOICE_PROCESSING.value:
|
||
voice_processing += 1
|
||
elif avatar_status == AvatarCloneStatus.ELEMENT_PROCESSING.value:
|
||
element_processing += 1
|
||
|
||
# 检查是否卡住(updated_at 超过 30 分钟)
|
||
updated_at_raw = state.get("updated_at", "")
|
||
if updated_at_raw:
|
||
try:
|
||
updated_ts = datetime.fromisoformat(updated_at_raw).timestamp()
|
||
if updated_ts < stuck_threshold and avatar_status in (
|
||
AvatarCloneStatus.PENDING.value,
|
||
AvatarCloneStatus.VOICE_PROCESSING.value,
|
||
AvatarCloneStatus.ELEMENT_PROCESSING.value,
|
||
):
|
||
stuck_tasks += 1
|
||
except ValueError:
|
||
pass
|
||
|
||
# 检查最近失败
|
||
if avatar_status in (
|
||
AvatarCloneStatus.VOICE_FAILED.value,
|
||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
AvatarCloneStatus.TIMEOUT.value,
|
||
):
|
||
updated_at_raw = state.get("updated_at", "")
|
||
if updated_at_raw:
|
||
try:
|
||
updated_ts = datetime.fromisoformat(updated_at_raw).timestamp()
|
||
if updated_ts >= recent_threshold:
|
||
recent_failures += 1
|
||
except ValueError:
|
||
pass
|
||
|
||
return success_response(
|
||
data=AvatarHealthResponse(
|
||
total_processing=total_processing,
|
||
pending=pending,
|
||
voice_processing=voice_processing,
|
||
element_processing=element_processing,
|
||
stuck_tasks=stuck_tasks,
|
||
recent_failures=recent_failures,
|
||
)
|
||
)
|
||
|
||
|
||
@router.post("/avatar/admin/trigger-recovery", response_model=ApiResponse[dict])
|
||
async def admin_trigger_recovery(
|
||
current_user: dict = Depends(get_current_user),
|
||
):
|
||
"""
|
||
手动触发卡住任务恢复(管理员接口)
|
||
|
||
Async Engine 会自动轮询,无需手动触发恢复。
|
||
"""
|
||
# 权限检查:基于特定手机号判断管理员
|
||
is_admin = current_user.mobile in ["13800138000", "admin"]
|
||
if not is_admin:
|
||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||
|
||
return success_response(
|
||
data={
|
||
"message": "Async Engine 会持续自动轮询,无需手动触发恢复",
|
||
"task_id": None,
|
||
}
|
||
)
|