346 lines
15 KiB
Diff
346 lines
15 KiB
Diff
diff --git a/python-api/app/ai/providers/vidu_provider.py b/python-api/app/ai/providers/vidu_provider.py
|
|
index cab5902..fccfbbf 100644
|
|
--- a/python-api/app/ai/providers/vidu_provider.py
|
|
+++ b/python-api/app/ai/providers/vidu_provider.py
|
|
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
-def _map_vidu_error(status: int, message: str) -> PlatformError:
|
|
- """把 Vidu HTTP 错误映射为标准 PlatformError"""
|
|
+# Vidu 错误码分类
|
|
+_VIDU_AUDIT_ERROR_CODES = {
|
|
+ "TaskPromptPolicyViolation",
|
|
+ "AuditSubmitIllegal",
|
|
+ "CreationPolicyViolation",
|
|
+ "PhotoAuditNotPass",
|
|
+ "AuditFailed",
|
|
+ "ImageCheckBodyJointsFailed",
|
|
+ "ImageCheckFaceFailed",
|
|
+ "ImageObjectsUndetected",
|
|
+ "FaceDetectFailure",
|
|
+ "FaceDetectNotPass",
|
|
+ "NoFaceDetected",
|
|
+ "MultiFaceDetected",
|
|
+}
|
|
+
|
|
+_VIDU_RETRYABLE_ERROR_CODES = {
|
|
+ "InternalServiceFailure",
|
|
+ "ModelUnavailable",
|
|
+ "Unknown",
|
|
+}
|
|
+
|
|
+_VIDU_RATE_LIMIT_ERROR_CODES = {
|
|
+ "QuotaExceeded",
|
|
+ "TooManyRequests",
|
|
+ "SystemThrottling",
|
|
+ "OperationInProcess",
|
|
+}
|
|
+
|
|
+
|
|
+def _extract_vidu_error_code(message: str | None) -> str | None:
|
|
+ """从 Vidu 错误信息中提取错误码"""
|
|
+ if not message:
|
|
+ return None
|
|
+ # Vidu 错误码格式:"ErrorCode: 中文描述"
|
|
+ return message.split(":")[0].strip() or None
|
|
+
|
|
+
|
|
+def _map_vidu_error(
|
|
+ status: int,
|
|
+ message: str,
|
|
+ *,
|
|
+ err_code: str | None = None,
|
|
+) -> PlatformError:
|
|
+ """把 Vidu HTTP 错误映射为标准 PlatformError
|
|
+
|
|
+ 优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
|
|
+ """
|
|
+ raw_code = err_code or _extract_vidu_error_code(message)
|
|
+
|
|
+ # 1. 内容安全/审核类:不可重试
|
|
+ if raw_code in _VIDU_AUDIT_ERROR_CODES:
|
|
+ return PlatformError(
|
|
+ message=message,
|
|
+ platform="vidu",
|
|
+ retryable=False,
|
|
+ error_type=PlatformErrorType.CONTENT_VIOLATION,
|
|
+ status_code=status,
|
|
+ raw_code=raw_code,
|
|
+ )
|
|
+
|
|
+ # 2. 平台内部/模型不可用:可重试
|
|
+ if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
|
|
+ return PlatformError(
|
|
+ message=message,
|
|
+ platform="vidu",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.SERVER_ERROR,
|
|
+ status_code=status,
|
|
+ raw_code=raw_code,
|
|
+ )
|
|
+
|
|
+ # 3. 限流类:可重试
|
|
+ if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
|
|
+ return PlatformError(
|
|
+ message=message,
|
|
+ platform="vidu",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.RATE_LIMIT,
|
|
+ status_code=status,
|
|
+ raw_code=raw_code,
|
|
+ )
|
|
+
|
|
+ # 4. HTTP status 兜底
|
|
mapping = {
|
|
429: (PlatformErrorType.RATE_LIMIT, True),
|
|
401: (PlatformErrorType.AUTH_FAILED, False),
|
|
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
|
|
retryable=retryable,
|
|
error_type=error_type,
|
|
status_code=status,
|
|
+ raw_code=raw_code,
|
|
)
|
|
|
|
|
|
@@ -66,7 +149,9 @@ class ViduProvider:
|
|
from app.core.platform_config import get_platform_config_loader
|
|
|
|
platform_config = get_platform_config_loader().get_platform("vidu")
|
|
- self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
|
|
+ self.base_url = (
|
|
+ platform_config.base_url if platform_config else "https://api.vidu.cn"
|
|
+ ).rstrip("/")
|
|
|
|
if not self.api_key:
|
|
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
|
|
@@ -135,9 +220,12 @@ class ViduProvider:
|
|
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
|
data = resp.json()
|
|
if resp.status_code != 200 or data.get("state") == "failed":
|
|
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
|
- logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
|
- raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
|
|
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
|
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
|
+ logger.error(
|
|
+ f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
|
+ )
|
|
+ raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
|
|
return data
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
logger.error(f"[Vidu TTS] 网络错误: {e}")
|
|
@@ -182,9 +270,14 @@ class ViduProvider:
|
|
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
|
data = resp.json()
|
|
if resp.status_code != 200 or data.get("state") == "failed":
|
|
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
|
- logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
|
- raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
|
|
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
|
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
|
+ logger.error(
|
|
+ f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
|
+ )
|
|
+ raise _map_vidu_error(
|
|
+ resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
|
|
+ )
|
|
return data
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
logger.error(f"[Vidu Clone] 网络错误: {e}")
|
|
@@ -238,9 +331,14 @@ class ViduProvider:
|
|
resp = await self.client.post(url, json=body)
|
|
data = resp.json()
|
|
if resp.status_code != 200 or data.get("state") == "failed":
|
|
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
|
- logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
|
- raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
|
|
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
|
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
|
+ logger.error(
|
|
+ f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
|
+ )
|
|
+ raise _map_vidu_error(
|
|
+ resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
|
|
+ )
|
|
return data
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
logger.error(f"[Vidu LipSync] 网络错误: {e}")
|
|
@@ -264,9 +362,14 @@ class ViduProvider:
|
|
resp = await self.client.get(url)
|
|
data = resp.json()
|
|
if resp.status_code != 200:
|
|
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
|
- logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
|
- raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
|
|
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
|
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
|
+ logger.error(
|
|
+ f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
|
+ )
|
|
+ raise _map_vidu_error(
|
|
+ resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
|
|
+ )
|
|
return data
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
logger.error(f"[Vidu Query] 网络错误: {e}")
|
|
diff --git a/python-api/app/ai/providers/volcengine_caption_provider.py b/python-api/app/ai/providers/volcengine_caption_provider.py
|
|
index 0f2f271..09ddcc7 100644
|
|
--- a/python-api/app/ai/providers/volcengine_caption_provider.py
|
|
+++ b/python-api/app/ai/providers/volcengine_caption_provider.py
|
|
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
|
if code is not None and code in error_mapping:
|
|
error_type, retryable = error_mapping[code]
|
|
return PlatformError(
|
|
- message, platform="volcengine_caption",
|
|
- retryable=retryable, error_type=error_type,
|
|
+ message,
|
|
+ platform="volcengine_caption",
|
|
+ retryable=retryable,
|
|
+ error_type=error_type,
|
|
status_code=status,
|
|
)
|
|
|
|
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
|
}
|
|
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
|
return PlatformError(
|
|
- message, platform="volcengine_caption",
|
|
- retryable=retryable, error_type=error_type,
|
|
+ message,
|
|
+ platform="volcengine_caption",
|
|
+ retryable=retryable,
|
|
+ error_type=error_type,
|
|
status_code=status,
|
|
)
|
|
|
|
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
|
|
max_lines: int = 1,
|
|
) -> dict[str, Any]:
|
|
"""提交字幕生成任务,返回 {id: task_id}"""
|
|
- params = {
|
|
+ params: dict[str, str | int] = {
|
|
"appid": self.appid,
|
|
"language": language,
|
|
"caption_type": caption_type,
|
|
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
|
|
except PlatformError:
|
|
raise
|
|
except httpx.HTTPStatusError as e:
|
|
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
|
+ raise _map_caption_error(
|
|
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
|
+ ) from e
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
raise PlatformError(
|
|
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
|
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
|
+ f"字幕服务网络错误: {e}",
|
|
+ platform="volcengine_caption",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.TIMEOUT,
|
|
) from e
|
|
except Exception as e:
|
|
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
|
|
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
|
|
blocking: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""查询字幕任务结果,返回原始 JSON"""
|
|
- params = {
|
|
+ params: dict[str, str | int] = {
|
|
"appid": self.appid,
|
|
"id": task_id,
|
|
"blocking": 1 if blocking else 0,
|
|
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
|
|
except PlatformError:
|
|
raise
|
|
except httpx.HTTPStatusError as e:
|
|
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
|
+ raise _map_caption_error(
|
|
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
|
+ ) from e
|
|
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
|
raise PlatformError(
|
|
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
|
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
|
+ f"字幕服务网络错误: {e}",
|
|
+ platform="volcengine_caption",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.TIMEOUT,
|
|
) from e
|
|
except Exception as e:
|
|
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
|
|
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
|
|
sta_punc_mode: int = 3,
|
|
) -> dict[str, Any]:
|
|
"""提交自动字幕打轴任务,返回 {id: task_id}"""
|
|
- params = {
|
|
+ params: dict[str, str | int] = {
|
|
"appid": self.appid,
|
|
"caption_type": caption_type,
|
|
"sta_punc_mode": sta_punc_mode,
|
|
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
if "id" not in data:
|
|
- raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
|
|
+ raise _map_caption_error(
|
|
+ 500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
|
|
+ )
|
|
return data
|
|
except PlatformError:
|
|
raise
|
|
diff --git a/python-api/app/ai/providers/volcengine_provider.py b/python-api/app/ai/providers/volcengine_provider.py
|
|
index 0e2a5d5..9f029a0 100644
|
|
--- a/python-api/app/ai/providers/volcengine_provider.py
|
|
+++ b/python-api/app/ai/providers/volcengine_provider.py
|
|
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
|
|
|
|
if status == 429 or "rate limit" in message.lower():
|
|
return PlatformError(
|
|
- message, platform="volcengine_ark", retryable=True,
|
|
- error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
|
|
+ message,
|
|
+ platform="volcengine_ark",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.RATE_LIMIT,
|
|
+ status_code=status,
|
|
)
|
|
elif status in (401, 403) or "authentication" in message.lower():
|
|
return PlatformError(
|
|
- message, platform="volcengine_ark", retryable=False,
|
|
- error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
|
|
+ message,
|
|
+ platform="volcengine_ark",
|
|
+ retryable=False,
|
|
+ error_type=PlatformErrorType.AUTH_FAILED,
|
|
+ status_code=status,
|
|
)
|
|
elif status and status >= 500:
|
|
return PlatformError(
|
|
- message, platform="volcengine_ark", retryable=True,
|
|
- error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
|
|
+ message,
|
|
+ platform="volcengine_ark",
|
|
+ retryable=True,
|
|
+ error_type=PlatformErrorType.SERVER_ERROR,
|
|
+ status_code=status,
|
|
)
|
|
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
|
|
return PlatformError(
|
|
- message, platform="volcengine_ark", retryable=True,
|
|
+ message,
|
|
+ platform="volcengine_ark",
|
|
+ retryable=True,
|
|
error_type=PlatformErrorType.TIMEOUT,
|
|
)
|
|
else:
|
|
return PlatformError(
|
|
- message, platform="volcengine_ark", retryable=False,
|
|
+ message,
|
|
+ platform="volcengine_ark",
|
|
+ retryable=False,
|
|
error_type=PlatformErrorType.UNKNOWN,
|
|
)
|
|
|