567 lines
17 KiB
Python
567 lines
17 KiB
Python
"""
|
|
火山引擎音视频字幕服务
|
|
======================
|
|
|
|
基于火山引擎 OpenSpeech API 的音视频字幕生成服务。
|
|
|
|
文档: https://www.volcengine.com/docs/6561/80907
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
import httpx
|
|
|
|
from app.config import get_settings
|
|
from app.schemas.caption import (
|
|
AutoAlignResult,
|
|
CaptionResult,
|
|
CaptionUtterance,
|
|
CaptionWord,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VolcengineCaptionError(Exception):
|
|
"""火山引擎字幕服务异常"""
|
|
|
|
def __init__(self, message: str, code: int = None, original_error: Exception = None):
|
|
super().__init__(message)
|
|
self.code = code
|
|
self.original_error = original_error
|
|
|
|
|
|
class VolcengineCaptionService:
|
|
"""
|
|
火山引擎音视频字幕服务封装
|
|
"""
|
|
|
|
# API 基础配置
|
|
BASE_URL = "https://openspeech.bytedance.com/api/v1/vc"
|
|
DEFAULT_TIMEOUT = 60.0
|
|
DEFAULT_POLL_INTERVAL = 1.0
|
|
MAX_POLL_RETRIES = 120 # 最多轮询120秒
|
|
|
|
# 错误码映射
|
|
ERROR_CODES = {
|
|
0: "成功",
|
|
2000: "处理中",
|
|
1001: "参数无效",
|
|
1002: "无权限",
|
|
1003: "超频",
|
|
1010: "音频过长",
|
|
1011: "音频过大",
|
|
1012: "格式无效",
|
|
1013: "音频静音",
|
|
}
|
|
|
|
def __init__(self, appid: str | None = None, token: str | None = None):
|
|
"""
|
|
初始化字幕服务
|
|
|
|
Args:
|
|
appid: 应用ID,默认从 Settings 读取
|
|
token: 鉴权Token,默认从 Settings 读取
|
|
"""
|
|
settings = get_settings()
|
|
self.appid = appid or settings.VOLCENGINE_CAPTION_APPID or ""
|
|
self.token = token or settings.VOLCENGINE_CAPTION_TOKEN or ""
|
|
|
|
if not self.appid:
|
|
raise VolcengineCaptionError("VOLCENGINE_CAPTION_APPID 未配置")
|
|
if not self.token:
|
|
raise VolcengineCaptionError("VOLCENGINE_CAPTION_TOKEN 未配置")
|
|
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
|
"""获取 HTTP 客户端"""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
|
|
return self._client
|
|
|
|
def _get_headers(self) -> dict:
|
|
"""获取请求头"""
|
|
return {
|
|
"Authorization": f"Bearer; {self.token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
async def submit_caption_task(
|
|
self,
|
|
audio_url: str,
|
|
language: str = "zh-CN",
|
|
caption_type: str = "auto",
|
|
use_punc: bool = True,
|
|
use_itn: bool = True,
|
|
words_per_line: int = 46,
|
|
max_lines: int = 1,
|
|
) -> str:
|
|
"""
|
|
提交字幕生成任务
|
|
|
|
Args:
|
|
audio_url: 音频/视频文件URL
|
|
language: 语言代码
|
|
caption_type: 识别类型 (auto/speech/singing)
|
|
use_punc: 自动标点
|
|
use_itn: 数字转换
|
|
words_per_line: 每行字数
|
|
max_lines: 每屏行数
|
|
|
|
Returns:
|
|
任务ID
|
|
|
|
Raises:
|
|
VolcengineCaptionError: 提交失败
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
params = {
|
|
"appid": self.appid,
|
|
"language": language,
|
|
"caption_type": caption_type,
|
|
"use_punc": str(use_punc),
|
|
"use_itn": str(use_itn),
|
|
"words_per_line": words_per_line,
|
|
"max_lines": max_lines,
|
|
}
|
|
|
|
payload = {"url": audio_url}
|
|
|
|
try:
|
|
response = await client.post(
|
|
f"{self.BASE_URL}/submit",
|
|
params=params,
|
|
json=payload,
|
|
headers=self._get_headers(),
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if "id" not in data:
|
|
raise VolcengineCaptionError(f"提交任务失败: {data.get('message', '未知错误')}")
|
|
|
|
task_id = data["id"]
|
|
logger.info(f"字幕任务已提交: {task_id}")
|
|
return task_id
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise VolcengineCaptionError(
|
|
f"HTTP错误: {e.response.status_code}",
|
|
original_error=e,
|
|
)
|
|
except Exception as e:
|
|
raise VolcengineCaptionError(f"提交任务失败: {str(e)}", original_error=e)
|
|
|
|
async def query_caption_task(
|
|
self,
|
|
task_id: str,
|
|
blocking: bool = False,
|
|
) -> CaptionResult:
|
|
"""
|
|
查询字幕任务结果
|
|
|
|
Args:
|
|
task_id: 任务ID
|
|
blocking: 是否阻塞等待结果 (blocking=1)
|
|
|
|
Returns:
|
|
字幕结果
|
|
|
|
Raises:
|
|
VolcengineCaptionError: 查询失败
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
params = {
|
|
"appid": self.appid,
|
|
"id": task_id,
|
|
"blocking": 1 if blocking else 0,
|
|
}
|
|
|
|
try:
|
|
response = await client.get(
|
|
f"{self.BASE_URL}/query",
|
|
params=params,
|
|
headers=self._get_headers(),
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return self._parse_caption_result(data)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
raise VolcengineCaptionError(
|
|
f"HTTP错误: {e.response.status_code}",
|
|
original_error=e,
|
|
)
|
|
except Exception as e:
|
|
raise VolcengineCaptionError(f"查询任务失败: {str(e)}", original_error=e)
|
|
|
|
async def generate_caption(
|
|
self,
|
|
audio_url: str,
|
|
language: str = "zh-CN",
|
|
caption_type: str = "auto",
|
|
use_punc: bool = True,
|
|
use_itn: bool = True,
|
|
words_per_line: int = 46,
|
|
max_lines: int = 1,
|
|
max_wait_time: int = 120,
|
|
) -> CaptionResult:
|
|
"""
|
|
生成字幕(完整流程:提交->轮询->返回结果)
|
|
|
|
Args:
|
|
audio_url: 音频/视频文件URL
|
|
language: 语言代码
|
|
caption_type: 识别类型
|
|
use_punc: 自动标点
|
|
use_itn: 数字转换
|
|
words_per_line: 每行字数
|
|
max_lines: 每屏行数
|
|
max_wait_time: 最大等待时间(秒)
|
|
|
|
Returns:
|
|
字幕生成结果
|
|
|
|
Raises:
|
|
VolcengineCaptionError: 生成失败或超时
|
|
"""
|
|
# 提交任务
|
|
task_id = await self.submit_caption_task(
|
|
audio_url=audio_url,
|
|
language=language,
|
|
caption_type=caption_type,
|
|
use_punc=use_punc,
|
|
use_itn=use_itn,
|
|
words_per_line=words_per_line,
|
|
max_lines=max_lines,
|
|
)
|
|
|
|
# 轮询结果
|
|
start_time = asyncio.get_event_loop().time()
|
|
retries = 0
|
|
|
|
while retries < self.MAX_POLL_RETRIES:
|
|
result = await self.query_caption_task(task_id, blocking=True)
|
|
|
|
if result.code == 0:
|
|
logger.info(f"字幕生成完成: {task_id}, 时长: {result.duration}s")
|
|
return result
|
|
elif result.code == 2000:
|
|
# 仍在处理中
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
if elapsed > max_wait_time:
|
|
raise VolcengineCaptionError(f"字幕生成超时: 已等待 {max_wait_time}s")
|
|
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
|
|
retries += 1
|
|
else:
|
|
# 其他错误
|
|
error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}")
|
|
raise VolcengineCaptionError(
|
|
f"字幕生成失败: {error_msg} ({result.message})", code=result.code
|
|
)
|
|
|
|
raise VolcengineCaptionError("字幕生成超时: 超过最大重试次数")
|
|
|
|
async def submit_auto_align_task(
|
|
self,
|
|
audio_url: str,
|
|
audio_text: str,
|
|
caption_type: str = "speech",
|
|
sta_punc_mode: int = 3,
|
|
) -> str:
|
|
"""
|
|
提交自动字幕打轴任务
|
|
|
|
Args:
|
|
audio_url: 音频/视频文件URL
|
|
audio_text: 要打轴的字幕文本
|
|
caption_type: 识别类型 (speech/singing)
|
|
sta_punc_mode: 标点模式 (1/2/3)
|
|
|
|
Returns:
|
|
任务ID
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
params = {
|
|
"appid": self.appid,
|
|
"caption_type": caption_type,
|
|
"sta_punc_mode": sta_punc_mode,
|
|
}
|
|
|
|
payload = {
|
|
"url": audio_url,
|
|
"audio_text": audio_text,
|
|
}
|
|
|
|
try:
|
|
response = await client.post(
|
|
f"{self.BASE_URL}/ata/submit",
|
|
params=params,
|
|
json=payload,
|
|
headers=self._get_headers(),
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
if "id" not in data:
|
|
raise VolcengineCaptionError(f"提交打轴任务失败: {data.get('message', '未知错误')}")
|
|
|
|
task_id = data["id"]
|
|
logger.info(f"打轴任务已提交: {task_id}")
|
|
return task_id
|
|
|
|
except Exception as e:
|
|
raise VolcengineCaptionError(f"提交打轴任务失败: {str(e)}", original_error=e)
|
|
|
|
async def query_auto_align_task(
|
|
self,
|
|
task_id: str,
|
|
blocking: bool = False,
|
|
) -> AutoAlignResult:
|
|
"""
|
|
查询打轴任务结果
|
|
|
|
Args:
|
|
task_id: 任务ID
|
|
blocking: 是否阻塞等待
|
|
|
|
Returns:
|
|
打轴结果
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
params = {
|
|
"appid": self.appid,
|
|
"id": task_id,
|
|
"blocking": 1 if blocking else 0,
|
|
}
|
|
|
|
try:
|
|
response = await client.get(
|
|
f"{self.BASE_URL}/ata/query",
|
|
params=params,
|
|
headers=self._get_headers(),
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
logger.info(
|
|
f"[VolcengineCaption] Query response: {json.dumps(data, ensure_ascii=False)}"
|
|
)
|
|
|
|
# 解析结果(与字幕生成结果格式相同)
|
|
caption_result = self._parse_caption_result(data)
|
|
logger.info(f"[VolcengineCaption] Parsed result: {caption_result}")
|
|
logger.info(
|
|
f"[VolcengineCaption] First utterance: {caption_result.utterances[0] if caption_result.utterances else None}"
|
|
)
|
|
return AutoAlignResult(
|
|
code=caption_result.code,
|
|
message=caption_result.message,
|
|
duration=caption_result.duration,
|
|
utterances=caption_result.utterances,
|
|
)
|
|
|
|
except Exception as e:
|
|
raise VolcengineCaptionError(f"查询打轴任务失败: {str(e)}", original_error=e)
|
|
|
|
async def auto_align_caption(
|
|
self,
|
|
audio_url: str,
|
|
audio_text: str,
|
|
caption_type: str = "speech",
|
|
sta_punc_mode: int = 3,
|
|
max_wait_time: int = 120,
|
|
) -> AutoAlignResult:
|
|
"""
|
|
自动字幕打轴(完整流程)
|
|
|
|
Args:
|
|
audio_url: 音频/视频文件URL
|
|
audio_text: 要打轴的字幕文本
|
|
caption_type: 识别类型
|
|
sta_punc_mode: 标点模式
|
|
max_wait_time: 最大等待时间
|
|
|
|
Returns:
|
|
打轴结果
|
|
"""
|
|
task_id = await self.submit_auto_align_task(
|
|
audio_url=audio_url,
|
|
audio_text=audio_text,
|
|
caption_type=caption_type,
|
|
sta_punc_mode=sta_punc_mode,
|
|
)
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
retries = 0
|
|
|
|
while retries < self.MAX_POLL_RETRIES:
|
|
result = await self.query_auto_align_task(task_id, blocking=True)
|
|
|
|
if result.code == 0:
|
|
logger.info(f"打轴完成: {task_id}")
|
|
return result
|
|
elif result.code == 2000:
|
|
elapsed = asyncio.get_event_loop().time() - start_time
|
|
if elapsed > max_wait_time:
|
|
raise VolcengineCaptionError(f"打轴超时: 已等待 {max_wait_time}s")
|
|
await asyncio.sleep(self.DEFAULT_POLL_INTERVAL)
|
|
retries += 1
|
|
else:
|
|
error_msg = self.ERROR_CODES.get(result.code, f"未知错误: {result.code}")
|
|
raise VolcengineCaptionError(
|
|
f"打轴失败: {error_msg} ({result.message})", code=result.code
|
|
)
|
|
|
|
raise VolcengineCaptionError("打轴超时: 超过最大重试次数")
|
|
|
|
def _parse_caption_result(self, data: dict) -> CaptionResult:
|
|
"""解析 API 响应为 CaptionResult"""
|
|
utterances = []
|
|
logger.info(f"[VolcengineCaption] Parsing caption result: {data}")
|
|
|
|
for u in data.get("utterances", []):
|
|
logger.info(f"[VolcengineCaption] Parsing utterance: {u}")
|
|
# 火山引擎可能返回驼峰命名或下划线命名的字段
|
|
words = [
|
|
CaptionWord(
|
|
text=w.get("text", ""),
|
|
start_time=w.get("start_time", 0) or w.get("startTime", 0),
|
|
end_time=w.get("end_time", 0) or w.get("endTime", 0),
|
|
)
|
|
for w in u.get("words", [])
|
|
]
|
|
|
|
utterances.append(
|
|
CaptionUtterance(
|
|
text=u.get("text", ""),
|
|
start_time=u.get("start_time", 0) or u.get("startTime", 0),
|
|
end_time=u.get("end_time", 0) or u.get("endTime", 0),
|
|
words=words,
|
|
)
|
|
)
|
|
|
|
result = CaptionResult(
|
|
code=data.get("code", -1),
|
|
message=data.get("message", ""),
|
|
duration=data.get("duration", 0.0),
|
|
utterances=utterances,
|
|
)
|
|
logger.info(f"[VolcengineCaption] Parsed result: {result}")
|
|
return result
|
|
|
|
@staticmethod
|
|
def to_srt(utterances: list[CaptionUtterance]) -> str:
|
|
"""
|
|
将字幕时间轴转换为 SRT 格式
|
|
|
|
Args:
|
|
utterances: 字幕时间轴列表
|
|
|
|
Returns:
|
|
SRT 格式字符串
|
|
"""
|
|
|
|
def ms_to_time(ms: int) -> str:
|
|
"""毫秒转换为 SRT 时间格式 HH:MM:SS,mmm"""
|
|
h = ms // 3600000
|
|
m = (ms % 3600000) // 60000
|
|
s = (ms % 60000) // 1000
|
|
ms_remain = ms % 1000
|
|
return f"{h:02d}:{m:02d}:{s:02d},{ms_remain:03d}"
|
|
|
|
lines = []
|
|
for i, u in enumerate(utterances, 1):
|
|
lines.append(str(i))
|
|
lines.append(f"{ms_to_time(u.start_time)} --> {ms_to_time(u.end_time)}")
|
|
lines.append(u.text)
|
|
lines.append("")
|
|
|
|
return "\n".join(lines).strip()
|
|
|
|
@staticmethod
|
|
def to_ass(
|
|
utterances: list[CaptionUtterance],
|
|
video_width: int = 1080,
|
|
video_height: int = 1920,
|
|
) -> str:
|
|
"""
|
|
将字幕转换为 ASS 格式(使用抖音美好体)
|
|
|
|
Args:
|
|
utterances: 字幕时间轴
|
|
video_width: 视频宽度
|
|
video_height: 视频高度
|
|
|
|
Returns:
|
|
ASS 格式字符串
|
|
"""
|
|
from app.services.ass_generator import generate_ass
|
|
|
|
return generate_ass(
|
|
utterances=utterances,
|
|
video_width=video_width,
|
|
video_height=video_height,
|
|
)
|
|
|
|
@staticmethod
|
|
def to_vtt(utterances: list[CaptionUtterance]) -> str:
|
|
"""
|
|
将字幕时间轴转换为 WebVTT 格式
|
|
|
|
Args:
|
|
utterances: 字幕时间轴列表
|
|
|
|
Returns:
|
|
WebVTT 格式字符串
|
|
"""
|
|
|
|
def ms_to_vtt_time(ms: int) -> str:
|
|
"""毫秒转换为 VTT 时间格式 HH:MM:SS.mmm"""
|
|
h = ms // 3600000
|
|
m = (ms % 3600000) // 60000
|
|
s = (ms % 60000) // 1000
|
|
ms_remain = ms % 1000
|
|
return f"{h:02d}:{m:02d}:{s:02d}.{ms_remain:03d}"
|
|
|
|
lines = ["WEBVTT", ""]
|
|
|
|
for u in utterances:
|
|
lines.append(f"{ms_to_vtt_time(u.start_time)} --> {ms_to_vtt_time(u.end_time)}")
|
|
lines.append(u.text)
|
|
lines.append("")
|
|
|
|
return "\n".join(lines).strip()
|
|
|
|
async def close(self):
|
|
"""关闭 HTTP 客户端"""
|
|
if self._client and not self._client.is_closed:
|
|
await self._client.aclose()
|
|
|
|
|
|
# 全局服务单例
|
|
_caption_service: VolcengineCaptionService | None = None
|
|
|
|
|
|
async def get_caption_service() -> VolcengineCaptionService:
|
|
"""获取字幕服务单例"""
|
|
global _caption_service
|
|
if _caption_service is None:
|
|
_caption_service = VolcengineCaptionService()
|
|
return _caption_service
|
|
|
|
|
|
def reset_caption_service():
|
|
"""重置字幕服务单例(用于测试)"""
|
|
global _caption_service
|
|
_caption_service = None
|