feat: init meijiaka-zj project from ai-meijiaka template
This commit is contained in:
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
AI 模型路由 V2 - 基于文件配置
|
||||
=================================
|
||||
|
||||
从 YAML 配置文件加载平台/模型配置,支持热重载。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
|
||||
from app.ai.providers.generic_llm_provider import MockProvider
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||||
from app.config import get_settings
|
||||
from app.core.config_loader import AIModelConfigLoader, get_config_loader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformInstance:
|
||||
"""平台实例包装器"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.provider = self._create_provider()
|
||||
|
||||
def _create_provider(self):
|
||||
"""根据平台类型创建 Provider
|
||||
|
||||
API Key 从 Settings 读取(符合配置规范)
|
||||
"""
|
||||
provider_type = self.config.get("provider", "mock")
|
||||
settings = get_settings()
|
||||
|
||||
if provider_type == "volcengine":
|
||||
# 从 Settings 读取 API Key
|
||||
api_key = settings.VOLCENGINE_API_KEY
|
||||
if not api_key:
|
||||
raise ProviderError(
|
||||
"Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY"
|
||||
)
|
||||
return VolcengineProvider(
|
||||
api_key=api_key,
|
||||
base_url=self.config.get("base_url") or settings.VOLCENGINE_BASE_URL,
|
||||
)
|
||||
elif provider_type == "klingai":
|
||||
# 从 Settings 读取 AK/SK
|
||||
access_key = settings.KLINGAI_ACCESS_KEY
|
||||
secret_key = settings.KLINGAI_SECRET_KEY
|
||||
if not access_key or not secret_key:
|
||||
raise ProviderError(
|
||||
"KlingAI Access/Secret Key 未配置,请在 .env 中设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY"
|
||||
)
|
||||
return KlingAIProvider(
|
||||
config={
|
||||
"access_key": access_key,
|
||||
"secret_key": secret_key,
|
||||
"base_url": self.config.get("base_url"),
|
||||
}
|
||||
)
|
||||
elif provider_type == "mock":
|
||||
return MockProvider()
|
||||
else:
|
||||
raise ProviderError(f"不支持的 Provider 类型: {provider_type}")
|
||||
|
||||
async def generate(
|
||||
self, model_name: str, prompt: str, **kwargs
|
||||
) -> GenerationResult:
|
||||
"""调用生成"""
|
||||
return await self.provider.generate(prompt=prompt, model=model_name, **kwargs)
|
||||
|
||||
async def generate_stream(
|
||||
self, model_name: str, prompt: str, **kwargs
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式生成"""
|
||||
async for chunk in self.provider.generate_stream(
|
||||
prompt=prompt, model=model_name, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def health_check(self, model_name: str | None = None) -> ModelHealth:
|
||||
"""健康检查"""
|
||||
return await self.provider.health_check(model_name)
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由 V2 - 基于文件配置
|
||||
|
||||
支持:
|
||||
- 从 YAML 文件加载配置
|
||||
- 多平台配置
|
||||
- 每平台多模型
|
||||
- 模型自动选择
|
||||
- 故障降级
|
||||
- 配置热重载
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.platforms: dict[str, PlatformInstance] = {}
|
||||
self._config_loader: AIModelConfigLoader | None = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self, db_session=None):
|
||||
"""初始化路由(db_session 参数保留兼容性,实际不使用)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 从文件配置加载
|
||||
self._config_loader = get_config_loader()
|
||||
self._load_from_config()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
|
||||
|
||||
def _load_from_config(self):
|
||||
"""从配置文件加载平台和模型"""
|
||||
self.platforms = {}
|
||||
|
||||
# 加载平台
|
||||
for platform in self._config_loader.get_all_platforms():
|
||||
try:
|
||||
# PlatformInstance 自动从 Settings 读取 API Key
|
||||
self.platforms[platform.id] = PlatformInstance(
|
||||
{
|
||||
"id": platform.id,
|
||||
"name": platform.name,
|
||||
"provider": platform.provider,
|
||||
"base_url": platform.base_url,
|
||||
}
|
||||
)
|
||||
logger.info(f"平台 {platform.id} 初始化成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
|
||||
|
||||
# 加载模型到 Provider(用于模型名称映射)
|
||||
volcengine_models = []
|
||||
for model in self._config_loader.get_enabled_models():
|
||||
if model.platform_id == "volcengine":
|
||||
volcengine_models.append(
|
||||
{
|
||||
"id": model.id,
|
||||
"model_name": model.model_name,
|
||||
}
|
||||
)
|
||||
|
||||
if volcengine_models:
|
||||
VolcengineProvider.load_models_from_config(volcengine_models)
|
||||
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
|
||||
|
||||
def reload_config(self) -> bool:
|
||||
"""重新加载配置"""
|
||||
if self._config_loader and self._config_loader.reload():
|
||||
self._load_from_config()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict | None:
|
||||
"""获取模型配置"""
|
||||
if self._config_loader:
|
||||
model = self._config_loader.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"platform_id": model.platform_id,
|
||||
"model_name": model.model_name,
|
||||
"display_name": model.display_name,
|
||||
"capabilities": model.capabilities,
|
||||
"default_params": model.default_params,
|
||||
"cost_per_1k_input": model.cost_per_1k_input,
|
||||
"cost_per_1k_output": model.cost_per_1k_output,
|
||||
"max_tokens_limit": model.max_tokens_limit,
|
||||
}
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self, capability: str | None = None, platform_id: str | None = None
|
||||
) -> list[dict]:
|
||||
"""列出可用模型"""
|
||||
models = []
|
||||
|
||||
if self._config_loader:
|
||||
if capability:
|
||||
config_models = self._config_loader.get_models_by_capability(capability)
|
||||
elif platform_id:
|
||||
config_models = self._config_loader.get_models_by_platform(platform_id)
|
||||
else:
|
||||
config_models = self._config_loader.get_enabled_models()
|
||||
|
||||
for model in config_models:
|
||||
models.append(
|
||||
{
|
||||
"id": model.id,
|
||||
"platform_id": model.platform_id,
|
||||
"model_name": model.model_name,
|
||||
"display_name": model.display_name,
|
||||
"capabilities": model.capabilities,
|
||||
"default_params": model.default_params,
|
||||
"cost_per_1k_input": model.cost_per_1k_input,
|
||||
"cost_per_1k_output": model.cost_per_1k_output,
|
||||
"max_tokens_limit": model.max_tokens_limit,
|
||||
}
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
def list_platforms(self) -> list[dict]:
|
||||
"""列出所有平台"""
|
||||
if self._config_loader:
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"provider": p.provider,
|
||||
}
|
||||
for p in self._config_loader.get_all_platforms()
|
||||
]
|
||||
return []
|
||||
|
||||
def select_model_for_task(self, task_type: str) -> str | None:
|
||||
"""根据任务类型选择最佳模型"""
|
||||
# 先检查任务默认配置
|
||||
if self._config_loader:
|
||||
default_model = self._config_loader.get_default_model_for_task(task_type)
|
||||
if default_model:
|
||||
model = self._config_loader.get_model(default_model)
|
||||
if model and model.is_enabled:
|
||||
return default_model
|
||||
|
||||
# 按能力匹配
|
||||
candidates = self._config_loader.get_models_by_capability(task_type)
|
||||
if candidates:
|
||||
return candidates[0].id
|
||||
|
||||
return None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
生成文本
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_id: 指定模型 ID,None 则自动选择
|
||||
task_type: 任务类型(用于自动选模型)
|
||||
"""
|
||||
# 确定模型
|
||||
if model_id is None:
|
||||
if task_type:
|
||||
model_id = self.select_model_for_task(task_type)
|
||||
if model_id is None:
|
||||
# 使用第一个可用模型
|
||||
models = (
|
||||
self._config_loader.get_enabled_models()
|
||||
if self._config_loader
|
||||
else []
|
||||
)
|
||||
if models:
|
||||
model_id = models[0].id
|
||||
else:
|
||||
raise ProviderError("没有可用的模型")
|
||||
|
||||
if self._config_loader:
|
||||
model = self._config_loader.get_model(model_id)
|
||||
if not model:
|
||||
raise ProviderError(f"模型不存在: {model_id}")
|
||||
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if not platform:
|
||||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||||
|
||||
# 合并默认参数
|
||||
params = {**model.default_params, **kwargs}
|
||||
|
||||
# 调用生成
|
||||
try:
|
||||
result = await platform.generate(
|
||||
prompt=prompt, model_name=model.model_name, **params
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型 {model_id} 生成失败: {e}")
|
||||
raise
|
||||
|
||||
async def generate_stream_with_progress(
|
||||
self,
|
||||
prompt: str,
|
||||
model_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
流式生成文本,带进度信息
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_id: 指定模型 ID
|
||||
task_type: 任务类型
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
dict: 包含 type, content, total_chars 等字段
|
||||
"""
|
||||
# 确定模型
|
||||
if model_id is None:
|
||||
if task_type:
|
||||
model_id = self.select_model_for_task(task_type)
|
||||
if model_id is None:
|
||||
models = (
|
||||
self._config_loader.get_enabled_models()
|
||||
if self._config_loader
|
||||
else []
|
||||
)
|
||||
if models:
|
||||
model_id = models[0].id
|
||||
else:
|
||||
raise ProviderError("没有可用的模型")
|
||||
|
||||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
if not model:
|
||||
raise ProviderError(f"模型不存在: {model_id}")
|
||||
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if not platform:
|
||||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||||
|
||||
# 合并默认参数
|
||||
params = {**model.default_params, **kwargs}
|
||||
|
||||
# 检查 provider 是否有 generate_stream_with_progress 方法
|
||||
provider = platform.provider
|
||||
if hasattr(provider, "generate_stream_with_progress"):
|
||||
async for chunk in provider.generate_stream_with_progress(
|
||||
prompt=prompt, model=model.model_name, **params
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# 降级到普通流式生成
|
||||
full_content = ""
|
||||
async for content in provider.generate_stream(
|
||||
prompt=prompt, model=model.model_name, **params
|
||||
):
|
||||
full_content += content
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"content": content,
|
||||
"total_chars": len(full_content),
|
||||
}
|
||||
|
||||
yield {
|
||||
"type": "usage",
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
}
|
||||
|
||||
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
|
||||
"""检查模型健康状态"""
|
||||
results = {}
|
||||
|
||||
if model_id:
|
||||
model = (
|
||||
self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
)
|
||||
if model:
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if platform:
|
||||
results[model_id] = await platform.health_check(model.model_name)
|
||||
else:
|
||||
# 检查所有模型
|
||||
if self._config_loader:
|
||||
for model in self._config_loader.get_enabled_models():
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if platform:
|
||||
try:
|
||||
results[model.id] = await platform.health_check(
|
||||
model.model_name
|
||||
)
|
||||
except Exception as e:
|
||||
results[model.id] = ModelHealth(
|
||||
id=model.id,
|
||||
name=model.display_name,
|
||||
is_available=False,
|
||||
response_time=0,
|
||||
last_error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 全局单例
|
||||
_model_router: ModelRouter | None = None
|
||||
_init_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_model_router(db_session=None) -> ModelRouter:
|
||||
"""获取 ModelRouter 单例(线程安全)
|
||||
|
||||
使用双重检查锁定模式确保并发安全。
|
||||
"""
|
||||
global _model_router
|
||||
if _model_router is None:
|
||||
async with _init_lock:
|
||||
# 双重检查,防止在获取锁期间其他协程已初始化
|
||||
if _model_router is None:
|
||||
logger.info("Initializing ModelRouter singleton...")
|
||||
_model_router = ModelRouter()
|
||||
await _model_router.initialize(db_session)
|
||||
logger.info("ModelRouter singleton initialized")
|
||||
return _model_router
|
||||
Reference in New Issue
Block a user