30536276ba
核心变更:
- 统一第三方接口架构:所有服务走 PlatformGateway(call_sync/submit_task/query_task/handle_webhook)
- 视频生成(Vidu 对口型)纳入 Async Engine,与 script/subtitle/tts 统一为 POST /tasks/{task_type} 模式
- 新增 VideoHandler、TTSHandler,完善 ScriptHandler/SubtitleHandler
- PlatformGateway 生成 internal_task_id,建立 Redis 双向映射,callback 场景传入 Async Engine task_id 保证映射一致
- SlotManager 新增 acquire_ctx 上下文管理器,所有 Handler 统一使用
- ViduAdapter 状态映射归一化(normalize_state/denormalize_state)
- 移除 ViduService Semaphore 和 tenacity 重试,并发控制完全交予 SlotManager
- nonce 防重放下沉到 CallbackCapable 协议
- Service 层错误统一为 PlatformError,路由层错误信息脱敏
- 废弃 /voice/lip-sync,清理 vidu.py 遗留路由
Bug 修复:
- VideoHandler 轮询阶段后添加 continue,防止已提交任务重复创建
- voice.py synthesize_to_file 变量名冲突(request vs request_body)
- PlatformGateway.submit_task 空 data 防护
- ScriptHandler 动态导入 asyncio 改为模块级导入
- SubtitleHandler 完成时补充 progress=100
文档:
- 更新 AGENTS.md 核心功能、运行时架构、异步调度描述
361 lines
12 KiB
Python
361 lines
12 KiB
Python
"""
|
||
火山引擎音视频字幕服务
|
||
======================
|
||
|
||
基于火山引擎 OpenSpeech API 的音视频字幕生成服务。
|
||
|
||
职责:
|
||
- 通过 PlatformGateway 调用字幕/打轴任务
|
||
- 轮询等待结果
|
||
- 结果后处理(SRT/ASS/VTT 格式转换)
|
||
|
||
文档: https://www.volcengine.com/docs/6561/80907
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
|
||
from fastapi import Request
|
||
|
||
from app.ai.adapters.base import TaskStatus
|
||
from app.ai.adapters.constants import Method
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
from app.platform_gateway import PlatformGateway
|
||
from app.schemas.caption import (
|
||
AutoAlignResult,
|
||
CaptionResult,
|
||
CaptionUtterance,
|
||
CaptionWord,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class VolcengineCaptionService:
|
||
"""
|
||
火山引擎音视频字幕服务封装
|
||
|
||
通过 PlatformGateway 调用第三方 API,自身负责:
|
||
- 业务参数处理
|
||
- 轮询等待
|
||
- 格式转换
|
||
"""
|
||
|
||
DEFAULT_POLL_INTERVAL = 1.0
|
||
MAX_POLL_RETRIES = 120 # 最多轮询120秒
|
||
|
||
# 错误码映射(仅用于日志展示,业务错误已由 Adapter 统一映射)
|
||
ERROR_CODES = {
|
||
0: "成功",
|
||
2000: "处理中",
|
||
1001: "参数无效",
|
||
1002: "无权限",
|
||
1003: "超频",
|
||
1010: "音频过长",
|
||
1011: "音频过大",
|
||
1012: "格式无效",
|
||
1013: "音频静音",
|
||
}
|
||
|
||
def __init__(self, gateway: PlatformGateway):
|
||
self.gateway = gateway
|
||
|
||
# ==================== 字幕生成(通过 Gateway)====================
|
||
|
||
async def submit_caption_task(
|
||
self,
|
||
audio_url: str,
|
||
language: str = "zh-CN",
|
||
caption_type: str = "auto",
|
||
use_punc: bool = True,
|
||
use_itn: bool = True,
|
||
words_per_line: int = 46,
|
||
max_lines: int = 1,
|
||
) -> str:
|
||
"""提交字幕生成任务,返回任务ID"""
|
||
result = await self.gateway.submit_task(
|
||
platform="volcengine_caption",
|
||
task_type=Method.CAPTION,
|
||
payload={
|
||
"audio_url": audio_url,
|
||
"language": language,
|
||
"caption_type": caption_type,
|
||
"use_punc": use_punc,
|
||
"use_itn": use_itn,
|
||
"words_per_line": words_per_line,
|
||
"max_lines": max_lines,
|
||
},
|
||
)
|
||
logger.info(f"字幕任务已提交: {result}")
|
||
return result
|
||
|
||
async def query_caption_task(
|
||
self,
|
||
task_id: str,
|
||
blocking: bool = False,
|
||
) -> CaptionResult:
|
||
"""查询字幕任务结果(task_id 为内部 ID)"""
|
||
status = await self.gateway.query_task_by_internal_id(task_id)
|
||
return self._status_to_caption_result(status)
|
||
|
||
async def query_caption_task_status(self, task_id: str) -> TaskStatus:
|
||
"""查询字幕任务状态,返回标准 TaskStatus(供 Scheduler 使用)"""
|
||
return await self.gateway.query_task_by_internal_id(task_id, task_type="caption")
|
||
|
||
async def query_auto_align_task_status(self, task_id: str) -> TaskStatus:
|
||
"""查询打轴任务状态,返回标准 TaskStatus(供 Scheduler 使用)"""
|
||
return await self.gateway.query_task_by_internal_id(task_id, task_type="auto_align")
|
||
|
||
# ==================== 通用轮询 ====================
|
||
|
||
async def _poll_task(
|
||
self,
|
||
task_id: str,
|
||
query_func,
|
||
max_wait_time: int = 120,
|
||
task_name: str = "任务",
|
||
):
|
||
"""通用轮询逻辑:提交后等待第三方任务完成。"""
|
||
start_time = asyncio.get_event_loop().time()
|
||
retries = 0
|
||
|
||
while retries < self.MAX_POLL_RETRIES:
|
||
result = await query_func(task_id, blocking=True)
|
||
|
||
if result.code == 0:
|
||
logger.info(f"{task_name}完成: {task_id}, 时长: {result.duration}s")
|
||
return result
|
||
elif result.code == 2000:
|
||
elapsed = asyncio.get_event_loop().time() - start_time
|
||
if elapsed > max_wait_time:
|
||
logger.warning(
|
||
f"{task_name}超时: task_id={task_id}, "
|
||
f"elapsed={elapsed:.1f}s, max_wait_time={max_wait_time}s"
|
||
)
|
||
raise PlatformError(
|
||
f"{task_name}超时,请稍后重试",
|
||
platform="volcengine_caption",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
)
|
||
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
|
||
retries += 1
|
||
else:
|
||
# 原始错误信息记录日志,不暴露给前端
|
||
logger.error(
|
||
f"{task_name}失败: task_id={task_id}, "
|
||
f"code={result.code}, message={result.message}"
|
||
)
|
||
raise PlatformError(
|
||
f"{task_name}失败,请稍后重试",
|
||
platform="volcengine_caption",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
|
||
logger.warning(
|
||
f"{task_name}超过最大轮询次数: task_id={task_id}, "
|
||
f"retries={retries}"
|
||
)
|
||
raise PlatformError(
|
||
f"{task_name}超时,请稍后重试",
|
||
platform="volcengine_caption",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
)
|
||
|
||
async def generate_caption(
|
||
self,
|
||
audio_url: str,
|
||
language: str = "zh-CN",
|
||
caption_type: str = "auto",
|
||
use_punc: bool = True,
|
||
use_itn: bool = True,
|
||
words_per_line: int = 46,
|
||
max_lines: int = 1,
|
||
max_wait_time: int = 120,
|
||
) -> CaptionResult:
|
||
"""生成字幕(完整流程:提交->轮询->返回结果)"""
|
||
task_id = await self.submit_caption_task(
|
||
audio_url=audio_url,
|
||
language=language,
|
||
caption_type=caption_type,
|
||
use_punc=use_punc,
|
||
use_itn=use_itn,
|
||
words_per_line=words_per_line,
|
||
max_lines=max_lines,
|
||
)
|
||
return await self._poll_task(
|
||
task_id=task_id,
|
||
query_func=self.query_caption_task,
|
||
max_wait_time=max_wait_time,
|
||
task_name="字幕生成",
|
||
)
|
||
|
||
# ==================== 自动打轴(通过 Gateway)====================
|
||
|
||
async def submit_auto_align_task(
|
||
self,
|
||
audio_url: str,
|
||
audio_text: str,
|
||
caption_type: str = "speech",
|
||
sta_punc_mode: int = 3,
|
||
) -> str:
|
||
"""提交自动字幕打轴任务,返回任务ID"""
|
||
result = await self.gateway.submit_task(
|
||
platform="volcengine_caption",
|
||
task_type=Method.AUTO_ALIGN,
|
||
payload={
|
||
"audio_url": audio_url,
|
||
"audio_text": audio_text,
|
||
"caption_type": caption_type,
|
||
"sta_punc_mode": sta_punc_mode,
|
||
},
|
||
)
|
||
logger.info(f"打轴任务已提交: {result}")
|
||
return result
|
||
|
||
async def query_auto_align_task(
|
||
self,
|
||
task_id: str,
|
||
blocking: bool = False,
|
||
) -> AutoAlignResult:
|
||
"""查询打轴任务结果(task_id 为内部 ID)"""
|
||
status = await self.gateway.query_task_by_internal_id(task_id)
|
||
caption_result = self._status_to_caption_result(status)
|
||
return AutoAlignResult(
|
||
code=caption_result.code,
|
||
message=caption_result.message,
|
||
duration=caption_result.duration,
|
||
utterances=caption_result.utterances,
|
||
)
|
||
|
||
async def auto_align_caption(
|
||
self,
|
||
audio_url: str,
|
||
audio_text: str,
|
||
caption_type: str = "speech",
|
||
sta_punc_mode: int = 3,
|
||
max_wait_time: int = 120,
|
||
) -> AutoAlignResult:
|
||
"""自动字幕打轴(完整流程)"""
|
||
task_id = await self.submit_auto_align_task(
|
||
audio_url=audio_url,
|
||
audio_text=audio_text,
|
||
caption_type=caption_type,
|
||
sta_punc_mode=sta_punc_mode,
|
||
)
|
||
return await self._poll_task(
|
||
task_id=task_id,
|
||
query_func=self.query_auto_align_task,
|
||
max_wait_time=max_wait_time,
|
||
task_name="字幕打轴",
|
||
)
|
||
|
||
# ==================== 内部工具方法 ====================
|
||
|
||
def _status_to_caption_result(self, status) -> CaptionResult:
|
||
"""将 Gateway 返回的 TaskStatus 转换为 CaptionResult"""
|
||
# TaskStatus.state 映射
|
||
state_to_code = {
|
||
"completed": 0,
|
||
"processing": 2000,
|
||
"pending": 2000,
|
||
"failed": -1,
|
||
}
|
||
code = state_to_code.get(status.state, -1)
|
||
|
||
utterances = []
|
||
result_data = status.result or {}
|
||
raw_utterances = result_data.get("utterances", [])
|
||
|
||
for u in raw_utterances:
|
||
words = [
|
||
CaptionWord(
|
||
text=w.get("text", ""),
|
||
start_time=w.get("start_time", 0) or w.get("startTime", 0),
|
||
end_time=w.get("end_time", 0) or w.get("endTime", 0),
|
||
)
|
||
for w in u.get("words", [])
|
||
]
|
||
utterances.append(
|
||
CaptionUtterance(
|
||
text=u.get("text", ""),
|
||
start_time=u.get("start_time", 0) or u.get("startTime", 0),
|
||
end_time=u.get("end_time", 0) or u.get("endTime", 0),
|
||
words=words,
|
||
)
|
||
)
|
||
|
||
return CaptionResult(
|
||
code=code,
|
||
message=status.error_message or "",
|
||
duration=result_data.get("duration", 0.0),
|
||
utterances=utterances,
|
||
)
|
||
|
||
# ==================== 格式转换(纯本地计算)====================
|
||
|
||
@staticmethod
|
||
def to_srt(utterances: list[CaptionUtterance]) -> str:
|
||
"""将字幕时间轴转换为 SRT 格式"""
|
||
|
||
def ms_to_time(ms: int) -> str:
|
||
h = ms // 3600000
|
||
m = (ms % 3600000) // 60000
|
||
s = (ms % 60000) // 1000
|
||
ms_remain = ms % 1000
|
||
return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}"
|
||
|
||
lines = []
|
||
for i, u in enumerate(utterances, 1):
|
||
lines.append(str(i))
|
||
lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}")
|
||
lines.append(u.text)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines).strip()
|
||
|
||
@staticmethod
|
||
def to_ass(
|
||
utterances: list[CaptionUtterance],
|
||
video_width: int = 1080,
|
||
video_height: int = 1920,
|
||
) -> str:
|
||
"""将字幕转换为 ASS 格式(使用抖音美好体)"""
|
||
from app.services.ass_generator import generate_ass
|
||
|
||
return generate_ass(
|
||
utterances=utterances,
|
||
video_width=video_width,
|
||
video_height=video_height,
|
||
)
|
||
|
||
@staticmethod
|
||
def to_vtt(utterances: list[CaptionUtterance]) -> str:
|
||
"""将字幕时间轴转换为 WebVTT 格式"""
|
||
|
||
def ms_to_vtt_time(ms: int) -> str:
|
||
h = ms // 3600000
|
||
m = (ms % 3600000) // 60000
|
||
s = (ms % 60000) // 1000
|
||
ms_remain = ms % 1000
|
||
return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}"
|
||
|
||
lines = ["WEBVTT", ""]
|
||
|
||
for u in utterances:
|
||
lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}")
|
||
lines.append(u.text)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines).strip()
|
||
|
||
|
||
async def get_caption_service(request: Request) -> VolcengineCaptionService:
|
||
"""FastAPI Depends:从 app.state 获取全局字幕服务实例。"""
|
||
gateway = request.app.state.platform_gateway
|
||
return VolcengineCaptionService(gateway)
|