feat: init meijiaka-zj project from ai-meijiaka template

This commit is contained in:
小鱼开发
2026-04-20 16:39:57 +08:00
commit 74983ce5ec
291 changed files with 76164 additions and 0 deletions
+417
View File
@@ -0,0 +1,417 @@
"""
AI 模型路由 V2 - 基于文件配置
=================================
从 YAML 配置文件加载平台/模型配置,支持热重载。
"""
import asyncio
import logging
from collections.abc import AsyncIterator
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
from app.ai.providers.generic_llm_provider import MockProvider
from app.ai.providers.klingai_provider import KlingAIProvider
from app.ai.providers.volcengine_provider import VolcengineProvider
from app.config import get_settings
from app.core.config_loader import AIModelConfigLoader, get_config_loader
logger = logging.getLogger(__name__)
class PlatformInstance:
"""平台实例包装器"""
def __init__(self, config: dict):
self.config = config
self.provider = self._create_provider()
def _create_provider(self):
"""根据平台类型创建 Provider
API Key 从 Settings 读取(符合配置规范)
"""
provider_type = self.config.get("provider", "mock")
settings = get_settings()
if provider_type == "volcengine":
# 从 Settings 读取 API Key
api_key = settings.VOLCENGINE_API_KEY
if not api_key:
raise ProviderError(
"Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY"
)
return VolcengineProvider(
api_key=api_key,
base_url=self.config.get("base_url") or settings.VOLCENGINE_BASE_URL,
)
elif provider_type == "klingai":
# 从 Settings 读取 AK/SK
access_key = settings.KLINGAI_ACCESS_KEY
secret_key = settings.KLINGAI_SECRET_KEY
if not access_key or not secret_key:
raise ProviderError(
"KlingAI Access/Secret Key 未配置,请在 .env 中设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY"
)
return KlingAIProvider(
config={
"access_key": access_key,
"secret_key": secret_key,
"base_url": self.config.get("base_url"),
}
)
elif provider_type == "mock":
return MockProvider()
else:
raise ProviderError(f"不支持的 Provider 类型: {provider_type}")
async def generate(
self, model_name: str, prompt: str, **kwargs
) -> GenerationResult:
"""调用生成"""
return await self.provider.generate(prompt=prompt, model=model_name, **kwargs)
async def generate_stream(
self, model_name: str, prompt: str, **kwargs
) -> AsyncIterator[str]:
"""流式生成"""
async for chunk in self.provider.generate_stream(
prompt=prompt, model=model_name, **kwargs
):
yield chunk
async def health_check(self, model_name: str | None = None) -> ModelHealth:
"""健康检查"""
return await self.provider.health_check(model_name)
class ModelRouter:
"""
模型路由 V2 - 基于文件配置
支持:
- 从 YAML 文件加载配置
- 多平台配置
- 每平台多模型
- 模型自动选择
- 故障降级
- 配置热重载
"""
def __init__(self):
self.platforms: dict[str, PlatformInstance] = {}
self._config_loader: AIModelConfigLoader | None = None
self._initialized = False
async def initialize(self, db_session=None):
"""初始化路由(db_session 参数保留兼容性,实际不使用)"""
if self._initialized:
return
# 从文件配置加载
self._config_loader = get_config_loader()
self._load_from_config()
self._initialized = True
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
def _load_from_config(self):
"""从配置文件加载平台和模型"""
self.platforms = {}
# 加载平台
for platform in self._config_loader.get_all_platforms():
try:
# PlatformInstance 自动从 Settings 读取 API Key
self.platforms[platform.id] = PlatformInstance(
{
"id": platform.id,
"name": platform.name,
"provider": platform.provider,
"base_url": platform.base_url,
}
)
logger.info(f"平台 {platform.id} 初始化成功")
except Exception as e:
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
# 加载模型到 Provider(用于模型名称映射)
volcengine_models = []
for model in self._config_loader.get_enabled_models():
if model.platform_id == "volcengine":
volcengine_models.append(
{
"id": model.id,
"model_name": model.model_name,
}
)
if volcengine_models:
VolcengineProvider.load_models_from_config(volcengine_models)
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
def reload_config(self) -> bool:
"""重新加载配置"""
if self._config_loader and self._config_loader.reload():
self._load_from_config()
return True
return False
def get_model_config(self, model_id: str) -> dict | None:
"""获取模型配置"""
if self._config_loader:
model = self._config_loader.get_model(model_id)
if model:
return {
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
return None
def list_models(
self, capability: str | None = None, platform_id: str | None = None
) -> list[dict]:
"""列出可用模型"""
models = []
if self._config_loader:
if capability:
config_models = self._config_loader.get_models_by_capability(capability)
elif platform_id:
config_models = self._config_loader.get_models_by_platform(platform_id)
else:
config_models = self._config_loader.get_enabled_models()
for model in config_models:
models.append(
{
"id": model.id,
"platform_id": model.platform_id,
"model_name": model.model_name,
"display_name": model.display_name,
"capabilities": model.capabilities,
"default_params": model.default_params,
"cost_per_1k_input": model.cost_per_1k_input,
"cost_per_1k_output": model.cost_per_1k_output,
"max_tokens_limit": model.max_tokens_limit,
}
)
return models
def list_platforms(self) -> list[dict]:
"""列出所有平台"""
if self._config_loader:
return [
{
"id": p.id,
"name": p.name,
"provider": p.provider,
}
for p in self._config_loader.get_all_platforms()
]
return []
def select_model_for_task(self, task_type: str) -> str | None:
"""根据任务类型选择最佳模型"""
# 先检查任务默认配置
if self._config_loader:
default_model = self._config_loader.get_default_model_for_task(task_type)
if default_model:
model = self._config_loader.get_model(default_model)
if model and model.is_enabled:
return default_model
# 按能力匹配
candidates = self._config_loader.get_models_by_capability(task_type)
if candidates:
return candidates[0].id
return None
async def generate(
self,
prompt: str,
model_id: str | None = None,
task_type: str | None = None,
**kwargs,
) -> GenerationResult:
"""
生成文本
Args:
prompt: 提示词
model_id: 指定模型 IDNone 则自动选择
task_type: 任务类型(用于自动选模型)
"""
# 确定模型
if model_id is None:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
# 使用第一个可用模型
models = (
self._config_loader.get_enabled_models()
if self._config_loader
else []
)
if models:
model_id = models[0].id
else:
raise ProviderError("没有可用的模型")
if self._config_loader:
model = self._config_loader.get_model(model_id)
if not model:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(model.platform_id)
if not platform:
raise ProviderError(f"平台不存在: {model.platform_id}")
# 合并默认参数
params = {**model.default_params, **kwargs}
# 调用生成
try:
result = await platform.generate(
prompt=prompt, model_name=model.model_name, **params
)
return result
except Exception as e:
logger.error(f"模型 {model_id} 生成失败: {e}")
raise
async def generate_stream_with_progress(
self,
prompt: str,
model_id: str | None = None,
task_type: str | None = None,
**kwargs,
):
"""
流式生成文本,带进度信息
Args:
prompt: 提示词
model_id: 指定模型 ID
task_type: 任务类型
**kwargs: 其他参数
Yields:
dict: 包含 type, content, total_chars 等字段
"""
# 确定模型
if model_id is None:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
models = (
self._config_loader.get_enabled_models()
if self._config_loader
else []
)
if models:
model_id = models[0].id
else:
raise ProviderError("没有可用的模型")
model = self._config_loader.get_model(model_id) if self._config_loader else None
if not model:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(model.platform_id)
if not platform:
raise ProviderError(f"平台不存在: {model.platform_id}")
# 合并默认参数
params = {**model.default_params, **kwargs}
# 检查 provider 是否有 generate_stream_with_progress 方法
provider = platform.provider
if hasattr(provider, "generate_stream_with_progress"):
async for chunk in provider.generate_stream_with_progress(
prompt=prompt, model=model.model_name, **params
):
yield chunk
else:
# 降级到普通流式生成
full_content = ""
async for content in provider.generate_stream(
prompt=prompt, model=model.model_name, **params
):
full_content += content
yield {
"type": "chunk",
"content": content,
"total_chars": len(full_content),
}
yield {
"type": "usage",
"prompt_tokens": 0,
"completion_tokens": 0,
}
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
"""检查模型健康状态"""
results = {}
if model_id:
model = (
self._config_loader.get_model(model_id) if self._config_loader else None
)
if model:
platform = self.platforms.get(model.platform_id)
if platform:
results[model_id] = await platform.health_check(model.model_name)
else:
# 检查所有模型
if self._config_loader:
for model in self._config_loader.get_enabled_models():
platform = self.platforms.get(model.platform_id)
if platform:
try:
results[model.id] = await platform.health_check(
model.model_name
)
except Exception as e:
results[model.id] = ModelHealth(
id=model.id,
name=model.display_name,
is_available=False,
response_time=0,
last_error=str(e),
)
return results
# 全局单例
_model_router: ModelRouter | None = None
_init_lock = asyncio.Lock()
async def get_model_router(db_session=None) -> ModelRouter:
"""获取 ModelRouter 单例(线程安全)
使用双重检查锁定模式确保并发安全。
"""
global _model_router
if _model_router is None:
async with _init_lock:
# 双重检查,防止在获取锁期间其他协程已初始化
if _model_router is None:
logger.info("Initializing ModelRouter singleton...")
_model_router = ModelRouter()
await _model_router.initialize(db_session)
logger.info("ModelRouter singleton initialized")
return _model_router