Files
meijiaka-zy/python-api/app/ai/providers/volcengine_mediakit_provider.py
T

179 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山引擎 MediaKit Provider
===========================
直接封装火山引擎 MediaKit HTTP API
- 图像背景移除(/api/v1/tools/remove-image-background/sync
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
"""
from __future__ import annotations
import logging
from typing import Any
import httpx
from app.config import get_settings
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
def _map_mediakit_error(status: int, message: str, code: int | None = None) -> PlatformError:
"""把 MediaKit 错误映射为标准 PlatformError"""
error_mapping = {
400: (PlatformErrorType.BAD_REQUEST, False),
401: (PlatformErrorType.AUTH_FAILED, False),
403: (PlatformErrorType.AUTH_FAILED, False),
429: (PlatformErrorType.RATE_LIMIT, True),
500: (PlatformErrorType.SERVER_ERROR, True),
502: (PlatformErrorType.SERVER_ERROR, True),
503: (PlatformErrorType.SERVER_ERROR, True),
}
error_type, retryable = error_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
message,
platform="volcengine_mediakit",
retryable=retryable,
error_type=error_type,
status_code=status,
)
class VolcengineMediakitProvider:
"""火山引擎 MediaKit Provider
直接调用 MediaKit HTTP API,不做业务层处理。
"""
BASE_URL = "https://mediakit.cn-beijing.volces.com"
REMOVE_BG_PATH = "/api/v1/tools/remove-image-background/sync"
DEFAULT_TIMEOUT = 60.0
def __init__(
self,
token: str | None = None,
client: httpx.AsyncClient | None = None,
):
settings = get_settings()
self.token = token or settings.VOLCENGINE_MEDIAKIT_TOKEN or ""
if not self.token:
raise PlatformError(
"VOLCENGINE_MEDIAKIT_TOKEN 未配置",
platform="volcengine_mediakit",
retryable=False,
error_type=PlatformErrorType.BAD_REQUEST,
)
if client is not None:
self.client = client
self._owns_client = False
else:
self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
self._owns_client = True
def _get_headers(self) -> dict:
return {
"Authorization": f"Bearer; {self.token}",
"Content-Type": "application/json",
}
async def close(self) -> None:
"""关闭 HTTP 客户端"""
if self._owns_client and self.client and not self.client.is_closed:
await self.client.aclose()
async def remove_background(
self,
image_url: str,
scene: str = "general",
need_contour: bool = False,
contour_color: str = "#FFFFFF",
contour_size: int = 10,
need_crop_background: bool = False,
) -> dict[str, Any]:
"""同步抠图,返回原始 JSON
Args:
image_url: 原始图片 URL
scene: 场景类型
need_contour: 是否为主体生成描边(仅 human/product 场景生效)
contour_color: 描边颜色,十六进制 RGB 格式
contour_size: 描边宽度(px),范围 [1, 100]
need_crop_background: 是否裁剪透明背景到刚好包裹主体
Returns:
{"code": 0, "message": "Success", "data": {"image_url": "https://..."}}
"""
payload: dict[str, Any] = {"image_url": image_url, "scene": scene}
if need_contour:
payload["need_contour"] = True
payload["contour_color"] = contour_color
payload["contour_size"] = max(1, min(100, contour_size))
if need_crop_background:
payload["need_crop_background"] = True
try:
response = await self.client.post(
f"{self.BASE_URL}{self.REMOVE_BG_PATH}",
json=payload,
headers=self._get_headers(),
)
response.raise_for_status()
data = response.json()
# 火山引擎 MediaKit 有两种响应格式:
# 格式1: {"code": 0, "message": "...", "data": {...}}
# 格式2: {"success": true, "result": {...}, "expires_at": ...}
code = data.get("code")
if code is not None:
# 格式1
if code != 0:
logger.warning(
f"[MediaKit] 抠图业务失败: code={code}, "
f"message={data.get('message', 'N/A')}, "
f"raw_response={data}, image_url={image_url[:80]}..."
)
raise _map_mediakit_error(
response.status_code,
data.get("message", f"抠图失败: code={code}"),
code=code,
)
return data
else:
# 格式2
if not data.get("success", False):
logger.warning(
f"[MediaKit] 抠图业务失败: success=false, "
f"raw_response={data}, image_url={image_url[:80]}..."
)
raise _map_mediakit_error(
response.status_code,
"抠图失败: 平台返回失败状态",
)
# 将格式2标准化为格式1,方便上层统一处理
return {
"code": 0,
"message": "Success",
"data": data.get("result", {}),
}
except PlatformError:
raise
except httpx.HTTPStatusError as e:
raise _map_mediakit_error(
e.response.status_code, f"HTTP错误: {e.response.status_code}"
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"MediaKit 网络错误: {e}",
platform="volcengine_mediakit",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_mediakit_error(500, f"抠图失败: {str(e)}") from e