179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
"""
|
||
火山引擎 MediaKit Provider
|
||
===========================
|
||
|
||
直接封装火山引擎 MediaKit HTTP API:
|
||
- 图像背景移除(/api/v1/tools/remove-image-background/sync)
|
||
|
||
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _map_mediakit_error(status: int, message: str, code: int | None = None) -> PlatformError:
|
||
"""把 MediaKit 错误映射为标准 PlatformError"""
|
||
error_mapping = {
|
||
400: (PlatformErrorType.BAD_REQUEST, False),
|
||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||
403: (PlatformErrorType.AUTH_FAILED, False),
|
||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||
500: (PlatformErrorType.SERVER_ERROR, True),
|
||
502: (PlatformErrorType.SERVER_ERROR, True),
|
||
503: (PlatformErrorType.SERVER_ERROR, True),
|
||
}
|
||
error_type, retryable = error_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_mediakit",
|
||
retryable=retryable,
|
||
error_type=error_type,
|
||
status_code=status,
|
||
)
|
||
|
||
|
||
class VolcengineMediakitProvider:
|
||
"""火山引擎 MediaKit Provider
|
||
|
||
直接调用 MediaKit HTTP API,不做业务层处理。
|
||
"""
|
||
|
||
BASE_URL = "https://mediakit.cn-beijing.volces.com"
|
||
REMOVE_BG_PATH = "/api/v1/tools/remove-image-background/sync"
|
||
DEFAULT_TIMEOUT = 60.0
|
||
|
||
def __init__(
|
||
self,
|
||
token: str | None = None,
|
||
client: httpx.AsyncClient | None = None,
|
||
):
|
||
settings = get_settings()
|
||
self.token = token or settings.VOLCENGINE_MEDIAKIT_TOKEN or ""
|
||
|
||
if not self.token:
|
||
raise PlatformError(
|
||
"VOLCENGINE_MEDIAKIT_TOKEN 未配置",
|
||
platform="volcengine_mediakit",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
self._owns_client = False
|
||
else:
|
||
self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
|
||
self._owns_client = True
|
||
|
||
def _get_headers(self) -> dict:
|
||
return {
|
||
"Authorization": f"Bearer; {self.token}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP 客户端"""
|
||
if self._owns_client and self.client and not self.client.is_closed:
|
||
await self.client.aclose()
|
||
|
||
async def remove_background(
|
||
self,
|
||
image_url: str,
|
||
scene: str = "general",
|
||
need_contour: bool = False,
|
||
contour_color: str = "#FFFFFF",
|
||
contour_size: int = 10,
|
||
need_crop_background: bool = False,
|
||
) -> dict[str, Any]:
|
||
"""同步抠图,返回原始 JSON
|
||
|
||
Args:
|
||
image_url: 原始图片 URL
|
||
scene: 场景类型
|
||
need_contour: 是否为主体生成描边(仅 human/product 场景生效)
|
||
contour_color: 描边颜色,十六进制 RGB 格式
|
||
contour_size: 描边宽度(px),范围 [1, 100]
|
||
need_crop_background: 是否裁剪透明背景到刚好包裹主体
|
||
|
||
Returns:
|
||
{"code": 0, "message": "Success", "data": {"image_url": "https://..."}}
|
||
"""
|
||
payload: dict[str, Any] = {"image_url": image_url, "scene": scene}
|
||
if need_contour:
|
||
payload["need_contour"] = True
|
||
payload["contour_color"] = contour_color
|
||
payload["contour_size"] = max(1, min(100, contour_size))
|
||
if need_crop_background:
|
||
payload["need_crop_background"] = True
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
f"{self.BASE_URL}{self.REMOVE_BG_PATH}",
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
# 火山引擎 MediaKit 有两种响应格式:
|
||
# 格式1: {"code": 0, "message": "...", "data": {...}}
|
||
# 格式2: {"success": true, "result": {...}, "expires_at": ...}
|
||
code = data.get("code")
|
||
if code is not None:
|
||
# 格式1
|
||
if code != 0:
|
||
logger.warning(
|
||
f"[MediaKit] 抠图业务失败: code={code}, "
|
||
f"message={data.get('message', 'N/A')}, "
|
||
f"raw_response={data}, image_url={image_url[:80]}..."
|
||
)
|
||
raise _map_mediakit_error(
|
||
response.status_code,
|
||
data.get("message", f"抠图失败: code={code}"),
|
||
code=code,
|
||
)
|
||
return data
|
||
else:
|
||
# 格式2
|
||
if not data.get("success", False):
|
||
logger.warning(
|
||
f"[MediaKit] 抠图业务失败: success=false, "
|
||
f"raw_response={data}, image_url={image_url[:80]}..."
|
||
)
|
||
raise _map_mediakit_error(
|
||
response.status_code,
|
||
"抠图失败: 平台返回失败状态",
|
||
)
|
||
# 将格式2标准化为格式1,方便上层统一处理
|
||
return {
|
||
"code": 0,
|
||
"message": "Success",
|
||
"data": data.get("result", {}),
|
||
}
|
||
|
||
except PlatformError:
|
||
raise
|
||
except httpx.HTTPStatusError as e:
|
||
raise _map_mediakit_error(
|
||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||
) from e
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
raise PlatformError(
|
||
f"MediaKit 网络错误: {e}",
|
||
platform="volcengine_mediakit",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
except Exception as e:
|
||
raise _map_mediakit_error(500, f"抠图失败: {str(e)}") from e
|