""" 统一任务管理 API =============== 提供任务创建和状态查询接口,支持: - video: 视频生成 - script: 脚本生成 - subtitle: 字幕对齐 - tts: 语音合成 """ import json import logging import uuid from typing import Literal from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field, field_validator from app.api.deps import get_current_user from app.core.redis_client import get_redis_client from app.models.user import User from app.scheduler.registry import TaskRegistry logger = logging.getLogger(__name__) router = APIRouter(tags=["Tasks"]) # ========== 请求/响应模型 ========== class ScriptParams(BaseModel): """脚本生成参数""" category: str = Field(..., min_length=1, description="大类代码") subcategory: str = Field(..., min_length=1, description="小类代码") @field_validator("category") @classmethod def validate_category(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("category 不能为空") return v.strip() @field_validator("subcategory") @classmethod def validate_subcategory(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("subcategory 不能为空") return v.strip() class SubtitleParams(BaseModel): """字幕生成参数""" video_path: str = Field(..., min_length=1, description="视频文件路径") language: str = Field(default="zh", description="语言代码") mode: str = Field(default="caption", description="模式: caption/auto_align") audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)") @field_validator("video_path") @classmethod def validate_video_path(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("video_path 不能为空") return v.strip() class TTSParams(BaseModel): """TTS 语音合成参数""" segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover") voice_id: str = Field(default="zh_female_yizhi", description="音色 ID") speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速") @field_validator("segments") @classmethod def validate_segments(cls, v: list[dict]) -> list[dict]: if not v: raise ValueError("segments 不能为空列表") return v class VideoParams(BaseModel): """视频生成参数""" video_url: str = Field(..., min_length=1, description="原视频 URL(数字人模板)") audio_url: str | None = Field(default=None, description="音频 URL(与 text 二选一)") text: str | None = Field(default=None, description="文本内容(与 audio_url 二选一)") voice_id: str | None = Field(default=None, description="音色 ID(文字驱动时生效)") speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速") volume: int = Field(default=0, ge=0, le=10, description="音量") ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL") duration: float = Field(..., gt=0, description="输入音频时长(秒),用于后端扣费") @field_validator("video_url") @classmethod def validate_video_url(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("video_url 不能为空") return v.strip() class TaskCreateRequest(BaseModel): """创建任务请求""" project_id: str | None = Field(None, description="项目ID(可选)") params: dict = Field(default_factory=dict, description="任务参数") class TaskCreateResponse(BaseModel): """创建任务响应""" task_id: str = Field(..., description="任务ID") status: str = Field("pending", description="任务状态") message: str = Field("任务已创建", description="状态消息") class TaskStatusResponse(BaseModel): """任务状态响应""" task_id: str = Field(..., description="任务ID") type: str | None = Field(None, description="任务类型") status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed") progress: int = Field(0, description="进度百分比 (0-100)") message: str = Field("", description="状态描述") completed: int = Field(0, description="已完成子任务数") total: int = Field(0, description="总子任务数") result: dict | None = Field(None, description="任务结果(完成时)") error: str | None = Field(None, description="错误信息(失败时)") created_at: str = Field("", description="任务创建时间(ISO格式)") # ========== 辅助函数 ========== def _generate_task_id() -> str: """生成任务ID""" return f"task_{uuid.uuid4().hex[:16]}" # ========== API 路由 ========== @router.post("/{task_type}", response_model=TaskCreateResponse) async def create_task( task_type: Literal["script", "subtitle", "video"], request: TaskCreateRequest, current_user: User = Depends(get_current_user), ) -> TaskCreateResponse: """ 创建新任务 根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。 """ task_id = _generate_task_id() user_id = str(current_user.id) project_id = request.project_id or request.params.get("project_id", "") redis = get_redis_client() registry = TaskRegistry(redis) try: await registry.create(task_id, task_type, user_id) except Exception as e: logger.error(f"[API] Failed to create registry entry: {e}") raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误") try: if task_type == "script": script_validated = ScriptParams(**request.params) await registry.update( task_id, status="running", progress=0, message="等待执行...", params={ "category": script_validated.category, "subcategory": script_validated.subcategory, }, ) await registry.add_running(task_id) elif task_type == "subtitle": subtitle_validated = SubtitleParams(**request.params) await registry.update( task_id, status="running", message="准备字幕生成...", completed=0, total=1, params={ "project_id": project_id, "video_path": subtitle_validated.video_path, "language": subtitle_validated.language, "mode": subtitle_validated.mode, "audio_text": subtitle_validated.audio_text, }, ) await registry.add_running(task_id) elif task_type == "video": video_validated = VideoParams(**request.params) await registry.update( task_id, status="running", message="准备视频生成...", completed=0, total=1, params={ "project_id": project_id, "video_url": video_validated.video_url, "audio_url": video_validated.audio_url, "text": video_validated.text, "voice_id": video_validated.voice_id, "speed": video_validated.speed, "volume": video_validated.volume, "ref_photo_url": video_validated.ref_photo_url, "duration": video_validated.duration, }, ) await registry.add_running(task_id) else: raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}") logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}") return TaskCreateResponse( task_id=task_id, status="running", message=f"{task_type} 任务已创建", ) except ValueError as e: logger.warning(f"[API] Invalid params for {task_type}: {e}") try: await registry.update( task_id, status="failed", message="参数错误,请检查后重试", error="参数错误" ) except Exception as registry_err: logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}") raise HTTPException(status_code=422, detail="参数错误,请检查后重试") except HTTPException: raise except Exception as e: logger.exception(f"[API] Failed to create task: {e}") try: await registry.update( task_id, status="failed", message="任务创建失败,请稍后重试", error="任务创建失败" ) except Exception as registry_err: logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}") raise HTTPException(status_code=500, detail="任务创建失败,请稍后重试") @router.get("", response_model=list[TaskStatusResponse]) async def list_tasks( project_id: str | None = None, current_user: User = Depends(get_current_user), ) -> list[TaskStatusResponse]: """ 查询当前用户所有进行中的任务 从 Redis running 集合读取真实状态,支持按 project_id 过滤。 """ redis = get_redis_client() registry = TaskRegistry(redis) try: tasks = await registry.list_running_by_user(str(current_user.id)) except Exception as e: logger.error(f"[API] Redis error when listing tasks: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") results: list[TaskStatusResponse] = [] for task in tasks: # 按 project_id 过滤 if project_id and task.project_id != project_id: continue results.append( TaskStatusResponse( task_id=task.task_id, type=task.task_type, status=task.status, progress=task.progress, message=task.message, completed=task.completed, total=task.total, result=None, # 列表查询不返回 result,避免数据过大 error=task.error, created_at=task.created_at, ) ) return results @router.get("/{task_id}", response_model=TaskStatusResponse) async def get_task_status( task_id: str, current_user: User = Depends(get_current_user), ) -> TaskStatusResponse: """ 查询任务状态 前端通过轮询此接口获取任务进度。 任务状态仅从 Redis 查询,记录过期后返回 404。 """ redis = get_redis_client() registry = TaskRegistry(redis) try: task = await registry.get(task_id) except Exception as e: logger.error(f"[API] Redis error when getting task {task_id}: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") if not task: raise HTTPException(status_code=404, detail="任务不存在或已过期") # 权限检查 if task.user_id != str(current_user.id): raise HTTPException(status_code=403, detail="无权访问此任务") return TaskStatusResponse( task_id=task_id, type=task.task_type, status=task.status, progress=task.progress, message=task.message, completed=task.completed, total=task.total, result=task.result, error=task.error, created_at=task.created_at, ) @router.get("/{task_id}/result") async def get_task_result( task_id: str, current_user: User = Depends(get_current_user), ) -> dict: """ 获取任务结果(简化接口,直接返回 result 字段) """ redis = get_redis_client() registry = TaskRegistry(redis) try: task = await registry.get(task_id) except Exception as e: logger.error(f"[API] Redis error when getting result {task_id}: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") if not task: raise HTTPException(status_code=404, detail="任务不存在或已过期") if task.user_id != str(current_user.id): raise HTTPException(status_code=403, detail="无权访问此任务") if task.status != "completed": raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {task.status}") return task.result or {}