30536276ba
核心变更:
- 统一第三方接口架构:所有服务走 PlatformGateway(call_sync/submit_task/query_task/handle_webhook)
- 视频生成(Vidu 对口型)纳入 Async Engine,与 script/subtitle/tts 统一为 POST /tasks/{task_type} 模式
- 新增 VideoHandler、TTSHandler,完善 ScriptHandler/SubtitleHandler
- PlatformGateway 生成 internal_task_id,建立 Redis 双向映射,callback 场景传入 Async Engine task_id 保证映射一致
- SlotManager 新增 acquire_ctx 上下文管理器,所有 Handler 统一使用
- ViduAdapter 状态映射归一化(normalize_state/denormalize_state)
- 移除 ViduService Semaphore 和 tenacity 重试,并发控制完全交予 SlotManager
- nonce 防重放下沉到 CallbackCapable 协议
- Service 层错误统一为 PlatformError,路由层错误信息脱敏
- 废弃 /voice/lip-sync,清理 vidu.py 遗留路由
Bug 修复:
- VideoHandler 轮询阶段后添加 continue,防止已提交任务重复创建
- voice.py synthesize_to_file 变量名冲突(request vs request_body)
- PlatformGateway.submit_task 空 data 防护
- ScriptHandler 动态导入 asyncio 改为模块级导入
- SubtitleHandler 完成时补充 progress=100
文档:
- 更新 AGENTS.md 核心功能、运行时架构、异步调度描述
342 lines
11 KiB
Python
342 lines
11 KiB
Python
"""
|
||
FastAPI 应用入口
|
||
================
|
||
"""
|
||
|
||
import logging
|
||
import sys
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from app.api.v1.router import api_router
|
||
from app.config import get_settings
|
||
from app.core.exceptions import PlatformError
|
||
from app.db.session import close_db, init_db
|
||
from app.schemas.common import ApiResponse
|
||
|
||
settings = get_settings()
|
||
|
||
# 配置日志 - 同时输出到控制台和文件
|
||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
log_level = getattr(logging, settings.LOG_LEVEL)
|
||
|
||
# 创建日志目录(在用户文档目录下)
|
||
log_dir = Path.home() / "Documents" / "Meijiaka-zy" / "logs"
|
||
log_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 日志文件名按日期
|
||
log_file = log_dir / f"api_{datetime.now().strftime('%Y%m%d')}.log"
|
||
|
||
# 配置根日志记录器
|
||
logging.basicConfig(
|
||
level=log_level,
|
||
format=log_format,
|
||
handlers=[
|
||
logging.StreamHandler(sys.stdout), # 控制台输出
|
||
logging.FileHandler(log_file, encoding="utf-8", mode="a"), # 文件输出
|
||
],
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"日志文件位置: {log_file}")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
应用生命周期管理
|
||
|
||
- 启动时:初始化数据库、加载模型配置
|
||
- 关闭时:清理资源
|
||
"""
|
||
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
|
||
|
||
# 开发和测试环境自动创建表
|
||
if settings.DEBUG and settings.ENV in ("development", "staging"):
|
||
logger.info("Initializing database tables...")
|
||
try:
|
||
# 确保所有模型已注册到 metadata
|
||
from app.models import User # noqa: F401
|
||
|
||
await init_db()
|
||
logger.info("Database tables initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Database initialization skipped: {e}")
|
||
|
||
# 加载 AI 模型配置(从 YAML 文件)
|
||
try:
|
||
from app.core.config_loader import get_config_loader
|
||
|
||
config_loader = get_config_loader()
|
||
platforms_count = len(config_loader.get_all_platforms())
|
||
models_count = len(config_loader.get_enabled_models())
|
||
|
||
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load models from config: {e}")
|
||
|
||
# 加载空镜素材配置(优先远程 CDN,fallback 本地 JSON)
|
||
try:
|
||
from app.services.material_service import load_config
|
||
|
||
load_config()
|
||
logger.info("Loaded material config")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load material config: {e}")
|
||
|
||
# 初始化 HTTP Client 池(各平台独立,故障隔离)
|
||
import httpx
|
||
|
||
app.state.http_clients = {
|
||
"vidu": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=20),
|
||
),
|
||
"volcengine_caption": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(60.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=10, max_keepalive_connections=10),
|
||
),
|
||
"default": httpx.AsyncClient(
|
||
timeout=httpx.Timeout(30.0, connect=5.0),
|
||
limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
|
||
),
|
||
}
|
||
logger.info("HTTP Client pool initialized")
|
||
|
||
# 初始化各平台 Provider(注入共享 Client)
|
||
from app.ai.providers.vidu_provider import ViduProvider
|
||
from app.ai.providers.volcengine_caption_provider import VolcengineCaptionProvider
|
||
|
||
app.state.vidu_provider = ViduProvider(client=app.state.http_clients["vidu"])
|
||
logger.info("Vidu Provider initialized")
|
||
|
||
# 火山字幕 Provider
|
||
try:
|
||
app.state.volcengine_caption_provider = VolcengineCaptionProvider(
|
||
client=app.state.http_clients["volcengine_caption"]
|
||
)
|
||
logger.info("Volcengine Caption Provider initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Volcengine Caption Provider 初始化跳过: {e}")
|
||
app.state.volcengine_caption_provider = None
|
||
|
||
# 火山方舟 Provider(可选,需要 API Key)
|
||
try:
|
||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||
|
||
app.state.volcengine_provider = VolcengineProvider()
|
||
logger.info("Volcengine Provider initialized")
|
||
except Exception as e:
|
||
logger.warning(f"Volcengine Provider 初始化跳过: {e}")
|
||
app.state.volcengine_provider = None
|
||
|
||
# 初始化 Adapter(包装 Provider)
|
||
from app.ai.adapters.vidu_adapter import ViduAdapter
|
||
from app.ai.adapters.volcengine_caption_adapter import VolcengineCaptionAdapter
|
||
from app.ai.adapters.volcengine_ark_adapter import VolcengineArkAdapter
|
||
from app.platform_gateway import PlatformGateway
|
||
|
||
app.state.vidu_adapter = ViduAdapter(app.state.vidu_provider)
|
||
logger.info("ViduAdapter initialized")
|
||
|
||
# 初始化 Gateway
|
||
app.state.platform_gateway = PlatformGateway()
|
||
app.state.platform_gateway.register("vidu", app.state.vidu_adapter)
|
||
|
||
if app.state.volcengine_caption_provider:
|
||
app.state.volcengine_caption_adapter = VolcengineCaptionAdapter(
|
||
app.state.volcengine_caption_provider
|
||
)
|
||
app.state.platform_gateway.register(
|
||
"volcengine_caption", app.state.volcengine_caption_adapter
|
||
)
|
||
logger.info("VolcengineCaptionAdapter initialized")
|
||
|
||
if app.state.volcengine_provider:
|
||
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
|
||
app.state.volcengine_provider
|
||
)
|
||
app.state.platform_gateway.register(
|
||
"volcengine_ark", app.state.volcengine_ark_adapter
|
||
)
|
||
logger.info("VolcengineArkAdapter registered to PlatformGateway")
|
||
|
||
logger.info("PlatformGateway initialized")
|
||
|
||
# 初始化 ModelRouter(传入 Gateway,确保底层调用走 PlatformGateway)
|
||
try:
|
||
from app.ai.model_router import get_model_router
|
||
|
||
await get_model_router(gateway=app.state.platform_gateway)
|
||
logger.info("ModelRouter initialized with PlatformGateway")
|
||
except Exception as e:
|
||
logger.warning(f"ModelRouter 初始化失败: {e}")
|
||
|
||
# 初始化 Service(传入 Gateway)
|
||
from app.services.vidu_service import ViduService
|
||
from app.services.volcengine_caption_service import VolcengineCaptionService
|
||
|
||
app.state.vidu_service = ViduService(app.state.platform_gateway)
|
||
logger.info("Vidu Service initialized")
|
||
|
||
if app.state.volcengine_caption_provider:
|
||
app.state.volcengine_caption_service = VolcengineCaptionService(
|
||
app.state.platform_gateway
|
||
)
|
||
logger.info("Volcengine Caption Service initialized")
|
||
else:
|
||
app.state.volcengine_caption_service = None
|
||
|
||
# LLM Gateway(可选,向后兼容)
|
||
if app.state.volcengine_provider:
|
||
from app.ai.gateways.llm_gateway import LLMGateway
|
||
|
||
app.state.llm_gateway = LLMGateway(
|
||
adapters={"volcengine_ark": app.state.volcengine_ark_adapter},
|
||
fallback_chains={},
|
||
)
|
||
logger.info("LLMGateway initialized")
|
||
else:
|
||
app.state.llm_gateway = None
|
||
|
||
yield
|
||
|
||
# 关闭时清理
|
||
logger.info("Shutting down...")
|
||
|
||
# 关闭所有 HTTP Client
|
||
for name, client in app.state.http_clients.items():
|
||
try:
|
||
if not client.is_closed:
|
||
await client.aclose()
|
||
logger.info(f"HTTP Client closed: {name}")
|
||
except Exception as e:
|
||
logger.warning(f"HTTP Client close error: {name}: {e}")
|
||
|
||
# 关闭 Gateway(内部会关闭所有 Adapter)
|
||
if hasattr(app.state, "platform_gateway"):
|
||
try:
|
||
await app.state.platform_gateway.close_all()
|
||
logger.info("PlatformGateway closed")
|
||
except Exception as e:
|
||
logger.warning(f"PlatformGateway close error: {e}")
|
||
|
||
await close_db()
|
||
logger.info("Cleanup complete")
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
"""创建 FastAPI 应用实例"""
|
||
|
||
app = FastAPI(
|
||
title=settings.APP_NAME,
|
||
version=settings.APP_VERSION,
|
||
description="美家卡智影 - AI 视频创作后端 API",
|
||
docs_url="/docs" if settings.DEBUG else None,
|
||
redoc_url="/redoc" if settings.DEBUG else None,
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# CORS 配置
|
||
# 开发环境下允许所有来源,避免跨域问题
|
||
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allow_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(api_router, prefix="/api/v1")
|
||
|
||
# PlatformError 全局处理器(第三方平台错误)
|
||
@app.exception_handler(PlatformError)
|
||
async def platform_error_handler(request, exc: PlatformError):
|
||
"""第三方平台调用错误统一处理"""
|
||
http_status = exc.to_http_status()
|
||
content = {
|
||
"code": exc.status_code or http_status,
|
||
"message": str(exc),
|
||
"data": None,
|
||
}
|
||
if settings.DEBUG:
|
||
content["detail"] = {
|
||
"platform": exc.platform,
|
||
"error_type": exc.error_type,
|
||
"retryable": exc.retryable,
|
||
}
|
||
return JSONResponse(status_code=http_status, content=content)
|
||
|
||
# 全局异常处理(统一返回 ApiResponse 格式)
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request, exc):
|
||
"""全局异常捕获"""
|
||
logger.exception("Unhandled exception")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={
|
||
"code": 500,
|
||
"message": "服务器内部错误",
|
||
"data": None,
|
||
"detail": {"error": str(exc)} if settings.DEBUG else None,
|
||
},
|
||
)
|
||
|
||
# 健康检查
|
||
@app.get("/health", tags=["System"])
|
||
async def health_check():
|
||
"""服务健康检查"""
|
||
return ApiResponse(
|
||
code=200,
|
||
data={
|
||
"status": "healthy",
|
||
"version": settings.APP_VERSION,
|
||
"environment": settings.ENV,
|
||
},
|
||
message="服务运行正常",
|
||
)
|
||
|
||
# 根路由
|
||
@app.get("/", tags=["System"])
|
||
async def root():
|
||
"""API 根路径"""
|
||
return ApiResponse(
|
||
code=200,
|
||
data={
|
||
"name": settings.APP_NAME,
|
||
"version": settings.APP_VERSION,
|
||
"docs": "/docs" if settings.DEBUG else None,
|
||
},
|
||
message="美家卡智影 API 服务",
|
||
)
|
||
|
||
return app
|
||
|
||
|
||
# 创建应用实例
|
||
app = create_app()
|
||
|
||
|
||
def main():
|
||
"""入口函数(用于命令行启动)"""
|
||
uvicorn.run(
|
||
"app.main:app",
|
||
host=settings.HOST,
|
||
port=settings.PORT,
|
||
workers=settings.WORKERS if not settings.DEBUG else 1,
|
||
reload=settings.DEBUG,
|
||
log_level=settings.LOG_LEVEL.lower(),
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
# test
|