""" 火山引擎音视频字幕服务 ====================== 基于火山引擎 OpenSpeech API 的音视频字幕生成服务。 文档: https://www.volcengine.com/docs/6561/80907 """ from __future__ import annotations import asyncio import json import logging import httpx from fastapi import Request from app.config import get_settings from app.core.exceptions import PlatformError, PlatformErrorType from app.schemas.caption import ( AutoAlignResult, CaptionResult, CaptionUtterance, CaptionWord, ) logger = logging.getLogger(__name__) def _map_caption_error(status: int, message: str, code: int | None = None) -> PlatformError: """把火山字幕错误映射为标准 PlatformError""" # 火山字幕业务错误码映射 error_mapping = { 1001: (PlatformErrorType.BAD_REQUEST, False), # 参数无效 1002: (PlatformErrorType.AUTH_FAILED, False), # 无权限 1003: (PlatformErrorType.RATE_LIMIT, True), # 超频(可重试) 1010: (PlatformErrorType.BAD_REQUEST, False), # 音频过长 1011: (PlatformErrorType.BAD_REQUEST, False), # 音频过大 1012: (PlatformErrorType.BAD_REQUEST, False), # 格式无效 1013: (PlatformErrorType.BAD_REQUEST, False), # 音频静音 } if code is not None and code in error_mapping: error_type, retryable = error_mapping[code] return PlatformError( message, platform="volcengine_caption", retryable=retryable, error_type=error_type, status_code=status, ) # HTTP 状态码映射 http_mapping = { 429: (PlatformErrorType.RATE_LIMIT, True), 401: (PlatformErrorType.AUTH_FAILED, False), 403: (PlatformErrorType.AUTH_FAILED, False), 400: (PlatformErrorType.BAD_REQUEST, False), 500: (PlatformErrorType.SERVER_ERROR, True), 502: (PlatformErrorType.SERVER_ERROR, True), 503: (PlatformErrorType.SERVER_ERROR, True), } error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False)) return PlatformError( message, platform="volcengine_caption", retryable=retryable, error_type=error_type, status_code=status, ) class VolcengineCaptionService: """ 火山引擎音视频字幕服务封装 """ # API 基础配置 BASE_URL = "https://openspeech.bytedance.com/api/v1/vc" DEFAULT_TIMEOUT = 60.0 DEFAULT_POLL_INTERVAL = 1.0 MAX_POLL_RETRIES = 120 # 最多轮询120秒 # 错误码映射 ERROR_CODES = { 0: "成功", 2000: "处理中", 1001: "参数无效", 1002: "无权限", 1003: "超频", 1010: "音频过长", 1011: "音频过大", 1012: "格式无效", 1013: "音频静音", } def __init__( self, appid: str | None = None, token: str | None = None, client: httpx.AsyncClient | None = None, ): """ 初始化字幕服务 Args: appid: 应用ID,默认从 Settings 读取 token: 鉴权Token,默认从 Settings 读取 client: 外部注入的 httpx.AsyncClient(由 lifespan 管理生命周期) """ settings = get_settings() self.appid = appid or settings.VOLCENGINE_CAPTION_APPID or "" self.token = token or settings.VOLCENGINE_CAPTION_TOKEN or "" if not self.appid: raise PlatformError( "VOLCENGINE_CAPTION_APPID 未配置", platform="volcengine_caption", retryable=False, error_type=PlatformErrorType.BAD_REQUEST, ) if not self.token: raise PlatformError( "VOLCENGINE_CAPTION_TOKEN 未配置", platform="volcengine_caption", retryable=False, error_type=PlatformErrorType.BAD_REQUEST, ) if client is not None: self.client = client self._owns_client = False else: self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT) self._owns_client = True def _get_headers(self) -> dict: """获取请求头""" return { "Authorization": f"Bearer; {self.token}", "Content-Type": "application/json", } async def submit_caption_task( self, audio_url: str, language: str = "zh-CN", caption_type: str = "auto", use_punc: bool = True, use_itn: bool = True, words_per_line: int = 46, max_lines: int = 1, ) -> str: """ 提交字幕生成任务 Args: audio_url: 音频/视频文件URL language: 语言代码 caption_type: 识别类型 (auto/speech/singing) use_punc: 自动标点 use_itn: 数字转换 words_per_line: 每行字数 max_lines: 每屏行数 Returns: 任务ID Raises: VolcengineCaptionError: 提交失败 """ client = self.client params = { "appid": self.appid, "language": language, "caption_type": caption_type, "use_punc": str(use_punc), "use_itn": str(use_itn), "words_per_line": words_per_line, "max_lines": max_lines, } payload = {"url": audio_url} try: response = await client.post( f"{self.BASE_URL}/submit", params=params, json=payload, headers=self._get_headers(), ) response.raise_for_status() data = response.json() if "id" not in data: raise _map_caption_error( 500, f"提交任务失败: {data.get('message', '未知错误')}", ) task_id = data["id"] logger.info(f"字幕任务已提交: {task_id}") return task_id except PlatformError: raise except httpx.HTTPStatusError as e: raise _map_caption_error( e.response.status_code, f"HTTP错误: {e.response.status_code}", ) from e except (httpx.NetworkError, httpx.TimeoutException) as e: raise PlatformError( f"字幕服务网络错误: {e}", platform="volcengine_caption", retryable=True, error_type=PlatformErrorType.TIMEOUT, ) from e except Exception as e: raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e async def query_caption_task( self, task_id: str, blocking: bool = False, ) -> CaptionResult: """ 查询字幕任务结果 Args: task_id: 任务ID blocking: 是否阻塞等待结果 (blocking=1) Returns: 字幕结果 Raises: VolcengineCaptionError: 查询失败 """ client = self.client params = { "appid": self.appid, "id": task_id, "blocking": 1 if blocking else 0, } try: response = await client.get( f"{self.BASE_URL}/query", params=params, headers=self._get_headers(), ) response.raise_for_status() data = response.json() return self._parse_caption_result(data) except PlatformError: raise except httpx.HTTPStatusError as e: raise _map_caption_error( e.response.status_code, f"HTTP错误: {e.response.status_code}", ) from e except (httpx.NetworkError, httpx.TimeoutException) as e: raise PlatformError( f"字幕服务网络错误: {e}", platform="volcengine_caption", retryable=True, error_type=PlatformErrorType.TIMEOUT, ) from e except Exception as e: raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e async def generate_caption( self, audio_url: str, language: str = "zh-CN", caption_type: str = "auto", use_punc: bool = True, use_itn: bool = True, words_per_line: int = 46, max_lines: int = 1, max_wait_time: int = 120, ) -> CaptionResult: """ 生成字幕(完整流程:提交->轮询->返回结果) Args: audio_url: 音频/视频文件URL language: 语言代码 caption_type: 识别类型 use_punc: 自动标点 use_itn: 数字转换 words_per_line: 每行字数 max_lines: 每屏行数 max_wait_time: 最大等待时间(秒) Returns: 字幕生成结果 Raises: VolcengineCaptionError: 生成失败或超时 """ # 提交任务 task_id = await self.submit_caption_task( audio_url=audio_url, language=language, caption_type=caption_type, use_punc=use_punc, use_itn=use_itn, words_per_line=words_per_line, max_lines=max_lines, ) # 轮询结果 start_time = asyncio.get_event_loop().time() retries = 0 while retries < self.MAX_POLL_RETRIES: result = await self.query_caption_task(task_id, blocking=True) if result.code == 0: logger.info(f"字幕生成完成: {task_id}, 时长: {result.duration}s") return result elif result.code == 2000: # 仍在处理中 elapsed = asyncio.get_event_loop().time() - start_time if elapsed > max_wait_time: raise _map_caption_error( 504, f"字幕生成超时: 已等待 {max_wait_time}s", ) await asyncio.sleep(self.DEFAULT_POLL_INTERVAL) retries += 1 else: # 其他错误 error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}") raise _map_caption_error( 500, f"字幕生成失败: {error_msg} ({result.message})", code=result.code, ) raise _map_caption_error(504, "字幕生成超时: 超过最大重试次数") async def submit_auto_align_task( self, audio_url: str, audio_text: str, caption_type: str = "speech", sta_punc_mode: int = 3, ) -> str: """ 提交自动字幕打轴任务 Args: audio_url: 音频/视频文件URL audio_text: 要打轴的字幕文本 caption_type: 识别类型 (speech/singing) sta_punc_mode: 标点模式 (1/2/3) Returns: 任务ID """ client = self.client params = { "appid": self.appid, "caption_type": caption_type, "sta_punc_mode": sta_punc_mode, } payload = { "url": audio_url, "audio_text": audio_text, } try: response = await client.post( f"{self.BASE_URL}/ata/submit", params=params, json=payload, headers=self._get_headers(), ) response.raise_for_status() data = response.json() if "id" not in data: raise _map_caption_error( 500, f"提交打轴任务失败: {data.get('message', '未知错误')}", ) task_id = data["id"] logger.info(f"打轴任务已提交: {task_id}") return task_id except PlatformError: raise except Exception as e: raise _map_caption_error(500, f"提交打轴任务失败: {str(e)}") from e async def query_auto_align_task( self, task_id: str, blocking: bool = False, ) -> AutoAlignResult: """ 查询打轴任务结果 Args: task_id: 任务ID blocking: 是否阻塞等待 Returns: 打轴结果 """ client = self.client params = { "appid": self.appid, "id": task_id, "blocking": 1 if blocking else 0, } try: response = await client.get( f"{self.BASE_URL}/ata/query", params=params, headers=self._get_headers(), ) response.raise_for_status() data = response.json() logger.info( f"[VolcengineCaption] Query response: {json.dumps(data, ensure_ascii=False)}" ) # 解析结果(与字幕生成结果格式相同) caption_result = self._parse_caption_result(data) logger.info(f"[VolcengineCaption] Parsed result: {caption_result}") logger.info( f"[VolcengineCaption] First utterance: {caption_result.utterances[0] if caption_result.utterances else None}" ) return AutoAlignResult( code=caption_result.code, message=caption_result.message, duration=caption_result.duration, utterances=caption_result.utterances, ) except PlatformError: raise except Exception as e: raise _map_caption_error(500, f"查询打轴任务失败: {str(e)}") from e async def auto_align_caption( self, audio_url: str, audio_text: str, caption_type: str = "speech", sta_punc_mode: int = 3, max_wait_time: int = 120, ) -> AutoAlignResult: """ 自动字幕打轴(完整流程) Args: audio_url: 音频/视频文件URL audio_text: 要打轴的字幕文本 caption_type: 识别类型 sta_punc_mode: 标点模式 max_wait_time: 最大等待时间 Returns: 打轴结果 """ task_id = await self.submit_auto_align_task( audio_url=audio_url, audio_text=audio_text, caption_type=caption_type, sta_punc_mode=sta_punc_mode, ) start_time = asyncio.get_event_loop().time() retries = 0 while retries < self.MAX_POLL_RETRIES: result = await self.query_auto_align_task(task_id, blocking=True) if result.code == 0: logger.info(f"打轴完成: {task_id}") return result elif result.code == 2000: elapsed = asyncio.get_event_loop().time() - start_time if elapsed > max_wait_time: raise _map_caption_error( 504, f"打轴超时: 已等待 {max_wait_time}s", ) await asyncio.sleep(self.DEFAULT_POLL_INTERVAL) retries += 1 else: error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}") raise _map_caption_error( 500, f"打轴失败: {error_msg} ({result.message})", code=result.code, ) raise _map_caption_error(504, "打轴超时: 超过最大重试次数") def _parse_caption_result(self, data: dict) -> CaptionResult: """解析 API 响应为 CaptionResult""" utterances = [] logger.info(f"[VolcengineCaption] Parsing caption result: {data}") for u in data.get("utterances", []): logger.info(f"[VolcengineCaption] Parsing utterance: {u}") # 火山引擎可能返回驼峰命名或下划线命名的字段 words = [ CaptionWord( text=w.get("text", ""), start_time=w.get("start_time", 0) or w.get("startTime", 0), end_time=w.get("end_time", 0) or w.get("endTime", 0), ) for w in u.get("words", []) ] utterances.append( CaptionUtterance( text=u.get("text", ""), start_time=u.get("start_time", 0) or u.get("startTime", 0), end_time=u.get("end_time", 0) or u.get("endTime", 0), words=words, ) ) result = CaptionResult( code=data.get("code", -1), message=data.get("message", ""), duration=data.get("duration", 0.0), utterances=utterances, ) logger.info(f"[VolcengineCaption] Parsed result: {result}") return result @staticmethod def to_srt(utterances: list[CaptionUtterance]) -> str: """ 将字幕时间轴转换为 SRT 格式 Args: utterances: 字幕时间轴列表 Returns: SRT 格式字符串 """ def ms_to_time(ms: int) -> str: """毫秒转换为 SRT 时间格式 HH:MM:SS,mmm""" h = ms // 3600000 m = (ms % 3600000) // 60000 s = (ms % 60000) // 1000 ms_remain = ms % 1000 return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}" lines = [] for i, u in enumerate(utterances, 1): lines.append(str(i)) lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}") lines.append(u.text) lines.append("") return "\n".join(lines).strip() @staticmethod def to_ass( utterances: list[CaptionUtterance], video_width: int = 1080, video_height: int = 1920, ) -> str: """ 将字幕转换为 ASS 格式(使用抖音美好体) Args: utterances: 字幕时间轴 video_width: 视频宽度 video_height: 视频高度 Returns: ASS 格式字符串 """ from app.services.ass_generator import generate_ass return generate_ass( utterances=utterances, video_width=video_width, video_height=video_height, ) @staticmethod def to_vtt(utterances: list[CaptionUtterance]) -> str: """ 将字幕时间轴转换为 WebVTT 格式 Args: utterances: 字幕时间轴列表 Returns: WebVTT 格式字符串 """ def ms_to_vtt_time(ms: int) -> str: """毫秒转换为 VTT 时间格式 HH:MM:SS.mmm""" h = ms // 3600000 m = (ms % 3600000) // 60000 s = (ms % 60000) // 1000 ms_remain = ms % 1000 return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}" lines = ["WEBVTT", ""] for u in utterances: lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}") lines.append(u.text) lines.append("") return "\n".join(lines).strip() async def close(self): """关闭 HTTP 客户端。仅在自己创建 Client 时关闭。""" if self._owns_client and self.client and not self.client.is_closed: await self.client.aclose() async def get_caption_service(request: Request) -> VolcengineCaptionService: """FastAPI Depends:从 app.state 获取全局字幕服务实例。""" return request.app.state.volcengine_caption_service