Files
meijiaka-zy/python-api/app/services/volcengine_caption_service.py
T

567 lines
17 KiB
Python

"""
火山引擎音视频字幕服务
======================
基于火山引擎 OpenSpeech API 的音视频字幕生成服务。
文档: https://www.volcengine.com/docs/6561/80907
"""
from __future__ import annotations
import asyncio
import json
import logging
import httpx
from app.config import get_settings
from app.schemas.caption import (
AutoAlignResult,
CaptionResult,
CaptionUtterance,
CaptionWord,
)
logger = logging.getLogger(__name__)
class VolcengineCaptionError(Exception):
"""火山引擎字幕服务异常"""
def __init__(self, message: str, code: int = None, original_error: Exception = None):
super().__init__(message)
self.code = code
self.original_error = original_error
class VolcengineCaptionService:
"""
火山引擎音视频字幕服务封装
"""
# API 基础配置
BASE_URL = "https://openspeech.bytedance.com/api/v1/vc"
DEFAULT_TIMEOUT = 60.0
DEFAULT_POLL_INTERVAL = 1.0
MAX_POLL_RETRIES = 120 # 最多轮询120秒
# 错误码映射
ERROR_CODES = {
0: "成功",
2000: "处理中",
1001: "参数无效",
1002: "无权限",
1003: "超频",
1010: "音频过长",
1011: "音频过大",
1012: "格式无效",
1013: "音频静音",
}
def __init__(self, appid: str | None = None, token: str | None = None):
"""
初始化字幕服务
Args:
appid: 应用ID,默认从 Settings 读取
token: 鉴权Token,默认从 Settings 读取
"""
settings = get_settings()
self.appid = appid or settings.VOLCENGINE_CAPTION_APPID or ""
self.token = token or settings.VOLCENGINE_CAPTION_TOKEN or ""
if not self.appid:
raise VolcengineCaptionError("VOLCENGINE_CAPTION_APPID 未配置")
if not self.token:
raise VolcengineCaptionError("VOLCENGINE_CAPTION_TOKEN 未配置")
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""获取 HTTP 客户端"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
return self._client
def _get_headers(self) -> dict:
"""获取请求头"""
return {
"Authorization": f"Bearer; {self.token}",
"Content-Type": "application/json",
}
async def submit_caption_task(
self,
audio_url: str,
language: str = "zh-CN",
caption_type: str = "auto",
use_punc: bool = True,
use_itn: bool = True,
words_per_line: int = 46,
max_lines: int = 1,
) -> str:
"""
提交字幕生成任务
Args:
audio_url: 音频/视频文件URL
language: 语言代码
caption_type: 识别类型 (auto/speech/singing)
use_punc: 自动标点
use_itn: 数字转换
words_per_line: 每行字数
max_lines: 每屏行数
Returns:
任务ID
Raises:
VolcengineCaptionError: 提交失败
"""
client = await self._get_client()
params = {
"appid": self.appid,
"language": language,
"caption_type": caption_type,
"use_punc": str(use_punc),
"use_itn": str(use_itn),
"words_per_line": words_per_line,
"max_lines": max_lines,
}
payload = {"url": audio_url}
try:
response = await client.post(
f"{self.BASE_URL}/submit",
params=params,
json=payload,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
if "id" not in data:
raise VolcengineCaptionError(f"提交任务失败: {data.get('message', '未知错误')}")
task_id = data["id"]
logger.info(f"字幕任务已提交: {task_id}")
return task_id
except httpx.HTTPStatusError as e:
raise VolcengineCaptionError(
f"HTTP错误: {e.response.status_code}",
original_error=e,
)
except Exception as e:
raise VolcengineCaptionError(f"提交任务失败: {str(e)}", original_error=e)
async def query_caption_task(
self,
task_id: str,
blocking: bool = False,
) -> CaptionResult:
"""
查询字幕任务结果
Args:
task_id: 任务ID
blocking: 是否阻塞等待结果 (blocking=1)
Returns:
字幕结果
Raises:
VolcengineCaptionError: 查询失败
"""
client = await self._get_client()
params = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
}
try:
response = await client.get(
f"{self.BASE_URL}/query",
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
return self._parse_caption_result(data)
except httpx.HTTPStatusError as e:
raise VolcengineCaptionError(
f"HTTP错误: {e.response.status_code}",
original_error=e,
)
except Exception as e:
raise VolcengineCaptionError(f"查询任务失败: {str(e)}", original_error=e)
async def generate_caption(
self,
audio_url: str,
language: str = "zh-CN",
caption_type: str = "auto",
use_punc: bool = True,
use_itn: bool = True,
words_per_line: int = 46,
max_lines: int = 1,
max_wait_time: int = 120,
) -> CaptionResult:
"""
生成字幕(完整流程:提交->轮询->返回结果)
Args:
audio_url: 音频/视频文件URL
language: 语言代码
caption_type: 识别类型
use_punc: 自动标点
use_itn: 数字转换
words_per_line: 每行字数
max_lines: 每屏行数
max_wait_time: 最大等待时间(秒)
Returns:
字幕生成结果
Raises:
VolcengineCaptionError: 生成失败或超时
"""
# 提交任务
task_id = await self.submit_caption_task(
audio_url=audio_url,
language=language,
caption_type=caption_type,
use_punc=use_punc,
use_itn=use_itn,
words_per_line=words_per_line,
max_lines=max_lines,
)
# 轮询结果
start_time = asyncio.get_event_loop().time()
retries = 0
while retries < self.MAX_POLL_RETRIES:
result = await self.query_caption_task(task_id, blocking=True)
if result.code == 0:
logger.info(f"字幕生成完成: {task_id}, 时长: {result.duration}s")
return result
elif result.code == 2000:
# 仍在处理中
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > max_wait_time:
raise VolcengineCaptionError(f"字幕生成超时: 已等待 {max_wait_time}s")
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
retries += 1
else:
# 其他错误
error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}")
raise VolcengineCaptionError(
f"字幕生成失败: {error_msg} ({result.message})", code=result.code
)
raise VolcengineCaptionError("字幕生成超时: 超过最大重试次数")
async def submit_auto_align_task(
self,
audio_url: str,
audio_text: str,
caption_type: str = "speech",
sta_punc_mode: int = 3,
) -> str:
"""
提交自动字幕打轴任务
Args:
audio_url: 音频/视频文件URL
audio_text: 要打轴的字幕文本
caption_type: 识别类型 (speech/singing)
sta_punc_mode: 标点模式 (1/2/3)
Returns:
任务ID
"""
client = await self._get_client()
params = {
"appid": self.appid,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
}
payload = {
"url": audio_url,
"audio_text": audio_text,
}
try:
response = await client.post(
f"{self.BASE_URL}/ata/submit",
params=params,
json=payload,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
if "id" not in data:
raise VolcengineCaptionError(f"提交打轴任务失败: {data.get('message', '未知错误')}")
task_id = data["id"]
logger.info(f"打轴任务已提交: {task_id}")
return task_id
except Exception as e:
raise VolcengineCaptionError(f"提交打轴任务失败: {str(e)}", original_error=e)
async def query_auto_align_task(
self,
task_id: str,
blocking: bool = False,
) -> AutoAlignResult:
"""
查询打轴任务结果
Args:
task_id: 任务ID
blocking: 是否阻塞等待
Returns:
打轴结果
"""
client = await self._get_client()
params = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
}
try:
response = await client.get(
f"{self.BASE_URL}/ata/query",
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
logger.info(
f"[VolcengineCaption] Query response: {json.dumps(data, ensure_ascii=False)}"
)
# 解析结果(与字幕生成结果格式相同)
caption_result = self._parse_caption_result(data)
logger.info(f"[VolcengineCaption] Parsed result: {caption_result}")
logger.info(
f"[VolcengineCaption] First utterance: {caption_result.utterances[0] if caption_result.utterances else None}"
)
return AutoAlignResult(
code=caption_result.code,
message=caption_result.message,
duration=caption_result.duration,
utterances=caption_result.utterances,
)
except Exception as e:
raise VolcengineCaptionError(f"查询打轴任务失败: {str(e)}", original_error=e)
async def auto_align_caption(
self,
audio_url: str,
audio_text: str,
caption_type: str = "speech",
sta_punc_mode: int = 3,
max_wait_time: int = 120,
) -> AutoAlignResult:
"""
自动字幕打轴(完整流程)
Args:
audio_url: 音频/视频文件URL
audio_text: 要打轴的字幕文本
caption_type: 识别类型
sta_punc_mode: 标点模式
max_wait_time: 最大等待时间
Returns:
打轴结果
"""
task_id = await self.submit_auto_align_task(
audio_url=audio_url,
audio_text=audio_text,
caption_type=caption_type,
sta_punc_mode=sta_punc_mode,
)
start_time = asyncio.get_event_loop().time()
retries = 0
while retries < self.MAX_POLL_RETRIES:
result = await self.query_auto_align_task(task_id, blocking=True)
if result.code == 0:
logger.info(f"打轴完成: {task_id}")
return result
elif result.code == 2000:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > max_wait_time:
raise VolcengineCaptionError(f"打轴超时: 已等待 {max_wait_time}s")
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
retries += 1
else:
error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}")
raise VolcengineCaptionError(
f"打轴失败: {error_msg} ({result.message})", code=result.code
)
raise VolcengineCaptionError("打轴超时: 超过最大重试次数")
def _parse_caption_result(self, data: dict) -> CaptionResult:
"""解析 API 响应为 CaptionResult"""
utterances = []
logger.info(f"[VolcengineCaption] Parsing caption result: {data}")
for u in data.get("utterances", []):
logger.info(f"[VolcengineCaption] Parsing utterance: {u}")
# 火山引擎可能返回驼峰命名或下划线命名的字段
words = [
CaptionWord(
text=w.get("text", ""),
start_time=w.get("start_time", 0) or w.get("startTime", 0),
end_time=w.get("end_time", 0) or w.get("endTime", 0),
)
for w in u.get("words", [])
]
utterances.append(
CaptionUtterance(
text=u.get("text", ""),
start_time=u.get("start_time", 0) or u.get("startTime", 0),
end_time=u.get("end_time", 0) or u.get("endTime", 0),
words=words,
)
)
result = CaptionResult(
code=data.get("code", -1),
message=data.get("message", ""),
duration=data.get("duration", 0.0),
utterances=utterances,
)
logger.info(f"[VolcengineCaption] Parsed result: {result}")
return result
@staticmethod
def to_srt(utterances: list[CaptionUtterance]) -> str:
"""
将字幕时间轴转换为 SRT 格式
Args:
utterances: 字幕时间轴列表
Returns:
SRT 格式字符串
"""
def ms_to_time(ms: int) -> str:
"""毫秒转换为 SRT 时间格式 HH:MM:SS,mmm"""
h = ms // 3600000
m = (ms % 3600000) // 60000
s = (ms % 60000) // 1000
ms_remain = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}"
lines = []
for i, u in enumerate(utterances, 1):
lines.append(str(i))
lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}")
lines.append(u.text)
lines.append("")
return "\n".join(lines).strip()
@staticmethod
def to_ass(
utterances: list[CaptionUtterance],
video_width: int = 1080,
video_height: int = 1920,
) -> str:
"""
将字幕转换为 ASS 格式(使用抖音美好体)
Args:
utterances: 字幕时间轴
video_width: 视频宽度
video_height: 视频高度
Returns:
ASS 格式字符串
"""
from app.services.ass_generator import generate_ass
return generate_ass(
utterances=utterances,
video_width=video_width,
video_height=video_height,
)
@staticmethod
def to_vtt(utterances: list[CaptionUtterance]) -> str:
"""
将字幕时间轴转换为 WebVTT 格式
Args:
utterances: 字幕时间轴列表
Returns:
WebVTT 格式字符串
"""
def ms_to_vtt_time(ms: int) -> str:
"""毫秒转换为 VTT 时间格式 HH:MM:SS.mmm"""
h = ms // 3600000
m = (ms % 3600000) // 60000
s = (ms % 60000) // 1000
ms_remain = ms % 1000
return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}"
lines = ["WEBVTT", ""]
for u in utterances:
lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}")
lines.append(u.text)
lines.append("")
return "\n".join(lines).strip()
async def close(self):
"""关闭 HTTP 客户端"""
if self._client and not self._client.is_closed:
await self._client.aclose()
# 全局服务单例
_caption_service: VolcengineCaptionService | None = None
async def get_caption_service() -> VolcengineCaptionService:
"""获取字幕服务单例"""
global _caption_service
if _caption_service is None:
_caption_service = VolcengineCaptionService()
return _caption_service
def reset_caption_service():
"""重置字幕服务单例(用于测试)"""
global _caption_service
_caption_service = None