Files
meijiaka-zy/.review_diffs/providers.diff
T

346 lines
15 KiB
Diff

diff --git a/python-api/app/ai/providers/vidu_provider.py b/python-api/app/ai/providers/vidu_provider.py
index cab5902..fccfbbf 100644
--- a/python-api/app/ai/providers/vidu_provider.py
+++ b/python-api/app/ai/providers/vidu_provider.py
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
-def _map_vidu_error(status: int, message: str) -> PlatformError:
- """把 Vidu HTTP 错误映射为标准 PlatformError"""
+# Vidu 错误码分类
+_VIDU_AUDIT_ERROR_CODES = {
+ "TaskPromptPolicyViolation",
+ "AuditSubmitIllegal",
+ "CreationPolicyViolation",
+ "PhotoAuditNotPass",
+ "AuditFailed",
+ "ImageCheckBodyJointsFailed",
+ "ImageCheckFaceFailed",
+ "ImageObjectsUndetected",
+ "FaceDetectFailure",
+ "FaceDetectNotPass",
+ "NoFaceDetected",
+ "MultiFaceDetected",
+}
+
+_VIDU_RETRYABLE_ERROR_CODES = {
+ "InternalServiceFailure",
+ "ModelUnavailable",
+ "Unknown",
+}
+
+_VIDU_RATE_LIMIT_ERROR_CODES = {
+ "QuotaExceeded",
+ "TooManyRequests",
+ "SystemThrottling",
+ "OperationInProcess",
+}
+
+
+def _extract_vidu_error_code(message: str | None) -> str | None:
+ """从 Vidu 错误信息中提取错误码"""
+ if not message:
+ return None
+ # Vidu 错误码格式:"ErrorCode: 中文描述"
+ return message.split(":")[0].strip() or None
+
+
+def _map_vidu_error(
+ status: int,
+ message: str,
+ *,
+ err_code: str | None = None,
+) -> PlatformError:
+ """把 Vidu HTTP 错误映射为标准 PlatformError
+
+ 优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
+ """
+ raw_code = err_code or _extract_vidu_error_code(message)
+
+ # 1. 内容安全/审核类:不可重试
+ if raw_code in _VIDU_AUDIT_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=False,
+ error_type=PlatformErrorType.CONTENT_VIOLATION,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 2. 平台内部/模型不可用:可重试
+ if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=True,
+ error_type=PlatformErrorType.SERVER_ERROR,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 3. 限流类:可重试
+ if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=True,
+ error_type=PlatformErrorType.RATE_LIMIT,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 4. HTTP status 兜底
mapping = {
429: (PlatformErrorType.RATE_LIMIT, True),
401: (PlatformErrorType.AUTH_FAILED, False),
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
retryable=retryable,
error_type=error_type,
status_code=status,
+ raw_code=raw_code,
)
@@ -66,7 +149,9 @@ class ViduProvider:
from app.core.platform_config import get_platform_config_loader
platform_config = get_platform_config_loader().get_platform("vidu")
- self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
+ self.base_url = (
+ platform_config.base_url if platform_config else "https://api.vidu.cn"
+ ).rstrip("/")
if not self.api_key:
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
@@ -135,9 +220,12 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu TTS] 网络错误: {e}")
@@ -182,9 +270,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Clone] 网络错误: {e}")
@@ -238,9 +331,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body)
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu LipSync] 网络错误: {e}")
@@ -264,9 +362,14 @@ class ViduProvider:
resp = await self.client.get(url)
data = resp.json()
if resp.status_code != 200:
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Query] 网络错误: {e}")
diff --git a/python-api/app/ai/providers/volcengine_caption_provider.py b/python-api/app/ai/providers/volcengine_caption_provider.py
index 0f2f271..09ddcc7 100644
--- a/python-api/app/ai/providers/volcengine_caption_provider.py
+++ b/python-api/app/ai/providers/volcengine_caption_provider.py
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
if code is not None and code in error_mapping:
error_type, retryable = error_mapping[code]
return PlatformError(
- message, platform="volcengine_caption",
- retryable=retryable, error_type=error_type,
+ message,
+ platform="volcengine_caption",
+ retryable=retryable,
+ error_type=error_type,
status_code=status,
)
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
}
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
- message, platform="volcengine_caption",
- retryable=retryable, error_type=error_type,
+ message,
+ platform="volcengine_caption",
+ retryable=retryable,
+ error_type=error_type,
status_code=status,
)
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
max_lines: int = 1,
) -> dict[str, Any]:
"""提交字幕生成任务,返回 {id: task_id}"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"language": language,
"caption_type": caption_type,
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
+ raise _map_caption_error(
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
+ ) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
+ f"字幕服务网络错误: {e}",
+ platform="volcengine_caption",
+ retryable=True,
+ error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
blocking: bool = False,
) -> dict[str, Any]:
"""查询字幕任务结果,返回原始 JSON"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
+ raise _map_caption_error(
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
+ ) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
+ f"字幕服务网络错误: {e}",
+ platform="volcengine_caption",
+ retryable=True,
+ error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
sta_punc_mode: int = 3,
) -> dict[str, Any]:
"""提交自动字幕打轴任务,返回 {id: task_id}"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
response.raise_for_status()
data = response.json()
if "id" not in data:
- raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
+ raise _map_caption_error(
+ 500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
+ )
return data
except PlatformError:
raise
diff --git a/python-api/app/ai/providers/volcengine_provider.py b/python-api/app/ai/providers/volcengine_provider.py
index 0e2a5d5..9f029a0 100644
--- a/python-api/app/ai/providers/volcengine_provider.py
+++ b/python-api/app/ai/providers/volcengine_provider.py
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
if status == 429 or "rate limit" in message.lower():
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
- error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
+ error_type=PlatformErrorType.RATE_LIMIT,
+ status_code=status,
)
elif status in (401, 403) or "authentication" in message.lower():
return PlatformError(
- message, platform="volcengine_ark", retryable=False,
- error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=False,
+ error_type=PlatformErrorType.AUTH_FAILED,
+ status_code=status,
)
elif status and status >= 500:
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
- error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
+ error_type=PlatformErrorType.SERVER_ERROR,
+ status_code=status,
)
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
error_type=PlatformErrorType.TIMEOUT,
)
else:
return PlatformError(
- message, platform="volcengine_ark", retryable=False,
+ message,
+ platform="volcengine_ark",
+ retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)