e58159fc42
Phase 1: 异常体系统一 - 新增 PlatformError / PlatformErrorType 标准定义 - 改造所有 Provider 异常抛出为 PlatformError - 注册全局 PlatformError exception handler Phase 2: Adapter Protocol - 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable) - 新增 app/ai/adapters/constants.py(Method 常量) - 新增 PlatformConfigLoader(config/platform-config.yaml) Phase 3: HTTP Client 统一 - ViduProvider 从 aiohttp 迁移到 httpx(注入方式) - VolcengineCaptionService 改为注入 http_client - lifespan 统一管理所有 Client 创建和关闭 Phase 4: Gateway 骨架 + Adapter 实现 - 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter - 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook) - 新增 LLMGateway(带 Fallback 降级链) - lifespan 注册所有 Adapter 和 Gateway Phase 6: 清理与验证 - 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL - Provider 改为从 PlatformConfigLoader 读取 base_url - 清理 volcengine_caption_service 全局单例 - config_loader 默认路径改为 platform-config.yaml - Scheduler 注入共享 HTTP client - vidu.py 回调路由使用 Adapter 验签和解析 - ruff 全量通过,应用启动测试通过
314 lines
10 KiB
Python
314 lines
10 KiB
Python
"""
|
||
FastAPI 应用入口
|
||
================
|
||
"""
|
||
|
||
import logging
|
||
import sys
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from app.api.v1.router import api_router
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError
|
||
from app.db.session import close_db, init_db
|
||
from app.schemas.common import ApiResponse
|
||
|
||
settings = get_settings()
|
||
|
||
# 配置日志 - 同时输出到控制台和文件
|
||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
log_level = getattr(logging, settings.LOG_LEVEL)
|
||
|
||
# 创建日志目录(在用户文档目录下)
|
||
log_dir = Path.home() / "Documents" / "Meijiaka-zy" / "logs"
|
||
log_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 日志文件名按日期
|
||
log_file = log_dir / f"api_{datetime.now().strftime('%Y%m%d')}.log"
|
||
|
||
# 配置根日志记录器
|
||
logging.basicConfig(
|
||
level=log_level,
|
||
format=log_format,
|
||
handlers=[
|
||
logging.StreamHandler(sys.stdout), # 控制台输出
|
||
logging.FileHandler(log_file, encoding="utf-8", mode="a"), # 文件输出
|
||
],
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"日志文件位置: {log_file}")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
应用生命周期管理
|
||
|
||
- 启动时:初始化数据库、加载模型配置
|
||
- 关闭时:清理资源
|
||
"""
|
||
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
|
||
|
||
# 开发和测试环境自动创建表
|
||
if settings.DEBUG and settings.ENV in ("development", "staging"):
|
||
logger.info("Initializing database tables...")
|
||
try:
|
||
# 确保所有模型已注册到 metadata
|
||
from app.models import User # noqa: F401
|
||
|
||
await init_db()
|
||
logger.info("Database tables initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Database initialization skipped: {e}")
|
||
|
||
# 加载 AI 模型配置(从 YAML 文件)
|
||
try:
|
||
from app.core.config_loader import get_config_loader
|
||
|
||
config_loader = get_config_loader()
|
||
platforms_count = len(config_loader.get_all_platforms())
|
||
models_count = len(config_loader.get_enabled_models())
|
||
|
||
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load models from config: {e}")
|
||
|
||
# 加载空镜素材配置(优先远程 CDN,fallback 本地 JSON)
|
||
try:
|
||
from app.services.material_service import load_config
|
||
|
||
load_config()
|
||
logger.info("Loaded material config")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load material config: {e}")
|
||
|
||
# 初始化 HTTP Client 池(各平台独立,故障隔离)
|
||
import httpx
|
||
|
||
app.state.http_clients = {
|
||
"vidu": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
|
||
),
|
||
"volcengine_caption": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(60.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
|
||
),
|
||
"default": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
|
||
),
|
||
}
|
||
logger.info("HTTP Client pool initialized")
|
||
|
||
# 初始化各平台 Provider(注入共享 Client)
|
||
from app.ai.providers.vidu_provider import ViduProvider
|
||
from app.services.vidu_service import ViduService
|
||
from app.services.volcengine_caption_service import VolcengineCaptionService
|
||
|
||
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
|
||
app.state.vidu_service = ViduService(app.state.vidu_provider)
|
||
logger.info("Vidu Provider & Service initialized")
|
||
|
||
# 火山字幕服务(始终初始化,因为 APPID/TOKEN 可能在后续配置)
|
||
try:
|
||
app.state.volcengine_caption_service = VolcengineCaptionService(
|
||
client=app.state.http_clients["volcengine_caption"]
|
||
)
|
||
logger.info("Volcengine Caption Service initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Volcengine Caption Service 初始化跳过: {e}")
|
||
app.state.volcengine_caption_service = None
|
||
|
||
# 火山方舟 Provider(可选,需要 API Key)
|
||
try:
|
||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||
|
||
app.state.volcengine_provider = VolcengineProvider()
|
||
logger.info("Volcengine Provider initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
|
||
app.state.volcengine_provider = None
|
||
|
||
# 初始化 Adapter(包装 Provider)
|
||
from app.ai.adapters.vidu_adapter import ViduAdapter
|
||
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
|
||
from app.platform_gateway import PlatformGateway
|
||
|
||
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
|
||
logger.info("ViduAdapter initialized")
|
||
|
||
# 初始化 Gateway
|
||
app.state.platform_gateway = PlatformGateway()
|
||
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
|
||
|
||
if app.state.volcengine_caption_service:
|
||
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
|
||
app.state.volcengine_caption_service
|
||
)
|
||
app.state.platform_gateway.register(
|
||
"volcengine_caption", app.state.volcengine_caption_adapter
|
||
)
|
||
logger.info("VolcengineCaptionAdapter initialized")
|
||
|
||
logger.info("PlatformGateway initialized")
|
||
|
||
# LLM Gateway(可选,需要 Volcengine Provider)
|
||
if app.state.volcengine_provider:
|
||
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
|
||
from app.ai.gateways.llm_gateway import LLMGateway
|
||
|
||
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
|
||
app.state.volcengine_provider
|
||
)
|
||
app.state.llm_gateway = LLMGateway(
|
||
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
|
||
fallback_chains={},
|
||
)
|
||
logger.info("LLMGateway initialized")
|
||
else:
|
||
app.state.llm_gateway = None
|
||
|
||
yield
|
||
|
||
# 关闭时清理
|
||
logger.info("Shutting down...")
|
||
|
||
# 关闭所有 HTTP Client
|
||
for name, client in app.state.http_clients.items():
|
||
try:
|
||
if not client.is_closed:
|
||
await client.aclose()
|
||
logger.info(f"HTTP Client closed: {name}")
|
||
except Exception as e:
|
||
logger.warning(f"HTTP Client close error: {name}: {e}")
|
||
|
||
# 关闭 Gateway(内部会关闭所有 Adapter)
|
||
if hasattr(app.state, "platform_gateway"):
|
||
try:
|
||
await app.state.platform_gateway.close_all()
|
||
logger.info("PlatformGateway closed")
|
||
except Exception as e:
|
||
logger.warning(f"PlatformGateway close error: {e}")
|
||
|
||
await close_db()
|
||
logger.info("Cleanup complete")
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
"""创建 FastAPI 应用实例"""
|
||
|
||
app = FastAPI(
|
||
title=settings.APP_NAME,
|
||
version=settings.APP_VERSION,
|
||
description="美家卡智影 - AI 视频创作后端 API",
|
||
docs_url="/docs" if settings.DEBUG else None,
|
||
redoc_url="/redoc" if settings.DEBUG else None,
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# CORS 配置
|
||
# 开发环境下允许所有来源,避免跨域问题
|
||
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allow_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(api_router, prefix="/api/v1")
|
||
|
||
# PlatformError 全局处理器(第三方平台错误)
|
||
@app.exception_handler(PlatformError)
|
||
async def platform_error_handler(request, exc: PlatformError):
|
||
"""第三方平台调用错误统一处理"""
|
||
http_status = exc.to_http_status()
|
||
content = {
|
||
"code": exc.status_code or http_status,
|
||
"message": str(exc),
|
||
"data": None,
|
||
}
|
||
if settings.DEBUG:
|
||
content["detail"] = {
|
||
"platform": exc.platform,
|
||
"error_type": exc.error_type,
|
||
"retryable": exc.retryable,
|
||
}
|
||
return JSONResponse(status_code=http_status, content=content)
|
||
|
||
# 全局异常处理(统一返回 ApiResponse 格式)
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request, exc):
|
||
"""全局异常捕获"""
|
||
logger.exception("Unhandled exception")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={
|
||
"code": 500,
|
||
"message": "服务器内部错误",
|
||
"data": None,
|
||
"detail": {"error": str(exc)} if settings.DEBUG else None,
|
||
},
|
||
)
|
||
|
||
# 健康检查
|
||
@app.get("/health", tags=["System"])
|
||
async def health_check():
|
||
"""服务健康检查"""
|
||
return ApiResponse(
|
||
code=200,
|
||
data={
|
||
"status": "healthy",
|
||
"version": settings.APP_VERSION,
|
||
"environment": settings.ENV,
|
||
},
|
||
message="服务运行正常",
|
||
)
|
||
|
||
# 根路由
|
||
@app.get("/", tags=["System"])
|
||
async def root():
|
||
"""API 根路径"""
|
||
return ApiResponse(
|
||
code=200,
|
||
data={
|
||
"name": settings.APP_NAME,
|
||
"version": settings.APP_VERSION,
|
||
"docs": "/docs" if settings.DEBUG else None,
|
||
},
|
||
message="美家卡智影 API 服务",
|
||
)
|
||
|
||
return app
|
||
|
||
|
||
# 创建应用实例
|
||
app = create_app()
|
||
|
||
|
||
def main():
|
||
"""入口函数(用于命令行启动)"""
|
||
uvicorn.run(
|
||
"app.main:app",
|
||
host=settings.HOST,
|
||
port=settings.PORT,
|
||
workers=settings.WORKERS if not settings.DEBUG else 1,
|
||
reload=settings.DEBUG,
|
||
log_level=settings.LOG_LEVEL.lower(),
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
# test
|