Files
meijiaka-zy/python-api/app/main.py
T
2026-05-26 19:21:23 +08:00

369 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FastAPI 应用入口
================
"""
import logging
import sys
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.v1.router import api_router
from app.config import get_settings
from app.core.exceptions import PlatformError
from app.db.session import close_db
from app.schemas.common import ApiResponse
settings = get_settings()
# 配置日志 - 仅输出到控制台(容器环境由 Docker 收集 stdout
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_level = getattr(logging, settings.LOG_LEVEL)
logging.basicConfig(
level=log_level,
format=log_format,
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库、加载模型配置
- 关闭时:清理资源
"""
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
# 加载 AI 模型配置(从 YAML 文件)
try:
from app.core.config_loader import get_config_loader
config_loader = get_config_loader()
platforms_count = len(config_loader.get_all_platforms())
models_count = len(config_loader.get_enabled_models())
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
except Exception as e:
logger.warning(f"Failed to load models from config: {e}")
# 初始化 HTTP Client 池(各平台独立,故障隔离)
import httpx
app.state.http_clients = {
"vidu": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
),
"volcengine_caption": httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0),
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
),
"volcengine_mediakit": httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0),
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
),
"default": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
),
}
logger.info("HTTP Client pool initialized")
# 初始化各平台 Provider(注入共享 Client
from app.ai.providers.vidu_provider import ViduProvider
from app.ai.providers.volcengine_caption_provider import VolcengineCaptionProvider
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
logger.info("Vidu Provider initialized")
# 火山字幕 Provider
try:
app.state.volcengine_caption_provider = VolcengineCaptionProvider(
client=app.state.http_clients["volcengine_caption"]
)
logger.info("Volcengine Caption Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Caption Provider 初始化跳过: {e}")
app.state.volcengine_caption_provider = None
# 火山 Mediakit Provider
from app.ai.providers.volcengine_mediakit_provider import VolcengineMediakitProvider
try:
app.state.volcengine_mediakit_provider = VolcengineMediakitProvider(
client=app.state.http_clients["volcengine_mediakit"]
)
logger.info("Volcengine Mediakit Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Mediakit Provider 初始化跳过: {e}")
app.state.volcengine_mediakit_provider = None
# 火山方舟 Provider(可选,需要 API Key
try:
from app.ai.providers.volcengine_provider import VolcengineProvider
app.state.volcengine_provider = VolcengineProvider()
logger.info("Volcengine Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
app.state.volcengine_provider = None
# 初始化 Adapter(包装 Provider
from app.ai.adapters.vidu_adapter import ViduAdapter
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
from app.ai.adapters.volcengine_mediakit_adapter import VolcengineMediakitAdapter
from app.platform_gateway import PlatformGateway
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
logger.info("ViduAdapter initialized")
# 初始化 Gateway
app.state.platform_gateway = PlatformGateway()
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
app.state.volcengine_caption_provider
)
app.state.platform_gateway.register(
"volcengine_caption", app.state.volcengine_caption_adapter
)
logger.info("VolcengineCaptionAdapter initialized")
if app.state.volcengine_mediakit_provider:
app.state.volcengine_mediakit_adapter = VolcengineMediakitAdapter(
app.state.volcengine_mediakit_provider
)
app.state.platform_gateway.register(
"volcengine_mediakit", app.state.volcengine_mediakit_adapter
)
logger.info("VolcengineMediakitAdapter initialized")
if app.state.volcengine_provider:
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
app.state.volcengine_provider
)
app.state.platform_gateway.register(
"volcengine_ark", app.state.volcengine_ark_adapter
)
logger.info("VolcengineArkAdapter registered to PlatformGateway")
logger.info("PlatformGateway initialized")
# 初始化 ModelRouter(传入 Gateway,确保底层调用走 PlatformGateway
try:
from app.ai.model_router import get_model_router
await get_model_router(gateway=app.state.platform_gateway)
logger.info("ModelRouter initialized with PlatformGateway")
except Exception as e:
logger.warning(f"ModelRouter 初始化失败: {e}")
# 初始化 Service(传入 Gateway
from app.services.vidu_service import ViduService
from app.services.volcengine_caption_service import VolcengineCaptionService
from app.services.volcengine_mediakit_service import VolcengineMediakitService
app.state.vidu_service = ViduService(app.state.platform_gateway)
logger.info("Vidu Service initialized")
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_service = VolcengineCaptionService(
app.state.platform_gateway
)
logger.info("Volcengine Caption Service initialized")
else:
app.state.volcengine_caption_service = None
if app.state.volcengine_mediakit_provider:
app.state.volcengine_mediakit_service = VolcengineMediakitService(
app.state.platform_gateway
)
logger.info("Volcengine Mediakit Service initialized")
else:
app.state.volcengine_mediakit_service = None
# LLM Gateway(可选,向后兼容)
if app.state.volcengine_provider:
from app.ai.gateways.llm_gateway import LLMGateway
app.state.llm_gateway = LLMGateway(
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
fallback_chains={},
)
logger.info("LLMGateway initialized")
else:
app.state.llm_gateway = None
yield
# 关闭时清理
logger.info("Shutting down...")
# 关闭所有 HTTP Client
for name, client in app.state.http_clients.items():
try:
if not client.is_closed:
await client.aclose()
logger.info(f"HTTP Client closed: {name}")
except Exception as e:
logger.warning(f"HTTP Client close error: {name}: {e}")
# 关闭 Gateway(内部会关闭所有 Adapter
if hasattr(app.state, "platform_gateway"):
try:
await app.state.platform_gateway.close_all()
logger.info("PlatformGateway closed")
except Exception as e:
logger.warning(f"PlatformGateway close error: {e}")
await close_db()
logger.info("Cleanup complete")
def create_app() -> FastAPI:
"""创建 FastAPI 应用实例"""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="美家卡智影 - AI 视频创作后端 API",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan,
)
# CORS 配置
# 开发环境下允许所有来源,避免跨域问题
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# 辅助函数:为异常响应添加 CORS 头
# 自定义 exception_handler 返回的 JSONResponse 不经过 CORSMiddleware
# 必须手动添加 CORS 头,否则浏览器会拦截异常响应。
def _cors_response(request, status_code: int, content: dict) -> JSONResponse:
response = JSONResponse(status_code=status_code, content=content)
origin = request.headers.get("origin")
if origin:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
# PlatformError 全局处理器(第三方平台错误)
@app.exception_handler(PlatformError)
async def platform_error_handler(request, exc: PlatformError):
"""第三方平台调用错误统一处理"""
http_status = exc.to_http_status()
content = {
"code": exc.status_code or http_status,
"message": str(exc),
"data": None,
}
if settings.DEBUG:
content["detail"] = {
"platform": exc.platform,
"error_type": exc.error_type,
"retryable": exc.retryable,
}
return _cors_response(request, http_status, content)
# HTTPException 处理器(将默认 { detail } 转为 ApiErrorResponse 格式)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
"""统一 HTTP 异常响应格式"""
# AppException 子类有 message 字段;原生 HTTPException 用 detail 当消息
message = getattr(exc, "message", None) or (
exc.detail if isinstance(exc.detail, str) else "请求失败"
)
detail = exc.detail if isinstance(exc.detail, dict) else None
return _cors_response(
request,
exc.status_code,
{
"code": exc.status_code,
"message": message,
"detail": detail,
},
)
# 全局异常处理(统一返回 ApiResponse 格式)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常捕获"""
logger.exception("Unhandled exception")
return _cors_response(
request,
500,
{
"code": 500,
"message": "服务器内部错误",
"data": None,
"detail": {"error": str(exc)} if settings.DEBUG else None,
},
)
# 健康检查(根路径,供 Docker/Nginx 负载均衡使用)
@app.get("/health", tags=["System"])
async def health_check():
"""服务健康检查"""
return ApiResponse(
code=200,
data={"status": "healthy"},
message="服务运行正常",
)
# 根路由
@app.get("/", tags=["System"])
async def root():
"""API 根路径"""
return ApiResponse(
code=200,
data={
"name": settings.APP_NAME,
"version": settings.APP_VERSION,
"docs": "/docs" if settings.DEBUG else None,
},
message="美家卡智影 API 服务",
)
return app
# 创建应用实例
app = create_app()
def main():
"""入口函数(用于命令行启动)"""
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
workers=settings.WORKERS if not settings.DEBUG else 1,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
)
if __name__ == "__main__":
main()