266 lines
8.7 KiB
Python
266 lines
8.7 KiB
Python
"""
|
||
火山引擎 OpenSpeech Provider
|
||
=============================
|
||
|
||
直接封装火山 OpenSpeech HTTP API:
|
||
- 字幕生成(/vc/submit + /vc/query)
|
||
- 自动打轴(/vc/ata/submit + /vc/ata/query)
|
||
|
||
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _map_caption_error(status: int, message: str, code: int | None = None) -> PlatformError:
|
||
"""把火山字幕错误映射为标准 PlatformError"""
|
||
error_mapping = {
|
||
1001: (PlatformErrorType.BAD_REQUEST, False),
|
||
1002: (PlatformErrorType.AUTH_FAILED, False),
|
||
1003: (PlatformErrorType.RATE_LIMIT, True),
|
||
1010: (PlatformErrorType.BAD_REQUEST, False),
|
||
1011: (PlatformErrorType.BAD_REQUEST, False),
|
||
1012: (PlatformErrorType.BAD_REQUEST, False),
|
||
1013: (PlatformErrorType.BAD_REQUEST, False),
|
||
}
|
||
|
||
if code is not None and code in error_mapping:
|
||
error_type, retryable = error_mapping[code]
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_caption",
|
||
retryable=retryable,
|
||
error_type=error_type,
|
||
status_code=status,
|
||
)
|
||
|
||
http_mapping = {
|
||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||
403: (PlatformErrorType.AUTH_FAILED, False),
|
||
400: (PlatformErrorType.BAD_REQUEST, False),
|
||
500: (PlatformErrorType.SERVER_ERROR, True),
|
||
502: (PlatformErrorType.SERVER_ERROR, True),
|
||
503: (PlatformErrorType.SERVER_ERROR, True),
|
||
}
|
||
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_caption",
|
||
retryable=retryable,
|
||
error_type=error_type,
|
||
status_code=status,
|
||
)
|
||
|
||
|
||
class VolcengineCaptionProvider:
|
||
"""火山引擎字幕 Provider
|
||
|
||
直接调用 OpenSpeech HTTP API,不做业务层处理(如格式转换、轮询)。
|
||
"""
|
||
|
||
BASE_URL = "https://openspeech.bytedance.com/api/v1/vc"
|
||
DEFAULT_TIMEOUT = 60.0
|
||
|
||
def __init__(
|
||
self,
|
||
appid: str | None = None,
|
||
token: str | None = None,
|
||
client: httpx.AsyncClient | None = None,
|
||
):
|
||
settings = get_settings()
|
||
self.appid = appid or settings.VOLCENGINE_CAPTION_APPID or ""
|
||
self.token = token or settings.VOLCENGINE_CAPTION_TOKEN or ""
|
||
|
||
if not self.appid:
|
||
raise PlatformError(
|
||
"VOLCENGINE_CAPTION_APPID 未配置",
|
||
platform="volcengine_caption",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
if not self.token:
|
||
raise PlatformError(
|
||
"VOLCENGINE_CAPTION_TOKEN 未配置",
|
||
platform="volcengine_caption",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
self._owns_client = False
|
||
else:
|
||
self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
|
||
self._owns_client = True
|
||
|
||
def _get_headers(self) -> dict:
|
||
return {
|
||
"Authorization": f"Bearer; {self.token}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP 客户端"""
|
||
if self._owns_client and self.client and not self.client.is_closed:
|
||
await self.client.aclose()
|
||
|
||
# ==================== 字幕生成 ====================
|
||
|
||
async def submit_caption_task(
|
||
self,
|
||
audio_url: str,
|
||
language: str = "zh-CN",
|
||
caption_type: str = "auto",
|
||
use_punc: bool = True,
|
||
use_itn: bool = True,
|
||
words_per_line: int = 46,
|
||
max_lines: int = 1,
|
||
) -> dict[str, Any]:
|
||
"""提交字幕生成任务,返回 {id: task_id}"""
|
||
params: dict[str, str | int] = {
|
||
"appid": self.appid,
|
||
"language": language,
|
||
"caption_type": caption_type,
|
||
"use_punc": str(use_punc),
|
||
"use_itn": str(use_itn),
|
||
"words_per_line": words_per_line,
|
||
"max_lines": max_lines,
|
||
}
|
||
payload = {"url": audio_url}
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
f"{self.BASE_URL}/submit",
|
||
params=params,
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
if "id" not in data:
|
||
raise _map_caption_error(500, f"提交任务失败: {data.get('message', '未知错误')}")
|
||
return data
|
||
except PlatformError:
|
||
raise
|
||
except httpx.HTTPStatusError as e:
|
||
raise _map_caption_error(
|
||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||
) from e
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
raise PlatformError(
|
||
f"字幕服务网络错误: {e}",
|
||
platform="volcengine_caption",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
except Exception as e:
|
||
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
|
||
|
||
async def query_caption_task(
|
||
self,
|
||
task_id: str,
|
||
blocking: bool = False,
|
||
) -> dict[str, Any]:
|
||
"""查询字幕任务结果,返回原始 JSON"""
|
||
params: dict[str, str | int] = {
|
||
"appid": self.appid,
|
||
"id": task_id,
|
||
"blocking": 1 if blocking else 0,
|
||
}
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.BASE_URL}/query",
|
||
params=params,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except PlatformError:
|
||
raise
|
||
except httpx.HTTPStatusError as e:
|
||
raise _map_caption_error(
|
||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||
) from e
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
raise PlatformError(
|
||
f"字幕服务网络错误: {e}",
|
||
platform="volcengine_caption",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
except Exception as e:
|
||
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
|
||
|
||
# ==================== 自动打轴 ====================
|
||
|
||
async def submit_auto_align_task(
|
||
self,
|
||
audio_url: str,
|
||
audio_text: str,
|
||
caption_type: str = "speech",
|
||
sta_punc_mode: int = 3,
|
||
) -> dict[str, Any]:
|
||
"""提交自动字幕打轴任务,返回 {id: task_id}"""
|
||
params: dict[str, str | int] = {
|
||
"appid": self.appid,
|
||
"caption_type": caption_type,
|
||
"sta_punc_mode": sta_punc_mode,
|
||
}
|
||
payload = {"url": audio_url, "audio_text": audio_text}
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
f"{self.BASE_URL}/ata/submit",
|
||
params=params,
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
if "id" not in data:
|
||
raise _map_caption_error(
|
||
500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
|
||
)
|
||
return data
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise _map_caption_error(500, f"提交打轴任务失败: {str(e)}") from e
|
||
|
||
async def query_auto_align_task(
|
||
self,
|
||
task_id: str,
|
||
blocking: bool = False,
|
||
) -> dict[str, Any]:
|
||
"""查询打轴任务结果,返回原始 JSON"""
|
||
params = {
|
||
"appid": self.appid,
|
||
"id": task_id,
|
||
"blocking": 1 if blocking else 0,
|
||
}
|
||
|
||
try:
|
||
response = await self.client.get(
|
||
f"{self.BASE_URL}/ata/query",
|
||
params=params,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except PlatformError:
|
||
raise
|
||
except Exception as e:
|
||
raise _map_caption_error(500, f"查询打轴任务失败: {str(e)}") from e
|