Files
meijiaka-zy/python-api/app/main.py
T
小鱼开发 447f3c2ffe feat: 空镜素材系统数据库化 + 修复积分不足弹窗叠加
后端:
- 新增 BrollCategory/BrollMaterial/BrollTag 模型及表(mjk_categories/materials/tags)
- 新增 Alembic 迁移 69274ce979a5
- 新增 broll_category/broll_material CRUD 层
- 重构 material_service:删除 JSON 配置,改用 PostgreSQL + Redis 去重
- 新增 /materials/batch-match 接口,删除 /materials/reload
- usage_count 原子递增,Redis 失败自动降级

前端:
- materials API 改为 projectId 去重,新增 batchMatch
- VideoGeneration 批量匹配改用 batchMatch,删除 usedUrls 手动维护
- 修复积分不足时进度弹窗与充值弹窗叠加的 bug
- 操作前预检积分,不足时显示提示条+立即充值按钮
2026-05-11 17:40:38 +08:00

361 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FastAPI 应用入口
================
"""
import logging
import sys
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.v1.router import api_router
from app.config import get_settings
from app.core.exceptions import PlatformError
from app.db.session import close_db, init_db
from app.schemas.common import ApiResponse
settings = get_settings()
# 配置日志 - 仅输出到控制台(容器环境由 Docker 收集 stdout
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_level = getattr(logging, settings.LOG_LEVEL)
logging.basicConfig(
level=log_level,
format=log_format,
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库、加载模型配置
- 关闭时:清理资源
"""
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
# 开发和测试环境自动创建表
if settings.DEBUG and settings.ENV in ("development", "staging"):
logger.info("Initializing database tables...")
try:
# 确保所有模型已注册到 metadata
from app.models import (
BrollCategory, # noqa: F401
BrollMaterial, # noqa: F401
BrollTag, # noqa: F401
PointBatch, # noqa: F401
PointRechargeOrder, # noqa: F401
PointTransaction, # noqa: F401
User, # noqa: F401
UserDevice, # noqa: F401
UserPoint, # noqa: F401
)
await init_db()
logger.info("Database tables initialized")
except Exception as e:
logger.warning(f"Database initialization skipped: {e}")
# 加载 AI 模型配置(从 YAML 文件)
try:
from app.core.config_loader import get_config_loader
config_loader = get_config_loader()
platforms_count = len(config_loader.get_all_platforms())
models_count = len(config_loader.get_enabled_models())
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
except Exception as e:
logger.warning(f"Failed to load models from config: {e}")
# 初始化 HTTP Client 池(各平台独立,故障隔离)
import httpx
app.state.http_clients = {
"vidu": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
),
"volcengine_caption": httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0),
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
),
"default": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
),
}
logger.info("HTTP Client pool initialized")
# 初始化各平台 Provider(注入共享 Client
from app.ai.providers.vidu_provider import ViduProvider
from app.ai.providers.volcengine_caption_provider import VolcengineCaptionProvider
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
logger.info("Vidu Provider initialized")
# 火山字幕 Provider
try:
app.state.volcengine_caption_provider = VolcengineCaptionProvider(
client=app.state.http_clients["volcengine_caption"]
)
logger.info("Volcengine Caption Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Caption Provider 初始化跳过: {e}")
app.state.volcengine_caption_provider = None
# 火山方舟 Provider(可选,需要 API Key
try:
from app.ai.providers.volcengine_provider import VolcengineProvider
app.state.volcengine_provider = VolcengineProvider()
logger.info("Volcengine Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
app.state.volcengine_provider = None
# 初始化 Adapter(包装 Provider
from app.ai.adapters.vidu_adapter import ViduAdapter
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
from app.platform_gateway import PlatformGateway
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
logger.info("ViduAdapter initialized")
# 初始化 Gateway
app.state.platform_gateway = PlatformGateway()
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
app.state.volcengine_caption_provider
)
app.state.platform_gateway.register(
"volcengine_caption", app.state.volcengine_caption_adapter
)
logger.info("VolcengineCaptionAdapter initialized")
if app.state.volcengine_provider:
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
app.state.volcengine_provider
)
app.state.platform_gateway.register(
"volcengine_ark", app.state.volcengine_ark_adapter
)
logger.info("VolcengineArkAdapter registered to PlatformGateway")
logger.info("PlatformGateway initialized")
# 初始化 ModelRouter(传入 Gateway,确保底层调用走 PlatformGateway
try:
from app.ai.model_router import get_model_router
await get_model_router(gateway=app.state.platform_gateway)
logger.info("ModelRouter initialized with PlatformGateway")
except Exception as e:
logger.warning(f"ModelRouter 初始化失败: {e}")
# 初始化 Service(传入 Gateway
from app.services.vidu_service import ViduService
from app.services.volcengine_caption_service import VolcengineCaptionService
app.state.vidu_service = ViduService(app.state.platform_gateway)
logger.info("Vidu Service initialized")
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_service = VolcengineCaptionService(
app.state.platform_gateway
)
logger.info("Volcengine Caption Service initialized")
else:
app.state.volcengine_caption_service = None
# LLM Gateway(可选,向后兼容)
if app.state.volcengine_provider:
from app.ai.gateways.llm_gateway import LLMGateway
app.state.llm_gateway = LLMGateway(
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
fallback_chains={},
)
logger.info("LLMGateway initialized")
else:
app.state.llm_gateway = None
yield
# 关闭时清理
logger.info("Shutting down...")
# 关闭所有 HTTP Client
for name, client in app.state.http_clients.items():
try:
if not client.is_closed:
await client.aclose()
logger.info(f"HTTP Client closed: {name}")
except Exception as e:
logger.warning(f"HTTP Client close error: {name}: {e}")
# 关闭 Gateway(内部会关闭所有 Adapter
if hasattr(app.state, "platform_gateway"):
try:
await app.state.platform_gateway.close_all()
logger.info("PlatformGateway closed")
except Exception as e:
logger.warning(f"PlatformGateway close error: {e}")
await close_db()
logger.info("Cleanup complete")
def create_app() -> FastAPI:
"""创建 FastAPI 应用实例"""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="美家卡智影 - AI 视频创作后端 API",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan,
)
# CORS 配置
# 开发环境下允许所有来源,避免跨域问题
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# 辅助函数:为异常响应添加 CORS 头
# 自定义 exception_handler 返回的 JSONResponse 不经过 CORSMiddleware
# 必须手动添加 CORS 头,否则浏览器会拦截异常响应。
def _cors_response(request, status_code: int, content: dict) -> JSONResponse:
response = JSONResponse(status_code=status_code, content=content)
origin = request.headers.get("origin")
if origin:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
# PlatformError 全局处理器(第三方平台错误)
@app.exception_handler(PlatformError)
async def platform_error_handler(request, exc: PlatformError):
"""第三方平台调用错误统一处理"""
http_status = exc.to_http_status()
content = {
"code": exc.status_code or http_status,
"message": str(exc),
"data": None,
}
if settings.DEBUG:
content["detail"] = {
"platform": exc.platform,
"error_type": exc.error_type,
"retryable": exc.retryable,
}
return _cors_response(request, http_status, content)
# HTTPException 处理器(将默认 { detail } 转为 ApiErrorResponse 格式)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
"""统一 HTTP 异常响应格式"""
# AppException 子类有 message 字段;原生 HTTPException 用 detail 当消息
message = getattr(exc, "message", None) or (
exc.detail if isinstance(exc.detail, str) else "请求失败"
)
detail = exc.detail if isinstance(exc.detail, dict) else None
return _cors_response(
request,
exc.status_code,
{
"code": exc.status_code,
"message": message,
"detail": detail,
},
)
# 全局异常处理(统一返回 ApiResponse 格式)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常捕获"""
logger.exception("Unhandled exception")
return _cors_response(
request,
500,
{
"code": 500,
"message": "服务器内部错误",
"data": None,
"detail": {"error": str(exc)} if settings.DEBUG else None,
},
)
# 健康检查
@app.get("/health", tags=["System"])
async def health_check():
"""服务健康检查"""
return ApiResponse(
code=200,
data={
"status": "healthy",
"version": settings.APP_VERSION,
"environment": settings.ENV,
},
message="服务运行正常",
)
# 根路由
@app.get("/", tags=["System"])
async def root():
"""API 根路径"""
return ApiResponse(
code=200,
data={
"name": settings.APP_NAME,
"version": settings.APP_VERSION,
"docs": "/docs" if settings.DEBUG else None,
},
message="美家卡智影 API 服务",
)
return app
# 创建应用实例
app = create_app()
def main():
"""入口函数(用于命令行启动)"""
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
workers=settings.WORKERS if not settings.DEBUG else 1,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
)
if __name__ == "__main__":
main()
# test