Files
meijiaka-zy/python-api/app/main.py
T
小鱼开发 e58159fc42 refactor: 第三方平台架构改造(Adapter Protocol + Gateway)
Phase 1: 异常体系统一
- 新增 PlatformError / PlatformErrorType 标准定义
- 改造所有 Provider 异常抛出为 PlatformError
- 注册全局 PlatformError exception handler

Phase 2: Adapter Protocol
- 新增 app/ai/adapters/base.py(PlatformAdapter + SyncCapable + TaskCapable + CallbackCapable)
- 新增 app/ai/adapters/constants.py(Method 常量)
- 新增 PlatformConfigLoader(config/platform-config.yaml)

Phase 3: HTTP Client 统一
- ViduProvider 从 aiohttp 迁移到 httpx(注入方式)
- VolcengineCaptionService 改为注入 http_client
- lifespan 统一管理所有 Client 创建和关闭

Phase 4: Gateway 骨架 + Adapter 实现
- 新增 ViduAdapter / VolcengineArkAdapter / VolcengineCaptionAdapter
- 新增 PlatformGateway(call_sync / submit_task / query_task / handle_webhook)
- 新增 LLMGateway(带 Fallback 降级链)
- lifespan 注册所有 Adapter 和 Gateway

Phase 6: 清理与验证
- 从 Settings 移除 VIDU_BASE_URL / VOLCENGINE_BASE_URL
- Provider 改为从 PlatformConfigLoader 读取 base_url
- 清理 volcengine_caption_service 全局单例
- config_loader 默认路径改为 platform-config.yaml
- Scheduler 注入共享 HTTP client
- vidu.py 回调路由使用 Adapter 验签和解析
- ruff 全量通过,应用启动测试通过
2026-05-04 16:07:16 +08:00

