Files
meijiaka-zy/python-api/app/services/volcengine_caption_service.py
T

358 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山引擎音视频字幕服务
======================
基于火山引擎 OpenSpeech API 的音视频字幕生成服务。
职责:
- 通过 PlatformGateway 调用字幕/打轴任务
- 轮询等待结果
- 结果后处理(SRT/ASS/VTT 格式转换)
文档: https://www.volcengine.com/docs/6561/80907
"""
from __future__ import annotations
import asyncio
import logging
from fastapi import Request
from app.ai.adapters.base import TaskStatus
from app.ai.adapters.constants import Method
from app.core.exceptions import PlatformError, PlatformErrorType
from app.platform_gateway import PlatformGateway
from app.schemas.caption import (
AutoAlignResult,
CaptionResult,
CaptionUtterance,
CaptionWord,
)
logger = logging.getLogger(__name__)
class VolcengineCaptionService:
"""
火山引擎音视频字幕服务封装
通过 PlatformGateway 调用第三方 API,自身负责:
- 业务参数处理
- 轮询等待
- 格式转换
"""
DEFAULT_POLL_INTERVAL = 1.0
MAX_POLL_RETRIES = 120 # 最多轮询120秒
# 错误码映射(仅用于日志展示,业务错误已由 Adapter 统一映射)
ERROR_CODES = {
0: "成功",
2000: "处理中",
1001: "参数无效",
1002: "无权限",
1003: "超频",
1010: "音频过长",
1011: "音频过大",
1012: "格式无效",
1013: "音频静音",
}
def __init__(self, gateway: PlatformGateway):
self.gateway = gateway
# ==================== 字幕生成(通过 Gateway====================
async def submit_caption_task(
self,
audio_url: str,
language: str = "zh-CN",
caption_type: str = "auto",
use_punc: bool = True,
use_itn: bool = True,
words_per_line: int = 46,
max_lines: int = 1,
) -> str:
"""提交字幕生成任务,返回任务ID"""
result = await self.gateway.submit_task(
platform="volcengine_caption",
task_type=Method.CAPTION,
payload={
"audio_url": audio_url,
"language": language,
"caption_type": caption_type,
"use_punc": use_punc,
"use_itn": use_itn,
"words_per_line": words_per_line,
"max_lines": max_lines,
},
)
logger.info(f"字幕任务已提交: {result}")
return result
async def query_caption_task(
self,
task_id: str,
blocking: bool = False,
) -> CaptionResult:
"""查询字幕任务结果(task_id 为内部 ID)"""
status = await self.gateway.query_task_by_internal_id(task_id)
return self._status_to_caption_result(status)
async def query_caption_task_status(self, task_id: str) -> TaskStatus:
"""查询字幕任务状态,返回标准 TaskStatus(供 Scheduler 使用)"""
return await self.gateway.query_task_by_internal_id(task_id, task_type="caption")
async def query_auto_align_task_status(self, task_id: str) -> TaskStatus:
"""查询打轴任务状态,返回标准 TaskStatus(供 Scheduler 使用)"""
return await self.gateway.query_task_by_internal_id(task_id, task_type="auto_align")
# ==================== 通用轮询 ====================
async def _poll_task(
self,
task_id: str,
query_func,
max_wait_time: int = 120,
task_name: str = "任务",
):
"""通用轮询逻辑:提交后等待第三方任务完成。"""
start_time = asyncio.get_event_loop().time()
retries = 0
while retries < self.MAX_POLL_RETRIES:
result = await query_func(task_id, blocking=True)
if result.code == 0:
logger.info(f"{task_name}完成: {task_id}, 时长: {result.duration}s")
return result
elif result.code == 2000:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > max_wait_time:
logger.warning(
f"{task_name}超时: task_id={task_id}, "
f"elapsed={elapsed:.1f}s, max_wait_time={max_wait_time}s"
)
raise PlatformError(
f"{task_name}超时,请稍后重试",
platform="volcengine_caption",
retryable=False,
error_type=PlatformErrorType.TIMEOUT,
)
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
retries += 1
else:
# 原始错误信息记录日志,不暴露给前端
logger.error(
f"{task_name}失败: task_id={task_id}, "
f"code={result.code}, message={result.message}"
)
raise PlatformError(
f"{task_name}失败,请稍后重试",
platform="volcengine_caption",
retryable=False,
error_type=PlatformErrorType.BAD_REQUEST,
)
logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
raise PlatformError(
f"{task_name}超时,请稍后重试",
platform="volcengine_caption",
retryable=False,
error_type=PlatformErrorType.TIMEOUT,
)
async def generate_caption(
self,
audio_url: str,
language: str = "zh-CN",
caption_type: str = "auto",
use_punc: bool = True,
use_itn: bool = True,
words_per_line: int = 46,
max_lines: int = 1,
max_wait_time: int = 120,
) -> CaptionResult:
"""生成字幕(完整流程:提交->轮询->返回结果)"""
task_id = await self.submit_caption_task(
audio_url=audio_url,
language=language,
caption_type=caption_type,
use_punc=use_punc,
use_itn=use_itn,
words_per_line=words_per_line,
max_lines=max_lines,
)
return await self._poll_task(
task_id=task_id,
query_func=self.query_caption_task,
max_wait_time=max_wait_time,
task_name="字幕生成",
)
# ==================== 自动打轴(通过 Gateway====================
async def submit_auto_align_task(
self,
audio_url: str,
audio_text: str,
caption_type: str = "speech",
sta_punc_mode: int = 3,
) -> str:
"""提交自动字幕打轴任务,返回任务ID"""
result = await self.gateway.submit_task(
platform="volcengine_caption",
task_type=Method.AUTO_ALIGN,
payload={
"audio_url": audio_url,
"audio_text": audio_text,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
},
)
logger.info(f"打轴任务已提交: {result}")
return result
async def query_auto_align_task(
self,
task_id: str,
blocking: bool = False,
) -> AutoAlignResult:
"""查询打轴任务结果(task_id 为内部 ID)"""
status = await self.gateway.query_task_by_internal_id(task_id)
caption_result = self._status_to_caption_result(status)
return AutoAlignResult(
code=caption_result.code,
message=caption_result.message,
duration=caption_result.duration,
utterances=caption_result.utterances,
)
async def auto_align_caption(
self,
audio_url: str,
audio_text: str,
caption_type: str = "speech",
sta_punc_mode: int = 3,
max_wait_time: int = 120,
) -> AutoAlignResult:
"""自动字幕打轴(完整流程)"""
task_id = await self.submit_auto_align_task(
audio_url=audio_url,
audio_text=audio_text,
caption_type=caption_type,
sta_punc_mode=sta_punc_mode,
)
return await self._poll_task(
task_id=task_id,
query_func=self.query_auto_align_task,
max_wait_time=max_wait_time,
task_name="字幕打轴",
)
# ==================== 内部工具方法 ====================
def _status_to_caption_result(self, status) -> CaptionResult:
"""将 Gateway 返回的 TaskStatus 转换为 CaptionResult"""
# TaskStatus.state 映射
state_to_code = {
"completed": 0,
"processing": 2000,
"pending": 2000,
"failed": -1,
}
code = state_to_code.get(status.state, -1)
utterances = []
result_data = status.result or {}
raw_utterances = result_data.get("utterances", [])
for u in raw_utterances:
words = [
CaptionWord(
text=w.get("text", ""),
start_time=w.get("start_time", 0) or w.get("startTime", 0),
end_time=w.get("end_time", 0) or w.get("endTime", 0),
)
for w in u.get("words", [])
]
utterances.append(
CaptionUtterance(
text=u.get("text", ""),
start_time=u.get("start_time", 0) or u.get("startTime", 0),
end_time=u.get("end_time", 0) or u.get("endTime", 0),
words=words,
)
)
return CaptionResult(
code=code,
message=status.error_message or "",
duration=result_data.get("duration", 0.0),
utterances=utterances,
)
# ==================== 格式转换(纯本地计算)====================
@staticmethod
def to_srt(utterances: list[CaptionUtterance]) -> str:
"""将字幕时间轴转换为 SRT 格式"""
def ms_to_time(ms: int) -> str:
h = ms // 3600000
m = (ms % 3600000) // 60000
s = (ms % 60000) // 1000
ms_remain = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}"
lines = []
for i, u in enumerate(utterances, 1):
lines.append(str(i))
lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}")
lines.append(u.text)
lines.append("")
return "\n".join(lines).strip()
@staticmethod
def to_ass(
utterances: list[CaptionUtterance],
video_width: int = 1080,
video_height: int = 1920,
) -> str:
"""将字幕转换为 ASS 格式(使用抖音美好体)"""
from app.services.ass_generator import generate_ass
return generate_ass(
utterances=utterances,
video_width=video_width,
video_height=video_height,
)
@staticmethod
def to_vtt(utterances: list[CaptionUtterance]) -> str:
"""将字幕时间轴转换为 WebVTT 格式"""
def ms_to_vtt_time(ms: int) -> str:
h = ms // 3600000
m = (ms % 3600000) // 60000
s = (ms % 60000) // 1000
ms_remain = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}"
lines = ["WEBVTT", ""]
for u in utterances:
lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}")
lines.append(u.text)
lines.append("")
return "\n".join(lines).strip()
async def get_caption_service(request: Request) -> VolcengineCaptionService:
"""FastAPI Depends:从 app.state 获取全局字幕服务实例。"""
gateway = request.app.state.platform_gateway
return VolcengineCaptionService(gateway)