""" 火山方舟官方 SDK Provider ========================== 基于火山方舟官方 Python SDK 实现,支持: - 文本生成 (Chat Completions) - 流式输出 - 图片生成 - 向量化 - 深度思考 - 工具调用 安装依赖: pip install 'volcengine-python-sdk[ark]' 文档: https://www.volcengine.com/docs/82379 """ from __future__ import annotations import logging import time from collections.abc import AsyncIterator from app.ai.providers.base import ( GenerationResult, LLMProvider, ModelHealth, ProviderError, ) logger = logging.getLogger(__name__) # 尝试导入火山方舟 SDK try: from volcenginesdkarkruntime import Ark VOLCENGINE_SDK_AVAILABLE = True except ImportError: VOLCENGINE_SDK_AVAILABLE = False logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'") class VolcengineProvider(LLMProvider): """ 火山方舟官方 SDK Provider 支持多模态能力: - 文本对话 (Chat Completions) - 图片生成 (Image Generation) - 向量化 (Embeddings) - 深度思考 (Reasoning) """ provider_id = "volcengine" provider_name = "火山方舟" # 默认配置 DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3" DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上 # 模型 ID 映射(从配置文件加载) PRESET_MODELS: dict[str, str] = {} @classmethod def load_models_from_config(cls, models: list[dict]): """ 从配置文件加载模型列表 Args: models: 模型列表,每个模型包含 model_name 字段 """ cls.PRESET_MODELS = {} for model in models: model_id = model.get("model_name") model_alias = model.get("id") if model_id and model_alias: cls.PRESET_MODELS[model_alias] = model_id # 确保至少有一个默认模型 if not cls.PRESET_MODELS: cls.PRESET_MODELS = { "doubao-seed-2-0-lite": "doubao-seed-2-0-lite-260215", } logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}") def __init__( self, api_key: str | None = None, base_url: str | None = None, timeout: int = DEFAULT_TIMEOUT, default_model: str | None = None, **kwargs, ): """ 初始化火山方舟 Provider Args: api_key: API Key,从火山方舟控制台获取 base_url: API 基础地址,默认北京节点 timeout: 请求超时时间(秒) default_model: 默认模型(Model ID) """ super().__init__(api_key, base_url, **kwargs) if not VOLCENGINE_SDK_AVAILABLE: raise ProviderError( "火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'", provider_id=self.provider_id, ) if not self.api_key: raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id) self.timeout = timeout # 使用模型 ID 映射(自动映射) if default_model: self.default_model = self.PRESET_MODELS.get(default_model, default_model) elif self.PRESET_MODELS: # 兜底:使用 doubao-seed-2-0-lite 或第一个可用的模型 self.default_model = self.PRESET_MODELS.get( "doubao-seed-2-0-lite", list(self.PRESET_MODELS.values())[0] ) else: # 兜底:使用一个默认模型ID(如果用户未配置任何模型) self.default_model = "doubao-seed-2-0-lite-260215" self.client = self._create_client() def _create_client(self) -> Ark: """创建火山方舟客户端""" return Ark( api_key=self.api_key, base_url=self.base_url or self.DEFAULT_BASE_URL, timeout=self.timeout, ) async def generate( self, prompt: str, model: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, system_prompt: str | None = None, **kwargs, ) -> GenerationResult: """ 同步生成文本 Args: prompt: 用户提示词 model: 模型 ID(如 doubao-seed-2-0-pro-260215) temperature: 随机性 (0-2) max_tokens: 最大生成 token 数 system_prompt: 系统提示词(可选) **kwargs: 额外参数(如 enable_thinking 启用深度思考) Returns: GenerationResult: 生成结果 """ try: # 构建消息 messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # 映射模型名称到模型 ID model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model # 构建请求参数 request_params = { "model": model_id, "messages": messages, "temperature": temperature, } if max_tokens: request_params["max_tokens"] = max_tokens # 额外参数(如深度思考) if "enable_thinking" in kwargs: request_params["extra_body"] = {"enable_thinking": kwargs["enable_thinking"]} # 调用 API completion = self.client.chat.completions.create(**request_params) # 解析结果 content = completion.choices[0].message.content or "" usage = None if completion.usage: usage = { "prompt_tokens": completion.usage.prompt_tokens, "completion_tokens": completion.usage.completion_tokens, "total_tokens": completion.usage.total_tokens, } return GenerationResult( content=content, usage=usage, model=completion.model or model or self.default_model, ) except Exception as e: raise ProviderError( f"火山方舟生成失败: {str(e)}", provider_id=self.provider_id, original_error=e ) async def generate_stream( self, prompt: str, model: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, system_prompt: str | None = None, **kwargs, ) -> AsyncIterator[str]: """ 流式生成文本 Args: prompt: 用户提示词 model: 模型名称 temperature: 随机性 max_tokens: 最大 token 数 system_prompt: 系统提示词(可选) **kwargs: 额外参数 Yields: str: 生成的文本片段 """ try: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model request_params = { "model": model_id, "messages": messages, "temperature": temperature, "stream": True, } if max_tokens: request_params["max_tokens"] = max_tokens stream = self.client.chat.completions.create(**request_params) for chunk in stream: if chunk.choices and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content except Exception as e: raise ProviderError( f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e ) async def generate_stream_with_progress( self, prompt: str, model: str | None = None, temperature: float = 0.7, max_tokens: int | None = 8000, system_prompt: str | None = None, **kwargs, ) -> AsyncIterator[dict]: """ 流式生成文本,带进度信息 Args: prompt: 用户提示词 model: 模型名称 temperature: 随机性 max_tokens: 最大 token 数 system_prompt: 系统提示词(可选) Yields: dict: { "type": "chunk" | "usage", "content": str, # 文本片段(type=chunk时) "total_tokens": int, # 累计token数(type=chunk时) "prompt_tokens": int, # 提示词token数(type=usage时) "completion_tokens": int, # 生成token数(type=usage时) } """ try: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model request_params = { "model": model_id, "messages": messages, "temperature": temperature, "stream": True, } if max_tokens: request_params["max_tokens"] = max_tokens # 流式调用 stream = self.client.chat.completions.create(**request_params) total_chars = 0 prompt_tokens = 0 completion_tokens = 0 for chunk in stream: # 获取文本内容 if chunk.choices and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content total_chars += len(content) yield { "type": "chunk", "content": content, "total_chars": total_chars, } # 获取使用统计(最后一个chunk) if chunk.usage: prompt_tokens = chunk.usage.prompt_tokens completion_tokens = chunk.usage.completion_tokens # 发送最终统计 yield { "type": "usage", "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, } except Exception as e: raise ProviderError( f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e ) async def generate_image( self, prompt: str, model: str | None = None, size: str = "1024x1024", **kwargs ) -> dict: """ 生成图片(Seedream 系列) Args: prompt: 图片提示词 model: 图片模型 ID size: 图片尺寸 Returns: dict: 包含图片 URL 或 base64 数据 """ try: # 图片生成需要单独的图片模型,不在当前配置中 # 如需使用,请在模型广场开通 doubao-seed-1.6 并配置 image_model = model or "doubao-seed-1.6-flash-250828" response = self.client.images.generate( model=image_model, prompt=prompt, size=size, **kwargs ) # 解析图片结果 images = [] for img in response.data: images.append( { "url": img.url, "b64_json": img.b64_json, "revised_prompt": img.revised_prompt, } ) return { "images": images, "model": response.model, } except Exception as e: raise ProviderError( f"火山方舟图片生成失败: {str(e)}", provider_id=self.provider_id, original_error=e ) async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict: """ 文本向量化 Args: texts: 文本列表 model: 向量化模型 Returns: dict: 包含向量化结果 """ try: response = self.client.embeddings.create( model=model or "doubao-embedding-1.5", input=texts, **kwargs ) embeddings = [] for item in response.data: embeddings.append( { "index": item.index, "embedding": item.embedding, } ) return { "embeddings": embeddings, "model": response.model, "usage": { "prompt_tokens": response.usage.prompt_tokens, "total_tokens": response.usage.total_tokens, }, } except Exception as e: raise ProviderError( f"火山方舟向量化失败: {str(e)}", provider_id=self.provider_id, original_error=e ) async def health_check(self, model: str | None = None) -> ModelHealth: """健康检查""" start_time = time.time() test_model = model or self.default_model try: response = self.client.chat.completions.create( model=test_model, messages=[{"role": "user", "content": "Hi"}], max_tokens=5, ) response_time = (time.time() - start_time) * 1000 return ModelHealth( id=test_model, name=f"火山方舟 {test_model}", is_available=True, response_time=response_time, last_error=None, ) except Exception as e: return ModelHealth( id=test_model, name=f"火山方舟 {test_model}", is_available=False, response_time=(time.time() - start_time) * 1000, last_error=str(e), ) @property def available_models(self) -> list[str]: """返回可用模型列表(与 ai_models.yaml 配置保持一致)""" return [ "doubao-seed-2-0-pro", "deepseek-v3-2", "doubao-seed-2-0-lite", "doubao-lite-32k", ]