""" Avatar 形象克隆模块 ================== 串行流程: 1. 使用上传的视频创建 KlingAI 自定义音色 (custom-voices) 2. 轮询等待音色生成完成,获取 voice_id 3. 使用同一视频 + voice_id 创建 KlingAI 主体 (advanced-custom-elements) 4. 轮询等待主体生成完成,获取 provider_element_id 5. 返回统一的 AvatarItem 异步架构: - POST /avatar/clone 只负责注册到 Async Engine(纯 Redis,无 DB),立即返回 task_id - 真正的轮询由 Async Engine Scheduler 在后台执行 - 前端通过 SSE 或轮询 GET /avatar/tasks/{task_id} 查询进度 数据策略: - 形象克隆数据只保存在前端本地,后端不持久化到数据库 - 任务运行时的中间状态全部存储在 Redis 中(TTL 24h) 错误提示策略: - custom-voice 失败:提示"有声的人物视频"相关原因 - element 失败:提示视频内容/质量不符合主体创建要求 - 超时:标记为 timeout,支持重试 """ import asyncio import contextlib import json import logging import uuid from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, Field from app.ai.providers.klingai_provider import KlingAIProvider from app.api.deps import get_current_user from app.config import get_settings from app.core.redis_client import get_redis_client from app.scheduler.registry import JobRegistry from app.schemas.common import ApiResponse, success_response from app.schemas.enums import AvatarCloneStatus logger = logging.getLogger(__name__) router = APIRouter() def _get_kling_provider() -> KlingAIProvider: settings = get_settings() return KlingAIProvider( config={ "access_key": settings.KLINGAI_ACCESS_KEY or "", "secret_key": settings.KLINGAI_SECRET_KEY or "", } ) async def _get_avatar_state(redis, job_id: str) -> dict | None: """从 Redis 读取 avatar 任务完整状态""" data = await redis.hgetall(f"job:{job_id}") if not data: return None # 解析 JSON 字段 for key in ("result", "params"): if key in data and data[key]: with contextlib.suppress(json.JSONDecodeError): data[key] = json.loads(data[key]) return data class CloneAvatarRequest(BaseModel): """创建形象克隆请求""" name: str = Field(..., min_length=1, max_length=20, description="形象名称") video_url: str = Field(description="人物视频 URL") class CloneAvatarResponse(BaseModel): """创建形象克隆响应""" task_id: str = Field(..., description="任务 ID(用于 SSE/轮询跟踪进度)") status: str = Field("pending", description="初始状态") class AvatarTaskStatusResponse(BaseModel): """任务状态查询响应""" task_id: str status: str = Field(..., description="当前状态") fail_reason: str | None = Field(None, description="失败原因") voice_id: str | None = Field(None, description="已生成的音色 ID") human_id: int | None = Field(None, description="已生成的主体 ID") trial_url: str | None = Field(None, description="试听 URL") video_url: str = Field(..., description="原始视频 URL") name: str = Field(..., description="形象名称") created_at: datetime = Field(..., description="创建时间") updated_at: datetime = Field(..., description="更新时间") class AvatarItem(BaseModel): """形象库列表项""" model_config = ConfigDict(from_attributes=True) id: str = Field(..., description="形象唯一标识") name: str = Field(..., description="展示名称") voice_id: str = Field(..., description="Kling 自定义音色 ID") human_id: int = Field(..., description="数字人主体 ID") video_url: str = Field(description="原始人物视频 URL") trial_url: str | None = Field(None, description="音色试听 URL") record_time: str = Field(description="创建时间 ISO 字符串") class UpdateAvatarNameRequest(BaseModel): """更新形象名称请求""" name: str = Field(..., min_length=1, max_length=20, description="新形象名称") # ============================================================ # API 路由 # ============================================================ @router.post("/avatar/clone", response_model=ApiResponse[CloneAvatarResponse]) async def clone_avatar( data: CloneAvatarRequest, current_user: dict = Depends(get_current_user), ): """ 提交形象克隆任务 立即返回 task_id,前端通过 SSE 或轮询跟踪进度。 实际串行流程由 Async Engine Scheduler 异步执行。 任务状态纯 Redis 存储,不写入数据库。 """ user_id = str(current_user.id) name = data.name.strip() video_url = data.video_url.strip() # 生成 task_id task_id = f"avt_{uuid.uuid4().hex[:16]}" now = datetime.now(UTC) # 写入 Redis,供 Async Engine 调度(同时存储 avatar 初始状态) redis = get_redis_client() registry = JobRegistry(redis) await registry.create(task_id, "avatar_clone", user_id) await registry.update( task_id, status="running", progress=5, message="开始形象克隆...", completed=0, total=1, params={ "avatar_id": task_id, "name": name, "video_url": video_url, "user_id": user_id, }, # 存储 avatar 状态字段(供 API 查询) avatar_status=AvatarCloneStatus.PENDING.value, avatar_name=name, avatar_video_url=video_url, voice_id="", provider_element_id="", provider_voice_job_id="", provider_element_job_id="", trial_url="", fail_reason="", created_at=now.isoformat(), updated_at=now.isoformat(), ) await registry.add_running(task_id) return success_response(data=CloneAvatarResponse(task_id=task_id, status="pending")) @router.get("/avatar/tasks/{task_id}", response_model=ApiResponse[AvatarTaskStatusResponse]) async def get_avatar_task_status( task_id: str, current_user: dict = Depends(get_current_user), ): """查询形象克隆任务状态(从 Redis 读取)""" redis = get_redis_client() state = await _get_avatar_state(redis, task_id) if not state: raise HTTPException(status_code=404, detail="任务不存在") # 权限检查 params = state.get("params", {}) if isinstance(state.get("params"), dict) else {} if params.get("user_id") != str(current_user.id): raise HTTPException(status_code=404, detail="任务不存在") def _dt(key: str) -> datetime: raw = state.get(key, "") if raw: try: return datetime.fromisoformat(raw) except ValueError: pass return datetime.now(UTC) def _int(key: str) -> int | None: raw = state.get(key, "") if raw: try: return int(raw) except ValueError: pass return None return success_response( data=AvatarTaskStatusResponse( task_id=task_id, status=state.get("avatar_status", state.get("status", "unknown")), fail_reason=state.get("fail_reason") or None, voice_id=state.get("voice_id") or None, human_id=_int("provider_element_id"), trial_url=state.get("trial_url") or None, video_url=params.get("video_url", ""), name=params.get("name", ""), created_at=_dt("created_at"), updated_at=_dt("updated_at"), ) ) @router.get("/avatar/clone/stream") async def sse_avatar_clone( task_id: str = Query(..., alias="task_id", description="任务 ID"), current_user: dict = Depends(get_current_user), ): """ SSE 流:实时推送形象克隆任务状态 前端连接后,每 3 秒推送一次状态,直到任务结束(succeed / failed / timeout)。 """ user_id = str(current_user.id) async def event_stream(): for _ in range(400): # 最多 20 分钟(400 * 3s) redis = get_redis_client() state = await _get_avatar_state(redis, task_id) if not state: payload = json.dumps( {"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False ) yield f"event: error\ndata: {payload}\n\n" break # 权限检查 params = state.get("params", {}) if isinstance(state.get("params"), dict) else {} if params.get("user_id") != user_id: payload = json.dumps( {"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False ) yield f"event: error\ndata: {payload}\n\n" break avatar_status = state.get("avatar_status", state.get("status", "unknown")) payload = json.dumps( { "task_id": task_id, "status": avatar_status, "fail_reason": state.get("fail_reason") or None, "voice_id": state.get("voice_id") or None, "provider_element_id": state.get("provider_element_id") or None, "trial_url": state.get("trial_url") or None, "video_url": params.get("video_url", ""), "name": params.get("name", ""), "created_at": state.get("created_at", ""), "updated_at": state.get("updated_at", ""), }, ensure_ascii=False, ) yield f"data: {payload}\n\n" if avatar_status in ( AvatarCloneStatus.SUCCEED, AvatarCloneStatus.VOICE_FAILED, AvatarCloneStatus.ELEMENT_FAILED, AvatarCloneStatus.TIMEOUT, ): break await asyncio.sleep(3) else: # 达到最大轮询次数,推送超时事件 payload = json.dumps( {"status": "timeout", "fail_reason": "连接超时,请通过轮询接口继续跟踪"}, ensure_ascii=False, ) yield f"event: timeout\ndata: {payload}\n\n" return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) @router.post("/avatar/tasks/{task_id}/retry", response_model=ApiResponse[dict]) async def retry_avatar_task( task_id: str, current_user: dict = Depends(get_current_user), ): """ 重试失败或超时的形象克隆任务 仅允许对 voice_failed / element_failed / timeout 状态的任务重试。 重试时会重置状态为 pending 并重新注册到 Async Engine。 """ redis = get_redis_client() state = await _get_avatar_state(redis, task_id) if not state: raise HTTPException(status_code=404, detail="任务不存在") params = state.get("params", {}) if isinstance(state.get("params"), dict) else {} if params.get("user_id") != str(current_user.id): raise HTTPException(status_code=404, detail="任务不存在") avatar_status = state.get("avatar_status", state.get("status", "")) if avatar_status not in ( AvatarCloneStatus.VOICE_FAILED.value, AvatarCloneStatus.ELEMENT_FAILED.value, AvatarCloneStatus.TIMEOUT.value, ): raise HTTPException(status_code=400, detail=f"当前状态 {avatar_status} 不支持重试") # 重置状态 registry = JobRegistry(redis) now = datetime.now(UTC).isoformat() await registry.update( task_id, status="running", avatar_status=AvatarCloneStatus.PENDING, fail_reason="", voice_id="", provider_element_id="", provider_voice_job_id="", provider_element_job_id="", trial_url="", updated_at=now, ) await registry.add_running(task_id) return success_response(data={"task_id": task_id, "status": "pending"}) @router.delete("/avatar/{avatar_id}", response_model=ApiResponse[dict]) async def delete_avatar( avatar_id: str, voice_id: str | None = None, current_user: dict = Depends(get_current_user), ): """ 删除形象:清理 Kling 资源 + 删除 Redis 任务记录 不操作数据库,形象数据由前端本地管理。 """ redis = get_redis_client() state = await _get_avatar_state(redis, avatar_id) # 获取 Kling 资源 ID(优先用传入的,否则从 Redis 读) actual_voice_id = voice_id actual_element_id = None if state: params = state.get("params", {}) if isinstance(state.get("params"), dict) else {} if params.get("user_id") == str(current_user.id): actual_element_id = state.get("provider_element_id") if not actual_voice_id: actual_voice_id = state.get("voice_id") # 异步清理 Kling 资源(不阻塞前端) provider = _get_kling_provider() if actual_element_id: try: await provider.delete_element(str(actual_element_id)) except Exception as e: logger.warning(f"delete_element failed: {e}") if actual_voice_id: try: await provider.delete_custom_voice(actual_voice_id) except Exception as e: logger.warning(f"delete_custom_voice failed: {e}") # 删除 Redis 任务记录 registry = JobRegistry(redis) await registry.delete(avatar_id) return success_response(data={"success": True, "message": "形象已删除"}) @router.get("/avatar/library", response_model=ApiResponse[list[AvatarItem]]) async def get_avatar_library( current_user: dict = Depends(get_current_user), ): """ 获取当前用户的克隆形象库 形象数据只保存在前端本地,后端不持久化。 此接口始终返回空列表,由前端从 localStorage/文件系统读取真实数据。 """ return success_response(data=[]) @router.patch("/avatar/{avatar_id}", response_model=ApiResponse[dict]) async def update_avatar_name( avatar_id: str, data: UpdateAvatarNameRequest, current_user: dict = Depends(get_current_user), ): """ 更新形象名称 形象数据由前端本地管理,后端仅返回成功。 """ new_name = data.name.strip() if not new_name: raise HTTPException(status_code=400, detail="名称不能为空") return success_response(data={"success": True, "name": new_name}) # ============================================================================= # 管理和监控接口(用于排查问题和手动恢复) # ============================================================================= class AvatarHealthResponse(BaseModel): """形象克隆服务健康状态""" total_processing: int = Field(..., description="处理中的任务总数") pending: int = Field(..., description="待处理任务数") voice_processing: int = Field(..., description="音色生成中任务数") element_processing: int = Field(..., description="主体生成中任务数") stuck_tasks: int = Field(..., description="卡住任务数(超过30分钟)") recent_failures: int = Field(..., description="最近1小时失败数") @router.get("/avatar/health", response_model=ApiResponse[AvatarHealthResponse]) async def get_avatar_health( current_user: dict = Depends(get_current_user), ): """ 获取形象克隆服务健康状态 基于 Redis 运行中任务统计,不查询数据库。 """ redis = get_redis_client() registry = JobRegistry(redis) job_ids = await registry.get_running_job_ids() total_processing = 0 pending = 0 voice_processing = 0 element_processing = 0 stuck_tasks = 0 recent_failures = 0 now = datetime.now(UTC) stuck_threshold = now.timestamp() - 30 * 60 # 30 分钟前 recent_threshold = now.timestamp() - 60 * 60 # 1 小时前 for job_id in job_ids: state = await _get_avatar_state(redis, job_id) if not state: continue # 只统计当前用户的任务(非管理员) params = state.get("params", {}) if isinstance(state.get("params"), dict) else {} if params.get("user_id") != str(current_user.id): continue job_type = state.get("type", "") if job_type != "avatar_clone": continue avatar_status = state.get("avatar_status", state.get("status", "")) total_processing += 1 if avatar_status == AvatarCloneStatus.PENDING.value: pending += 1 elif avatar_status == AvatarCloneStatus.VOICE_PROCESSING.value: voice_processing += 1 elif avatar_status == AvatarCloneStatus.ELEMENT_PROCESSING.value: element_processing += 1 # 检查是否卡住(updated_at 超过 30 分钟) updated_at_raw = state.get("updated_at", "") if updated_at_raw: try: updated_ts = datetime.fromisoformat(updated_at_raw).timestamp() if updated_ts < stuck_threshold and avatar_status in ( AvatarCloneStatus.PENDING.value, AvatarCloneStatus.VOICE_PROCESSING.value, AvatarCloneStatus.ELEMENT_PROCESSING.value, ): stuck_tasks += 1 except ValueError: pass # 检查最近失败 if avatar_status in ( AvatarCloneStatus.VOICE_FAILED.value, AvatarCloneStatus.ELEMENT_FAILED.value, AvatarCloneStatus.TIMEOUT.value, ): updated_at_raw = state.get("updated_at", "") if updated_at_raw: try: updated_ts = datetime.fromisoformat(updated_at_raw).timestamp() if updated_ts >= recent_threshold: recent_failures += 1 except ValueError: pass return success_response( data=AvatarHealthResponse( total_processing=total_processing, pending=pending, voice_processing=voice_processing, element_processing=element_processing, stuck_tasks=stuck_tasks, recent_failures=recent_failures, ) ) @router.post("/avatar/admin/trigger-recovery", response_model=ApiResponse[dict]) async def admin_trigger_recovery( current_user: dict = Depends(get_current_user), ): """ 手动触发卡住任务恢复(管理员接口) Async Engine 会自动轮询,无需手动触发恢复。 """ # 权限检查:基于特定手机号判断管理员 is_admin = current_user.mobile in ["13800138000", "admin"] if not is_admin: raise HTTPException(status_code=403, detail="需要管理员权限") return success_response( data={ "message": "Async Engine 会持续自动轮询,无需手动触发恢复", "task_id": None, } )