Files
meijiaka-zy/python-api/app/ai/providers/volcengine_caption_provider.py
T

266 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山引擎 OpenSpeech Provider
=============================
直接封装火山 OpenSpeech HTTP API
- 字幕生成(/vc/submit + /vc/query
- 自动打轴(/vc/ata/submit + /vc/ata/query
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from app.config import get_settings
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
def _map_caption_error(status: int, message: str, code: int | None = None) -> PlatformError:
"""把火山字幕错误映射为标准 PlatformError"""
error_mapping = {
1001: (PlatformErrorType.BAD_REQUEST, False),
1002: (PlatformErrorType.AUTH_FAILED, False),
1003: (PlatformErrorType.RATE_LIMIT, True),
1010: (PlatformErrorType.BAD_REQUEST, False),
1011: (PlatformErrorType.BAD_REQUEST, False),
1012: (PlatformErrorType.BAD_REQUEST, False),
1013: (PlatformErrorType.BAD_REQUEST, False),
}
if code is not None and code in error_mapping:
error_type, retryable = error_mapping[code]
return PlatformError(
message,
platform="volcengine_caption",
retryable=retryable,
error_type=error_type,
status_code=status,
)
http_mapping = {
429: (PlatformErrorType.RATE_LIMIT, True),
401: (PlatformErrorType.AUTH_FAILED, False),
403: (PlatformErrorType.AUTH_FAILED, False),
400: (PlatformErrorType.BAD_REQUEST, False),
500: (PlatformErrorType.SERVER_ERROR, True),
502: (PlatformErrorType.SERVER_ERROR, True),
503: (PlatformErrorType.SERVER_ERROR, True),
}
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
message,
platform="volcengine_caption",
retryable=retryable,
error_type=error_type,
status_code=status,
)
class VolcengineCaptionProvider:
"""火山引擎字幕 Provider
直接调用 OpenSpeech HTTP API,不做业务层处理(如格式转换、轮询)。
"""
BASE_URL = "https://openspeech.bytedance.com/api/v1/vc"
DEFAULT_TIMEOUT = 60.0
def __init__(
self,
appid: str | None = None,
token: str | None = None,
client: httpx.AsyncClient | None = None,
):
settings = get_settings()
self.appid = appid or settings.VOLCENGINE_CAPTION_APPID or ""
self.token = token or settings.VOLCENGINE_CAPTION_TOKEN or ""
if not self.appid:
raise PlatformError(
"VOLCENGINE_CAPTION_APPID 未配置",
platform="volcengine_caption",
retryable=False,
error_type=PlatformErrorType.BAD_REQUEST,
)
if not self.token:
raise PlatformError(
"VOLCENGINE_CAPTION_TOKEN 未配置",
platform="volcengine_caption",
retryable=False,
error_type=PlatformErrorType.BAD_REQUEST,
)
if client is not None:
self.client = client
self._owns_client = False
else:
self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
self._owns_client = True
def _get_headers(self) -> dict:
return {
"Authorization": f"Bearer; {self.token}",
"Content-Type": "application/json",
}
async def close(self) -> None:
"""关闭 HTTP 客户端"""
if self._owns_client and self.client and not self.client.is_closed:
await self.client.aclose()
# ==================== 字幕生成 ====================
async def submit_caption_task(
self,
audio_url: str,
language: str = "zh-CN",
caption_type: str = "auto",
use_punc: bool = True,
use_itn: bool = True,
words_per_line: int = 46,
max_lines: int = 1,
) -> dict[str, Any]:
"""提交字幕生成任务,返回 {id: task_id}"""
params: dict[str, str | int] = {
"appid": self.appid,
"language": language,
"caption_type": caption_type,
"use_punc": str(use_punc),
"use_itn": str(use_itn),
"words_per_line": words_per_line,
"max_lines": max_lines,
}
payload = {"url": audio_url}
try:
response = await self.client.post(
f"{self.BASE_URL}/submit",
params=params,
json=payload,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
if "id" not in data:
raise _map_caption_error(500, f"提交任务失败: {data.get('message', '未知错误')}")
return data
except PlatformError:
raise
except httpx.HTTPStatusError as e:
raise _map_caption_error(
e.response.status_code, f"HTTP错误: {e.response.status_code}"
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"字幕服务网络错误: {e}",
platform="volcengine_caption",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
async def query_caption_task(
self,
task_id: str,
blocking: bool = False,
) -> dict[str, Any]:
"""查询字幕任务结果,返回原始 JSON"""
params: dict[str, str | int] = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
}
try:
response = await self.client.get(
f"{self.BASE_URL}/query",
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
return response.json()
except PlatformError:
raise
except httpx.HTTPStatusError as e:
raise _map_caption_error(
e.response.status_code, f"HTTP错误: {e.response.status_code}"
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"字幕服务网络错误: {e}",
platform="volcengine_caption",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
# ==================== 自动打轴 ====================
async def submit_auto_align_task(
self,
audio_url: str,
audio_text: str,
caption_type: str = "speech",
sta_punc_mode: int = 3,
) -> dict[str, Any]:
"""提交自动字幕打轴任务,返回 {id: task_id}"""
params: dict[str, str | int] = {
"appid": self.appid,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
}
payload = {"url": audio_url, "audio_text": audio_text}
try:
response = await self.client.post(
f"{self.BASE_URL}/ata/submit",
params=params,
json=payload,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
if "id" not in data:
raise _map_caption_error(
500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
)
return data
except PlatformError:
raise
except Exception as e:
raise _map_caption_error(500, f"提交打轴任务失败: {str(e)}") from e
async def query_auto_align_task(
self,
task_id: str,
blocking: bool = False,
) -> dict[str, Any]:
"""查询打轴任务结果,返回原始 JSON"""
params = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
}
try:
response = await self.client.get(
f"{self.BASE_URL}/ata/query",
params=params,
headers=self._get_headers(),
)
response.raise_for_status()
return response.json()
except PlatformError:
raise
except Exception as e:
raise _map_caption_error(500, f"查询打轴任务失败: {str(e)}") from e