Files
meijiaka-zy/python-api/app/ai/providers/volcengine_provider.py
T

337 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山方舟官方 SDK Provider
==========================
基于火山方舟官方 Python SDK 实现,支持:
- 文本生成 (Chat Completions)
- 流式输出
- 向量化
- 深度思考
- 工具调用
安装依赖:
pip install 'volcengine-python-sdk[ark]'
文档:
https://www.volcengine.com/docs/82379
"""
from __future__ import annotations
import logging
import time
from app.ai.providers.base import (
GenerationResult,
LLMProvider,
ModelHealth,
ProviderError,
)
from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
# 尝试导入火山方舟 SDK
try:
from volcenginesdkarkruntime import AsyncArk
VOLCENGINE_SDK_AVAILABLE = True
except ImportError:
VOLCENGINE_SDK_AVAILABLE = False
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
class VolcengineProvider(LLMProvider):
"""
火山方舟官方 SDK Provider
支持能力:
- 文本对话 (Chat Completions)
- 向量化 (Embeddings)
- 深度思考 (Reasoning)
"""
provider_id = "volcengine"
provider_name = "火山方舟"
# 默认配置
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
# 模型 ID 映射(从配置文件加载)
PRESET_MODELS: dict[str, str] = {}
@classmethod
def load_models_from_config(cls, models: list[dict]):
"""
从配置文件加载模型列表
Args:
models: 模型列表,每个模型包含 model_name 字段
"""
cls.PRESET_MODELS = {}
for model in models:
model_id = model.get("model_name")
model_alias = model.get("id")
if model_id and model_alias:
cls.PRESET_MODELS[model_alias] = model_id
# 确保至少有一个默认模型
if not cls.PRESET_MODELS:
cls.PRESET_MODELS = {
"doubao-seed-2-0-pro": "doubao-seed-2-0-pro-260215",
}
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
timeout: int = DEFAULT_TIMEOUT,
default_model: str | None = None,
**kwargs,
):
"""
初始化火山方舟 Provider
Args:
api_key: API Key,从火山方舟控制台获取
base_url: API 基础地址,默认北京节点
timeout: 请求超时时间(秒)
default_model: 默认模型(Model ID
"""
from app.config import get_settings
from app.core.platform_config import get_platform_config_loader
settings = get_settings()
# API Key 从环境变量读取
resolved_api_key = api_key or settings.VOLCENGINE_API_KEY
# base_url 从 platform-config.yaml 读取,fallback 到代码常量
loader = get_platform_config_loader()
platform = loader.get_platform("volcengine_ark")
yaml_base_url = platform.base_url if platform else None
resolved_base_url = base_url or yaml_base_url or self.DEFAULT_BASE_URL
super().__init__(resolved_api_key, resolved_base_url, **kwargs)
if not VOLCENGINE_SDK_AVAILABLE:
raise ProviderError(
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
provider_id=self.provider_id,
)
if not self.api_key:
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
self.timeout = timeout
# 使用模型 ID 映射(自动映射)
if default_model:
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
elif self.PRESET_MODELS:
self.default_model = list(self.PRESET_MODELS.values())[0]
else:
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
self.default_model = "doubao-seed-2-0-pro-260215"
self.client = self._create_client()
def _create_client(self) -> AsyncArk:
"""创建火山方舟异步客户端"""
return AsyncArk(
api_key=self.api_key,
base_url=self.base_url or self.DEFAULT_BASE_URL,
timeout=self.timeout,
)
async def generate(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
system_prompt: str | None = None,
**kwargs,
) -> GenerationResult:
"""
同步生成文本
Args:
prompt: 用户提示词
model: 模型 ID(如 doubao-seed-2-0-pro-260215
temperature: 随机性 (0-2)
max_tokens: 最大生成 token 数
system_prompt: 系统提示词(可选)
**kwargs: 额外参数(如 reasoning_effort 控制思考程度)
Returns:
GenerationResult: 生成结果
"""
try:
# 构建消息
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 映射模型名称到模型 ID
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
# 构建请求参数
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 额外参数(如深度思考)
if "reasoning_effort" in kwargs:
request_params["reasoning_effort"] = kwargs["reasoning_effort"]
if kwargs.get("response_format") == "json_object":
request_params["response_format"] = {"type": "json_object"}
# 调用 API
completion = await self.client.chat.completions.create(**request_params)
# 解析结果
content = completion.choices[0].message.content or ""
usage = None
if completion.usage:
usage = {
"prompt_tokens": completion.usage.prompt_tokens,
"completion_tokens": completion.usage.completion_tokens,
"total_tokens": completion.usage.total_tokens,
}
return GenerationResult(
content=content,
usage=usage,
model=completion.model or model or self.default_model,
)
except Exception as e:
raise self._wrap_error(e)
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
"""
文本向量化
Args:
texts: 文本列表
model: 向量化模型
Returns:
dict: 包含向量化结果
"""
try:
response = await self.client.embeddings.create(
model=model or "doubao-embedding-1.5", input=texts, **kwargs
)
embeddings = []
for item in response.data:
embeddings.append(
{
"index": item.index,
"embedding": item.embedding,
}
)
return {
"embeddings": embeddings,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
except Exception as e:
raise self._wrap_error(e)
async def health_check(self, model: str | None = None) -> ModelHealth:
"""健康检查"""
start_time = time.time()
test_model = model or self.default_model
try:
await self.client.chat.completions.create(
model=test_model,
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=True,
response_time=response_time,
last_error=None,
)
except Exception as e:
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=False,
response_time=(time.time() - start_time) * 1000,
last_error=str(e),
)
def _wrap_error(self, e: Exception) -> PlatformError:
"""把 SDK 异常翻译为标准 PlatformError"""
message = str(e)
status = getattr(e, "status_code", None) or getattr(e, "code", None)
if status == 429 or "rate limit" in message.lower():
return PlatformError(
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.RATE_LIMIT,
status_code=status,
)
elif status in (401, 403) or "authentication" in message.lower():
return PlatformError(
message,
platform="volcengine_ark",
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
status_code=status,
)
elif status and status >= 500:
return PlatformError(
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.SERVER_ERROR,
status_code=status,
)
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
return PlatformError(
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
)
else:
return PlatformError(
message,
platform="volcengine_ark",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
@property
def available_models(self) -> list[str]:
"""返回可用模型列表(与 platform-config.yaml 配置保持一致)"""
return [
"doubao-seed-2-0-pro",
]