135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
"""
|
||
火山引擎 MediaKit Provider
|
||
===========================
|
||
|
||
直接封装火山引擎 MediaKit HTTP API:
|
||
- 图像背景移除(/api/v1/tools/remove-image-background/sync)
|
||
|
||
使用 httpx.AsyncClient,支持外部注入(由 lifespan 管理生命周期)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _map_mediakit_error(status: int, message: str, code: int | None = None) -> PlatformError:
|
||
"""把 MediaKit 错误映射为标准 PlatformError"""
|
||
error_mapping = {
|
||
400: (PlatformErrorType.BAD_REQUEST, False),
|
||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||
403: (PlatformErrorType.AUTH_FAILED, False),
|
||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||
500: (PlatformErrorType.SERVER_ERROR, True),
|
||
502: (PlatformErrorType.SERVER_ERROR, True),
|
||
503: (PlatformErrorType.SERVER_ERROR, True),
|
||
}
|
||
error_type, retryable = error_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||
return PlatformError(
|
||
message, platform="volcengine_mediakit",
|
||
retryable=retryable, error_type=error_type,
|
||
status_code=status,
|
||
)
|
||
|
||
|
||
class VolcengineMediakitProvider:
|
||
"""火山引擎 MediaKit Provider
|
||
|
||
直接调用 MediaKit HTTP API,不做业务层处理。
|
||
"""
|
||
|
||
BASE_URL = "https://mediakit.cn-beijing.volces.com"
|
||
REMOVE_BG_PATH = "/api/v1/tools/remove-image-background/sync"
|
||
DEFAULT_TIMEOUT = 60.0
|
||
|
||
def __init__(
|
||
self,
|
||
token: str | None = None,
|
||
client: httpx.AsyncClient | None = None,
|
||
):
|
||
settings = get_settings()
|
||
self.token = token or settings.VOLCENGINE_MEDIAKIT_TOKEN or ""
|
||
|
||
if not self.token:
|
||
raise PlatformError(
|
||
"VOLCENGINE_MEDIAKIT_TOKEN 未配置",
|
||
platform="volcengine_mediakit",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.BAD_REQUEST,
|
||
)
|
||
|
||
if client is not None:
|
||
self.client = client
|
||
self._owns_client = False
|
||
else:
|
||
self.client = httpx.AsyncClient(timeout=self.DEFAULT_TIMEOUT)
|
||
self._owns_client = True
|
||
|
||
def _get_headers(self) -> dict:
|
||
return {
|
||
"Authorization": f"Bearer; {self.token}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP 客户端"""
|
||
if self._owns_client and self.client and not self.client.is_closed:
|
||
await self.client.aclose()
|
||
|
||
async def remove_background(
|
||
self,
|
||
image_url: str,
|
||
scene: str = "general",
|
||
) -> dict[str, Any]:
|
||
"""同步抠图,返回原始 JSON
|
||
|
||
Returns:
|
||
{"code": 0, "message": "Success", "data": {"image_url": "https://..."}}
|
||
"""
|
||
payload = {"image_url": image_url, "scene": scene}
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
f"{self.BASE_URL}{self.REMOVE_BG_PATH}",
|
||
json=payload,
|
||
headers=self._get_headers(),
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
|
||
code = data.get("code", -1)
|
||
if code != 0:
|
||
logger.warning(
|
||
f"[MediaKit] 抠图业务失败: code={code}, "
|
||
f"message={data.get('message', 'N/A')}, "
|
||
f"raw_response={data}, image_url={image_url[:80]}..."
|
||
)
|
||
raise _map_mediakit_error(
|
||
response.status_code,
|
||
data.get("message", f"抠图失败: code={code}"),
|
||
code=code,
|
||
)
|
||
return data
|
||
|
||
except PlatformError:
|
||
raise
|
||
except httpx.HTTPStatusError as e:
|
||
raise _map_mediakit_error(
|
||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||
) from e
|
||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||
raise PlatformError(
|
||
f"MediaKit 网络错误: {e}", platform="volcengine_mediakit",
|
||
retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||
) from e
|
||
except Exception as e:
|
||
raise _map_mediakit_error(500, f"抠图失败: {str(e)}") from e
|