465 lines
14 KiB
Python
465 lines
14 KiB
Python
"""
|
||
火山方舟官方 SDK Provider
|
||
==========================
|
||
|
||
基于火山方舟官方 Python SDK 实现,支持:
|
||
- 文本生成 (Chat Completions)
|
||
- 流式输出
|
||
- 图片生成
|
||
- 向量化
|
||
- 深度思考
|
||
- 工具调用
|
||
|
||
安装依赖:
|
||
pip install 'volcengine-python-sdk[ark]'
|
||
|
||
文档:
|
||
https://www.volcengine.com/docs/82379
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import time
|
||
from collections.abc import AsyncIterator
|
||
|
||
from app.ai.providers.base import (
|
||
GenerationResult,
|
||
LLMProvider,
|
||
ModelHealth,
|
||
ProviderError,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 尝试导入火山方舟 SDK
|
||
try:
|
||
from volcenginesdkarkruntime import Ark
|
||
|
||
VOLCENGINE_SDK_AVAILABLE = True
|
||
except ImportError:
|
||
VOLCENGINE_SDK_AVAILABLE = False
|
||
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
|
||
|
||
|
||
class VolcengineProvider(LLMProvider):
|
||
"""
|
||
火山方舟官方 SDK Provider
|
||
|
||
支持多模态能力:
|
||
- 文本对话 (Chat Completions)
|
||
- 图片生成 (Image Generation)
|
||
- 向量化 (Embeddings)
|
||
- 深度思考 (Reasoning)
|
||
"""
|
||
|
||
provider_id = "volcengine"
|
||
provider_name = "火山方舟"
|
||
|
||
# 默认配置
|
||
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
|
||
|
||
# 模型 ID 映射(从配置文件加载)
|
||
PRESET_MODELS: dict[str, str] = {}
|
||
|
||
@classmethod
|
||
def load_models_from_config(cls, models: list[dict]):
|
||
"""
|
||
从配置文件加载模型列表
|
||
|
||
Args:
|
||
models: 模型列表,每个模型包含 model_name 字段
|
||
"""
|
||
cls.PRESET_MODELS = {}
|
||
for model in models:
|
||
model_id = model.get("model_name")
|
||
model_alias = model.get("id")
|
||
if model_id and model_alias:
|
||
cls.PRESET_MODELS[model_alias] = model_id
|
||
|
||
# 确保至少有一个默认模型
|
||
if not cls.PRESET_MODELS:
|
||
cls.PRESET_MODELS = {
|
||
"doubao-seed-2-0-lite": "doubao-seed-2-0-lite-260215",
|
||
}
|
||
|
||
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
base_url: str | None = None,
|
||
timeout: int = DEFAULT_TIMEOUT,
|
||
default_model: str | None = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
初始化火山方舟 Provider
|
||
|
||
Args:
|
||
api_key: API Key,从火山方舟控制台获取
|
||
base_url: API 基础地址,默认北京节点
|
||
timeout: 请求超时时间(秒)
|
||
default_model: 默认模型(Model ID)
|
||
"""
|
||
super().__init__(api_key, base_url, **kwargs)
|
||
|
||
if not VOLCENGINE_SDK_AVAILABLE:
|
||
raise ProviderError(
|
||
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
|
||
provider_id=self.provider_id,
|
||
)
|
||
|
||
if not self.api_key:
|
||
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
|
||
|
||
self.timeout = timeout
|
||
# 使用模型 ID 映射(自动映射)
|
||
if default_model:
|
||
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
|
||
elif self.PRESET_MODELS:
|
||
# 兜底:使用 doubao-seed-2-0-lite 或第一个可用的模型
|
||
self.default_model = self.PRESET_MODELS.get(
|
||
"doubao-seed-2-0-lite", list(self.PRESET_MODELS.values())[0]
|
||
)
|
||
else:
|
||
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
|
||
self.default_model = "doubao-seed-2-0-lite-260215"
|
||
|
||
self.client = self._create_client()
|
||
|
||
def _create_client(self) -> Ark:
|
||
"""创建火山方舟客户端"""
|
||
return Ark(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url or self.DEFAULT_BASE_URL,
|
||
timeout=self.timeout,
|
||
)
|
||
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
system_prompt: str | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""
|
||
同步生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
model: 模型 ID(如 doubao-seed-2-0-pro-260215)
|
||
temperature: 随机性 (0-2)
|
||
max_tokens: 最大生成 token 数
|
||
system_prompt: 系统提示词(可选)
|
||
**kwargs: 额外参数(如 enable_thinking 启用深度思考)
|
||
|
||
Returns:
|
||
GenerationResult: 生成结果
|
||
"""
|
||
try:
|
||
# 构建消息
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
# 映射模型名称到模型 ID
|
||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||
|
||
# 构建请求参数
|
||
request_params = {
|
||
"model": model_id,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
}
|
||
|
||
if max_tokens:
|
||
request_params["max_tokens"] = max_tokens
|
||
|
||
# 额外参数(如深度思考)
|
||
if "enable_thinking" in kwargs:
|
||
request_params["extra_body"] = {"enable_thinking": kwargs["enable_thinking"]}
|
||
|
||
# 调用 API
|
||
completion = self.client.chat.completions.create(**request_params)
|
||
|
||
# 解析结果
|
||
content = completion.choices[0].message.content or ""
|
||
usage = None
|
||
if completion.usage:
|
||
usage = {
|
||
"prompt_tokens": completion.usage.prompt_tokens,
|
||
"completion_tokens": completion.usage.completion_tokens,
|
||
"total_tokens": completion.usage.total_tokens,
|
||
}
|
||
|
||
return GenerationResult(
|
||
content=content,
|
||
usage=usage,
|
||
model=completion.model or model or self.default_model,
|
||
)
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"火山方舟生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def generate_stream(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
system_prompt: str | None = None,
|
||
**kwargs,
|
||
) -> AsyncIterator[str]:
|
||
"""
|
||
流式生成文本
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
model: 模型名称
|
||
temperature: 随机性
|
||
max_tokens: 最大 token 数
|
||
system_prompt: 系统提示词(可选)
|
||
**kwargs: 额外参数
|
||
|
||
Yields:
|
||
str: 生成的文本片段
|
||
"""
|
||
try:
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||
|
||
request_params = {
|
||
"model": model_id,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"stream": True,
|
||
}
|
||
|
||
if max_tokens:
|
||
request_params["max_tokens"] = max_tokens
|
||
|
||
stream = self.client.chat.completions.create(**request_params)
|
||
|
||
for chunk in stream:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
yield chunk.choices[0].delta.content
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def generate_stream_with_progress(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = 8000,
|
||
system_prompt: str | None = None,
|
||
**kwargs,
|
||
) -> AsyncIterator[dict]:
|
||
"""
|
||
流式生成文本,带进度信息
|
||
|
||
Args:
|
||
prompt: 用户提示词
|
||
model: 模型名称
|
||
temperature: 随机性
|
||
max_tokens: 最大 token 数
|
||
system_prompt: 系统提示词(可选)
|
||
|
||
Yields:
|
||
dict: {
|
||
"type": "chunk" | "usage",
|
||
"content": str, # 文本片段(type=chunk时)
|
||
"total_tokens": int, # 累计token数(type=chunk时)
|
||
"prompt_tokens": int, # 提示词token数(type=usage时)
|
||
"completion_tokens": int, # 生成token数(type=usage时)
|
||
}
|
||
"""
|
||
try:
|
||
messages = []
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||
|
||
request_params = {
|
||
"model": model_id,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"stream": True,
|
||
}
|
||
|
||
if max_tokens:
|
||
request_params["max_tokens"] = max_tokens
|
||
|
||
# 流式调用
|
||
stream = self.client.chat.completions.create(**request_params)
|
||
|
||
total_chars = 0
|
||
prompt_tokens = 0
|
||
completion_tokens = 0
|
||
|
||
for chunk in stream:
|
||
# 获取文本内容
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
content = chunk.choices[0].delta.content
|
||
total_chars += len(content)
|
||
yield {
|
||
"type": "chunk",
|
||
"content": content,
|
||
"total_chars": total_chars,
|
||
}
|
||
|
||
# 获取使用统计(最后一个chunk)
|
||
if chunk.usage:
|
||
prompt_tokens = chunk.usage.prompt_tokens
|
||
completion_tokens = chunk.usage.completion_tokens
|
||
|
||
# 发送最终统计
|
||
yield {
|
||
"type": "usage",
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
}
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def generate_image(
|
||
self, prompt: str, model: str | None = None, size: str = "1024x1024", **kwargs
|
||
) -> dict:
|
||
"""
|
||
生成图片(Seedream 系列)
|
||
|
||
Args:
|
||
prompt: 图片提示词
|
||
model: 图片模型 ID
|
||
size: 图片尺寸
|
||
|
||
Returns:
|
||
dict: 包含图片 URL 或 base64 数据
|
||
"""
|
||
try:
|
||
# 图片生成需要单独的图片模型,不在当前配置中
|
||
# 如需使用,请在模型广场开通 doubao-seed-1.6 并配置
|
||
image_model = model or "doubao-seed-1.6-flash-250828"
|
||
response = self.client.images.generate(
|
||
model=image_model, prompt=prompt, size=size, **kwargs
|
||
)
|
||
|
||
# 解析图片结果
|
||
images = []
|
||
for img in response.data:
|
||
images.append(
|
||
{
|
||
"url": img.url,
|
||
"b64_json": img.b64_json,
|
||
"revised_prompt": img.revised_prompt,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"images": images,
|
||
"model": response.model,
|
||
}
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"火山方舟图片生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
|
||
"""
|
||
文本向量化
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
model: 向量化模型
|
||
|
||
Returns:
|
||
dict: 包含向量化结果
|
||
"""
|
||
try:
|
||
response = self.client.embeddings.create(
|
||
model=model or "doubao-embedding-1.5", input=texts, **kwargs
|
||
)
|
||
|
||
embeddings = []
|
||
for item in response.data:
|
||
embeddings.append(
|
||
{
|
||
"index": item.index,
|
||
"embedding": item.embedding,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"embeddings": embeddings,
|
||
"model": response.model,
|
||
"usage": {
|
||
"prompt_tokens": response.usage.prompt_tokens,
|
||
"total_tokens": response.usage.total_tokens,
|
||
},
|
||
}
|
||
|
||
except Exception as e:
|
||
raise ProviderError(
|
||
f"火山方舟向量化失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||
)
|
||
|
||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||
"""健康检查"""
|
||
start_time = time.time()
|
||
test_model = model or self.default_model
|
||
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=test_model,
|
||
messages=[{"role": "user", "content": "Hi"}],
|
||
max_tokens=5,
|
||
)
|
||
|
||
response_time = (time.time() - start_time) * 1000
|
||
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"火山方舟 {test_model}",
|
||
is_available=True,
|
||
response_time=response_time,
|
||
last_error=None,
|
||
)
|
||
|
||
except Exception as e:
|
||
return ModelHealth(
|
||
id=test_model,
|
||
name=f"火山方舟 {test_model}",
|
||
is_available=False,
|
||
response_time=(time.time() - start_time) * 1000,
|
||
last_error=str(e),
|
||
)
|
||
|
||
@property
|
||
def available_models(self) -> list[str]:
|
||
"""返回可用模型列表(与 ai_models.yaml 配置保持一致)"""
|
||
return [
|
||
"doubao-seed-2-0-pro",
|
||
"deepseek-v3-2",
|
||
"doubao-seed-2-0-lite",
|
||
"doubao-lite-32k",
|
||
]
|