337 lines
11 KiB
Python
337 lines
11 KiB
Python
"""
|
||
火山方舟官方 SDK Provider
|
||
==========================
|
||
|
||
基于火山方舟官方 Python SDK 实现,支持:
|
||
- 文本生成 (Chat Completions)
|
||
- 流式输出
|
||
- 向量化
|
||
- 深度思考
|
||
- 工具调用
|
||
|
||
安装依赖:
|
||
pip install 'volcengine-python-sdk[ark]'
|
||
|
||
文档:
|
||
https://www.volcengine.com/docs/82379
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import time
|
||
|
||
from app.ai.providers.base import (
|
||
GenerationResult,
|
||
LLMProvider,
|
||
ModelHealth,
|
||
ProviderError,
|
||
)
|
||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 尝试导入火山方舟 SDK
|
||
try:
|
||
from volcenginesdkarkruntime import AsyncArk
|
||
|
||
VOLCENGINE_SDK_AVAILABLE = True
|
||
except ImportError:
|
||
VOLCENGINE_SDK_AVAILABLE = False
|
||
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
|
||
|
||
|
||
class VolcengineProvider(LLMProvider):
|
||
"""
|
||
火山方舟官方 SDK Provider
|
||
|
||
支持能力:
|
||
- 文本对话 (Chat Completions)
|
||
- 向量化 (Embeddings)
|
||
- 深度思考 (Reasoning)
|
||
"""
|
||
|
||
provider_id = "volcengine"
|
||
provider_name = "火山方舟"
|
||
|
||
# 默认配置
|
||
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
|
||
|
||
# 模型 ID 映射(从配置文件加载)
|
||
PRESET_MODELS: dict[str, str] = {}
|
||
|
||
@classmethod
|
||
def load_models_from_config(cls, models: list[dict]):
|
||
"""
|
||
从配置文件加载模型列表
|
||
|
||
Args:
|
||
models: 模型列表,每个模型包含 model_name 字段
|
||
"""
|
||
cls.PRESET_MODELS = {}
|
||
for model in models:
|
||
model_id = model.get("model_name")
|
||
model_alias = model.get("id")
|
||
if model_id and model_alias:
|
||
cls.PRESET_MODELS[model_alias] = model_id
|
||
|
||
# 确保至少有一个默认模型
|
||
if not cls.PRESET_MODELS:
|
||
cls.PRESET_MODELS = {
|
||
"doubao-seed-2-0-pro": "doubao-seed-2-0-pro-260215",
|
||
}
|
||
|
||
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
base_url: str | None = None,
|
||
timeout: int = DEFAULT_TIMEOUT,
|
||
default_model: str | None = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
初始化火山方舟 Provider
|
||
|
||
Args:
|
||
api_key: API Key,从火山方舟控制台获取
|
||
base_url: API 基础地址,默认北京节点
|
||
timeout: 请求超时时间(秒)
|
||
default_model: 默认模型(Model ID)
|
||
"""
|
||
from app.config import get_settings
|
||
from app.core.platform_config import get_platform_config_loader
|
||
|
||
settings = get_settings()
|
||
# API Key 从环境变量读取
|
||
resolved_api_key = api_key or settings.VOLCENGINE_API_KEY
|
||
# base_url 从 platform-config.yaml 读取,fallback 到代码常量
|
||
loader = get_platform_config_loader()
|
||
platform = loader.get_platform("volcengine_ark")
|
||
yaml_base_url = platform.base_url if platform else None
|
||
resolved_base_url = base_url or yaml_base_url or self.DEFAULT_BASE_URL
|
||
|
||
super().__init__(resolved_api_key, resolved_base_url, **kwargs)
|
||
|
||
if not VOLCENGINE_SDK_AVAILABLE:
|
||
raise ProviderError(
|
||
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
|
||
provider_id=self.provider_id,
|
||
)
|
||
|
||
if not self.api_key:
|
||
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
|
||
|
||
self.timeout = timeout
|
||
# 使用模型 ID 映射(自动映射)
|
||
if default_model:
|
||
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
|
||
elif self.PRESET_MODELS:
|
||
self.default_model = list(self.PRESET_MODELS.values())[0]
|
||
else:
|
||
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
|
||
self.default_model = "doubao-seed-2-0-pro-260215"
|
||
|
||
self.client = self._create_client()
|
||
|
||
def _create_client(self) -> AsyncArk:
|
||
"""创建火山方舟异步客户端"""
|
||
return AsyncArk(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url or self.DEFAULT_BASE_URL,
|
||
timeout=self.timeout,
|
||
)
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
system_prompt: str | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""
|
||
同步生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
model: 模型 ID(如 doubao-seed-2-0-pro-260215)
|
||
temperature: 随机性 (0-2)
|
||
max_tokens: 最大生成 token 数
|
||
system_prompt: 系统提示词(可选)
|
||
**kwargs: 额外参数(如 reasoning_effort 控制思考程度)
|
||
|
||
Returns:
|
||
GenerationResult: 生成结果
|
||
"""
|
||
try:
|
||
# 构建消息
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
# 映射模型名称到模型 ID
|
||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||
|
||
# 构建请求参数
|
||
request_params = {
|
||
"model": model_id,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
}
|
||
|
||
if max_tokens:
|
||
request_params["max_tokens"] = max_tokens
|
||
|
||
# 额外参数(如深度思考)
|
||
if "reasoning_effort" in kwargs:
|
||
request_params["reasoning_effort"] = kwargs["reasoning_effort"]
|
||
|
||
if kwargs.get("response_format") == "json_object":
|
||
request_params["response_format"] = {"type": "json_object"}
|
||
|
||
# 调用 API
|
||
completion = await self.client.chat.completions.create(**request_params)
|
||
|
||
# 解析结果
|
||
content = completion.choices[0].message.content or ""
|
||
usage = None
|
||
if completion.usage:
|
||
usage = {
|
||
"prompt_tokens": completion.usage.prompt_tokens,
|
||
"completion_tokens": completion.usage.completion_tokens,
|
||
"total_tokens": completion.usage.total_tokens,
|
||
}
|
||
|
||
return GenerationResult(
|
||
content=content,
|
||
usage=usage,
|
||
model=completion.model or model or self.default_model,
|
||
)
|
||
|
||
except Exception as e:
|
||
raise self._wrap_error(e)
|
||
|
||
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
|
||
"""
|
||
文本向量化
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
model: 向量化模型
|
||
|
||
Returns:
|
||
dict: 包含向量化结果
|
||
"""
|
||
try:
|
||
response = await self.client.embeddings.create(
|
||
model=model or "doubao-embedding-1.5", input=texts, **kwargs
|
||
)
|
||
|
||
embeddings = []
|
||
for item in response.data:
|
||
embeddings.append(
|
||
{
|
||
"index": item.index,
|
||
"embedding": item.embedding,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"embeddings": embeddings,
|
||
"model": response.model,
|
||
"usage": {
|
||
"prompt_tokens": response.usage.prompt_tokens,
|
||
"total_tokens": response.usage.total_tokens,
|
||
},
|
||
}
|
||
|
||
except Exception as e:
|
||
raise self._wrap_error(e)
|
||
|
||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||
"""健康检查"""
|
||
start_time = time.time()
|
||
test_model = model or self.default_model
|
||
|
||
try:
|
||
await self.client.chat.completions.create(
|
||
model=test_model,
|
||
messages=[{"role": "user", "content": "Hi"}],
|
||
max_tokens=5,
|
||
)
|
||
|
||
response_time = (time.time() - start_time) * 1000
|
||
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"火山方舟 {test_model}",
|
||
is_available=True,
|
||
response_time=response_time,
|
||
last_error=None,
|
||
)
|
||
|
||
except Exception as e:
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"火山方舟 {test_model}",
|
||
is_available=False,
|
||
response_time=(time.time() - start_time) * 1000,
|
||
last_error=str(e),
|
||
)
|
||
|
||
def _wrap_error(self, e: Exception) -> PlatformError:
|
||
"""把 SDK 异常翻译为标准 PlatformError"""
|
||
message = str(e)
|
||
status = getattr(e, "status_code", None) or getattr(e, "code", None)
|
||
|
||
if status == 429 or "rate limit" in message.lower():
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_ark",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.RATE_LIMIT,
|
||
status_code=status,
|
||
)
|
||
elif status in (401, 403) or "authentication" in message.lower():
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_ark",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.AUTH_FAILED,
|
||
status_code=status,
|
||
)
|
||
elif status and status >= 500:
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_ark",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.SERVER_ERROR,
|
||
status_code=status,
|
||
)
|
||
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_ark",
|
||
retryable=True,
|
||
error_type=PlatformErrorType.TIMEOUT,
|
||
)
|
||
else:
|
||
return PlatformError(
|
||
message,
|
||
platform="volcengine_ark",
|
||
retryable=False,
|
||
error_type=PlatformErrorType.UNKNOWN,
|
||
)
|
||
|
||
@property
|
||
def available_models(self) -> list[str]:
|
||
"""返回可用模型列表(与 platform-config.yaml 配置保持一致)"""
|
||
return [
|
||
"doubao-seed-2-0-pro",
|
||
]
|