Files
meijiaka-zy/python-api/app/api/v1/tasks.py
T
小鱼开发 d4a13ece17 chore: 清理后端未使用 import(9 处)
ruff --select F401 --fix 自动修复:
- deps.py: user_crud
- caption.py: ApiResponse, VolcengineCaptionService
- points.py: UTC
- tasks.py: json
- voice.py: asyncio
- main.py: init_db
- broll_category.py: Text, ARRAY
2026-05-14 22:40:01 +08:00

371 lines
13 KiB
Python

"""
统一任务管理 API
===============
提供任务创建和状态查询接口,支持:
- video: 视频生成
- script: 脚本生成
- subtitle: 字幕对齐
- tts: 语音合成
"""
import logging
import uuid
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from app.api.deps import get_current_user
from app.core.redis_client import get_redis_client
from app.db.session import AsyncSessionLocal
from app.models.user import User
from app.scheduler.registry import TaskRegistry
from app.services import point_service as ps
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Tasks"])
# ========== 请求/响应模型 ==========
class ScriptParams(BaseModel):
"""脚本生成参数"""
category: str = Field(..., min_length=1, description="大类代码")
subcategory: str = Field(..., min_length=1, description="小类代码")
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("category 不能为空")
return v.strip()
@field_validator("subcategory")
@classmethod
def validate_subcategory(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("subcategory 不能为空")
return v.strip()
class SubtitleParams(BaseModel):
"""字幕生成参数"""
video_path: str = Field(..., min_length=1, description="视频文件路径")
language: str = Field(default="zh", description="语言代码")
mode: str = Field(default="caption", description="模式: caption/auto_align")
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
@field_validator("video_path")
@classmethod
def validate_video_path(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("video_path 不能为空")
return v.strip()
class TTSParams(BaseModel):
"""TTS 语音合成参数"""
segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover")
voice_id: str = Field(default="zh_female_yizhi", description="音色 ID")
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
@field_validator("segments")
@classmethod
def validate_segments(cls, v: list[dict]) -> list[dict]:
if not v:
raise ValueError("segments 不能为空列表")
return v
class VideoParams(BaseModel):
"""视频生成参数"""
video_url: str = Field(..., min_length=1, description="原视频 URL(数字人模板)")
audio_url: str | None = Field(default=None, description="音频 URL(与 text 二选一)")
text: str | None = Field(default=None, description="文本内容(与 audio_url 二选一)")
voice_id: str | None = Field(default=None, description="音色 ID(文字驱动时生效)")
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
volume: int = Field(default=0, ge=0, le=10, description="音量")
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
batch_id: str | None = Field(default=None, description="批次ID(可选)")
@field_validator("video_url")
@classmethod
def validate_video_url(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("video_url 不能为空")
return v.strip()
class TaskCreateRequest(BaseModel):
"""创建任务请求"""
project_id: str | None = Field(None, description="项目ID(可选)")
params: dict = Field(default_factory=dict, description="任务参数")
class TaskCreateResponse(BaseModel):
"""创建任务响应"""
task_id: str = Field(..., description="任务ID")
status: str = Field("pending", description="任务状态")
message: str = Field("任务已创建", description="状态消息")
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str = Field(..., description="任务ID")
type: str | None = Field(None, description="任务类型")
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
progress: int = Field(0, description="进度百分比 (0-100)")
message: str = Field("", description="状态描述")
completed: int = Field(0, description="已完成子任务数")
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
created_at: str = Field("", description="任务创建时间(ISO格式)")
# ========== 辅助函数 ==========
def _generate_task_id() -> str:
"""生成任务ID"""
return f"task_{uuid.uuid4().hex[:16]}"
# ========== API 路由 ==========
@router.post("/{task_type}", response_model=TaskCreateResponse)
async def create_task(
task_type: Literal["script", "subtitle", "video"],
request: TaskCreateRequest,
current_user: User = Depends(get_current_user),
) -> TaskCreateResponse:
"""
创建新任务
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
创建前检查积分余额,不足时直接返回 402。
"""
task_id = _generate_task_id()
user_id = str(current_user.id)
project_id = request.project_id or request.params.get("project_id", "")
# ── 1. 参数验证 + 积分预检 ──────────────────────────
required_points = 0
validated_params: dict = {}
try:
if task_type == "script":
script_validated = ScriptParams(**request.params)
required_points = ps._calculate_cost("script")
validated_params = {
"category": script_validated.category,
"subcategory": script_validated.subcategory,
}
elif task_type == "subtitle":
subtitle_validated = SubtitleParams(**request.params)
required_points = 0 # 字幕生成免费
validated_params = {
"project_id": project_id,
"video_path": subtitle_validated.video_path,
"language": subtitle_validated.language,
"mode": subtitle_validated.mode,
"audio_text": subtitle_validated.audio_text,
}
elif task_type == "video":
video_validated = VideoParams(**request.params)
# 视频生成按总时长预检(不是按单个分镜)
required_points = ps._estimate_max_cost(
"video", {"input_seconds": video_validated.total_planned_duration}
)
validated_params = {
"project_id": project_id,
"video_url": video_validated.video_url,
"audio_url": video_validated.audio_url,
"text": video_validated.text,
"voice_id": video_validated.voice_id,
"speed": video_validated.speed,
"volume": video_validated.volume,
"ref_photo_url": video_validated.ref_photo_url,
"planned_duration": video_validated.planned_duration,
}
else:
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
except ValueError as e:
logger.warning(f"[API] Invalid params for {task_type}: {e}")
raise HTTPException(status_code=422, detail=f"参数错误: {e}")
# ── 2. 积分余额检查 ────────────────────────────────
if required_points > 0:
async with AsyncSessionLocal() as db:
check = await ps.check_balance(db, user_id, required_points)
if not check["sufficient"]:
logger.warning(
f"[API] 积分不足: user={user_id}, type={task_type}, "
f"required={required_points}, balance={check['balance']}"
)
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
)
# ── 3. 写入 Redis ──────────────────────────────────
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
await registry.create(task_id, task_type, user_id)
except Exception as e:
logger.error(f"[API] Failed to create registry entry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
try:
await registry.update(
task_id,
status="running",
progress=0,
message="等待执行...",
params=validated_params,
)
await registry.add_running(task_id)
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
return TaskCreateResponse(
task_id=task_id,
status="running",
message=f"{task_type} 任务已创建",
)
except Exception as e:
logger.error(f"[API] Failed to update registry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
project_id: str | None = None,
current_user: User = Depends(get_current_user),
) -> list[TaskStatusResponse]:
"""
查询当前用户所有进行中的任务
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
tasks = await registry.list_running_by_user(str(current_user.id))
except Exception as e:
logger.error(f"[API] Redis error when listing tasks: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
results: list[TaskStatusResponse] = []
for task in tasks:
# 按 project_id 过滤
if project_id and task.project_id != project_id:
continue
results.append(
TaskStatusResponse(
task_id=task.task_id,
type=task.task_type,
status=task.status,
progress=task.progress,
message=task.message,
completed=task.completed,
total=task.total,
result=None, # 列表查询不返回 result,避免数据过大
error=task.error,
created_at=task.created_at,
)
)
return results
@router.get("/{task_id}", response_model=TaskStatusResponse)
async def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user),
) -> TaskStatusResponse:
"""
查询任务状态
前端通过轮询此接口获取任务进度。
任务状态仅从 Redis 查询,记录过期后返回 404。
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
task = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not task:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
# 权限检查
if task.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
return TaskStatusResponse(
task_id=task_id,
type=task.task_type,
status=task.status,
progress=task.progress,
message=task.message,
completed=task.completed,
total=task.total,
result=task.result,
error=task.error,
created_at=task.created_at,
)
@router.get("/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取任务结果(简化接口,直接返回 result 字段)
"""
redis = get_redis_client()
registry = TaskRegistry(redis)
try:
task = await registry.get(task_id)
except Exception as e:
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
if not task:
raise HTTPException(status_code=404, detail="任务不存在或已过期")
if task.user_id != str(current_user.id):
raise HTTPException(status_code=403, detail="无权访问此任务")
if task.status != "completed":
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {task.status}")
return task.result or {}