""" 火山方舟官方 SDK Provider ========================== 基于火山方舟官方 Python SDK 实现,支持: - 文本生成 (Chat Completions) - 流式输出 - 向量化 - 深度思考 - 工具调用 安装依赖: pip install 'volcengine-python-sdk[ark]' 文档: https://www.volcengine.com/docs/82379 """ from __future__ import annotations import logging import time from app.ai.providers.base import ( GenerationResult, LLMProvider, ModelHealth, ProviderError, ) from app.core.exceptions import PlatformError, PlatformErrorType logger = logging.getLogger(__name__) # 尝试导入火山方舟 SDK try: from volcenginesdkarkruntime import AsyncArk VOLCENGINE_SDK_AVAILABLE = True except ImportError: VOLCENGINE_SDK_AVAILABLE = False logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'") class VolcengineProvider(LLMProvider): """ 火山方舟官方 SDK Provider 支持能力: - 文本对话 (Chat Completions) - 向量化 (Embeddings) - 深度思考 (Reasoning) """ provider_id = "volcengine" provider_name = "火山方舟" # 默认配置 DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3" DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上 # 模型 ID 映射(从配置文件加载) PRESET_MODELS: dict[str, str] = {} @classmethod def load_models_from_config(cls, models: list[dict]): """ 从配置文件加载模型列表 Args: models: 模型列表,每个模型包含 model_name 字段 """ cls.PRESET_MODELS = {} for model in models: model_id = model.get("model_name") model_alias = model.get("id") if model_id and model_alias: cls.PRESET_MODELS[model_alias] = model_id # 确保至少有一个默认模型 if not cls.PRESET_MODELS: cls.PRESET_MODELS = { "doubao-seed-2-0-pro": "doubao-seed-2-0-pro-260215", } logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}") def __init__( self, api_key: str | None = None, base_url: str | None = None, timeout: int = DEFAULT_TIMEOUT, default_model: str | None = None, **kwargs, ): """ 初始化火山方舟 Provider Args: api_key: API Key,从火山方舟控制台获取 base_url: API 基础地址,默认北京节点 timeout: 请求超时时间(秒) default_model: 默认模型(Model ID) """ from app.config import get_settings from app.core.platform_config import get_platform_config_loader settings = get_settings() # API Key 从环境变量读取 resolved_api_key = api_key or settings.VOLCENGINE_API_KEY # base_url 从 platform-config.yaml 读取,fallback 到代码常量 loader = get_platform_config_loader() platform = loader.get_platform("volcengine_ark") yaml_base_url = platform.base_url if platform else None resolved_base_url = base_url or yaml_base_url or self.DEFAULT_BASE_URL super().__init__(resolved_api_key, resolved_base_url, **kwargs) if not VOLCENGINE_SDK_AVAILABLE: raise ProviderError( "火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'", provider_id=self.provider_id, ) if not self.api_key: raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id) self.timeout = timeout # 使用模型 ID 映射(自动映射) if default_model: self.default_model = self.PRESET_MODELS.get(default_model, default_model) elif self.PRESET_MODELS: self.default_model = list(self.PRESET_MODELS.values())[0] else: # 兜底:使用一个默认模型ID(如果用户未配置任何模型) self.default_model = "doubao-seed-2-0-pro-260215" self.client = self._create_client() def _create_client(self) -> AsyncArk: """创建火山方舟异步客户端""" return AsyncArk( api_key=self.api_key, base_url=self.base_url or self.DEFAULT_BASE_URL, timeout=self.timeout, ) async def generate( self, prompt: str, model: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, system_prompt: str | None = None, **kwargs, ) -> GenerationResult: """ 同步生成文本 Args: prompt: 用户提示词 model: 模型 ID(如 doubao-seed-2-0-pro-260215) temperature: 随机性 (0-2) max_tokens: 最大生成 token 数 system_prompt: 系统提示词(可选) **kwargs: 额外参数(如 reasoning_effort 控制思考程度) Returns: GenerationResult: 生成结果 """ try: # 构建消息 messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # 映射模型名称到模型 ID model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model # 构建请求参数 request_params = { "model": model_id, "messages": messages, "temperature": temperature, } if max_tokens: request_params["max_tokens"] = max_tokens # 额外参数(如深度思考) if "reasoning_effort" in kwargs: request_params["reasoning_effort"] = kwargs["reasoning_effort"] if kwargs.get("response_format") == "json_object": request_params["response_format"] = {"type": "json_object"} # 调用 API completion = await self.client.chat.completions.create(**request_params) # 解析结果 content = completion.choices[0].message.content or "" usage = None if completion.usage: usage = { "prompt_tokens": completion.usage.prompt_tokens, "completion_tokens": completion.usage.completion_tokens, "total_tokens": completion.usage.total_tokens, } return GenerationResult( content=content, usage=usage, model=completion.model or model or self.default_model, ) except Exception as e: raise self._wrap_error(e) async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict: """ 文本向量化 Args: texts: 文本列表 model: 向量化模型 Returns: dict: 包含向量化结果 """ try: response = await self.client.embeddings.create( model=model or "doubao-embedding-1.5", input=texts, **kwargs ) embeddings = [] for item in response.data: embeddings.append( { "index": item.index, "embedding": item.embedding, } ) return { "embeddings": embeddings, "model": response.model, "usage": { "prompt_tokens": response.usage.prompt_tokens, "total_tokens": response.usage.total_tokens, }, } except Exception as e: raise self._wrap_error(e) async def health_check(self, model: str | None = None) -> ModelHealth: """健康检查""" start_time = time.time() test_model = model or self.default_model try: await self.client.chat.completions.create( model=test_model, messages=[{"role": "user", "content": "Hi"}], max_tokens=5, ) response_time = (time.time() - start_time) * 1000 return ModelHealth( id=test_model, name=f"火山方舟 {test_model}", is_available=True, response_time=response_time, last_error=None, ) except Exception as e: return ModelHealth( id=test_model, name=f"火山方舟 {test_model}", is_available=False, response_time=(time.time() - start_time) * 1000, last_error=str(e), ) def _wrap_error(self, e: Exception) -> PlatformError: """把 SDK 异常翻译为标准 PlatformError""" message = str(e) status = getattr(e, "status_code", None) or getattr(e, "code", None) if status == 429 or "rate limit" in message.lower(): return PlatformError( message, platform="volcengine_ark", retryable=True, error_type=PlatformErrorType.RATE_LIMIT, status_code=status, ) elif status in (401, 403) or "authentication" in message.lower(): return PlatformError( message, platform="volcengine_ark", retryable=False, error_type=PlatformErrorType.AUTH_FAILED, status_code=status, ) elif status and status >= 500: return PlatformError( message, platform="volcengine_ark", retryable=True, error_type=PlatformErrorType.SERVER_ERROR, status_code=status, ) elif "timeout" in message.lower() or isinstance(e, TimeoutError): return PlatformError( message, platform="volcengine_ark", retryable=True, error_type=PlatformErrorType.TIMEOUT, ) else: return PlatformError( message, platform="volcengine_ark", retryable=False, error_type=PlatformErrorType.UNKNOWN, ) @property def available_models(self) -> list[str]: """返回可用模型列表(与 platform-config.yaml 配置保持一致)""" return [ "doubao-seed-2-0-pro", ]