Files
meijiaka-zy/python-api/app/api/v1/tasks.py
T
2026-05-04 19:18:22 +08:00

464 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
统一任务管理 API
===============
提供任务创建和状态查询接口,支持:
- video: 视频生成
- image: 图片生成
- script: 脚本生成
- subtitle: 字幕对齐
"""
import json
import logging
import uuid
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from app.api.deps import get_current_user
from app.core.redis_client import get_redis_client
from app.models.user import User
from app.scheduler.registry import TaskRegistry
from app.schemas.segment import Segment
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Tasks"])
# ========== 请求/响应模型 ==========
class VideoParams(BaseModel):
"""视频生成参数"""
segments: list[Segment] = Field(..., description="分镜列表")
@field_validator("segments")
@classmethod
def validate_segments(cls, v: list[Segment]) -> list[Segment]:
if not v:
raise ValueError("segments 不能为空列表")
return v
class ImageParams(BaseModel):
"""图片生成参数"""
prompt: str = Field(..., min_length=1, description="图片描述")
image_type: str = Field(default="cover", description="图片类型: empty_shot/cover")
reference_image: str | None = Field(None, description="参考图片URL(图生图)")
@field_validator("prompt")
@classmethod
def validate_prompt(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("prompt 不能为空")
return v.strip()
class ScriptParams(BaseModel):
"""脚本生成参数"""
category: str = Field(..., min_length=1, description="大类代码")
subcategory: str = Field(..., min_length=1, description="小类代码")
style: str = Field(default="default", description="脚本风格")
duration: int = Field(default=60, ge=10, le=300, description="视频时长(秒)")
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("category 不能为空")
return v.strip()
@field_validator("subcategory")
@classmethod
def validate_subcategory(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("subcategory 不能为空")
return v.strip()
class SubtitleParams(BaseModel):
"""字幕生成参数"""
video_path: str = Field(..., min_length=1, description="视频文件路径")
language: str = Field(default="zh", description="语言代码")
mode: str = Field(default="caption", description="模式: caption/auto_align")
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
@field_validator("video_path")
@classmethod
def validate_video_path(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("video_path 不能为空")
return v.strip()
class TTSParams(BaseModel):
"""TTS 语音合成参数"""
segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover")
voice_id: str = Field(default="zh_female_yizhi", description="音色 ID")
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
@field_validator("segments")
@classmethod
def validate_segments(cls, v: list[dict]) -> list[dict]:
if not v:
raise ValueError("segments 不能为空列表")
return v
class TaskCreateRequest(BaseModel):
"""创建任务请求"""
project_id: str | None = Field(None, description="项目ID(可选)")
params: dict = Field(default_factory=dict, description="任务参数")
class TaskCreateResponse(BaseModel):
"""创建任务响应"""
task_id: str = Field(..., description="任务ID")
status: str = Field("pending", description="任务状态")
message: str = Field("任务已创建", description="状态消息")
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str = Field(..., description="任务ID")
type: str | None = Field(None, description="任务类型")
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
progress: int = Field(0, description="进度百分比 (0-100)")
message: str = Field("", description="状态描述")
completed: int = Field(0, description="已完成子任务数")
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
created_at: str = Field("", description="任务创建时间(ISO格式)")
# ========== 辅助函数 ==========
def _generate_task_id() -> str:
"""生成任务ID"""
return f"task_{uuid.uuid4().hex[:16]}"
# ========== API 路由 ==========
@router.post("/{task_type}", response_model=TaskCreateResponse)
async def create_task(
task_type: Literal["video", "image", "script", "subtitle", "tts"],
request: TaskCreateRequest,
current_user: User = Depends(get_current_user),
) -> TaskCreateResponse:
"""
创建新任务
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
"""
task_id = _generate_task_id()
user_id = str(current_user.id)
project_id = request.project_id or request.params.get("project_id", "")
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
await registry.create(task_id, task_type, user_id)
except Exception as e:
logger.error(f"[API] Failed to create registry entry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
try:
if task_type == "video":
import re
video_params = dict(request.params)
if "shots" in video_params:
shots = video_params.pop("shots")
for s in shots:
# 清洗 id:前端可能发送数字,Segment 模型要求 str
if "id" in s and not isinstance(s["id"], str):
s["id"] = str(s["id"])
# 清洗 duration:前端可能发送 "5s"Segment 模型要求 int
duration = s.get("duration")
if isinstance(duration, str):
m = re.search(r"\d+", duration)
s["duration"] = int(m.group()) if m else None
video_params["segments"] = shots
validated = VideoParams(**video_params)
segments = validated.segments
normalized_segments = []
for s in segments:
normalized_segments.append(
{
"id": str(s.id),
"type": s.type,
"scene": s.scene,
"voiceover": s.voiceover,
"duration": s.duration,
"voice_id": s.voice_id,
"provider_task_id": None,
"status": "pending",
"video_url": None,
"local_path": None,
"qiniu_url": None,
"error_message": None,
}
)
await registry.update(
task_id,
status="running",
message=f"开始生成视频,共 {len(normalized_segments)} 个镜头...",
completed=0,
total=len(normalized_segments),
params={
"project_id": project_id,
"user_id": user_id,
"shots": json.dumps(normalized_segments, ensure_ascii=False),
},
)
await registry.add_running(task_id)
elif task_type == "image":
image_validated = ImageParams(**request.params)
await registry.update(
task_id,
status="running",
message="准备生成图片...",
completed=0,
total=1,
params={
"project_id": project_id,
"user_id": user_id,
"prompt": image_validated.prompt,
"image_type": image_validated.image_type,
"reference_image": image_validated.reference_image,
},
)
await registry.add_running(task_id)
elif task_type == "script":
script_validated = ScriptParams(**request.params)
await registry.update(
task_id,
status="running",
progress=0,
message="等待执行...",
params={
"category": script_validated.category,
"subcategory": script_validated.subcategory,
"style": script_validated.style,
"duration": script_validated.duration,
},
)
await registry.add_running(task_id)
elif task_type == "subtitle":
subtitle_validated = SubtitleParams(**request.params)
await registry.update(
task_id,
status="running",
message="准备字幕生成...",
completed=0,
total=1,
params={
"project_id": project_id,
"video_path": subtitle_validated.video_path,
"language": subtitle_validated.language,
"mode": subtitle_validated.mode,
"audio_text": subtitle_validated.audio_text,
},
)
await registry.add_running(task_id)
elif task_type == "tts":
tts_params = dict(request.params)
raw_segments = tts_params.get("segments") or tts_params.get("texts", [])
if isinstance(raw_segments, list):
normalized = []
for i, seg in enumerate(raw_segments):
if isinstance(seg, dict):
normalized.append({
"id": seg.get("id", f"tts_{i}"),
"text": seg.get("text") or seg.get("voiceover", ""),
"index": i,
})
else:
normalized.append({
"id": f"tts_{i}",
"text": str(seg),
"index": i,
})
tts_params["segments"] = normalized
else:
raise ValueError("segments 必须为列表")
await registry.update(
task_id,
status="running",
message="准备语音合成...",
completed=0,
total=len(tts_params.get("segments", [])),
params={
"project_id": project_id,
"user_id": user_id,
"segments": json.dumps(tts_params.get("segments", []), ensure_ascii=False),
"voice_id": tts_params.get("voice_id", "zh_female_yizhi"),
"speed": tts_params.get("speed", 1.0),
},
)
await registry.add_running(task_id)
else:
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
return TaskCreateResponse(
task_id=task_id,
status="pending",
message=f"{task_type} 任务已创建",
)
except ValueError as e:
logger.warning(f"[API] Invalid params for {task_type}: {e}")
try:
await registry.update(task_id, status="failed", message=f"参数错误: {e}", error=str(e))
except Exception as registry_err:
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
raise HTTPException(status_code=422, detail=f"参数错误: {e}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"[API] Failed to create task: {e}")
try:
await registry.update(task_id, status="failed", message=str(e), error=str(e))
except Exception as registry_err:
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
project_id: str | None = None,
current_user: User = Depends(get_current_user),
) -> list[TaskStatusResponse]:
"""
查询当前用户所有进行中的任务
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
tasks = await registry.list_running_by_user(str(current_user.id))
except Exception as e:
logger.error(f"[API] Redis error when listing tasks: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
results: list[TaskStatusResponse] = []
for task in tasks:
# 按 project_id 过滤
if project_id and task.project_id != project_id:
continue
results.append(
TaskStatusResponse(
task_id=task.task_id,
type=task.task_type,
status=task.status,
progress=task.progress,
message=task.message,
completed=task.completed,
total=task.total,
result=None, # 列表查询不返回 result,避免数据过大
error=task.error,
created_at=task.created_at,
)
)
return results
@router.get("/{task_id}", response_model=TaskStatusResponse)
async def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user),
) -> TaskStatusResponse:
"""
查询任务状态
前端通过轮询此接口获取任务进度。
任务状态仅从 Redis 查询,记录过期后返回 404。
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
task = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not task:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
# 权限检查
if task.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
return TaskStatusResponse(
task_id=task_id,
type=task.task_type,
status=task.status,
progress=task.progress,
message=task.message,
completed=task.completed,
total=task.total,
result=task.result,
error=task.error,
created_at=task.created_at,
)
@router.get("/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取任务结果(简化接口,直接返回 result 字段)
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
task = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not task:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
if task.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
if task.status != "completed":
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {task.status}")
return task.result or {}