""" 火山引擎音视频字幕服务 ====================== 基于火山引擎 OpenSpeech API 的音视频字幕生成服务。 职责: - 通过 PlatformGateway 调用字幕/打轴任务 - 轮询等待结果 - 结果后处理(SRT/ASS/VTT 格式转换) 文档: https://www.volcengine.com/docs/6561/80907 """ from __future__ import annotations import asyncio import logging from fastapi import Request from app.ai.adapters.base import TaskStatus from app.ai.adapters.constants import Method from app.core.exceptions import PlatformError, PlatformErrorType from app.platform_gateway import PlatformGateway from app.schemas.caption import ( AutoAlignResult, CaptionResult, CaptionUtterance, CaptionWord, ) logger = logging.getLogger(__name__) class VolcengineCaptionService: """ 火山引擎音视频字幕服务封装 通过 PlatformGateway 调用第三方 API,自身负责: - 业务参数处理 - 轮询等待 - 格式转换 """ DEFAULT_POLL_INTERVAL = 1.0 MAX_POLL_RETRIES = 120 # 最多轮询120秒 # 错误码映射(仅用于日志展示,业务错误已由 Adapter 统一映射) ERROR_CODES = { 0: "成功", 2000: "处理中", 1001: "参数无效", 1002: "无权限", 1003: "超频", 1010: "音频过长", 1011: "音频过大", 1012: "格式无效", 1013: "音频静音", } def __init__(self, gateway: PlatformGateway): self.gateway = gateway # ==================== 字幕生成(通过 Gateway)==================== async def submit_caption_task( self, audio_url: str, language: str = "zh-CN", caption_type: str = "auto", use_punc: bool = True, use_itn: bool = True, words_per_line: int = 46, max_lines: int = 1, ) -> str: """提交字幕生成任务,返回任务ID""" result = await self.gateway.submit_task( platform="volcengine_caption", task_type=Method.CAPTION, payload={ "audio_url": audio_url, "language": language, "caption_type": caption_type, "use_punc": use_punc, "use_itn": use_itn, "words_per_line": words_per_line, "max_lines": max_lines, }, ) logger.info(f"字幕任务已提交: {result}") return result async def query_caption_task( self, task_id: str, blocking: bool = False, ) -> CaptionResult: """查询字幕任务结果(task_id 为内部 ID)""" status = await self.gateway.query_task_by_internal_id(task_id) return self._status_to_caption_result(status) async def query_caption_task_status(self, task_id: str) -> TaskStatus: """查询字幕任务状态,返回标准 TaskStatus(供 Scheduler 使用)""" return await self.gateway.query_task_by_internal_id(task_id, task_type="caption") async def query_auto_align_task_status(self, task_id: str) -> TaskStatus: """查询打轴任务状态,返回标准 TaskStatus(供 Scheduler 使用)""" return await self.gateway.query_task_by_internal_id(task_id, task_type="auto_align") # ==================== 通用轮询 ==================== async def _poll_task( self, task_id: str, query_func, max_wait_time: int = 120, task_name: str = "任务", ): """通用轮询逻辑:提交后等待第三方任务完成。""" start_time = asyncio.get_event_loop().time() retries = 0 while retries < self.MAX_POLL_RETRIES: result = await query_func(task_id, blocking=True) if result.code == 0: logger.info(f"{task_name}完成: {task_id}, 时长: {result.duration}s") return result elif result.code == 2000: elapsed = asyncio.get_event_loop().time() - start_time if elapsed > max_wait_time: logger.warning( f"{task_name}超时: task_id={task_id}, " f"elapsed={elapsed:.1f}s, max_wait_time={max_wait_time}s" ) raise PlatformError( f"{task_name}超时,请稍后重试", platform="volcengine_caption", retryable=False, error_type=PlatformErrorType.TIMEOUT, ) await asyncio.sleep(self.DEFAULT_POLL_INTERVAL) retries += 1 else: # 原始错误信息记录日志,不暴露给前端 logger.error( f"{task_name}失败: task_id={task_id}, " f"code={result.code}, message={result.message}" ) raise PlatformError( f"{task_name}失败,请稍后重试", platform="volcengine_caption", retryable=False, error_type=PlatformErrorType.BAD_REQUEST, ) logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}") raise PlatformError( f"{task_name}超时,请稍后重试", platform="volcengine_caption", retryable=False, error_type=PlatformErrorType.TIMEOUT, ) async def generate_caption( self, audio_url: str, language: str = "zh-CN", caption_type: str = "auto", use_punc: bool = True, use_itn: bool = True, words_per_line: int = 46, max_lines: int = 1, max_wait_time: int = 120, ) -> CaptionResult: """生成字幕(完整流程:提交->轮询->返回结果)""" task_id = await self.submit_caption_task( audio_url=audio_url, language=language, caption_type=caption_type, use_punc=use_punc, use_itn=use_itn, words_per_line=words_per_line, max_lines=max_lines, ) return await self._poll_task( task_id=task_id, query_func=self.query_caption_task, max_wait_time=max_wait_time, task_name="字幕生成", ) # ==================== 自动打轴(通过 Gateway)==================== async def submit_auto_align_task( self, audio_url: str, audio_text: str, caption_type: str = "speech", sta_punc_mode: int = 3, ) -> str: """提交自动字幕打轴任务,返回任务ID""" result = await self.gateway.submit_task( platform="volcengine_caption", task_type=Method.AUTO_ALIGN, payload={ "audio_url": audio_url, "audio_text": audio_text, "caption_type": caption_type, "sta_punc_mode": sta_punc_mode, }, ) logger.info(f"打轴任务已提交: {result}") return result async def query_auto_align_task( self, task_id: str, blocking: bool = False, ) -> AutoAlignResult: """查询打轴任务结果(task_id 为内部 ID)""" status = await self.gateway.query_task_by_internal_id(task_id) caption_result = self._status_to_caption_result(status) return AutoAlignResult( code=caption_result.code, message=caption_result.message, duration=caption_result.duration, utterances=caption_result.utterances, ) async def auto_align_caption( self, audio_url: str, audio_text: str, caption_type: str = "speech", sta_punc_mode: int = 3, max_wait_time: int = 120, ) -> AutoAlignResult: """自动字幕打轴(完整流程)""" task_id = await self.submit_auto_align_task( audio_url=audio_url, audio_text=audio_text, caption_type=caption_type, sta_punc_mode=sta_punc_mode, ) return await self._poll_task( task_id=task_id, query_func=self.query_auto_align_task, max_wait_time=max_wait_time, task_name="字幕打轴", ) # ==================== 内部工具方法 ==================== def _status_to_caption_result(self, status) -> CaptionResult: """将 Gateway 返回的 TaskStatus 转换为 CaptionResult""" # TaskStatus.state 映射 state_to_code = { "completed": 0, "processing": 2000, "pending": 2000, "failed": -1, } code = state_to_code.get(status.state, -1) utterances = [] result_data = status.result or {} raw_utterances = result_data.get("utterances", []) for u in raw_utterances: words = [ CaptionWord( text=w.get("text", ""), start_time=w.get("start_time", 0) or w.get("startTime", 0), end_time=w.get("end_time", 0) or w.get("endTime", 0), ) for w in u.get("words", []) ] utterances.append( CaptionUtterance( text=u.get("text", ""), start_time=u.get("start_time", 0) or u.get("startTime", 0), end_time=u.get("end_time", 0) or u.get("endTime", 0), words=words, ) ) return CaptionResult( code=code, message=status.error_message or "", duration=result_data.get("duration", 0.0), utterances=utterances, ) # ==================== 格式转换(纯本地计算)==================== @staticmethod def to_srt(utterances: list[CaptionUtterance]) -> str: """将字幕时间轴转换为 SRT 格式""" def ms_to_time(ms: int) -> str: h = ms // 3600000 m = (ms % 3600000) // 60000 s = (ms % 60000) // 1000 ms_remain = ms % 1000 return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}" lines = [] for i, u in enumerate(utterances, 1): lines.append(str(i)) lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}") lines.append(u.text) lines.append("") return "\n".join(lines).strip() @staticmethod def to_ass( utterances: list[CaptionUtterance], video_width: int = 1080, video_height: int = 1920, ) -> str: """将字幕转换为 ASS 格式(使用抖音美好体)""" from app.services.ass_generator import generate_ass return generate_ass( utterances=utterances, video_width=video_width, video_height=video_height, ) @staticmethod def to_vtt(utterances: list[CaptionUtterance]) -> str: """将字幕时间轴转换为 WebVTT 格式""" def ms_to_vtt_time(ms: int) -> str: h = ms // 3600000 m = (ms % 3600000) // 60000 s = (ms % 60000) // 1000 ms_remain = ms % 1000 return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}" lines = ["WEBVTT", ""] for u in utterances: lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}") lines.append(u.text) lines.append("") return "\n".join(lines).strip() async def get_caption_service(request: Request) -> VolcengineCaptionService: """FastAPI Depends:从 app.state 获取全局字幕服务实例。""" gateway = request.app.state.platform_gateway return VolcengineCaptionService(gateway)