Files
meijiaka-zy/python-api/app/api/v1/tasks.py
T

562 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
统一任务管理 API
===============
提供任务创建和状态查询接口,支持:
- video: 视频生成
- image: 图片生成
- script: 脚本生成
- subtitle: 字幕对齐
- copy: 文案提取
- avatar_clone: 形象克隆
"""
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from app.api.deps import get_current_user
from app.core.redis_client import get_redis_client
from app.models.user import User
from app.scheduler.registry import JobRegistry
from app.schemas.enums import AvatarCloneStatus
from app.schemas.segment import Segment
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Tasks"])
# ========== 请求/响应模型 ==========
class VideoParams(BaseModel):
"""视频生成参数"""
segments: list[Segment] = Field(..., description="分镜列表")
human_id: int | None = Field(None, description="数字人主体ID")
@field_validator("segments")
@classmethod
def validate_segments(cls, v: list[Segment]) -> list[Segment]:
if not v:
raise ValueError("segments 不能为空列表")
return v
class ImageParams(BaseModel):
"""图片生成参数"""
prompt: str = Field(..., min_length=1, description="图片描述")
image_type: str = Field(default="cover", description="图片类型: empty_shot/cover")
reference_image: str | None = Field(None, description="参考图片URL(图生图)")
human_id: int | None = Field(None, description="数字人主体IDomni-image使用)")
@field_validator("prompt")
@classmethod
def validate_prompt(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("prompt 不能为空")
return v.strip()
class ScriptParams(BaseModel):
"""脚本生成参数"""
category: str = Field(..., min_length=1, description="大类代码")
subcategory: str = Field(..., min_length=1, description="小类代码")
style: str = Field(default="default", description="脚本风格")
duration: int = Field(default=60, ge=10, le=300, description="视频时长(秒)")
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("category 不能为空")
return v.strip()
@field_validator("subcategory")
@classmethod
def validate_subcategory(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("subcategory 不能为空")
return v.strip()
class SubtitleParams(BaseModel):
"""字幕生成参数"""
video_path: str = Field(..., min_length=1, description="视频文件路径")
language: str = Field(default="zh", description="语言代码")
mode: str = Field(default="caption", description="模式: caption/auto_align")
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
@field_validator("video_path")
@classmethod
def validate_video_path(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("video_path 不能为空")
return v.strip()
class CopyParams(BaseModel):
"""文案提取参数"""
video_url: str = Field(..., min_length=1, description="视频链接URL")
@field_validator("video_url")
@classmethod
def validate_video_url(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("video_url 不能为空")
if not v.startswith(("http://", "https://")):
raise ValueError("video_url 必须是有效的URL")
return v.strip()
class TTSParams(BaseModel):
"""TTS 语音合成参数"""
segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover")
voice_id: str = Field(default="zh_female_yizhi", description="音色 ID")
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
@field_validator("segments")
@classmethod
def validate_segments(cls, v: list[dict]) -> list[dict]:
if not v:
raise ValueError("segments 不能为空列表")
return v
class TaskCreateRequest(BaseModel):
"""创建任务请求"""
project_id: str | None = Field(None, description="项目ID(可选)")
params: dict = Field(default_factory=dict, description="任务参数")
class TaskCreateResponse(BaseModel):
"""创建任务响应"""
task_id: str = Field(..., description="任务ID")
status: str = Field("pending", description="任务状态")
message: str = Field("任务已创建", description="状态消息")
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str = Field(..., description="任务ID")
type: str | None = Field(None, description="任务类型")
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
progress: int = Field(0, description="进度百分比 (0-100)")
message: str = Field("", description="状态描述")
completed: int = Field(0, description="已完成子任务数")
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
created_at: str = Field("", description="任务创建时间(ISO格式)")
# ========== 辅助函数 ==========
def _generate_task_id() -> str:
"""生成任务ID"""
return f"task_{uuid.uuid4().hex[:16]}"
# ========== API 路由 ==========
@router.post("/{task_type}", response_model=TaskCreateResponse)
async def create_task(
task_type: Literal["video", "image", "script", "subtitle", "copy", "avatar_clone", "tts"],
request: TaskCreateRequest,
current_user: User = Depends(get_current_user),
) -> TaskCreateResponse:
"""
创建新任务
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
"""
task_id = _generate_task_id()
user_id = str(current_user.id)
project_id = request.project_id or request.params.get("project_id", "")
redis = get_redis_client()
registry = JobRegistry(redis)
try:
await registry.create(task_id, task_type, user_id)
except Exception as e:
logger.error(f"[API] Failed to create registry entry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
try:
if task_type == "video":
# 字段适配:前端 shots/element_id → 后端 segments/human_id
import re
video_params = dict(request.params)
if "shots" in video_params:
shots = video_params.pop("shots")
for s in shots:
# 清洗 id:前端可能发送数字,Segment 模型要求 str
if "id" in s and not isinstance(s["id"], str):
s["id"] = str(s["id"])
# 清洗 duration:前端可能发送 "5s"Segment 模型要求 int
duration = s.get("duration")
if isinstance(duration, str):
m = re.search(r"\d+", duration)
s["duration"] = int(m.group()) if m else None
video_params["segments"] = shots
if "element_id" in video_params:
video_params["human_id"] = video_params.pop("element_id")
validated = VideoParams(**video_params)
segments = validated.segments
human_id = validated.human_id
normalized_segments = []
for s in segments:
normalized_segments.append(
{
"id": str(s.id),
"type": s.type,
"scene": s.scene,
"voiceover": s.voiceover,
"duration": s.duration,
"human_id": (human_id if s.type == "segment" else None),
"voice_id": s.voice_id,
"provider_task_id": None,
"status": "pending",
"video_url": None,
"local_path": None,
"qiniu_url": None,
"error_message": None,
}
)
await registry.update(
task_id,
status="running",
message=f"开始生成视频,共 {len(normalized_segments)} 个镜头...",
completed=0,
total=len(normalized_segments),
params={
"project_id": project_id,
"user_id": user_id,
"human_id": human_id,
"shots": json.dumps(normalized_segments, ensure_ascii=False),
},
)
await registry.add_running(task_id)
elif task_type == "image":
image_validated = ImageParams(**request.params)
await registry.update(
task_id,
status="running",
message="准备生成图片...",
completed=0,
total=1,
params={
"project_id": project_id,
"user_id": user_id,
"prompt": image_validated.prompt,
"image_type": image_validated.image_type,
"reference_image": image_validated.reference_image,
"human_id": image_validated.human_id,
},
)
await registry.add_running(task_id)
elif task_type == "script":
script_validated = ScriptParams(**request.params)
await registry.update(
task_id,
status="running",
progress=0,
message="等待执行...",
params={
"category": script_validated.category,
"subcategory": script_validated.subcategory,
"style": script_validated.style,
"duration": script_validated.duration,
},
)
await registry.add_running(task_id)
elif task_type == "subtitle":
subtitle_validated = SubtitleParams(**request.params)
await registry.update(
task_id,
status="running",
message="准备字幕生成...",
completed=0,
total=1,
params={
"project_id": project_id,
"video_path": subtitle_validated.video_path,
"language": subtitle_validated.language,
"mode": subtitle_validated.mode,
"audio_text": subtitle_validated.audio_text,
},
)
await registry.add_running(task_id)
elif task_type == "copy":
copy_validated = CopyParams(**request.params)
await registry.update(
task_id,
status="running",
message="准备提取文案...",
completed=0,
total=1,
params={"video_url": copy_validated.video_url},
)
await registry.add_running(task_id)
elif task_type == "avatar_clone":
name = request.params.get("name", "").strip()
video_url = request.params.get("video_url", "").strip()
if not name:
raise ValueError("name 不能为空")
if not video_url:
raise ValueError("video_url 不能为空")
if not video_url.startswith(("http://", "https://")):
raise ValueError("video_url 必须是有效的URL")
avatar_id = f"avt_{uuid.uuid4().hex[:16]}"
now = datetime.now(UTC).isoformat()
# avatar_clone 使用自己的 task_idavt_xxx),不走通用的 task_xxx
await registry.create(avatar_id, "avatar_clone", user_id)
await registry.update(
avatar_id,
status="running",
progress=5,
message="开始形象克隆...",
completed=0,
total=1,
params={
"avatar_id": avatar_id,
"name": name,
"video_url": video_url,
"user_id": user_id,
},
avatar_status=AvatarCloneStatus.PENDING.value,
avatar_name=name,
avatar_video_url=video_url,
voice_id="",
provider_element_id="",
provider_voice_job_id="",
provider_element_job_id="",
trial_url="",
fail_reason="",
created_at=now,
updated_at=now,
)
await registry.add_running(avatar_id)
# 返回的任务 ID 用 avatar_id,保持前端兼容
task_id = avatar_id
elif task_type == "tts":
tts_params = dict(request.params)
raw_segments = tts_params.get("segments") or tts_params.get("texts", [])
if isinstance(raw_segments, list):
normalized = []
for i, seg in enumerate(raw_segments):
if isinstance(seg, dict):
normalized.append({
"id": seg.get("id", f"tts_{i}"),
"text": seg.get("text") or seg.get("voiceover", ""),
"index": i,
})
else:
normalized.append({
"id": f"tts_{i}",
"text": str(seg),
"index": i,
})
tts_params["segments"] = normalized
else:
raise ValueError("segments 必须为列表")
await registry.update(
task_id,
status="running",
message="准备语音合成...",
completed=0,
total=len(tts_params.get("segments", [])),
params={
"project_id": project_id,
"user_id": user_id,
"segments": json.dumps(tts_params.get("segments", []), ensure_ascii=False),
"voice_id": tts_params.get("voice_id", "zh_female_yizhi"),
"speed": tts_params.get("speed", 1.0),
},
)
await registry.add_running(task_id)
else:
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
return TaskCreateResponse(
task_id=task_id,
status="pending",
message=f"{task_type} 任务已创建",
)
except ValueError as e:
logger.warning(f"[API] Invalid params for {task_type}: {e}")
try:
await registry.update(task_id, status="failed", message=f"参数错误: {e}", error=str(e))
except Exception as registry_err:
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
raise HTTPException(status_code=422, detail=f"参数错误: {e}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"[API] Failed to create task: {e}")
try:
await registry.update(task_id, status="failed", message=str(e), error=str(e))
except Exception as registry_err:
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
def _map_avatar_status(status: str) -> str:
"""将 AvatarCloneStatus 映射为统一任务状态"""
mapping = {
"succeed": "completed",
"voice_failed": "failed",
"element_failed": "failed",
"timeout": "failed",
"pending": "running",
"voice_processing": "running",
"element_pending": "running",
"element_processing": "running",
}
return mapping.get(status, "running")
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
project_id: str | None = None,
current_user: User = Depends(get_current_user),
) -> list[TaskStatusResponse]:
"""
查询当前用户所有进行中的任务
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
"""
redis = get_redis_client()
registry = JobRegistry(redis)
try:
jobs = await registry.list_running_by_user(str(current_user.id))
except Exception as e:
logger.error(f"[API] Redis error when listing tasks: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
results: list[TaskStatusResponse] = []
for job in jobs:
# 按 project_id 过滤
if project_id and job.project_id != project_id:
continue
results.append(
TaskStatusResponse(
task_id=job.job_id,
type=job.job_type,
status=job.status,
progress=job.progress,
message=job.message,
completed=job.completed,
total=job.total,
result=None, # 列表查询不返回 result,避免数据过大
error=job.error,
created_at=job.created_at,
)
)
return results
@router.get("/{task_id}", response_model=TaskStatusResponse)
async def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user),
) -> TaskStatusResponse:
"""
查询任务状态
前端通过轮询此接口获取任务进度。
任务状态仅从 Redis 查询,记录过期后返回 404。
"""
redis = get_redis_client()
registry = JobRegistry(redis)
try:
job = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not job:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
# 权限检查
if job.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
return TaskStatusResponse(
task_id=task_id,
type=job.job_type,
status=job.status,
progress=job.progress,
message=job.message,
completed=job.completed,
total=job.total,
result=job.result,
error=job.error,
created_at=job.created_at,
)
@router.get("/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取任务结果(简化接口,直接返回 result 字段)
"""
redis = get_redis_client()
registry = JobRegistry(redis)
try:
job = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not job:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
if job.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
if job.status != "completed":
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {job.status}")
return job.result or {}