""" AI 模型管理与生成 API ===================== 提供模型列表查询、文本生成、脚本生成、润色等功能。 模型配置存储在 config/ai_models.yaml,支持热重载。 """ import logging logger = logging.getLogger(__name__) from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from app.ai.model_router import get_model_router from app.core.config_loader import get_config_loader from app.schemas.common import ApiResponse, success_response from app.services.ai_response_utils import ( safe_parse_ai_json_response, validate_and_normalize_shots, validate_shots_structure, ) router = APIRouter() # ============ 请求/响应 Schema ============ class PlatformResponse(BaseModel): """平台响应""" id: str name: str provider: str class ModelResponse(BaseModel): """模型响应""" id: str platform_id: str model_name: str display_name: str capabilities: list[str] default_params: dict is_enabled: bool full_model_id: str class GenerateRequest(BaseModel): """生成请求""" prompt: str = Field(..., description="提示词") model_id: str | None = Field(None, description="指定模型 ID") task_type: str | None = Field( None, description="任务类型,用于自动选模型: script/polish" ) temperature: float | None = Field(None, description="随机性 (0-2)") max_tokens: int | None = Field(None, description="最大生成长度") class GenerateResponse(BaseModel): """生成响应""" content: str model: str usage: dict | None class HealthResponse(BaseModel): """健康检查响应""" status: str total_models: int available_models: int models: list[dict] # ============ API 路由 ============ @router.get("/platforms", response_model=ApiResponse[list[PlatformResponse]]) async def list_platforms(): """获取所有平台列表""" config_loader = get_config_loader() platforms = config_loader.get_all_platforms() return success_response( data=[ PlatformResponse( id=p.id, name=p.name, provider=p.provider, ) for p in platforms ] ) @router.get("/models", response_model=ApiResponse[list[ModelResponse]]) async def list_models(capability: str | None = None): """获取模型列表 Args: capability: 按能力过滤,如 script、polish、chat """ router = await get_model_router() models = router.list_models(capability=capability) return success_response( data=[ ModelResponse( id=m["id"], platform_id=m["platform_id"], model_name=m["model_name"], display_name=m["display_name"], capabilities=m["capabilities"], default_params=m["default_params"], is_enabled=True, # 列表中的都是启用的 full_model_id=f"{m['platform_id']}/{m['id']}", ) for m in models ] ) @router.post("/generate", response_model=ApiResponse[GenerateResponse]) async def generate_text(data: GenerateRequest): """文本生成(自动路由到对应平台)""" router = await get_model_router() try: result = await router.generate( prompt=data.prompt, model_id=data.model_id, task_type=data.task_type, temperature=data.temperature, max_tokens=data.max_tokens, ) return success_response( data=GenerateResponse( content=result.content, model=result.model, usage=result.usage, ) ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/health", response_model=ApiResponse[HealthResponse]) async def health_check(model_id: str | None = None): """检查模型健康状态""" router = await get_model_router() health_results = await router.health_check(model_id) models_status = [] available_count = 0 for mid, health in health_results.items(): models_status.append( { "id": mid, "name": health.name, "is_available": health.is_available, "response_time": health.response_time, "last_error": health.last_error, } ) if health.is_available: available_count += 1 return success_response( data={ "status": "healthy" if available_count > 0 else "unhealthy", "total_models": len(models_status), "available_models": available_count, "models": models_status, } ) @router.get("/platforms/{platform_id}/test", response_model=ApiResponse[dict]) async def test_platform_connection(platform_id: str): """测试平台连接""" from app.ai.model_router import PlatformInstance config_loader = get_config_loader() platform = config_loader.get_platform(platform_id) if not platform: raise HTTPException(status_code=404, detail="平台不存在") try: # PlatformInstance 自动从 Settings 读取 API Key instance = PlatformInstance( { "id": platform.id, "name": platform.name, "provider": platform.provider, "base_url": platform.base_url, } ) # 尝试调用 result = await instance.provider.health_check() return success_response( data={ "platform_id": platform_id, "success": result.is_available, "response_time": result.response_time, "message": "连接成功" if result.is_available else result.last_error, } ) except Exception as e: return success_response( data={ "platform_id": platform_id, "success": False, "error": str(e), } ) @router.post("/reload", response_model=ApiResponse[dict]) async def reload_config(): """重新加载配置文件""" router = await get_model_router() reloaded = router.reload_config() if reloaded: return success_response(data={"reloaded": True}, message="配置已重新加载") else: return success_response(data={"reloaded": False}, message="配置文件无变化") # ============================================================================= # Prompt 模板 API # ============================================================================= from app.ai.prompts import ( SCRIPT_TYPES, VIDEO_STYLES, PolishPromptBuilder, ScriptPromptBuilder, ) class PromptTemplatesResponse(BaseModel): """Prompt 模板配置响应""" script_types: list[dict] video_styles: list[dict] tones: list[str] class ScriptGenerateRequest(BaseModel): """脚本生成请求""" topic: str = Field(..., description="脚本主题", example="水电改造的3个致命错误") duration: int = Field(30, ge=15, le=120, description="视频时长(秒)") script_type: str = Field("干货型", description="脚本类型") video_style: str = Field("口播", description="视频风格") tone: str | None = Field(None, description="语气风格") requirements: str | None = Field(None, description="额外要求") model_id: str | None = Field(None, description="指定模型ID,默认使用系统默认模型") class ScriptGenerateResponse(BaseModel): """脚本生成响应 - 针对前端展示优化""" success: bool script: list[ dict | None ] # 镜头列表,包含 shot_number, type, scene/prompt, voiceover, duration, word_count total_duration: int | None # 预计总时长(秒) target_duration: int # 目标时长(秒) total_word_count: int | None # 总字数(供前端展示) segment_count: int | None # 分镜数量(供前端展示) empty_shot_count: int | None # 空镜数量(供前端展示) script_type: str model: str usage: dict | None error: str | None raw_content: str | None class PolishRequest(BaseModel): """润色请求""" content: str = Field(..., description="需要润色的内容") polish_type: str = Field("voiceover", description="润色类型:scene/voiceover") model_id: str | None = Field(None, description="指定模型ID") class PolishResponse(BaseModel): """润色响应""" success: bool original: str polished: str | None polish_type: str model: str usage: dict | None @router.get("/prompts/templates", response_model=ApiResponse[PromptTemplatesResponse]) async def get_prompt_templates(): """ 获取所有可用的 Prompt 模板配置 包括脚本类型、视频风格、语气风格等选项。 """ return success_response( data={ "script_types": [ { "id": key, "name": value["name"], "description": value["description"], "key_points": value["key_points"], } for key, value in SCRIPT_TYPES.items() if key != "default" ], "video_styles": [ { "id": key, "name": value["name"], "description": value["description"], } for key, value in VIDEO_STYLES.items() ], "tones": ["专业", "亲和", "幽默", "严肃", "激情"], } ) @router.post("/prompts/build", response_model=ApiResponse[dict]) async def build_system_prompt( duration: int = 30, script_type: str = "干货型", video_style: str = "口播", tone: str | None = None, ): """ 构建系统 Prompt(用于调试和预览) 返回构建好的系统 Prompt,可用于前端预览或调试。 """ builder = ScriptPromptBuilder() prompt = builder.build( duration=duration, script_type=script_type, video_style=video_style, industry="家装", tone=tone, ) return success_response( data={ "system_prompt": prompt, "length": len(prompt), "parameters": { "duration": duration, "script_type": script_type, "video_style": video_style, "tone": tone, }, } ) @router.post("/scripts/generate", response_model=ApiResponse[ScriptGenerateResponse]) async def generate_script(data: ScriptGenerateRequest): """ 生成家装行业短视频脚本 使用专业的 Prompt 模板生成包含分镜+空镜的混合脚本。 针对前端展示优化,返回分镜数、空镜数、总字数等统计信息。 """ router = await get_model_router() # 构建系统 Prompt builder = ScriptPromptBuilder() system_prompt = builder.build( duration=data.duration, script_type=data.script_type, video_style=data.video_style, industry="家装", tone=data.requirements, custom_requirements=data.requirements, ) # 构建用户输入 user_prompt = f"""主题是"{data.topic}" 要求: 1. 严格按照时长要求控制 2. 每个镜头的配音字数必须匹配时长(4-5字/秒) 3. 空镜必须有画外音,不能为空 4. 只返回JSON数组,不要有其他文字""" full_prompt = f"{system_prompt}\n\n【用户输入】\n{user_prompt}\n\n请生成脚本,只返回JSON数组:" # 调用模型 try: result = await router.generate( prompt=full_prompt, model_id=data.model_id, task_type="script", temperature=0.7, max_tokens=2500, ) # 安全地解析 JSON 响应 success_parsed, parsed_data, error_msg = safe_parse_ai_json_response( result.content ) if not success_parsed: logger.error(f"AI 响应解析失败: {error_msg}") return success_response( data={ "success": False, "script": None, "total_duration": None, "target_duration": data.duration, "total_word_count": None, "segment_count": None, "empty_shot_count": None, "script_type": data.script_type, "model": result.model, "usage": result.usage, "error": error_msg or "JSON解析失败", "raw_content": result.content, } ) # 验证并标准化分镜数据 try: script = validate_and_normalize_shots(parsed_data) except Exception as e: logger.error(f"分镜数据标准化失败: {e}") return success_response( data={ "success": False, "script": None, "total_duration": None, "target_duration": data.duration, "total_word_count": None, "segment_count": None, "empty_shot_count": None, "script_type": data.script_type, "model": result.model, "usage": result.usage, "error": f"分镜数据格式错误: {e}", "raw_content": result.content, } ) # 验证分镜结构 is_valid, validation_errors = validate_shots_structure(script) if not is_valid: logger.warning(f"分镜结构验证失败: {validation_errors}") # 继续处理,但记录警告 # 计算统计信息(供前端展示) total_duration = sum( int(shot.get("duration", "5s").rstrip("s秒")) for shot in script if isinstance(shot, dict) ) total_word_count = sum( len(shot.get("voiceover", "")) for shot in script if isinstance(shot, dict) ) segment_count = sum( 1 for shot in script if isinstance(shot, dict) and shot.get("type") == "segment" ) empty_shot_count = sum( 1 for shot in script if isinstance(shot, dict) and shot.get("type") == "empty_shot" ) return success_response( data={ "success": True, "script": script, "total_duration": total_duration, "target_duration": data.duration, "total_word_count": total_word_count, "segment_count": segment_count, "empty_shot_count": empty_shot_count, "script_type": data.script_type, "model": result.model, "usage": result.usage, "error": None, "raw_content": None, } ) except Exception as e: raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}") @router.post("/scripts/polish", response_model=ApiResponse[PolishResponse]) async def polish_script_content(data: PolishRequest): """ 润色脚本内容 对场景描述或口播文案进行专业润色。 """ router = await get_model_router() # 构建润色 Prompt builder = PolishPromptBuilder() system_prompt = builder.build(data.polish_type) full_prompt = f"{system_prompt}\n\n【待润色内容】\n{data.content}\n\n请润色:" # 调用模型 try: result = await router.generate( prompt=full_prompt, model_id=data.model_id, task_type="polish", temperature=0.6, max_tokens=1000, ) return success_response( data={ "success": True, "original": data.content, "polished": result.content, "polish_type": data.polish_type, "model": result.model, "usage": result.usage, } ) except Exception as e: raise HTTPException(status_code=500, detail=f"润色失败: {str(e)}")