Files
meijiaka-zy/python-api/app/ai/providers/volcengine_provider.py
T
小鱼开发 d0057ecc2c feat: 脚本生成流式优化 - Ark SDK 迁移至 httpx SSE + reasoning_effort 关闭思考过程
- volcengine_provider: Ark SDK 同步迭代器 → AsyncOpenAI → httpx 原始 SSE
  - generate_stream_with_progress 使用 httpx 直接请求,消除 80s+ 缓冲
  - 新增 generate_stream (AsyncOpenAI) 作为备用方案
  - enable_thinking 替换为 reasoning_effort,支持思考程度控制
- ai_models.yaml: 默认 LLM 改为 doubao-seed-2-0-pro,添加 reasoning_effort: minimal
- model_router: 透传 reasoning_effort 参数
- script_service: 4 阶段 SSE 精简 (start→analyzing→generating→complete)
- script.py: SSE 直连端点 /script/generate/stream
- 前端 ScriptCreation: 直连 SSE 端点,弃用调度器轮询模式
2026-04-26 20:17:12 +08:00

530 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
火山方舟官方 SDK Provider
==========================
基于火山方舟官方 Python SDK 实现,支持:
- 文本生成 (Chat Completions)
- 流式输出
- 图片生成
- 向量化
- 深度思考
- 工具调用
安装依赖:
pip install 'volcengine-python-sdk[ark]'
文档:
https://www.volcengine.com/docs/82379
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import AsyncIterator
from app.ai.providers.base import (
GenerationResult,
LLMProvider,
ModelHealth,
ProviderError,
)
logger = logging.getLogger(__name__)
# 尝试导入火山方舟 SDK
try:
from volcenginesdkarkruntime import Ark
VOLCENGINE_SDK_AVAILABLE = True
except ImportError:
VOLCENGINE_SDK_AVAILABLE = False
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
try:
from openai import AsyncOpenAI
ASYNC_OPENAI_AVAILABLE = True
except ImportError:
ASYNC_OPENAI_AVAILABLE = False
logger.warning("OpenAI SDK 未安装,流式生成将不可用")
class VolcengineProvider(LLMProvider):
"""
火山方舟官方 SDK Provider
支持多模态能力:
- 文本对话 (Chat Completions)
- 图片生成 (Image Generation)
- 向量化 (Embeddings)
- 深度思考 (Reasoning)
"""
provider_id = "volcengine"
provider_name = "火山方舟"
# 默认配置
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
# 模型 ID 映射(从配置文件加载)
PRESET_MODELS: dict[str, str] = {}
@classmethod
def load_models_from_config(cls, models: list[dict]):
"""
从配置文件加载模型列表
Args:
models: 模型列表,每个模型包含 model_name 字段
"""
cls.PRESET_MODELS = {}
for model in models:
model_id = model.get("model_name")
model_alias = model.get("id")
if model_id and model_alias:
cls.PRESET_MODELS[model_alias] = model_id
# 确保至少有一个默认模型
if not cls.PRESET_MODELS:
cls.PRESET_MODELS = {
"doubao-seed-2-0-lite": "doubao-seed-2-0-lite-260215",
}
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
timeout: int = DEFAULT_TIMEOUT,
default_model: str | None = None,
**kwargs,
):
"""
初始化火山方舟 Provider
Args:
api_key: API Key,从火山方舟控制台获取
base_url: API 基础地址,默认北京节点
timeout: 请求超时时间(秒)
default_model: 默认模型(Model ID
"""
super().__init__(api_key, base_url, **kwargs)
if not VOLCENGINE_SDK_AVAILABLE:
raise ProviderError(
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
provider_id=self.provider_id,
)
if not self.api_key:
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
self.timeout = timeout
# 使用模型 ID 映射(自动映射)
if default_model:
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
elif self.PRESET_MODELS:
# 兜底:使用 doubao-seed-2-0-lite 或第一个可用的模型
self.default_model = self.PRESET_MODELS.get(
"doubao-seed-2-0-lite", list(self.PRESET_MODELS.values())[0]
)
else:
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
self.default_model = "doubao-seed-2-0-lite-260215"
self.client = self._create_client()
import httpx
self.async_client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url or self.DEFAULT_BASE_URL,
http_client=httpx.AsyncClient(headers={"Accept-Encoding": "identity"}),
)
def _create_client(self) -> Ark:
"""创建火山方舟客户端"""
return Ark(
api_key=self.api_key,
base_url=self.base_url or self.DEFAULT_BASE_URL,
timeout=self.timeout,
)
async def generate(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
system_prompt: str | None = None,
**kwargs,
) -> GenerationResult:
"""
同步生成文本
Args:
prompt: 用户提示词
model: 模型 ID(如 doubao-seed-2-0-pro-260215
temperature: 随机性 (0-2)
max_tokens: 最大生成 token 数
system_prompt: 系统提示词(可选)
**kwargs: 额外参数(如 reasoning_effort 控制思考程度)
Returns:
GenerationResult: 生成结果
"""
try:
# 构建消息
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 映射模型名称到模型 ID
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
# 构建请求参数
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
# 额外参数(如深度思考)
if "reasoning_effort" in kwargs:
request_params["extra_body"] = {"reasoning_effort": kwargs["reasoning_effort"]}
# 调用 API
completion = self.client.chat.completions.create(**request_params)
# 解析结果
content = completion.choices[0].message.content or ""
usage = None
if completion.usage:
usage = {
"prompt_tokens": completion.usage.prompt_tokens,
"completion_tokens": completion.usage.completion_tokens,
"total_tokens": completion.usage.total_tokens,
}
return GenerationResult(
content=content,
usage=usage,
model=completion.model or model or self.default_model,
)
except Exception as e:
raise ProviderError(
f"火山方舟生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_stream(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = None,
system_prompt: str | None = None,
**kwargs,
) -> AsyncIterator[str]:
"""
流式生成文本
使用 AsyncOpenAI 客户端避免 Ark SDK 同步流式缓冲问题。
Args:
prompt: 用户提示词
model: 模型名称
temperature: 随机性
max_tokens: 最大 token 数
system_prompt: 系统提示词(可选)
**kwargs: 额外参数
Yields:
str: 生成的文本片段
"""
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
request_params = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
request_params["max_tokens"] = max_tokens
if "reasoning_effort" in kwargs:
request_params["extra_body"] = {"reasoning_effort": kwargs["reasoning_effort"]}
stream = await self.async_client.chat.completions.create(**request_params)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
raise ProviderError(
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_stream_with_progress(
self,
prompt: str,
model: str | None = None,
temperature: float = 0.7,
max_tokens: int | None = 8000,
system_prompt: str | None = None,
**kwargs,
) -> AsyncIterator[dict]:
"""
流式生成文本,带进度信息
使用原始 httpx SSE 请求,完全控制请求头和请求体。
Args:
prompt: 用户提示词
model: 模型名称
temperature: 随机性
max_tokens: 最大 token 数
system_prompt: 系统提示词(可选)
Yields:
dict: {
"type": "chunk" | "usage",
"content": str, # 文本片段(type=chunk时)
"total_tokens": int, # 累计token数(type=chunk时)
"prompt_tokens": int, # 提示词token数(type=usage时)
"completion_tokens": int, # 生成token数(type=usage时)
}
"""
import json
import httpx
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
request_body = {
"model": model_id,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
request_body["max_tokens"] = max_tokens
if "reasoning_effort" in kwargs:
request_body["reasoning_effort"] = kwargs["reasoning_effort"]
req_start = time.time()
logger.info(f"[Volcengine] request_body model={model_id}, max_tokens={max_tokens}, reasoning_effort={kwargs.get('reasoning_effort', 'NOT_SET')}")
total_chars = 0
prompt_tokens = 0
completion_tokens = 0
first_chunk = True
first_yield = True
chunk_idx = 0
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url or self.DEFAULT_BASE_URL}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "text/event-stream",
},
json=request_body,
timeout=300,
) as response:
logger.info(f"[Volcengine] response status={response.status_code}, headers={dict(response.headers)}")
buffer = ""
async for raw_line in response.aiter_lines():
buffer += raw_line + "\n"
if raw_line.startswith("data: "):
data = raw_line[6:]
if data == "[DONE]":
break
try:
chunk = json.loads(data)
chunk_idx += 1
choices = chunk.get("choices", [])
delta = choices[0].get("delta", {}) if choices else {}
content = delta.get("content", "")
if first_chunk:
first_chunk = False
logger.info(f"[Volcengine] 首 chunk 耗时: {time.time() - req_start:.3f}s, idx={chunk_idx}, content={content!r}, delta_keys={list(delta.keys())}")
if content:
if first_yield:
first_yield = False
logger.info(f"[Volcengine] 首个有内容 chunk 耗时: {time.time() - req_start:.3f}s, idx={chunk_idx}, content_len={len(content)}")
total_chars += len(content)
yield {
"type": "chunk",
"content": content,
"total_chars": total_chars,
}
usage = chunk.get("usage")
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
except json.JSONDecodeError:
continue
logger.info(f"[Volcengine] 流结束, 总chunk数={chunk_idx}, 有内容chunk数={chunk_idx - (1 if first_chunk else 0)}")
yield {
"type": "usage",
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
}
except Exception as e:
raise ProviderError(
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def generate_image(
self, prompt: str, model: str | None = None, size: str = "1024x1024", **kwargs
) -> dict:
"""
生成图片(Seedream 系列)
Args:
prompt: 图片提示词
model: 图片模型 ID
size: 图片尺寸
Returns:
dict: 包含图片 URL 或 base64 数据
"""
try:
# 图片生成需要单独的图片模型,不在当前配置中
# 如需使用,请在模型广场开通 doubao-seed-1.6 并配置
image_model = model or "doubao-seed-1.6-flash-250828"
response = self.client.images.generate(
model=image_model, prompt=prompt, size=size, **kwargs
)
# 解析图片结果
images = []
for img in response.data:
images.append(
{
"url": img.url,
"b64_json": img.b64_json,
"revised_prompt": img.revised_prompt,
}
)
return {
"images": images,
"model": response.model,
}
except Exception as e:
raise ProviderError(
f"火山方舟图片生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
"""
文本向量化
Args:
texts: 文本列表
model: 向量化模型
Returns:
dict: 包含向量化结果
"""
try:
response = self.client.embeddings.create(
model=model or "doubao-embedding-1.5", input=texts, **kwargs
)
embeddings = []
for item in response.data:
embeddings.append(
{
"index": item.index,
"embedding": item.embedding,
}
)
return {
"embeddings": embeddings,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
except Exception as e:
raise ProviderError(
f"火山方舟向量化失败: {str(e)}", provider_id=self.provider_id, original_error=e
)
async def health_check(self, model: str | None = None) -> ModelHealth:
"""健康检查"""
start_time = time.time()
test_model = model or self.default_model
try:
response = self.client.chat.completions.create(
model=test_model,
messages=[{"role": "user", "content": "Hi"}],
max_tokens=5,
)
response_time = (time.time() - start_time) * 1000
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=True,
response_time=response_time,
last_error=None,
)
except Exception as e:
return ModelHealth(
id=test_model,
name=f"火山方舟 {test_model}",
is_available=False,
response_time=(time.time() - start_time) * 1000,
last_error=str(e),
)
@property
def available_models(self) -> list[str]:
"""返回可用模型列表(与 ai_models.yaml 配置保持一致)"""
return [
"doubao-seed-2-0-pro",
"deepseek-v3-2",
"doubao-seed-2-0-lite",
"doubao-lite-32k",
]