141 lines
3.1 KiB
Python
141 lines
3.1 KiB
Python
"""
|
||
LLM Provider 抽象基类
|
||
=====================
|
||
|
||
定义所有 AI 模型提供商的统一接口。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from abc import ABC, abstractmethod
|
||
from collections.abc import AsyncIterator
|
||
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class ModelHealth(BaseModel):
|
||
"""模型健康状态"""
|
||
|
||
id: str
|
||
name: str
|
||
is_available: bool
|
||
response_time: float # 毫秒
|
||
last_error: str | None = None
|
||
|
||
|
||
class GenerationResult(BaseModel):
|
||
"""生成结果"""
|
||
|
||
content: str
|
||
usage: dict | None = None # token 用量等
|
||
model: str # 实际使用的模型
|
||
|
||
|
||
class LLMProvider(ABC):
|
||
"""
|
||
LLM 提供商抽象基类
|
||
|
||
所有 AI 模型提供商(OpenAI、文心一言、通义千问等)需实现此接口。
|
||
"""
|
||
|
||
# 提供商标识
|
||
provider_id: str = ""
|
||
provider_name: str = ""
|
||
|
||
def __init__(self, api_key: str | None = None, base_url: str | None = None, **kwargs):
|
||
"""
|
||
初始化 Provider
|
||
|
||
Args:
|
||
api_key: API 密钥
|
||
base_url: 自定义 Base URL(用于代理或私有部署)
|
||
**kwargs: 其他配置参数
|
||
"""
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.config = kwargs
|
||
|
||
@abstractmethod
|
||
async def generate(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> GenerationResult:
|
||
"""
|
||
同步生成文本
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称,None 则使用默认模型
|
||
temperature: 随机性(0-2)
|
||
max_tokens: 最大生成 token 数
|
||
**kwargs: 额外参数
|
||
|
||
Returns:
|
||
GenerationResult: 生成结果
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def generate_stream(
|
||
self,
|
||
prompt: str,
|
||
model: str | None = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int | None = None,
|
||
**kwargs,
|
||
) -> AsyncIterator[str]:
|
||
"""
|
||
流式生成文本
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
model: 模型名称
|
||
temperature: 随机性
|
||
max_tokens: 最大 token 数
|
||
**kwargs: 额外参数
|
||
|
||
Yields:
|
||
str: 生成的文本片段
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||
"""
|
||
健康检查
|
||
|
||
Args:
|
||
model: 指定模型,None 则检查默认模型
|
||
|
||
Returns:
|
||
ModelHealth: 健康状态
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def available_models(self) -> list[str]:
|
||
"""返回可用的模型列表"""
|
||
pass
|
||
|
||
|
||
class ProviderError(Exception):
|
||
"""Provider 调用异常"""
|
||
|
||
def __init__(
|
||
self, message: str, provider_id: str = "", original_error: Exception | None = None
|
||
):
|
||
super().__init__(message)
|
||
self.provider_id = provider_id
|
||
self.original_error = original_error
|
||
|
||
|
||
class ModelUnavailableError(ProviderError):
|
||
"""模型不可用异常"""
|
||
|
||
pass
|