""" 统一任务管理 API =============== 提供任务创建和状态查询接口,支持: - video: 视频生成 - script: 脚本生成 - subtitle: 字幕对齐 - tts: 语音合成 """ import logging import uuid from typing import Literal from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field, field_validator, model_validator from app.api.deps import get_current_user from app.core.exceptions import InsufficientPointsException from app.core.redis_client import get_redis_client from app.db.session import AsyncSessionLocal from app.models.user import User from app.scheduler.registry import TaskRegistry from app.services import point_service as ps logger = logging.getLogger(__name__) router = APIRouter(tags=["Tasks"]) # ========== 请求/响应模型 ========== class ScriptParams(BaseModel): """脚本生成参数""" category: str = Field(..., min_length=1, description="大类代码") filename: str = Field(..., min_length=1, description="提示词文件名") @field_validator("category") @classmethod def validate_category(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("category 不能为空") return v.strip() @field_validator("filename") @classmethod def validate_filename(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("filename 不能为空") return v.strip() class SubtitleParams(BaseModel): """字幕生成参数""" video_path: str = Field(..., min_length=1, description="视频文件路径") language: str = Field(default="zh", description="语言代码") mode: str = Field(default="caption", description="模式: caption/auto_align") audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)") @field_validator("video_path") @classmethod def validate_video_path(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("video_path 不能为空") return v.strip() @model_validator(mode="after") def validate_auto_align(self) -> "SubtitleParams": if self.mode == "auto_align" and (not self.audio_text or not self.audio_text.strip()): raise ValueError("auto_align 模式必须提供 audio_text") return self class TTSParams(BaseModel): """TTS 语音合成参数""" segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover") voice_id: str = Field(default="zh_female_yizhi", description="音色 ID") speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速") @field_validator("segments") @classmethod def validate_segments(cls, v: list[dict]) -> list[dict]: if not v: raise ValueError("segments 不能为空列表") return v class VideoParams(BaseModel): """视频生成参数""" video_url: str = Field(..., min_length=1, description="原视频 URL(数字人模板)") audio_url: str | None = Field(default=None, description="音频 URL(与 text 二选一)") text: str | None = Field(default=None, description="文本内容(与 audio_url 二选一)") voice_id: str | None = Field(default=None, description="音色 ID(文字驱动时生效)") speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速") volume: int = Field(default=0, ge=0, le=10, description="音量") ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL") planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检") total_planned_duration: float = Field( ..., gt=0, description="所有分镜规划时长之和(秒),用于预检" ) batch_id: str | None = Field(default=None, description="批次ID(可选)") @field_validator("video_url") @classmethod def validate_video_url(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("video_url 不能为空") return v.strip() @model_validator(mode="after") def validate_audio_or_text(self) -> "VideoParams": if not self.audio_url and not self.text: raise ValueError("audio_url 和 text 必须至少填一个") return self class TaskCreateRequest(BaseModel): """创建任务请求""" project_id: str | None = Field(None, description="项目ID(可选)") params: dict = Field(default_factory=dict, description="任务参数") class TaskCreateResponse(BaseModel): """创建任务响应""" task_id: str = Field(..., description="任务ID") status: str = Field("pending", description="任务状态") message: str = Field("任务已创建", description="状态消息") class TaskStatusResponse(BaseModel): """任务状态响应""" task_id: str = Field(..., description="任务ID") type: str | None = Field(None, description="任务类型") status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed") progress: int = Field(0, description="进度百分比 (0-100)") message: str = Field("", description="状态描述") completed: int = Field(0, description="已完成子任务数") total: int = Field(0, description="总子任务数") result: dict | None = Field(None, description="任务结果(完成时)") error: str | None = Field(None, description="错误信息(失败时)") error_code: str | None = Field(None, description="错误码(失败时,如 content_violation)") created_at: str = Field("", description="任务创建时间(ISO格式)") # ========== 辅助函数 ========== def _generate_task_id() -> str: """生成任务ID""" return f"task_{uuid.uuid4().hex[:16]}" # ========== API 路由 ========== @router.post("/{task_type}", response_model=TaskCreateResponse) async def create_task( task_type: Literal["script", "subtitle", "video"], request: TaskCreateRequest, current_user: User = Depends(get_current_user), ) -> TaskCreateResponse: """ 创建新任务 根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。 创建前检查积分余额,不足时直接返回 402。 """ task_id = _generate_task_id() user_id = str(current_user.id) project_id = request.project_id or request.params.get("project_id", "") # ── 1. 参数验证 + 积分预检 ────────────────────────── required_points = 0 validated_params: dict = {} try: if task_type == "script": script_validated = ScriptParams(**request.params) required_points = ps._calculate_cost("script") validated_params = { "category": script_validated.category, "filename": script_validated.filename, "user_id": user_id, "required_points": required_points, "project_id": project_id, } elif task_type == "subtitle": subtitle_validated = SubtitleParams(**request.params) required_points = 0 # 字幕生成免费 validated_params = { "project_id": project_id, "video_path": subtitle_validated.video_path, "language": subtitle_validated.language, "mode": subtitle_validated.mode, "audio_text": subtitle_validated.audio_text, } elif task_type == "video": video_validated = VideoParams(**request.params) # 视频生成按总时长预检(不是按单个分镜) required_points = ps._estimate_max_cost( "video", {"input_seconds": video_validated.total_planned_duration} ) validated_params = { "project_id": project_id, "video_url": video_validated.video_url, "audio_url": video_validated.audio_url, "text": video_validated.text, "voice_id": video_validated.voice_id, "speed": video_validated.speed, "volume": video_validated.volume, "ref_photo_url": video_validated.ref_photo_url, "planned_duration": video_validated.planned_duration, } else: raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}") except ValueError as e: logger.warning(f"[API] Invalid params for {task_type}: {e}") raise HTTPException(status_code=422, detail=f"参数错误: {e}") # ── 2. 积分余额检查 ──────────────────────────────── if required_points > 0: async with AsyncSessionLocal() as db: check = await ps.check_balance(db, user_id, required_points) if not check["sufficient"]: logger.warning( f"[API] 积分不足: user={user_id}, type={task_type}, " f"required={required_points}, balance={check['balance']}" ) raise InsufficientPointsException( f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}" ) # ── 3. 写入 Redis ────────────────────────────────── redis = get_redis_client() registry = TaskRegistry(redis) try: await registry.create(task_id, task_type, user_id) except Exception as e: logger.error(f"[API] Failed to create registry entry: {e}") raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误") try: await registry.update( task_id, status="running", progress=0, message="等待执行...", params=validated_params, ) await registry.add_running(task_id) except Exception as e: logger.error(f"[API] Failed to update registry: {e}") raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误") logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}") return TaskCreateResponse( task_id=task_id, status="running", message=f"{task_type} 任务已创建", ) @router.get("", response_model=list[TaskStatusResponse]) async def list_tasks( project_id: str | None = None, current_user: User = Depends(get_current_user), ) -> list[TaskStatusResponse]: """ 查询当前用户所有进行中的任务 从 Redis running 集合读取真实状态,支持按 project_id 过滤。 """ redis = get_redis_client() registry = TaskRegistry(redis) try: tasks = await registry.list_running_by_user(str(current_user.id)) except Exception as e: logger.error(f"[API] Redis error when listing tasks: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") results: list[TaskStatusResponse] = [] for task in tasks: # 按 project_id 过滤 if project_id and task.project_id != project_id: continue results.append( TaskStatusResponse( task_id=task.task_id, type=task.task_type, status=task.status, progress=task.progress, message=task.message, completed=task.completed, total=task.total, result=None, # 列表查询不返回 result,避免数据过大 error=task.error, error_code=task.error_code, created_at=task.created_at, ) ) return results @router.get("/{task_id}", response_model=TaskStatusResponse) async def get_task_status( task_id: str, current_user: User = Depends(get_current_user), ) -> TaskStatusResponse: """ 查询任务状态 前端通过轮询此接口获取任务进度。 任务状态仅从 Redis 查询,记录过期后返回 404。 """ redis = get_redis_client() registry = TaskRegistry(redis) try: task = await registry.get(task_id) except Exception as e: logger.error(f"[API] Redis error when getting task {task_id}: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") if not task: raise HTTPException(status_code=404, detail="任务不存在或已过期") # 权限检查 if task.user_id != str(current_user.id): raise HTTPException(status_code=403, detail="无权访问此任务") return TaskStatusResponse( task_id=task_id, type=task.task_type, status=task.status, progress=task.progress, message=task.message, completed=task.completed, total=task.total, result=task.result, error=task.error, error_code=task.error_code, created_at=task.created_at, ) @router.get("/{task_id}/result") async def get_task_result( task_id: str, current_user: User = Depends(get_current_user), ) -> dict: """ 获取任务结果(简化接口,直接返回 result 字段) """ redis = get_redis_client() registry = TaskRegistry(redis) try: task = await registry.get(task_id) except Exception as e: logger.error(f"[API] Redis error when getting result {task_id}: {e}") raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试") if not task: raise HTTPException(status_code=404, detail="任务不存在或已过期") if task.user_id != str(current_user.id): raise HTTPException(status_code=403, detail="无权访问此任务") if task.status != "completed": raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {task.status}") return task.result or {}