Files
meijiaka-zy/python-api/app/ai/providers/volcengine_provider.py
T

465 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山方舟官方 SDK Provider
==========================
基于火山方舟官方 Python SDK 实现,支持:
- 文本生成 (Chat Completions)
- 流式输出
- 图片生成
- 向量化
- 深度思考
- 工具调用
安装依赖:
pip install 'volcengine-python-sdk[ark]'
文档:
https://www.volcengine.com/docs/82379
"""
from __future__ import annotations
import logging
import time
from collections.abc import AsyncIterator
from app.ai.providers.base import (
GenerationResult,
LLMProvider,
ModelHealth,
ProviderError,
)
logger = logging.getLogger(__name__)
# 尝试导入火山方舟 SDK
try:
from volcenginesdkarkruntime import Ark
VOLCENGINE_SDK_AVAILABLE = True
except ImportError:
VOLCENGINE_SDK_AVAILABLE = False
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
class VolcengineProvider(LLMProvider):
"""
火山方舟官方 SDK Provider
支持多模态能力:
- 文本对话 (Chat Completions)
- 图片生成 (Image Generation)
- 向量化 (Embeddings)
- 深度思考 (Reasoning)
"""
provider_id = "volcengine"
provider_name = "火山方舟"
# 默认配置
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
# 模型 ID 映射(从配置文件加载)
PRESET_MODELS: dict[str, str] = {}
@classmethod
def load_models_from_config(cls, models: list[dict]):
"""
从配置文件加载模型列表
Args:
models: 模型列表,每个模型包含 model_name 字段
"""
cls.PRESET_MODELS = {}
for model in models:
model_id = model.get("model_name")
model_alias = model.get("id")
if model_id and model_alias:
cls.PRESET_MODELS[model_alias] = model_id
# 确保至少有一个默认模型
if not cls.PRESET_MODELS:
cls.PRESET_MODELS = {
"doubao-seed-2-0-lite": "doubao-seed-2-0-lite-260215",
}
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
timeout: int = DEFAULT_TIMEOUT,
default_model: str | None = None,
**kwargs,
):
"""
初始化火山方舟 Provider
Args:
api_key: API Key,从火山方舟控制台获取
base_url: API 基础地址,默认北京节点
timeout: 请求超时时间(秒)
default_model: 默认模型(Model ID
"""
super().__init__(api_key, base_url, **kwargs)
if not VOLCENGINE_SDK_AVAILABLE:
raise ProviderError(
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
provider_id=self.provider_id,
)
if not self.api_key:
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
self.timeout = timeout
# 使用模型 ID 映射(自动映射)
if default_model:
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
elif self.PRESET_MODELS:
# 兜底:使用 doubao-seed-2-0-lite 或第一个可用的模型
self.default_model = self.PRESET_MODELS.get(
"doubao-seed-2-0-lite", list(self.PRESET_MODELS.values())[0]
)
else:
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
self.default_model = "doubao-seed-2-0-lite-260215"
self.client = self._create_client()
def _create_client(self) -> Ark:
"""创建火山方舟客户端"""
return Ark(
api_key=self.api_key,
base_url=self.base_url or self.DEFAULT_BASE_URL,
timeout=self.timeout,
)
async def generate(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
system_prompt: str | None = None,
**kwargs,
) -> GenerationResult:
"""
同步生成文本
Args:
prompt: 用户提示词
model: 模型 ID(如 doubao-seed-2-0-pro-260215
temperature: 随机性 (0-2)
max_tokens: 最大生成 token 数
system_prompt: 系统提示词(可选)
**kwargs: 额外参数(如 enable_thinking 启用深度思考)
Returns:
GenerationResult: 生成结果
"""
try:
# 构建消息
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 映射模型名称到模型 ID
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
# 构建请求参数
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 额外参数(如深度思考)
if "enable_thinking" in kwargs:
request_params["extra_body"] = {"enable_thinking": kwargs["enable_thinking"]}
# 调用 API
completion = self.client.chat.completions.create(**request_params)
# 解析结果
content = completion.choices[0].message.content or ""
usage = None
if completion.usage:
usage = {
"prompt_tokens": completion.usage.prompt_tokens,
"completion_tokens": completion.usage.completion_tokens,
"total_tokens": completion.usage.total_tokens,
}
return GenerationResult(
content=content,
usage=usage,
model=completion.model or model or self.default_model,
)
except Exception as e:
raise ProviderError(
f"火山方舟生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_stream(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
system_prompt: str | None = None,
**kwargs,
) -> AsyncIterator[str]:
"""
流式生成文本
Args:
prompt: 用户提示词
model: 模型名称
temperature: 随机性
max_tokens: 最大 token 数
system_prompt: 系统提示词(可选)
**kwargs: 额外参数
Yields:
str: 生成的文本片段
"""
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
stream = self.client.chat.completions.create(**request_params)
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
raise ProviderError(
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_stream_with_progress(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = 8000,
system_prompt: str | None = None,
**kwargs,
) -> AsyncIterator[dict]:
"""
流式生成文本,带进度信息
Args:
prompt: 用户提示词
model: 模型名称
temperature: 随机性
max_tokens: 最大 token 数
system_prompt: 系统提示词(可选)
Yields:
dict: {
"type": "chunk" | "usage",
"content": str, # 文本片段(type=chunk时)
"total_tokens": int, # 累计token数(type=chunk时)
"prompt_tokens": int, # 提示词token数(type=usage时)
"completion_tokens": int, # 生成token数(type=usage时)
}
"""
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 流式调用
stream = self.client.chat.completions.create(**request_params)
total_chars = 0
prompt_tokens = 0
completion_tokens = 0
for chunk in stream:
# 获取文本内容
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
total_chars += len(content)
yield {
"type": "chunk",
"content": content,
"total_chars": total_chars,
}
# 获取使用统计(最后一个chunk
if chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
# 发送最终统计
yield {
"type": "usage",
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
}
except Exception as e:
raise ProviderError(
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_image(
self, prompt: str, model: str | None = None, size: str = "1024x1024", **kwargs
) -> dict:
"""
生成图片(Seedream 系列)
Args:
prompt: 图片提示词
model: 图片模型 ID
size: 图片尺寸
Returns:
dict: 包含图片 URL 或 base64 数据
"""
try:
# 图片生成需要单独的图片模型,不在当前配置中
# 如需使用,请在模型广场开通 doubao-seed-1.6 并配置
image_model = model or "doubao-seed-1.6-flash-250828"
response = self.client.images.generate(
model=image_model, prompt=prompt, size=size, **kwargs
)
# 解析图片结果
images = []
for img in response.data:
images.append(
{
"url": img.url,
"b64_json": img.b64_json,
"revised_prompt": img.revised_prompt,
}
)
return {
"images": images,
"model": response.model,
}
except Exception as e:
raise ProviderError(
f"火山方舟图片生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
"""
文本向量化
Args:
texts: 文本列表
model: 向量化模型
Returns:
dict: 包含向量化结果
"""
try:
response = self.client.embeddings.create(
model=model or "doubao-embedding-1.5", input=texts, **kwargs
)
embeddings = []
for item in response.data:
embeddings.append(
{
"index": item.index,
"embedding": item.embedding,
}
)
return {
"embeddings": embeddings,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
except Exception as e:
raise ProviderError(
f"火山方舟向量化失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def health_check(self, model: str | None = None) -> ModelHealth:
"""健康检查"""
start_time = time.time()
test_model = model or self.default_model
try:
response = self.client.chat.completions.create(
model=test_model,
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=True,
response_time=response_time,
last_error=None,
)
except Exception as e:
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=False,
response_time=(time.time() - start_time) * 1000,
last_error=str(e),
)
@property
def available_models(self) -> list[str]:
"""返回可用模型列表(与 ai_models.yaml 配置保持一致)"""
return [
"doubao-seed-2-0-pro",
"deepseek-v3-2",
"doubao-seed-2-0-lite",
"doubao-lite-32k",
]