c6eba97b43
后端: - 简化积分服务: 删除 freeze/settle/refund, 保留 consume/recharge/expire - 计费配置化: config/points-config.yaml 驱动 fixed/duration/free 三种模式 - TTS 时长探测: app/utils/audio_utils.py (httpx + mutagen 纯 Python) - Python 层扣费: script(5)/polish(1)/title(1)/voice_clone(200)/tts(按秒)/video(按秒) - 字幕 free_services: caption/auto_align 不扣费 - 新增 POST /points/consume 端点(402余额预检) - 新增 check_balance + /points/cost 返回 sufficient/balance/required - 新增 expire_batches 定时回收, 接入 scheduler main(每5分钟) - 删除废弃 tts_handler.py - Alembic 迁移: 删除 frozen/total_refunded 字段 - 同步 requirements.lock 添加 mutagen 前端: - Rust/IPC 层扣费: compose(5)/subtitle_burn(2)/cover_design(2) - 字幕打轴改异步: 走 scheduler subtitle handler - 对口型传 duration: VideoGeneration 传 actualDuration - 创建 pointStore: 全局余额 + fetchBalance + 充值弹窗控制 - 402 欠费弹 RechargeModal: VideoGeneration/SubtitleBurning/CoverDesign - 修复 VoiceDubbing.tsx 类型错误 (alignResult never) - 同步 PointBalance 类型(删除 frozen/available/totalRefunded) Refs: 积分消耗集成收尾
369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""
|
|
统一任务管理 API
|
|
===============
|
|
|
|
提供任务创建和状态查询接口,支持:
|
|
- video: 视频生成
|
|
- script: 脚本生成
|
|
- subtitle: 字幕对齐
|
|
- tts: 语音合成
|
|
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Literal
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.core.redis_client import get_redis_client
|
|
from app.models.user import User
|
|
from app.scheduler.registry import TaskRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["Tasks"])
|
|
|
|
|
|
# ========== 请求/响应模型 ==========
|
|
|
|
|
|
class ScriptParams(BaseModel):
|
|
"""脚本生成参数"""
|
|
|
|
category: str = Field(..., min_length=1, description="大类代码")
|
|
subcategory: str = Field(..., min_length=1, description="小类代码")
|
|
|
|
|
|
@field_validator("category")
|
|
@classmethod
|
|
def validate_category(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("category 不能为空")
|
|
return v.strip()
|
|
|
|
@field_validator("subcategory")
|
|
@classmethod
|
|
def validate_subcategory(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("subcategory 不能为空")
|
|
return v.strip()
|
|
|
|
|
|
class SubtitleParams(BaseModel):
|
|
"""字幕生成参数"""
|
|
|
|
video_path: str = Field(..., min_length=1, description="视频文件路径")
|
|
language: str = Field(default="zh", description="语言代码")
|
|
mode: str = Field(default="caption", description="模式: caption/auto_align")
|
|
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
|
|
|
|
@field_validator("video_path")
|
|
@classmethod
|
|
def validate_video_path(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("video_path 不能为空")
|
|
return v.strip()
|
|
|
|
|
|
class TTSParams(BaseModel):
|
|
"""TTS 语音合成参数"""
|
|
|
|
segments: list[dict] = Field(..., description="分镜列表,每项包含 id, text/voiceover")
|
|
voice_id: str = Field(default="zh_female_yizhi", description="音色 ID")
|
|
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
|
|
|
|
@field_validator("segments")
|
|
@classmethod
|
|
def validate_segments(cls, v: list[dict]) -> list[dict]:
|
|
if not v:
|
|
raise ValueError("segments 不能为空列表")
|
|
return v
|
|
|
|
|
|
class VideoParams(BaseModel):
|
|
"""视频生成(对口型)参数"""
|
|
|
|
video_url: str = Field(..., min_length=1, description="原视频 URL(数字人模板)")
|
|
audio_url: str | None = Field(default=None, description="音频 URL(与 text 二选一)")
|
|
text: str | None = Field(default=None, description="文本内容(与 audio_url 二选一)")
|
|
voice_id: str | None = Field(default=None, description="音色 ID(文字驱动时生效)")
|
|
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="语速")
|
|
volume: int = Field(default=0, ge=0, le=10, description="音量")
|
|
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
|
|
@field_validator("video_url")
|
|
@classmethod
|
|
def validate_video_url(cls, v: str) -> str:
|
|
if not v or not v.strip():
|
|
raise ValueError("video_url 不能为空")
|
|
return v.strip()
|
|
|
|
|
|
class TaskCreateRequest(BaseModel):
|
|
"""创建任务请求"""
|
|
|
|
project_id: str | None = Field(None, description="项目ID(可选)")
|
|
params: dict = Field(default_factory=dict, description="任务参数")
|
|
|
|
|
|
class TaskCreateResponse(BaseModel):
|
|
"""创建任务响应"""
|
|
|
|
task_id: str = Field(..., description="任务ID")
|
|
status: str = Field("pending", description="任务状态")
|
|
message: str = Field("任务已创建", description="状态消息")
|
|
|
|
|
|
class TaskStatusResponse(BaseModel):
|
|
"""任务状态响应"""
|
|
|
|
task_id: str = Field(..., description="任务ID")
|
|
type: str | None = Field(None, description="任务类型")
|
|
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
|
|
progress: int = Field(0, description="进度百分比 (0-100)")
|
|
message: str = Field("", description="状态描述")
|
|
completed: int = Field(0, description="已完成子任务数")
|
|
total: int = Field(0, description="总子任务数")
|
|
result: dict | None = Field(None, description="任务结果(完成时)")
|
|
error: str | None = Field(None, description="错误信息(失败时)")
|
|
created_at: str = Field("", description="任务创建时间(ISO格式)")
|
|
|
|
|
|
# ========== 辅助函数 ==========
|
|
|
|
|
|
def _generate_task_id() -> str:
|
|
"""生成任务ID"""
|
|
return f"task_{uuid.uuid4().hex[:16]}"
|
|
|
|
|
|
# ========== API 路由 ==========
|
|
|
|
|
|
@router.post("/{task_type}", response_model=TaskCreateResponse)
|
|
async def create_task(
|
|
task_type: Literal["script", "subtitle", "video"],
|
|
request: TaskCreateRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
) -> TaskCreateResponse:
|
|
"""
|
|
创建新任务
|
|
|
|
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
|
|
"""
|
|
task_id = _generate_task_id()
|
|
user_id = str(current_user.id)
|
|
project_id = request.project_id or request.params.get("project_id", "")
|
|
|
|
redis = get_redis_client()
|
|
registry = TaskRegistry(redis)
|
|
|
|
try:
|
|
await registry.create(task_id, task_type, user_id)
|
|
except Exception as e:
|
|
logger.error(f"[API] Failed to create registry entry: {e}")
|
|
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
|
|
|
|
try:
|
|
if task_type == "script":
|
|
script_validated = ScriptParams(**request.params)
|
|
await registry.update(
|
|
task_id,
|
|
status="running",
|
|
progress=0,
|
|
message="等待执行...",
|
|
params={
|
|
"category": script_validated.category,
|
|
"subcategory": script_validated.subcategory,
|
|
|
|
},
|
|
)
|
|
await registry.add_running(task_id)
|
|
|
|
elif task_type == "subtitle":
|
|
subtitle_validated = SubtitleParams(**request.params)
|
|
await registry.update(
|
|
task_id,
|
|
status="running",
|
|
message="准备字幕生成...",
|
|
completed=0,
|
|
total=1,
|
|
params={
|
|
"project_id": project_id,
|
|
"video_path": subtitle_validated.video_path,
|
|
"language": subtitle_validated.language,
|
|
"mode": subtitle_validated.mode,
|
|
"audio_text": subtitle_validated.audio_text,
|
|
},
|
|
)
|
|
await registry.add_running(task_id)
|
|
|
|
elif task_type == "video":
|
|
video_validated = VideoParams(**request.params)
|
|
await registry.update(
|
|
task_id,
|
|
status="running",
|
|
message="准备视频生成...",
|
|
completed=0,
|
|
total=1,
|
|
params={
|
|
"project_id": project_id,
|
|
"video_url": video_validated.video_url,
|
|
"audio_url": video_validated.audio_url,
|
|
"text": video_validated.text,
|
|
"voice_id": video_validated.voice_id,
|
|
"speed": video_validated.speed,
|
|
"volume": video_validated.volume,
|
|
"ref_photo_url": video_validated.ref_photo_url,
|
|
|
|
},
|
|
)
|
|
await registry.add_running(task_id)
|
|
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
|
|
|
|
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
|
return TaskCreateResponse(
|
|
task_id=task_id,
|
|
status="pending",
|
|
message=f"{task_type} 任务已创建",
|
|
)
|
|
|
|
except ValueError as e:
|
|
logger.warning(f"[API] Invalid params for {task_type}: {e}")
|
|
try:
|
|
await registry.update(
|
|
task_id, status="failed", message="参数错误,请检查后重试", error="参数错误"
|
|
)
|
|
except Exception as registry_err:
|
|
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
|
raise HTTPException(status_code=422, detail="参数错误,请检查后重试")
|
|
|
|
except HTTPException:
|
|
raise
|
|
|
|
except Exception as e:
|
|
logger.exception(f"[API] Failed to create task: {e}")
|
|
try:
|
|
await registry.update(
|
|
task_id, status="failed", message="任务创建失败,请稍后重试", error="任务创建失败"
|
|
)
|
|
except Exception as registry_err:
|
|
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
|
raise HTTPException(status_code=500, detail="任务创建失败,请稍后重试")
|
|
|
|
|
|
@router.get("", response_model=list[TaskStatusResponse])
|
|
async def list_tasks(
|
|
project_id: str | None = None,
|
|
current_user: User = Depends(get_current_user),
|
|
) -> list[TaskStatusResponse]:
|
|
"""
|
|
查询当前用户所有进行中的任务
|
|
|
|
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
|
|
"""
|
|
redis = get_redis_client()
|
|
registry = TaskRegistry(redis)
|
|
|
|
try:
|
|
tasks = await registry.list_running_by_user(str(current_user.id))
|
|
except Exception as e:
|
|
logger.error(f"[API] Redis error when listing tasks: {e}")
|
|
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
|
|
|
results: list[TaskStatusResponse] = []
|
|
for task in tasks:
|
|
# 按 project_id 过滤
|
|
if project_id and task.project_id != project_id:
|
|
continue
|
|
results.append(
|
|
TaskStatusResponse(
|
|
task_id=task.task_id,
|
|
type=task.task_type,
|
|
status=task.status,
|
|
progress=task.progress,
|
|
message=task.message,
|
|
completed=task.completed,
|
|
total=task.total,
|
|
result=None, # 列表查询不返回 result,避免数据过大
|
|
error=task.error,
|
|
created_at=task.created_at,
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
@router.get("/{task_id}", response_model=TaskStatusResponse)
|
|
async def get_task_status(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
) -> TaskStatusResponse:
|
|
"""
|
|
查询任务状态
|
|
|
|
前端通过轮询此接口获取任务进度。
|
|
任务状态仅从 Redis 查询,记录过期后返回 404。
|
|
"""
|
|
redis = get_redis_client()
|
|
registry = TaskRegistry(redis)
|
|
|
|
try:
|
|
task = await registry.get(task_id)
|
|
except Exception as e:
|
|
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
|
|
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
|
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
|
|
|
# 权限检查
|
|
if task.user_id != str(current_user.id):
|
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
|
|
|
return TaskStatusResponse(
|
|
task_id=task_id,
|
|
type=task.task_type,
|
|
status=task.status,
|
|
progress=task.progress,
|
|
message=task.message,
|
|
completed=task.completed,
|
|
total=task.total,
|
|
result=task.result,
|
|
error=task.error,
|
|
created_at=task.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/{task_id}/result")
|
|
async def get_task_result(
|
|
task_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
"""
|
|
获取任务结果(简化接口,直接返回 result 字段)
|
|
"""
|
|
redis = get_redis_client()
|
|
registry = TaskRegistry(redis)
|
|
|
|
try:
|
|
task = await registry.get(task_id)
|
|
except Exception as e:
|
|
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
|
|
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
|
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
|
|
|
if task.user_id != str(current_user.id):
|
|
raise HTTPException(status_code=403, detail="无权访问此任务")
|
|
|
|
if task.status != "completed":
|
|
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {task.status}")
|
|
|
|
return task.result or {}
|