505 lines
24 KiB
Python
505 lines
24 KiB
Python
"""
|
||
Avatar 形象克隆处理器
|
||
====================
|
||
|
||
管理 Kling 形象克隆的提交与轮询。
|
||
占用全局槽位:2
|
||
|
||
数据策略:不操作数据库,所有中间状态存储在 Redis 中。
|
||
"""
|
||
|
||
import asyncio
|
||
import contextlib
|
||
import json
|
||
import logging
|
||
from datetime import UTC, datetime
|
||
from typing import Any
|
||
|
||
import aiohttp
|
||
|
||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||
from app.config import get_settings
|
||
from app.core.redis_client import get_redis_client
|
||
from app.scheduler.handlers.base import AsyncHandler
|
||
from app.scheduler.models import StateChange
|
||
from app.scheduler.registry import JobRegistry
|
||
from app.scheduler.slot_manager import SlotManager
|
||
from app.schemas.enums import AvatarCloneStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
SLOT_KEY = "kling:avatar_slots"
|
||
MAX_SLOTS = 2
|
||
|
||
SYSTEM_BUSY_MESSAGE = "系统繁忙,请稍后重试"
|
||
SYSTEM_ERROR_MESSAGE = "系统处理异常,请稍后重试或联系客服"
|
||
|
||
|
||
def _get_kling_provider() -> KlingAIProvider:
|
||
settings = get_settings()
|
||
return KlingAIProvider(
|
||
config={
|
||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||
}
|
||
)
|
||
|
||
|
||
def _translate_voice_error(message: str) -> str:
|
||
msg = (message or "").lower()
|
||
if "no valid audio" in msg or "audio" in msg or "voice" in msg or "人声" in msg:
|
||
return "自定义音色创建失败:视频中没有检测到清晰的人声。请确保上传「有声的人物视频」,且人声干净、无杂音、背景噪音小。"
|
||
if "duration" in msg or "时长" in msg:
|
||
return "自定义音色创建失败:视频时长不符合要求。请使用 5-30 秒的视频。"
|
||
if "format" in msg or "格式" in msg:
|
||
return "自定义音色创建失败:视频格式不支持。请使用 MP4 或 MOV 格式。"
|
||
if "size" in msg or "大小" in msg or "mb" in msg:
|
||
return "自定义音色创建失败:视频文件过大。请压缩至 200MB 以内。"
|
||
if "quality" in msg or "质量" in msg:
|
||
return "自定义音色创建失败:视频/音频质量不符合要求。请确保画面清晰、人声干净、无强烈背景噪音。"
|
||
return f"自定义音色创建失败:{message}。请检查是否上传了符合要求的「有声的人物视频」。"
|
||
|
||
|
||
def _translate_element_error(message: str) -> str:
|
||
msg = (message or "").lower()
|
||
if "duration" in msg or "时长" in msg:
|
||
return "主体创建失败:视频时长不符合要求。请使用 3-8 秒的人物特写视频。"
|
||
if "resolution" in msg or "height" in msg or "像素" in msg or "720" in msg or "2160" in msg:
|
||
return "主体创建失败:视频分辨率不符合要求。请确保视频高度在 720px~2160px 之间。"
|
||
if "size" in msg or "大小" in msg or "mb" in msg or "200" in msg:
|
||
return "主体创建失败:视频文件过大。请压缩至 200MB 以内。"
|
||
if "format" in msg or "格式" in msg or "mp4" in msg or "mov" in msg:
|
||
return "主体创建失败:视频格式不支持。请使用 MP4 或 MOV 格式。"
|
||
if "face" in msg or "人脸" in msg or "detect" in msg or "主体" in msg:
|
||
return "主体创建失败:未能从视频中检测到稳定的人脸。请确保视频为「写实风格的人物正面特写」,人脸清晰、无遮挡、光线充足。"
|
||
if "human" in msg or "人形" in msg or "character" in msg or "写实" in msg:
|
||
return "主体创建失败:视频内容不符合要求。请确保视频中是「写实风格的真实人物」,非卡通、非动物、非虚拟形象。"
|
||
return f"主体创建失败:{message}。请检查视频是否为 3-8 秒、人脸清晰、写实风格的正面人物视频。"
|
||
|
||
|
||
def _translate_system_error(error: Exception, step: str) -> tuple[str, str]:
|
||
error_str = str(error)
|
||
error_type = type(error).__name__
|
||
if isinstance(error, aiohttp.ClientError | asyncio.TimeoutError):
|
||
return SYSTEM_BUSY_MESSAGE, f"[{step}] 网络错误: {error_type}: {error_str}"
|
||
if "500" in error_str or "503" in error_str or "502" in error_str:
|
||
return SYSTEM_BUSY_MESSAGE, f"[{step}] KlingAI 服务错误: {error_type}: {error_str}"
|
||
if (
|
||
"rate limit" in error_str.lower()
|
||
or "too many requests" in error_str.lower()
|
||
or "429" in error_str
|
||
):
|
||
return SYSTEM_BUSY_MESSAGE, f"[{step}] API 限流: {error_type}: {error_str}"
|
||
return SYSTEM_ERROR_MESSAGE, f"[{step}] 系统错误: {error_type}: {error_str}"
|
||
|
||
|
||
async def _update_avatar_state(registry: JobRegistry, avatar_id: str, **fields: Any) -> None:
|
||
"""更新 Redis 中的 avatar 状态(同时更新 updated_at)"""
|
||
fields["updated_at"] = datetime.now(UTC).isoformat()
|
||
await registry.update(avatar_id, **fields)
|
||
|
||
|
||
class AvatarHandler(AsyncHandler):
|
||
name = "avatar_clone"
|
||
slot_key = SLOT_KEY
|
||
max_slots = MAX_SLOTS
|
||
|
||
async def tick(
|
||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||
) -> list[StateChange]:
|
||
changes: list[StateChange] = []
|
||
for job in jobs:
|
||
job_changes = await self._process_job(job, registry, slots)
|
||
changes.extend(job_changes)
|
||
return changes
|
||
|
||
async def _process_job(
|
||
self, job: Any, registry: JobRegistry, slots: SlotManager
|
||
) -> list[StateChange]:
|
||
changes: list[StateChange] = []
|
||
avatar_id = job.job_id
|
||
|
||
# 从 Redis 读取 avatar 状态
|
||
redis = get_redis_client()
|
||
state_raw = await redis.hgetall(f"job:{avatar_id}")
|
||
if not state_raw:
|
||
logger.error(f"Avatar job not found in Redis: {avatar_id}")
|
||
_msg = "任务记录丢失,请重新提交"
|
||
changes.append(StateChange(job_id=avatar_id, field_path="status", value="failed"))
|
||
changes.append(StateChange(job_id=avatar_id, field_path="message", value=_msg))
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_msg))
|
||
return changes
|
||
|
||
# 解析 params
|
||
params = {}
|
||
if "params" in state_raw and state_raw["params"]:
|
||
with contextlib.suppress(json.JSONDecodeError):
|
||
params = json.loads(state_raw["params"])
|
||
|
||
status = state_raw.get("avatar_status", state_raw.get("status", ""))
|
||
provider = _get_kling_provider()
|
||
|
||
# 辅助函数:读取字段
|
||
def _f(key: str) -> str:
|
||
return state_raw.get(key, "") or ""
|
||
|
||
# ---------- pending: 创建音色 ----------
|
||
if status == AvatarCloneStatus.PENDING.value:
|
||
slot_id = f"avatar:{avatar_id}"
|
||
acquired = await slots.acquire(SLOT_KEY, slot_id, MAX_SLOTS)
|
||
if not acquired:
|
||
return changes # 槽位已满,等下一轮
|
||
|
||
try:
|
||
await _update_avatar_state(
|
||
registry, avatar_id, avatar_status=AvatarCloneStatus.VOICE_PROCESSING.value
|
||
)
|
||
changes.append(
|
||
StateChange(
|
||
job_id=avatar_id, field_path="message", value="正在创建自定义音色..."
|
||
)
|
||
)
|
||
voice_result = await provider.create_custom_voice(
|
||
voice_name=params.get("name", ""),
|
||
video_url=params.get("video_url", ""),
|
||
)
|
||
voice_task_id = voice_result.get("task_id")
|
||
if not voice_task_id:
|
||
raise Exception("未返回音色任务 ID")
|
||
await _update_avatar_state(registry, avatar_id, provider_voice_job_id=voice_task_id)
|
||
logger.info(f"Avatar {avatar_id}: created voice task {voice_task_id}")
|
||
except Exception as e:
|
||
await slots.release(SLOT_KEY, slot_id)
|
||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||
code in str(e) for code in ["500", "503", "502", "429"]
|
||
):
|
||
user_msg, cloud_detail = _translate_system_error(e, "voice_create")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||
fail_reason=user_msg,
|
||
)
|
||
logger.error(f"Avatar {avatar_id} voice_create system error: {cloud_detail}")
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||
)
|
||
else:
|
||
_reason = _translate_voice_error(str(e))
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
|
||
# ---------- voice_processing: 轮询音色 ----------
|
||
elif status == AvatarCloneStatus.VOICE_PROCESSING.value:
|
||
provider_voice_job_id = _f("provider_voice_job_id")
|
||
if not provider_voice_job_id:
|
||
return changes
|
||
try:
|
||
result = await provider.get_custom_voice_task(provider_voice_job_id)
|
||
kling_status = result.get("task_status", "processing")
|
||
logger.info(
|
||
f"Avatar {avatar_id}: voice task {provider_voice_job_id} status={kling_status}"
|
||
)
|
||
if kling_status == "processing":
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value="音色处理中...")
|
||
)
|
||
elif kling_status == "succeed":
|
||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||
task_result = result.get("task_result", {})
|
||
voices = task_result.get("voices", [])
|
||
voice_id = None
|
||
trial_url = None
|
||
if voices:
|
||
voice_info = voices[0]
|
||
voice_id = voice_info.get("voice_id") or voice_info.get("id")
|
||
trial_url = (
|
||
voice_info.get("trial_url")
|
||
or voice_info.get("preview_url")
|
||
or voice_info.get("voice_url")
|
||
)
|
||
if not voice_id:
|
||
raise Exception("音色任务成功但未返回 voice_id")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_PENDING.value,
|
||
voice_id=voice_id,
|
||
trial_url=trial_url or "",
|
||
)
|
||
changes.append(
|
||
StateChange(
|
||
job_id=avatar_id,
|
||
field_path="message",
|
||
value="音色创建成功,准备创建形象主体...",
|
||
)
|
||
)
|
||
logger.info(f"Avatar {avatar_id}: voice succeed, voice_id={voice_id}")
|
||
|
||
elif kling_status == "failed":
|
||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||
error_msg = result.get("task_msg", "任务执行失败")
|
||
_reason = _translate_voice_error(error_msg)
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
except Exception as e:
|
||
logger.exception(f"Avatar {avatar_id}: voice poll error")
|
||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||
code in str(e) for code in ["500", "503", "502", "429"]
|
||
):
|
||
user_msg, cloud_detail = _translate_system_error(e, "voice_poll")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||
fail_reason=user_msg,
|
||
)
|
||
logger.error(f"Avatar {avatar_id} voice_poll system error: {cloud_detail}")
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||
)
|
||
else:
|
||
_reason = _translate_voice_error(str(e))
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
|
||
# ---------- element_pending: 创建主体 ----------
|
||
elif status == AvatarCloneStatus.ELEMENT_PENDING.value:
|
||
slot_id = f"avatar:{avatar_id}"
|
||
acquired = await slots.acquire(SLOT_KEY, slot_id, MAX_SLOTS)
|
||
if not acquired:
|
||
return changes
|
||
|
||
try:
|
||
await _update_avatar_state(
|
||
registry, avatar_id, avatar_status=AvatarCloneStatus.ELEMENT_PROCESSING.value
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value="正在创建形象主体...")
|
||
)
|
||
element_result = await provider.create_element(
|
||
element_name=params.get("name", ""),
|
||
element_description=f"{params.get('name', '')} 的克隆形象",
|
||
reference_type="video_refer",
|
||
element_video_list={
|
||
"refer_videos": [{"video_url": params.get("video_url", "")}]
|
||
},
|
||
element_voice_id=_f("voice_id"),
|
||
)
|
||
element_task_id = element_result.get("task_id")
|
||
if not element_task_id:
|
||
raise Exception("未返回主体任务 ID")
|
||
await _update_avatar_state(
|
||
registry, avatar_id, provider_element_job_id=element_task_id
|
||
)
|
||
logger.info(f"Avatar {avatar_id}: created element task {element_task_id}")
|
||
except Exception as e:
|
||
await slots.release(SLOT_KEY, slot_id)
|
||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||
code in str(e) for code in ["500", "503", "502", "429"]
|
||
):
|
||
user_msg, cloud_detail = _translate_system_error(e, "element_create")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
fail_reason=user_msg,
|
||
)
|
||
logger.error(f"Avatar {avatar_id} element_create system error: {cloud_detail}")
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||
)
|
||
else:
|
||
_reason = _translate_element_error(str(e))
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
|
||
# ---------- element_processing: 轮询主体 ----------
|
||
elif status == AvatarCloneStatus.ELEMENT_PROCESSING.value:
|
||
provider_element_job_id = _f("provider_element_job_id")
|
||
if not provider_element_job_id:
|
||
return changes
|
||
try:
|
||
result = await provider.get_element_task(provider_element_job_id)
|
||
kling_status = result.get("task_status", "processing")
|
||
logger.info(
|
||
f"Avatar {avatar_id}: element task {provider_element_job_id} status={kling_status}"
|
||
)
|
||
if kling_status == "processing":
|
||
changes.append(
|
||
StateChange(
|
||
job_id=avatar_id, field_path="message", value="形象主体处理中..."
|
||
)
|
||
)
|
||
elif kling_status == "succeed":
|
||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||
task_result = result.get("task_result", {})
|
||
elements = task_result.get("elements", [])
|
||
element_id = None
|
||
if elements:
|
||
element_id = elements[0].get("element_id")
|
||
if not element_id:
|
||
element_id = task_result.get("element_id")
|
||
if not element_id:
|
||
raise Exception("主体任务成功但未返回 element_id")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.SUCCEED.value,
|
||
provider_element_id=str(element_id),
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="completed")
|
||
)
|
||
changes.append(
|
||
StateChange(
|
||
job_id=avatar_id,
|
||
field_path="result",
|
||
value={
|
||
"avatar_id": avatar_id,
|
||
"name": params.get("name", ""),
|
||
"video_url": params.get("video_url", ""),
|
||
"voice_id": _f("voice_id"),
|
||
"element_id": int(element_id),
|
||
"trial_url": _f("trial_url"),
|
||
},
|
||
)
|
||
)
|
||
logger.info(f"Avatar {avatar_id}: element succeed, element_id={element_id}")
|
||
|
||
elif kling_status == "failed":
|
||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||
error_msg = result.get("task_msg", "任务执行失败")
|
||
_reason = _translate_element_error(error_msg)
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
except Exception as e:
|
||
logger.exception(f"Avatar {avatar_id}: element poll error")
|
||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||
code in str(e) for code in ["500", "503", "502", "429"]
|
||
):
|
||
user_msg, cloud_detail = _translate_system_error(e, "element_poll")
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
fail_reason=user_msg,
|
||
)
|
||
logger.error(f"Avatar {avatar_id} element_poll system error: {cloud_detail}")
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||
)
|
||
else:
|
||
_reason = _translate_element_error(str(e))
|
||
await _update_avatar_state(
|
||
registry,
|
||
avatar_id,
|
||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
fail_reason=_reason,
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||
)
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||
)
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||
|
||
# ---------- 已结束状态:移出 running ----------
|
||
elif status in (
|
||
AvatarCloneStatus.SUCCEED.value,
|
||
AvatarCloneStatus.VOICE_FAILED.value,
|
||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||
):
|
||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||
if status == AvatarCloneStatus.SUCCEED.value:
|
||
changes.append(
|
||
StateChange(job_id=avatar_id, field_path="status", value="completed")
|
||
)
|
||
else:
|
||
_msg = "任务状态异常"
|
||
changes.append(StateChange(job_id=avatar_id, field_path="status", value="failed"))
|
||
changes.append(StateChange(job_id=avatar_id, field_path="message", value=_msg))
|
||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_msg))
|
||
|
||
return changes
|