314 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FastAPI 应用入口
================
"""
import logging
import sys
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.v1.router import api_router
from app.config import get_settings
from app.core.exceptions import PlatformError
from app.db.session import close_db, init_db
from app.schemas.common import ApiResponse
settings = get_settings()
# 配置日志 - 同时输出到控制台和文件
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
log_level = getattr(logging, settings.LOG_LEVEL)
# 创建日志目录(在用户文档目录下)
log_dir = Path.home() / "Documents" / "Meijiaka-zy" / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
# 日志文件名按日期
log_file = log_dir / f"api_{datetime.now().strftime('%Y%m%d')}.log"
# 配置根日志记录器
logging.basicConfig(
level=log_level,
format=log_format,
handlers=[
logging.StreamHandler(sys.stdout), # 控制台输出
logging.FileHandler(log_file, encoding="utf-8", mode="a"), # 文件输出
],
)
logger = logging.getLogger(__name__)
logger.info(f"日志文件位置: {log_file}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
- 启动时:初始化数据库、加载模型配置
- 关闭时:清理资源
"""
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
# 开发和测试环境自动创建表
if settings.DEBUG and settings.ENV in ("development", "staging"):
logger.info("Initializing database tables...")
try:
# 确保所有模型已注册到 metadata
from app.models import User # noqa: F401
await init_db()
logger.info("Database tables initialized")
except Exception as e:
logger.warning(f"Database initialization skipped: {e}")
# 加载 AI 模型配置(从 YAML 文件)
try:
from app.core.config_loader import get_config_loader
config_loader = get_config_loader()
platforms_count = len(config_loader.get_all_platforms())
models_count = len(config_loader.get_enabled_models())
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
except Exception as e:
logger.warning(f"Failed to load models from config: {e}")
# 加载空镜素材配置(优先远程 CDN,fallback 本地 JSON
try:
from app.services.material_service import load_config
load_config()
logger.info("Loaded material config")
except Exception as e:
logger.warning(f"Failed to load material config: {e}")
# 初始化 HTTP Client 池(各平台独立,故障隔离)
import httpx
app.state.http_clients = {
"vidu": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
),
"volcengine_caption": httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=5.0),
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
),
"default": httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
),
}
logger.info("HTTP Client pool initialized")
# 初始化各平台 Provider(注入共享 Client
from app.ai.providers.vidu_provider import ViduProvider
from app.services.vidu_service import ViduService
from app.services.volcengine_caption_service import VolcengineCaptionService
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
app.state.vidu_service = ViduService(app.state.vidu_provider)
logger.info("Vidu Provider & Service initialized")
# 火山字幕服务(始终初始化,因为 APPID/TOKEN 可能在后续配置)
try:
app.state.volcengine_caption_service = VolcengineCaptionService(
client=app.state.http_clients["volcengine_caption"]
)
logger.info("Volcengine Caption Service initialized")
except Exception as e:
logger.warning(f"Volcengine Caption Service 初始化跳过: {e}")
app.state.volcengine_caption_service = None
# 火山方舟 Provider(可选,需要 API Key
try:
from app.ai.providers.volcengine_provider import VolcengineProvider
app.state.volcengine_provider = VolcengineProvider()
logger.info("Volcengine Provider initialized")
except Exception as e:
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
app.state.volcengine_provider = None
# 初始化 Adapter(包装 Provider
from app.ai.adapters.vidu_adapter import ViduAdapter
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
from app.platform_gateway import PlatformGateway
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
logger.info("ViduAdapter initialized")
# 初始化 Gateway
app.state.platform_gateway = PlatformGateway()
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
if app.state.volcengine_caption_service:
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
app.state.volcengine_caption_service
)
app.state.platform_gateway.register(
"volcengine_caption", app.state.volcengine_caption_adapter
)
logger.info("VolcengineCaptionAdapter initialized")
logger.info("PlatformGateway initialized")
# LLM Gateway(可选,需要 Volcengine Provider
if app.state.volcengine_provider:
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
from app.ai.gateways.llm_gateway import LLMGateway
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
app.state.volcengine_provider
)
app.state.llm_gateway = LLMGateway(
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
fallback_chains={},
)
logger.info("LLMGateway initialized")
else:
app.state.llm_gateway = None
yield
# 关闭时清理
logger.info("Shutting down...")
# 关闭所有 HTTP Client
for name, client in app.state.http_clients.items():
try:
if not client.is_closed:
await client.aclose()
logger.info(f"HTTP Client closed: {name}")
except Exception as e:
logger.warning(f"HTTP Client close error: {name}: {e}")
# 关闭 Gateway(内部会关闭所有 Adapter
if hasattr(app.state, "platform_gateway"):
try:
await app.state.platform_gateway.close_all()
logger.info("PlatformGateway closed")
except Exception as e:
logger.warning(f"PlatformGateway close error: {e}")
await close_db()
logger.info("Cleanup complete")
def create_app() -> FastAPI:
"""创建 FastAPI 应用实例"""
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="美家卡智影 - AI 视频创作后端 API",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan,
)
# CORS 配置
# 开发环境下允许所有来源,避免跨域问题
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# PlatformError 全局处理器(第三方平台错误)
@app.exception_handler(PlatformError)
async def platform_error_handler(request, exc: PlatformError):
"""第三方平台调用错误统一处理"""
http_status = exc.to_http_status()
content = {
"code": exc.status_code or http_status,
"message": str(exc),
"data": None,
}
if settings.DEBUG:
content["detail"] = {
"platform": exc.platform,
"error_type": exc.error_type,
"retryable": exc.retryable,
}
return JSONResponse(status_code=http_status, content=content)
# 全局异常处理(统一返回 ApiResponse 格式)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常捕获"""
logger.exception("Unhandled exception")
return JSONResponse(
status_code=500,
content={
"code": 500,
"message": "服务器内部错误",
"data": None,
"detail": {"error": str(exc)} if settings.DEBUG else None,
},
)
# 健康检查
@app.get("/health", tags=["System"])
async def health_check():
"""服务健康检查"""
return ApiResponse(
code=200,
data={
"status": "healthy",
"version": settings.APP_VERSION,
"environment": settings.ENV,
},
message="服务运行正常",
)
# 根路由
@app.get("/", tags=["System"])
async def root():
"""API 根路径"""
return ApiResponse(
code=200,
data={
"name": settings.APP_NAME,
"version": settings.APP_VERSION,
"docs": "/docs" if settings.DEBUG else None,
},
message="美家卡智影 API 服务",
)
return app
# 创建应用实例
app = create_app()
def main():
"""入口函数(用于命令行启动)"""
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
workers=settings.WORKERS if not settings.DEBUG else 1,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
)
if __name__ == "__main__":
main()
# test