464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""
|
||
统一任务管理 API
|
||
===============
|
||
|
||
提供任务创建和状态查询接口,支持:
|
||
- video: 视频生成
|
||
- image: 图片生成
|
||
- script: 脚本生成
|
||
- subtitle: 字幕对齐
|
||
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Literal
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel, Field, field_validator
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.core.redis_client import get_redis_client
|
||
from app.models.user import User
|
||
from app.scheduler.registry import JobRegistry
|
||
from app.schemas.segment import Segment
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(tags=["Tasks"])
|
||
|
||
|
||
# ========== 请求/响应模型 ==========
|
||
|
||
|
||
class VideoParams(BaseModel):
|
||
"""视频生成参数"""
|
||
|
||
segments: list[Segment] = Field(..., description="分镜列表")
|
||
|
||
@field_validator("segments")
|
||
@classmethod
|
||
def validate_segments(cls, v: list[Segment]) -> list[Segment]:
|
||
if not v:
|
||
raise ValueError("segments 不能为空列表")
|
||
return v
|
||
|
||
|
||
class ImageParams(BaseModel):
|
||
"""图片生成参数"""
|
||
|
||
prompt: str = Field(..., min_length=1, description="图片描述")
|
||
image_type: str = Field(default="cover", description="图片类型: empty_shot/cover")
|
||
reference_image: str | None = Field(None, description="参考图片URL(图生图)")
|
||
|
||
@field_validator("prompt")
|
||
@classmethod
|
||
def validate_prompt(cls, v: str) -> str:
|
||
if not v or not v.strip():
|
||
raise ValueError("prompt 不能为空")
|
||
return v.strip()
|
||
|
||
|
||
class ScriptParams(BaseModel):
|
||
"""脚本生成参数"""
|
||
|
||
category: str = Field(..., min_length=1, description="大类代码")
|
||
subcategory: str = Field(..., min_length=1, description="小类代码")
|
||
style: str = Field(default="default", description="脚本风格")
|
||
duration: int = Field(default=60, ge=10, le=300, description="视频时长(秒)")
|
||
|
||
@field_validator("category")
|
||
@classmethod
|
||
def validate_category(cls, v: str) -> str:
|
||
if not v or not v.strip():
|
||
raise ValueError("category 不能为空")
|
||
return v.strip()
|
||
|
||
@field_validator("subcategory")
|
||
@classmethod
|
||
def validate_subcategory(cls, v: str) -> str:
|
||
if not v or not v.strip():
|
||
raise ValueError("subcategory 不能为空")
|
||
return v.strip()
|
||
|
||
|
||
class SubtitleParams(BaseModel):
|
||
"""字幕生成参数"""
|
||
|
||
video_path: str = Field(..., min_length=1, description="视频文件路径")
|
||
language: str = Field(default="zh", description="语言代码")
|
||
mode: str = Field(default="caption", description="模式: caption/auto_align")
|
||
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
|
||
|
||
@field_validator("video_path")
|
||
@classmethod
|
||
def validate_video_path(cls, v: str) -> str:
|
||
if not v or not v.strip():
|
||
raise ValueError("video_path 不能为空")
|
||
return v.strip()
|
||
|
||
|
||
class TTSParams(BaseModel):
|
||
"""TTS 语音合成参数"""
|
||
|
||
segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover")
|
||
voice_id: str = Field(default="zh_female_yizhi", description="音色 ID")
|
||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
|
||
|
||
@field_validator("segments")
|
||
@classmethod
|
||
def validate_segments(cls, v: list[dict]) -> list[dict]:
|
||
if not v:
|
||
raise ValueError("segments 不能为空列表")
|
||
return v
|
||
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
"""创建任务请求"""
|
||
|
||
project_id: str | None = Field(None, description="项目ID(可选)")
|
||
params: dict = Field(default_factory=dict, description="任务参数")
|
||
|
||
|
||
class TaskCreateResponse(BaseModel):
|
||
"""创建任务响应"""
|
||
|
||
task_id: str = Field(..., description="任务ID")
|
||
status: str = Field("pending", description="任务状态")
|
||
message: str = Field("任务已创建", description="状态消息")
|
||
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
"""任务状态响应"""
|
||
|
||
task_id: str = Field(..., description="任务ID")
|
||
type: str | None = Field(None, description="任务类型")
|
||
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
|
||
progress: int = Field(0, description="进度百分比 (0-100)")
|
||
message: str = Field("", description="状态描述")
|
||
completed: int = Field(0, description="已完成子任务数")
|
||
total: int = Field(0, description="总子任务数")
|
||
result: dict | None = Field(None, description="任务结果(完成时)")
|
||
error: str | None = Field(None, description="错误信息(失败时)")
|
||
created_at: str = Field("", description="任务创建时间(ISO格式)")
|
||
|
||
|
||
# ========== 辅助函数 ==========
|
||
|
||
|
||
def _generate_task_id() -> str:
|
||
"""生成任务ID"""
|
||
return f"task_{uuid.uuid4().hex[:16]}"
|
||
|
||
|
||
# ========== API 路由 ==========
|
||
|
||
|
||
@router.post("/{task_type}", response_model=TaskCreateResponse)
|
||
async def create_task(
|
||
task_type: Literal["video", "image", "script", "subtitle", "tts"],
|
||
request: TaskCreateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
) -> TaskCreateResponse:
|
||
"""
|
||
创建新任务
|
||
|
||
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
|
||
"""
|
||
task_id = _generate_task_id()
|
||
user_id = str(current_user.id)
|
||
project_id = request.project_id or request.params.get("project_id", "")
|
||
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
|
||
try:
|
||
await registry.create(task_id, task_type, user_id)
|
||
except Exception as e:
|
||
logger.error(f"[API] Failed to create registry entry: {e}")
|
||
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
|
||
|
||
try:
|
||
if task_type == "video":
|
||
import re
|
||
|
||
video_params = dict(request.params)
|
||
if "shots" in video_params:
|
||
shots = video_params.pop("shots")
|
||
for s in shots:
|
||
# 清洗 id:前端可能发送数字,Segment 模型要求 str
|
||
if "id" in s and not isinstance(s["id"], str):
|
||
s["id"] = str(s["id"])
|
||
# 清洗 duration:前端可能发送 "5s",Segment 模型要求 int
|
||
duration = s.get("duration")
|
||
if isinstance(duration, str):
|
||
m = re.search(r"\d+", duration)
|
||
s["duration"] = int(m.group()) if m else None
|
||
video_params["segments"] = shots
|
||
validated = VideoParams(**video_params)
|
||
segments = validated.segments
|
||
|
||
normalized_segments = []
|
||
for s in segments:
|
||
normalized_segments.append(
|
||
{
|
||
"id": str(s.id),
|
||
"type": s.type,
|
||
"scene": s.scene,
|
||
"voiceover": s.voiceover,
|
||
"duration": s.duration,
|
||
"voice_id": s.voice_id,
|
||
"provider_task_id": None,
|
||
"status": "pending",
|
||
"video_url": None,
|
||
"local_path": None,
|
||
"qiniu_url": None,
|
||
"error_message": None,
|
||
}
|
||
)
|
||
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
message=f"开始生成视频,共 {len(normalized_segments)} 个镜头...",
|
||
completed=0,
|
||
total=len(normalized_segments),
|
||
params={
|
||
"project_id": project_id,
|
||
"user_id": user_id,
|
||
"shots": json.dumps(normalized_segments, ensure_ascii=False),
|
||
},
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
elif task_type == "image":
|
||
image_validated = ImageParams(**request.params)
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
message="准备生成图片...",
|
||
completed=0,
|
||
total=1,
|
||
params={
|
||
"project_id": project_id,
|
||
"user_id": user_id,
|
||
"prompt": image_validated.prompt,
|
||
"image_type": image_validated.image_type,
|
||
"reference_image": image_validated.reference_image,
|
||
},
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
elif task_type == "script":
|
||
script_validated = ScriptParams(**request.params)
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
progress=0,
|
||
message="等待执行...",
|
||
params={
|
||
"category": script_validated.category,
|
||
"subcategory": script_validated.subcategory,
|
||
"style": script_validated.style,
|
||
"duration": script_validated.duration,
|
||
},
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
elif task_type == "subtitle":
|
||
subtitle_validated = SubtitleParams(**request.params)
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
message="准备字幕生成...",
|
||
completed=0,
|
||
total=1,
|
||
params={
|
||
"project_id": project_id,
|
||
"video_path": subtitle_validated.video_path,
|
||
"language": subtitle_validated.language,
|
||
"mode": subtitle_validated.mode,
|
||
"audio_text": subtitle_validated.audio_text,
|
||
},
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
elif task_type == "tts":
|
||
tts_params = dict(request.params)
|
||
raw_segments = tts_params.get("segments") or tts_params.get("texts", [])
|
||
if isinstance(raw_segments, list):
|
||
normalized = []
|
||
for i, seg in enumerate(raw_segments):
|
||
if isinstance(seg, dict):
|
||
normalized.append({
|
||
"id": seg.get("id", f"tts_{i}"),
|
||
"text": seg.get("text") or seg.get("voiceover", ""),
|
||
"index": i,
|
||
})
|
||
else:
|
||
normalized.append({
|
||
"id": f"tts_{i}",
|
||
"text": str(seg),
|
||
"index": i,
|
||
})
|
||
tts_params["segments"] = normalized
|
||
else:
|
||
raise ValueError("segments 必须为列表")
|
||
|
||
await registry.update(
|
||
task_id,
|
||
status="running",
|
||
message="准备语音合成...",
|
||
completed=0,
|
||
total=len(tts_params.get("segments", [])),
|
||
params={
|
||
"project_id": project_id,
|
||
"user_id": user_id,
|
||
"segments": json.dumps(tts_params.get("segments", []), ensure_ascii=False),
|
||
"voice_id": tts_params.get("voice_id", "zh_female_yizhi"),
|
||
"speed": tts_params.get("speed", 1.0),
|
||
},
|
||
)
|
||
await registry.add_running(task_id)
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
|
||
|
||
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
||
return TaskCreateResponse(
|
||
task_id=task_id,
|
||
status="pending",
|
||
message=f"{task_type} 任务已创建",
|
||
)
|
||
|
||
except ValueError as e:
|
||
logger.warning(f"[API] Invalid params for {task_type}: {e}")
|
||
try:
|
||
await registry.update(task_id, status="failed", message=f"参数错误: {e}", error=str(e))
|
||
except Exception as registry_err:
|
||
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
||
raise HTTPException(status_code=422, detail=f"参数错误: {e}")
|
||
|
||
except HTTPException:
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.exception(f"[API] Failed to create task: {e}")
|
||
try:
|
||
await registry.update(task_id, status="failed", message=str(e), error=str(e))
|
||
except Exception as registry_err:
|
||
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
||
raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
|
||
|
||
|
||
@router.get("", response_model=list[TaskStatusResponse])
|
||
async def list_tasks(
|
||
project_id: str | None = None,
|
||
current_user: User = Depends(get_current_user),
|
||
) -> list[TaskStatusResponse]:
|
||
"""
|
||
查询当前用户所有进行中的任务
|
||
|
||
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
|
||
"""
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
|
||
try:
|
||
jobs = await registry.list_running_by_user(str(current_user.id))
|
||
except Exception as e:
|
||
logger.error(f"[API] Redis error when listing tasks: {e}")
|
||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||
|
||
results: list[TaskStatusResponse] = []
|
||
for job in jobs:
|
||
# 按 project_id 过滤
|
||
if project_id and job.project_id != project_id:
|
||
continue
|
||
results.append(
|
||
TaskStatusResponse(
|
||
task_id=job.job_id,
|
||
type=job.job_type,
|
||
status=job.status,
|
||
progress=job.progress,
|
||
message=job.message,
|
||
completed=job.completed,
|
||
total=job.total,
|
||
result=None, # 列表查询不返回 result,避免数据过大
|
||
error=job.error,
|
||
created_at=job.created_at,
|
||
)
|
||
)
|
||
return results
|
||
|
||
|
||
@router.get("/{task_id}", response_model=TaskStatusResponse)
|
||
async def get_task_status(
|
||
task_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
) -> TaskStatusResponse:
|
||
"""
|
||
查询任务状态
|
||
|
||
前端通过轮询此接口获取任务进度。
|
||
任务状态仅从 Redis 查询,记录过期后返回 404。
|
||
"""
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
|
||
try:
|
||
job = await registry.get(task_id)
|
||
except Exception as e:
|
||
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
|
||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
||
|
||
# 权限检查
|
||
if job.user_id != str(current_user.id):
|
||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||
|
||
return TaskStatusResponse(
|
||
task_id=task_id,
|
||
type=job.job_type,
|
||
status=job.status,
|
||
progress=job.progress,
|
||
message=job.message,
|
||
completed=job.completed,
|
||
total=job.total,
|
||
result=job.result,
|
||
error=job.error,
|
||
created_at=job.created_at,
|
||
)
|
||
|
||
|
||
@router.get("/{task_id}/result")
|
||
async def get_task_result(
|
||
task_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
) -> dict:
|
||
"""
|
||
获取任务结果(简化接口,直接返回 result 字段)
|
||
"""
|
||
redis = get_redis_client()
|
||
registry = JobRegistry(redis)
|
||
|
||
try:
|
||
job = await registry.get(task_id)
|
||
except Exception as e:
|
||
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
|
||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||
|
||
if not job:
|
||
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
||
|
||
if job.user_id != str(current_user.id):
|
||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||
|
||
if job.status != "completed":
|
||
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {job.status}")
|
||
|
||
return job.result or {}
|