Files
meijiaka-zy/python-api/app/main.py
T
小鱼开发 755ecc9abe refactor(config): 统一配置体系,禁用热重载,清理兼容层
- 删除 .gitlab-ci.yml
- 删除 runtime_config.py 兼容层
- Pydantic Settings + YAML 三层配置分离
- 统一 PlatformConfigLoader 加载器
- docker-compose 移除重复 environment 覆盖
- volcengine base_url 从 YAML 读取
- 微信支付/SMS 空值启动时拦截
- 日志仅输出控制台,不写文件
- 更新 model_router 注释
2026-05-07 18:42:47 +08:00

337 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FastAPI 应用入口
================
"""
import logging
import sys
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.v1.router import api_router
from app.config import get_settings
from app.core.exceptions import PlatformError
from app.db.session import close_db, init_db
from app.schemas.common import ApiResponse
settings = get_settings()
# 配置日志 - 仅输出到控制台(容器环境由 Docker 收集 stdout
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_level = getattr(logging, settings.LOG_LEVEL)
logging.basicConfig(
level=log_level,
format=log_format,
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库、加载模型配置
- 关闭时:清理资源
"""
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
# 开发和测试环境自动创建表
if settings.DEBUG and settings.ENV in ("development", "staging"):
logger.info("Initializing database tables...")
try:
# 确保所有模型已注册到 metadata
from app.models import (
PointBatch, # noqa: F401
PointRechargeOrder, # noqa: F401
PointTransaction, # noqa: F401
User, # noqa: F401
UserDevice, # noqa: F401
UserPoint, # noqa: F401
)
await init_db()
logger.info("Database tables initialized")
except Exception as e:
logger.warning(f"Database initialization skipped: {e}")
# 加载 AI 模型配置(从 YAML 文件)
try:
from app.core.config_loader import get_config_loader
config_loader = get_config_loader()
platforms_count = len(config_loader.get_all_platforms())
models_count = len(config_loader.get_enabled_models())
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
except Exception as e:
logger.warning(f"Failed to load models from config: {e}")
# 加载空镜素材配置(优先远程 CDN,fallback 本地 JSON
try:
from app.services.material_service import load_config
load_config()
logger.info("Loaded material config")
except Exception as e:
logger.warning(f"Failed to load material config: {e}")
# 初始化 HTTP Client 池(各平台独立,故障隔离)
import httpx
app.state.http_clients = {
"vidu": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
),
"volcengine_caption": httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0),
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
),
"default": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
),
}
logger.info("HTTP Client pool initialized")
# 初始化各平台 Provider(注入共享 Client
from app.ai.providers.vidu_provider import ViduProvider
from app.ai.providers.volcengine_caption_provider import VolcengineCaptionProvider
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
logger.info("Vidu Provider initialized")
# 火山字幕 Provider
try:
app.state.volcengine_caption_provider = VolcengineCaptionProvider(
client=app.state.http_clients["volcengine_caption"]
)
logger.info("Volcengine Caption Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Caption Provider 初始化跳过: {e}")
app.state.volcengine_caption_provider = None
# 火山方舟 Provider(可选,需要 API Key
try:
from app.ai.providers.volcengine_provider import VolcengineProvider
app.state.volcengine_provider = VolcengineProvider()
logger.info("Volcengine Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
app.state.volcengine_provider = None
# 初始化 Adapter(包装 Provider
from app.ai.adapters.vidu_adapter import ViduAdapter
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
from app.platform_gateway import PlatformGateway
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
logger.info("ViduAdapter initialized")
# 初始化 Gateway
app.state.platform_gateway = PlatformGateway()
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
app.state.volcengine_caption_provider
)
app.state.platform_gateway.register(
"volcengine_caption", app.state.volcengine_caption_adapter
)
logger.info("VolcengineCaptionAdapter initialized")
if app.state.volcengine_provider:
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
app.state.volcengine_provider
)
app.state.platform_gateway.register(
"volcengine_ark", app.state.volcengine_ark_adapter
)
logger.info("VolcengineArkAdapter registered to PlatformGateway")
logger.info("PlatformGateway initialized")
# 初始化 ModelRouter(传入 Gateway,确保底层调用走 PlatformGateway
try:
from app.ai.model_router import get_model_router
await get_model_router(gateway=app.state.platform_gateway)
logger.info("ModelRouter initialized with PlatformGateway")
except Exception as e:
logger.warning(f"ModelRouter 初始化失败: {e}")
# 初始化 Service(传入 Gateway
from app.services.vidu_service import ViduService
from app.services.volcengine_caption_service import VolcengineCaptionService
app.state.vidu_service = ViduService(app.state.platform_gateway)
logger.info("Vidu Service initialized")
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_service = VolcengineCaptionService(
app.state.platform_gateway
)
logger.info("Volcengine Caption Service initialized")
else:
app.state.volcengine_caption_service = None
# LLM Gateway(可选,向后兼容)
if app.state.volcengine_provider:
from app.ai.gateways.llm_gateway import LLMGateway
app.state.llm_gateway = LLMGateway(
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
fallback_chains={},
)
logger.info("LLMGateway initialized")
else:
app.state.llm_gateway = None
yield
# 关闭时清理
logger.info("Shutting down...")
# 关闭所有 HTTP Client
for name, client in app.state.http_clients.items():
try:
if not client.is_closed:
await client.aclose()
logger.info(f"HTTP Client closed: {name}")
except Exception as e:
logger.warning(f"HTTP Client close error: {name}: {e}")
# 关闭 Gateway(内部会关闭所有 Adapter
if hasattr(app.state, "platform_gateway"):
try:
await app.state.platform_gateway.close_all()
logger.info("PlatformGateway closed")
except Exception as e:
logger.warning(f"PlatformGateway close error: {e}")
await close_db()
logger.info("Cleanup complete")
def create_app() -> FastAPI:
"""创建 FastAPI 应用实例"""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="美家卡智影 - AI 视频创作后端 API",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan,
)
# CORS 配置
# 开发环境下允许所有来源,避免跨域问题
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# PlatformError 全局处理器(第三方平台错误)
@app.exception_handler(PlatformError)
async def platform_error_handler(request, exc: PlatformError):
"""第三方平台调用错误统一处理"""
http_status = exc.to_http_status()
content = {
"code": exc.status_code or http_status,
"message": str(exc),
"data": None,
}
if settings.DEBUG:
content["detail"] = {
"platform": exc.platform,
"error_type": exc.error_type,
"retryable": exc.retryable,
}
return JSONResponse(status_code=http_status, content=content)
# 全局异常处理(统一返回 ApiResponse 格式)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常捕获"""
logger.exception("Unhandled exception")
return JSONResponse(
status_code=500,
content={
"code": 500,
"message": "服务器内部错误",
"data": None,
"detail": {"error": str(exc)} if settings.DEBUG else None,
},
)
# 健康检查
@app.get("/health", tags=["System"])
async def health_check():
"""服务健康检查"""
return ApiResponse(
code=200,
data={
"status": "healthy",
"version": settings.APP_VERSION,
"environment": settings.ENV,
},
message="服务运行正常",
)
# 根路由
@app.get("/", tags=["System"])
async def root():
"""API 根路径"""
return ApiResponse(
code=200,
data={
"name": settings.APP_NAME,
"version": settings.APP_VERSION,
"docs": "/docs" if settings.DEBUG else None,
},
message="美家卡智影 API 服务",
)
return app
# 创建应用实例
app = create_app()
def main():
"""入口函数(用于命令行启动)"""
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
workers=settings.WORKERS if not settings.DEBUG else 1,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
)
if __name__ == "__main__":
main()
# test