""" AI 模型路由 V2 - 基于文件配置 ================================= 从 YAML 配置文件加载平台/模型配置,支持热重载。 """ import asyncio import logging from collections.abc import AsyncIterator from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError from app.ai.providers.generic_llm_provider import MockProvider from app.ai.providers.klingai_provider import KlingAIProvider from app.ai.providers.volcengine_provider import VolcengineProvider from app.config import get_settings from app.core.config_loader import AIModelConfigLoader, get_config_loader logger = logging.getLogger(__name__) class PlatformInstance: """平台实例包装器""" def __init__(self, config: dict): self.config = config self.provider = self._create_provider() def _create_provider(self): """根据平台类型创建 Provider API Key 从 Settings 读取(符合配置规范) """ provider_type = self.config.get("provider", "mock") settings = get_settings() if provider_type == "volcengine": # 从 Settings 读取 API Key api_key = settings.VOLCENGINE_API_KEY if not api_key: raise ProviderError( "Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY" ) return VolcengineProvider( api_key=api_key, base_url=self.config.get("base_url") or settings.VOLCENGINE_BASE_URL, ) elif provider_type == "klingai": # 从 Settings 读取 AK/SK access_key = settings.KLINGAI_ACCESS_KEY secret_key = settings.KLINGAI_SECRET_KEY if not access_key or not secret_key: raise ProviderError( "KlingAI Access/Secret Key 未配置,请在 .env 中设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY" ) return KlingAIProvider( config={ "access_key": access_key, "secret_key": secret_key, "base_url": self.config.get("base_url"), } ) elif provider_type == "mock": return MockProvider() else: raise ProviderError(f"不支持的 Provider 类型: {provider_type}") async def generate( self, model_name: str, prompt: str, **kwargs ) -> GenerationResult: """调用生成""" return await self.provider.generate(prompt=prompt, model=model_name, **kwargs) async def generate_stream( self, model_name: str, prompt: str, **kwargs ) -> AsyncIterator[str]: """流式生成""" async for chunk in self.provider.generate_stream( prompt=prompt, model=model_name, **kwargs ): yield chunk async def health_check(self, model_name: str | None = None) -> ModelHealth: """健康检查""" return await self.provider.health_check(model_name) class ModelRouter: """ 模型路由 V2 - 基于文件配置 支持: - 从 YAML 文件加载配置 - 多平台配置 - 每平台多模型 - 模型自动选择 - 故障降级 - 配置热重载 """ def __init__(self): self.platforms: dict[str, PlatformInstance] = {} self._config_loader: AIModelConfigLoader | None = None self._initialized = False async def initialize(self, db_session=None): """初始化路由(db_session 参数保留兼容性,实际不使用)""" if self._initialized: return # 从文件配置加载 self._config_loader = get_config_loader() self._load_from_config() self._initialized = True logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台") def _load_from_config(self): """从配置文件加载平台和模型""" self.platforms = {} # 加载平台 for platform in self._config_loader.get_all_platforms(): try: # PlatformInstance 自动从 Settings 读取 API Key self.platforms[platform.id] = PlatformInstance( { "id": platform.id, "name": platform.name, "provider": platform.provider, "base_url": platform.base_url, } ) logger.info(f"平台 {platform.id} 初始化成功") except Exception as e: logger.warning(f"平台 {platform.id} 初始化失败: {e}") # 加载模型到 Provider(用于模型名称映射) volcengine_models = [] for model in self._config_loader.get_enabled_models(): if model.platform_id == "volcengine": volcengine_models.append( { "id": model.id, "model_name": model.model_name, } ) if volcengine_models: VolcengineProvider.load_models_from_config(volcengine_models) logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider") def reload_config(self) -> bool: """重新加载配置""" if self._config_loader and self._config_loader.reload(): self._load_from_config() return True return False def get_model_config(self, model_id: str) -> dict | None: """获取模型配置""" if self._config_loader: model = self._config_loader.get_model(model_id) if model: return { "id": model.id, "platform_id": model.platform_id, "model_name": model.model_name, "display_name": model.display_name, "capabilities": model.capabilities, "default_params": model.default_params, "cost_per_1k_input": model.cost_per_1k_input, "cost_per_1k_output": model.cost_per_1k_output, "max_tokens_limit": model.max_tokens_limit, } return None def list_models( self, capability: str | None = None, platform_id: str | None = None ) -> list[dict]: """列出可用模型""" models = [] if self._config_loader: if capability: config_models = self._config_loader.get_models_by_capability(capability) elif platform_id: config_models = self._config_loader.get_models_by_platform(platform_id) else: config_models = self._config_loader.get_enabled_models() for model in config_models: models.append( { "id": model.id, "platform_id": model.platform_id, "model_name": model.model_name, "display_name": model.display_name, "capabilities": model.capabilities, "default_params": model.default_params, "cost_per_1k_input": model.cost_per_1k_input, "cost_per_1k_output": model.cost_per_1k_output, "max_tokens_limit": model.max_tokens_limit, } ) return models def list_platforms(self) -> list[dict]: """列出所有平台""" if self._config_loader: return [ { "id": p.id, "name": p.name, "provider": p.provider, } for p in self._config_loader.get_all_platforms() ] return [] def select_model_for_task(self, task_type: str) -> str | None: """根据任务类型选择最佳模型""" # 先检查任务默认配置 if self._config_loader: default_model = self._config_loader.get_default_model_for_task(task_type) if default_model: model = self._config_loader.get_model(default_model) if model and model.is_enabled: return default_model # 按能力匹配 candidates = self._config_loader.get_models_by_capability(task_type) if candidates: return candidates[0].id return None async def generate( self, prompt: str, model_id: str | None = None, task_type: str | None = None, **kwargs, ) -> GenerationResult: """ 生成文本 Args: prompt: 提示词 model_id: 指定模型 ID,None 则自动选择 task_type: 任务类型(用于自动选模型) """ # 确定模型 if model_id is None: if task_type: model_id = self.select_model_for_task(task_type) if model_id is None: # 使用第一个可用模型 models = ( self._config_loader.get_enabled_models() if self._config_loader else [] ) if models: model_id = models[0].id else: raise ProviderError("没有可用的模型") if self._config_loader: model = self._config_loader.get_model(model_id) if not model: raise ProviderError(f"模型不存在: {model_id}") platform = self.platforms.get(model.platform_id) if not platform: raise ProviderError(f"平台不存在: {model.platform_id}") # 合并默认参数 params = {**model.default_params, **kwargs} # 调用生成 try: result = await platform.generate( prompt=prompt, model_name=model.model_name, **params ) return result except Exception as e: logger.error(f"模型 {model_id} 生成失败: {e}") raise async def generate_stream_with_progress( self, prompt: str, model_id: str | None = None, task_type: str | None = None, **kwargs, ): """ 流式生成文本,带进度信息 Args: prompt: 提示词 model_id: 指定模型 ID task_type: 任务类型 **kwargs: 其他参数 Yields: dict: 包含 type, content, total_chars 等字段 """ # 确定模型 if model_id is None: if task_type: model_id = self.select_model_for_task(task_type) if model_id is None: models = ( self._config_loader.get_enabled_models() if self._config_loader else [] ) if models: model_id = models[0].id else: raise ProviderError("没有可用的模型") model = self._config_loader.get_model(model_id) if self._config_loader else None if not model: raise ProviderError(f"模型不存在: {model_id}") platform = self.platforms.get(model.platform_id) if not platform: raise ProviderError(f"平台不存在: {model.platform_id}") # 合并默认参数 params = {**model.default_params, **kwargs} # 检查 provider 是否有 generate_stream_with_progress 方法 provider = platform.provider if hasattr(provider, "generate_stream_with_progress"): async for chunk in provider.generate_stream_with_progress( prompt=prompt, model=model.model_name, **params ): yield chunk else: # 降级到普通流式生成 full_content = "" async for content in provider.generate_stream( prompt=prompt, model=model.model_name, **params ): full_content += content yield { "type": "chunk", "content": content, "total_chars": len(full_content), } yield { "type": "usage", "prompt_tokens": 0, "completion_tokens": 0, } async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]: """检查模型健康状态""" results = {} if model_id: model = ( self._config_loader.get_model(model_id) if self._config_loader else None ) if model: platform = self.platforms.get(model.platform_id) if platform: results[model_id] = await platform.health_check(model.model_name) else: # 检查所有模型 if self._config_loader: for model in self._config_loader.get_enabled_models(): platform = self.platforms.get(model.platform_id) if platform: try: results[model.id] = await platform.health_check( model.model_name ) except Exception as e: results[model.id] = ModelHealth( id=model.id, name=model.display_name, is_available=False, response_time=0, last_error=str(e), ) return results # 全局单例 _model_router: ModelRouter | None = None _init_lock = asyncio.Lock() async def get_model_router(db_session=None) -> ModelRouter: """获取 ModelRouter 单例(线程安全) 使用双重检查锁定模式确保并发安全。 """ global _model_router if _model_router is None: async with _init_lock: # 双重检查,防止在获取锁期间其他协程已初始化 if _model_router is None: logger.info("Initializing ModelRouter singleton...") _model_router = ModelRouter() await _model_router.initialize(db_session) logger.info("ModelRouter singleton initialized") return _model_router