feat: init meijiaka-zj project from ai-meijiaka template
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
# 配置架构规则
|
||||
|
||||
description: 配置管理架构规范
|
||||
|
||||
## 规则
|
||||
|
||||
### 配置读取
|
||||
- 所有配置必须通过 `from app.config import get_settings` 读取
|
||||
- 禁止直接使用 `os.getenv()` 或 `os.environ.get()`
|
||||
- 禁止在服务层、API 层直接使用环境变量
|
||||
|
||||
### 添加新配置
|
||||
1. 在 `app/config.py` 的 `Settings` 类中定义字段
|
||||
2. 使用 `Field(default=..., description="...")` 提供默认值和说明
|
||||
3. 敏感信息使用 `str | None = None` 类型
|
||||
4. 更新 `.env.example` 文档
|
||||
|
||||
### 在服务中使用配置
|
||||
```python
|
||||
from app.config import get_settings
|
||||
|
||||
def some_function():
|
||||
settings = get_settings()
|
||||
api_key = settings.SOME_API_KEY
|
||||
```
|
||||
|
||||
### 禁止的写法
|
||||
```python
|
||||
import os
|
||||
|
||||
# ❌ 禁止
|
||||
api_key = os.getenv("SOME_API_KEY")
|
||||
api_key = os.environ.get("SOME_API_KEY")
|
||||
```
|
||||
|
||||
### 推荐的写法
|
||||
```python
|
||||
from app.config import get_settings
|
||||
|
||||
# ✅ 正确
|
||||
settings = get_settings()
|
||||
api_key = settings.SOME_API_KEY
|
||||
```
|
||||
@@ -0,0 +1,9 @@
|
||||
[ 117ms] [INFO] %cDownload the React DevTools for a better development experience: https://react.dev/link/react-devtools font-weight:bold @ http://localhost:1420/node_modules/.vite/deps/react-dom_client.js?v=a15e99c2:20102
|
||||
[ 148ms] [LOG] [ScriptCreation] segments changed: 0 @ http://localhost:1420/src/pages/VideoCreation/ScriptCreation.tsx:52
|
||||
[ 148ms] [LOG] [ScriptCreation] segments changed: 0 @ http://localhost:1420/src/pages/VideoCreation/ScriptCreation.tsx:52
|
||||
[ 152ms] [ERROR] [authStore] 加载认证状态失败: TypeError: Cannot read properties of undefined (reading 'invoke')
|
||||
at invoke (http://localhost:1420/node_modules/.vite/deps/chunk-G7S6KQDI.js?v=a15e99c2:109:37)
|
||||
at loadFromStorage (http://localhost:1420/src/store/authStore.ts:72:28) @ http://localhost:1420/src/store/authStore.ts:83
|
||||
[ 152ms] [ERROR] [authStore] 加载认证状态失败: TypeError: Cannot read properties of undefined (reading 'invoke')
|
||||
at invoke (http://localhost:1420/node_modules/.vite/deps/chunk-G7S6KQDI.js?v=a15e99c2:109:37)
|
||||
at loadFromStorage (http://localhost:1420/src/store/authStore.ts:72:28) @ http://localhost:1420/src/store/authStore.ts:83
|
||||
@@ -0,0 +1,31 @@
|
||||
- generic [ref=e4]:
|
||||
- generic [ref=e5]:
|
||||
- generic [ref=e6]:
|
||||
- img "美家卡 智影" [ref=e7]
|
||||
- generic [ref=e8]: 美家卡 智影
|
||||
- paragraph [ref=e9]: AI 驱动的智能视频创作平台
|
||||
- generic [ref=e10]:
|
||||
- heading "欢迎登录" [level=2] [ref=e11]
|
||||
- paragraph [ref=e12]: 使用手机号验证码快速登录
|
||||
- generic [ref=e13]:
|
||||
- generic [ref=e14]:
|
||||
- generic [ref=e15]: 手机号
|
||||
- generic [ref=e16]:
|
||||
- generic [ref=e17]: "+86"
|
||||
- textbox "请输入手机号" [active] [ref=e18]
|
||||
- generic [ref=e19]:
|
||||
- generic [ref=e20]: 验证码
|
||||
- generic [ref=e21]:
|
||||
- textbox "请输入验证码" [ref=e22]
|
||||
- button "获取验证码" [disabled] [ref=e23]
|
||||
- button "登录" [disabled] [ref=e24]
|
||||
- generic [ref=e25]:
|
||||
- checkbox "我已阅读并同意《用户服务协议》和《隐私政策》" [ref=e26] [cursor=pointer]
|
||||
- generic [ref=e27] [cursor=pointer]:
|
||||
- text: 我已阅读并同意
|
||||
- link "《用户服务协议》" [ref=e28]:
|
||||
- /url: "#"
|
||||
- text: 和
|
||||
- link "《隐私政策》" [ref=e29]:
|
||||
- /url: "#"
|
||||
- generic [ref=e30]: meijiaka.cn
|
||||
@@ -0,0 +1 @@
|
||||
3.13
|
||||
@@ -0,0 +1,784 @@
|
||||
<!-- From: /Users/0fun/work/ai-meijiaka/AGENTS.md -->
|
||||
# 美家卡智影 (Meijiaka AI Video) - AI 视频创作平台
|
||||
|
||||
## 项目概述
|
||||
|
||||
美家卡智影是一个 AI 驱动的视频创作桌面应用,采用 **Tauri + React + FastAPI** 混合架构。用户可以通过 AI 生成脚本、创建数字人视频,最终合成完整的营销视频。
|
||||
|
||||
### 核心功能
|
||||
|
||||
- **AI 脚本生成**: 基于 LLM 自动生成视频脚本和分镜
|
||||
- **数字人视频**: 基于 KlingAI 创建数字人视频片段
|
||||
- **字幕生成**: 基于火山引擎豆包语音自动生成字幕并压制到视频
|
||||
- **封面制作**: 提取视频首帧并叠加字幕样式生成封面
|
||||
- **视频合成**: 本地 FFmpeg 处理视频拼接、音频混流、导出成品
|
||||
- **形象克隆**: 基于 KlingAI 的自定义数字人形象管理
|
||||
- **项目管理**: 项目数据本地 JSON 文件存储,认证状态云端同步
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ai-meijiaka/
|
||||
├── python-api/ # FastAPI 后端服务(AI 代理 + 认证 + 任务调度)
|
||||
│ ├── app/
|
||||
│ │ ├── api/v1/ # API 路由 (REST): auth, script, ai_models, klingai,
|
||||
│ │ │ # qiniu, video, avatar, system,
|
||||
│ │ │ # caption, tasks
|
||||
│ │ ├── ai/ # AI 模型路由、Provider、提示词模板
|
||||
│ │ ├── core/ # 安全、配置加载、Token管理器、Redis客户端、异常处理
|
||||
│ │ ├── crud/ # 数据访问层(users, model_usage, avatar)
|
||||
│ │ ├── db/ # 数据库配置(PostgreSQL + asyncpg + SQLAlchemy 2.0)
|
||||
│ │ ├── models/ # SQLAlchemy 模型(users, model_usage_logs, avatars)
|
||||
│ │ ├── schemas/ # Pydantic 校验模型
|
||||
│ │ ├── services/ # AI 服务代理、DTO标准化、七牛/字幕/视频服务
|
||||
│ │ ├── scheduler/ # Async Engine 异步任务调度(video, image, script,
|
||||
│ │ │ # subtitle, copy, avatar_clone)
|
||||
│ │ ├── config.py # Pydantic Settings 配置管理
|
||||
│ │ └── main.py # FastAPI 入口(含生命周期管理)
|
||||
│ ├── config/ # AI 模型配置文件(ai_models.yaml),支持热重载
|
||||
│ ├── alembic/ # 数据库迁移
|
||||
│ ├── scripts/ # 初始化/测试脚本
|
||||
│ ├── pyproject.toml # Python 依赖和工具配置
|
||||
│ ├── requirements.lock # uv 锁定依赖版本
|
||||
│ ├── Makefile # 常用命令封装
|
||||
│ ├── docker-compose.yml
|
||||
│ └── Dockerfile
|
||||
│
|
||||
├── tauri-app/ # Tauri 桌面应用(业务数据本地存储)
|
||||
│ ├── src/ # React 前端源码
|
||||
│ │ ├── api/
|
||||
│ │ │ ├── adapters/ # 数据转换层(前后端字段映射)
|
||||
│ │ │ ├── generated/ # OpenAPI 自动生成类型(只读)
|
||||
│ │ │ ├── modules/ # API 模块封装(HTTP + IPC)
|
||||
│ │ │ ├── client.ts # HTTP 客户端(自动 camelCase↔snake_case)
|
||||
│ │ │ ├── types.ts # 手写核心类型
|
||||
│ │ │ └── ipc.ts # Tauri IPC 调用封装
|
||||
│ │ ├── components/ # 可复用组件
|
||||
│ │ ├── pages/ # 页面组件
|
||||
│ │ ├── store/ # Zustand 状态管理(+ Immer + persist)
|
||||
│ │ ├── hooks/ # 自定义 React Hooks
|
||||
│ │ ├── styles/ # 全局 CSS 变量、主题
|
||||
│ │ └── utils/ # 工具函数
|
||||
│ ├── src-tauri/ # Rust 后端源码
|
||||
│ │ ├── src/
|
||||
│ │ │ ├── lib.rs # Tauri 应用入口,命令注册
|
||||
│ │ │ ├── ffmpeg_cmd.rs # FFmpeg 命令封装
|
||||
│ │ │ ├── video_processing.rs # 视频合成业务逻辑
|
||||
│ │ │ ├── storage/ # 本地存储引擎(原子写入、文件锁、路径净化)
|
||||
│ │ │ ├── commands/ # IPC 命令按领域拆分(project/asset/auth/avatar)
|
||||
│ │ │ ├── api_proxy.rs # Python API 代理转发
|
||||
│ │ │ ├── auth.rs # 认证命令(已迁移至 commands/auth_state.rs)
|
||||
│ │ │ ├── avatar_cache.rs # 头像缓存管理
|
||||
│ │ │ └── utils.rs # 通用工具函数
|
||||
│ │ ├── Cargo.toml
|
||||
│ │ ├── tauri.conf.json
|
||||
│ │ └── binaries/ # 嵌入式 FFmpeg
|
||||
│ ├── package.json
|
||||
│ ├── vite.config.ts
|
||||
│ ├── tsconfig.json
|
||||
│ └── eslint.config.js
|
||||
│
|
||||
└── docs/ # 项目文档
|
||||
├── anytocopy-api.md
|
||||
├── anytocopy-integration.md
|
||||
├── app-update-system.md
|
||||
├── database-design.md
|
||||
├── kling-api-dev.md
|
||||
├── migrate-avatars-to-local.md
|
||||
├── qiniu-kodo-python-sdk-guide.md
|
||||
├── video-generation-flow.md
|
||||
└── volcengine-video-caption-api.md
|
||||
```
|
||||
|
||||
## 技术栈
|
||||
|
||||
### 后端 (python-api)
|
||||
|
||||
**⚠️ Python 版本要求: 3.13+** (项目使用 `|` 类型注解语法)
|
||||
|
||||
| 组件 | 技术 | 版本 | 用途 |
|
||||
|------|------|------|------|
|
||||
| Python | - | 3.13+ | 运行环境 |
|
||||
| Web 框架 | FastAPI | 0.116+ | REST API |
|
||||
| 数据库 | PostgreSQL | 15+ | 用户认证 + 成本统计 + 形象管理 |
|
||||
| ORM | SQLAlchemy | 2.0 (异步) | 数据模型 |
|
||||
| 缓存/调度 | Redis + Async Engine | 5.2+ / 自定义 | 异步任务槽位调度 |
|
||||
| AI SDK | OpenAI / volcengine | 1.58+ / 5.0+ | LLM 调用 |
|
||||
| 认证 | python-jose + passlib | 3.4+ / 1.7+ | JWT 认证 |
|
||||
| 对象存储 | qiniu | 7.13+ | 七牛云存储 |
|
||||
| HTTP 客户端 | httpx + aiohttp | 0.28+ / 3.13+ | 异步 HTTP |
|
||||
| 包管理/构建 | uv | - | 虚拟环境、依赖锁定、Docker 构建 |
|
||||
|
||||
**后端架构说明**:
|
||||
- 后端为"轻量云账号 + 全本地业务数据"模式
|
||||
- 云端仅存储:用户账户、形象元数据、成本统计
|
||||
- 业务数据(项目/脚本/媒体)全部本地存储
|
||||
- 任务调度使用**自定义 Async Engine**(基于 Redis 的槽位管理),**非 Celery**
|
||||
|
||||
### 前端 (tauri-app)
|
||||
|
||||
| 组件 | 技术 | 版本 | 用途 |
|
||||
|------|------|------|------|
|
||||
| 桌面框架 | Tauri | 2.x | 桌面应用壳 |
|
||||
| UI 框架 | React | 19.1+ | 用户界面 |
|
||||
| 路由 | React Router DOM | 7.x | 页面路由(主壳使用 NavigationContext) |
|
||||
| 状态管理 | Zustand | 5.x | 全局状态 + Immer 中间件 |
|
||||
| 数据获取 | SWR | 2.x | 请求缓存 |
|
||||
| 虚拟列表 | @tanstack/react-virtual | 3.x | 大数据列表渲染 |
|
||||
| 构建工具 | Vite | 7.x | 构建、开发服务器 |
|
||||
| 测试 | Vitest + @testing-library | 4.x | 单元测试 |
|
||||
| 类型生成 | openapi-typescript | 7.x | 从 OpenAPI 生成 TS 类型 |
|
||||
|
||||
### Rust 后端 (src-tauri/src)
|
||||
|
||||
| 模块 | 用途 |
|
||||
|------|------|
|
||||
| lib.rs | Tauri 应用入口,命令注册 |
|
||||
| ffmpeg_cmd.rs | FFmpeg 命令封装(首帧提取、字幕压制、封面合成) |
|
||||
| video_processing.rs | 视频合成业务逻辑 |
|
||||
| storage/engine.rs | 本地存储引擎(原子写入、文件锁、路径净化) |
|
||||
| storage/paths.rs | 集中化路径计算 |
|
||||
| commands/project.rs | 项目本地存储 IPC 命令 |
|
||||
| commands/asset.rs | 资源文件保存 IPC 命令 |
|
||||
| commands/auth_state.rs | 认证状态文件持久化 |
|
||||
| api_proxy.rs | Python API 代理转发 |
|
||||
| avatar_cache.rs | 头像视频缓存管理 |
|
||||
|
||||
## 开发环境搭建
|
||||
|
||||
### 1. 启动 Python 后端
|
||||
|
||||
```bash
|
||||
cd python-api
|
||||
|
||||
# 方式一:Docker Compose(推荐)
|
||||
cp .env.example .env
|
||||
docker-compose up -d
|
||||
|
||||
# 方式二:本地开发(若 Docker 不可用)
|
||||
# 启动 PostgreSQL 和 Redis
|
||||
docker-compose up -d db redis
|
||||
|
||||
# 安装依赖(使用 uv)
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# 启动开发服务器(注意:Docker API 会占用 8080 端口)
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# 另开终端启动 Async Engine Scheduler(必须同时启动,否则任务不会执行)
|
||||
python -m app.scheduler.main
|
||||
```
|
||||
|
||||
后端服务地址:
|
||||
- API: http://localhost:8080/api/v1
|
||||
- 文档: http://localhost:8080/docs
|
||||
- 健康检查: http://localhost:8080/health
|
||||
|
||||
**Docker Compose 服务组成**(4 个服务):
|
||||
- `db`: PostgreSQL 15
|
||||
- `redis`: Redis 7
|
||||
- `api`: FastAPI 开发服务器(端口 8080→8000)
|
||||
- `scheduler`: Async Engine 统一调度器,处理所有第三方异步任务
|
||||
|
||||
### 2. 启动 Tauri 前端
|
||||
|
||||
```bash
|
||||
cd tauri-app
|
||||
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 开发模式(自动启动 Vite + Tauri)
|
||||
npm run tauri dev
|
||||
```
|
||||
|
||||
前端窗口:
|
||||
- Vite 开发服务器: http://localhost:1420
|
||||
- 应用窗口: 1440×960(最小 960×640,可调整大小)
|
||||
|
||||
## 构建命令
|
||||
|
||||
### Python 后端
|
||||
|
||||
```bash
|
||||
cd python-api
|
||||
|
||||
# 使用 Makefile(推荐)
|
||||
make dev # 安装开发依赖 + pre-commit 钩子
|
||||
make lint # ruff + mypy
|
||||
make format # black + ruff --fix
|
||||
make test # pytest
|
||||
make test-cov # 覆盖率报告
|
||||
make security # bandit + pip-audit
|
||||
make lint-semantic # 语义层禁词检查
|
||||
make ci # 运行所有 CI 检查(format-check + lint + lint-semantic + test + security)
|
||||
make docker-run # Docker Compose 启动全部服务
|
||||
make scheduler # 启动 Async Engine Scheduler
|
||||
|
||||
# 手动命令
|
||||
black app/
|
||||
ruff check app/
|
||||
mypy app/
|
||||
bandit -c pyproject.toml -r app/
|
||||
pip-audit
|
||||
pytest
|
||||
pytest --cov=app
|
||||
|
||||
# 导出 OpenAPI 文档到前端
|
||||
python3 -c "
|
||||
import logging
|
||||
logging.disable(logging.WARNING)
|
||||
from app.main import app
|
||||
import json
|
||||
print(json.dumps(app.openapi(), indent=2, ensure_ascii=False))
|
||||
" > ../tauri-app/src/api/generated/openapi.json
|
||||
|
||||
# Docker 构建
|
||||
docker build -t meijiaka-api .
|
||||
```
|
||||
|
||||
### Tauri 前端
|
||||
|
||||
```bash
|
||||
cd tauri-app
|
||||
|
||||
# 开发
|
||||
npm run dev # 纯 Vite 开发(不启动 Tauri)
|
||||
npm run tauri dev # 完整 Tauri 开发模式
|
||||
|
||||
# 构建
|
||||
npm run build # 前端生产构建
|
||||
npm run tauri build # 打包桌面应用
|
||||
|
||||
# 测试
|
||||
npm run test # 运行 Vitest
|
||||
npm run test:ui # UI 模式
|
||||
npm run test:coverage # 覆盖率报告
|
||||
|
||||
# 代码质量
|
||||
npm run lint # ESLint 检查
|
||||
npm run lint:fix # ESLint 自动修复
|
||||
npm run format # Prettier 格式化
|
||||
npm run format:check # Prettier 格式检查
|
||||
npm run stylelint # CSS 检查
|
||||
npm run stylelint:fix # CSS 自动修复
|
||||
|
||||
# 类型生成
|
||||
npm run gen:api # 从 OpenAPI 生成 TypeScript 类型
|
||||
```
|
||||
|
||||
## 架构说明
|
||||
|
||||
### 混合路由架构
|
||||
|
||||
前端 API 调用采用 **智能路由** 策略:
|
||||
|
||||
1. **HTTP 直连 Python**: 纯数据 API(脚本生成、模型管理、任务轮询等)
|
||||
2. **Tauri IPC → Rust**: 需要本地能力的 API(FFmpeg、文件系统)
|
||||
|
||||
路由决策在 `tauri-app/src/api/client.ts` 中实现。HTTP 客户端会自动处理 `camelCase` ↔ `snake_case` 字段名转换。需要走 Rust IPC 的 API 包括:
|
||||
- `video_composite_synthesis` // FFmpeg 视频合成
|
||||
- `burn_subtitle` // 字幕压制
|
||||
- `extract_video_first_frame` // 首帧提取
|
||||
- `generate_cover_image` // 封面生成
|
||||
- `save_project_meta*` / `load_project_meta*` // 本地文件系统
|
||||
- `save_project_segments*` / `load_project_segments*`
|
||||
- `save_project_asset` / `get_video_save_path` / `get_image_save_path`
|
||||
- `save_final_product`
|
||||
- 头像缓存相关 API
|
||||
|
||||
**添加新 API 流程**:
|
||||
1. Python 端实现端点
|
||||
2. 前端直接调用(默认 HTTP)
|
||||
3. 仅当需要本地能力时,在 Rust 中添加命令并在 `lib.rs` 注册
|
||||
|
||||
### AI Provider 架构
|
||||
|
||||
后端 AI 模块采用 **多 Provider 路由** 设计:
|
||||
|
||||
```
|
||||
app/ai/
|
||||
├── model_router.py # 模型路由器(自动降级)
|
||||
├── providers/
|
||||
│ ├── base.py # Provider 抽象基类
|
||||
│ ├── generic_llm_provider.py # 通用 OpenAI 兼容 Provider
|
||||
│ ├── volcengine_provider.py # 火山方舟官方 SDK
|
||||
│ └── klingai_provider.py # KlingAI 数字人
|
||||
└── prompts/ # 提示词模板(禁止硬编码)
|
||||
```
|
||||
|
||||
支持的 AI 平台:
|
||||
- **火山方舟** (字节跳动) - 推荐,性价比高
|
||||
- **OpenAI** - GPT 系列
|
||||
- **文心一言** (百度)
|
||||
- **通义千问** (阿里云)
|
||||
- **可灵 AI** (快手) - 视频生成、数字人、形象克隆
|
||||
|
||||
AI 模型配置位于 `python-api/config/ai_models.yaml`,支持热重载,无需重启服务即可更新模型配置。
|
||||
|
||||
### Async Engine(异步任务调度)
|
||||
|
||||
**⚠️ 重要:项目不使用 Celery,使用自定义 Async Engine**
|
||||
|
||||
架构:
|
||||
```
|
||||
API (POST /tasks/{type}) → Redis JobRegistry → AsyncEngine tick loop → Handlers
|
||||
```
|
||||
|
||||
组件:
|
||||
- **`AsyncEngine`** (`app/scheduler/engine.py`): 每 ~10s 执行 `tick()`,加载运行中任务,按类型分组,并行分发给 Handler,通过 Pipeline 应用 `StateChange`,清理已完成任务
|
||||
- **`JobRegistry`** (`app/scheduler/registry.py`): Redis-based 任务 CRUD,使用 `job:{id}` hash + `scheduler:running_tasks` SET
|
||||
- **`SlotManager`** (`app/scheduler/slot_manager.py`): Redis Lua 原子脚本实现并发槽位抢占/释放
|
||||
- **`JobRecord`** / **`StateChange`** (`app/scheduler/models.py`): 调度器内部类型
|
||||
|
||||
已注册的 Handler(`app/scheduler/main.py`):
|
||||
|
||||
| Handler | 槽位数 | Redis Key | 用途 |
|
||||
|---------|--------|-----------|------|
|
||||
| VideoHandler | 18 | `kling:video_slots` | Kling 视频生成(omni + image2video) |
|
||||
| ImageHandler | 9 | `kling:image_slots` | Kling 图片生成 |
|
||||
| ScriptHandler | 10 | `script:slots` | LLM 脚本生成(含 AnyToCopy 视频文案提取) |
|
||||
| SubtitleHandler | 5 | `volc:subtitle_slots` | 火山引擎字幕/自动对齐 |
|
||||
| CopyHandler | 5 | `anytocopy:slots` | AnyToCopy 视频文案提取 |
|
||||
| AvatarHandler | 2 | `kling:avatar_slots` | Kling 形象克隆(状态机: pending→voice_processing→element_pending→element_processing→succeed) |
|
||||
|
||||
### TokenManager(API 认证 Token 管理)
|
||||
|
||||
`app/core/token_manager.py` 提供通用的 API 认证 Token 缓存与自动刷新:
|
||||
|
||||
```python
|
||||
from app.core.token_manager import JWTTokenStrategy, TokenManager
|
||||
|
||||
class MyProvider:
|
||||
def __init__(self, access_key: str, secret_key: str):
|
||||
self._token_strategy = JWTTokenStrategy(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
expires_in=1800, # 30分钟
|
||||
)
|
||||
|
||||
async def _get_headers(self) -> dict[str, str]:
|
||||
token_info = await TokenManager.get_instance().get_token(self._token_strategy)
|
||||
return {"Authorization": f"Bearer {token_info.token}"}
|
||||
```
|
||||
|
||||
**特性**:
|
||||
- Token 缓存(避免重复生成)
|
||||
- 自动刷新(Token 即将过期时自动刷新)
|
||||
- 并发安全(双重检查锁定,确保并发请求只生成一次 Token)
|
||||
- 后台预热(提前 10 分钟刷新,避免请求时等待)
|
||||
- 支持 JWT、OAuth2 等多种策略
|
||||
|
||||
### 本地存储引擎(Rust)
|
||||
|
||||
Rust 层实现了 defense-in-depth 的本地存储系统:`src-tauri/src/storage/`
|
||||
|
||||
- **`engine.rs`**: 核心原子操作
|
||||
- `sanitize_id()` — 白名单 `[a-zA-Z0-9_-]+`,防御路径遍历
|
||||
- `sanitize_filename()` — 提取纯文件名,拒绝目录组件
|
||||
- `atomic_write_json()` / `atomic_write_bytes()` — 先写 `.tmp` 再 `rename` 原子替换
|
||||
- `with_file_lock()` — 通过 `fs2` 实现独占文件锁
|
||||
- `read_json<T>()` — 安全读取,文件不存在返回 `None`
|
||||
- **`paths.rs`**: 集中路径计算
|
||||
- `~/Documents/Meijiaka/projects/{id}/` (meta.json, segments.json, assets/)
|
||||
- `~/Documents/Meijiaka/products/`
|
||||
- `{app_config_dir}/auth.json`
|
||||
- `{app_data_dir}/avatars/`
|
||||
|
||||
**所有本地 JSON 读写必须经过 StorageEngine,禁止在命令处理器中直接调用 `fs::write`**。
|
||||
|
||||
### 数据库模型
|
||||
|
||||
后端仅保留 **3 个表**:
|
||||
|
||||
```
|
||||
users -- 用户账户信息(mobile, nickname, avatar_url)
|
||||
model_usage_logs -- 大模型调用记录(token, 成本, 响应时间)
|
||||
avatars -- 克隆形象元数据(云端备份,前端已迁移至本地 JSON)
|
||||
```
|
||||
|
||||
**业务数据本地存储**:
|
||||
- 项目/脚本/分镜 → 前端本地 JSON 文件(`~/Documents/Meijiaka/projects/`)
|
||||
- 音频/视频/图片文件 → 本地磁盘
|
||||
- 用户配置 → localStorage(少量 UI 状态)
|
||||
|
||||
### 数据流规范
|
||||
|
||||
```
|
||||
用户输入主题 ──→ 后端 AI 生成脚本 ──→ 后端返回分镜列表 ──→ 前端保存到本地
|
||||
│ │
|
||||
└────────────────── 后端不存储脚本数据 ────────────────────────┘
|
||||
```
|
||||
|
||||
### 本地存储结构
|
||||
|
||||
```
|
||||
~/Documents/Meijiaka/ # 用户文档目录
|
||||
├── config.json # 全局配置
|
||||
├── projects/ # 项目数据
|
||||
│ └── {project_id}/
|
||||
│ ├── meta.json # 项目元数据
|
||||
│ ├── segments.json # 分镜数据
|
||||
│ └── assets/ # 资源文件(封面、成品等)
|
||||
├── products/ # 成品视频目录
|
||||
├── avatars.json # 形象列表(本地)
|
||||
└── cache/ # 缓存目录
|
||||
```
|
||||
|
||||
项目元数据 `meta.json` 关键字段:
|
||||
- `id`, `title`, `topic`, `status` (draft | published)
|
||||
- `currentStep`: 1=脚本生成, 2=形象视频, 3=字幕压制, 4=封面制作, 5=视频合成
|
||||
- `createdAt`, `updatedAt`, `exportedAt`
|
||||
- `coverPath`, `finalVideoPath`
|
||||
- `selectedElementId`, `selectedHumanId`
|
||||
- `coverConfig`, `scriptDuration`, `scriptType`
|
||||
|
||||
分镜数据 `segments.json` 字段:
|
||||
- `id`, `type` (segment | empty_shot), `scene`, `voiceover`, `duration`
|
||||
- `videoPath`, `videoUrl`, `elementId`, `voiceId`
|
||||
- `alignmentResult`, `burnedVideoPath`, `burnedAt`
|
||||
|
||||
### 前端导航
|
||||
|
||||
主应用壳使用 **自定义 NavigationContext**(React Context)实现页面切换,映射 `Record<PageType, ComponentType>`。`react-router-dom` 已安装但主要用于未来扩展或特定路由场景,当前主流程不使用 BrowserRouter 进行导航。
|
||||
|
||||
### 状态管理
|
||||
|
||||
六個專門的 Zustand store:
|
||||
|
||||
| Store | 职责 | 持久化 |
|
||||
|-------|------|--------|
|
||||
| `authStore` | JWT、UserInfo、登录/登出 | Tauri `auth.json`(或 localStorage fallback) |
|
||||
| `projectStore` | 分镜、currentStep、选题、封面配置 | **仅 UI 标志**通过 `persist`;业务数据显式写入本地 JSON |
|
||||
| `taskStore` | 异步任务状态/进度/消息 | **无**(内存 only,真相源在后端 Redis) |
|
||||
| `uiStore` | Toast 通知队列 | 无 |
|
||||
| `progressStore` | 全局进度模态框 | 无 |
|
||||
| `settingsStore` | 主题模式、用户偏好 | localStorage |
|
||||
|
||||
`projectStore` **不自动保存**。数据在显式过渡点持久化到磁盘(如进入 step 2、调用 `setFinalVideoPath` 时触发 `saveMetaToLocalFile`)。`saveMetaToLocalFile()` 通过 Promise 链串行化写入,避免并发覆盖。
|
||||
|
||||
## 开发规范
|
||||
|
||||
### 核心原则
|
||||
|
||||
1. **后端环境优先使用 Docker Compose**: 开发时通过 `docker-compose up -d` 启动后端。前端默认连接 `http://127.0.0.1:8080/api/v1`。
|
||||
2. **接口契约优先**: 后端承诺无论使用什么 AI 模型,输出永远符合同一个 Schema
|
||||
3. **类型单一来源**: 后端 Schema 是权威,前端通过 OpenAPI 生成类型
|
||||
4. **Adapter 层隔离**: 前后端字段差异只允许在 Adapter 层处理
|
||||
5. **数据库分层**: API → Service → CRUD → Model,禁止跨层调用
|
||||
6. **提示词文件化**: 除前端输入外,后端不允许硬编码任何 Prompt
|
||||
7. **配置统一管理**: 所有配置通过 `get_settings()` 读取,禁止直接使用 `os.getenv()`
|
||||
8. **本地存储必须经过 StorageEngine**: Rust 层所有文件操作使用 `atomic_write_json` + `with_file_lock`
|
||||
|
||||
### 配置管理规范
|
||||
|
||||
**架构层级:**
|
||||
```
|
||||
.env (Layer 1) ──→ Settings (Layer 2) ──→ 服务层 (Layer 3)
|
||||
↑
|
||||
唯一配置出口
|
||||
```
|
||||
|
||||
**强制规范:**
|
||||
- **所有服务**必须使用 `from app.config import get_settings` 读取配置
|
||||
- **禁止**在服务层使用 `os.getenv()` 或 `os.environ.get()`
|
||||
- **所有配置项**必须在 `app/config.py` 的 `Settings` 类中定义
|
||||
- **敏感信息**(API Keys、Secrets)必须通过环境变量注入
|
||||
- **业务默认值**可以硬编码在 `Settings` 中
|
||||
|
||||
**添加新配置流程:**
|
||||
1. 在 `app/config.py` 的 `Settings` 类中添加字段定义
|
||||
2. 在 `.env` 中添加实际值(敏感信息)或使用默认值
|
||||
3. 在服务层通过 `get_settings()` 读取
|
||||
4. 更新 `.env.example` 文档
|
||||
|
||||
### 语义层防护网
|
||||
|
||||
项目强制执行语义分层,禁止供应商术语泄漏到业务层:
|
||||
|
||||
| 层级 | 职责 | 禁词示例 |
|
||||
|------|------|----------|
|
||||
| Layer 6 (Presentation) | API Schema | `element_id`, `kling_task_id` |
|
||||
| Layer 4 (Orchestration) | Scheduler | `task_id`(应使用 `job_id`) |
|
||||
| Layer 3 (Domain) | Service | 供应商特定术语 |
|
||||
| Layer 2 (Adapter) | Provider | 允许使用供应商原生术语 |
|
||||
|
||||
Makefile 提供 `make lint-semantic` 进行自动化检查:
|
||||
- API 层(除 `klingai.py`)禁止使用 `element_id`(应使用 `provider_element_id` 或 `human_id`)
|
||||
- Scheduler 层禁止使用 `task_id`(应使用 `job_id`)
|
||||
- 全局禁止 `kling_task_id`(应使用 `provider_task_id`)
|
||||
- Scheduler Redis key 必须使用 `job:` 而非 `task:`
|
||||
|
||||
### 快速参考
|
||||
|
||||
| 场景 | 正确做法 |
|
||||
|------|---------|
|
||||
| 后端换 AI 模型 | 修改 `services/ai_response_utils.py` 标准化层,不修改 Schema |
|
||||
| 后端新增字段 | `Optional[T] = Field(None)`,向后兼容 |
|
||||
| 后端修改字段 | 保留旧字段,标记 deprecated,逐步迁移 |
|
||||
| 前端需要新字段 | Store 中 `extends` 基础类型 |
|
||||
| 数据清洗 | **只在** Adapter 层,禁止在组件层 |
|
||||
| 新增数据库实体 | 创建 Model → CRUD → API(分层开发)|
|
||||
| 数据库查询 | 在 CRUD 层封装,API 层调用 |
|
||||
| 事务管理 | API 层控制,通过 `get_db` 依赖注入 |
|
||||
| 新增提示词 | 创建 `.txt` 文件,使用 `_load_prompt()` 加载 |
|
||||
| 新增本地文件操作 | 使用 `storage::engine` 原子写入 + 文件锁 |
|
||||
|
||||
### 后端分层架构
|
||||
|
||||
```
|
||||
API Layer (api/v1/*.py)
|
||||
↓ 调用
|
||||
Service Layer (services/*.py) - 可选,复杂业务
|
||||
↓ 调用
|
||||
CRUD Layer (crud/*.py)
|
||||
↓ 调用
|
||||
Model Layer (models/*.py)
|
||||
↓ 调用
|
||||
Database Layer (db/*.py)
|
||||
```
|
||||
|
||||
**禁止**:
|
||||
- API 层直接操作 Model
|
||||
- CRUD 层返回 Schema(应返回 Model)
|
||||
- Service 层直接操作数据库(应通过 CRUD)
|
||||
- 在业务代码中写 SQL
|
||||
|
||||
### 前端类型规范
|
||||
|
||||
| 层级 | 类型来源 | 说明 |
|
||||
|------|----------|------|
|
||||
| 后端 Schema | `python-api/app/schemas/*.py` | Pydantic 模型,OpenAPI 生成源 |
|
||||
| 前端基础类型 | `tauri-app/src/api/types.ts` | 手写的核心类型,与后端对齐 |
|
||||
| 前端完整类型 | `tauri-app/src/api/generated/schema.ts` | OpenAPI 自动生成,只读 |
|
||||
| Store 扩展 | `tauri-app/src/store/*.ts` | `extends` 基础类型添加前端字段 |
|
||||
|
||||
### 间距规范
|
||||
|
||||
前端使用基于 4px 的网格系统,定义在 `tauri-app/src/styles/variables.css`:
|
||||
|
||||
| 变量 | 值 | 使用场景 |
|
||||
|------|-----|----------|
|
||||
| `--spacing-2xs` | 2px | 微调控件、边框线 |
|
||||
| `--spacing-xs` | 4px | 紧凑间隙、图标边距 |
|
||||
| `--spacing-sm` | 8px | 小间隙、按钮内边距-y |
|
||||
| `--spacing-md` | 12px | 标准间隙、卡片内边距 |
|
||||
| `--spacing-lg` | 16px | 大间隙、区块间距 |
|
||||
| `--spacing-xl` | 24px | 页面区块、内容分隔 |
|
||||
| `--spacing-2xl` | 32px | 大区块间距、页面边距 |
|
||||
| `--spacing-3xl` | 48px | 页面级间距、Hero 区域 |
|
||||
|
||||
## 代码风格
|
||||
|
||||
### Python
|
||||
|
||||
- **格式化**: Black (line-length: 100, target-version: py313)
|
||||
- **检查**: Ruff (E, F, I, N, W, UP, B, C4, SIM)
|
||||
- **类型**: MyPy(非严格模式全局,但 `app.schemas.*`、`app.crud.*`、`app.scheduler.handlers.*` 强制严格模式)
|
||||
- **文档**: 中文注释,Google Style Docstrings
|
||||
- **安全**: Bandit + pip-audit
|
||||
- **Git Hooks**: pre-commit(Black、Ruff、uv lock 同步检查)
|
||||
|
||||
### TypeScript/React
|
||||
|
||||
- **类型**: 严格 TypeScript 模式(`strict: true`, `noUnusedLocals: true`, `noUnusedParameters: true`)
|
||||
- **组件**: 函数组件 + Hooks
|
||||
- **状态**: Zustand 管理全局状态(配合 Immer 处理不可变更新)
|
||||
- **样式**: 普通 CSS + CSS 变量(`tauri-app/src/styles/variables.css`)
|
||||
- **ESLint**: 使用 `eslint.config.js`(Flat Config),含 React Hooks 和 React Refresh 规则
|
||||
- **Prettier**: semi=true, singleQuote=true, tabWidth=2, printWidth=100
|
||||
- **Stylelint**: `stylelint-config-standard`,禁止 magic px 用于 `border-radius` 和 `font-size`
|
||||
|
||||
### Rust
|
||||
|
||||
- **格式化**: rustfmt
|
||||
- **检查**: cargo clippy
|
||||
- **注释**: 中文文档注释
|
||||
|
||||
### 提交规范
|
||||
|
||||
```
|
||||
feat: 新功能
|
||||
fix: 修复
|
||||
docs: 文档
|
||||
refactor: 重构
|
||||
test: 测试
|
||||
chore: 构建/工具
|
||||
```
|
||||
|
||||
## 测试策略
|
||||
|
||||
### 后端测试
|
||||
|
||||
```bash
|
||||
cd python-api
|
||||
|
||||
# 运行所有测试
|
||||
pytest -v
|
||||
|
||||
# 覆盖率报告
|
||||
pytest --cov=app --cov-report=html --cov-report=term
|
||||
```
|
||||
|
||||
**测试配置** (`pyproject.toml`):
|
||||
- asyncio_mode = "auto"
|
||||
- 测试文件命名: `test_*.py`
|
||||
|
||||
> **注**:当前项目中 `python-api/tests/` 目录尚未创建,后端测试待补充。
|
||||
|
||||
### 前端测试
|
||||
|
||||
```bash
|
||||
cd tauri-app
|
||||
|
||||
# 运行 Vitest
|
||||
npm run test
|
||||
|
||||
# UI 模式
|
||||
npm run test:ui
|
||||
|
||||
# 覆盖率报告
|
||||
npm run test:coverage
|
||||
```
|
||||
|
||||
**测试配置**:
|
||||
- 测试框架: Vitest 4.x + @testing-library/react + jsdom
|
||||
- 测试文件: `src/**/*.test.ts(x)`
|
||||
- Mock 配置: `src/__tests__/setup.ts`
|
||||
- 自动 Mock: localStorage, Tauri API (`@tauri-apps/api/core`)
|
||||
- 示例测试: `src/store/__tests__/authStore.test.tsx`
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
1. **SECRET_KEY**: 生产环境必须修改为强随机密钥(`get_settings()` 会在生产环境校验)
|
||||
2. **CORS**: 生产环境限制为实际前端域名,开发环境 `DEBUG=true` 时允许所有来源
|
||||
3. **API Keys**: 不要提交到 Git,使用 `.env` 文件注入
|
||||
4. **FFmpeg**: 嵌入的二进制文件需验证来源
|
||||
5. **文件上传**: 限制文件类型和大小,防止攻击
|
||||
6. **路径遍历**: Rust StorageEngine 的 `sanitize_id()` 和 `sanitize_filename()` 防御路径遍历攻击
|
||||
7. **原子写入**: 所有本地 JSON 使用 `atomic_write_json`(先写 `.tmp` 再 `rename`)
|
||||
8. **文件锁**: 并发 RMW 操作使用 `with_file_lock` 防止竞态
|
||||
9. **日志**: 后端日志写入 `~/Documents/Meijiaka/logs/api_YYYYMMDD.log`
|
||||
|
||||
## 配置说明
|
||||
|
||||
### Python 后端 (.env)
|
||||
|
||||
关键环境变量:
|
||||
|
||||
```bash
|
||||
# 数据库 (PostgreSQL)
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/meijiaka
|
||||
|
||||
# Redis
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
|
||||
# JWT 密钥(生产环境必须修改)
|
||||
SECRET_KEY=your-secret-key-here-change-in-production
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080
|
||||
|
||||
# AI API Keys
|
||||
VOLCENGINE_API_KEY=your-volcengine-key
|
||||
VOLCENGINE_CAPTION_APPID=your-caption-appid
|
||||
VOLCENGINE_CAPTION_TOKEN=your-caption-token
|
||||
KLINGAI_ACCESS_KEY=your-kling-access-key
|
||||
KLINGAI_SECRET_KEY=your-kling-secret-key
|
||||
OPENAI_API_KEY=sk-your-openai-key
|
||||
|
||||
# 七牛云存储
|
||||
QINIU_ACCESS_KEY=your-qiniu-access-key
|
||||
QINIU_SECRET_KEY=your-qiniu-secret-key
|
||||
QINIU_VIDEO_BUCKET=media-liche
|
||||
QINIU_IMAGE_BUCKET=img-liche
|
||||
|
||||
# CORS 允许的前端地址
|
||||
CORS_ORIGINS=http://localhost:1420,http://127.0.0.1:1420,http://localhost:8080
|
||||
```
|
||||
|
||||
### Tauri 配置 (tauri.conf.json)
|
||||
|
||||
```json
|
||||
{
|
||||
"productName": "美家卡智影",
|
||||
"identifier": "cn.meijiaka.ai-video",
|
||||
"build": {
|
||||
"devUrl": "http://localhost:1420",
|
||||
"frontendDist": "../dist"
|
||||
},
|
||||
"bundle": {
|
||||
"externalBin": ["binaries/ffmpeg"],
|
||||
"resources": {
|
||||
"fonts/*": "fonts/"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### AI 模型配置 (config/ai_models.yaml)
|
||||
|
||||
模型配置文件支持热重载,无需重启服务即可更新模型配置。主要配置项:
|
||||
|
||||
- **platforms**: AI 平台配置(mock, volcengine, klingai)
|
||||
- **models**: 可用模型列表及其能力标签 [script, polish, chat, image, embedding, vision]
|
||||
- **task_defaults**: 任务类型到模型的默认映射
|
||||
|
||||
## 视频创作流程
|
||||
|
||||
1. **脚本生成** (Step 1) - AI 生成视频脚本和分镜
|
||||
2. **形象视频** (Step 2) - 选择数字人形象,生成视频片段
|
||||
3. **字幕压制** (Step 3) - 生成字幕并压制到视频中
|
||||
4. **封面制作** (Step 4) - 生成视频封面
|
||||
5. **视频合成** (Step 5) - FFmpeg 拼接视频片段,导出最终视频
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 火山方舟如何配置?
|
||||
|
||||
1. 注册火山引擎账号并实名认证
|
||||
2. 创建 API Key
|
||||
3. 开通模型并创建推理接入点
|
||||
4. 在 `.env` 中设置 `VOLCENGINE_API_KEY`
|
||||
|
||||
### Q: 可灵 AI 如何配置?
|
||||
|
||||
1. 前往可灵 AI 开发者平台 https://klingai.com/document-api
|
||||
2. 获取 Access Key 和 Secret Key
|
||||
3. 在 `.env` 中设置 `KLINGAI_ACCESS_KEY` 和 `KLINGAI_SECRET_KEY`
|
||||
|
||||
### Q: FFmpeg 在哪里?
|
||||
|
||||
Tauri 应用已嵌入 FFmpeg 二进制文件:
|
||||
- 位置: `tauri-app/src-tauri/binaries/ffmpeg-*`
|
||||
- 使用: Rust 层通过 `ffmpeg_cmd` 模块调用
|
||||
- 打包时会作为 `externalBin` 资源嵌入
|
||||
|
||||
### Q: 后端换了 AI 模型,输出格式变了怎么办?
|
||||
|
||||
修改 `services/ai_response_utils.py` 中的标准化函数,增加新的字段映射,**不要**修改 API Schema。
|
||||
|
||||
### Q: 如何新增/修改提示词?
|
||||
|
||||
1. 创建文件: `app/ai/prompts/my_prompt.txt`
|
||||
2. 加载使用: `prompt = self._load_prompt("my_prompt")`
|
||||
3. **禁止**: 在 Python 代码中直接写 `"""你是一位..."""`
|
||||
|
||||
### Q: 项目数据是如何持久化的?
|
||||
|
||||
- 项目元数据(`meta.json`)和分镜数据(`segments.json`)保存在 `~/Documents/Meijiaka/projects/{project_id}/`
|
||||
- 不通过 Zustand `persist` 保存项目数据,而是通过 `localProjectApi` 显式调用 Tauri IPC 写入文件
|
||||
- `projectStore` 的 `persist` 中间件仅保存少量 UI 状态
|
||||
|
||||
### Q: Async Engine 和 Celery 有什么区别?
|
||||
|
||||
本项目使用**自定义 Async Engine** 替代 Celery:
|
||||
- 基于 Redis 的槽位管理(SlotManager),限制各类型任务的并发数
|
||||
- 独立的 `scheduler` 进程(`python -m app.scheduler.main`)
|
||||
- 每个 Handler 实现 `AsyncHandler` 接口,状态机驱动任务生命周期
|
||||
- 优势:更细粒度的并发控制、统一状态机、无 Celery 依赖
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-04-17
|
||||
**架构模式**: 单机版(轻量云账号 + 全本地业务数据)
|
||||
@@ -0,0 +1,518 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## 项目概述
|
||||
|
||||
美家卡智影 (Meijiaka AI Video) - AI 视频创作平台。一个 AI 驱动的桌面应用,采用 **Tauri + React + FastAPI** 混合架构,用户可以通过 AI 生成脚本、创建数字人视频,自动生成字幕,最终本地合成完整的营销视频。
|
||||
|
||||
### 环境要求
|
||||
|
||||
| 组件 | 版本要求 |
|
||||
|------|----------|
|
||||
| Python | **3.13+** (代码使用 `|` 类型注解语法) |
|
||||
| Node.js | 20+ |
|
||||
| Rust | 1.70+ |
|
||||
| Docker | 20+ (可选,用于数据库) |
|
||||
|
||||
核心设计理念:**轻量云账号 + 全本地业务数据** - 云端只存储用户认证和使用日志,所有项目/脚本/媒体都存在用户本地。
|
||||
|
||||
## 架构
|
||||
|
||||
### 混合架构
|
||||
|
||||
- **FastAPI 后端**: 处理 AI 模型调用、用户认证、API 服务
|
||||
- **Tauri + React 前端**: 桌面 UI,React 负责渲染,Tauri 提供系统能力
|
||||
- **Rust 后端**: 通过 Tauri IPC 处理本地操作(FFmpeg 视频处理、文件系统访问)
|
||||
|
||||
### 存储策略
|
||||
|
||||
核心设计理念:**轻量云账号 + 全本地业务数据** - 云端只存储用户认证和使用日志,所有项目/脚本/媒体都存在用户本地。
|
||||
|
||||
- **云端**: PostgreSQL 只存储 2 张表:`users` (用户账户)、`model_usage_logs` (用量统计)
|
||||
- `avatars` 表已废弃:数字人名片元数据现在纯本地存储 `avatars.json`
|
||||
- **本地**: JSON 文件存储项目/脚本/分镜数据、数字人元数据,用户磁盘存储媒体文件,FFmpeg 处理视频合成
|
||||
- **缓存/队列**: Redis + Async Engine Scheduler 处理异步任务
|
||||
|
||||
### 混合通信模式
|
||||
|
||||
| 通信模式 | 使用场景 | 前端调用方式 |
|
||||
|---------|---------|------------|
|
||||
| HTTP → FastAPI | AI 生成、认证、配置管理 | `client.get/post/put/delete()` |
|
||||
| Tauri IPC → Rust | FFmpeg 视频处理、本地文件系统 | `ipc.request()` 或直接 `invoke()` |
|
||||
|
||||
**通信模块**:
|
||||
- `tauri-app/src/api/client.ts` - HTTP 客户端,自动处理 camelCase/snake_case 转换
|
||||
- `tauri-app/src/api/ipc.ts` - IPC 客户端
|
||||
- `tauri-app/src/api/modules/localStorage.ts` - 本地项目存储(走 IPC)
|
||||
- `tauri-app/src/api/modules/videoComposite.ts` - 视频合成(走 IPC)
|
||||
|
||||
### AI Provider 架构
|
||||
|
||||
后端 AI 模块采用多 Provider 设计:
|
||||
- `app/ai/model_router.py` - 模型路由器,支持自动降级
|
||||
- `app/ai/providers/base.py` - 抽象基类
|
||||
- `app/ai/providers/*` - 具体实现(OpenAI、火山引擎、KlingAI 等)
|
||||
- `app/ai/prompts/` - 提示词模板文件
|
||||
|
||||
支持的 AI 平台:火山方舟(推荐)、OpenAI、百度文心一言、阿里云通义千问、KlingAI(数字人视频生成)。
|
||||
|
||||
模型配置文件:`python-api/config/ai_models.yaml`(支持热重载)
|
||||
|
||||
### Token 管理
|
||||
|
||||
外部 API 认证 Token 使用 `app/core/token_manager.py` 统一管理:
|
||||
- Token 缓存(避免重复生成)
|
||||
- 自动刷新(Token 即将过期时自动刷新)
|
||||
- 并发安全(双重检查锁定)
|
||||
- 支持 JWT、OAuth2 等多种策略
|
||||
|
||||
### 数据流
|
||||
|
||||
1. **脚本生成**: 用户输入 → FastAPI AI 代理 → 标准化输出 → 前端保存到本地 JSON
|
||||
2. **数字人视频**: 后端调用 KlingAI API → 返回视频 URL → 前端下载并本地存储
|
||||
3. **视频合成**: 前端 → Tauri IPC → Rust 后端 → FFmpeg → 渲染最终视频文件
|
||||
|
||||
### 本地存储结构(用户机器)
|
||||
|
||||
```
|
||||
~/Documents/Meijiaka/
|
||||
├── config.json # 全局应用配置
|
||||
├── projects/
|
||||
│ └── {project_id}/
|
||||
│ ├── meta.json # 项目元数据
|
||||
│ ├── segments.json # 脚本/分镜数据
|
||||
│ └── assets/ # 媒体文件
|
||||
├── avatars/
|
||||
│ └── {avatar_id}/
|
||||
│ ├── meta.json # 数字人名片配置
|
||||
│ └── source.mp4 # 源视频
|
||||
└── cache/ # 临时文件
|
||||
```
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
ai-meijiaka/
|
||||
├── python-api/ # FastAPI 后端服务
|
||||
│ ├── app/
|
||||
│ │ ├── api/v1/ # REST API 端点
|
||||
│ │ ├── ai/ # AI 模型路由和 Provider
|
||||
│ │ ├── ai/prompts/ # 提示词模板文件
|
||||
│ │ ├── core/ # 安全、配置、异常处理
|
||||
│ │ ├── db/ # 数据库配置
|
||||
│ │ ├── models/ # SQLAlchemy 数据模型
|
||||
│ │ ├── schemas/ # Pydantic 验证模型
|
||||
│ │ ├── services/ # 业务逻辑和 AI 服务代理
|
||||
│ │ ├── scheduler/ # Async Engine 统一异步调度器
|
||||
│ │ ├── config.py # 配置管理
|
||||
│ │ └── main.py # 应用入口
|
||||
│ ├── config/ # AI 模型配置(YAML)
|
||||
│ ├── tests/ # pytest 测试套件
|
||||
│ ├── scripts/ # 管理和测试脚本
|
||||
│ └── docker-compose.yml # Docker 服务编排
|
||||
│
|
||||
├── tauri-app/ # Tauri 桌面应用
|
||||
│ ├── src/ # React 前端源码
|
||||
│ │ ├── api/ # API 客户端和类型
|
||||
│ │ │ ├── adapters/ # 前后端字段差异适配
|
||||
│ │ │ ├── generated/ # OpenAPI 自动生成类型
|
||||
│ │ │ └── modules/ # API 模块封装
|
||||
│ │ ├── components/ # 可复用 React 组件
|
||||
│ │ ├── pages/ # 页面组件(路由)
|
||||
│ │ ├── store/ # Zustand 全局状态管理
|
||||
│ │ ├── hooks/ # 自定义 React Hooks
|
||||
│ │ └── utils/ # 前端工具函数
|
||||
│ ├── src-tauri/ # Rust 后端
|
||||
│ │ ├── src/
|
||||
│ │ │ ├── lib.rs # Tauri 应用入口,命令注册
|
||||
│ │ │ ├── commands/ # 按领域拆分的命令模块
|
||||
│ │ │ │ ├── asset.rs # 资源文件操作
|
||||
│ │ │ │ ├── auth_state.rs # 认证状态管理
|
||||
│ │ │ │ ├── avatar.rs # 数字人头像管理
|
||||
│ │ │ │ ├── product.rs # 产品相关
|
||||
│ │ │ │ └── project.rs # 项目存储操作
|
||||
│ │ │ ├── storage/ # 存储引擎分层
|
||||
│ │ │ │ ├── mod.rs # 模块导出
|
||||
│ │ │ │ ├── paths.rs # 路径计算
|
||||
│ │ │ │ ├── engine.rs # 核心存储引擎(原子写+文件锁)
|
||||
│ │ │ │ ├── auth.rs # 认证存储
|
||||
│ │ │ │ ├── project.rs # 项目存储
|
||||
│ │ │ │ ├── avatar.rs # 头像存储
|
||||
│ │ │ │ └── cache.rs # 缓存存储
|
||||
│ │ │ ├── ffmpeg_cmd.rs # FFmpeg 命令封装
|
||||
│ │ │ ├── video_processing.rs # 视频合成逻辑
|
||||
│ │ │ ├── api_proxy.rs # Python API 代理
|
||||
│ │ │ ├── avatar_cache.rs # 头像视频缓存管理
|
||||
│ │ │ └── utils.rs # 通用工具函数
|
||||
│ │ ├── binaries/ # 嵌入的 FFmpeg 可执行文件
|
||||
│ │ └── Cargo.toml # Rust 依赖配置
|
||||
│ └── package.json # NPM 依赖和脚本
|
||||
│
|
||||
└── docs/ # 开发文档
|
||||
```
|
||||
|
||||
## 常用命令
|
||||
|
||||
### 后端 (python-api)
|
||||
|
||||
项目使用 `uv` 进行依赖管理,并提供了 `Makefile` 封装常用命令:
|
||||
|
||||
```bash
|
||||
cd python-api
|
||||
|
||||
# 使用 uv 和 Makefile(推荐)
|
||||
make dev # 安装开发依赖并配置 pre-commit
|
||||
make docker-run # 使用 Docker Compose 启动所有服务(db, redis, api, scheduler)
|
||||
make run # 启动 FastAPI 开发服务器
|
||||
make scheduler # 启动 Async Engine Scheduler
|
||||
make lint # 运行代码检查 (ruff + mypy)
|
||||
make format # 格式化代码
|
||||
make test # 运行所有测试
|
||||
make security # 运行安全扫描 (bandit + pip-audit)
|
||||
|
||||
# 手动方式
|
||||
# 安装依赖
|
||||
python -m venv venv && source venv/bin/activate
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# 启动 PostgreSQL + Redis(必需)
|
||||
docker-compose up -d db redis
|
||||
|
||||
# 启动 FastAPI 开发服务器
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# 启动 Async Engine Scheduler(另开终端)
|
||||
python -m app.scheduler.main
|
||||
|
||||
# 代码质量
|
||||
black app/ # 格式化代码(行宽 100)
|
||||
ruff check app/ # 代码检查
|
||||
mypy app/ # 严格类型检查
|
||||
bandit -c pyproject.toml -r app/ # 安全扫描
|
||||
pip-audit # 依赖漏洞检测
|
||||
python scripts/check_config_architecture.py # 检查配置架构一致性
|
||||
|
||||
# 导出 OpenAPI 文档到前端
|
||||
python3 -c "
|
||||
import logging
|
||||
logging.disable(logging.WARNING)
|
||||
from app.main import app
|
||||
import json
|
||||
print(json.dumps(app.openapi(), indent=2, ensure_ascii=False))
|
||||
" > ../tauri-app/src/api/generated/openapi.json
|
||||
|
||||
# 测试
|
||||
pytest # 运行所有测试
|
||||
pytest tests/test_script.py -v # 运行单个测试文件
|
||||
pytest --cov=app # 覆盖率报告
|
||||
|
||||
# Docker
|
||||
docker-compose up -d # 启动所有服务(db, redis, api, scheduler)
|
||||
|
||||
# 端口占用检查
|
||||
lsof -i :8080 # 检查 8080 端口占用
|
||||
```
|
||||
|
||||
**可用 Makefile 命令:**
|
||||
|
||||
| 命令 | 用途 |
|
||||
|------|------|
|
||||
| `make help` | 显示帮助信息 |
|
||||
| `make install` | 安装生产依赖(使用 lock 文件)|
|
||||
| `make dev` | 安装开发依赖并配置 pre-commit |
|
||||
| `make update-lock` | 更新 requirements.lock |
|
||||
| `make lint` | 运行代码检查 (ruff + mypy) |
|
||||
| `make format` | 格式化代码 (black + ruff) |
|
||||
| `make format-check` | 检查代码格式(不修改)|
|
||||
| `make test` | 运行测试 |
|
||||
| `make test-cov` | 运行测试并生成覆盖率报告 |
|
||||
| `make security` | 运行安全扫描 |
|
||||
| `make run` | 启动开发服务器 |
|
||||
| `make scheduler` | 启动 Async Engine Scheduler |
|
||||
| `make docker-run` | Docker Compose 启动全部服务 |
|
||||
| `make docker-down` | 停止 Docker 服务 |
|
||||
| `make clean` | 清理缓存文件 |
|
||||
| `make ci` | 运行所有 CI 检查 |
|
||||
|
||||
### 前端 (tauri-app)
|
||||
|
||||
```bash
|
||||
cd tauri-app
|
||||
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 开发
|
||||
npm run dev # 仅启动 Vite(不打开 Tauri 窗口)
|
||||
npm run tauri dev # 完整 Tauri 桌面开发模式
|
||||
|
||||
# 构建
|
||||
npm run build # 前端生产构建
|
||||
npm run tauri build # 打包桌面应用(.dmg/.exe/.AppImage)
|
||||
|
||||
# 代码质量
|
||||
npm run lint # ESLint 检查 JS/TS
|
||||
npm run lint:fix # ESLint 自动修复
|
||||
npm run format # Prettier 格式化代码
|
||||
npm run stylelint # CSS 检查
|
||||
|
||||
# 测试
|
||||
npm run test # 运行 Vitest
|
||||
npm run test:coverage # 覆盖率报告
|
||||
npm run test:ui # 打开 Vitest UI
|
||||
|
||||
# 类型生成
|
||||
npm run gen:api # 从 OpenAPI schema 生成 TypeScript 类型
|
||||
```
|
||||
|
||||
### 数据库迁移
|
||||
|
||||
项目使用 Alembic 进行数据库迁移:
|
||||
|
||||
```bash
|
||||
cd python-api
|
||||
|
||||
# 生成新迁移(修改模型后)
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# 应用迁移
|
||||
alembic upgrade head
|
||||
|
||||
# 回滚迁移
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
### 开发提示
|
||||
|
||||
- **Tauri 调试**: 使用 `npm run tauri dev` 时,Rust 后端日志在终端输出,前端日志在浏览器控制台
|
||||
- **本地项目路径**: 项目数据保存在 `~/Documents/Meijiaka/projects/{project_id}/`
|
||||
- **配置修改**: AI 模型配置 `python-api/config/ai_models.yaml` 支持热重载,无需重启服务
|
||||
- **类型同步**: 修改后端 API 后,记得重新导出 OpenAPI 并运行 `npm run gen:api`
|
||||
- **Async Engine Scheduler**: 系统使用 Slot-Based Scheduler 统一调度所有第三方异步任务:
|
||||
- `video` - 数字人视频生成(18 slots)
|
||||
- `avatar_clone` - 形象克隆(2 slots)
|
||||
- `image` - 图片生成(9 slots)
|
||||
- `subtitle` - 字幕生成(5 slots)
|
||||
- `copy` - 文案提取(5 slots)
|
||||
- **任务状态**: 任务状态唯一真相源为后端 Redis,`taskStore` 不持久化,启动时从后端 `GET /tasks` 查询
|
||||
- **项目数据**: 项目元数据和分镜数据通过 IPC 显式写入本地文件,不通过 Zustand persist 持久化
|
||||
- **字幕渲染**: 使用 `assjs` 库进行 ASS/SSA 字幕预览渲染,WASM 和 Worker 文件通过 Vite 插件复制到 `public/` 目录,修改资源路径后需要检查插件配置
|
||||
|
||||
## 开发规范
|
||||
|
||||
### 后端 (Python)
|
||||
|
||||
- **格式化**: Black (行宽: 100)
|
||||
- **检查**: Ruff
|
||||
- **类型**: MyPy (strict 模式)
|
||||
- **架构**: API → Service → CRUD → Model,禁止跨层调用
|
||||
- **数据库**: 始终使用异步 SQLAlchemy,事务在 API 层控制
|
||||
- **AI 集成**: 无论使用什么提供者,输出 Schema 必须保持一致,在 Service 层标准化
|
||||
- **提示词**: 所有提示词放在 `app/ai/prompts/` 单独文件,不硬编码
|
||||
- **配置管理**: 所有配置通过 `from app.config import get_settings` 读取,禁止直接使用 `os.getenv()`,所有配置项必须在 `Settings` 类中定义
|
||||
|
||||
### 配置管理强制规范
|
||||
|
||||
**架构层级:**
|
||||
```
|
||||
.env (Layer 1) ──→ Settings (Layer 2) ──→ 服务层 (Layer 3)
|
||||
↑
|
||||
唯一配置出口
|
||||
```
|
||||
|
||||
**强制规则:**
|
||||
- **所有服务**必须使用 `from app.config import get_settings` 读取配置
|
||||
- **禁止**在服务层、API 层直接使用 `os.getenv()` 或 `os.environ.get()`
|
||||
- **所有配置项**必须在 `app/config.py` 的 `Settings` 类中定义
|
||||
- **敏感信息**(API Keys、Secrets)必须通过环境变量注入
|
||||
- **业务默认值**可以硬编码在 `Settings` 中
|
||||
|
||||
**添加新配置流程:**
|
||||
1. 在 `app/config.py` 的 `Settings` 类中添加字段定义
|
||||
2. 使用 `Field(default=..., description="...")` 提供默认值和说明
|
||||
3. 敏感信息使用 `str | None = None` 类型
|
||||
4. 更新 `.env.example` 文档
|
||||
|
||||
### Rust (Tauri 后端)
|
||||
|
||||
- **格式化**: `rustfmt`(默认配置)
|
||||
- **检查**: `cargo clippy`(零警告)
|
||||
- **模块组织**: 命令按领域拆分到 `src/commands/{domain}.rs`,在 `lib.rs` 中注册
|
||||
- **存储分层**: 存储逻辑按领域拆分到 `src/storage/{domain}.rs`
|
||||
- **命令参数**: Tauri IPC 命令必须使用 Args 结构体接收参数:
|
||||
```rust
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SaveProjectMetaArgs {
|
||||
pub project_id: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
```
|
||||
- **禁止**: 命令函数直接使用 camelCase 参数名(会产生 `non_snake_case` 警告)
|
||||
|
||||
### 本地数据存储规范(Tauri/Rust)
|
||||
|
||||
**分层架构:**
|
||||
```
|
||||
Layer 1: 页面组件(Pages/Components) — 只操作 Store,禁止直接调用 IPC save
|
||||
Layer 2: Zustand Store(内存状态) — Immer 不可变更新
|
||||
Layer 3: PersistManager(持久化协调) — debounce 批量、flush 强制、错误上报
|
||||
Layer 4: API 模块(localStorageApi 等) — 类型安全的 IPC 调用封装
|
||||
Layer 5: Rust StorageEngine(文件系统) — sanitize + atomic_write + file_lock
|
||||
```
|
||||
|
||||
**强制规范:**
|
||||
1. **禁止页面组件直接调用 `localProjectApi.saveXxx()`** — 必须通过 Store → PersistManager
|
||||
2. **禁止 Rust 命令函数直接 `fs::write`** — 必须通过 `StorageEngine::atomic_write_json`
|
||||
3. **所有 ID 参数必须 `sanitize_id`** — 路径参数白名单校验(`[a-zA-Z0-9_-]+`)
|
||||
4. **所有 JSON 写操作必须原子化** — 临时文件 + `fs::rename`
|
||||
5. **RMW 操作必须加锁** — `with_file_lock` 或 Mutex
|
||||
|
||||
**StorageEngine 核心能力:**
|
||||
- `sanitize_id(id)` — ID 白名单校验,防御路径遍历
|
||||
- `sanitize_filename(name)` — 提取纯文件名,拒绝目录组件
|
||||
- `atomic_write_json(path, value)` — 先写 `.tmp` 再 rename,防崩溃截断
|
||||
- `with_file_lock(path, f)` — 文件锁保护 RMW 操作
|
||||
- `read_json<T>(path)` — 安全读取,文件不存在返回 `None`,损坏返回 `Err`
|
||||
|
||||
### 前端 (TypeScript/React)
|
||||
|
||||
- **类型**: 严格 TypeScript 模式
|
||||
- **组件**: 函数组件 + Hooks
|
||||
- **状态管理**: Zustand 管理全局状态,Immer 处理不可变更新
|
||||
- **数据获取**: SWR 缓存,自动 localStorage 降级
|
||||
- **API 客户端**: 从后端 OpenAPI schema 自动生成类型
|
||||
- **命名风格**: camelCase(自动与后端 snake_case 转换)
|
||||
- **本地存储**: 项目数据通过 Tauri IPC 保存到 `~/Documents/Meijiaka/projects/`
|
||||
|
||||
### 提交规范
|
||||
|
||||
```
|
||||
feat: 新功能
|
||||
fix: 修复
|
||||
docs: 文档
|
||||
refactor: 重构
|
||||
test: 测试
|
||||
chore: 构建/工具
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
### 后端 (.env)
|
||||
|
||||
```bash
|
||||
# 数据库
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/meijiaka
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# JWT 认证
|
||||
SECRET_KEY=your-secret-key-here
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080
|
||||
|
||||
# AI 服务凭证
|
||||
VOLCENGINE_API_KEY=your-volcengine-key
|
||||
VOLCENGINE_CAPTION_APPID=your-caption-appid
|
||||
VOLCENGINE_CAPTION_TOKEN=your-caption-token
|
||||
OPENAI_API_KEY=sk-your-openai-key
|
||||
KLINGAI_ACCESS_KEY=your-kling-access-key
|
||||
KLINGAI_SECRET_KEY=your-kling-secret-key
|
||||
|
||||
# 七牛云存储(数字人视频持久化)
|
||||
QINIU_ACCESS_KEY=your-qiniu-access-key
|
||||
QINIU_SECRET_KEY=your-qiniu-secret-key
|
||||
QINIU_VIDEO_BUCKET=media-bucket
|
||||
QINIU_IMAGE_BUCKET=image-bucket
|
||||
|
||||
# CORS 配置
|
||||
CORS_ORIGINS=http://localhost:1420,http://127.0.0.1:1420,http://localhost:8080
|
||||
```
|
||||
|
||||
## 服务地址
|
||||
|
||||
- API: http://localhost:8080/api/v1
|
||||
- 文档: http://localhost:8080/docs
|
||||
- Vite 开发服务器: http://localhost:1420
|
||||
|
||||
## 关键开发文件
|
||||
|
||||
| 文件 | 用途 |
|
||||
|------|------|
|
||||
| `python-api/app/main.py` | FastAPI 应用入口 |
|
||||
| `python-api/app/api/v1/*.py` | API 端点定义 |
|
||||
| `python-api/app/ai/model_router.py` | AI 模型路由和降级 |
|
||||
| `python-api/app/services/*.py` | 业务逻辑和 AI 响应标准化 |
|
||||
| `python-api/config/ai_models.yaml` | AI 模型配置 |
|
||||
| `tauri-app/src/App.tsx` | 主 React 组件 |
|
||||
| `tauri-app/src/api/client.ts` | 智能路由的 API 客户端 |
|
||||
| `tauri-app/src/store/projectStore.ts` | 项目状态管理 |
|
||||
| `tauri-app/src-tauri/src/lib.rs` | Rust 命令注册 |
|
||||
| `tauri-app/src-tauri/src/commands/project.rs` | 项目存储 IPC 命令 |
|
||||
| `tauri-app/src-tauri/src/storage/engine.rs` | 核心存储引擎(原子写+校验)|
|
||||
| `tauri-app/src-tauri/src/video_processing.rs` | FFmpeg 视频合成 |
|
||||
| `tauri-app/src-tauri/src/avatar_cache.rs` | 头像视频缓存管理 |
|
||||
| `python-api/app/core/token_manager.py` | API Token 缓存与自动刷新 |
|
||||
| `python-api/app/config.py` | Pydantic Settings 配置管理 |
|
||||
| `tauri-app/src/pages/VideoCreation/SubtitleBurning.tsx` | 字幕压制页面(ASS 字幕渲染) |
|
||||
| `tauri-app/src/hooks/useAssJsRenderer.ts` | assjs 字幕渲染 Hook |
|
||||
| `tauri-app/src/utils/assGenerator.ts` | ASS 字幕文件生成工具 |
|
||||
|
||||
## 额外开发文档
|
||||
|
||||
项目 `docs/` 目录包含详细的深度开发文档:
|
||||
|
||||
| 文档 | 主题 |
|
||||
|------|------|
|
||||
| `docs/video-generation-flow.md` | 完整视频生成流程说明 |
|
||||
| `docs/kling-api-dev.md` | KlingAI 数字人视频 API 对接开发文档 |
|
||||
| `docs/app-update-system.md` | 应用自动更新系统设计 |
|
||||
| `docs/anytocopy-integration.md` | 版权素材集成说明 |
|
||||
| `docs/anytocopy-api.md` | 版权素材 API 文档 |
|
||||
| `docs/volcengine-video-caption-api.md` | 火山引擎字幕 API 对接 |
|
||||
| `docs/qiniu-kodo-python-sdk-guide.md` | 七牛云存储 SDK 集成指南 |
|
||||
| `docs/database-design.md` | 数据库设计文档 |
|
||||
| `docs/unified-async-scheduler.md` | 统一异步调度器设计 |
|
||||
| `docs/semantic-refactoring-plan.md` | 后端语义重构计划 |
|
||||
| `docs/migrate-avatars-to-local.md` | 头像数据迁移到本地说明 |
|
||||
|
||||
## 统一术语表(语义治理)
|
||||
|
||||
后端代码已完成语义治理重构,所有开发必须遵守统一术语表,禁止使用废弃别名。
|
||||
|
||||
整个后端划分为 6 个语义层级,每一层只使用属于该层的术语:
|
||||
|
||||
```
|
||||
Layer 6: Presentation (API Schema / 前端适配层) → Segment, Human, Job, Script
|
||||
Layer 5: Application (API 路由) → Segment, Human, Job, Project
|
||||
Layer 4: Orchestration (Scheduler / SlotManager) → Job, JobRecord, Slot, Handler
|
||||
Layer 3: Domain (Service / 业务逻辑) → Segment, Human, VideoComposition, Caption
|
||||
Layer 2: Adapter (Provider Client) → KlingJob, KlingElement, VolcJob, ProviderTaskId
|
||||
Layer 1: Infrastructure (DB / Redis / HTTP) → 底层技术术语
|
||||
```
|
||||
|
||||
### 术语对照表
|
||||
|
||||
| 业务概念 | 官方术语 | 使用层级 | 禁止使用的别名 |
|
||||
|---------|---------|---------|--------------|
|
||||
| 视频分镜 | `Segment` | Layer 3-6 | `shot`, `scene_desc` |
|
||||
| 数字人形象 | `Human` / `Avatar` | Layer 3-6(DB 用 `avatar`,API 用 `human_id`) | `element`, `character` |
|
||||
| 调度器工作单元 | `Job` | Layer 4 | `task` |
|
||||
| 供应商侧任务 | `ProviderJob` | Layer 2 | `kling_task`, `volc_task` |
|
||||
| 供应商任务 ID | `provider_task_id` | Layer 2-4 | `kling_task_id`, `video_task_id`, `image_task_id` |
|
||||
| 分镜状态 | `SegmentStatus` | Layer 3-4 | 裸字符串 |
|
||||
| 调度器状态 | `JobStatus` | Layer 4 | 裸字符串 |
|
||||
| 形象克隆状态 | `AvatarCloneStatus` | Layer 3 | 裸字符串 |
|
||||
| Kling 原始状态 | `KlingTaskStatus` | **Layer 2 仅限** | 泄漏到 Layer 3+ |
|
||||
|
||||
### 分层禁令
|
||||
|
||||
1. **API 层 (`app/api/v1/`)**:禁止出现 `element_id`, `kling_task_id`, `shot_type`, `omni`
|
||||
2. **Scheduler 层 (`app/scheduler/`)**:禁止出现 `task_id`(应为 `job_id`),禁止构造供应商 prompt 语法
|
||||
3. **Service 层 (`app/services/`)**:禁止出现 `<<<element_1>>>` 等供应商专用语法
|
||||
4. **Provider 层 (`app/ai/providers/`)**:允许使用 `element_id`, `kling_task_id`, `KlingTaskStatus`
|
||||
|
||||
### 类型禁令
|
||||
|
||||
- 跨层传递的接口禁止裸用 `dict[str, Any]`。`params`、`result`、`changes` 等字段必须使用 Pydantic 模型或 TypedDict
|
||||
- 状态字段禁止使用裸字符串,必须使用对应的 `StrEnum`
|
||||
- CRUD 层 `obj_in` 禁止裸字典,必须使用 `CreateSchema` / `UpdateSchema`
|
||||
Vendored
BIN
Binary file not shown.
@@ -0,0 +1,357 @@
|
||||
# AnyToCopy API 开发文档
|
||||
|
||||
> 原文档:https://www.anytocopy.com/account/api/docs
|
||||
>
|
||||
> 功能:支持 50+ 平台视频文案提取、视频去水印
|
||||
|
||||
## 概述
|
||||
|
||||
AnyToCopy API 提供视频/图片文案提取功能,支持抖音、小红书、快手等 50+ 平台。
|
||||
|
||||
**核心功能**:
|
||||
- 视频文案提取(语音转文字)
|
||||
- 视频去水印下载
|
||||
- 图片去水印下载
|
||||
- 支持 50+ 内容平台
|
||||
|
||||
---
|
||||
|
||||
## 基础信息
|
||||
|
||||
| 项目 | 内容 |
|
||||
|------|------|
|
||||
| **Base URL** | `https://api.anytocopy.com/vip/open-api/v1` |
|
||||
| **协议** | HTTPS |
|
||||
| **数据格式** | JSON |
|
||||
|
||||
### 鉴权方式
|
||||
|
||||
在请求头中携带 API Key 和 Secret:
|
||||
|
||||
```http
|
||||
X-API-Key: your_api_key
|
||||
X-API-Secret: your_api_secret
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 接口列表
|
||||
|
||||
### 1. 提交视频文案提取任务
|
||||
|
||||
创建提取任务,返回 `taskId` 用于后续查询。
|
||||
|
||||
#### 请求
|
||||
|
||||
| 项目 | 内容 |
|
||||
|------|------|
|
||||
| **Method** | POST |
|
||||
| **Endpoint** | `/video/extract` |
|
||||
|
||||
#### 请求参数
|
||||
|
||||
| 参数名 | 类型 | 必填 | 说明 |
|
||||
|--------|------|------|------|
|
||||
| `workUrl` | String | 是 | 作品链接(支持抖音、小红书等) |
|
||||
| `taskType` | String | 否 | 任务类型,默认 `TEXT`(文案提取) |
|
||||
|
||||
#### curl 示例
|
||||
|
||||
```bash
|
||||
curl -X POST 'https://api.anytocopy.com/vip/open-api/v1/video/extract?workUrl=https://v.douyin.com/xxx&taskType=TEXT' \
|
||||
-H 'X-API-Key: your_api_key' \
|
||||
-H 'X-API-Secret: your_api_secret'
|
||||
```
|
||||
|
||||
#### 响应示例
|
||||
|
||||
**成功响应(HTTP 200)**:
|
||||
```json
|
||||
{
|
||||
"msg": "任务已提交",
|
||||
"code": 200,
|
||||
"data": "2008802706718072832"
|
||||
}
|
||||
```
|
||||
|
||||
**失败响应(并发限制)**:
|
||||
```json
|
||||
{
|
||||
"msg": "您的并发任务已达上限(5/5),请等待任务完成后再试",
|
||||
"code": 500
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 查询任务状态和结果
|
||||
|
||||
根据 `taskId` 查询任务进度与提取结果。
|
||||
|
||||
#### 请求
|
||||
|
||||
| 项目 | 内容 |
|
||||
|------|------|
|
||||
| **Method** | GET |
|
||||
| **Endpoint** | `/video/query` |
|
||||
|
||||
#### 请求参数
|
||||
|
||||
| 参数名 | 类型 | 必填 | 说明 |
|
||||
|--------|------|------|------|
|
||||
| `taskId` | String | 是 | 任务 ID(提交任务时返回) |
|
||||
|
||||
#### curl 示例
|
||||
|
||||
```bash
|
||||
curl -X GET 'https://api.anytocopy.com/vip/open-api/v1/video/query?taskId=2008802706718072832' \
|
||||
-H 'X-API-Key: your_api_key' \
|
||||
-H 'X-API-Secret: your_api_secret'
|
||||
```
|
||||
|
||||
#### 响应示例
|
||||
|
||||
**任务完成(SUCCESS)**:
|
||||
```json
|
||||
{
|
||||
"msg": "操作成功",
|
||||
"code": 200,
|
||||
"data": {
|
||||
"taskId": "2008802706718072832",
|
||||
"title": "小个子女生如何逆袭第一眼大美女",
|
||||
"content": "#听劝改造[话题]# #如何找到自己的风格",
|
||||
"videoUrl": "https://sns-video-bd.xhscdn.com/f0370019b934b9b6e_258.mp4",
|
||||
"videoUrlList": ["https://sns-video-bd.xhscdn.com/stream/79258.mp4"],
|
||||
"imageUrlList": ["https://ci.xiaohongshu.com/1040g2sg31r0hdqhjnge05q"],
|
||||
"cover": "https://ci.xiaohongshu.com/1040g2sg31r0hdqhjnge05",
|
||||
"textContent": "小个子女生真的不要再和别人卷身高上的天赋...",
|
||||
"platform": "xhs",
|
||||
"audioUrl": "https://pub-6026ae78487b47e5bd4a5b8a0d9ae5aa.r2.dev/audio.mp3",
|
||||
"duration": 156.36,
|
||||
"workType": "video",
|
||||
"status": "SUCCESS",
|
||||
"errorMessage": "视频处理成功!",
|
||||
"createBy": "60227",
|
||||
"createTime": "2026-01-07 15:27:42"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**任务处理中(WAITING)**:
|
||||
```json
|
||||
{
|
||||
"msg": "操作成功",
|
||||
"code": 200,
|
||||
"data": {
|
||||
"taskId": "2008805155734429696",
|
||||
"title": "今日摘抄,不知道原创是谁,太多了",
|
||||
"content": "今日摘抄,不知道原创是谁,太多了,可以在",
|
||||
"videoUrl": "https://sns-video-qc.xhscdn.com/stream/79/258.mp4",
|
||||
"videoUrlList": ["https://sns-video-qc.xhscdn.com/stream/79/258.mp4"],
|
||||
"imageUrlList": ["https://ci.xiaohongshu.com/spectrum/1040g0k031qo"],
|
||||
"cover": "https://ci.xiaohongshu.com/spectrum/1040g0k031qoj7pr9gm905",
|
||||
"textContent": "",
|
||||
"platform": "xhs",
|
||||
"audioUrl": null,
|
||||
"duration": null,
|
||||
"workType": "video",
|
||||
"status": "WAITING",
|
||||
"errorMessage": "作品内容提取中...",
|
||||
"createBy": "60227",
|
||||
"createTime": "2026-01-07 15:37:26"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**任务失败(FAILURE)**:
|
||||
```json
|
||||
{
|
||||
"msg": "操作成功",
|
||||
"code": 200,
|
||||
"data": {
|
||||
"taskId": "2008805155734429696",
|
||||
"status": "FAILURE",
|
||||
"errorMessage": "任务执行失败"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 响应状态码
|
||||
|
||||
| 状态码 | 说明 | 场景 |
|
||||
|--------|------|------|
|
||||
| 200 | 成功 | 任务创建成功或查询成功 |
|
||||
| 500 | 失败 | 并发任务已达上限或其他错误 |
|
||||
|
||||
---
|
||||
|
||||
## 任务状态说明
|
||||
|
||||
| 状态值 | 说明 | 处理建议 |
|
||||
|--------|------|----------|
|
||||
| `WAITING` | 任务等待中或处理中 | 继续轮询查询任务状态 |
|
||||
| `PROCESSING` | 任务处理中 | 继续轮询查询任务状态 |
|
||||
| `SUCCESS` | 任务执行成功 | 可获取完整的提取结果数据 |
|
||||
| `FAILED` / `FAILURE` | 任务执行失败 | 检查 `errorMessage` 字段获取失败原因 |
|
||||
|
||||
---
|
||||
|
||||
## 响应字段说明
|
||||
|
||||
| 字段名 | 类型 | 说明 |
|
||||
|--------|------|------|
|
||||
| `taskId` | String | 任务唯一标识 |
|
||||
| `title` | String | 作品标题 |
|
||||
| `content` | String | 作品正文内容 |
|
||||
| `textContent` | String | 视频语音转文字文案(任务完成后) |
|
||||
| `videoUrl` | String | 视频下载链接(无水印) |
|
||||
| `audioUrl` | String | 音频文件链接(任务完成后) |
|
||||
| `imageUrlList` | Array | 图片链接列表 |
|
||||
| `cover` | String | 封面图片链接 |
|
||||
| `platform` | String | 平台标识(xhs、douyin 等) |
|
||||
| `duration` | Number | 视频时长(秒) |
|
||||
| `workType` | String | 作品类型(video、image) |
|
||||
| `status` | String | 任务状态(WAITING、SUCCESS、FAILURE) |
|
||||
| `errorMessage` | String | 状态描述或错误信息 |
|
||||
| `createBy` | String | 创建者 ID |
|
||||
| `createTime` | String | 创建时间 |
|
||||
|
||||
---
|
||||
|
||||
## 接口使用流程
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ 1. 提交任务 │ --> │ 2. 轮询查询 │ --> │ 3. 处理结果 │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
POST /video/extract GET /video/query status = SUCCESS
|
||||
获取 taskId 每 3-5 秒查询一次 获取完整结果
|
||||
```
|
||||
|
||||
### 推荐调用流程
|
||||
|
||||
1. **提交任务**
|
||||
- 调用 `POST /video/extract` 接口,传入作品链接
|
||||
- 成功后返回 `taskId`,用于后续查询
|
||||
|
||||
2. **轮询查询**
|
||||
- 使用返回的 `taskId` 调用 `GET /video/query` 接口
|
||||
- 建议每隔 **3-5 秒** 查询一次任务状态
|
||||
|
||||
3. **处理结果**
|
||||
- 当 `status` 为 `SUCCESS` 时,获取完整的提取结果(标题、正文、视频、音频等)
|
||||
- 若为 `FAILURE`,检查 `errorMessage` 了解失败原因
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践
|
||||
|
||||
- **轮询间隔建议**:3-5 秒,避免过于频繁请求
|
||||
- **最大轮询次数**:建议设置 60 次上限,避免无限轮询
|
||||
- **安全保管**:妥善保管 API Key 和 Secret,不要泄露到客户端
|
||||
- **并发限制**:并发任务上限为 5 个,合理安排任务提交
|
||||
|
||||
---
|
||||
|
||||
## 支持平台
|
||||
|
||||
支持 50+ 平台,主要包括:
|
||||
|
||||
| 平台 | 标识 | 说明 |
|
||||
|------|------|------|
|
||||
| 小红书 | xhs | 视频、图文 |
|
||||
| 抖音 | douyin | 视频 |
|
||||
| 快手 | kuaishou | 视频 |
|
||||
| ... | ... | 更多平台 |
|
||||
|
||||
---
|
||||
|
||||
## Python 集成示例
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
class AnyToCopyClient:
|
||||
BASE_URL = "https://api.anytocopy.com/vip/open-api/v1"
|
||||
|
||||
def __init__(self, api_key: str, api_secret: str):
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.headers = {
|
||||
"X-API-Key": api_key,
|
||||
"X-API-Secret": api_secret,
|
||||
}
|
||||
|
||||
async def submit_task(self, work_url: str, task_type: str = "TEXT") -> dict:
|
||||
"""提交视频文案提取任务"""
|
||||
url = f"{self.BASE_URL}/video/extract"
|
||||
params = {"workUrl": work_url, "taskType": task_type}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=self.headers, params=params) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def query_task(self, task_id: str) -> dict:
|
||||
"""查询任务状态和结果"""
|
||||
url = f"{self.BASE_URL}/video/query"
|
||||
params = {"taskId": task_id}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=self.headers, params=params) as resp:
|
||||
return await resp.json()
|
||||
|
||||
async def extract_video(self, work_url: str, max_retries: int = 60) -> dict:
|
||||
"""完整的视频提取流程(提交 + 轮询)"""
|
||||
# 1. 提交任务
|
||||
submit_result = await self.submit_task(work_url)
|
||||
if submit_result.get("code") != 200:
|
||||
raise Exception(f"提交任务失败: {submit_result.get('msg')}")
|
||||
|
||||
task_id = submit_result["data"]
|
||||
print(f"任务已提交,taskId: {task_id}")
|
||||
|
||||
# 2. 轮询查询
|
||||
for i in range(max_retries):
|
||||
await asyncio.sleep(3) # 每 3 秒查询一次
|
||||
|
||||
query_result = await self.query_task(task_id)
|
||||
if query_result.get("code") != 200:
|
||||
continue
|
||||
|
||||
data = query_result.get("data", {})
|
||||
status = data.get("status")
|
||||
|
||||
if status == "SUCCESS":
|
||||
print(f"任务完成!")
|
||||
return data
|
||||
elif status == "FAILURE":
|
||||
raise Exception(f"任务失败: {data.get('errorMessage')}")
|
||||
else:
|
||||
print(f"[{i+1}/{max_retries}] 任务处理中...")
|
||||
|
||||
raise Exception("轮询超时,任务未完成")
|
||||
|
||||
|
||||
# 使用示例
|
||||
async def main():
|
||||
client = AnyToCopyClient(
|
||||
api_key="your_api_key",
|
||||
api_secret="your_api_secret"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await client.extract_video("https://v.douyin.com/xxxxx")
|
||||
print(f"标题: {result['title']}")
|
||||
print(f"文案: {result['textContent']}")
|
||||
print(f"视频: {result['videoUrl']}")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
@@ -0,0 +1,128 @@
|
||||
# AnyToCopy 视频文案提取集成
|
||||
|
||||
## 功能概述
|
||||
|
||||
脚本生成 API 现已支持自动识别视频链接并提取文案。
|
||||
|
||||
- **自动检测**:输入创作主题时自动检测是否为视频链接
|
||||
- **智能提取**:支持小红书、抖音、快手等 50+ 平台
|
||||
- **无缝集成**:提取的文案自动用于脚本生成
|
||||
|
||||
## 支持平台
|
||||
|
||||
| 平台 | 示例链接 |
|
||||
|------|----------|
|
||||
| 小红书 | `https://xhslink.com/xxx` |
|
||||
| 抖音 | `https://v.douyin.com/xxx` |
|
||||
| 快手 | `https://v.kuaishou.com/xxx` |
|
||||
| 哔哩哔哩 | `https://b23.tv/xxx` |
|
||||
| 微博 | `https://weibo.com/xxx` |
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 1. 普通文案生成(原有功能)
|
||||
|
||||
```json
|
||||
POST /api/v1/ai/scripts/generate
|
||||
{
|
||||
"topic": "家装验收的5个细节",
|
||||
"duration": 60,
|
||||
"script_type": "professional"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 视频链接提取文案后生成
|
||||
|
||||
```json
|
||||
POST /api/v1/ai/scripts/generate
|
||||
{
|
||||
"topic": "https://v.douyin.com/AbC123",
|
||||
"duration": 60,
|
||||
"script_type": "professional"
|
||||
}
|
||||
```
|
||||
|
||||
**流程**:
|
||||
1. 检测输入为视频链接
|
||||
2. 调用 AnyToCopy API 提取视频文案
|
||||
3. 使用提取的文案作为创作主题生成脚本
|
||||
|
||||
### 3. 混合输入(链接 + 说明)
|
||||
|
||||
```json
|
||||
POST /api/v1/ai/scripts/generate
|
||||
{
|
||||
"topic": "参考这个视频的风格 https://v.douyin.com/AbC123,写一个关于装修验收的脚本",
|
||||
"duration": 60,
|
||||
"script_type": "professional"
|
||||
}
|
||||
```
|
||||
|
||||
**流程**:
|
||||
1. 从文本中提取视频链接
|
||||
2. 提取视频文案
|
||||
3. 将提取的文案与原始说明结合生成脚本
|
||||
|
||||
## 流式生成(SSE)
|
||||
|
||||
视频提取过程会显示在进度中:
|
||||
|
||||
```
|
||||
data: {"type": "analyzing", "progress": 5, "message": "检测到视频链接,正在提取文案..."}
|
||||
|
||||
data: {"type": "analyzing", "progress": 10, "message": "视频文案提取成功,共 1200 字符"}
|
||||
|
||||
data: {"type": "generating", "progress": 15, "message": "正在创作脚本..."}
|
||||
...
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
在 `.env` 文件中配置 AnyToCopy API:
|
||||
|
||||
```bash
|
||||
# AnyToCopy 视频文案提取服务
|
||||
ANYTOCOPY_API_KEY=your-api-key
|
||||
ANYTOCOPY_API_SECRET=your-api-secret
|
||||
ANYTOCOPY_BASE_URL=https://api.anytocopy.com/vip/open-api/v1
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **API Key**:需要从 AnyToCopy 官网获取 API Key
|
||||
2. **并发限制**:AnyToCopy 限制并发任务数为 5
|
||||
3. **提取时间**:视频文案提取通常需要 10-30 秒
|
||||
4. **失败处理**:如果提取失败,会自动使用原始输入继续生成脚本
|
||||
|
||||
## 代码集成
|
||||
|
||||
### 服务层
|
||||
|
||||
```python
|
||||
from app.services.anytocopy_service import get_anytocopy_service
|
||||
|
||||
anytocopy = get_anytocopy_service()
|
||||
result = await anytocopy.extract_text_from_input("https://v.douyin.com/xxx")
|
||||
|
||||
if result["is_video_url"]:
|
||||
extracted_text = result["extracted_text"]
|
||||
# 使用提取的文案
|
||||
```
|
||||
|
||||
### 独立使用 AnyToCopy 服务
|
||||
|
||||
```python
|
||||
from app.services.anytocopy_service import AnyToCopyService
|
||||
|
||||
service = AnyToCopyService({
|
||||
"api_key": "your-key",
|
||||
"api_secret": "your-secret",
|
||||
})
|
||||
|
||||
# 提交任务
|
||||
result = await service.submit_task("https://v.douyin.com/xxx")
|
||||
task_id = result["data"]
|
||||
|
||||
# 查询结果
|
||||
query_result = await service.query_task(task_id)
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,357 @@
|
||||
# 美家卡智影 - 数据库设计规范
|
||||
|
||||
> 企业级统一数据库命名和设计规范,遵循 "轻量云 + 全本地业务数据" 架构设计。
|
||||
|
||||
---
|
||||
|
||||
## 一、整体设计原则
|
||||
|
||||
| 原则 | 说明 |
|
||||
|------|------|
|
||||
| **统一前缀** | 所有业务表统一使用 `mjk_` 前缀(美家卡全称缩写)|
|
||||
| **无外键设计** | 不使用数据库外键约束,数据一致性由业务层保证,架构更简洁灵活 |
|
||||
| **软删除优先** | 使用 `deleted_at` 时间戳做软删除,保留历史数据便于排查 |
|
||||
| **全小写下划线** | 命名全小写,单词用下划线分隔 |
|
||||
|
||||
---
|
||||
|
||||
## 二、表命名规范
|
||||
|
||||
### 命名格式
|
||||
|
||||
```
|
||||
mjk_{module}_{description}[_logs]
|
||||
```
|
||||
|
||||
- `mjk_` - 统一项目前缀
|
||||
- `module` - 业务模块名称
|
||||
- `description` - 内容描述
|
||||
- `_logs` 后缀 - 日志/统计类表(按事件增长)
|
||||
|
||||
### 示例
|
||||
|
||||
| 表名 | 说明 |
|
||||
|------|------|
|
||||
| `mjk_users` | 用户账户表 |
|
||||
| `mjk_model_usage_logs` | AI 模型调用日志表 |
|
||||
| `mjk_avatars` | 数字人名片表(已废弃,数据迁移到本地)|
|
||||
| `mjk_interface_request_logs` | 接口请求记录表 |
|
||||
|
||||
---
|
||||
|
||||
## 三、字段命名规范
|
||||
|
||||
| 场景 | 规则 | 示例 |
|
||||
|------|------|------|
|
||||
| **主键** | 统一命名 `id`,类型 `BIGSERIAL` / `BIGINT` | `id BIGSERIAL PRIMARY KEY` |
|
||||
| **外键引用** | 格式 `{referenced_table}_{primary_key}`,不要加前缀 | 引用 `mjk_users.id` → `user_id` |
|
||||
| **布尔类型** | 前缀 `is_` 或 `has_` | `is_deleted`, `has_attachment` |
|
||||
| **时间戳** | 后缀 `_at`,类型 `TIMESTAMP WITH TIME ZONE` | `created_at`, `updated_at`, `started_at`, `finished_at` |
|
||||
| **状态字段** | 字段名固定 `status`,类型 `VARCHAR(N)`,存储枚举字符串 | `status VARCHAR(20) NOT NULL` |
|
||||
| **软删除** | 字段名 `deleted_at`,允许 `NULL`,`NULL` 表示未删除 | `deleted_at TIMESTAMP WITH TIME ZONE` |
|
||||
|
||||
---
|
||||
|
||||
## 四、约束与索引命名规范
|
||||
|
||||
| 对象 | 命名格式 | 示例 |
|
||||
|------|---------|------|
|
||||
| **主键** | `{table_name}_pkey`(PostgreSQL 默认) | `mjk_interface_request_logs_pkey` |
|
||||
| **唯一约束** | `uk_{table_name}_{column_list}` | `uk_mjk_interface_request_logs_request_id` |
|
||||
| **普通索引** | `idx_{table_name}_{column_list}` | `idx_mjk_interface_request_logs_user_id` |
|
||||
|
||||
---
|
||||
|
||||
## 五、所有业务表结构
|
||||
|
||||
---
|
||||
|
||||
### 1. `mjk_users` - 用户基本信息表
|
||||
|
||||
存储用户基本认证信息,云端只存账户,不存业务数据。
|
||||
|
||||
```sql
|
||||
CREATE TABLE mjk_users (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
mobile VARCHAR(20) UNIQUE NOT NULL,
|
||||
nickname VARCHAR(64),
|
||||
avatar_url TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE UNIQUE INDEX idx_mjk_users_mobile ON mjk_users(mobile);
|
||||
```
|
||||
|
||||
**字段说明**:
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `id` | 用户唯一ID |
|
||||
| `mobile` | 手机号(登录账号,唯一)|
|
||||
| `nickname` | 用户昵称 |
|
||||
| `avatar_url` | 头像URL |
|
||||
| `created_at` | 创建时间 |
|
||||
| `updated_at` | 最后更新时间 |
|
||||
|
||||
---
|
||||
|
||||
### 2. `mjk_user_credits` - 用户积分账户记录表
|
||||
|
||||
记录用户积分账户的所有变动(充值、消费),每个变动一条记录。
|
||||
|
||||
```sql
|
||||
CREATE TABLE mjk_user_credits (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id VARCHAR(50) NOT NULL,
|
||||
change_type VARCHAR(20) NOT NULL, -- recharge / consume
|
||||
change_credits INTEGER NOT NULL, -- 变动积分数(充值正,消费负)
|
||||
balance_before INTEGER NOT NULL, -- 变动前余额
|
||||
balance_after INTEGER NOT NULL, -- 变动后余额
|
||||
interface_type VARCHAR(50), -- 消费接口类型(消费时才有)
|
||||
request_id VARCHAR(64), -- 关联接口请求ID
|
||||
remark VARCHAR(200), -- 备注(充值订单号等)
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_mjk_user_credits_user_id ON mjk_user_credits(user_id);
|
||||
CREATE INDEX idx_mjk_user_credits_created_at ON mjk_user_credits(created_at);
|
||||
CREATE INDEX idx_mjk_user_credits_change_type ON mjk_user_credits(change_type);
|
||||
```
|
||||
|
||||
**字段说明**:
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `id` | 记录ID |
|
||||
| `user_id` | 关联用户ID |
|
||||
| `change_type` | 变动类型:`recharge`(充值) / `consume`(消费) |
|
||||
| `change_credits` | 变动积分,充值为正,消费为负 |
|
||||
| `balance_before` | 变动前积分余额 |
|
||||
| `balance_after` | 变动后积分余额 |
|
||||
| `interface_type` | 消费接口类型(仅消费时有)|
|
||||
| `request_id` | 关联接口请求ID(可用于追溯)|
|
||||
| `remark` | 备注,充值时存订单号 |
|
||||
| `created_at` | 变动时间 |
|
||||
|
||||
**余额计算**:用户当前余额 = `sum(change_credits)`,可以随时计算,也可以在用户表存冗余字段加速查询。
|
||||
|
||||
---
|
||||
|
||||
### 3. `mjk_model_usage_logs` - AI 模型调用日志表
|
||||
|
||||
记录每一次 AI 模型调用,用于成本统计和监控。
|
||||
|
||||
```sql
|
||||
CREATE TABLE mjk_model_usage_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
model_id VARCHAR(100) NOT NULL,
|
||||
platform_id VARCHAR(50) NOT NULL,
|
||||
task_type VARCHAR(50) NOT NULL,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
completion_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
cost_cny FLOAT NOT NULL DEFAULT 0.0,
|
||||
response_time_ms INTEGER,
|
||||
success BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
error_message TEXT,
|
||||
user_id VARCHAR(50),
|
||||
project_id VARCHAR(50),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_mjk_model_usage_logs_user_id ON mjk_model_usage_logs(user_id);
|
||||
CREATE INDEX idx_mjk_model_usage_logs_created_at ON mjk_model_usage_logs(created_at);
|
||||
```
|
||||
|
||||
**字段说明**:
|
||||
- `model_id` - AI 模型ID
|
||||
- `platform_id` - AI 平台ID(openai/volcengine/klingai 等)
|
||||
- `task_type` - 任务类型(script/polish/chat 等)
|
||||
- `prompt_tokens` - 输入 Token 数
|
||||
- `completion_tokens` - 输出 Token 数
|
||||
- `total_tokens` - 总 Token 数
|
||||
- `cost_cny` - 消耗金额(人民币元)
|
||||
- `response_time_ms` - 响应时间(毫秒)
|
||||
- `success` - 是否成功
|
||||
- `error_message` - 错误信息
|
||||
- `user_id` - 关联用户ID
|
||||
- `project_id` - 关联项目ID
|
||||
- `created_at` - 创建时间
|
||||
|
||||
---
|
||||
|
||||
### 4. `mjk_interface_request_logs` - 接口请求记录表(新增)
|
||||
|
||||
**按后端接口类型记录所有用户请求,统计积分消耗**。这张表是顶层的接口请求统计,每一次前端调用后端接口都记一条。
|
||||
|
||||
```sql
|
||||
CREATE TABLE mjk_interface_request_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
request_id VARCHAR(64) NOT NULL,
|
||||
user_id VARCHAR(50) NOT NULL,
|
||||
interface_type VARCHAR(50) NOT NULL,
|
||||
interface_name VARCHAR(100),
|
||||
status VARCHAR(20) NOT NULL,
|
||||
cost_credits INTEGER NOT NULL DEFAULT 0,
|
||||
started_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
finished_at TIMESTAMP WITH TIME ZONE,
|
||||
error_message TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- 唯一约束
|
||||
ALTER TABLE mjk_interface_request_logs
|
||||
ADD CONSTRAINT uk_mjk_interface_request_logs_request_id
|
||||
UNIQUE (request_id);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_mjk_interface_request_logs_user_id
|
||||
ON mjk_interface_request_logs(user_id);
|
||||
CREATE INDEX idx_mjk_interface_request_logs_interface_type
|
||||
ON mjk_interface_request_logs(interface_type);
|
||||
CREATE INDEX idx_mjk_interface_request_logs_status
|
||||
ON mjk_interface_request_logs(status);
|
||||
CREATE INDEX idx_mjk_interface_request_logs_created_at
|
||||
ON mjk_interface_request_logs(created_at);
|
||||
```
|
||||
|
||||
**字段说明**:
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `id` | 日志记录自增ID |
|
||||
| `request_id` | 本次请求唯一ID(全局唯一)|
|
||||
| `user_id` | 请求用户ID |
|
||||
| `interface_type` | 接口类型(枚举见下方)|
|
||||
| `interface_name` | 接口名称(可读描述)|
|
||||
| `status` | 请求状态:`success` / `failed` |
|
||||
| `cost_credits` | 消耗积分数 |
|
||||
| `started_at` | 请求开始时间 |
|
||||
| `finished_at` | 请求结束时间 |
|
||||
| `error_message` | 失败原因 |
|
||||
| `created_at` | 记录创建时间 |
|
||||
|
||||
**`interface_type` 枚举值**:
|
||||
|
||||
| 值 | 说明 |
|
||||
|----|------|
|
||||
| `script_generate` | 脚本生成 |
|
||||
| `script_polish` | 脚本润色 |
|
||||
| `avatar_clone` | 数字人克隆 |
|
||||
| `video_generate` | 数字人视频生成 |
|
||||
| `subtitle_generate` | 字幕打轴生成 |
|
||||
| `image_generate` | 封面图片生成 |
|
||||
|
||||
---
|
||||
|
||||
### 5. `mjk_avatars` - 数字人名片表(已废弃)
|
||||
|
||||
> **迁移计划**:原 `avatars` 表已废弃,所有数字人元数据全量迁移到用户本地存储。
|
||||
> 路径:`~/Documents/Meijiaka/avatars/{avatar_id}/meta.json`
|
||||
|
||||
保留本表仅用于存量数据兼容,后续可删除。
|
||||
|
||||
```sql
|
||||
CREATE TABLE mjk_avatars (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(50) NOT NULL,
|
||||
name VARCHAR(64) NOT NULL,
|
||||
voice_id VARCHAR(64),
|
||||
element_id BIGINT,
|
||||
voice_task_id VARCHAR(128),
|
||||
element_task_id VARCHAR(128),
|
||||
video_url TEXT NOT NULL,
|
||||
trial_url TEXT,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
fail_reason TEXT,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_mjk_avatars_user_id ON mjk_avatars(user_id);
|
||||
CREATE INDEX idx_mjk_avatars_voice_task_id ON mjk_avatars(voice_task_id);
|
||||
CREATE INDEX idx_mjk_avatars_element_task_id ON mjk_avatars(element_task_id);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、本地存储结构(业务数据)
|
||||
|
||||
所有业务数据(项目、脚本、数字人)都存在用户本地磁盘,云端只存储日志和统计:
|
||||
|
||||
```
|
||||
~/Documents/Meijiaka/
|
||||
├── config.json # 全局应用配置
|
||||
├── projects/
|
||||
│ └── {project_id}/
|
||||
│ ├── meta.json # 项目元数据
|
||||
│ ├── segments.json # 脚本/分镜数据
|
||||
│ └── assets/ # 媒体文件
|
||||
├── avatars/
|
||||
│ └── {avatar_id}/
|
||||
│ ├── meta.json # 数字人元数据(id/name/voice_id/element_id/status 等)
|
||||
│ └── source.mp4 # 原始上传视频
|
||||
└── cache/ # 临时文件
|
||||
```
|
||||
|
||||
**`avatars/{avatar_id}/meta.json` 结构**:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "avt_xxx",
|
||||
"name": "我的数字人",
|
||||
"voiceId": "kling-voice-id",
|
||||
"elementId": 12345678,
|
||||
"voiceTaskId": "kling-task-id",
|
||||
"elementTaskId": "kling-task-id",
|
||||
"videoUrl": "https://.../source.mp4",
|
||||
"trialUrl": "https://.../trial.wav",
|
||||
"status": "succeed",
|
||||
"failReason": null,
|
||||
"createdAt": "2026-04-16T10:00:00Z",
|
||||
"updatedAt": "2026-04-16T10:05:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、架构总结
|
||||
|
||||
| 数据类型 | 存储位置 | 说明 |
|
||||
|---------|----------|------|
|
||||
| 用户基本信息 | 云端 `mjk_users` | 必须存云端 |
|
||||
| 用户积分变动记录 | 云端 `mjk_user_credits` | 记录充值/消费流水,统计用户余额 |
|
||||
| AI 模型调用日志 | 云端 `mjk_model_usage_logs` | AI 模型细粒度调用日志(成本统计)|
|
||||
| 接口请求记录 | 云端 `mjk_interface_request_logs` | 按后端接口记录请求、状态、消耗积分 |
|
||||
| 项目/脚本/分镜 | 用户本地 JSON | 全本地业务数据 |
|
||||
| 数字人元数据/原始视频 | 用户本地文件 | 全本地业务数据(原云端表已废弃)|
|
||||
| 合成输出视频 | 用户本地文件 | 全本地 |
|
||||
|
||||
完美符合设计理念:**轻量云账号 + 全本地业务数据**。
|
||||
|
||||
---
|
||||
|
||||
## 八、迁移说明
|
||||
|
||||
### 从无前缀版本迁移到统一前缀版本
|
||||
|
||||
1. 使用 Alembic 自动重命名所有现有表
|
||||
```sql
|
||||
ALTER TABLE users RENAME TO mjk_users;
|
||||
ALTER TABLE model_usage_logs RENAME TO mjk_model_usage_logs;
|
||||
ALTER TABLE avatars RENAME TO mjk_avatars;
|
||||
```
|
||||
2. 新建两张表:
|
||||
- `mjk_user_credits` - 用户积分变动记录表
|
||||
- `mjk_interface_request_logs` - 接口请求记录表
|
||||
3. 修改所有 SQLAlchemy 模型中的 `__tablename__`
|
||||
4. 后续:将 `mjk_avatars` 数据迁移到用户本地后可删除该表
|
||||
|
||||
---
|
||||
|
||||
*版本:v1.1*
|
||||
*创建日期:2026-04-16*
|
||||
*更新:新增 `mjk_user_credits` 积分账户表*
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,101 @@
|
||||
# 美家卡-智剪 (Meijiaka Smart Cut) 项目开发实施方案
|
||||
|
||||
基于您的最新反馈与确认,本项目将以《golden-purring-crown.md》(方案A)为主要交互蓝本进行落地,明确采用 **手动匹配分镜视频**、**完全本地化数据存储** 和 **沿用现有架构新建仓库** 的策略。
|
||||
|
||||
## 目标与改动背景
|
||||
|
||||
**项目背景**:衍生自现有的「美家卡智影」,新项目「美家卡智剪」侧重针对用户已有的视频素材,利用 AI 进行配音、并完成拼接与后期制作。
|
||||
|
||||
核心工作流程(6 步):
|
||||
1. **脚本生成** (基于主题生成具有预估时长的分镜与旁白)
|
||||
2. **视频剪辑 (新)** (用户手动为**每一个单分镜**导入对应长度的视频素材短片)
|
||||
3. **音色配音 (新)** (用户本地维护音色特征,使用大模型 TTS 为所有分镜批量生成口播音频)
|
||||
4. **字幕压制** (自动打轴并挂载 ASS 字幕,复用智影功能)
|
||||
5. **封面制作** (根据首分镜首帧和文字生成封面,复用智影功能)
|
||||
6. **视频合成** (所有片段首尾拼接成短视频,并将原有环境音替换/混音为合成音频,复用智影功能)
|
||||
|
||||
---
|
||||
|
||||
## 核心设计决策 (User Confirmed)
|
||||
|
||||
1. **交互模式**:不采用长视频自动切割算法。**必须采用单一分镜独立手动导入视频的交互**。
|
||||
2. **数据存储**:**纯本地文件系统**。所有业务数据(项目元数据、分镜配置、克隆好的本地音色记录等)全部保存在用户本地磁盘路径下,**不保存在云端数据库中**。
|
||||
3. **架构剥离**:通过拷贝文件系统级别进行剥离 (`rsync ai-meijiaka -> meijiaka-zj`),保留现有混合路由和本地缓存设计。
|
||||
|
||||
---
|
||||
|
||||
## Proposed Changes
|
||||
|
||||
### 1. 架构剥离与仓库初始化
|
||||
在同级目录下快速搭建衍生仓库,移除不相干的缓存依赖。
|
||||
|
||||
#### [NEW] `meijiaka-zj/` (新建本地项目根目录)
|
||||
- 配置应用标识词修正( `产品名: 美家卡智剪`, `Bundle Identifier: cn.meijiaka.ai-video-editor` 等)。
|
||||
- 修改并初始化 git 记录。
|
||||
|
||||
---
|
||||
|
||||
### 2. 后端 API (Python FastAPI)
|
||||
|
||||
不再新建数据库表结构,将所有新 API 的核心转为与本地文件、大语言模型 API 之间的交互代理,由 AsyncEngine 发起。
|
||||
|
||||
#### [NEW] `python-api/app/scheduler/handlers/tts_handler.py`
|
||||
创建用于处理批量语音生成的并发 Dispatcher。
|
||||
|
||||
#### [NEW] `python-api/app/services/voice_clone_service.py` & `tts_service.py`
|
||||
包装调用 `KlingAIProvider`:
|
||||
- 创建克隆音色的调用逻辑(由于无数据库,云端成功后的声纹特征及 `voice_id` 将通过 API 抛回前端并由 Tauri 存入本地 JSON 集合中)。
|
||||
- 提供语音合成和查询能力的端点。
|
||||
|
||||
#### [NEW] `python-api/app/api/v1/voice.py`
|
||||
仅暴露无状态/代理转发类型的路由给前端:克隆状态查询、提交合成等。无 DB 依赖。
|
||||
|
||||
---
|
||||
|
||||
### 3. Rust 系统能力扩展 (src-tauri)
|
||||
|
||||
由于采用本地存储,需要在 Rust 层扩展音频文件和声纹文件的安全存储指令。
|
||||
|
||||
#### [NEW] `tauri-app/src-tauri/src/storage/voice.rs`
|
||||
新增声音本地缓存与描述管理,目录范例:`~/Documents/Meijiaka/voices/` (用于存储 voice meta.json 和相关的 reference audio)。
|
||||
|
||||
#### [NEW] `tauri-app/src-tauri/src/commands/voice.rs`
|
||||
由 Tauri 提供存储IPC API给前端:读取本地音色列表、写入新克隆的音色等。
|
||||
|
||||
#### [MODIFY] `tauri-app/src-tauri/src/ffmpeg_cmd.rs`
|
||||
**[重要机制更新]**: 实现目标音频覆盖处理,提供类似 `replace_audio_in_video` 的函数,依靠 `-c:v copy -c:a aac -shortest -map 0:v:0 -map 1:a:0` 剥离原声并在对应的短视频片段上压入新的 TTS 朗读声音。
|
||||
|
||||
---
|
||||
|
||||
### 4. 前端应用层 (tauri-app / React)
|
||||
|
||||
调整原应用状态数据,创建本地数据绑定。
|
||||
|
||||
#### [MODIFY] `tauri-app/src/store/projectStore.ts`
|
||||
扩展原 `SmartCutShot` 阶段参数,支持记录新增加的独立视频源地址 (`mediaPath`) 和单独段落的合成语音地址 (`audioPath`)。
|
||||
|
||||
#### [NEW] `tauri-app/src/store/voiceStore.ts`
|
||||
与 `src-tauri` 通过 IPC 交互:
|
||||
- 从本地加载用户维护在 `voices/` 下的所有自定义音色。
|
||||
- 处理前端的缓存与显示。
|
||||
|
||||
#### [NEW] `tauri-app/src/pages/VideoCreation/VideoEditing.tsx` (Step 2)
|
||||
重定义分镜视频导入步骤:
|
||||
- 为左侧每一个生成的分镜文案展示独立的 Upload/Select Box。
|
||||
- 用户可以点选或拖动,调用系统弹窗将 `mp4` 一对一绑定给自己心仪的旁白节点。
|
||||
|
||||
#### [NEW] `tauri-app/src/pages/VideoCreation/VoiceDubbing.tsx` (Step 3)
|
||||
批量克隆与TTS应用页面:
|
||||
- 渲染本地和预定义的云端默认音库。
|
||||
- 前端批量发起所有含旁白分镜的异步合成任务,获取 URL 后调用 Rust 保留至项目对应的 `audio/` 子目录中。
|
||||
|
||||
---
|
||||
|
||||
## Verification Plan
|
||||
|
||||
### Manual Verification (端到端走通测试)
|
||||
- **环境**: 在新目录 `meijiaka-zj` 启动前后端服务。
|
||||
- **Step 1**: 使用纯业务旁白的模版生成分镜文案。
|
||||
- **Step 2**: 对列表中独立出现的 3 个分镜卡片,依次上传/拖入 3 个独立的 `.mp4` 文件以测试前端映射逻辑。
|
||||
- **Step 3(关键测试)**: 选择一个克隆音色发起全局合成。观察 `tts_slots` 运转状况。完毕后查验对应项目的物理存储路径内正确生成了 `.mp3` 音轨。
|
||||
- **Step 6**: 打包合成,测试 `ffmpeg_cmd.rs` 中音频替代逻辑是否执行无误,输出画面不掉帧、声音是合成口音的短片。
|
||||
@@ -0,0 +1,833 @@
|
||||
# 美家卡智剪 — 产品技术方案
|
||||
|
||||
> 基于「美家卡智影」架构的 AI 辅助短视频剪辑产品方案
|
||||
> 版本: v2.0 | 日期: 2026-04-20
|
||||
|
||||
---
|
||||
|
||||
## 一、产品定位
|
||||
|
||||
| 维度 | 美家卡智影(现有) | 美家卡智剪(新项目) |
|
||||
|------|-------------------|---------------------|
|
||||
| **核心能力** | AI 数字人视频生成 | AI 音色克隆 + 语音合成 + 素材智能剪辑 |
|
||||
| **视频来源** | KlingAI 生成数字人视频 | 用户导入长视频素材 |
|
||||
| **声音来源** | KlingAI 预设/自定义音色 + 数字人 | 用户克隆音色 / 预设音色 + TTS |
|
||||
| **目标场景** | 口播视频、营销视频从无到有 | 已有长素材快速剪辑成片、声音克隆配音 |
|
||||
| **核心差异** | 「生成式」创作 | 「剪辑式」创作 + AI 声音 |
|
||||
|
||||
### 一句话定义
|
||||
> **美家卡智剪** = 导入长视频 + AI 文案分镜 + 自动切割 + 音色克隆 + 语音合成 + 字幕压制 + 封面合成 + 视频导出
|
||||
|
||||
---
|
||||
|
||||
## 二、核心流程设计(6 步)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ │
|
||||
│ Step 1 → Step 2 → Step 3 → Step 4 → S5 → S6 │
|
||||
│ 脚本生成 → 视频剪辑 → 音色配音 → 字幕压制 → 封面 → 合成 │
|
||||
│ │
|
||||
│ ├─ AI文案 ├─ 导入长视频 ├─ 音色克隆 ├─ 自动打轴 │
|
||||
│ ├─ 粘贴文案 ├─ 自动切割 ├─ 预设音色 ├─ ASS字幕 │
|
||||
│ ├─ 智能分镜 │ (按分镜时长) ├─ 分镜TTS ├─ FFmpeg压制 │
|
||||
│ │ │ ├─ 试听/调整 │
|
||||
│ │ │ │ │
|
||||
│ [改造] [全新] [全新] [复用] [复用] │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 各步骤详细说明
|
||||
|
||||
---
|
||||
|
||||
#### Step 1 — 脚本生成(Script Generation)
|
||||
|
||||
**文案输入(3 种方式):**
|
||||
1. **AI 生成**:输入主题/关键词,LLM 生成短视频文案
|
||||
2. **直接粘贴**:用户粘贴已准备好的文案,系统自动分镜
|
||||
3. **导入文件**:支持 `.txt` / `.docx` / `.srt` 导入
|
||||
|
||||
**智能分镜:**
|
||||
- 按句子/段落自动拆分分镜
|
||||
- 每个分镜含:`voiceover`(旁白文案)、`duration`(预估时长)
|
||||
- 支持拖拽调整分镜顺序、合并、拆分
|
||||
- 文案字数根据目标时长自动约束(15s≈40字 / 30s≈80字 / 60s≈160字)
|
||||
|
||||
**输出:**
|
||||
- `segments[]`:分镜列表,每个分镜含文案和预估时长
|
||||
- 此步骤与智影 Step 1 基本一致,Prompt 调整为生成纯旁白文案(不含场景描述)
|
||||
|
||||
---
|
||||
|
||||
#### Step 2 — 视频剪辑(Video Editing)
|
||||
|
||||
**核心逻辑:导入一个长视频,按分镜时长自动切割。**
|
||||
|
||||
**流程:**
|
||||
1. 用户导入一个长视频文件(`.mp4/.mov`)
|
||||
2. 系统提取视频总时长
|
||||
3. 按分镜数量和预估时长自动计算切割点
|
||||
4. 调用 FFmpeg 将长视频切割为 N 个片段
|
||||
5. 每个片段自动绑定到对应分镜
|
||||
|
||||
**自动切割算法:**
|
||||
```
|
||||
总视频时长 = T
|
||||
分镜数 = N
|
||||
分镜预估时长 = [d1, d2, ..., dN]
|
||||
预估总时长 = D = d1 + d2 + ... + dN
|
||||
|
||||
如果 D <= T:
|
||||
按比例分配: 每个分镜实际时长 = di * (T / D)
|
||||
切割点: cumsum([d1*T/D, d2*T/D, ...])
|
||||
|
||||
如果 D > T:
|
||||
提示用户: 文案预估总时长超过视频时长,建议缩短文案或导入更长视频
|
||||
```
|
||||
|
||||
**界面示意:**
|
||||
```
|
||||
┌────────────────────────────────────────────┐
|
||||
│ 分镜列表 │ 素材导入 │
|
||||
│ ├─ 分镜1 (5s) │ ├─ 📁 点击导入 │
|
||||
│ ├─ 分镜2 (8s) │ │ 或拖拽视频 │
|
||||
│ ├─ 分镜3 (7s) │ │ │
|
||||
│ └─ 分镜4 (5s) │ │ 🎬 素材.mp4 │
|
||||
│ │ │ 时长: 25s │
|
||||
│ 预估总时长: 25s │ │ 分辨率: 1080p │
|
||||
│ │ └────────────────│
|
||||
│ [自动切割] │ │
|
||||
└────────────────────────────────────────────┘
|
||||
|
||||
切割结果预览:
|
||||
┌────────────────────────────────────────────┐
|
||||
│ 分镜1 ←→ 🎬 [00:00 - 00:05] (5s) │
|
||||
│ 分镜2 ←→ 🎬 [00:05 - 00:13] (8s) │
|
||||
│ 分镜3 ←→ 🎬 [00:13 - 00:20] (7s) │
|
||||
│ 分镜4 ←→ 🎬 [00:20 - 00:25] (5s) │
|
||||
└────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**技术实现:**
|
||||
- 前端:文件选择 → 调用 Rust IPC `import_media` → 保存到项目 `media/` 目录
|
||||
- Rust:`split_video` 命令使用 FFmpeg `-ss` + `-t` 截取片段
|
||||
- 每个片段保存为 `shot_{index}.mp4`,路径写入 `segment.mediaPath`
|
||||
|
||||
---
|
||||
|
||||
#### Step 3 — 音色配音(Voice & Dubbing)
|
||||
|
||||
**音色管理:**
|
||||
- **预设音色**:接入 KlingAI 官方预设音色(温柔女声、播报男声等)
|
||||
- **我的音色**:用户克隆的音色列表
|
||||
- 克隆方式:录音(10-20 秒)或上传音频文件
|
||||
- 克隆状态:处理中 / 完成 / 失败
|
||||
- 支持预览、重命名、删除
|
||||
|
||||
**语音合成(TTS):**
|
||||
- 为每个分镜独立选择音色
|
||||
- 支持统一设置(一键应用到全部分镜)
|
||||
- 可调节语速(0.8x - 2.0x)
|
||||
- 实时试听、重新生成
|
||||
|
||||
**批量合成:**
|
||||
- 一键合成所有分镜音频
|
||||
- 后台 Async Engine 并行处理(受槽位限制)
|
||||
- 实时进度显示
|
||||
|
||||
---
|
||||
|
||||
#### Step 4 — 字幕压制(Subtitle Burning)
|
||||
|
||||
**基本复用智影现有逻辑,数据源变化:**
|
||||
- 原:基于数字人视频的音频流进行自动打轴
|
||||
- 新:基于 TTS 合成的音频文件进行自动打轴
|
||||
|
||||
**流程:**
|
||||
1. 提交 `subtitle` 任务(`mode: auto_align`)
|
||||
2. 参数:`audioUrl`(TTS 音频)+ `audioText`(分镜文案)
|
||||
3. 返回 `alignmentResult`(utterances 时间轴)
|
||||
4. 用户选择字幕样式(颜色/字号/描边/位置)
|
||||
5. 调用 Rust IPC `burn_subtitle` 压制 ASS 字幕到视频
|
||||
|
||||
**输出:**
|
||||
- 每个分镜生成 `burnedVideoPath`(素材视频 + TTS 音频 + ASS 字幕)
|
||||
|
||||
---
|
||||
|
||||
#### Step 5 — 封面制作(Cover Design)
|
||||
|
||||
**完全复用智影现有逻辑:**
|
||||
1. 提取第一个分镜视频的首帧作为背景
|
||||
2. 用户输入封面标题
|
||||
3. 选择字体样式(抖音美好体等)
|
||||
4. 调用 Rust IPC `generate_cover_image` 合成封面
|
||||
|
||||
---
|
||||
|
||||
#### Step 6 — 视频合成(Video Composite)
|
||||
|
||||
**完全复用智影现有逻辑:**
|
||||
1. 收集所有分镜的 `burnedVideoPath`
|
||||
2. 如有封面图,先转为 0.5s 封面视频
|
||||
3. 调用 Rust IPC `video_composite_synthesis` 拼接所有片段
|
||||
4. 输出最终成品到 `~/Documents/Meijiaka/products/`
|
||||
|
||||
---
|
||||
|
||||
## 三、功能模块对比矩阵
|
||||
|
||||
| 模块 | 智影(现有) | 智剪(新) | 复用度 |
|
||||
|------|-------------|-----------|--------|
|
||||
| **脚本生成** | AI 生成脚本 | AI 生成文案 + 粘贴/导入 | 🔶 改造 |
|
||||
| **视频生成** | KlingAI 数字人 | 素材导入 + **自动切割** | 🔴 新增 |
|
||||
| **音色管理** | KlingAI Element 绑定音色 | 独立音色克隆 + 预设库 | 🔶 改造 |
|
||||
| **语音合成** | 数字人自带口播 | TTS 独立合成音频 | 🔴 新增 |
|
||||
| **字幕压制** | 自动打轴+FFmpeg | 完全复用 | 🟢 复用 |
|
||||
| **封面制作** | 首帧+标题+FFmpeg | 完全复用 | 🟢 复用 |
|
||||
| **视频合成** | FFmpeg concat | 完全复用 | 🟢 复用 |
|
||||
| **本地存储** | meta.json + segments.json | 扩展字段 | 🔶 改造 |
|
||||
| **任务调度** | 6 个 Handler | 新增 TTS Handler | 🔶 改造 |
|
||||
| **用户认证** | JWT + 手机号 | 完全复用 | 🟢 复用 |
|
||||
| **形象克隆** | Avatar 完整流程 | 简化为音色克隆 | 🔶 改造 |
|
||||
|
||||
> 🟢 完全复用 | 🔶 需要改造 | 🔴 全新开发
|
||||
|
||||
---
|
||||
|
||||
## 四、前端架构方案
|
||||
|
||||
### 4.1 页面结构
|
||||
|
||||
```
|
||||
tauri-app/src/pages/
|
||||
├── VideoCreation/
|
||||
│ ├── index.tsx # 6步流程容器(复用,调整步骤名)
|
||||
│ ├── ScriptCreation.tsx # Step 1: 脚本生成(复用改造)
|
||||
│ ├── VideoEditing.tsx # Step 2: 视频剪辑(全新)
|
||||
│ ├── VoiceDubbing.tsx # Step 3: 音色配音(全新)
|
||||
│ ├── SubtitleBurning.tsx # Step 4: 字幕压制(复用)
|
||||
│ ├── CoverDesign.tsx # Step 5: 封面制作(复用)
|
||||
│ └── VideoComposite.tsx # Step 6: 视频合成(复用)
|
||||
```
|
||||
|
||||
### 4.2 Store 设计
|
||||
|
||||
#### projectStore(改造)
|
||||
|
||||
```typescript
|
||||
interface SmartCutState {
|
||||
// === Step 1: 脚本与分镜 ===
|
||||
segments: SmartCutShot[];
|
||||
topic?: string;
|
||||
scriptType?: string;
|
||||
|
||||
// === Step 3: 音色配音 ===
|
||||
defaultVoiceId?: string; // 默认音色
|
||||
|
||||
// === Step 5+6: 封面与合成 ===
|
||||
coverPath?: string;
|
||||
coverConfig?: CoverConfig;
|
||||
finalVideoPath?: string;
|
||||
exportedAt?: string;
|
||||
|
||||
// === 流程状态 ===
|
||||
currentStep: number; // 1-6
|
||||
}
|
||||
|
||||
interface SmartCutShot {
|
||||
id: string;
|
||||
type: 'segment' | 'empty_shot';
|
||||
voiceover: string; // 旁白文案
|
||||
duration: number; // 预估/实际时长
|
||||
|
||||
// === Step 2: 视频剪辑后绑定 ===
|
||||
mediaPath?: string; // 切割后的视频片段路径
|
||||
mediaStartTime?: number; // 在原视频中的起始时间(秒)
|
||||
mediaEndTime?: number; // 在原视频中的结束时间(秒)
|
||||
|
||||
// === Step 3: 配音配置 ===
|
||||
ttsConfig?: TTSConfig;
|
||||
audioPath?: string; // TTS 合成音频本地路径
|
||||
audioUrl?: string; // TTS 音频远程 URL
|
||||
|
||||
// === Step 4: 字幕与后期 ===
|
||||
alignmentResult?: AlignmentResult;
|
||||
burnedVideoPath?: string;
|
||||
burnedAt?: string;
|
||||
}
|
||||
|
||||
interface TTSConfig {
|
||||
voiceId: string;
|
||||
voiceName: string;
|
||||
speed: number; // 0.8 - 2.0
|
||||
}
|
||||
```
|
||||
|
||||
#### voiceStore(新增)
|
||||
|
||||
```typescript
|
||||
interface VoiceState {
|
||||
// 预设音色
|
||||
presetVoices: PresetVoice[];
|
||||
presetVoicesLoading: boolean;
|
||||
|
||||
// 用户克隆音色
|
||||
clonedVoices: ClonedVoice[];
|
||||
clonedVoicesLoading: boolean;
|
||||
|
||||
// 当前选中的默认音色
|
||||
selectedVoiceId?: string;
|
||||
}
|
||||
|
||||
interface PresetVoice {
|
||||
voiceId: string;
|
||||
voiceName: string;
|
||||
previewUrl?: string;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
interface ClonedVoice {
|
||||
id: string; // vc_xxx
|
||||
name: string;
|
||||
providerVoiceId: string; // KlingAI 返回的 voice_id
|
||||
provider: string;
|
||||
status: 'processing' | 'succeed' | 'failed';
|
||||
previewUrl?: string;
|
||||
createdAt: string;
|
||||
}
|
||||
```
|
||||
|
||||
### 4.3 新增 Hooks
|
||||
|
||||
| Hook | 职责 |
|
||||
|------|------|
|
||||
| `useVoiceClone.ts` | 音色克隆:提交克隆、轮询状态、管理列表 |
|
||||
| `useTTSGeneration.ts` | TTS 批量合成:提交任务、轮询、更新 segment |
|
||||
| `useMediaImport.ts` | 素材导入:文件选择、调用 Rust IPC |
|
||||
| `useAutoSplit.ts` | 自动切割:计算切割点、调用 split_video、绑定分镜 |
|
||||
|
||||
### 4.4 API 模块
|
||||
|
||||
```
|
||||
tauri-app/src/api/modules/
|
||||
├── voice.ts # 音色克隆 / 预设音色 / 查询 / 删除
|
||||
├── tts.ts # TTS 提交 / 查询 / 批量
|
||||
├── script.ts # 复用,文案生成
|
||||
├── caption.ts # 复用,字幕相关
|
||||
└── videoComposite.ts # 复用,视频合成
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、后端架构方案
|
||||
|
||||
### 5.1 新增 API 路由
|
||||
|
||||
```python
|
||||
# python-api/app/api/v1/voice.py
|
||||
@router.post("/voice/clone") # 提交音色克隆任务
|
||||
@router.get("/voice/clones") # 查询用户克隆音色列表
|
||||
@router.get("/voice/clones/{id}") # 查询单个克隆任务
|
||||
@router.delete("/voice/clones/{id}") # 删除克隆音色
|
||||
@router.get("/voice/presets") # 查询预设音色列表
|
||||
|
||||
# python-api/app/api/v1/tts.py
|
||||
@router.post("/tts") # 提交 TTS 任务
|
||||
@router.get("/tts/{job_id}") # 查询 TTS 任务状态
|
||||
@router.post("/tts/batch") # 批量提交 TTS 任务
|
||||
```
|
||||
|
||||
### 5.2 新增 Async Engine Handler
|
||||
|
||||
新增 **`tts`** 任务类型:
|
||||
|
||||
```python
|
||||
# app/scheduler/handlers/tts_handler.py
|
||||
|
||||
class TTSHandler(AsyncHandler):
|
||||
"""TTS 语音合成 Handler
|
||||
|
||||
为每个分镜的文案生成语音音频。
|
||||
"""
|
||||
job_type = "tts"
|
||||
slot_key = "kling:tts_slots"
|
||||
max_slots = 10
|
||||
|
||||
async def handle(self, job: JobRecord) -> list[StateChange]:
|
||||
"""处理流程:
|
||||
1. 从 job.payload 提取 text, voice_id, voice_speed
|
||||
2. 调用 KlingAI TTS API 生成音频
|
||||
3. 轮询任务完成
|
||||
4. 下载音频文件到本地项目目录
|
||||
5. (可选)上传七牛云持久化
|
||||
6. 返回结果含 audio_path, audio_url, duration
|
||||
"""
|
||||
```
|
||||
|
||||
**Redis 配置:**
|
||||
```
|
||||
槽位 Key: kling:tts_slots
|
||||
槽位数: 10
|
||||
```
|
||||
|
||||
### 5.3 新增 Service 层
|
||||
|
||||
```python
|
||||
# app/services/tts_service.py
|
||||
class TTSService:
|
||||
"""TTS 语音合成服务"""
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str,
|
||||
voice_speed: float = 1.0,
|
||||
output_dir: str | None = None,
|
||||
) -> TTSResult:
|
||||
"""生成单条 TTS 音频"""
|
||||
|
||||
async def batch_generate(
|
||||
self,
|
||||
items: list[TTSRequest],
|
||||
user_id: str,
|
||||
) -> list[str]:
|
||||
"""批量提交 TTS 任务到 Async Engine"""
|
||||
|
||||
# app/services/voice_clone_service.py
|
||||
class VoiceCloneService:
|
||||
"""音色克隆服务"""
|
||||
|
||||
async def create_clone(
|
||||
self,
|
||||
voice_name: str,
|
||||
audio_url: str, # 七牛云音频URL
|
||||
user_id: str,
|
||||
) -> VoiceCloneJob:
|
||||
"""提交音色克隆任务到 KlingAI"""
|
||||
|
||||
async def sync_clone_status(
|
||||
self,
|
||||
job_id: str,
|
||||
) -> VoiceCloneStatus:
|
||||
"""同步查询克隆任务状态(轻量操作,不走Async Engine)"""
|
||||
|
||||
async def list_clones(self, user_id: str) -> list[ClonedVoice]:
|
||||
"""查询用户所有克隆音色"""
|
||||
```
|
||||
|
||||
### 5.4 新增数据库模型
|
||||
|
||||
```python
|
||||
# app/models/voice_clone.py
|
||||
class VoiceClone(Base):
|
||||
"""用户克隆音色元数据(云端备份)"""
|
||||
__tablename__ = "voice_clones"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True) # vc_xxx
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id"))
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
provider: Mapped[str] = mapped_column(String(50), default="klingai")
|
||||
provider_voice_id: Mapped[str] = mapped_column(String(100))
|
||||
status: Mapped[str] = mapped_column(String(20)) # processing/succeed/failed
|
||||
preview_url: Mapped[str | None] = mapped_column(String(500))
|
||||
fail_reason: Mapped[str | None] = mapped_column(Text)
|
||||
deleted_at: Mapped[datetime | None]
|
||||
created_at: Mapped[datetime]
|
||||
updated_at: Mapped[datetime]
|
||||
```
|
||||
|
||||
> 注:智剪中不需要 Element(形象主体),只需要 Voice(音色),因此独立建表更简洁。
|
||||
|
||||
### 5.5 复用已有能力
|
||||
|
||||
| 已有能力 | 复用方式 |
|
||||
|---------|---------|
|
||||
| `KlingAIProvider.generate_tts()` | 直接调用,封装到 Service 层 |
|
||||
| `KlingAIProvider.create_custom_voice()` | 直接调用,封装到 VoiceCloneService |
|
||||
| `KlingAIProvider.list_preset_voices()` | 直接调用 |
|
||||
| `VolcengineCaptionService` | 完全复用,传入 TTS 音频 URL |
|
||||
| `SlotManager` + `JobRegistry` | 完全复用 |
|
||||
| `TokenManager` + `JWTTokenStrategy` | 完全复用 |
|
||||
| `qiniu_service.upload()` | 复用,支持 audio 类型 |
|
||||
| 七牛云上传凭证 | 复用 |
|
||||
|
||||
---
|
||||
|
||||
## 六、Rust 层改造方案
|
||||
|
||||
### 6.1 新增 IPC 命令
|
||||
|
||||
```rust
|
||||
// commands/media.rs
|
||||
#[tauri::command]
|
||||
async fn import_media(
|
||||
app: AppHandle,
|
||||
project_id: String,
|
||||
source_path: String,
|
||||
) -> Result<MediaInfo, String>
|
||||
|
||||
// commands/video_edit.rs
|
||||
#[tauri::command]
|
||||
async fn split_video(
|
||||
app: AppHandle,
|
||||
input_path: String,
|
||||
segments: Vec<SplitSegment>, // [{start, end, output_name}]
|
||||
) -> Result<Vec<String>, String> // 返回切割后的文件路径列表
|
||||
```
|
||||
|
||||
### 6.2 新增 FFmpeg 命令封装
|
||||
|
||||
在 `ffmpeg_cmd.rs` 中新增:
|
||||
|
||||
```rust
|
||||
/// 按时间范围批量截取视频片段
|
||||
///
|
||||
/// 输入一个长视频,按多个时间范围切割为独立文件
|
||||
pub async fn split_video_segments(
|
||||
app: &AppHandle,
|
||||
input: &str,
|
||||
segments: &[(f64, f64, &str)], // (start, end, output_path)
|
||||
) -> Result<Vec<String>, FFmpegError>
|
||||
|
||||
/// 提取视频元信息(时长、分辨率、码率等)
|
||||
pub async fn probe_media_info(
|
||||
input: &str,
|
||||
) -> Result<MediaInfo, FFmpegError>
|
||||
```
|
||||
|
||||
### 6.3 本地存储路径扩展
|
||||
|
||||
```rust
|
||||
// storage/paths.rs
|
||||
|
||||
/// 项目素材目录:~/Documents/Meijiaka/projects/{id}/media/
|
||||
pub fn get_project_media_dir(project_id: &str) -> PathBuf
|
||||
|
||||
/// 项目音频目录:~/Documents/Meijiaka/projects/{id}/audio/
|
||||
pub fn get_project_audio_dir(project_id: &str) -> PathBuf
|
||||
|
||||
/// 项目分镜视频目录:~/Documents/Meijiaka/projects/{id}/shots/
|
||||
pub fn get_project_shots_dir(project_id: &str) -> PathBuf
|
||||
```
|
||||
|
||||
存储结构:
|
||||
```
|
||||
~/Documents/Meijiaka/
|
||||
├── projects/{project_id}/
|
||||
│ ├── meta.json
|
||||
│ ├── segments.json
|
||||
│ ├── media/ # 导入的原始素材
|
||||
│ │ └── source.mp4 # 原始长视频
|
||||
│ ├── shots/ # 自动切割后的分镜视频
|
||||
│ │ ├── shot_001.mp4
|
||||
│ │ └── shot_002.mp4
|
||||
│ ├── audio/ # TTS 生成的音频
|
||||
│ │ ├── tts_001.mp3
|
||||
│ │ └── tts_002.mp3
|
||||
│ └── assets/ # 封面等成品资源
|
||||
│ └── cover_xxx.png
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、AI 能力集成
|
||||
|
||||
### 7.1 音色克隆
|
||||
|
||||
**Provider: KlingAI(已具备能力)**
|
||||
|
||||
```
|
||||
API: POST /v1/general/custom-voices
|
||||
参数:
|
||||
- voice_name: 音色名称
|
||||
- voice_url: 音频文件URL(5-30秒,干净人声)
|
||||
|
||||
限制:
|
||||
- 音频时长: 5-30 秒
|
||||
- 格式: MP3 / WAV
|
||||
- 要求: 单一人声、无杂音、无背景音乐
|
||||
```
|
||||
|
||||
**前端录音方案:**
|
||||
- 使用 Web Audio API 录制麦克风音频
|
||||
- 实时波形可视化
|
||||
- 录制时长控制(10-20 秒最佳)
|
||||
- 录制完成后上传至七牛云 → 后端提交克隆任务
|
||||
|
||||
**状态流转:**
|
||||
```
|
||||
用户录音/上传 → 前端上传七牛云 → 后端调用 KlingAI 创建音色
|
||||
↓
|
||||
[processing] ← 前端轮询
|
||||
↓
|
||||
[succeed] → 保存到 DB → 加入"我的音色"
|
||||
↓
|
||||
[failed] → 提示用户重新录制
|
||||
```
|
||||
|
||||
### 7.2 语音合成(TTS)
|
||||
|
||||
**Provider: KlingAI(已具备能力,需上层封装)**
|
||||
|
||||
```
|
||||
API: POST /v1/audio/tts
|
||||
参数:
|
||||
- text: 要合成的文本(旁白文案)
|
||||
- voice_id: 音色ID(预设或自定义)
|
||||
- voice_language: zh / en
|
||||
- voice_speed: 0.8 - 2.0(默认 1.0)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID
|
||||
- 轮询 GET /v1/audio/tts/{task_id} 获取音频URL
|
||||
```
|
||||
|
||||
**批量处理策略:**
|
||||
- 每个分镜一个 TTS 任务
|
||||
- Async Engine 并行处理(最多 10 个并发)
|
||||
- 前端显示总体进度(已完成 N / 总分镜数 M)
|
||||
|
||||
### 7.3 文案生成
|
||||
|
||||
**复用现有 ScriptService**,但调整 Prompt:
|
||||
- 原:生成「场景描述 + 旁白 + 时长」的营销脚本
|
||||
- 新:生成「旁白文案 + 预估时长」的短视频文案
|
||||
- 支持根据目标时长(15s / 30s / 60s)控制字数
|
||||
|
||||
---
|
||||
|
||||
## 八、独立新仓库初始化方案
|
||||
|
||||
### 8.1 仓库创建
|
||||
|
||||
```bash
|
||||
# 在本地创建新仓库目录
|
||||
mkdir meijiaka-zj
|
||||
cd meijiaka-zj
|
||||
git init
|
||||
|
||||
# 复制智影代码(排除依赖和构建产物)
|
||||
rsync -av \
|
||||
--exclude='.git' \
|
||||
--exclude='node_modules' \
|
||||
--exclude='.venv' \
|
||||
--exclude='__pycache__' \
|
||||
--exclude='.mypy_cache' \
|
||||
--exclude='.ruff_cache' \
|
||||
--exclude='.pytest_cache' \
|
||||
--exclude='dist' \
|
||||
--exclude='target' \
|
||||
--exclude='*.lock' \
|
||||
--exclude='.DS_Store' \
|
||||
../ai-meijiaka/ .
|
||||
|
||||
# 初始化提交
|
||||
git add -A
|
||||
git commit -m "init: fork from meijiaka-zy"
|
||||
```
|
||||
|
||||
### 8.2 品牌配置修改清单
|
||||
|
||||
| 文件 | 修改项 |
|
||||
|------|--------|
|
||||
| `tauri-app/src-tauri/tauri.conf.json` | `productName`: 美家卡智影 → 美家卡智剪;`identifier`: `cn.meijiaka.ai-video` → `cn.meijiaka.ai-video-editor`;`title`: 美家卡 智影 → 美家卡 智剪 |
|
||||
| `tauri-app/package.json` | `name`: 可保持不变(内部包名) |
|
||||
| `python-api/app/main.py` | FastAPI 文档标题、描述更新 |
|
||||
| `AGENTS.md` | 全文替换「智影」→「智剪」,更新产品描述 |
|
||||
| `README.md` | 更新为智剪的产品说明 |
|
||||
|
||||
### 8.3 项目结构
|
||||
|
||||
```
|
||||
meijiaka-zj/ # 新仓库根目录
|
||||
├── python-api/ # FastAPI 后端(从智影复制后改造)
|
||||
│ ├── app/
|
||||
│ │ ├── api/v1/ # 新增 voice.py, tts.py 路由
|
||||
│ │ ├── ai/providers/ # 复用 KlingAIProvider
|
||||
│ │ ├── scheduler/handlers/ # 新增 tts_handler.py
|
||||
│ │ ├── services/ # 新增 tts_service.py, voice_clone_service.py
|
||||
│ │ ├── models/ # 新增 voice_clone.py
|
||||
│ │ └── schemas/ # 新增 voice.py, tts.py
|
||||
│ ├── config/
|
||||
│ ├── alembic/
|
||||
│ ├── pyproject.toml
|
||||
│ └── ...
|
||||
│
|
||||
├── tauri-app/ # Tauri 前端(从智影复制后改造)
|
||||
│ ├── src/
|
||||
│ │ ├── pages/VideoCreation/
|
||||
│ │ │ ├── ScriptCreation.tsx # Step 1(改造)
|
||||
│ │ │ ├── VideoEditing.tsx # Step 2(新增)
|
||||
│ │ │ ├── VoiceDubbing.tsx # Step 3(新增)
|
||||
│ │ │ ├── SubtitleBurning.tsx # Step 4(复用)
|
||||
│ │ │ ├── CoverDesign.tsx # Step 5(复用)
|
||||
│ │ │ └── VideoComposite.tsx # Step 6(复用)
|
||||
│ │ ├── store/
|
||||
│ │ │ ├── projectStore.ts # 改造
|
||||
│ │ │ └── voiceStore.ts # 新增
|
||||
│ │ ├── api/modules/
|
||||
│ │ │ ├── voice.ts # 新增
|
||||
│ │ │ └── tts.ts # 新增
|
||||
│ │ └── hooks/
|
||||
│ │ ├── useVoiceClone.ts # 新增
|
||||
│ │ ├── useTTSGeneration.ts # 新增
|
||||
│ │ └── useAutoSplit.ts # 新增
|
||||
│ ├── src-tauri/src/
|
||||
│ │ ├── commands/media.rs # 新增
|
||||
│ │ ├── ffmpeg_cmd.rs # 新增函数
|
||||
│ │ └── storage/paths.rs # 新增路径
|
||||
│ └── ...
|
||||
│
|
||||
├── docs/ # 文档
|
||||
│ └── meijiaka-zhijian-proposal.md
|
||||
│
|
||||
└── scripts/ # 工具脚本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 九、实施路线图
|
||||
|
||||
### Phase 1: 基础架构(2 周)
|
||||
|
||||
**目标**:搭建新项目骨架,打通基础能力
|
||||
|
||||
| 任务 | 说明 |
|
||||
|------|------|
|
||||
| ① 仓库初始化 | 复制智影代码,修改品牌配置,建立独立仓库 |
|
||||
| ② 数据模型改造 | 新增 `voice_clones` 表,改造 `segments` Schema |
|
||||
| ③ TTS API 封装 | 新增 `tts_service.py`、`voice.py` / `tts.py` 路由 |
|
||||
| ④ 音色克隆 API | 新增 `voice_clone_service.py` |
|
||||
| ⑤ 前端 Store 改造 | 改造 `projectStore`,新增 `voiceStore` |
|
||||
| ⑥ 素材导入 IPC | 新增 `import_media`、`split_video` Rust 命令 |
|
||||
|
||||
### Phase 2: 核心流程(2 周)
|
||||
|
||||
**目标**:完成 6 步核心流程 MVP
|
||||
|
||||
| 任务 | 说明 |
|
||||
|------|------|
|
||||
| ⑦ Step 1 脚本生成 | 复用现有逻辑,Prompt 调整为纯旁白文案 |
|
||||
| ⑧ Step 2 视频剪辑 | 素材导入 UI + 自动切割逻辑 + 分镜绑定 |
|
||||
| ⑨ Step 3 音色配音 | 音色克隆 UI + TTS 合成 UI + 批量任务 |
|
||||
| ⑩ TTS Async Handler | 实现 `TTSHandler`,接入 Async Engine |
|
||||
| ⑪ 字幕压制适配 | 基于 TTS 音频的自动打轴 + 字幕压制 |
|
||||
| ⑫ 封面+合成 | 复用现有逻辑,验证端到端流程 |
|
||||
|
||||
### Phase 3: 打磨优化(1 周)
|
||||
|
||||
**目标**:提升用户体验,修复问题
|
||||
|
||||
| 任务 | 说明 |
|
||||
|------|------|
|
||||
| ⑬ 切割算法优化 | 智能检测场景切换点,避免在人物说话中间切割 |
|
||||
| ⑭ 批量操作优化 | 统一音色、批量重新合成 |
|
||||
| ⑮ 错误处理 | 视频格式不支持、TTS 失败、文案超长等异常 |
|
||||
| ⑯ 性能优化 | 大视频导入、多任务并发 |
|
||||
| ⑰ 测试验收 | 全流程测试,修复 bug |
|
||||
|
||||
### 总工期预估:**5 周**
|
||||
|
||||
```
|
||||
Week 1-2: Phase 1 — 基础架构
|
||||
Week 3-4: Phase 2 — 核心流程 MVP
|
||||
Week 5: Phase 3 — 打磨优化 + 测试
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 十、技术风险与应对
|
||||
|
||||
| 风险 | 影响 | 应对方案 |
|
||||
|------|------|---------|
|
||||
| KlingAI TTS 并发限制 | 批量合成慢 | Async Engine 槽位控制 + 前端进度管理 |
|
||||
| KlingAI 音色克隆失败率高 | 用户体验差 | 前端引导用户录制规范音频(安静环境、清晰人声) |
|
||||
| 文案总时长 > 视频时长 | 无法完整配音 | Step 2 导入时校验,超长则提示用户调整文案或换视频 |
|
||||
| 自动切割点落在不自然位置 | 画面割裂 | V2 引入场景切换检测,在关键帧处切割 |
|
||||
| 大视频文件导入卡顿 | 前端无响应 | Tauri 后端异步处理导入,前端仅显示进度 |
|
||||
| 视频格式兼容性 | 某些格式无法处理 | FFmpeg 统一标准化转码,支持主流格式 |
|
||||
| TTS 文本过长 | KlingAI 限制 | 分镜文案字数控制(建议单分镜 < 200 字) |
|
||||
|
||||
---
|
||||
|
||||
## 十一、长期演进方向
|
||||
|
||||
| 版本 | 功能 |
|
||||
|------|------|
|
||||
| **V1.0**(MVP)| 长视频导入 + 自动切割 + 音色克隆 + TTS + 字幕 + 封面 + 合成 |
|
||||
| **V1.5** | 智能切割(基于场景切换检测) |
|
||||
| **V2.0** | 多轨道编辑(背景音乐、音效、转场) |
|
||||
| **V2.5** | AI 视频摘要(长视频自动提取精彩片段) |
|
||||
| **V3.0** | 多音色对话(支持多人配音、角色音色) |
|
||||
|
||||
---
|
||||
|
||||
## 附录
|
||||
|
||||
### A. 关键术语对照
|
||||
|
||||
| 智影术语 | 智剪对应 | 说明 |
|
||||
|---------|---------|------|
|
||||
| `elementId` | `voiceId` | 从数字人形象ID变为音色ID |
|
||||
| `videoUrl` | `mediaPath` | 从AI生成视频变为切割后的素材片段 |
|
||||
| `Avatar` | `VoiceClone` | 从形象克隆简化为音色克隆 |
|
||||
| `humanId` | — | 移除,不再需要 |
|
||||
| `scene` | — | 可选保留,用于V2智能匹配 |
|
||||
|
||||
### B. 需要改造的文件清单
|
||||
|
||||
**后端(python-api):**
|
||||
```
|
||||
新增:
|
||||
app/api/v1/voice.py
|
||||
app/api/v1/tts.py
|
||||
app/services/tts_service.py
|
||||
app/services/voice_clone_service.py
|
||||
app/scheduler/handlers/tts_handler.py
|
||||
app/models/voice_clone.py
|
||||
app/schemas/voice.py
|
||||
app/schemas/tts.py
|
||||
|
||||
改造:
|
||||
app/scheduler/main.py # 注册 TTSHandler
|
||||
app/api/v1/router.py # 添加 voice/tts 路由
|
||||
app/schemas/segment.py # 扩展 Segment Schema
|
||||
app/ai/prompts/script/*.txt # 调整 Prompt 为纯旁白文案
|
||||
```
|
||||
|
||||
**前端(tauri-app):**
|
||||
```
|
||||
新增:
|
||||
src/pages/VideoCreation/VideoEditing.tsx
|
||||
src/pages/VideoCreation/VoiceDubbing.tsx
|
||||
src/store/voiceStore.ts
|
||||
src/api/modules/voice.ts
|
||||
src/api/modules/tts.ts
|
||||
src/hooks/useVoiceClone.ts
|
||||
src/hooks/useTTSGeneration.ts
|
||||
src/hooks/useAutoSplit.ts
|
||||
|
||||
改造:
|
||||
src/pages/VideoCreation/index.tsx # 调整为6步
|
||||
src/pages/VideoCreation/ScriptCreation.tsx # 移除场景描述字段
|
||||
src/store/projectStore.ts # 扩展数据模型
|
||||
src/api/types.ts # 更新类型定义
|
||||
```
|
||||
|
||||
**Rust(src-tauri):**
|
||||
```
|
||||
新增:
|
||||
src/commands/media.rs
|
||||
src/ffmpeg_cmd.rs 中的 split_video_segments / probe_media_info
|
||||
src/storage/paths.rs 中的 media/audio/shots 路径
|
||||
|
||||
改造:
|
||||
src/lib.rs # 注册新命令
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*本方案基于「美家卡智影」现有架构设计,最大化复用已有能力,降低开发成本与风险。*
|
||||
@@ -0,0 +1,243 @@
|
||||
# 迁移方案:废弃云端 `mjk_avatars` 表,数字人元数据全量迁移到本地存储
|
||||
|
||||
> **状态:方案已调整(2026-04-17)** — 原方案中提到的 Celery 架构已完全移除,形象克隆现由 `app/scheduler/handlers/avatar_handler.py`(Async Engine Scheduler)统一调度。云端仍保留 `avatars` 表作为形象克隆的持久化记录。
|
||||
|
||||
## 方案目标
|
||||
|
||||
| 目标 | 说明 |
|
||||
|------|------|
|
||||
| ✅ 贯彻设计理念 | 真正做到**轻量云 + 全本地业务数据**,云端只记日志不存业务数据 |
|
||||
| ✅ 统一接口日志 | 所有接口请求统一记录到 `mjk_interface_request_logs`,按接口统计积分消耗 |
|
||||
| ✅ 简化后端代码 | 删除大量 CRUD、状态管理、定时任务代码,后端更干净 |
|
||||
| ✅ 用户掌控数据 | 所有数字人元数据存在用户本地,云端只记克隆请求的消耗积分 |
|
||||
|
||||
---
|
||||
|
||||
## 存储结构变化
|
||||
|
||||
### 变化前(现状)
|
||||
```
|
||||
云端 PostgreSQL: mjk_avatars
|
||||
└─ 存储所有数字人元数据 (name/voice_id/element_id/status 等)
|
||||
前端本地:
|
||||
└─ 只做缓存,从云端同步
|
||||
```
|
||||
|
||||
### 变化后(目标)
|
||||
```
|
||||
云端 PostgreSQL:
|
||||
├─ mjk_interface_request_logs ← 只记:avatar_clone 请求 + 消耗积分 + 状态
|
||||
└─ mjk_avatars ← 废弃,不再写入新数据(存量可保留可删除)
|
||||
用户本地磁盘:
|
||||
└─ ~/Documents/Meijiaka/avatars/{avatar_id}/
|
||||
├─ meta.json ← 完整数字人元数据(JSON)
|
||||
└─ source.mp4 ← 原始上传视频
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 本地存储结构定义
|
||||
|
||||
### 目录结构
|
||||
```
|
||||
~/Documents/Meijiaka/
|
||||
└── avatars/
|
||||
└── {avatar_id}/ # avatar_id = avt_{16位随机hex}
|
||||
├── meta.json # 元数据(JSON 格式)
|
||||
└── source.mp4 # 原始上传视频
|
||||
```
|
||||
|
||||
### `meta.json` 结构
|
||||
```json
|
||||
{
|
||||
"id": "avt_xxxxxxxxxxxxxxxx",
|
||||
"name": "我的数字人",
|
||||
"voiceId": "klingai-voice-id-string",
|
||||
"elementId": 12345678,
|
||||
"voiceTaskId": "kling-task-id-string",
|
||||
"elementTaskId": "kling-task-id-string",
|
||||
"videoUrl": "https://domain.com/path/to/source.mp4",
|
||||
"trialUrl": "https://domain.com/path/to/trial.wav",
|
||||
"status": "succeed",
|
||||
"failReason": null,
|
||||
"createdAt": "2026-04-16T10:00:00.000Z",
|
||||
"updatedAt": "2026-04-16T10:05:00.000Z"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 代码改动清单
|
||||
|
||||
### 后端 Python
|
||||
|
||||
| 操作 | 文件 | 改动说明 |
|
||||
|------|------|----------|
|
||||
| 🆕 新增 | `app/models/interface_request_logs.py` | SQLAlchemy 模型 `InterfaceRequestLogs` |
|
||||
| 🆕 新增 | `app/crud/interface_request_logs.py` | CRUD:create / update |
|
||||
| ✏️ 修改 | `app/models/__init__.py` | 删除 `Avatar` 导入,新增 `InterfaceRequestLogs` |
|
||||
| ✏️ 修改 | `app/api/v1/avatar.py` | 完全重写<br>• 保留:`POST /clone` / `GET /tasks/{id}` / `GET /clone/stream` / `POST /tasks/{id}/retry` / `DELETE /{id}` <br>• 删除:`GET /library` / `PATCH /{id}` / `/health` |
|
||||
| ✏️ 修改 | `app/scheduler/handlers/avatar_handler.py` | 精简:删除所有对 `mjk_avatars` 读写,只记接口日志,进度放 Redis Registry |
|
||||
| ❌ 删除 | `app/models/avatar.py` | 模型废弃,删除 |
|
||||
| ❌ 删除 | `app/crud/avatar.py` | CRUD 废弃,删除 |
|
||||
| ❌ 删除 | `app/tasks/avatar_clone.py` | 逻辑已合并到 avatar_handler,删除 |
|
||||
|
||||
### Rust Tauri(`tauri-app/src-tauri/src/persistence.rs`)
|
||||
|
||||
新增以下 IPC 命令:
|
||||
|
||||
```rust
|
||||
/// 列出所有本地数字人(按创建时间倒序)
|
||||
#[tauri::command]
|
||||
pub fn list_avatars(app: AppHandle) -> Result<Vec<AvatarMeta>, String>;
|
||||
|
||||
/// 保存数字人元数据
|
||||
#[tauri::command]
|
||||
pub fn save_avatar(app: AppHandle, avatar_id: String, meta: AvatarMeta) -> Result<(), String>;
|
||||
|
||||
/// 获取单个数字人元数据
|
||||
#[tauri::command]
|
||||
pub fn get_avatar(app: AppHandle, avatar_id: String) -> Result<Option<AvatarMeta>, String>;
|
||||
|
||||
/// 删除数字人(删除整个本地目录)
|
||||
#[tauri::command]
|
||||
pub fn delete_avatar(app: AppHandle, avatar_id: String) -> Result<(), String>;
|
||||
|
||||
/// 更新数字人名称
|
||||
#[tauri::command]
|
||||
pub fn update_avatar_name(app: AppHandle, avatar_id: String, name: String) -> Result<(), String>;
|
||||
```
|
||||
|
||||
在 `lib.rs` 注册新命令。
|
||||
|
||||
### 前端 TypeScript
|
||||
|
||||
| 模块 | 改动 |
|
||||
|------|------|
|
||||
| **Avatar 列表** | 原:`GET /avatar/library` 从后端获取 → 现在:调用 Tauri IPC 从本地读取 |
|
||||
| **创建克隆** | 流程变化:<br>1. 前端生成 `avatar_id`<br>2. `POST /avatar/clone` → 获取 `task_id`<br>3. 前端创建本地目录 + 写入初始 `meta.json` (`status=pending`)<br>4. SSE 监听进度<br>5. 完成后 → 前端把 `voice_id`/`element_id` 写入本地 `meta.json`<br>6. 完成 |
|
||||
| **删除 Avatar** | 流程变化:<br>1. 前端调用 `DELETE /avatar/{avatar_id}`(后端负责删除 Kling 远程资源)<br>2. 后端记删除日志到接口日志<br>3. 前端调用 IPC 删除本地目录 |
|
||||
| **重命名 Avatar** | 原:调用后端 PATCH → 现在:前端直接修改本地 `meta.json`,无需请求后端 |
|
||||
| **选择数字人生成视频** | 用法不变:从本地读取 `voice_id`/`element_id` → 传给后端视频生成接口 |
|
||||
|
||||
---
|
||||
|
||||
## 工作流对比
|
||||
|
||||
### 改动前(当前)
|
||||
```
|
||||
用户提交克隆
|
||||
→ POST /clone → 后端写 mjk_avatars (status=pending) → 派发任务
|
||||
→ Async Engine Scheduler (avatar_handler) 每一步都更新 `avatars` 表
|
||||
→ 前端 SSE 轮询读 `avatars` 表拿进度
|
||||
→ 完成后 Scheduler 更新 status=succeed 写入 voice_id/element_id
|
||||
→ 前端从 `avatars` 表读结果 → 缓存到本地
|
||||
→ 列表从 `avatars` 表读取
|
||||
```
|
||||
|
||||
### 改动后(目标)
|
||||
```
|
||||
用户提交克隆
|
||||
→ 前端生成 avatar_id → 创建本地 meta.json (status=pending)
|
||||
→ POST /clone → 后端:
|
||||
1. 在 mjk_interface_request_logs 插入记录
|
||||
interface_type=avatar_clone, status=pending, started_at=now, cost_credits=X
|
||||
2. 注册到 Async Engine Scheduler (Redis Registry)
|
||||
3. 返回 {task_id, avatar_id}
|
||||
→ Async Engine Scheduler (avatar_handler) 执行:
|
||||
1. 调用 Kling 创建音色 → 轮询 → 获取 voice_id
|
||||
2. 调用 Kling 创建主体 → 轮询 → 获取 element_id
|
||||
3. 更新 Redis Registry 状态为 completed,写入结果
|
||||
4. 更新接口日志: status=success/failed, finished_at=now
|
||||
→ 前端 SSE 从 TaskCache 获取结果
|
||||
→ 完成后前端将 voice_id/element_id 写入本地 meta.json
|
||||
→ 列表展示直接从本地读取,不请求后端
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 接口日志记录规则
|
||||
|
||||
`mjk_interface_request_logs` 对 `avatar_clone` 的记录:
|
||||
|
||||
| 时机 | 操作 | 字段值 |
|
||||
|------|------|--------|
|
||||
| 刚收到请求 | 插入新记录 | `interface_type=avatar_clone`, `status=pending`, `started_at=NOW`, `cost_credits` = 克隆一次所需积分 |
|
||||
| 任务完成成功 | 更新记录 | `status=success`, `finished_at=NOW` |
|
||||
| 任务失败 | 更新记录 | `status=failed`, `finished_at=NOW`, `error_message=错误原因` |
|
||||
|
||||
> 积分在请求创建时即扣除,因为无论成功失败,KlingAI 开始处理后会计费。
|
||||
|
||||
---
|
||||
|
||||
## 存量数据迁移策略
|
||||
|
||||
### 渐进迁移(对用户友好)
|
||||
|
||||
1. **保留云端表**:`mjk_avatars` 保留不删除,存量数据继续存在
|
||||
2. **前端自动迁移**:用户首次打开形象库时:
|
||||
- 前端检查:如果后端有数据但本地没有 → 提示用户"将云端数字人同步到本地"
|
||||
- 用户确认后,前端逐个拉取数据写入本地
|
||||
- 同步完成后,后续只使用本地数据
|
||||
3. **下线旧表**:稳定运行一段时间后,可在维护窗口物理删除 `mjk_avatars` 表
|
||||
|
||||
### 回滚方案
|
||||
- 迁移过程中如果出问题,随时切回原逻辑(表保留,代码只需恢复删除部分)
|
||||
|
||||
---
|
||||
|
||||
## 优缺点总结
|
||||
|
||||
| 优点 | 说明 |
|
||||
|------|------|
|
||||
| ✅ 完全符合需求 | 云端只存接口请求记录和消耗积分,不存用户业务数据 |
|
||||
| ✅ 云端存储成本极低 | 只有接口日志,每条几KB,用户增长成本可控 |
|
||||
| ✅ 后端代码大幅简化 | 删除了整个 Avatar CRUD、状态机管理、定时任务恢复,代码减少约 300 行 |
|
||||
| ✅ 用户完全掌控数据 | 所有数字人元数据存储在用户本地磁盘 |
|
||||
| ✅ 形象库展示更快 | 本地读取文件比查询数据库快很多 |
|
||||
| ✅ 兼容存量数据 | 渐进迁移,可回滚 |
|
||||
|
||||
| 缺点 | 说明 | 应对 |
|
||||
|------|------|------|
|
||||
| 用户换电脑需要迁移 | 用户需要自行迁移数据,或重新克隆 | 后续可增加导出/导入功能解决 |
|
||||
| 本地硬盘损坏数据丢失 | 这是"全本地"设计的必然结果 | 符合项目初始"轻量云+全本地"设计理念,用户自担数据安全 |
|
||||
|
||||
---
|
||||
|
||||
## 执行步骤(按顺序)
|
||||
|
||||
1. **数据库**
|
||||
- 生成 Alembic 迁移:所有表重命名加 `mjk_` 前缀 + 新建 `mjk_interface_request_logs`
|
||||
- 修改所有 Python 模型中的 `__tablename__`
|
||||
|
||||
2. **后端代码**
|
||||
- 新建 `interface_request_logs` 模型和 CRUD
|
||||
- 重写 `app/api/v1/avatar.py`
|
||||
- 精简 `app/tasks/avatar_tasks.py`
|
||||
- 删除废弃文件
|
||||
|
||||
3. **Rust Tauri**
|
||||
- 在 `persistence.rs` 新增 avatar 相关 IPC 命令
|
||||
- 在 `lib.rs` 注册命令
|
||||
|
||||
4. **前端代码**
|
||||
- 修改形象库:从本地读取
|
||||
- 修改创建流程:完成后写入本地
|
||||
- 修改删除流程:删除云端 Kling 资源后删除本地
|
||||
- 修改重命名:直接本地修改
|
||||
|
||||
5. **测试验证**
|
||||
- 创建克隆 → 检查本地文件生成 → 检查接口日志写入
|
||||
- 列表展示 → 删除 → 重命名 全流程测试
|
||||
|
||||
---
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [数据库设计规范](./database-design.md) - 完整的数据库命名规范和表结构
|
||||
- [视频生成流程](./video-generation-flow.md) - 完整视频生成流程说明
|
||||
|
||||
---
|
||||
|
||||
*版本:v1.0*
|
||||
*创建日期:2026-04-16*
|
||||
@@ -0,0 +1,739 @@
|
||||
# 七牛云对象存储 (Kodo) Python SDK 开发规范
|
||||
|
||||
## 概述
|
||||
|
||||
本文档规范美家卡智影项目中使用七牛云对象存储 (Kodo) Python SDK 的开发标准,涵盖文件上传、下载、管理和 CDN 操作等核心功能。
|
||||
|
||||
**SDK 版本**: v5.0.0+
|
||||
**Python 版本**: 3.8+ (兼容 2.7 和 3.3+)
|
||||
**官方文档**: https://developer.qiniu.com/kodo/1242/python
|
||||
|
||||
---
|
||||
|
||||
## 1. 安装与初始化
|
||||
|
||||
### 1.1 安装 SDK
|
||||
|
||||
```bash
|
||||
pip install qiniu
|
||||
```
|
||||
|
||||
### 1.2 初始化配置
|
||||
|
||||
```python
|
||||
from qiniu import Auth
|
||||
|
||||
# 从环境变量读取密钥(推荐)
|
||||
import os
|
||||
access_key = os.getenv('QINIU_ACCESS_KEY')
|
||||
secret_key = os.getenv('QINIU_SECRET_KEY')
|
||||
|
||||
# 构建鉴权对象
|
||||
q = Auth(access_key, secret_key)
|
||||
```
|
||||
|
||||
**环境变量配置** (`.env` 文件):
|
||||
```bash
|
||||
QINIU_ACCESS_KEY=your-access-key
|
||||
QINIU_SECRET_KEY=your-secret-key
|
||||
QINIU_BUCKET_NAME=your-bucket-name
|
||||
QINIU_BUCKET_DOMAIN=your-domain.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 文件上传
|
||||
|
||||
### 2.1 上传方式选择
|
||||
|
||||
| 场景 | 推荐方式 | 说明 |
|
||||
|------|----------|------|
|
||||
| 小文件 (< 100MB) | 表单上传 (put_file) | 简单快速,一次请求完成 |
|
||||
| 大文件 (> 100MB) | 分片上传 v2 (put_file_v2) | 支持断点续传,适应弱网环境 |
|
||||
| 网络不稳定 | 分片上传 v2 | 自动重试,更可靠 |
|
||||
|
||||
### 2.2 服务端生成上传 Token
|
||||
|
||||
```python
|
||||
from qiniu import Auth
|
||||
|
||||
def generate_upload_token(
|
||||
bucket_name: str,
|
||||
key: str = None,
|
||||
expires: int = 3600,
|
||||
policy: dict = None
|
||||
) -> str:
|
||||
"""
|
||||
生成上传凭证
|
||||
|
||||
Args:
|
||||
bucket_name: 存储空间名称
|
||||
key: 指定文件名(可选)
|
||||
expires: Token 有效期(秒),默认 3600
|
||||
policy: 上传策略配置(可选)
|
||||
|
||||
Returns:
|
||||
上传 Token 字符串
|
||||
"""
|
||||
q = Auth(access_key, secret_key)
|
||||
|
||||
# 自定义上传策略(可选)
|
||||
if policy is None:
|
||||
policy = {}
|
||||
|
||||
token = q.upload_token(bucket_name, key, expires, policy)
|
||||
return token
|
||||
```
|
||||
|
||||
### 2.3 客户端直传(推荐)
|
||||
|
||||
**服务端生成 Token,客户端直传到七牛云**:
|
||||
|
||||
```python
|
||||
# 服务端 API
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/qiniu", tags=["Qiniu"])
|
||||
|
||||
class UploadTokenRequest(BaseModel):
|
||||
key: str # 文件名
|
||||
expires: int = 3600 # Token 有效期
|
||||
|
||||
class UploadTokenResponse(BaseModel):
|
||||
token: str
|
||||
key: str
|
||||
upload_url: str = "https://upload.qiniup.com"
|
||||
|
||||
@router.post("/upload-token", response_model=UploadTokenResponse)
|
||||
async def get_upload_token(request: UploadTokenRequest):
|
||||
"""获取上传凭证,客户端直传"""
|
||||
token = generate_upload_token(
|
||||
bucket_name=os.getenv('QINIU_BUCKET_NAME'),
|
||||
key=request.key,
|
||||
expires=request.expires
|
||||
)
|
||||
return UploadTokenResponse(token=token, key=request.key)
|
||||
```
|
||||
|
||||
### 2.4 服务端上传文件(保留场景)
|
||||
|
||||
```python
|
||||
from qiniu import Auth, put_file_v2, etag
|
||||
import qiniu.config
|
||||
|
||||
def upload_file(
|
||||
local_file_path: str,
|
||||
key: str,
|
||||
bucket_name: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
服务端上传文件到七牛云
|
||||
|
||||
Args:
|
||||
local_file_path: 本地文件路径
|
||||
key: 存储的文件名(如 "audios/voice.mp3")
|
||||
bucket_name: 存储空间名称
|
||||
|
||||
Returns:
|
||||
{"key": str, "hash": str, "url": str}
|
||||
"""
|
||||
bucket_name = bucket_name or os.getenv('QINIU_BUCKET_NAME')
|
||||
|
||||
# 生成上传 Token
|
||||
token = q.upload_token(bucket_name, key, 3600)
|
||||
|
||||
# 使用分片上传 v2(推荐)
|
||||
ret, info = put_file_v2(
|
||||
up_token=token,
|
||||
key=key,
|
||||
file_path=local_file_path,
|
||||
version='v2' # 指定分片上传 v2 版本
|
||||
)
|
||||
|
||||
if ret is None:
|
||||
raise Exception(f"上传失败: {info}")
|
||||
|
||||
# 验证文件完整性
|
||||
assert ret['key'] == key
|
||||
assert ret['hash'] == etag(local_file_path)
|
||||
|
||||
# 构建访问 URL
|
||||
domain = os.getenv('QINIU_BUCKET_DOMAIN')
|
||||
url = f"https://{domain}/{key}"
|
||||
|
||||
return {
|
||||
"key": ret['key'],
|
||||
"hash": ret['hash'],
|
||||
"url": url
|
||||
}
|
||||
```
|
||||
|
||||
### 2.5 上传策略 (PutPolicy)
|
||||
|
||||
常用策略配置:
|
||||
|
||||
```python
|
||||
# 1. 限制文件大小 (10MB ~ 100MB)
|
||||
policy = {
|
||||
"fsizeMin": 1024 * 1024 * 10, # 最小 10MB
|
||||
"fsizeLimit": 1024 * 1024 * 100, # 最大 100MB
|
||||
"mimeLimit": "audio/*;video/*" # 限制文件类型
|
||||
}
|
||||
|
||||
# 2. 上传后回调业务服务器
|
||||
policy = {
|
||||
"callbackUrl": "https://your-api.com/callback",
|
||||
"callbackBody": "key=$(key)&hash=$(etag)&fname=$(fname)&fsize=$(fsize)",
|
||||
"callbackBodyType": "application/x-www-form-urlencoded"
|
||||
}
|
||||
|
||||
# 3. 上传后转码(持久化处理)
|
||||
import base64
|
||||
fops = 'avthumb/mp4/s/640x360/vb/1.25m'
|
||||
saveas_key = base64.urlsafe_b64encode(f'{bucket_name}:output.mp4'.encode()).decode()
|
||||
|
||||
policy = {
|
||||
"persistentOps": f"{fops}|saveas/{saveas_key}",
|
||||
"persistentPipeline": "transcoding", # 队列名称
|
||||
"persistentNotifyUrl": "https://your-api.com/pfop/callback"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 文件下载
|
||||
|
||||
### 3.1 公有空间下载
|
||||
|
||||
公有空间文件可直接访问:
|
||||
|
||||
```python
|
||||
def get_public_url(key: str, domain: str = None) -> str:
|
||||
"""获取公有空间文件 URL"""
|
||||
domain = domain or os.getenv('QINIU_BUCKET_DOMAIN')
|
||||
return f"https://{domain}/{key}"
|
||||
```
|
||||
|
||||
### 3.2 私有空间下载(临时 URL)
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
def get_private_url(key: str, expires: int = 3600) -> str:
|
||||
"""
|
||||
生成私有空间文件的临时下载 URL
|
||||
|
||||
Args:
|
||||
key: 文件 Key
|
||||
expires: 链接有效期(秒)
|
||||
|
||||
Returns:
|
||||
带签名的临时 URL
|
||||
"""
|
||||
domain = os.getenv('QINIU_BUCKET_DOMAIN')
|
||||
base_url = f"https://{domain}/{key}"
|
||||
|
||||
# 生成私有下载链接
|
||||
private_url = q.private_download_url(base_url, expires=expires)
|
||||
return private_url
|
||||
|
||||
# 使用示例
|
||||
def download_file(key: str, local_path: str):
|
||||
"""下载私有空间文件到本地"""
|
||||
private_url = get_private_url(key, expires=3600)
|
||||
|
||||
response = requests.get(private_url)
|
||||
if response.status_code == 200:
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
return True
|
||||
return False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 文件管理 (BucketManager)
|
||||
|
||||
### 4.1 初始化管理器
|
||||
|
||||
```python
|
||||
from qiniu import Auth, BucketManager
|
||||
|
||||
q = Auth(access_key, secret_key)
|
||||
bucket = BucketManager(q)
|
||||
```
|
||||
|
||||
### 4.2 获取文件信息
|
||||
|
||||
```python
|
||||
def get_file_info(bucket_name: str, key: str) -> dict:
|
||||
"""
|
||||
获取文件元信息
|
||||
|
||||
Returns:
|
||||
{
|
||||
"fsize": 文件大小(字节),
|
||||
"hash": 文件哈希,
|
||||
"mimeType": MIME类型,
|
||||
"putTime": 上传时间(100纳秒时间戳),
|
||||
"type": 存储类型(0=标准,1=低频,2=归档,3=深度归档)
|
||||
}
|
||||
"""
|
||||
ret, info = bucket.stat(bucket_name, key)
|
||||
if ret is None:
|
||||
raise Exception(f"获取文件信息失败: {info}")
|
||||
return ret
|
||||
```
|
||||
|
||||
### 4.3 列举文件列表
|
||||
|
||||
```python
|
||||
from typing import List, Optional
|
||||
|
||||
def list_files(
|
||||
bucket_name: str,
|
||||
prefix: str = None, # 前缀筛选
|
||||
limit: int = 100, # 每页数量
|
||||
marker: str = None # 分页标记
|
||||
) -> dict:
|
||||
"""
|
||||
列举空间文件列表
|
||||
|
||||
Returns:
|
||||
{
|
||||
"items": [{"key": ..., "fsize": ..., ...}],
|
||||
"marker": "分页标记",
|
||||
"commonPrefixes": ["公共前缀列表"]
|
||||
}
|
||||
"""
|
||||
ret, eof, info = bucket.list(
|
||||
bucket_name,
|
||||
prefix=prefix,
|
||||
marker=marker,
|
||||
limit=limit,
|
||||
delimiter=None # 不指定分隔符
|
||||
)
|
||||
|
||||
return {
|
||||
"items": ret.get('items', []),
|
||||
"marker": ret.get('marker'),
|
||||
"eof": eof # 是否已列举完
|
||||
}
|
||||
|
||||
# 遍历所有文件
|
||||
def list_all_files(bucket_name: str, prefix: str = None) -> List[dict]:
|
||||
"""遍历获取所有文件"""
|
||||
files = []
|
||||
marker = None
|
||||
|
||||
while True:
|
||||
result = list_files(bucket_name, prefix, limit=1000, marker=marker)
|
||||
files.extend(result['items'])
|
||||
|
||||
if result['eof'] or not result['marker']:
|
||||
break
|
||||
marker = result['marker']
|
||||
|
||||
return files
|
||||
```
|
||||
|
||||
### 4.4 删除文件
|
||||
|
||||
```python
|
||||
def delete_file(bucket_name: str, key: str) -> bool:
|
||||
"""删除单个文件"""
|
||||
ret, info = bucket.delete(bucket_name, key)
|
||||
return ret == {}
|
||||
|
||||
def delete_files_batch(bucket_name: str, keys: List[str]) -> dict:
|
||||
"""批量删除文件"""
|
||||
from qiniu import build_batch_delete
|
||||
|
||||
ops = build_batch_delete(bucket_name, keys)
|
||||
ret, info = bucket.batch(ops)
|
||||
return ret
|
||||
```
|
||||
|
||||
### 4.5 复制和移动文件
|
||||
|
||||
```python
|
||||
def copy_file(
|
||||
src_bucket: str,
|
||||
src_key: str,
|
||||
dest_bucket: str,
|
||||
dest_key: str,
|
||||
force: bool = True
|
||||
) -> bool:
|
||||
"""复制文件"""
|
||||
ret, info = bucket.copy(
|
||||
src_bucket, src_key,
|
||||
dest_bucket, dest_key,
|
||||
force=force # 强制覆盖
|
||||
)
|
||||
return ret is not None
|
||||
|
||||
def move_file(
|
||||
src_bucket: str,
|
||||
src_key: str,
|
||||
dest_bucket: str,
|
||||
dest_key: str,
|
||||
force: bool = True
|
||||
) -> bool:
|
||||
"""移动/重命名文件"""
|
||||
ret, info = bucket.move(
|
||||
src_bucket, src_key,
|
||||
dest_bucket, dest_key,
|
||||
force=force
|
||||
)
|
||||
return ret is not None
|
||||
```
|
||||
|
||||
### 4.6 修改文件元信息
|
||||
|
||||
```python
|
||||
def change_mime(bucket_name: str, key: str, mime_type: str):
|
||||
"""修改文件 MIME 类型"""
|
||||
ret, info = bucket.change_mime(bucket_name, key, mime_type)
|
||||
return ret is not None
|
||||
|
||||
def change_type(bucket_name: str, key: str, file_type: int):
|
||||
"""
|
||||
修改文件存储类型
|
||||
|
||||
file_type:
|
||||
0 = 标准存储
|
||||
1 = 低频存储
|
||||
2 = 归档存储
|
||||
3 = 深度归档存储
|
||||
"""
|
||||
ret, info = bucket.change_type(bucket_name, key, file_type)
|
||||
return ret is not None
|
||||
```
|
||||
|
||||
### 4.7 批量操作
|
||||
|
||||
```python
|
||||
from qiniu import (
|
||||
build_batch_stat,
|
||||
build_batch_copy,
|
||||
build_batch_move,
|
||||
build_batch_rename,
|
||||
build_batch_delete
|
||||
)
|
||||
|
||||
def batch_stat(bucket_name: str, keys: List[str]) -> List[dict]:
|
||||
"""批量查询文件信息"""
|
||||
ops = build_batch_stat(bucket_name, keys)
|
||||
ret, info = bucket.batch(ops)
|
||||
return ret
|
||||
|
||||
def batch_rename(
|
||||
bucket_name: str,
|
||||
key_map: dict, # {"old_key": "new_key", ...}
|
||||
force: bool = True
|
||||
):
|
||||
"""批量重命名"""
|
||||
ops = build_batch_rename(bucket_name, key_map, force=force)
|
||||
ret, info = bucket.batch(ops)
|
||||
return ret
|
||||
|
||||
def batch_copy(
|
||||
src_bucket: str,
|
||||
key_map: dict, # {"src_key": "dest_key", ...}
|
||||
dest_bucket: str = None,
|
||||
force: bool = True
|
||||
):
|
||||
"""批量复制"""
|
||||
dest_bucket = dest_bucket or src_bucket
|
||||
ops = build_batch_copy(src_bucket, key_map, dest_bucket, force=force)
|
||||
ret, info = bucket.batch(ops)
|
||||
return ret
|
||||
```
|
||||
|
||||
### 4.8 抓取网络资源
|
||||
|
||||
```python
|
||||
def fetch_remote_file(
|
||||
remote_url: str,
|
||||
key: str,
|
||||
bucket_name: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
抓取远程文件到七牛云
|
||||
|
||||
Args:
|
||||
remote_url: 远程文件 URL
|
||||
key: 保存的文件名
|
||||
bucket_name: 目标空间
|
||||
|
||||
Returns:
|
||||
{"key": ..., "hash": ..., "fsize": ...}
|
||||
"""
|
||||
bucket_name = bucket_name or os.getenv('QINIU_BUCKET_NAME')
|
||||
ret, info = bucket.fetch(remote_url, bucket_name, key)
|
||||
return ret
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. CDN 操作
|
||||
|
||||
### 5.1 初始化 CDN Manager
|
||||
|
||||
```python
|
||||
from qiniu import CdnManager
|
||||
|
||||
cdn_manager = CdnManager(q)
|
||||
```
|
||||
|
||||
### 5.2 刷新 CDN 缓存
|
||||
|
||||
```python
|
||||
def refresh_urls(urls: List[str]) -> dict:
|
||||
"""刷新指定 URL 的 CDN 缓存"""
|
||||
ret, info = cdn_manager.refresh_urls(urls)
|
||||
return ret
|
||||
|
||||
def refresh_dirs(dirs: List[str]) -> dict:
|
||||
"""刷新整个目录的 CDN 缓存"""
|
||||
ret, info = cdn_manager.refresh_dirs(dirs)
|
||||
return ret
|
||||
```
|
||||
|
||||
### 5.3 预取资源
|
||||
|
||||
```python
|
||||
def prefetch_urls(urls: List[str]) -> dict:
|
||||
"""预取资源到 CDN 节点"""
|
||||
ret, info = cdn_manager.prefetch_urls(urls)
|
||||
return ret
|
||||
```
|
||||
|
||||
### 5.4 获取 CDN 日志
|
||||
|
||||
```python
|
||||
def get_cdn_log_list(domains: List[str], log_date: str) -> List[dict]:
|
||||
"""
|
||||
获取 CDN 日志下载链接
|
||||
|
||||
Args:
|
||||
domains: 域名列表
|
||||
log_date: 日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
[{"name": ..., "url": ..., "size": ..., "mtime": ...}]
|
||||
"""
|
||||
ret, info = cdn_manager.get_log_list_data(domains, log_date)
|
||||
return ret.get('data', [])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 项目集成方案
|
||||
|
||||
### 6.1 服务端封装模块
|
||||
|
||||
```python
|
||||
# app/services/qiniu_service.py
|
||||
"""
|
||||
七牛云对象存储服务封装
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from qiniu import Auth, BucketManager, CdnManager, put_file_v2, etag
|
||||
|
||||
class QiniuService:
|
||||
"""七牛云服务封装"""
|
||||
|
||||
def __init__(self):
|
||||
access_key = os.getenv('QINIU_ACCESS_KEY')
|
||||
secret_key = os.getenv('QINIU_SECRET_KEY')
|
||||
self.bucket_name = os.getenv('QINIU_BUCKET_NAME')
|
||||
self.domain = os.getenv('QINIU_BUCKET_DOMAIN')
|
||||
|
||||
self.auth = Auth(access_key, secret_key)
|
||||
self.bucket = BucketManager(self.auth)
|
||||
self.cdn = CdnManager(self.auth)
|
||||
|
||||
def get_upload_token(self, key: str, expires: int = 3600, policy: dict = None) -> str:
|
||||
"""生成上传 Token"""
|
||||
return self.auth.upload_token(self.bucket_name, key, expires, policy)
|
||||
|
||||
def get_file_url(self, key: str, private: bool = False, expires: int = 3600) -> str:
|
||||
"""获取文件访问 URL"""
|
||||
base_url = f"https://{self.domain}/{key}"
|
||||
if private:
|
||||
return self.auth.private_download_url(base_url, expires)
|
||||
return base_url
|
||||
|
||||
def upload_file(self, local_path: str, key: str) -> dict:
|
||||
"""服务端上传文件"""
|
||||
token = self.get_upload_token(key)
|
||||
ret, info = put_file_v2(token, key, local_path, version='v2')
|
||||
|
||||
if ret is None:
|
||||
raise Exception(f"上传失败: {info}")
|
||||
|
||||
return {
|
||||
"key": ret['key'],
|
||||
"hash": ret['hash'],
|
||||
"url": self.get_file_url(key)
|
||||
}
|
||||
|
||||
def delete_file(self, key: str) -> bool:
|
||||
"""删除文件"""
|
||||
ret, info = self.bucket.delete(self.bucket_name, key)
|
||||
return ret == {}
|
||||
|
||||
def refresh_cdn(self, keys: List[str]) -> dict:
|
||||
"""刷新 CDN 缓存"""
|
||||
urls = [self.get_file_url(key) for key in keys]
|
||||
return self.cdn.refresh_urls(urls)
|
||||
|
||||
# 全局单例
|
||||
_qiniu_service: Optional[QiniuService] = None
|
||||
|
||||
def get_qiniu_service() -> QiniuService:
|
||||
global _qiniu_service
|
||||
if _qiniu_service is None:
|
||||
_qiniu_service = QiniuService()
|
||||
return _qiniu_service
|
||||
```
|
||||
|
||||
### 6.2 FastAPI 路由集成
|
||||
|
||||
```python
|
||||
# app/api/v1/qiniu.py
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File
|
||||
from app.services.qiniu_service import get_qiniu_service
|
||||
|
||||
router = APIRouter(prefix="/qiniu", tags=["Qiniu"])
|
||||
|
||||
@router.post("/upload-token")
|
||||
async def get_upload_token(key: str, expires: int = 3600):
|
||||
"""获取客户端直传 Token"""
|
||||
service = get_qiniu_service()
|
||||
token = service.get_upload_token(key, expires)
|
||||
return {"token": token, "key": key}
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), key: str = None):
|
||||
"""服务端上传文件(小文件场景)"""
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
service = get_qiniu_service()
|
||||
|
||||
# 生成唯一文件名
|
||||
if key is None:
|
||||
import uuid
|
||||
ext = file.filename.split('.')[-1] if '.' in file.filename else ''
|
||||
key = f"uploads/{uuid.uuid4()}.{ext}" if ext else f"uploads/{uuid.uuid4()}"
|
||||
|
||||
# 保存临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
shutil.copyfileobj(file.file, tmp)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = service.upload_file(tmp_path, key)
|
||||
return result
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
@router.delete("/files/{key:path}")
|
||||
async def delete_file(key: str):
|
||||
"""删除文件"""
|
||||
service = get_qiniu_service()
|
||||
success = service.delete_file(key)
|
||||
return {"success": success}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 最佳实践
|
||||
|
||||
### 7.1 文件名规范
|
||||
|
||||
```python
|
||||
def generate_key(file_type: str, user_id: str, filename: str) -> str:
|
||||
"""
|
||||
生成规范的文件存储路径
|
||||
|
||||
格式: {type}/{user_id}/{date}/{uuid}.{ext}
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
ext = filename.split('.')[-1] if '.' in filename else 'bin'
|
||||
date = datetime.now().strftime('%Y%m')
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
|
||||
return f"{file_type}/{user_id}/{date}/{unique_id}.{ext}"
|
||||
|
||||
# 使用示例
|
||||
key = generate_key("voices", "user_123", "my-voice.mp3")
|
||||
# 结果: voices/user_123/202501/a1b2c3d4.mp3
|
||||
```
|
||||
|
||||
### 7.2 错误处理
|
||||
|
||||
```python
|
||||
from qiniu import AuthError, HTTPError
|
||||
|
||||
def handle_qiniu_error(func):
|
||||
"""七牛云操作错误处理装饰器"""
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except AuthError as e:
|
||||
raise Exception(f"认证失败: {e}")
|
||||
except HTTPError as e:
|
||||
raise Exception(f"请求失败: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"操作失败: {e}")
|
||||
return wrapper
|
||||
```
|
||||
|
||||
### 7.3 安全配置
|
||||
|
||||
1. **密钥管理**: 使用环境变量,禁止硬编码
|
||||
2. **Token 有效期**: 上传 Token 建议 1 小时,下载 Token 根据场景设置
|
||||
3. **上传策略**: 限制文件大小和 MIME 类型
|
||||
4. **私有空间**: 敏感文件使用私有空间 + 临时 URL
|
||||
|
||||
---
|
||||
|
||||
## 8. 常见问题
|
||||
|
||||
### Q1: 上传失败,返回 401 错误?
|
||||
|
||||
**A**: 检查 AccessKey 和 SecretKey 是否正确,以及 Token 是否过期。
|
||||
|
||||
### Q2: 如何支持大文件上传?
|
||||
|
||||
**A**: 使用分片上传 v2 (`put_file_v2`),SDK 会自动处理分片和断点续传。
|
||||
|
||||
### Q3: 文件上传后如何获取访问 URL?
|
||||
|
||||
**A**: 公有空间直接拼接 `https://{domain}/{key}`,私有空间使用 `auth.private_download_url()` 生成临时 URL。
|
||||
|
||||
### Q4: 如何刷新 CDN 缓存?
|
||||
|
||||
**A**: 使用 `CdnManager.refresh_urls()` 或 `refresh_dirs()`,注意目录刷新有每日限额。
|
||||
|
||||
### Q5: 上传回调不生效?
|
||||
|
||||
**A**: 确保 callbackUrl 是公网可访问的 HTTPS 地址,且返回 Content-Type: application/json。
|
||||
|
||||
---
|
||||
|
||||
## 9. 参考资料
|
||||
|
||||
- [七牛云 Python SDK 官方文档](https://developer.qiniu.com/kodo/1242/python)
|
||||
- [上传策略文档](https://developer.qiniu.com/kodo/1206/put-policy)
|
||||
- [表单上传 API](https://developer.qiniu.com/kodo/1272/api-overview)
|
||||
- [Python SDK GitHub](https://github.com/qiniu/python-sdk)
|
||||
@@ -0,0 +1,425 @@
|
||||
# 后端语义治理与架构重构计划
|
||||
|
||||
> **范围**:`python-api/app/` 全目录
|
||||
> **目标**:根治需求调整与 Celery→Async Scheduler 迁移导致的命名腐败、语义漂移、类型漂移问题
|
||||
> **原则**:长期稳定优先,不计短期开发成本
|
||||
> **状态**:计划中(待团队评审后执行)
|
||||
|
||||
---
|
||||
|
||||
## 一、问题诊断
|
||||
|
||||
经过深度代码扫描,当前代码存在六大类语义腐败,按严重程度排序:
|
||||
|
||||
### P0:调度器核心模型的 `params: dict[str, Any]` 类型漂移
|
||||
- **症状**:所有 Handler 从裸字典里 `params.get("shots")`,且因 Redis 序列化反复写 `if isinstance(shots, str): shots = json.loads(shots)`
|
||||
- **风险**:这是 Scheduler 取代 Celery 后最脆弱的环节,任何字段名改动都会导致运行时崩溃
|
||||
- **关键文件**:`app/scheduler/models.py`, `app/scheduler/registry.py`, `app/scheduler/handlers/video_handler.py`
|
||||
|
||||
### P0:`shot/segment/scene` 的三重命名 + 四处重复定义
|
||||
- **症状**:`ScriptShot`(Schema)、`ShotData`(API)、`ShotTask`(Service)、`ShotUnit`(Scheduler)四者并存,字段名 `scene`/`scene_desc`、`type`/`shot_type` 混用
|
||||
- **风险**:任何分镜字段改动必须改 4 个类,极易遗漏
|
||||
- **关键文件**:`app/schemas/script.py`, `app/api/v1/video.py`, `app/services/kling_video_service.py`, `app/scheduler/models.py`
|
||||
|
||||
### P0:Kling 供应商语义大规模泄漏到业务层和 API 层
|
||||
- **症状**:`element_id`(Kling 主体 ID)、`kling_task_id`、Omni prompt 语法 `<<<element_1>>>` 直接出现在 API Schema、Scheduler 模型、数据库模型中
|
||||
- **风险**:一旦更换视频供应商,影响面会穿透所有层级
|
||||
- **关键文件**:`app/api/v1/video.py`, `app/api/v1/tasks.py`, `app/models/avatar.py`, `app/scheduler/handlers/video_handler.py`
|
||||
|
||||
### P1:`task` / `task_id` 的五重语义混用
|
||||
- **症状**:FastAPI BackgroundTask、Scheduler Task、Kling API Task、AnyToCopy Task、Volcengine Task 都叫 `task`
|
||||
- **风险**:日志堆栈中无法区分层级,调试极其困难
|
||||
- **关键文件**:`app/api/v1/tasks.py`, `app/scheduler/`, `app/ai/providers/klingai_provider.py`, `app/services/`
|
||||
|
||||
### P1:历史残留命名与双轨运行
|
||||
- **症状**:`# 兼容旧字段`、`video_task_id`、`image_task_id`、`cache_err` 等 Celery 时代残留;`script` 任务仍走 BackgroundTask,其他任务走 Scheduler
|
||||
- **风险**:双轨运行导致统一监控、重试、日志无法覆盖全部任务类型
|
||||
- **关键文件**:`app/scheduler/handlers/video_handler.py`, `app/scheduler/handlers/image_handler.py`, `app/api/v1/tasks.py`
|
||||
|
||||
### P2:CRUD 层裸字典更新
|
||||
- **症状**:`avatar_crud.update(db, db_obj=avatar, obj_in={"status": "element_pending"})`
|
||||
- **风险**:拼写错误、状态值非法无法被静态检查捕获
|
||||
- **关键文件**:`app/crud/base.py`, `app/crud/avatar.py`, `app/scheduler/handlers/avatar_handler.py`
|
||||
|
||||
---
|
||||
|
||||
## 二、架构目标:六层语义治理
|
||||
|
||||
我们将整个后端严格划分为 **6 个语义层级**,每一层只使用属于该层的术语:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Layer 6: Presentation (API Schema / 前端适配层) │
|
||||
│ 术语: Segment, Human, Job, Script, Cover │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 5: Application (API 路由 / BackgroundJob) │
|
||||
│ 术语: Segment, Human, Job, Project │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 4: Orchestration (Scheduler / SlotManager) │
|
||||
│ 术语: Job, JobRecord, Slot, Handler │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 3: Domain (Service / 业务逻辑) │
|
||||
│ 术语: Segment, Human, VideoComposition, Caption │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 2: Adapter (Provider Client / 供应商适配) │
|
||||
│ 术语: KlingJob, KlingElement, VolcJob, ProviderTaskId │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Layer 1: Infrastructure (DB / Redis / HTTP / FileSys) │
|
||||
│ 术语: 仅使用底层技术术语 │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 核心禁令
|
||||
|
||||
1. `element`、`omni`、`kling_task_id` 等**供应商术语**禁止出现在 Layer 3 以上
|
||||
2. `shot` 禁止出现在 Layer 3 以上(Kling 术语,业务层统一叫 `segment`)
|
||||
3. `task` 禁止出现在 Layer 4(Scheduler 内部统一叫 `job`)
|
||||
4. `dict[str, Any]` 禁止出现在跨层传递的接口中
|
||||
|
||||
---
|
||||
|
||||
## 三、重构阶段(Phase 1-5)
|
||||
|
||||
每个 Phase 独立成组,建议按顺序执行。每个 Phase 完成后必须跑通 `pytest` 和 `make lint`。
|
||||
|
||||
---
|
||||
|
||||
### Phase 1:Schema 统一 + 状态机 Enum 化
|
||||
**目标**:建立"单一真相源",消除 shot/segment/scene 的四重定义
|
||||
**预估工时**:3-4 天
|
||||
**影响面**:全项目基础类型
|
||||
|
||||
#### Task 1.1:新建统一术语 Schema
|
||||
- [ ] 新建 `app/schemas/segment.py`
|
||||
```python
|
||||
class Segment(BaseModel):
|
||||
id: str
|
||||
type: Literal["segment", "empty_shot"]
|
||||
scene: str # 统一为 scene,删除 scene_desc
|
||||
voiceover: str
|
||||
duration: int | None = None
|
||||
human_id: str | None = None # 业务术语,对应 Kling 的 element_id
|
||||
status: SegmentStatus = SegmentStatus.PENDING
|
||||
provider_task_id: str | None = None
|
||||
video_url: str | None = None
|
||||
local_path: str | None = None
|
||||
query_fail_count: int = 0
|
||||
```
|
||||
- [ ] 新建 `app/schemas/enums.py`,定义以下 Enum:
|
||||
- `JobStatus`: pending, running, completed, failed
|
||||
- `SegmentStatus`: pending, submitted, completed, failed
|
||||
- `AvatarCloneStatus`: pending, voice_processing, voice_failed, element_pending, element_processing, element_failed, succeed
|
||||
- `KlingTaskStatus`: submitted, processing, succeed, failed(仅限 Provider 层使用)
|
||||
- [ ] 新建 `app/schemas/job.py`,定义 `JobParams` Union:
|
||||
- `VideoJobParams`(含 `segments: list[Segment]`)
|
||||
- `ImageJobParams`
|
||||
- `SubtitleJobParams`
|
||||
- `CopyJobParams`
|
||||
- `AvatarCloneJobParams`
|
||||
|
||||
#### Task 1.2:删除重复定义
|
||||
- [ ] 删除 `app/scheduler/models.py` 中的 `ShotUnit`
|
||||
- [ ] 删除/重构 `app/services/kling_video_service.py` 中的 `ShotTask`(字段迁移到 `Segment`)
|
||||
- [ ] 删除 `app/api/v1/video.py` 中的 `ShotData`,改为引用 `Segment`
|
||||
- [ ] 将 `app/schemas/script.py` 中的 `ScriptShot` 改为 `Segment` 的别名或类型适配器
|
||||
|
||||
#### Task 1.3:字段名统一
|
||||
- [ ] 批量将 `scene_desc` 重命名为 `scene`(覆盖 `kling_video_service.py`, `video_handler.py` 等)
|
||||
- [ ] 批量将 `shot_type` 重命名为 `type`(在业务层/Schema 层;Provider 层保留 `shot_type` 仅用于 Kling API 调用)
|
||||
- [ ] `app/api/v1/tasks.py` 中的 `shots: list[dict]` 改为 `segments: list[Segment]`
|
||||
|
||||
#### Task 1.4:状态字符串 Enum 化
|
||||
- [ ] `app/scheduler/models.py` 中 `TaskRecord.status` 类型改为 `JobStatus`
|
||||
- [ ] `app/services/kling_video_service.py` 中 `VideoGenerationJob.status` 类型改为 `JobStatus`
|
||||
- [ ] `app/models/avatar.py` 中 `Avatar.status` 类型改为 `AvatarCloneStatus`
|
||||
- [ ] `app/ai/providers/klingai_provider.py` 中所有 Kling 状态字符串操作改为 `KlingTaskStatus`
|
||||
|
||||
#### 验收标准
|
||||
- [ ] `grep -rn "class ShotUnit\|class ShotTask\|class ShotData\|class ScriptShot" app/` 返回为空(除了注释或别名声明)
|
||||
- [ ] `grep -rn "scene_desc" app/ | grep -v "__pycache__"` 返回为空
|
||||
- [ ] `mypy app/schemas/` 无类型错误
|
||||
- [ ] `pytest` 通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 2:Scheduler 层全面"去 task 化"
|
||||
**目标**:消除 `task` 的五重语义混用,建立 `Job` 专属语义域
|
||||
**预估工时**:3-4 天
|
||||
**影响面**:`app/scheduler/` 目录及引用方
|
||||
|
||||
#### Task 2.1:核心模型与 Registry 重命名
|
||||
- [ ] `app/scheduler/models.py`:`TaskRecord` → `JobRecord`
|
||||
- [ ] `app/scheduler/registry.py`:`TaskRegistry` → `JobRegistry`
|
||||
- 所有 `task_id` 参数/字段 → `job_id`
|
||||
- 所有 `task_type` 参数/字段 → `job_type`
|
||||
- `running_task_ids` → `running_job_ids`
|
||||
- [ ] `app/scheduler/engine.py`:`AsyncEngine` 中所有 `task` → `job`
|
||||
|
||||
#### Task 2.2:Registry 承担全部序列化职责
|
||||
- [ ] 在 `JobRegistry.get()` 中统一完成 JSON 反序列化
|
||||
- 解析 `params` 字段时,根据 `job_type` 路由到正确的 `JobParams` Pydantic 模型
|
||||
- 保证 Handler 拿到的 `job.params` 永远是强类型对象
|
||||
- [ ] 删除 `video_handler.py` 和 `image_handler.py` 中所有的 `isinstance(shots, str): json.loads` 逻辑
|
||||
- [ ] 删除 `_download_and_upload` 中的手动 JSON 类型判断
|
||||
|
||||
#### Task 2.3:`StateChange` 取代裸字典
|
||||
- [ ] 在 `app/scheduler/models.py` 中定义:
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class StateChange:
|
||||
job_id: str
|
||||
field: str
|
||||
value: Any
|
||||
```
|
||||
- [ ] 修改 `app/scheduler/engine.py`:
|
||||
- `_apply_changes(self, changes: list[dict[str, Any]])` → `_apply_changes(self, changes: list[StateChange])`
|
||||
- 序列化逻辑移入 `StateChange.to_redis_command()` 方法
|
||||
- [ ] 修改 `app/scheduler/handlers/base.py`:`tick()` 返回类型改为 `list[StateChange]`
|
||||
- [ ] 修改所有 Handler:`changes.append({"task_id": ..., "field": ...})` → `changes.append(StateChange(job_id=..., field=..., value=...))`
|
||||
|
||||
#### Task 2.4:API 层适配
|
||||
- [ ] `app/api/v1/tasks.py` 中:内部变量名 `task_id` 在引用 Scheduler 时改为 `job_id`(API URL `/tasks/{task_id}` 保持不变,仅内部变量和注释调整)
|
||||
- [ ] `app/api/v1/avatar.py` 中:引用 `TaskRegistry` 的地方改为 `JobRegistry`
|
||||
|
||||
#### 验收标准
|
||||
- [ ] `grep -rn "TaskRecord\|TaskRegistry\|running_task_ids" app/scheduler/` 返回为空
|
||||
- [ ] `grep -rn "isinstance(.*shots.*str)" app/scheduler/handlers/` 返回为空
|
||||
- [ ] `grep -rn '"task_id":' app/scheduler/handlers/` 返回为空(仅 `StateChange` dataclass 内部可保留)
|
||||
- [ ] `pytest` 通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 3:建立"供应商防火墙"(Adapter 层隔离)
|
||||
**目标**:将 Kling/Volc 术语彻底下压到 Provider 层,业务层以上使用纯业务语义
|
||||
**预估工时**:4-5 天
|
||||
**影响面**:API Schema、DB 模型、Scheduler 模型、Provider 层
|
||||
|
||||
#### Task 3.1:API 层删除 Kling 术语泄漏
|
||||
- [ ] `app/api/v1/video.py`:
|
||||
- `element_id: int | None = Field(None, description="Kling主体ID")` → `human_id: int | None = Field(None, description="数字人主体ID")`
|
||||
- [ ] `app/api/v1/tasks.py`:
|
||||
- 同上,所有 `element_id` → `human_id`
|
||||
- `VideoParams` 中的 `shots` → `segments`
|
||||
|
||||
#### Task 3.2:DB 模型增加供应商抽象
|
||||
- [ ] `app/models/avatar.py`:
|
||||
- `element_id: Mapped[int | None]` → `provider_element_id: Mapped[int | None]`
|
||||
- `voice_task_id` → `provider_voice_job_id`
|
||||
- `element_task_id` → `provider_element_job_id`
|
||||
- 新增 `provider: Mapped[str] = mapped_column(default="kling")`(为未来多供应商做准备)
|
||||
- [ ] 生成 Alembic 迁移脚本(字段重命名 + 新增字段)
|
||||
|
||||
#### Task 3.3:Scheduler 模型供应商抽象
|
||||
- [ ] `app/scheduler/models.py`(Phase 2 后的 `JobRecord` 及 `Segment`):
|
||||
- `kling_task_id` → `provider_task_id`
|
||||
- 如需追溯供应商,增加 `provider: str = "kling"`
|
||||
|
||||
#### Task 3.4:Provider 返回值模型化
|
||||
- [ ] 新建 `app/ai/providers/kling_dto.py`:
|
||||
- `KlingVideoResult`
|
||||
- `KlingImageResult`
|
||||
- `KlingVoiceResult`
|
||||
- `KlingElementResult`
|
||||
- [ ] 修改 `app/ai/providers/klingai_provider.py`:
|
||||
- 所有返回裸 `dict[str, Any]` 的方法改为返回对应的 `Kling*Result`
|
||||
- 状态字段类型改为 `KlingTaskStatus`
|
||||
|
||||
#### Task 3.5:Prompt 语法迁移到 Provider 层
|
||||
- [ ] 删除 `app/scheduler/handlers/video_handler.py` 中的:
|
||||
- `_build_omni_prompt()`
|
||||
- `_build_empty_shot_video_prompt()`
|
||||
- [ ] 在 `app/ai/providers/klingai_provider.py` 中新建 `KlingPromptBuilder` 类:
|
||||
```python
|
||||
class KlingPromptBuilder:
|
||||
@staticmethod
|
||||
def omni_segment(scene: str, voiceover: str, element_id: str | None = None) -> str
|
||||
@staticmethod
|
||||
def empty_shot(scene: str, voiceover: str) -> str
|
||||
```
|
||||
- [ ] `video_handler.py` 调用时只传业务字段(`scene`, `voiceover`, `human_id`),由 Provider 层负责映射为 Kling 语法
|
||||
|
||||
#### Task 3.6:Service 层适配器映射
|
||||
- [ ] `app/services/kling_video_service.py`:
|
||||
- 删除 `avatar_id` 废弃字段
|
||||
- `human_id` → 在调用 Provider 时映射为 `element_id`
|
||||
- [ ] `app/services/qiniu_service.py`:
|
||||
- `file_type="avatar"` → `file_type="clone_video"` 或 `"human_video"`
|
||||
|
||||
#### 验收标准
|
||||
- [ ] `grep -rn "element_id" app/api/ app/schemas/ app/scheduler/models.py | grep -v "provider_element_id"` 返回为空
|
||||
- [ ] `grep -rn "kling_task_id" app/api/ app/schemas/ app/scheduler/models.py` 返回为空
|
||||
- [ ] `grep -rn "<<<element_1>>>\|<<<voice_1>>>" app/scheduler/ app/services/` 返回为空(仅在 Provider 层保留)
|
||||
- [ ] Alembic 迁移脚本可正常升级/降级
|
||||
- [ ] `pytest` 通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 4:清理历史残留 + 消除双轨运行
|
||||
**目标**:删除所有 Celery 时代残留,将 `script` 任务纳入 Scheduler 统一调度
|
||||
**预估工时**:2-3 天
|
||||
**影响面**:Handler、API 路由、历史命名
|
||||
|
||||
#### Task 4.1:删除兼容旧字段代码
|
||||
- [ ] `app/scheduler/handlers/video_handler.py`:
|
||||
- 删除 `shot["video_task_id"] = kling_task_id # 兼容旧字段`
|
||||
- 删除初始化 shots 时的 `"video_task_id": None`
|
||||
- [ ] `app/scheduler/handlers/image_handler.py`:
|
||||
- 删除 `params["image_task_id"] = kling_task_id`
|
||||
- [ ] `app/services/kling_video_service.py`:
|
||||
- 删除 `avatar_id` 字段
|
||||
|
||||
#### Task 4.2:修正历史残留命名
|
||||
- [ ] `app/core/redis_client.py`:删除文档字符串中的 `RateLimiter` 字样
|
||||
- [ ] `app/api/v1/tasks.py`:
|
||||
- `cache entry` → `registry entry`
|
||||
- `cache_err` → `registry_err`
|
||||
- `Failed to update cache` → `Failed to update registry`
|
||||
- [ ] `app/core/token_manager.py`:`_background_tasks` → `_refresh_tasks`
|
||||
- [ ] 删除 `app/services/dto.py`
|
||||
|
||||
#### Task 4.3:将 `script` 任务迁移到 Scheduler
|
||||
- [ ] 新建 `app/scheduler/handlers/script_handler.py`
|
||||
- 将 `app/api/v1/tasks.py` 中 `_run_script_task` 的逻辑迁移至此
|
||||
- 继承 `AsyncHandler`,`name = "script"`,不占用 Slot(或占用独立 `script_slots`)
|
||||
- [ ] 修改 `app/api/v1/tasks.py`:
|
||||
- `script` 任务改为 `registry.create(job_type="script", ...)`
|
||||
- 删除 `BackgroundTasks` 相关参数和 `_run_script_task` 函数
|
||||
- [ ] 修改 `app/scheduler/main.py`:注册 `ScriptHandler`
|
||||
|
||||
#### 验收标准
|
||||
- [ ] `grep -rn "兼容旧字段\|video_task_id\|image_task_id" app/scheduler/ app/services/` 返回为空
|
||||
- [ ] `grep -rn "cache_err\|cache entry" app/api/v1/tasks.py` 返回为空
|
||||
- [ ] `app/services/dto.py` 不存在
|
||||
- [ ] `app/api/v1/tasks.py` 中无 `BackgroundTasks` 导入和使用
|
||||
- [ ] `pytest` 通过
|
||||
|
||||
---
|
||||
|
||||
### Phase 5:CRUD 层强类型化
|
||||
**目标**:消灭 CRUD 层的裸字典更新
|
||||
**预估工时**:2 天
|
||||
**影响面**:CRUD Base、Avatar CRUD、Scheduler Handler
|
||||
|
||||
#### Task 5.1:CRUD Base 类型约束
|
||||
- [ ] `app/crud/base.py`:
|
||||
- `obj_in: dict[str, Any]` → `obj_in: CreateSchemaType | UpdateSchemaType`
|
||||
- 保留 `dict` 仅作为 `update` 的 fallback,但所有业务调用方优先使用 Schema
|
||||
|
||||
#### Task 5.2:Avatar Schema 定义
|
||||
- [ ] 新建 `app/schemas/avatar.py`:
|
||||
```python
|
||||
class AvatarCreate(BaseModel):
|
||||
name: str
|
||||
video_url: str
|
||||
status: AvatarCloneStatus = AvatarCloneStatus.PENDING
|
||||
|
||||
class AvatarUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
status: AvatarCloneStatus | None = None
|
||||
provider_voice_job_id: str | None = None
|
||||
provider_element_job_id: str | None = None
|
||||
provider_element_id: int | None = None
|
||||
fail_reason: str | None = None
|
||||
```
|
||||
- [ ] `app/crud/avatar.py`:改为 `class CRUDAvatar(CRUDBase[Avatar, AvatarCreate, AvatarUpdate])`
|
||||
|
||||
#### Task 5.3:Handler 调用方改造
|
||||
- [ ] `app/scheduler/handlers/avatar_handler.py`:
|
||||
- 所有 `_update_avatar(avatar_id, {"status": "..."})` 改为 `_update_avatar(avatar_id, AvatarUpdate(status=AvatarCloneStatus.XXX))`
|
||||
- 删除裸字典辅助函数 `_update_avatar` 中的 `**obj_in` 展开,改用 `obj_in.model_dump(exclude_unset=True)`
|
||||
|
||||
#### 验收标准
|
||||
- [ ] `grep -rn 'obj_in=\{' app/scheduler/handlers/avatar_handler.py` 返回为空
|
||||
- [ ] `mypy app/crud/` 无类型错误
|
||||
- [ ] `pytest` 通过
|
||||
|
||||
---
|
||||
|
||||
## 四、自动化防护网(Phase 5 之后部署)
|
||||
|
||||
### 4.1 预提交钩子:禁词检查
|
||||
在 `.pre-commit-config.yaml` 或 `Makefile` 中增加 `lint-semantic`:
|
||||
|
||||
```makefile
|
||||
lint-semantic:
|
||||
@echo "Checking semantic boundaries..."
|
||||
@! grep -rn "element_id" app/api/ app/schemas/ app/scheduler/models.py | grep -v "provider_element_id" || (echo "ERROR: element_id leaked to upper layers"; exit 1)
|
||||
@! grep -rn "kling_task_id" app/api/ app/schemas/ app/scheduler/models.py || (echo "ERROR: kling_task_id leaked to upper layers"; exit 1)
|
||||
@! grep -rn "scene_desc" app/ | grep -v "__pycache__" || (echo "ERROR: scene_desc not fully renamed"; exit 1)
|
||||
@! grep -rn "TaskRecord\|TaskRegistry\|running_task_ids" app/scheduler/ || (echo "ERROR: Scheduler task naming not fully migrated"; exit 1)
|
||||
@! grep -rn "<<<element_1>>>\|<<<voice_1>>>" app/scheduler/ app/services/ || (echo "ERROR: Kling prompt syntax leaked"; exit 1)
|
||||
@echo "Semantic check passed"
|
||||
```
|
||||
|
||||
### 4.2 mypy 渐进严格化
|
||||
- [ ] 在 `pyproject.toml` 中为 `app/scheduler/` 和 `app/schemas/` 单独开启 `strict = true`
|
||||
- [ ] 逐步扩展至 `app/api/` 和 `app/services/`
|
||||
|
||||
### 4.3 AGENTS.md 术语表(Glossary)
|
||||
在 `AGENTS.md` 中新增"统一术语表"章节(见下文),所有 AI Agent 修改代码前必须查阅。
|
||||
|
||||
---
|
||||
|
||||
## 五、风险与回滚策略
|
||||
|
||||
| 风险 | 影响 | mitigation |
|
||||
|------|------|-------------|
|
||||
| Phase 1 删除 `ScriptShot` 影响前端类型生成 | 中 | `ScriptShot` 保留为 `Segment` 的 `TypeAlias` 一个 Sprint,待前端适配后删除 |
|
||||
| Phase 2 `JobRegistry` 重命名导致 API 层引用遗漏 | 高 | 使用 IDE 全局重构(PyCharm/Ruff/Rename Symbol),执行后跑全量 `pytest` |
|
||||
| Phase 3 DB 字段重命名需要数据迁移 | 中 | Alembic 脚本必须包含 `op.alter_column` 的 `existing_type` 和 `existing_nullable` |
|
||||
| Phase 4 `script` 迁出 BackgroundTask 后响应时间变长 | 低 | 脚本生成仍立即返回 `job_id`,前端通过轮询 `/tasks/{job_id}` 获取结果,接口契约不变 |
|
||||
| 多 Phase 并行开发导致冲突 | 高 | **严禁并行**:必须按 1→2→3→4→5 顺序执行,每个 Phase 合并到主分支后再开始下一个 |
|
||||
|
||||
---
|
||||
|
||||
## 六、作为 GitHub Project Task List 的格式
|
||||
|
||||
若导入 GitHub Project,建议按以下结构建 5 个 Milestone:
|
||||
|
||||
```markdown
|
||||
### Milestone 1: Schema Unification
|
||||
- [ ] #101 Create `app/schemas/segment.py` with `Segment` model
|
||||
- [ ] #102 Create `app/schemas/enums.py` with `JobStatus`, `SegmentStatus`, `AvatarCloneStatus`, `KlingTaskStatus`
|
||||
- [ ] #103 Create `app/schemas/job.py` with `JobParams` Union
|
||||
- [ ] #104 Remove `ShotUnit` from `app/scheduler/models.py`
|
||||
- [ ] #105 Remove `ShotTask` from `app/services/kling_video_service.py`
|
||||
- [ ] #106 Remove `ShotData` from `app/api/v1/video.py`
|
||||
- [ ] #107 Rename `scene_desc` → `scene` across codebase
|
||||
- [ ] #108 Migrate all `status` strings to Enums
|
||||
|
||||
### Milestone 2: Scheduler De-tasking
|
||||
- [ ] #201 Rename `TaskRecord` → `JobRecord`
|
||||
- [ ] #202 Rename `TaskRegistry` → `JobRegistry`
|
||||
- [ ] #203 Registry auto-deserializes `JobParams` models
|
||||
- [ ] #204 Replace `dict` changes with `StateChange` dataclass
|
||||
- [ ] #205 Update all Handlers to return `list[StateChange]`
|
||||
|
||||
### Milestone 3: Vendor Firewall
|
||||
- [ ] #301 API layer: `element_id` → `human_id`
|
||||
- [ ] #302 DB model: add `provider` field, rename task/element IDs
|
||||
- [ ] #303 Scheduler model: `kling_task_id` → `provider_task_id`
|
||||
- [ ] #304 Provider DTOs: `KlingVideoResult`, `KlingImageResult`, etc.
|
||||
- [ ] #305 Move `KlingPromptBuilder` to Provider layer
|
||||
- [ ] #306 Alembic migration for avatar table changes
|
||||
|
||||
### Milestone 4: Cleanup & Unification
|
||||
- [ ] #401 Remove legacy compatibility fields (`video_task_id`, `image_task_id`)
|
||||
- [ ] #402 Fix historical naming (`cache_err`, `RateLimiter` docstrings, etc.)
|
||||
- [ ] #403 Delete `app/services/dto.py`
|
||||
- [ ] #404 Migrate `script` task from BackgroundTask to Scheduler
|
||||
|
||||
### Milestone 5: CRUD Strong Typing
|
||||
- [ ] #501 Create `AvatarCreate` / `AvatarUpdate` schemas
|
||||
- [ ] #502 Type-constrain CRUDBase
|
||||
- [ ] #503 Refactor `avatar_handler.py` to use `AvatarUpdate` instead of raw dicts
|
||||
- [ ] #504 Add `lint-semantic` to Makefile / pre-commit
|
||||
- [ ] #505 Update `AGENTS.md` with Glossary and layer rules
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、相关文档
|
||||
|
||||
- [统一异步调度器设计文档](./unified-async-scheduler.md)
|
||||
- [数据库设计文档](./database-design.md)
|
||||
- [AGENTS.md](../AGENTS.md)(术语表与分层禁令)
|
||||
@@ -0,0 +1,350 @@
|
||||
# 统一异步任务调度方案
|
||||
|
||||
> **状态:已完成(2026-04-17)** — 本文档所述方案已全面实施,Celery 已完全移除,所有第三方异步任务现由 Async Engine Scheduler 统一调度。
|
||||
|
||||
> 本文档用于替代原 Celery 在"提交→轮询→收尾"类第三方异步任务中的角色,解决视频生成、形象克隆等任务在队列中频繁出现的拥堵、死锁和状态不一致问题。
|
||||
|
||||
---
|
||||
|
||||
## 1. 背景与问题
|
||||
|
||||
### 1.1 当前架构的缺陷
|
||||
|
||||
目前项目使用 Celery Worker 处理所有第三方异步任务,包括:
|
||||
|
||||
- **视频生成** (`video`):提交 Kling 分镜 → 轮询状态 → 下载上传
|
||||
- **形象克隆** (`avatar_clone`):提交音色 → 轮询 → 提交主体 → 轮询
|
||||
- **字幕对齐** (`subtitle`)
|
||||
- **图片生成** (`image`)
|
||||
|
||||
这些任务都被放进 Celery 队列,由 Worker 并发消费。但 Kling 视频生成本质上是**"占用并发槽位并长时间等待"**的过程,当前设计存在三个结构性问题:
|
||||
|
||||
1. **轮询任务风暴**:`poll_video_task` 用 Celery `retry(countdown=5)` 模拟轮询,一个 8 分钟的 Kling 任务会产生近百个 Celery Task,淹没队列调度器和 Redis Result Backend。
|
||||
2. **快慢任务混排**:下载上传(IO 密集型)和提交/轮询(轻量 HTTP)共用 `video` 队列,Worker 被长任务占满,新任务饿死。
|
||||
3. **状态死锁**:`download_upload_shot` 作为独立 Celery Task,一旦被 Worker 强制 Kill(如超时),shot 状态永远卡在 `downloading`,而轮询任务又不再处理它,导致整个任务假死。
|
||||
|
||||
### 1.2 核心认知
|
||||
|
||||
Kling 是一个有**严格并发上限**(20 槽位)的第三方异步执行池。我们需要的是一个 **Slot-Based Scheduler**(槽位调度器),而不是一个任务队列(Celery)。
|
||||
|
||||
> **任务队列**擅长"把独立任务尽快分发出去";
|
||||
> **槽位调度器**擅长"在有限资源下,周期性补货、轮询和收尾"。
|
||||
|
||||
Kling 视频生成和形象克隆属于后者。
|
||||
|
||||
---
|
||||
|
||||
## 2. 架构总览
|
||||
|
||||
```
|
||||
┌─────────────┐ HTTP ┌──────────────────┐
|
||||
│ Tauri App │ ◄────────────► │ FastAPI API │
|
||||
│ (React) │ │ (Gateway) │
|
||||
└─────────────┘ └────────┬─────────┘
|
||||
│
|
||||
┌───────────────────┼───────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
||||
│ PostgreSQL │ │ Redis │ │ Object Store │
|
||||
│ (持久化/审计) │ │ (运行时状态) │ │ (七牛/本地) │
|
||||
└──────────────┘ └──────┬───────┘ └──────────────┘
|
||||
│
|
||||
┌────────┴────────┐
|
||||
│ Async Engine │
|
||||
│ (Slot Scheduler)│
|
||||
│ python main.py │
|
||||
└────────┬────────┘
|
||||
│
|
||||
┌──────────────────┼──────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Video Handler│ │Avatar Handler│ │Future Handler│
|
||||
│ max_slots=18│ │ max_slots=2 │ │ (subtitle…) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
### 2.1 核心组件
|
||||
|
||||
| 组件 | 职责 | 技术选型 |
|
||||
|------|------|----------|
|
||||
| **FastAPI API** | 接收前端请求、创建任务、写入状态、供前端轮询 | 现有 FastAPI |
|
||||
| **Redis** | 存储任务的**运行时状态**(running shots、当前 stage、slot 占用集合) | 现有 Redis |
|
||||
| **PostgreSQL** | 存储任务的**持久化记录**(创建时间、最终结果、成本统计、失败原因) | 现有 PostgreSQL |
|
||||
| **Async Engine** | 独立的调度进程,每 10 秒一次 **Tick**,驱动所有任务状态推进 | Python `asyncio` |
|
||||
| **Handler** | 插件化模块,每个第三方平台一个实现 | 面向接口的 Python 类 |
|
||||
|
||||
---
|
||||
|
||||
## 3. 核心机制
|
||||
|
||||
### 3.1 统一状态机
|
||||
|
||||
无论 video 还是 avatar_clone,所有第三方异步任务单元收敛到 **5 个统一状态**:
|
||||
|
||||
```
|
||||
pending → submitted → succeed → completed
|
||||
│ │
|
||||
└──────────────────────────────┘
|
||||
↓
|
||||
failed
|
||||
```
|
||||
|
||||
- **`pending`**:在队列里等待全局 slot 空闲
|
||||
- **`submitted`**:已占用 slot,已提交给 Kling,等待轮询结果
|
||||
- **`succeed`**:Kling 返回成功,Async Engine 立即触发下载/收尾(后台异步执行)
|
||||
- **`failed`**:Kling 返回失败或提交异常
|
||||
- **`completed`**:下载、上传、DB 写入全部完成
|
||||
|
||||
对于 avatar_clone 这种**多阶段**任务,内部用 Sub-State 嵌套,但每个阶段仍遵循同一模式:
|
||||
|
||||
```
|
||||
voice_pending → voice_submitted → voice_succeed
|
||||
↓
|
||||
element_pending → element_submitted → element_succeed → completed
|
||||
```
|
||||
|
||||
### 3.2 Slot Manager(全局并发控制器)
|
||||
|
||||
基于 **Redis SET + Lua 脚本** 实现严格的原子槽位管理:
|
||||
|
||||
```python
|
||||
class SlotManager:
|
||||
async def acquire(self, slot_key: str, slot_id: str, max_slots: int) -> bool:
|
||||
"""Lua 脚本原子执行:SADD -> SCARD -> 超限则 SREM"""
|
||||
lua = """
|
||||
local key = KEYS[1]
|
||||
local slot_id = ARGV[1]
|
||||
local max_slots = tonumber(ARGV[2])
|
||||
redis.call('sadd', key, slot_id)
|
||||
local count = redis.call('scard', key)
|
||||
if count > max_slots then
|
||||
redis.call('srem', key, slot_id)
|
||||
return 0
|
||||
end
|
||||
redis.call('expire', key, 1800)
|
||||
return 1
|
||||
"""
|
||||
return await self.redis.eval(lua, 1, slot_key, slot_id, str(max_slots)) == 1
|
||||
|
||||
async def release(self, slot_key: str, slot_id: str) -> None:
|
||||
await self.redis.srem(slot_key, slot_id)
|
||||
```
|
||||
|
||||
当前配置:
|
||||
|
||||
- **Video 槽位池**:`kling:video_slots`,上限 **18**
|
||||
- **Avatar 槽位池**:`kling:avatar_slots`,上限 **2**
|
||||
|
||||
> **为什么 Lua 脚本?** 确保 `SADD + SCARD + 条件 SREM` 原子执行。即使未来启动第二个 Scheduler 实例做 HA,也不会出现并发超发。
|
||||
|
||||
### 3.3 Async Engine Tick 循环
|
||||
|
||||
Scheduler 是一个独立的 `asyncio` 进程,主循环如下:
|
||||
|
||||
```python
|
||||
async def main():
|
||||
engine = AsyncEngine()
|
||||
while True:
|
||||
tick_start = time.monotonic()
|
||||
|
||||
# 1. 加载所有 running 的任务
|
||||
tasks = await engine.registry.get_running_tasks()
|
||||
|
||||
# 2. 按 Handler 分组,并行执行各自的 tick
|
||||
changes = await asyncio.gather(*[
|
||||
handler.tick(tasks_for_handler, engine.slots)
|
||||
for handler in engine.handlers.values()
|
||||
])
|
||||
|
||||
# 3. 批量应用状态变更(Pipeline 写入 Redis)
|
||||
await engine.registry.apply_changes(flatten(changes))
|
||||
|
||||
# 4. 对 completed/failed 的任务,持久化到 PostgreSQL
|
||||
await engine.persist_finished_tasks()
|
||||
|
||||
# 5. 控制 Tick 间隔(固定 10 秒,执行过久时至少休息 2 秒)
|
||||
elapsed = time.monotonic() - tick_start
|
||||
await asyncio.sleep(max(10 - elapsed, 2))
|
||||
```
|
||||
|
||||
### 3.4 Handler 插件化接口
|
||||
|
||||
每个第三方平台实现一个 Handler:
|
||||
|
||||
```python
|
||||
class AsyncHandler(ABC):
|
||||
name: str # e.g. "video"
|
||||
slot_key: str # e.g. "kling:video_slots"
|
||||
max_slots: int # e.g. 18
|
||||
|
||||
@abstractmethod
|
||||
async def tick(self, tasks: list[Task], slots: SlotManager) -> list[StateChange]:
|
||||
"""每个 Tick 执行一次,返回需要更新的状态变更列表"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### Video Handler 的 tick 逻辑
|
||||
|
||||
1. **查**:遍历所有 `submitted` 的 shots,批量并行查询 Kling 状态
|
||||
2. **收**:
|
||||
- `succeed` → `release_slot` + `asyncio.create_task(download_and_upload(shot))` + 状态改为 `completed`
|
||||
- `failed` → `release_slot` + 状态改为 `failed`
|
||||
3. **补**:计算空闲槽位数,从 `pending` 队列 FIFO 取出新 shot 提交给 Kling,直到槽满
|
||||
4. **写**:更新 task 的聚合状态(completed/total/message)到 Redis
|
||||
|
||||
#### Avatar Handler 的 tick 逻辑
|
||||
|
||||
1. 检查当前 stage(如 `voice_submitted`),查询 Kling 状态
|
||||
2. 若 `voice_succeed` → 释放 slot,推进到 `element_pending`,并在同一 tick 内尝试申请 slot 提交主体创建
|
||||
3. 若 `element_succeed` → 释放 slot,状态改为 `completed`
|
||||
|
||||
---
|
||||
|
||||
## 4. API 层与数据流
|
||||
|
||||
### 4.1 创建任务(不变)
|
||||
|
||||
```python
|
||||
@router.post("/{task_type}", response_model=TaskCreateResponse)
|
||||
async def create_task(task_type: str, request: TaskCreateRequest):
|
||||
task_id = generate_task_id()
|
||||
|
||||
# 1. 写入 PostgreSQL(持久化底单)
|
||||
await db_task.create(task_id=task_id, type=task_type, user_id=...)
|
||||
|
||||
# 2. 写入 Redis(标记为 pending,供 Async Engine 消费)
|
||||
await redis_task.create(task_id=task_id, type=task_type, status="pending", ...)
|
||||
|
||||
return TaskCreateResponse(task_id=task_id, status="pending")
|
||||
```
|
||||
|
||||
### 4.2 查询状态(不变)
|
||||
|
||||
```python
|
||||
@router.get("/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
# 先读 Redis(热数据)
|
||||
task = await redis_task.get(task_id)
|
||||
if not task:
|
||||
# fallback 到 PostgreSQL(已完成的归档数据)
|
||||
task = await db_task.get(task_id)
|
||||
return task
|
||||
```
|
||||
|
||||
### 4.3 前端兼容性
|
||||
|
||||
**前端 `useTask.ts` 的轮询逻辑完全不需要修改。** 这是渐进式迁移的关键——调度层的重构对上层透明。
|
||||
|
||||
---
|
||||
|
||||
## 5. 部署方案
|
||||
|
||||
### 5.1 Docker Compose
|
||||
|
||||
```yaml
|
||||
services:
|
||||
api:
|
||||
build:
|
||||
context: ../python-api
|
||||
dockerfile: Dockerfile
|
||||
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
# ...
|
||||
|
||||
scheduler:
|
||||
build:
|
||||
context: ../python-api
|
||||
dockerfile: Dockerfile
|
||||
container_name: meijiaka-scheduler
|
||||
command: python -m app.scheduler.main
|
||||
environment:
|
||||
- REDIS_HOST=redis
|
||||
- DATABASE_URL=postgresql+asyncpg://...
|
||||
depends_on:
|
||||
- redis
|
||||
- db
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 512M
|
||||
```
|
||||
|
||||
### 5.2 Celery 的处置
|
||||
|
||||
- 立即下线 `worker-video`(不再消费 `video` 队列)
|
||||
- Phase 2 下线 `worker-avatar`(Avatar Handler 迁入 Async Engine 后)
|
||||
- 可选:暂时保留 Celery 跑 `subtitle`,待后续迁移
|
||||
- 最终目标:所有"提交→轮询"类任务都迁入 Async Engine,Celery 整体移除
|
||||
|
||||
---
|
||||
|
||||
## 6. 迁移路径
|
||||
|
||||
| 阶段 | 时间 | 动作 | 风险 |
|
||||
|------|------|------|------|
|
||||
| **Phase 1** | 本周 | Async Engine 只接管 `video`(18 slots);`avatar_clone` 仍由 Celery 运行 | 改动面最小,只验证 video 链路 |
|
||||
| **Phase 2** | 1-2 周后 | Async Engine 新增 `avatar_clone` Handler(2 slots);彻底下线 Celery 的 `worker-video` 和 `worker-avatar` | 验证 avatar 链路,解决资源饿死 |
|
||||
| **Phase 3** | 未来 | `subtitle`、`image` 等陆续迁入 Async Engine;Celery 完全移除 | 统一所有第三方异步任务调度 |
|
||||
|
||||
---
|
||||
|
||||
## 7. 设计原则论证
|
||||
|
||||
### 7.1 主流(Mainstream)
|
||||
|
||||
- **Redis + PostgreSQL 双存储**:运行态在 Redis,持久态在 PostgreSQL。这是现代异步系统的事实标准,从 AWS Lambda 到 Vercel 再到国内云厂商均采用类似模式。
|
||||
- **Python asyncio 轻量调度器**:不引入 Kafka、RabbitMQ 或 Airflow 等重型框架,利用原生异步能力构建。Prefect、Dagster 的底层 Scheduler 也采用类似思想。
|
||||
- **Gateway + 独立 Scheduler 进程**:API 负责接入,Scheduler 负责推进,职责清晰。这是当前中小型 SaaS 的主流演进方向。
|
||||
|
||||
### 7.2 合理(Reasonable)
|
||||
|
||||
- **完全匹配项目定位**:项目定位是"轻量云账号 + 全本地业务数据",不需要 Kubernetes 或复杂工作流引擎。Async Engine 只是一个额外的 Python 进程,资源占用 < 512MB。
|
||||
- **渐进式迁移,契约不变**:前端轮询逻辑、API URL、响应 Schema 均不变。改动仅集中在后端任务分发层,业务代码零侵入。
|
||||
- **资源隔离精确可控**:Video 18 slots + Avatar 2 slots = 20 slots,与 Kling 实际并发限制完全对齐。不会出现"形象克隆占满 Worker 导致视频饿死"的结构性问题。
|
||||
- **开发体验优先**:本地开发时,scheduler 可以和 api 一起 `docker-compose up`,也可以单独 `python -m app.scheduler.main` 调试。不需要 ngrok,不需要把开发环境搬到云端。
|
||||
- **幂等和可恢复**:每个 shot 的提交操作都是幂等的。Redis 记录了 `kling_task_id`,Scheduler 重启后从 Redis 恢复 running 任务,继续轮询,不会丢失状态。
|
||||
|
||||
### 7.3 长期稳定(Long-term Stable)
|
||||
|
||||
- **HA 预留,无单点故障**:`SlotManager` 基于 Redis Lua 脚本实现原子操作,天然支持多实例竞争。未来如需高可用,可启动第二个 Scheduler 实例,通过 Redis 分布式锁选举 Leader,实现秒级主备切换,无需重构。
|
||||
- **Handler 插件化扩展**:未来接入即梦、Runway、Pika 或新的 AI 服务,只需实现新的 `AsyncHandler` 子类,配置 `slot_key` 和 `max_slots`。核心调度逻辑永远不需要改动。
|
||||
- **数据一致性保障**:运行态在 Redis(崩溃恢复快),完成态在 PostgreSQL(数据不丢)。即使 Scheduler 挂掉 30 分钟,Kling 端的任务仍在运行,恢复后继续轮询即可。
|
||||
- **第三方接口变更的防御性**:Handler 内部对 Kling API 的调用有统一的超时控制、重试策略和异常兜底。如果 Kling 某个接口升级,只改对应 Handler,不影响其他模块。
|
||||
- **可观测性支撑长期运维**:通过 Prometheus 指标,可长期监控"视频生成成功率"、"平均生成耗时"、"槽位利用率"、"Kling API 延迟分布",为后续扩容和成本优化提供数据支撑。
|
||||
|
||||
---
|
||||
|
||||
## 8. 关键文件位置(建议)
|
||||
|
||||
```
|
||||
python-api/
|
||||
├── app/
|
||||
│ └── scheduler/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # Async Engine 入口(Tick 循环)
|
||||
│ ├── engine.py # AsyncEngine 核心调度器
|
||||
│ ├── slot_manager.py # 槽位管理器(Redis Lua)
|
||||
│ ├── registry.py # 任务注册表(Redis 读写)
|
||||
│ ├── handlers/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── base.py # AsyncHandler 抽象基类
|
||||
│ │ ├── video_handler.py # Video 任务处理器
|
||||
│ │ └── avatar_handler.py # Avatar Clone 处理器
|
||||
│ └── models.py # Scheduler 内部数据模型
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. 结论
|
||||
|
||||
当前 Celery 在"提交→轮询→收尾"类任务中的角色是**结构性错位**的。它带来的任务风暴、队列拥堵和状态死锁不是可以通过调参修复的,而是其模型与问题本质不匹配的结果。
|
||||
|
||||
**统一异步调度方案**的核心决策:
|
||||
|
||||
1. **用 Async Engine(Slot-Based Scheduler)替代 Celery 管理所有第三方异步任务**
|
||||
2. **Video 分配 18 槽,Avatar Clone 分配 2 槽,由唯一调度器全局管理**
|
||||
3. **API 层和前端轮询逻辑完全不变,实现渐进式迁移**
|
||||
4. **本地开发环境保持原样,无需引入 webhook 或云端部署**
|
||||
|
||||
这是根治"任务队列生成视频总会出问题"的唯一长期方案。
|
||||
@@ -0,0 +1,342 @@
|
||||
# 视频生成交互流程设计
|
||||
|
||||
## 一、正常流程(批量生成)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 0: 检查前置条件 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 点击【生成视频】按钮 │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌─────────────────┐ │
|
||||
│ │ 检查本地状态 │ │
|
||||
│ │ 如果正在生成中 │───► 提示"已有任务进行中,请等待完成" │
|
||||
│ └─────────────────┘ 或者"是否取消当前任务?" │
|
||||
│ │ │
|
||||
│ ▼ 无进行中任务 │
|
||||
│ ┌─────────────────┐ │
|
||||
│ │ 检查是否选形象 │ │
|
||||
│ │ 未选择 │───► 弹出形象选择弹窗 │
|
||||
│ └─────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ 已选择 │
|
||||
│ 继续下一步 │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: 确认弹窗 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────┐ │
|
||||
│ │ 开始生成视频 │ │
|
||||
│ ├─────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ 将生成 8 个分镜视频 │ │
|
||||
│ │ 预计耗时:约 15-20 分钟 │ │
|
||||
│ │ │ │
|
||||
│ │ ⚠️ 生成过程中请勿关闭应用 │ │
|
||||
│ │ 您可以最小化窗口,但不要关闭 │ │
|
||||
│ │ │ │
|
||||
│ │ ┌────────────┐ ┌──────────────────┐ │ │
|
||||
│ │ │ 取消 │ │ 开始生成 │ │ │
|
||||
│ │ └────────────┘ └──────────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: 进入生成状态(界面锁定) │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 【生成】按钮变为【生成中...】且 disabled │
|
||||
│ │
|
||||
│ 顶部显示全局状态栏: │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐│
|
||||
│ │ 🎬 视频生成中 ━━━━━━━━⏳━━━━ 预计还需 12 分钟 [?] ││
|
||||
│ └─────────────────────────────────────────────────────────────┘│
|
||||
│ │
|
||||
│ 显示模态弹窗(不可关闭): │
|
||||
│ ┌─────────────────────────────────────────┐ │
|
||||
│ │ 视频生成 │ │
|
||||
│ ├─────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ ┌─────────────┐ │ │
|
||||
│ │ │ 状态标签 │ │ │
|
||||
│ │ │ 任务已开启 │ │ │
|
||||
│ │ └─────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ │ 正在为空镜生成参考图片... │ │
|
||||
│ │ │ │
|
||||
│ │ 预计还需 12 分钟 │ │
|
||||
│ │ │ │
|
||||
│ │ [最小化到后台] │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ 界面锁定状态: │
|
||||
│ - 禁用:生成按钮、新建项目、添加/删除分镜 │
|
||||
│ - 可浏览:但不能修改任何内容 │
|
||||
│ - 可退出应用:但会提示"任务将后台继续,确定退出?" │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: 状态流转 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 状态标签流转: │
|
||||
│ │
|
||||
│ 分镜(omni-video): │
|
||||
│ 任务已开启 ──► 排队生成中 ──► 任务已完成 │
|
||||
│ │
|
||||
│ 空镜(文生图+图生视频): │
|
||||
│ 任务已开启 ──► 生成参考图片... ──► 排队生成中 ──► 任务已完成 │
|
||||
│ │
|
||||
│ 详细描述文字实时更新(SSE 推送): │
|
||||
│ - "正在初始化任务..." │
|
||||
│ - "正在为空镜生成参考图片..." │
|
||||
│ - "图片生成完成,开始生成视频..." │
|
||||
│ - "正在生成视频,请稍候..." │
|
||||
│ - "已完成 3/8 个分镜" │
|
||||
│ - "整理生成结果..." │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: 完成 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 模态弹窗更新: │
|
||||
│ ┌─────────────────────────────────────────┐ │
|
||||
│ │ 视频生成 │ │
|
||||
│ ├─────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ ┌─────────────┐ │ │
|
||||
│ │ │ 任务已完成 │ │ │
|
||||
│ │ └─────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ │ 成功生成 8 个视频 │ │
|
||||
│ │ │ │
|
||||
│ │ ┌──────────────────────────────────┐ │ │
|
||||
│ │ │ 确定 │ │ │
|
||||
│ │ └──────────────────────────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ 用户点击【确定】: │
|
||||
│ 1. 关闭弹窗 │
|
||||
│ 2. 解锁界面 │
|
||||
│ 3. 自动滚动到第一个有视频的分镜 │
|
||||
│ 4. 播放第一个视频 │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 二、单个重新生成流程
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 差异点 │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 入口:分镜卡片上的【重新生成】按钮 │
|
||||
│ │
|
||||
│ 确认弹窗简化: │
|
||||
│ ┌─────────────────────────────────────────┐ │
|
||||
│ │ 重新生成视频 │ │
|
||||
│ ├─────────────────────────────────────────┤ │
|
||||
│ │ 将重新生成分镜 3 的视频 │ │
|
||||
│ │ 预计耗时:约 3-5 分钟 │ │
|
||||
│ │ │ │
|
||||
│ │ ⚠️ 生成过程中请勿关闭应用 │ │
|
||||
│ │ │ │
|
||||
│ │ [取消] [确定] │ │
|
||||
│ └─────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ 完成后:自动选中该分镜并播放 │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 三、异常流程
|
||||
|
||||
### 3.1 用户尝试关闭应用
|
||||
|
||||
```
|
||||
用户点击关闭窗口(或 Cmd+Q / Alt+F4)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ ⚠️ 确认关闭 │
|
||||
├─────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 视频生成任务仍在进行中 │
|
||||
│ │
|
||||
│ 如果选择关闭: │
|
||||
│ - 任务将在后台继续运行 │
|
||||
│ - 生成完成后会推送系统通知 │
|
||||
│ - 下次打开应用可查看结果 │
|
||||
│ │
|
||||
│ [取消] [最小化到托盘] [关闭应用] │
|
||||
│ │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 3.2 应用崩溃/强制退出后恢复
|
||||
|
||||
```
|
||||
用户重新打开应用
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ 📋 恢复未完成任务 │
|
||||
├─────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 检测到上次有未完成的视频生成任务 │
|
||||
│ │
|
||||
│ 项目:厨房改造方案 │
|
||||
│ 进度:已完成 5/8 个分镜 │
|
||||
│ 状态:仍在后台处理中 │
|
||||
│ │
|
||||
│ [查看进度] [我知道了] │
|
||||
│ │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
点击【查看进度】:
|
||||
- 跳转到视频生成页面
|
||||
- 自动恢复进度弹窗显示
|
||||
- 继续监听 SSE/轮询
|
||||
```
|
||||
|
||||
### 3.3 生成失败
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ ❌ 生成失败 │
|
||||
├─────────────────────────────────────────┤
|
||||
│ │
|
||||
│ 视频生成过程中发生错误 │
|
||||
│ │
|
||||
│ 错误信息:Kling API 超时 │
|
||||
│ │
|
||||
│ 已生成的视频已保存 │
|
||||
│ 失败的分镜:分镜3、分镜7 │
|
||||
│ │
|
||||
│ [返回查看] [重试失败项] │
|
||||
│ │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
点击【重试失败项】:
|
||||
- 只重新生成失败的那几个分镜
|
||||
- 复用现有参数
|
||||
```
|
||||
|
||||
### 3.4 网络断开
|
||||
|
||||
```
|
||||
SSE 连接断开
|
||||
│
|
||||
▼
|
||||
状态栏显示:"网络异常,正在重连...(1/3)"
|
||||
│
|
||||
▼
|
||||
自动重连 SSE(最多 3 次)
|
||||
│
|
||||
├─► 重连成功:继续接收进度
|
||||
│
|
||||
└─► 重连失败:切换到轮询模式
|
||||
│
|
||||
▼
|
||||
每 5 秒轮询一次状态
|
||||
│
|
||||
▼
|
||||
网络恢复后:自动切回 SSE
|
||||
```
|
||||
|
||||
## 四、本地状态管理
|
||||
|
||||
```typescript
|
||||
// localStorage: meijiaka_generation_state
|
||||
interface GenerationState {
|
||||
// 任务标识
|
||||
jobId: string;
|
||||
projectId: string;
|
||||
|
||||
// 任务状态
|
||||
status: 'pending' | 'generating' | 'completed' | 'failed';
|
||||
|
||||
// 任务信息(用于恢复显示)
|
||||
shots: Array<{
|
||||
id: string;
|
||||
type: 'segment' | 'empty_shot';
|
||||
}>;
|
||||
totalShots: number;
|
||||
|
||||
// 时间戳
|
||||
startedAt: number;
|
||||
lastUpdatedAt: number;
|
||||
|
||||
// 结果(完成后填写)
|
||||
results?: Array<{
|
||||
shotId: string;
|
||||
status: 'completed' | 'failed';
|
||||
videoPath?: string;
|
||||
errorMessage?: string;
|
||||
}>;
|
||||
|
||||
// 错误信息
|
||||
errorMessage?: string;
|
||||
}
|
||||
```
|
||||
|
||||
## 五、状态流转图
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ IDLE │
|
||||
└──────┬──────┘
|
||||
│ 点击生成
|
||||
▼
|
||||
┌─────────────┐
|
||||
┌───────────────►│ CONFIRM │◄───────────────┐
|
||||
│ │ 确认弹窗 │ │
|
||||
│ └──────┬──────┘ │
|
||||
│ 取消 │ 确认 │
|
||||
│ ▼ │
|
||||
│ ┌─────────────┐ │
|
||||
│ │ GENERATING │────────────────┤
|
||||
│ │ 生成中 │ 应用崩溃/关闭 │
|
||||
│ └──────┬──────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────┼────────────┐ │
|
||||
│ │ │ │ │
|
||||
│ ▼ ▼ ▼ │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
|
||||
│ │SUCCESS │ │ FAILED │ │ TIMEOUT │ │
|
||||
│ └────┬────┘ └────┬────┘ └────┬────┘ │
|
||||
│ │ │ │ │
|
||||
│ ▼ └────────────┘ │
|
||||
│ ┌─────────┐ │ │
|
||||
└───┤ RESULT │◄──────────────┘ │
|
||||
│ 结果弹窗 │ │
|
||||
└────┬────┘ │
|
||||
│ │
|
||||
▼ │
|
||||
┌─────────┐ │
|
||||
│ IDLE │───────────────────────────────┘
|
||||
└─────────┘ 下次启动检测恢复
|
||||
```
|
||||
|
||||
## 六、关键决策点
|
||||
|
||||
| 决策 | 选择 | 理由 |
|
||||
|------|------|------|
|
||||
| 生成中能否关闭应用 | ✅ 可以,但提示后台继续 | 用户有急事时需要关闭 |
|
||||
| 生成中能否切换项目 | ❌ 不能 | 避免状态混乱 |
|
||||
| 生成中能否修改脚本 | ❌ 不能 | 避免参数不一致 |
|
||||
| 失败后能否重试 | ✅ 可以,只重试失败的 | 减少重复等待 |
|
||||
| 是否需要系统通知 | ✅ 需要(第二阶段) | 用户最小化后能感知完成 |
|
||||
@@ -0,0 +1,201 @@
|
||||
# 火山引擎音视频字幕 API 开发文档
|
||||
|
||||
> 更新日期: 2026-04-09
|
||||
> 官方文档: https://www.volcengine.com/docs/6561/80907
|
||||
|
||||
---
|
||||
|
||||
## 产品简介
|
||||
|
||||
火山引擎音视频字幕服务提供两种能力:
|
||||
|
||||
1. **音视频字幕生成** - 自动识别音频中的语音/歌词,生成带时间轴的字幕
|
||||
2. **自动字幕打轴** - 为已有字幕文本自动配上时间轴
|
||||
|
||||
---
|
||||
|
||||
## 基础信息
|
||||
|
||||
| 项目 | 内容 |
|
||||
|------|------|
|
||||
| 基础 URL | `https://openspeech.bytedance.com/api/v1/vc` |
|
||||
| 鉴权 Header | `Authorization: Bearer; {token}` |
|
||||
| 文件限制 | ≤200MB, 支持 WAV/M4A/MP3/MP4/MOV/OGG |
|
||||
|
||||
---
|
||||
|
||||
## API 接口
|
||||
|
||||
### 1. 音视频字幕生成
|
||||
|
||||
#### 提交任务
|
||||
```http
|
||||
POST /submit?appid={appid}&language=zh-CN&use_punc=True
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer; {token}
|
||||
|
||||
{"url": "https://example.com/audio.mp3"}
|
||||
```
|
||||
|
||||
**关键参数:**
|
||||
- `language` - 语言: `zh-CN`, `en-US`, `ja-JP`, `ko-KR`, `es-MX`, `ru-RU`, `fr-FR`, `yue`, `wuu`, `nan`, `ug`
|
||||
- `caption_type` - 识别类型: `auto`(默认), `speech`, `singing`
|
||||
- `use_punc` - 自动标点: `True`, `False`
|
||||
- `use_itn` - 数字转换: `True`(中文数字转阿拉伯数字)
|
||||
- `words_per_line` - 每行字数, 默认 46
|
||||
- `max_lines` - 每屏行数, 默认 1
|
||||
|
||||
#### 查询结果
|
||||
```http
|
||||
GET /query?appid={appid}&id={task_id}&blocking=1
|
||||
Authorization: Bearer; {token}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"message": "Success",
|
||||
"duration": 5.32,
|
||||
"utterances": [
|
||||
{
|
||||
"text": "识别文本",
|
||||
"start_time": 0,
|
||||
"end_time": 3197,
|
||||
"words": [
|
||||
{"text": "单字", "start_time": 0, "end_time": 208}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. 自动字幕打轴
|
||||
|
||||
#### 提交任务
|
||||
```http
|
||||
POST /ata/submit?appid={appid}&caption_type=speech
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer; {token}
|
||||
|
||||
{
|
||||
"url": "https://example.com/audio.mp3",
|
||||
"audio_text": "这是要被打轴的字幕文本"
|
||||
}
|
||||
```
|
||||
|
||||
**参数:**
|
||||
- `caption_type` - `speech`(说话) 或 `singing`(歌词)
|
||||
- `sta_punc_mode` - 标点模式: `1`(省略句末标点), `2`(空格代替), `3`(保留完整标点)
|
||||
|
||||
#### 查询结果
|
||||
```http
|
||||
GET /ata/query?appid={appid}&id={task_id}&blocking=1
|
||||
Authorization: Bearer; {token}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 错误码
|
||||
|
||||
| 码 | 含义 | 处理 |
|
||||
|----|------|------|
|
||||
| 0 | 成功 | - |
|
||||
| 2000 | 处理中 | 继续轮询 |
|
||||
| 1001 | 参数无效 | 检查必填参数 |
|
||||
| 1002 | 无权限 | 检查 token |
|
||||
| 1003 | 超频 | 降低调用频率 |
|
||||
| 1010 | 音频过长 | 缩短音频 |
|
||||
| 1011 | 音频过大 | 压缩音频(<200MB) |
|
||||
| 1012 | 格式无效 | 检查音频格式 |
|
||||
| 1013 | 音频静音 | 检查音频内容 |
|
||||
|
||||
---
|
||||
|
||||
## Python 代码示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
import time
|
||||
|
||||
TOKEN = "your_token"
|
||||
APPID = "your_appid"
|
||||
BASE_URL = "https://openspeech.bytedance.com/api/v1/vc"
|
||||
|
||||
def submit(audio_url, language="zh-CN", use_punc=True):
|
||||
"""提交字幕生成任务"""
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/submit",
|
||||
params={"appid": APPID, "language": language, "use_punc": str(use_punc)},
|
||||
json={"url": audio_url},
|
||||
headers={"Authorization": f"Bearer; {TOKEN}"}
|
||||
)
|
||||
return resp.json()["id"]
|
||||
|
||||
def query(task_id):
|
||||
"""查询任务结果"""
|
||||
resp = requests.get(
|
||||
f"{BASE_URL}/query",
|
||||
params={"appid": APPID, "id": task_id, "blocking": "1"},
|
||||
headers={"Authorization": f"Bearer; {TOKEN}"}
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
def generate_caption(audio_url, language="zh-CN"):
|
||||
"""完整流程: 提交->轮询->返回结果"""
|
||||
task_id = submit(audio_url, language)
|
||||
|
||||
for _ in range(60): # 最多轮询60秒
|
||||
result = query(task_id)
|
||||
if result["code"] == 0:
|
||||
return result["utterances"]
|
||||
elif result["code"] != 2000:
|
||||
raise Exception(f"Task failed: {result['message']}")
|
||||
time.sleep(1)
|
||||
|
||||
raise Exception("Timeout")
|
||||
|
||||
def to_srt(utterances):
|
||||
"""转换为 SRT 字幕格式"""
|
||||
def ms_to_time(ms):
|
||||
h = ms // 3600000
|
||||
m = (ms % 3600000) // 60000
|
||||
s = (ms % 60000) // 1000
|
||||
ms = ms % 1000
|
||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||
|
||||
lines = []
|
||||
for i, u in enumerate(utterances, 1):
|
||||
lines.append(f"{i}")
|
||||
lines.append(f"{ms_to_time(u['start_time'])} --> {ms_to_time(u['end_time'])}")
|
||||
lines.append(u['text'])
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
utterances = generate_caption("https://example.com/audio.mp3")
|
||||
srt_content = to_srt(utterances)
|
||||
print(srt_content)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## cURL 示例
|
||||
|
||||
```bash
|
||||
# 1. 提交任务
|
||||
TASK_ID=$(curl -s -X POST \
|
||||
-H "Authorization: Bearer; ${TOKEN}" \
|
||||
-H "content-type: application/json" \
|
||||
-d '{"url": "'${AUDIO_URL}'"}' \
|
||||
"https://openspeech.bytedance.com/api/v1/vc/submit?appid=${APPID}&language=zh-CN" \
|
||||
| jq -r '.id')
|
||||
|
||||
# 2. 查询结果
|
||||
curl -s -X GET \
|
||||
-H "Authorization: Bearer; ${TOKEN}" \
|
||||
"https://openspeech.bytedance.com/api/v1/vc/query?appid=${APPID}&id=${TASK_ID}&blocking=1"
|
||||
```
|
||||
@@ -0,0 +1,72 @@
|
||||
# 美家卡智影 API - 环境变量配置示例
|
||||
# ================================
|
||||
# 复制此文件为 .env 并填写实际值
|
||||
|
||||
# === 基础配置 ===
|
||||
APP_NAME=美家卡智影 API
|
||||
APP_VERSION=0.1.0
|
||||
DEBUG=true
|
||||
ENV=development
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
|
||||
# === 数据库配置 ===
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/meijiaka
|
||||
|
||||
# === Redis 配置 ===
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
# REDIS_PASSWORD= # 如无密码请留空或注释
|
||||
|
||||
# === JWT 安全配置 ===
|
||||
# 生产环境必须修改为强随机密钥
|
||||
SECRET_KEY=your-secret-key-here-change-in-production
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080
|
||||
ALGORITHM=HS256
|
||||
|
||||
# === CORS 配置 ===
|
||||
CORS_ORIGINS=http://localhost:1420,http://127.0.0.1:1420,http://localhost:8080
|
||||
|
||||
# === AI 平台配置 ===
|
||||
|
||||
# 火山方舟(必需)
|
||||
VOLCENGINE_API_KEY=your-volcengine-api-key
|
||||
VOLCENGINE_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||
|
||||
# 火山字幕服务(必需)
|
||||
VOLCENGINE_CAPTION_APPID=your-caption-appid
|
||||
VOLCENGINE_CAPTION_TOKEN=your-caption-token
|
||||
|
||||
# 可灵 AI(必需,用于视频生成)
|
||||
KLINGAI_ACCESS_KEY=your-kling-access-key
|
||||
KLINGAI_SECRET_KEY=your-kling-secret-key
|
||||
|
||||
# OpenAI(可选)
|
||||
# OPENAI_API_KEY=sk-your-openai-key
|
||||
# OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
||||
# 文心一言(可选)
|
||||
# WENXIN_API_KEY=your-wenxin-key
|
||||
# WENXIN_SECRET_KEY=your-wenxin-secret
|
||||
|
||||
# 通义千问(可选)
|
||||
# QIANWEN_API_KEY=your-qianwen-key
|
||||
|
||||
# === 七牛云存储(必需,用于空镜图片上传)===
|
||||
QINIU_ACCESS_KEY=your-qiniu-access-key
|
||||
QINIU_SECRET_KEY=your-qiniu-secret-key
|
||||
QINIU_VIDEO_BUCKET=media-liche
|
||||
QINIU_VIDEO_DOMAIN=media.liche.cn
|
||||
QINIU_IMAGE_BUCKET=img-liche
|
||||
QINIU_IMAGE_DOMAIN=img.liche.cn
|
||||
|
||||
# === 其他服务 ===
|
||||
|
||||
# AnyToCopy 文案提取(可选)
|
||||
ANYTOCOPY_API_KEY=your-anytocopy-api-key
|
||||
ANYTOCOPY_API_SECRET=your-anytocopy-secret
|
||||
ANYTOCOPY_BASE_URL=https://api.anytocopy.com/vip/open-api/v1
|
||||
|
||||
# === 日志配置 ===
|
||||
LOG_LEVEL=INFO
|
||||
@@ -0,0 +1,75 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Database
|
||||
*.db
|
||||
*.sqlite3
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Test coverage
|
||||
htmlcov/
|
||||
.coverage
|
||||
.pytest_cache/
|
||||
.tox/
|
||||
|
||||
# Alembic 迁移(保留脚本,忽略临时文件)
|
||||
alembic/versions/*.pyc
|
||||
|
||||
# Celery
|
||||
celerybeat-schedule
|
||||
|
||||
# Redis
|
||||
dump.rdb
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
|
||||
# Local development
|
||||
local/
|
||||
temp/
|
||||
tmp/
|
||||
|
||||
# Data files
|
||||
data/
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
# 美家卡智影 - Git 钩子配置
|
||||
# 安装: pre-commit install
|
||||
# 手动运行: pre-commit run --all-files
|
||||
|
||||
repos:
|
||||
# 代码格式化
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.13
|
||||
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
# TODO: 修复历史遗留类型错误后重新启用
|
||||
# 类型检查(暂时禁用)
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.14.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# additional_dependencies: [types-PyYAML]
|
||||
|
||||
# 安全扫描(暂时禁用)
|
||||
# - repo: https://github.com/PyCQA/bandit
|
||||
# rev: 1.8.0
|
||||
# hooks:
|
||||
# - id: bandit
|
||||
# args: ["-c", "pyproject.toml"]
|
||||
# additional_dependencies: ["bandit[toml]"]
|
||||
|
||||
# 依赖锁定文件同步检查
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: uv-lock-check
|
||||
name: Check uv lock file is up-to-date
|
||||
entry: bash -c 'uv pip compile pyproject.toml -o requirements.lock --locked'
|
||||
language: system
|
||||
files: ^(pyproject\.toml|requirements\.lock)$
|
||||
pass_filenames: false
|
||||
@@ -0,0 +1 @@
|
||||
3.13
|
||||
@@ -0,0 +1 @@
|
||||
{"http:Pn60lJXcaOGKvMjn5qv-OMr7wR1lp1p8QG7Ul6NK:media-liche": {"upHosts": ["http://upload-z2.qiniup.com", "http://up-z2.qiniup.com"], "ioHosts": ["http://iovip-z2.qbox.me"], "rsHosts": ["http://rs-z2.qbox.me"], "rsfHosts": ["http://rsf-z2.qbox.me"], "apiHosts": ["http://api-z2.qiniu.com"], "deadline": 1776740815}, "http:Pn60lJXcaOGKvMjn5qv-OMr7wR1lp1p8QG7Ul6NK:img-liche": {"upHosts": ["http://upload-z2.qiniup.com", "http://up-z2.qiniup.com"], "ioHosts": ["http://iovip-z2.qbox.me"], "rsHosts": ["http://rs-z2.qbox.me"], "rsfHosts": ["http://rsf-z2.qbox.me"], "apiHosts": ["http://api-z2.qiniu.com"], "deadline": 1776433218}}
|
||||
@@ -0,0 +1,44 @@
|
||||
# 美家卡智影 API - Docker 镜像 (使用 uv 优化)
|
||||
# ===========================================
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
|
||||
# 设置 uv 环境变量
|
||||
ENV UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PYTHON_DOWNLOADS=never
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 先复制锁定文件,利用 Docker 缓存层
|
||||
COPY requirements.lock pyproject.toml ./
|
||||
|
||||
# 创建虚拟环境并安装依赖(利用 uv 的速度优势)
|
||||
RUN uv venv /opt/venv && \
|
||||
uv pip sync --python /opt/venv/bin/python requirements.lock
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ ./app/
|
||||
|
||||
# 安装应用本身(不安装 dev 依赖)
|
||||
RUN uv pip install --python /opt/venv/bin/python --no-deps -e .
|
||||
|
||||
# ===== 生产镜像 =====
|
||||
FROM python:3.13-slim AS production
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# 从 builder 复制虚拟环境
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ ./app/
|
||||
COPY pyproject.toml .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,140 @@
|
||||
# 美家卡智影 API - 常用命令
|
||||
# ==========================
|
||||
|
||||
.PHONY: help install dev install-hooks update-lock lint format test security clean docker
|
||||
|
||||
help: ## 显示帮助信息
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
# ========== 依赖管理 ==========
|
||||
|
||||
install: ## 安装生产依赖(使用 lock 文件)
|
||||
uv pip sync requirements.lock
|
||||
|
||||
dev: ## 安装开发依赖(包含 dev extras)
|
||||
uv pip install -e ".[dev]"
|
||||
pre-commit install
|
||||
|
||||
install-hooks: ## 安装 Git pre-commit 钩子
|
||||
pre-commit install
|
||||
|
||||
update-lock: ## 更新 requirements.lock(修改 pyproject.toml 后执行)
|
||||
uv pip compile pyproject.toml -o requirements.lock --upgrade
|
||||
|
||||
update-lock-no-upgrade: ## 重新生成 lock 文件(不升级版本)
|
||||
uv pip compile pyproject.toml -o requirements.lock
|
||||
|
||||
# ========== 代码质量 ==========
|
||||
|
||||
lint: ## 运行代码检查 (ruff + mypy)
|
||||
ruff check app/
|
||||
mypy app/
|
||||
|
||||
format: ## 格式化代码 (black + ruff)
|
||||
black app/
|
||||
ruff check --fix app/
|
||||
|
||||
format-check: ## 检查代码格式(不修改)
|
||||
black --check app/
|
||||
ruff check app/
|
||||
|
||||
# ========== 测试 ==========
|
||||
|
||||
test: ## 运行测试
|
||||
pytest -v
|
||||
|
||||
test-cov: ## 运行测试并生成覆盖率报告
|
||||
pytest --cov=app --cov-report=html --cov-report=term
|
||||
|
||||
# ========== 安全扫描 ==========
|
||||
|
||||
security: ## 运行安全扫描 (bandit + pip-audit)
|
||||
@echo "🔍 运行 Bandit 安全扫描..."
|
||||
bandit -r app/ -c pyproject.toml
|
||||
@echo "🔍 运行依赖漏洞扫描..."
|
||||
pip-audit
|
||||
|
||||
# ========== 开发服务器 ==========
|
||||
|
||||
run: ## 启动开发服务器
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
scheduler: ## 启动 Async Engine Scheduler
|
||||
python -m app.scheduler.main
|
||||
|
||||
# ========== Docker ==========
|
||||
|
||||
docker: ## 构建 Docker 镜像
|
||||
docker build -t meijiaka-api:latest .
|
||||
|
||||
docker-run: ## 使用 Docker Compose 启动全部服务
|
||||
docker-compose up -d
|
||||
|
||||
docker-logs: ## 查看 Docker 日志
|
||||
docker-compose logs -f
|
||||
|
||||
docker-down: ## 停止 Docker 服务
|
||||
docker-compose down
|
||||
|
||||
# ========== 清理 ==========
|
||||
|
||||
clean: ## 清理缓存文件
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
rm -rf htmlcov/ .coverage 2>/dev/null || true
|
||||
|
||||
# ========== 语义层防护网 ==========
|
||||
|
||||
lint-semantic: ## 语义层禁词检查(防止供应商术语泄漏到业务层)
|
||||
@echo "🔍 检查 Layer 3+ 是否泄漏供应商术语..."
|
||||
@# API 层(除 klingai Provider 代理)禁止 element_id 作为字段/参数名
|
||||
@errs=$$(grep -rn 'element_id' app/api --include='*.py' \
|
||||
| grep -v 'klingai.py' \
|
||||
| grep -v 'provider_element_id' \
|
||||
| grep -v '__pycache__' \
|
||||
| grep -v '#' \
|
||||
| grep -v '".*element_id.*"' \
|
||||
| grep -v "'.*element_id.*'"); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ API 层发现 element_id(应使用 provider_element_id 或 human_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# Scheduler 层禁止 task_id 作为内部变量/Redis key(读取 Provider 返回除外)
|
||||
@errs=$$(grep -rn '\btask_id\b' app/scheduler --include='*.py' \
|
||||
| grep -v 'job_id' \
|
||||
| grep -v '__pycache__' \
|
||||
| grep -v '\.get("task_id")' \
|
||||
| grep -v 'result.get("task_id")' \
|
||||
| grep -v 'task_type' \
|
||||
| grep -v '"task_id"' \
|
||||
| grep -v "'task_id'"); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ Scheduler 层发现 task_id(应使用 job_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# 全局禁止 kling_task_id 作为持久化字段
|
||||
@errs=$$(grep -rn 'kling_task_id' app --include='*.py' \
|
||||
| grep -v '__pycache__' \
|
||||
| grep -v 'providers/klingai'); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ 发现 kling_task_id(应使用 provider_task_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# Scheduler 层 Redis key 必须使用 job: 而非 task:
|
||||
@errs=$$(grep -rn 'redis.*task:' app/scheduler --include='*.py' \
|
||||
| grep -v '__pycache__'); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ Scheduler Redis key 使用 task:(应使用 job:)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "✅ 语义层检查通过"
|
||||
|
||||
# ========== CI 检查 ==========
|
||||
|
||||
ci: format-check lint lint-semantic test security ## 运行所有 CI 检查
|
||||
@@ -0,0 +1,166 @@
|
||||
# 美家卡智影 API
|
||||
|
||||
美家卡智影后端服务 - 基于 FastAPI + PostgreSQL + Redis 的 AI 视频创作 API。
|
||||
|
||||
## 技术栈
|
||||
|
||||
| 组件 | 技术 | 版本 |
|
||||
|------|------|------|
|
||||
| Web 框架 | FastAPI | ^0.110.0 |
|
||||
| 数据库 | PostgreSQL | 15+ |
|
||||
| ORM | SQLAlchemy | 2.0+ (异步) |
|
||||
| 缓存/状态 | Redis | 7.x |
|
||||
| 异步调度 | Async Engine (Slot Scheduler) | Python asyncio |
|
||||
| 部署 | Docker + Docker Compose | - |
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 环境准备
|
||||
|
||||
确保已安装:
|
||||
- Python 3.11+
|
||||
- Docker & Docker Compose(推荐)
|
||||
- 或本地 PostgreSQL + Redis
|
||||
|
||||
### 2. 使用 Docker Compose 启动(推荐)
|
||||
|
||||
```bash
|
||||
# 1. 克隆项目后进入目录
|
||||
cd python-api
|
||||
|
||||
# 2. 复制环境变量配置
|
||||
cp .env.example .env
|
||||
|
||||
# 3. 启动所有服务
|
||||
docker-compose up -d
|
||||
|
||||
# 4. 查看日志
|
||||
docker-compose logs -f api
|
||||
|
||||
# 5. 服务地址
|
||||
# API: http://localhost:8080
|
||||
# 文档: http://localhost:8080/docs
|
||||
```
|
||||
|
||||
### 3. 本地开发
|
||||
|
||||
```bash
|
||||
# 1. 创建虚拟环境
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
|
||||
# 2. 安装依赖
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# 3. 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env,修改数据库连接等配置
|
||||
|
||||
# 4. 启动 PostgreSQL 和 Redis(Docker)
|
||||
docker-compose up -d db redis
|
||||
|
||||
# 5. 启动开发服务器
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# 7. 启动 Async Engine Scheduler(另开终端)
|
||||
python -m app.scheduler.main
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
python-api/
|
||||
├── app/ # 主应用代码
|
||||
│ ├── api/v1/ # API 路由
|
||||
│ ├── core/ # 核心工具(安全、异常)
|
||||
│ ├── db/ # 数据库配置
|
||||
│ ├── models/ # SQLAlchemy 模型
|
||||
│ ├── schemas/ # Pydantic Schema
|
||||
│ ├── services/ # 业务逻辑
|
||||
│ ├── scheduler/ # Async Engine 异步任务调度
|
||||
│ ├── ai/ # AI 模型相关
|
||||
│ ├── utils/ # 工具函数
|
||||
│ ├── config.py # 配置管理
|
||||
│ └── main.py # FastAPI 入口
|
||||
├── docker-compose.yml # Docker 编排
|
||||
├── Dockerfile # Docker 镜像
|
||||
├── pyproject.toml # 项目依赖
|
||||
└── README.md # 本文档
|
||||
```
|
||||
|
||||
## 数据模型
|
||||
|
||||
### 核心实体
|
||||
|
||||
- **User** - 用户/设备(设备 ID + JWT 认证)
|
||||
- **Project** - 视频创作项目
|
||||
- **ScriptSegment** - 脚本分镜
|
||||
- **MediaAsset** - 媒体元数据(音频/视频/封面)
|
||||
- **TaskQueue** - 异步任务队列
|
||||
|
||||
## API 路由
|
||||
|
||||
### 已实现
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| POST | `/api/v1/auth/login` | 设备登录/注册 |
|
||||
| GET | `/api/v1/auth/me` | 获取当前用户 |
|
||||
| GET | `/api/v1/system/health` | 健康检查 |
|
||||
| GET | `/api/v1/system/version` | 版本信息 |
|
||||
|
||||
### 待实现(M2-M5)
|
||||
|
||||
- `/api/v1/script/*` - 脚本生成(SSE 流式)
|
||||
- `/api/v1/voice/*` - 语音合成(TTS)
|
||||
- `/api/v1/video/*` - 数字人视频(异步任务)
|
||||
- `/api/v1/project/*` - 项目云同步
|
||||
- `/api/v1/parser/*` - 视频链接解析(预留)
|
||||
|
||||
## 环境变量
|
||||
|
||||
见 `.env.example`,主要配置项:
|
||||
|
||||
| 变量 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `DATABASE_URL` | PostgreSQL 连接字符串 | `postgresql+asyncpg://postgres:postgres@localhost:5432/meijiaka` |
|
||||
| `REDIS_URL` | Redis 连接字符串 | `redis://localhost:6379/0` |
|
||||
| `SECRET_KEY` | JWT 签名密钥 | 必须修改 |
|
||||
| `OPENAI_API_KEY` | OpenAI API Key | - |
|
||||
| `CORS_ORIGINS` | 允许的跨域来源 | `http://localhost:1420` |
|
||||
|
||||
## 开发规范
|
||||
|
||||
### 代码风格
|
||||
|
||||
```bash
|
||||
# 格式化
|
||||
black app/
|
||||
|
||||
# 检查
|
||||
ruff check app/
|
||||
mypy app/
|
||||
|
||||
# 测试
|
||||
pytest
|
||||
```
|
||||
|
||||
### 提交规范
|
||||
|
||||
- `feat:` 新功能
|
||||
- `fix:` 修复
|
||||
- `docs:` 文档
|
||||
- `refactor:` 重构
|
||||
- `test:` 测试
|
||||
|
||||
## 与前端集成
|
||||
|
||||
Tauri 前端默认连接 `http://127.0.0.1:8080/api/v1`。
|
||||
|
||||
云端部署后:
|
||||
1. 修改前端 `src/api/client.ts` 中的 `PYTHON_API_BASE_URL`
|
||||
2. 更新 `tauri.conf.json` CSP 配置,添加云端域名到 `connect-src`
|
||||
|
||||
## 许可
|
||||
|
||||
MIT
|
||||
@@ -0,0 +1,150 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
# 数据库 URL 从环境变量读取,在 env.py 中设置
|
||||
# sqlalchemy.url = postgresql://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Alembic 环境配置 - PostgreSQL
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# 加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 导入模型
|
||||
from app.db.session import Base
|
||||
from app.models.avatar import Avatar # noqa
|
||||
from app.models.model_usage import ModelUsageLog # noqa
|
||||
from app.models.user import User # noqa
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# 从环境变量读取数据库 URL
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if database_url:
|
||||
# 将 asyncpg 转换为 psycopg2 用于 alembic (同步)
|
||||
sync_database_url = database_url.replace("+asyncpg", "")
|
||||
config.set_main_option("sqlalchemy.url", sync_database_url)
|
||||
|
||||
# 设置日志
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# 模型元数据
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,106 @@
|
||||
"""rename_avatar_vendor_fields_add_provider
|
||||
|
||||
Revision ID: 451756e6a43e
|
||||
Revises: d4bd9ad91607
|
||||
Create Date: 2026-04-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "451756e6a43e"
|
||||
down_revision: str | Sequence[str] | None = "d4bd9ad91607"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# Add provider column with default "kling"
|
||||
op.add_column(
|
||||
"avatars",
|
||||
sa.Column(
|
||||
"provider",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="kling",
|
||||
comment="供应商标识",
|
||||
),
|
||||
)
|
||||
|
||||
# Rename element_id -> provider_element_id
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"element_id",
|
||||
new_column_name="provider_element_id",
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# Rename voice_task_id -> provider_voice_job_id
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"voice_task_id",
|
||||
new_column_name="provider_voice_job_id",
|
||||
existing_type=sa.String(length=128),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# Rename element_task_id -> provider_element_job_id
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"element_task_id",
|
||||
new_column_name="provider_element_job_id",
|
||||
existing_type=sa.String(length=128),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# Rename indexes
|
||||
op.drop_index("ix_avatars_voice_task_id", table_name="avatars")
|
||||
op.drop_index("ix_avatars_element_task_id", table_name="avatars")
|
||||
op.create_index(
|
||||
"ix_avatars_provider_voice_job_id", "avatars", ["provider_voice_job_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_avatars_provider_element_job_id", "avatars", ["provider_element_job_id"], unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# Rename indexes back
|
||||
op.drop_index("ix_avatars_provider_element_job_id", table_name="avatars")
|
||||
op.drop_index("ix_avatars_provider_voice_job_id", table_name="avatars")
|
||||
op.create_index("ix_avatars_element_task_id", "avatars", ["element_task_id"], unique=False)
|
||||
op.create_index("ix_avatars_voice_task_id", "avatars", ["voice_task_id"], unique=False)
|
||||
|
||||
# Rename columns back
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"provider_element_job_id",
|
||||
new_column_name="element_task_id",
|
||||
existing_type=sa.String(length=128),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"provider_voice_job_id",
|
||||
new_column_name="voice_task_id",
|
||||
existing_type=sa.String(length=128),
|
||||
existing_nullable=True,
|
||||
)
|
||||
op.alter_column(
|
||||
"avatars",
|
||||
"provider_element_id",
|
||||
new_column_name="element_id",
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# Drop provider column
|
||||
op.drop_column("avatars", "provider")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""add avatars table
|
||||
|
||||
Revision ID: d4bd9ad91607
|
||||
Revises: fb1be66e804a
|
||||
Create Date: 2026-04-06 21:51:36.225361
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'd4bd9ad91607'
|
||||
down_revision: Union[str, Sequence[str], None] = 'fb1be66e804a'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('avatars',
|
||||
sa.Column('id', sa.String(length=64), nullable=False, comment='形象唯一标识(Kling element_id 字符串)'),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False, comment='关联用户 ID'),
|
||||
sa.Column('name', sa.String(length=64), nullable=False, comment='形象展示名称'),
|
||||
sa.Column('voice_id', sa.String(length=64), nullable=True, comment='Kling 自定义音色 ID'),
|
||||
sa.Column('element_id', sa.BigInteger(), nullable=True, comment='Kling 主体 ID'),
|
||||
sa.Column('voice_task_id', sa.String(length=128), nullable=True, comment='Kling 自定义音色任务 ID'),
|
||||
sa.Column('element_task_id', sa.String(length=128), nullable=True, comment='Kling 主体创建任务 ID'),
|
||||
sa.Column('video_url', sa.Text(), nullable=False, comment='原始人物视频 URL'),
|
||||
sa.Column('trial_url', sa.Text(), nullable=True, comment='音色试听音频 URL'),
|
||||
sa.Column('status', sa.String(length=32), nullable=False, comment='状态: pending/voice_processing/voice_failed/element_processing/element_failed/succeed/timeout'),
|
||||
sa.Column('fail_reason', sa.Text(), nullable=True, comment='失败原因(中文可读)'),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True, comment='软删除时间,NULL 表示未删除'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, comment='记录创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, comment='记录更新时间'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_avatars_element_task_id'), 'avatars', ['element_task_id'], unique=False)
|
||||
op.create_index(op.f('ix_avatars_user_id'), 'avatars', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_avatars_voice_task_id'), 'avatars', ['voice_task_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_avatars_voice_task_id'), table_name='avatars')
|
||||
op.drop_index(op.f('ix_avatars_user_id'), table_name='avatars')
|
||||
op.drop_index(op.f('ix_avatars_element_task_id'), table_name='avatars')
|
||||
op.drop_table('avatars')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,38 @@
|
||||
"""replace device_id with mobile in users table
|
||||
|
||||
Revision ID: fb1be66e804a
|
||||
Revises:
|
||||
Create Date: 2026-04-03 10:22:30.465704
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fb1be66e804a'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('mobile', sa.String(length=20), nullable=False, comment='手机号'))
|
||||
op.drop_index(op.f('ix_users_device_id'), table_name='users')
|
||||
op.create_index(op.f('ix_users_mobile'), 'users', ['mobile'], unique=True)
|
||||
op.drop_column('users', 'device_id')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('device_id', sa.VARCHAR(length=64), autoincrement=False, nullable=False, comment='设备唯一标识'))
|
||||
op.drop_index(op.f('ix_users_mobile'), table_name='users')
|
||||
op.create_index(op.f('ix_users_device_id'), 'users', ['device_id'], unique=True)
|
||||
op.drop_column('users', 'mobile')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
AI 模型路由 V2 - 基于文件配置
|
||||
=================================
|
||||
|
||||
从 YAML 配置文件加载平台/模型配置,支持热重载。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
|
||||
from app.ai.providers.generic_llm_provider import MockProvider
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||||
from app.config import get_settings
|
||||
from app.core.config_loader import AIModelConfigLoader, get_config_loader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformInstance:
|
||||
"""平台实例包装器"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.provider = self._create_provider()
|
||||
|
||||
def _create_provider(self):
|
||||
"""根据平台类型创建 Provider
|
||||
|
||||
API Key 从 Settings 读取(符合配置规范)
|
||||
"""
|
||||
provider_type = self.config.get("provider", "mock")
|
||||
settings = get_settings()
|
||||
|
||||
if provider_type == "volcengine":
|
||||
# 从 Settings 读取 API Key
|
||||
api_key = settings.VOLCENGINE_API_KEY
|
||||
if not api_key:
|
||||
raise ProviderError(
|
||||
"Volcengine API Key 未配置,请在 .env 中设置 VOLCENGINE_API_KEY"
|
||||
)
|
||||
return VolcengineProvider(
|
||||
api_key=api_key,
|
||||
base_url=self.config.get("base_url") or settings.VOLCENGINE_BASE_URL,
|
||||
)
|
||||
elif provider_type == "klingai":
|
||||
# 从 Settings 读取 AK/SK
|
||||
access_key = settings.KLINGAI_ACCESS_KEY
|
||||
secret_key = settings.KLINGAI_SECRET_KEY
|
||||
if not access_key or not secret_key:
|
||||
raise ProviderError(
|
||||
"KlingAI Access/Secret Key 未配置,请在 .env 中设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY"
|
||||
)
|
||||
return KlingAIProvider(
|
||||
config={
|
||||
"access_key": access_key,
|
||||
"secret_key": secret_key,
|
||||
"base_url": self.config.get("base_url"),
|
||||
}
|
||||
)
|
||||
elif provider_type == "mock":
|
||||
return MockProvider()
|
||||
else:
|
||||
raise ProviderError(f"不支持的 Provider 类型: {provider_type}")
|
||||
|
||||
async def generate(
|
||||
self, model_name: str, prompt: str, **kwargs
|
||||
) -> GenerationResult:
|
||||
"""调用生成"""
|
||||
return await self.provider.generate(prompt=prompt, model=model_name, **kwargs)
|
||||
|
||||
async def generate_stream(
|
||||
self, model_name: str, prompt: str, **kwargs
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式生成"""
|
||||
async for chunk in self.provider.generate_stream(
|
||||
prompt=prompt, model=model_name, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def health_check(self, model_name: str | None = None) -> ModelHealth:
|
||||
"""健康检查"""
|
||||
return await self.provider.health_check(model_name)
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由 V2 - 基于文件配置
|
||||
|
||||
支持:
|
||||
- 从 YAML 文件加载配置
|
||||
- 多平台配置
|
||||
- 每平台多模型
|
||||
- 模型自动选择
|
||||
- 故障降级
|
||||
- 配置热重载
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.platforms: dict[str, PlatformInstance] = {}
|
||||
self._config_loader: AIModelConfigLoader | None = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self, db_session=None):
|
||||
"""初始化路由(db_session 参数保留兼容性,实际不使用)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 从文件配置加载
|
||||
self._config_loader = get_config_loader()
|
||||
self._load_from_config()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"ModelRouter 初始化完成: {len(self.platforms)} 平台")
|
||||
|
||||
def _load_from_config(self):
|
||||
"""从配置文件加载平台和模型"""
|
||||
self.platforms = {}
|
||||
|
||||
# 加载平台
|
||||
for platform in self._config_loader.get_all_platforms():
|
||||
try:
|
||||
# PlatformInstance 自动从 Settings 读取 API Key
|
||||
self.platforms[platform.id] = PlatformInstance(
|
||||
{
|
||||
"id": platform.id,
|
||||
"name": platform.name,
|
||||
"provider": platform.provider,
|
||||
"base_url": platform.base_url,
|
||||
}
|
||||
)
|
||||
logger.info(f"平台 {platform.id} 初始化成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"平台 {platform.id} 初始化失败: {e}")
|
||||
|
||||
# 加载模型到 Provider(用于模型名称映射)
|
||||
volcengine_models = []
|
||||
for model in self._config_loader.get_enabled_models():
|
||||
if model.platform_id == "volcengine":
|
||||
volcengine_models.append(
|
||||
{
|
||||
"id": model.id,
|
||||
"model_name": model.model_name,
|
||||
}
|
||||
)
|
||||
|
||||
if volcengine_models:
|
||||
VolcengineProvider.load_models_from_config(volcengine_models)
|
||||
logger.info(f"已加载 {len(volcengine_models)} 个火山方舟模型到 Provider")
|
||||
|
||||
def reload_config(self) -> bool:
|
||||
"""重新加载配置"""
|
||||
if self._config_loader and self._config_loader.reload():
|
||||
self._load_from_config()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict | None:
|
||||
"""获取模型配置"""
|
||||
if self._config_loader:
|
||||
model = self._config_loader.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"id": model.id,
|
||||
"platform_id": model.platform_id,
|
||||
"model_name": model.model_name,
|
||||
"display_name": model.display_name,
|
||||
"capabilities": model.capabilities,
|
||||
"default_params": model.default_params,
|
||||
"cost_per_1k_input": model.cost_per_1k_input,
|
||||
"cost_per_1k_output": model.cost_per_1k_output,
|
||||
"max_tokens_limit": model.max_tokens_limit,
|
||||
}
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self, capability: str | None = None, platform_id: str | None = None
|
||||
) -> list[dict]:
|
||||
"""列出可用模型"""
|
||||
models = []
|
||||
|
||||
if self._config_loader:
|
||||
if capability:
|
||||
config_models = self._config_loader.get_models_by_capability(capability)
|
||||
elif platform_id:
|
||||
config_models = self._config_loader.get_models_by_platform(platform_id)
|
||||
else:
|
||||
config_models = self._config_loader.get_enabled_models()
|
||||
|
||||
for model in config_models:
|
||||
models.append(
|
||||
{
|
||||
"id": model.id,
|
||||
"platform_id": model.platform_id,
|
||||
"model_name": model.model_name,
|
||||
"display_name": model.display_name,
|
||||
"capabilities": model.capabilities,
|
||||
"default_params": model.default_params,
|
||||
"cost_per_1k_input": model.cost_per_1k_input,
|
||||
"cost_per_1k_output": model.cost_per_1k_output,
|
||||
"max_tokens_limit": model.max_tokens_limit,
|
||||
}
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
def list_platforms(self) -> list[dict]:
|
||||
"""列出所有平台"""
|
||||
if self._config_loader:
|
||||
return [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"provider": p.provider,
|
||||
}
|
||||
for p in self._config_loader.get_all_platforms()
|
||||
]
|
||||
return []
|
||||
|
||||
def select_model_for_task(self, task_type: str) -> str | None:
|
||||
"""根据任务类型选择最佳模型"""
|
||||
# 先检查任务默认配置
|
||||
if self._config_loader:
|
||||
default_model = self._config_loader.get_default_model_for_task(task_type)
|
||||
if default_model:
|
||||
model = self._config_loader.get_model(default_model)
|
||||
if model and model.is_enabled:
|
||||
return default_model
|
||||
|
||||
# 按能力匹配
|
||||
candidates = self._config_loader.get_models_by_capability(task_type)
|
||||
if candidates:
|
||||
return candidates[0].id
|
||||
|
||||
return None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
生成文本
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_id: 指定模型 ID,None 则自动选择
|
||||
task_type: 任务类型(用于自动选模型)
|
||||
"""
|
||||
# 确定模型
|
||||
if model_id is None:
|
||||
if task_type:
|
||||
model_id = self.select_model_for_task(task_type)
|
||||
if model_id is None:
|
||||
# 使用第一个可用模型
|
||||
models = (
|
||||
self._config_loader.get_enabled_models()
|
||||
if self._config_loader
|
||||
else []
|
||||
)
|
||||
if models:
|
||||
model_id = models[0].id
|
||||
else:
|
||||
raise ProviderError("没有可用的模型")
|
||||
|
||||
if self._config_loader:
|
||||
model = self._config_loader.get_model(model_id)
|
||||
if not model:
|
||||
raise ProviderError(f"模型不存在: {model_id}")
|
||||
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if not platform:
|
||||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||||
|
||||
# 合并默认参数
|
||||
params = {**model.default_params, **kwargs}
|
||||
|
||||
# 调用生成
|
||||
try:
|
||||
result = await platform.generate(
|
||||
prompt=prompt, model_name=model.model_name, **params
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型 {model_id} 生成失败: {e}")
|
||||
raise
|
||||
|
||||
async def generate_stream_with_progress(
|
||||
self,
|
||||
prompt: str,
|
||||
model_id: str | None = None,
|
||||
task_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
流式生成文本,带进度信息
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_id: 指定模型 ID
|
||||
task_type: 任务类型
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
dict: 包含 type, content, total_chars 等字段
|
||||
"""
|
||||
# 确定模型
|
||||
if model_id is None:
|
||||
if task_type:
|
||||
model_id = self.select_model_for_task(task_type)
|
||||
if model_id is None:
|
||||
models = (
|
||||
self._config_loader.get_enabled_models()
|
||||
if self._config_loader
|
||||
else []
|
||||
)
|
||||
if models:
|
||||
model_id = models[0].id
|
||||
else:
|
||||
raise ProviderError("没有可用的模型")
|
||||
|
||||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
if not model:
|
||||
raise ProviderError(f"模型不存在: {model_id}")
|
||||
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if not platform:
|
||||
raise ProviderError(f"平台不存在: {model.platform_id}")
|
||||
|
||||
# 合并默认参数
|
||||
params = {**model.default_params, **kwargs}
|
||||
|
||||
# 检查 provider 是否有 generate_stream_with_progress 方法
|
||||
provider = platform.provider
|
||||
if hasattr(provider, "generate_stream_with_progress"):
|
||||
async for chunk in provider.generate_stream_with_progress(
|
||||
prompt=prompt, model=model.model_name, **params
|
||||
):
|
||||
yield chunk
|
||||
else:
|
||||
# 降级到普通流式生成
|
||||
full_content = ""
|
||||
async for content in provider.generate_stream(
|
||||
prompt=prompt, model=model.model_name, **params
|
||||
):
|
||||
full_content += content
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"content": content,
|
||||
"total_chars": len(full_content),
|
||||
}
|
||||
|
||||
yield {
|
||||
"type": "usage",
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
}
|
||||
|
||||
async def health_check(self, model_id: str | None = None) -> dict[str, ModelHealth]:
|
||||
"""检查模型健康状态"""
|
||||
results = {}
|
||||
|
||||
if model_id:
|
||||
model = (
|
||||
self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
)
|
||||
if model:
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if platform:
|
||||
results[model_id] = await platform.health_check(model.model_name)
|
||||
else:
|
||||
# 检查所有模型
|
||||
if self._config_loader:
|
||||
for model in self._config_loader.get_enabled_models():
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
if platform:
|
||||
try:
|
||||
results[model.id] = await platform.health_check(
|
||||
model.model_name
|
||||
)
|
||||
except Exception as e:
|
||||
results[model.id] = ModelHealth(
|
||||
id=model.id,
|
||||
name=model.display_name,
|
||||
is_available=False,
|
||||
response_time=0,
|
||||
last_error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 全局单例
|
||||
_model_router: ModelRouter | None = None
|
||||
_init_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_model_router(db_session=None) -> ModelRouter:
|
||||
"""获取 ModelRouter 单例(线程安全)
|
||||
|
||||
使用双重检查锁定模式确保并发安全。
|
||||
"""
|
||||
global _model_router
|
||||
if _model_router is None:
|
||||
async with _init_lock:
|
||||
# 双重检查,防止在获取锁期间其他协程已初始化
|
||||
if _model_router is None:
|
||||
logger.info("Initializing ModelRouter singleton...")
|
||||
_model_router = ModelRouter()
|
||||
await _model_router.initialize(db_session)
|
||||
logger.info("ModelRouter singleton initialized")
|
||||
return _model_router
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Prompt 模板系统
|
||||
================
|
||||
|
||||
家装行业 AI 视频脚本 Prompt 模板。
|
||||
所有 Prompt 存储在 txt 文件中,支持热更新。
|
||||
|
||||
使用示例:
|
||||
from app.ai.prompts import load_script_system, load_script_user
|
||||
|
||||
# 加载 System Prompt
|
||||
system = load_script_system()
|
||||
|
||||
# 加载并渲染 User Prompt
|
||||
user = load_script_user(
|
||||
topic="装修避坑",
|
||||
duration=45,
|
||||
script_type="干货型"
|
||||
)
|
||||
"""
|
||||
|
||||
from .loader import (
|
||||
SCRIPT_TYPES,
|
||||
VIDEO_STYLES,
|
||||
PolishPromptBuilder,
|
||||
ScriptPromptBuilder,
|
||||
load_polish_scene,
|
||||
load_polish_voiceover,
|
||||
load_prompt,
|
||||
load_script_system,
|
||||
load_script_user,
|
||||
render_template,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"load_prompt",
|
||||
"render_template",
|
||||
"load_script_system",
|
||||
"load_script_user",
|
||||
"load_polish_scene",
|
||||
"load_polish_voiceover",
|
||||
"ScriptPromptBuilder",
|
||||
"PolishPromptBuilder",
|
||||
"SCRIPT_TYPES",
|
||||
"VIDEO_STYLES",
|
||||
]
|
||||
@@ -0,0 +1 @@
|
||||
根据标题"{caption}"生成一张适合短视频封面的竖屏图片,画面精美、视觉冲击力强的营销风格,主体人物自然融入场景。
|
||||
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Prompt 简单加载器
|
||||
=================
|
||||
从文件加载 Prompt,支持热更新。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
|
||||
_PROMPTS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load_prompt(path: str) -> str:
|
||||
"""
|
||||
加载 Prompt 文件
|
||||
|
||||
Args:
|
||||
path: 相对路径,如 "script/system", "polish/scene"
|
||||
|
||||
Returns:
|
||||
Prompt 内容,文件不存在返回空字符串
|
||||
"""
|
||||
file_path = _PROMPTS_DIR / f"{path}.txt"
|
||||
if file_path.exists():
|
||||
return file_path.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
|
||||
def render_template(template: str, **kwargs) -> str:
|
||||
"""
|
||||
安全渲染模板变量
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
**kwargs: 变量值
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
try:
|
||||
# 转义 $ 符号防止用户输入干扰
|
||||
safe_kwargs = {k: str(v).replace("$", "$$") for k, v in kwargs.items()}
|
||||
return Template(template).substitute(**safe_kwargs)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"模板缺少变量: {e}")
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def load_script_system() -> str:
|
||||
"""加载脚本生成 System Prompt"""
|
||||
return load_prompt("script/system")
|
||||
|
||||
|
||||
def load_script_user(topic: str, duration: int, script_type: str) -> str:
|
||||
"""加载并渲染脚本生成 User Prompt"""
|
||||
template = load_prompt("script/user")
|
||||
return render_template(template, topic=topic, duration=duration, type=script_type)
|
||||
|
||||
|
||||
def load_polish_scene() -> str:
|
||||
"""加载画面润色 Prompt"""
|
||||
return load_prompt("polish/scene")
|
||||
|
||||
|
||||
def load_polish_voiceover() -> str:
|
||||
"""加载文案润色 Prompt"""
|
||||
return load_prompt("polish/voiceover")
|
||||
|
||||
|
||||
# 预定义的脚本类型和风格
|
||||
SCRIPT_TYPES = [
|
||||
{"id": "干货型", "name": "干货型", "description": "知识分享、技巧传授"},
|
||||
{"id": "故事型", "name": "故事型", "description": "案例故事、用户体验"},
|
||||
{"id": "对比型", "name": "对比型", "description": "产品对比、优劣分析"},
|
||||
{"id": "避坑型", "name": "避坑型", "description": "防骗指南、常见误区"},
|
||||
{"id": "测评型", "name": "测评型", "description": "产品测评、真实体验"},
|
||||
]
|
||||
|
||||
VIDEO_STYLES = [
|
||||
{"id": "口播", "name": "口播", "description": "真人出镜讲解"},
|
||||
{"id": "图文", "name": "图文", "description": "图片+文字+配音"},
|
||||
{"id": "混剪", "name": "混剪", "description": "素材混剪+配音"},
|
||||
{"id": "剧情", "name": "剧情", "description": "情景剧演绎"},
|
||||
{"id": "Vlog", "name": "Vlog", "description": "记录式视频"},
|
||||
]
|
||||
|
||||
|
||||
class ScriptPromptBuilder:
|
||||
"""
|
||||
脚本 Prompt 构建器
|
||||
|
||||
用于构建家装行业短视频脚本的 System Prompt。
|
||||
"""
|
||||
|
||||
def build(
|
||||
self,
|
||||
duration: int = 30,
|
||||
script_type: str = "干货型",
|
||||
video_style: str = "口播",
|
||||
industry: str = "家装",
|
||||
tone: str | None = None,
|
||||
custom_requirements: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建系统 Prompt
|
||||
|
||||
Args:
|
||||
duration: 视频时长(秒)
|
||||
script_type: 脚本类型(干货型、故事型等)
|
||||
video_style: 视频风格(口播、剧情等)
|
||||
industry: 行业(家装)
|
||||
tone: 语气风格
|
||||
custom_requirements: 自定义要求
|
||||
|
||||
Returns:
|
||||
完整的 System Prompt
|
||||
"""
|
||||
# 基础 System Prompt
|
||||
base_prompt = load_script_system()
|
||||
|
||||
# 构建上下文信息
|
||||
context_parts = [
|
||||
f"行业:{industry}",
|
||||
f"时长:{duration}秒",
|
||||
f"类型:{script_type}",
|
||||
f"风格:{video_style}",
|
||||
]
|
||||
|
||||
if tone:
|
||||
context_parts.append(f"语气:{tone}")
|
||||
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
# 构建完整 Prompt
|
||||
full_prompt = f"""{base_prompt}
|
||||
|
||||
【创作要求】
|
||||
{context}
|
||||
"""
|
||||
|
||||
if custom_requirements:
|
||||
full_prompt += f"""
|
||||
【特殊要求】
|
||||
{custom_requirements}
|
||||
"""
|
||||
|
||||
# 添加输出格式要求
|
||||
full_prompt += """
|
||||
【输出格式】
|
||||
请严格按照以下 JSON 数组格式返回,每个元素代表一个镜头:
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "segment",
|
||||
"scene": "画面描述",
|
||||
"voiceover": "配音文案",
|
||||
"duration": "5s"
|
||||
}
|
||||
]
|
||||
|
||||
type 可以是:
|
||||
- "segment": 分镜(有画面+配音)
|
||||
- "empty_shot": 空镜(纯画面,voiceover 可为空)
|
||||
|
||||
注意:
|
||||
1. 只返回 JSON 数组,不要有其他文字
|
||||
2. 确保 JSON 格式正确
|
||||
3. 总时长必须严格控制在要求范围内
|
||||
"""
|
||||
|
||||
return full_prompt
|
||||
|
||||
|
||||
class PolishPromptBuilder:
|
||||
"""
|
||||
润色 Prompt 构建器
|
||||
|
||||
用于构建润色文案或画面描述的 Prompt。
|
||||
"""
|
||||
|
||||
POLISH_TYPES = {
|
||||
"scene": "画面描述",
|
||||
"voiceover": "配音文案",
|
||||
"text": "文案内容",
|
||||
}
|
||||
|
||||
def build(self, polish_type: str = "voiceover") -> str:
|
||||
"""
|
||||
构建润色 Prompt
|
||||
|
||||
Args:
|
||||
polish_type: 润色类型(scene/voiceover/text)
|
||||
|
||||
Returns:
|
||||
System Prompt
|
||||
"""
|
||||
type_name = self.POLISH_TYPES.get(polish_type, "文案")
|
||||
|
||||
if polish_type == "scene":
|
||||
return self._build_scene_prompt()
|
||||
else:
|
||||
return self._build_voiceover_prompt()
|
||||
|
||||
def _build_scene_prompt(self) -> str:
|
||||
"""构建画面描述润色 Prompt"""
|
||||
return """你是一位专业的视频画面描述优化师。你的任务是优化画面描述,使其更加生动、具体、有画面感。
|
||||
|
||||
优化要求:
|
||||
1. 增加细节描写(光线、色彩、构图)
|
||||
2. 使用专业的影视语言
|
||||
3. 描述要具体可执行
|
||||
4. 保持简洁,不要过度渲染
|
||||
5. 适合 AI 视频生成模型理解
|
||||
|
||||
请直接返回优化后的画面描述,不要添加解释。"""
|
||||
|
||||
def _build_voiceover_prompt(self) -> str:
|
||||
"""构建配音文案润色 Prompt"""
|
||||
return """你是一位专业的短视频文案编辑。你的任务是优化口播文案,使其更加流畅、有吸引力。
|
||||
|
||||
优化要求:
|
||||
1. 语言口语化,适合朗读
|
||||
2. 增加节奏感和停顿
|
||||
3. 保留核心信息点
|
||||
4. 适当使用修辞手法
|
||||
5. 控制字数,不要过长
|
||||
|
||||
请直接返回优化后的文案,不要添加解释。"""
|
||||
@@ -0,0 +1,13 @@
|
||||
你是一位口播短视频专家。请润色以下空镜画面描述,使其更适合AI视频生成:
|
||||
|
||||
【原文】
|
||||
{content}
|
||||
|
||||
【要求】
|
||||
- 保持原意,优化细节
|
||||
- 重点强调场景环境、空间氛围、光影效果、材质质感
|
||||
- 可以描述静态景物、装修细节、空间布局
|
||||
- 不要有"镜头""特写""机位"等摄影术语
|
||||
- 控制好字数,字数不能与原文差距超过20个字
|
||||
|
||||
直接输出润色后的描述,不要添加任何说明:
|
||||
@@ -0,0 +1,13 @@
|
||||
你是一位【口播短视频】专家。请润色以下分镜画面描述,使其更适合AI视频生成:
|
||||
|
||||
【原文】
|
||||
{content}
|
||||
|
||||
【要求】
|
||||
- 保持原意,优化细节
|
||||
- 重点强调人物神态、表情、动作、姿态
|
||||
- 描述人物与镜头前观众的互动
|
||||
- 不要有"镜头""特写""机位"等摄影术语
|
||||
- 控制好字数,字数不能与原文差距超过20个字
|
||||
|
||||
直接输出润色后的描述,不要添加任何说明:
|
||||
@@ -0,0 +1,12 @@
|
||||
你是一位短视频口播文案专家。请润色以下配音文案,使其更适合短视频口播:
|
||||
|
||||
【原文】
|
||||
{content}
|
||||
|
||||
【要求】
|
||||
- 口语化,像跟朋友聊天
|
||||
- 字数不能与原文差距超过10个字
|
||||
- 增加感染力
|
||||
- 不要有"综上所述"等书面语
|
||||
|
||||
直接输出润色后的文案,不要添加任何说明:
|
||||
@@ -0,0 +1,96 @@
|
||||
你是一位专业的【口播类短视频】脚本创作专家,专注于家装/装修领域的抖音/视频号口播内容创作。
|
||||
|
||||
【平台适配要求】
|
||||
1. 竖屏拍摄(9:16比例),画面构图以人物为主体
|
||||
2. 台词口语化、接地气,像跟朋友聊天,避免"综上所述""研究表明"等书面语
|
||||
3. 语速稍快有节奏感,每句15-25字,一口气说完不换气,不拖沓
|
||||
4. 避免专业术语堆砌,用业主听得懂的大白话
|
||||
5. 符合新媒体用户观看习惯:3秒定生死,节奏紧凑
|
||||
|
||||
【画面描述标准 - 人物为主,环境为辅】
|
||||
画面描述以【人物状态、表情、动作、情绪】为主。
|
||||
不要写"镜头推近""特写""中景"等摄影术语。
|
||||
每句画面描述控制在 50-70 字,确保有足够细节用于 AI 视频生成。
|
||||
|
||||
❌ 差的示例:
|
||||
"中景竖屏,主播站在毛坯房中央,背景是一面待装修的空白墙面,自然光从右侧窗户照入,主播表情真诚略带焦急,直视镜头说话。"
|
||||
(问题:太多环境描写,太多镜头术语)
|
||||
|
||||
✅ 好的示例:
|
||||
"主播站在空旷的毛坯房里,右手拿着黄色卷尺,他缓缓抬头,表情严肃地看向你,身后是未装修的水泥墙面,神态专业务实。"
|
||||
(聚焦人物:在哪、拿什么、什么表情、看什么)
|
||||
|
||||
【黄金3秒法则 - 开场必须抓眼】
|
||||
- 杜绝铺垫!不要"大家好我是XX""今天给大家讲个事"
|
||||
- 直接击中业主痛点或好奇心,让手指停不下来
|
||||
- 钩子示例:
|
||||
* "装修被坑了8万的业主,昨天来找我哭诉..."
|
||||
* "为什么同样的户型,你家装修比别人贵5万?"
|
||||
* "停!先别急着签合同,这条视频能救你3万块钱"
|
||||
* "每年都有500位业主找我装修,只因为我说透了这一点..."
|
||||
|
||||
【中间内容要求 - 降低跳出率】
|
||||
- 有干货:给出具体数字、方法、避坑点
|
||||
- 有冲突:制造认知反差或情绪起伏
|
||||
- 有看点:适当加入真实案例、现场画面
|
||||
- 避免空洞:不说"我们专业靠谱",而是"我做了12年装修,见过387个踩坑案例..."
|
||||
|
||||
【最后7秒 - 留资引导(必须可落地)】
|
||||
- 必须有明确、可执行的动作指令
|
||||
- 给业主一个无法拒绝的理由(免费、限时、专属)
|
||||
- 示例话术:
|
||||
* "评论区扣'装修报价',免费领本地3套装修方案+精准报价单"
|
||||
* "私信'装修'两个字,预约设计师免费上门量房、出平面布局图"
|
||||
* "点击左下角小风车,一键获取你家专属装修预算,绝无隐形消费"
|
||||
* "前20名扣1的业主,送全屋水电VR存档,后期维修不砸墙"
|
||||
- ❌ 杜绝空泛引导:"需要装修的联系我们""想了解的私信我"
|
||||
|
||||
【分镜使用原则】
|
||||
- 分镜(segment)用于"主播”出镜的镜头
|
||||
- 【重要】分镜之间要保证画面的连贯性
|
||||
- 分镜 scene 示例:
|
||||
"主播缓缓竖起第三根手指,嘴角扬起一抹了然的笑意。他身体微微前倾,目光柔和地看向前方,仿佛正与屏幕对面的人分享一个轻松的秘密。手指在空中短暂停留,带着从容的节奏。"
|
||||
|
||||
【脚本类型说明】
|
||||
- 对比型:前后反差,制造冲击
|
||||
- 恐吓型:直击痛点,先吓再给解药
|
||||
- 干货型:输出实用方法,建立专业度
|
||||
- 共情型:说业主想说的话,引发共鸣
|
||||
- 挑战型:设定目标,增加悬念
|
||||
- 福利型:用福利钩子吸引停留和留资
|
||||
|
||||
【镜头数量参考】
|
||||
- 30秒短视频:5-7个分镜
|
||||
- 45秒短视频:7-9个分镜
|
||||
- 60秒短视频:10-12个分镜
|
||||
- 75秒短视频:12-15个分镜
|
||||
- 每个分镜时长不得少于3秒
|
||||
- 实际总时长不与用户所选差距超过3秒
|
||||
|
||||
【输出格式要求】
|
||||
请以 JSON 数组格式输出,每个元素包含:
|
||||
- id: 序号(从 1 开始)
|
||||
- type: "segment"(主播口播出镜)
|
||||
- scene: 画面描述(分镜聚焦人物:在哪、干什么、什么表情,什么动作,什么情绪,涉及道具不要出现掏出、拿出这类的动作,不要出现文字,不写镜头术语,不写环境细节;空镜聚焦场景、事物、氛围、环境;)
|
||||
- voiceover: 配音文案(必填,口语化15-25字/句)
|
||||
- duration: 时长(如 "5s")
|
||||
|
||||
【示例】
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "segment",
|
||||
"scene": "主播缓缓竖起第三根手指,嘴角扬起一抹了然的笑意。他身体微微前倾,目光柔和地看向前方,仿佛正与屏幕对面的人分享一个轻松的秘密。手指在空中短暂停留,带着从容的节奏。",
|
||||
"voiceover": "装修被坑了8万的业主,昨天来找我哭诉...",
|
||||
"duration": "5s"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "segment",
|
||||
"scene": "主播竖起第二根手指,眉头微皱,嘴角向下撇,眼神中带着一丝不满与无奈。他身体微微前倾,仿佛正对着镜头对面的观众倾诉,手指随着说话轻轻晃动,像是细数着那些令人头疼的业主经历。",
|
||||
"voiceover": "第一个坑,水电改造。很多人图便宜找游击队,结果漏水漏电!",
|
||||
"duration": "8s"
|
||||
}
|
||||
]
|
||||
|
||||
注意:只输出纯 JSON,不要包含 markdown 代码块或其他说明文字。
|
||||
@@ -0,0 +1,114 @@
|
||||
你是一位专业的【口播类短视频】脚本创作专家,专注于家装/装修领域的抖音/视频号口播内容创作。
|
||||
|
||||
【平台适配要求】
|
||||
1. 竖屏拍摄(9:16比例),画面构图以人物为主体
|
||||
2. 台词口语化、接地气,像跟朋友聊天,避免"综上所述""研究表明"等书面语
|
||||
3. 语速稍快有节奏感,每句15-25字,一口气说完不换气,不拖沓
|
||||
4. 避免专业术语堆砌,用业主听得懂的大白话
|
||||
5. 符合新媒体用户观看习惯:3秒定生死,节奏紧凑
|
||||
|
||||
【画面描述标准 - 人物为主,环境为辅】
|
||||
画面描述以【人物状态、表情、动作、情绪】为主。
|
||||
不要写"镜头推近""特写""中景"等摄影术语。
|
||||
每句画面描述控制在 50-70 字,确保有足够细节用于 AI 视频生成。
|
||||
|
||||
❌ 差的示例:
|
||||
"中景竖屏,主播站在毛坯房中央,背景是一面待装修的空白墙面,自然光从右侧窗户照入,主播表情真诚略带焦急,直视镜头说话。"
|
||||
(问题:太多环境描写,太多镜头术语)
|
||||
|
||||
✅ 好的示例:
|
||||
"主播站在空旷的毛坯房里,右手拿着黄色卷尺,他缓缓抬头,表情严肃地看向你,身后是未装修的水泥墙面,神态专业务实。"
|
||||
(聚焦人物:在哪、拿什么、什么表情、看什么)
|
||||
|
||||
【黄金3秒法则 - 开场必须抓眼】
|
||||
- 杜绝铺垫!不要"大家好我是XX""今天给大家讲个事"
|
||||
- 直接击中业主痛点或好奇心,让手指停不下来
|
||||
- 钩子示例:
|
||||
* "装修被坑了8万的业主,昨天来找我哭诉..."
|
||||
* "为什么同样的户型,你家装修比别人贵5万?"
|
||||
* "停!先别急着签合同,这条视频能救你3万块钱"
|
||||
* "每年都有500位业主找我装修,只因为我说透了这一点..."
|
||||
|
||||
【中间内容要求 - 降低跳出率】
|
||||
- 有干货:给出具体数字、方法、避坑点
|
||||
- 有冲突:制造认知反差或情绪起伏
|
||||
- 有看点:适当加入真实案例、现场画面
|
||||
- 避免空洞:不说"我们专业靠谱",而是"我做了12年装修,见过387个踩坑案例..."
|
||||
|
||||
【最后7秒 - 留资引导(必须可落地)】
|
||||
- 必须有明确、可执行的动作指令
|
||||
- 给业主一个无法拒绝的理由(免费、限时、专属)
|
||||
- 示例话术:
|
||||
* "评论区扣'装修报价',免费领本地3套装修方案+精准报价单"
|
||||
* "私信'装修'两个字,预约设计师免费上门量房、出平面布局图"
|
||||
* "点击左下角小风车,一键获取你家专属装修预算,绝无隐形消费"
|
||||
* "前20名扣1的业主,送全屋水电VR存档,后期维修不砸墙"
|
||||
- ❌ 杜绝空泛引导:"需要装修的联系我们""想了解的私信我"
|
||||
|
||||
【分镜使用原则】
|
||||
- 分镜(segment)用于"主播”出镜的镜头
|
||||
- 【重要】分镜之间要保证画面的连贯性
|
||||
- 分镜 scene 示例:
|
||||
"主播缓缓竖起第三根手指,嘴角扬起一抹了然的笑意。他身体微微前倾,目光柔和地看向前方,仿佛正与屏幕对面的人分享一个轻松的秘密。手指在空中短暂停留,带着从容的节奏。"
|
||||
|
||||
【空镜使用原则】
|
||||
- 空镜(empty_shot)用于"不需要主播出镜、但需要展示具体画面"的场景或者两个镜头的过渡切换
|
||||
- 空镜数量控制在 1-4 个即可
|
||||
- 【重要】空镜的 scene 字段要详细生动,包含:场景环境、光影氛围、物体细节、动作状态
|
||||
- 空镜 scene 示例:
|
||||
"现代简约客厅,落地窗外是城市夜景,暖黄色灯光从吊顶洒下,米色布艺沙发前是一张原木茶几,茶几上放着一杯冒着热气的咖啡,画面温馨舒适,景深效果突出主体"
|
||||
- 空镜 scene 示例(差):"客厅场景"(太简单,无法生成视频)
|
||||
- 空镜不需要主播出镜,所以不写"主播、也不要出现镜头字眼",而是写场景、物体、氛围
|
||||
- 空镜不要连续出现
|
||||
- 【重要】空镜也需要配音文案(voiceover),作为画外音旁白配合画面展示
|
||||
|
||||
【脚本类型说明】
|
||||
- 对比型:前后反差,制造冲击
|
||||
- 恐吓型:直击痛点,先吓再给解药
|
||||
- 干货型:输出实用方法,建立专业度
|
||||
- 共情型:说业主想说的话,引发共鸣
|
||||
- 挑战型:设定目标,增加悬念
|
||||
- 福利型:用福利钩子吸引停留和留资
|
||||
|
||||
【镜头数量参考】
|
||||
- 30秒短视频:5-7个分镜
|
||||
- 45秒短视频:7-9个分镜
|
||||
- 60秒短视频:10-12个分镜
|
||||
- 75秒短视频:12-15个分镜
|
||||
- 空镜固定时长5秒
|
||||
- 每个分镜时长不得少于3秒
|
||||
|
||||
【输出格式要求】
|
||||
请以 JSON 数组格式输出,每个元素包含:
|
||||
- id: 序号(从 1 开始)
|
||||
- type: "segment"(主播口播出镜)或 "empty_shot"(空镜补充)
|
||||
- scene: 画面描述(分镜聚焦人物:在哪、干什么、什么表情,什么动作,什么情绪,不写镜头术语,不写环境细节;空镜聚焦场景、事物、氛围、环境;)
|
||||
- voiceover: 配音文案(必填,口语化15-25字/句)
|
||||
- duration: 时长(如 "5s")
|
||||
|
||||
【示例】
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "segment",
|
||||
"scene": "主播缓缓竖起第三根手指,嘴角扬起一抹了然的笑意。他身体微微前倾,目光柔和地看向前方,仿佛正与屏幕对面的人分享一个轻松的秘密。手指在空中短暂停留,带着从容的节奏。",
|
||||
"voiceover": "装修被坑了8万的业主,昨天来找我哭诉...",
|
||||
"duration": "5s"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "segment",
|
||||
"scene": "主播竖起第二根手指,眉头微皱,嘴角向下撇,眼神中带着一丝不满与无奈。他身体微微前倾,仿佛正对着镜头对面的观众倾诉,手指随着说话轻轻晃动,像是细数着那些令人头疼的业主经历。",
|
||||
"voiceover": "第一个坑,水电改造。很多人图便宜找游击队,结果漏水漏电!",
|
||||
"duration": "8s"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "empty_shot",
|
||||
"scene": "现代装修施工现场,地面开槽露出整齐排列的PPR水管,蓝色水管与红色线管形成对比,专业工人戴白色安全帽手持热熔机作业,背景虚化突出管线细节,自然光从左上方窗户洒入,4K画质,浅景深,暖色调,镜头缓慢推进营造专业严谨氛围",
|
||||
"voiceover": "看,这就是专业的水电施工现场,每根管线都有标准",
|
||||
"duration": "5s"
|
||||
}
|
||||
]
|
||||
|
||||
注意:只输出纯 JSON,不要包含 markdown 代码块或其他说明文字。
|
||||
@@ -0,0 +1,10 @@
|
||||
请根据以下要求,创作一份口播类短视频分镜脚本:
|
||||
|
||||
【创作主题】
|
||||
$topic
|
||||
|
||||
【视频时长】
|
||||
约 $duration 秒,正负不超过3秒。
|
||||
|
||||
【脚本类型】
|
||||
$type
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
LLM Provider 导出
|
||||
=================
|
||||
"""
|
||||
|
||||
from app.ai.providers.base import (
|
||||
GenerationResult,
|
||||
LLMProvider,
|
||||
ModelHealth,
|
||||
ModelUnavailableError,
|
||||
ProviderError,
|
||||
)
|
||||
from app.ai.providers.generic_llm_provider import GenericLLMProvider, MockProvider
|
||||
|
||||
# 火山方舟官方 SDK Provider
|
||||
# 需要: pip install 'volcengine-python-sdk[ark]'
|
||||
try:
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||||
|
||||
VOLCENGINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOLCENGINE_AVAILABLE = False
|
||||
VolcengineProvider = None
|
||||
|
||||
# 可灵 AI Provider
|
||||
# 需要: pip install pyjwt
|
||||
try:
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
|
||||
KLINGAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
KLINGAI_AVAILABLE = False
|
||||
KlingAIProvider = None
|
||||
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"GenerationResult",
|
||||
"ModelHealth",
|
||||
"ProviderError",
|
||||
"ModelUnavailableError",
|
||||
"GenericLLMProvider",
|
||||
"MockProvider",
|
||||
]
|
||||
|
||||
if VOLCENGINE_AVAILABLE:
|
||||
__all__.append("VolcengineProvider")
|
||||
|
||||
if KLINGAI_AVAILABLE:
|
||||
__all__.append("KlingAIProvider")
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
LLM Provider 抽象基类
|
||||
=====================
|
||||
|
||||
定义所有 AI 模型提供商的统一接口。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelHealth(BaseModel):
|
||||
"""模型健康状态"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
is_available: bool
|
||||
response_time: float # 毫秒
|
||||
last_error: str | None = None
|
||||
|
||||
|
||||
class GenerationResult(BaseModel):
|
||||
"""生成结果"""
|
||||
|
||||
content: str
|
||||
usage: dict | None = None # token 用量等
|
||||
model: str # 实际使用的模型
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
LLM 提供商抽象基类
|
||||
|
||||
所有 AI 模型提供商(OpenAI、文心一言、通义千问等)需实现此接口。
|
||||
"""
|
||||
|
||||
# 提供商标识
|
||||
provider_id: str = ""
|
||||
provider_name: str = ""
|
||||
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None, **kwargs):
|
||||
"""
|
||||
初始化 Provider
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
base_url: 自定义 Base URL(用于代理或私有部署)
|
||||
**kwargs: 其他配置参数
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.config = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
同步生成文本
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model: 模型名称,None 则使用默认模型
|
||||
temperature: 随机性(0-2)
|
||||
max_tokens: 最大生成 token 数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
GenerationResult: 生成结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
流式生成文本
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model: 模型名称
|
||||
temperature: 随机性
|
||||
max_tokens: 最大 token 数
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
str: 生成的文本片段
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||||
"""
|
||||
健康检查
|
||||
|
||||
Args:
|
||||
model: 指定模型,None 则检查默认模型
|
||||
|
||||
Returns:
|
||||
ModelHealth: 健康状态
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def available_models(self) -> list[str]:
|
||||
"""返回可用的模型列表"""
|
||||
pass
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
"""Provider 调用异常"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, provider_id: str = "", original_error: Exception | None = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.provider_id = provider_id
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class ModelUnavailableError(ProviderError):
|
||||
"""模型不可用异常"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
OpenAI Provider 实现
|
||||
====================
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.ai.providers.base import (
|
||||
GenerationResult,
|
||||
LLMProvider,
|
||||
ModelHealth,
|
||||
ProviderError,
|
||||
)
|
||||
|
||||
|
||||
class GenericLLMProvider(LLMProvider):
|
||||
"""
|
||||
OpenAI / OpenAI 兼容 API Provider
|
||||
|
||||
支持:
|
||||
- OpenAI 官方 API
|
||||
- Azure OpenAI
|
||||
- 任何 OpenAI 兼容接口(如本地 vLLM)
|
||||
"""
|
||||
|
||||
provider_id = "openai"
|
||||
provider_name = "OpenAI"
|
||||
|
||||
# 默认可用模型
|
||||
DEFAULT_MODELS = [
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None, **kwargs):
|
||||
super().__init__(api_key, base_url, **kwargs)
|
||||
|
||||
if not self.api_key:
|
||||
raise ProviderError("OpenAI API Key 未配置", provider_id=self.provider_id)
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url or "https://api.openai.com/v1",
|
||||
)
|
||||
self.default_model = kwargs.get("default_model", "gpt-3.5-turbo")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""同步生成"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model or self.default_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
content=response.choices[0].message.content or "",
|
||||
usage=response.usage.model_dump() if response.usage else None,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"OpenAI 生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式生成"""
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=model or self.default_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"OpenAI 流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||||
"""健康检查"""
|
||||
start_time = time.time()
|
||||
test_model = model or self.default_model
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=test_model,
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=5,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ModelHealth(
|
||||
id=test_model,
|
||||
name=f"OpenAI {test_model}",
|
||||
is_available=True,
|
||||
response_time=response_time,
|
||||
last_error=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ModelHealth(
|
||||
id=test_model,
|
||||
name=f"OpenAI {test_model}",
|
||||
is_available=False,
|
||||
response_time=(time.time() - start_time) * 1000,
|
||||
last_error=str(e),
|
||||
)
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
"""返回可用模型列表"""
|
||||
return self.config.get("models", self.DEFAULT_MODELS)
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""
|
||||
Mock Provider - 用于测试和演示
|
||||
|
||||
不调用真实 API,返回模拟 JSON 数据。
|
||||
"""
|
||||
|
||||
provider_id = "mock"
|
||||
provider_name = "Mock(测试)"
|
||||
|
||||
def _extract_content_from_prompt(self, prompt: str) -> str:
|
||||
"""从 prompt 中提取原文内容"""
|
||||
import re
|
||||
|
||||
# 匹配 【原文】和【润色要求】之间的内容
|
||||
match = re.search(r"【原文】\s*(.+?)\s*【润色要求】", prompt, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return "优化后的文案"
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""模拟生成 - 根据 prompt 类型返回不同格式数据"""
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
await asyncio.sleep(0.5) # 模拟延迟
|
||||
|
||||
# 检测是否为润色请求
|
||||
if "润色" in prompt or "polish" in prompt.lower():
|
||||
# 返回润色后的文本
|
||||
original = self._extract_content_from_prompt(prompt)
|
||||
polished = f"【润色后】{original}——这句话说得更有感染力了,适合短视频口播!"
|
||||
return GenerationResult(
|
||||
content=polished,
|
||||
usage={"prompt_tokens": 50, "completion_tokens": 50, "total_tokens": 100},
|
||||
model=model or "mock-model",
|
||||
)
|
||||
|
||||
# 否则返回脚本生成的 JSON 数据
|
||||
mock_shots = [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "segment",
|
||||
"scene": "镜头从门外缓缓推入,展示客厅整体布局,自然光从落地窗洒入",
|
||||
"voiceover": "大家好,今天给大家讲讲家装验收最容易被忽略的5个细节",
|
||||
"duration": "5s",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "segment",
|
||||
"scene": "特写墙面,手指划过检查平整度,展示一处细微裂纹",
|
||||
"voiceover": "第一,墙面验收。很多人只看颜色,其实平整度和裂纹更重要",
|
||||
"duration": "8s",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "segment",
|
||||
"scene": "蹲下来拍摄地板接缝处,展示踢脚线与地板的缝隙",
|
||||
"voiceover": "第二,地板验收。重点看接缝是否均匀,踢脚线是否贴合",
|
||||
"duration": "8s",
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "empty_shot",
|
||||
"scene": "现代简约风格卫生间,白色瓷砖,柔和灯光,镜头缓慢平移",
|
||||
"voiceover": "",
|
||||
"duration": "3s",
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "segment",
|
||||
"scene": "打开水龙头,检查水流和水压,特写地漏排水速度",
|
||||
"voiceover": "第三,水电验收。测试所有开关、龙头,检查排水是否顺畅",
|
||||
"duration": "8s",
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "segment",
|
||||
"scene": "开关面板特写,逐一测试灯光开关,展示一处松动的面板",
|
||||
"voiceover": "第四,电路验收。每个开关都要试,面板安装是否牢固",
|
||||
"duration": "7s",
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "segment",
|
||||
"scene": "主人公安慰地微笑,竖起大拇指,背景是温馨的客厅",
|
||||
"voiceover": "记住这5点,验收不踩坑!关注我,更多家装干货等你",
|
||||
"duration": "6s",
|
||||
},
|
||||
]
|
||||
|
||||
return GenerationResult(
|
||||
content=json.dumps(mock_shots, ensure_ascii=False),
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 200, "total_tokens": 300},
|
||||
model=model or "mock-model",
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""模拟流式生成 - 返回脚本 JSON"""
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# 检测是否为润色请求
|
||||
if "润色" in prompt or "polish" in prompt.lower():
|
||||
response = "【润色后】优化后的文案,更适合短视频口播!"
|
||||
else:
|
||||
# 返回脚本生成的 JSON 数据
|
||||
mock_shots = [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "segment",
|
||||
"scene": "主播站在毛坯房里,表情严肃",
|
||||
"voiceover": "装修被坑了8万的业主,昨天来找我哭诉...",
|
||||
"duration": "5s",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "segment",
|
||||
"scene": "主播指着墙面,手指划过",
|
||||
"voiceover": "第一坑,水电改造!很多人图便宜找游击队",
|
||||
"duration": "8s",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "empty_shot",
|
||||
"scene": "现代装修施工现场,水电管线整齐排列,4K画质",
|
||||
"voiceover": "看,这就是专业施工",
|
||||
"duration": "3s",
|
||||
},
|
||||
]
|
||||
response = json.dumps(mock_shots, ensure_ascii=False)
|
||||
|
||||
# 流式输出
|
||||
chunk_size = 10 # 每10个字符一个chunk
|
||||
for i in range(0, len(response), chunk_size):
|
||||
yield response[i : i + chunk_size]
|
||||
await asyncio.sleep(0.05) # 模拟打字机效果
|
||||
|
||||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||||
"""模拟健康检查"""
|
||||
return ModelHealth(
|
||||
id=model or "mock-model",
|
||||
name="Mock Model",
|
||||
is_available=True,
|
||||
response_time=50.0,
|
||||
last_error=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
return ["mock-model", "mock-gpt-3.5", "mock-gpt-4"]
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Kling AI Provider DTO
|
||||
=====================
|
||||
|
||||
Provider 层数据模型,封装 Kling API 返回结构。
|
||||
禁止向业务层泄漏裸 dict[str, Any]。
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.enums import KlingTaskStatus
|
||||
|
||||
|
||||
class KlingVideoResult(BaseModel):
|
||||
"""Kling 视频生成结果"""
|
||||
|
||||
task_id: str | None = Field(None, alias="task_id")
|
||||
task_status: KlingTaskStatus | None = Field(None, alias="task_status")
|
||||
task_status_msg: str | None = Field(None, alias="task_status_msg")
|
||||
task_result: dict | None = Field(None, alias="task_result")
|
||||
|
||||
|
||||
class KlingImageResult(BaseModel):
|
||||
"""Kling 图片生成结果"""
|
||||
|
||||
task_id: str | None = Field(None, alias="task_id")
|
||||
task_status: KlingTaskStatus | None = Field(None, alias="task_status")
|
||||
task_status_msg: str | None = Field(None, alias="task_status_msg")
|
||||
task_result: dict | None = Field(None, alias="task_result")
|
||||
|
||||
|
||||
class KlingVoiceResult(BaseModel):
|
||||
"""Kling 自定义音色结果"""
|
||||
|
||||
task_id: str | None = Field(None, alias="task_id")
|
||||
task_status: KlingTaskStatus | None = Field(None, alias="task_status")
|
||||
task_result: dict | None = Field(None, alias="task_result")
|
||||
|
||||
|
||||
class KlingElementResult(BaseModel):
|
||||
"""Kling 主体创建结果"""
|
||||
|
||||
task_id: str | None = Field(None, alias="task_id")
|
||||
task_status: KlingTaskStatus | None = Field(None, alias="task_status")
|
||||
task_result: dict | None = Field(None, alias="task_result")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
火山方舟官方 SDK Provider
|
||||
==========================
|
||||
|
||||
基于火山方舟官方 Python SDK 实现,支持:
|
||||
- 文本生成 (Chat Completions)
|
||||
- 流式输出
|
||||
- 图片生成
|
||||
- 向量化
|
||||
- 深度思考
|
||||
- 工具调用
|
||||
|
||||
安装依赖:
|
||||
pip install 'volcengine-python-sdk[ark]'
|
||||
|
||||
文档:
|
||||
https://www.volcengine.com/docs/82379
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.ai.providers.base import (
|
||||
GenerationResult,
|
||||
LLMProvider,
|
||||
ModelHealth,
|
||||
ProviderError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入火山方舟 SDK
|
||||
try:
|
||||
from volcenginesdkarkruntime import Ark
|
||||
|
||||
VOLCENGINE_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOLCENGINE_SDK_AVAILABLE = False
|
||||
logger.warning("火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'")
|
||||
|
||||
|
||||
class VolcengineProvider(LLMProvider):
|
||||
"""
|
||||
火山方舟官方 SDK Provider
|
||||
|
||||
支持多模态能力:
|
||||
- 文本对话 (Chat Completions)
|
||||
- 图片生成 (Image Generation)
|
||||
- 向量化 (Embeddings)
|
||||
- 深度思考 (Reasoning)
|
||||
"""
|
||||
|
||||
provider_id = "volcengine"
|
||||
provider_name = "火山方舟"
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
|
||||
DEFAULT_TIMEOUT = 1800 # 秒,方舟推荐 1800 秒以上
|
||||
|
||||
# 模型 ID 映射(从配置文件加载)
|
||||
PRESET_MODELS: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def load_models_from_config(cls, models: list[dict]):
|
||||
"""
|
||||
从配置文件加载模型列表
|
||||
|
||||
Args:
|
||||
models: 模型列表,每个模型包含 model_name 字段
|
||||
"""
|
||||
cls.PRESET_MODELS = {}
|
||||
for model in models:
|
||||
model_id = model.get("model_name")
|
||||
model_alias = model.get("id")
|
||||
if model_id and model_alias:
|
||||
cls.PRESET_MODELS[model_alias] = model_id
|
||||
|
||||
# 确保至少有一个默认模型
|
||||
if not cls.PRESET_MODELS:
|
||||
cls.PRESET_MODELS = {
|
||||
"doubao-seed-2-0-lite": "doubao-seed-2-0-lite-260215",
|
||||
}
|
||||
|
||||
logger.info(f"已加载 {len(cls.PRESET_MODELS)} 个模型: {list(cls.PRESET_MODELS.keys())}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
default_model: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
初始化火山方舟 Provider
|
||||
|
||||
Args:
|
||||
api_key: API Key,从火山方舟控制台获取
|
||||
base_url: API 基础地址,默认北京节点
|
||||
timeout: 请求超时时间(秒)
|
||||
default_model: 默认模型(Model ID)
|
||||
"""
|
||||
super().__init__(api_key, base_url, **kwargs)
|
||||
|
||||
if not VOLCENGINE_SDK_AVAILABLE:
|
||||
raise ProviderError(
|
||||
"火山方舟 SDK 未安装,请运行: pip install 'volcengine-python-sdk[ark]'",
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
raise ProviderError("火山方舟 API Key 未配置", provider_id=self.provider_id)
|
||||
|
||||
self.timeout = timeout
|
||||
# 使用模型 ID 映射(自动映射)
|
||||
if default_model:
|
||||
self.default_model = self.PRESET_MODELS.get(default_model, default_model)
|
||||
elif self.PRESET_MODELS:
|
||||
# 兜底:使用 doubao-seed-2-0-lite 或第一个可用的模型
|
||||
self.default_model = self.PRESET_MODELS.get(
|
||||
"doubao-seed-2-0-lite", list(self.PRESET_MODELS.values())[0]
|
||||
)
|
||||
else:
|
||||
# 兜底:使用一个默认模型ID(如果用户未配置任何模型)
|
||||
self.default_model = "doubao-seed-2-0-lite-260215"
|
||||
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self) -> Ark:
|
||||
"""创建火山方舟客户端"""
|
||||
return Ark(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url or self.DEFAULT_BASE_URL,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
同步生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
model: 模型 ID(如 doubao-seed-2-0-pro-260215)
|
||||
temperature: 随机性 (0-2)
|
||||
max_tokens: 最大生成 token 数
|
||||
system_prompt: 系统提示词(可选)
|
||||
**kwargs: 额外参数(如 enable_thinking 启用深度思考)
|
||||
|
||||
Returns:
|
||||
GenerationResult: 生成结果
|
||||
"""
|
||||
try:
|
||||
# 构建消息
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 映射模型名称到模型 ID
|
||||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||||
|
||||
# 构建请求参数
|
||||
request_params = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 额外参数(如深度思考)
|
||||
if "enable_thinking" in kwargs:
|
||||
request_params["extra_body"] = {"enable_thinking": kwargs["enable_thinking"]}
|
||||
|
||||
# 调用 API
|
||||
completion = self.client.chat.completions.create(**request_params)
|
||||
|
||||
# 解析结果
|
||||
content = completion.choices[0].message.content or ""
|
||||
usage = None
|
||||
if completion.usage:
|
||||
usage = {
|
||||
"prompt_tokens": completion.usage.prompt_tokens,
|
||||
"completion_tokens": completion.usage.completion_tokens,
|
||||
"total_tokens": completion.usage.total_tokens,
|
||||
}
|
||||
|
||||
return GenerationResult(
|
||||
content=content,
|
||||
usage=usage,
|
||||
model=completion.model or model or self.default_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"火山方舟生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
流式生成文本
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
model: 模型名称
|
||||
temperature: 随机性
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词(可选)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Yields:
|
||||
str: 生成的文本片段
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||||
|
||||
request_params = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
stream = self.client.chat.completions.create(**request_params)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def generate_stream_with_progress(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = 8000,
|
||||
system_prompt: str | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[dict]:
|
||||
"""
|
||||
流式生成文本,带进度信息
|
||||
|
||||
Args:
|
||||
prompt: 用户提示词
|
||||
model: 模型名称
|
||||
temperature: 随机性
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词(可选)
|
||||
|
||||
Yields:
|
||||
dict: {
|
||||
"type": "chunk" | "usage",
|
||||
"content": str, # 文本片段(type=chunk时)
|
||||
"total_tokens": int, # 累计token数(type=chunk时)
|
||||
"prompt_tokens": int, # 提示词token数(type=usage时)
|
||||
"completion_tokens": int, # 生成token数(type=usage时)
|
||||
}
|
||||
"""
|
||||
try:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
model_id = self.PRESET_MODELS.get(model, model) if model else self.default_model
|
||||
|
||||
request_params = {
|
||||
"model": model_id,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
request_params["max_tokens"] = max_tokens
|
||||
|
||||
# 流式调用
|
||||
stream = self.client.chat.completions.create(**request_params)
|
||||
|
||||
total_chars = 0
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
|
||||
for chunk in stream:
|
||||
# 获取文本内容
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
total_chars += len(content)
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"content": content,
|
||||
"total_chars": total_chars,
|
||||
}
|
||||
|
||||
# 获取使用统计(最后一个chunk)
|
||||
if chunk.usage:
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
|
||||
# 发送最终统计
|
||||
yield {
|
||||
"type": "usage",
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"火山方舟流式生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def generate_image(
|
||||
self, prompt: str, model: str | None = None, size: str = "1024x1024", **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
生成图片(Seedream 系列)
|
||||
|
||||
Args:
|
||||
prompt: 图片提示词
|
||||
model: 图片模型 ID
|
||||
size: 图片尺寸
|
||||
|
||||
Returns:
|
||||
dict: 包含图片 URL 或 base64 数据
|
||||
"""
|
||||
try:
|
||||
# 图片生成需要单独的图片模型,不在当前配置中
|
||||
# 如需使用,请在模型广场开通 doubao-seed-1.6 并配置
|
||||
image_model = model or "doubao-seed-1.6-flash-250828"
|
||||
response = self.client.images.generate(
|
||||
model=image_model, prompt=prompt, size=size, **kwargs
|
||||
)
|
||||
|
||||
# 解析图片结果
|
||||
images = []
|
||||
for img in response.data:
|
||||
images.append(
|
||||
{
|
||||
"url": img.url,
|
||||
"b64_json": img.b64_json,
|
||||
"revised_prompt": img.revised_prompt,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"火山方舟图片生成失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def create_embeddings(self, texts: list[str], model: str | None = None, **kwargs) -> dict:
|
||||
"""
|
||||
文本向量化
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
model: 向量化模型
|
||||
|
||||
Returns:
|
||||
dict: 包含向量化结果
|
||||
"""
|
||||
try:
|
||||
response = self.client.embeddings.create(
|
||||
model=model or "doubao-embedding-1.5", input=texts, **kwargs
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
for item in response.data:
|
||||
embeddings.append(
|
||||
{
|
||||
"index": item.index,
|
||||
"embedding": item.embedding,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"embeddings": embeddings,
|
||||
"model": response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ProviderError(
|
||||
f"火山方舟向量化失败: {str(e)}", provider_id=self.provider_id, original_error=e
|
||||
)
|
||||
|
||||
async def health_check(self, model: str | None = None) -> ModelHealth:
|
||||
"""健康检查"""
|
||||
start_time = time.time()
|
||||
test_model = model or self.default_model
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=test_model,
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ModelHealth(
|
||||
id=test_model,
|
||||
name=f"火山方舟 {test_model}",
|
||||
is_available=True,
|
||||
response_time=response_time,
|
||||
last_error=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ModelHealth(
|
||||
id=test_model,
|
||||
name=f"火山方舟 {test_model}",
|
||||
is_available=False,
|
||||
response_time=(time.time() - start_time) * 1000,
|
||||
last_error=str(e),
|
||||
)
|
||||
|
||||
@property
|
||||
def available_models(self) -> list[str]:
|
||||
"""返回可用模型列表(与 ai_models.yaml 配置保持一致)"""
|
||||
return [
|
||||
"doubao-seed-2-0-pro",
|
||||
"deepseek-v3-2",
|
||||
"doubao-seed-2-0-lite",
|
||||
"doubao-lite-32k",
|
||||
]
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
依赖注入工具
|
||||
============
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.security import verify_token
|
||||
from app.db.session import get_db as db_session
|
||||
from app.models.user import User
|
||||
|
||||
settings = get_settings()
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
async def get_db() -> AsyncSession:
|
||||
"""获取数据库 Session"""
|
||||
async for session in db_session():
|
||||
yield session
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
获取当前登录用户
|
||||
|
||||
从 Authorization Header 中提取 JWT Token 并验证
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="缺少认证信息",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
payload = verify_token(token)
|
||||
|
||||
if payload is None or payload.get("sub") is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证信息",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
db=Depends(get_db),
|
||||
) -> User | None:
|
||||
"""
|
||||
获取当前登录用户(可选,未登录返回 None)
|
||||
"""
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
@@ -0,0 +1,552 @@
|
||||
"""
|
||||
AI 模型管理与生成 API
|
||||
=====================
|
||||
|
||||
提供模型列表查询、文本生成、脚本生成、润色等功能。
|
||||
|
||||
模型配置存储在 config/ai_models.yaml,支持热重载。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.ai.model_router import get_model_router
|
||||
from app.core.config_loader import get_config_loader
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.services.ai_response_utils import (
|
||||
safe_parse_ai_json_response,
|
||||
validate_and_normalize_shots,
|
||||
validate_shots_structure,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 请求/响应 Schema ============
|
||||
|
||||
|
||||
class PlatformResponse(BaseModel):
|
||||
"""平台响应"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
provider: str
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
"""模型响应"""
|
||||
|
||||
id: str
|
||||
platform_id: str
|
||||
model_name: str
|
||||
display_name: str
|
||||
capabilities: list[str]
|
||||
default_params: dict
|
||||
is_enabled: bool
|
||||
full_model_id: str
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""生成请求"""
|
||||
|
||||
prompt: str = Field(..., description="提示词")
|
||||
model_id: str | None = Field(None, description="指定模型 ID")
|
||||
task_type: str | None = Field(
|
||||
None, description="任务类型,用于自动选模型: script/polish"
|
||||
)
|
||||
temperature: float | None = Field(None, description="随机性 (0-2)")
|
||||
max_tokens: int | None = Field(None, description="最大生成长度")
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
"""生成响应"""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
usage: dict | None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""健康检查响应"""
|
||||
|
||||
status: str
|
||||
total_models: int
|
||||
available_models: int
|
||||
models: list[dict]
|
||||
|
||||
|
||||
# ============ API 路由 ============
|
||||
|
||||
|
||||
@router.get("/platforms", response_model=ApiResponse[list[PlatformResponse]])
|
||||
async def list_platforms():
|
||||
"""获取所有平台列表"""
|
||||
config_loader = get_config_loader()
|
||||
platforms = config_loader.get_all_platforms()
|
||||
|
||||
return success_response(
|
||||
data=[
|
||||
PlatformResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
provider=p.provider,
|
||||
)
|
||||
for p in platforms
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ApiResponse[list[ModelResponse]])
|
||||
async def list_models(capability: str | None = None):
|
||||
"""获取模型列表
|
||||
|
||||
Args:
|
||||
capability: 按能力过滤,如 script、polish、chat
|
||||
"""
|
||||
router = await get_model_router()
|
||||
models = router.list_models(capability=capability)
|
||||
|
||||
return success_response(
|
||||
data=[
|
||||
ModelResponse(
|
||||
id=m["id"],
|
||||
platform_id=m["platform_id"],
|
||||
model_name=m["model_name"],
|
||||
display_name=m["display_name"],
|
||||
capabilities=m["capabilities"],
|
||||
default_params=m["default_params"],
|
||||
is_enabled=True, # 列表中的都是启用的
|
||||
full_model_id=f"{m['platform_id']}/{m['id']}",
|
||||
)
|
||||
for m in models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse[GenerateResponse])
|
||||
async def generate_text(data: GenerateRequest):
|
||||
"""文本生成(自动路由到对应平台)"""
|
||||
router = await get_model_router()
|
||||
|
||||
try:
|
||||
result = await router.generate(
|
||||
prompt=data.prompt,
|
||||
model_id=data.model_id,
|
||||
task_type=data.task_type,
|
||||
temperature=data.temperature,
|
||||
max_tokens=data.max_tokens,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=GenerateResponse(
|
||||
content=result.content,
|
||||
model=result.model,
|
||||
usage=result.usage,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/health", response_model=ApiResponse[HealthResponse])
|
||||
async def health_check(model_id: str | None = None):
|
||||
"""检查模型健康状态"""
|
||||
router = await get_model_router()
|
||||
|
||||
health_results = await router.health_check(model_id)
|
||||
|
||||
models_status = []
|
||||
available_count = 0
|
||||
|
||||
for mid, health in health_results.items():
|
||||
models_status.append(
|
||||
{
|
||||
"id": mid,
|
||||
"name": health.name,
|
||||
"is_available": health.is_available,
|
||||
"response_time": health.response_time,
|
||||
"last_error": health.last_error,
|
||||
}
|
||||
)
|
||||
if health.is_available:
|
||||
available_count += 1
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"status": "healthy" if available_count > 0 else "unhealthy",
|
||||
"total_models": len(models_status),
|
||||
"available_models": available_count,
|
||||
"models": models_status,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/platforms/{platform_id}/test", response_model=ApiResponse[dict])
|
||||
async def test_platform_connection(platform_id: str):
|
||||
"""测试平台连接"""
|
||||
from app.ai.model_router import PlatformInstance
|
||||
|
||||
config_loader = get_config_loader()
|
||||
platform = config_loader.get_platform(platform_id)
|
||||
|
||||
if not platform:
|
||||
raise HTTPException(status_code=404, detail="平台不存在")
|
||||
|
||||
try:
|
||||
# PlatformInstance 自动从 Settings 读取 API Key
|
||||
instance = PlatformInstance(
|
||||
{
|
||||
"id": platform.id,
|
||||
"name": platform.name,
|
||||
"provider": platform.provider,
|
||||
"base_url": platform.base_url,
|
||||
}
|
||||
)
|
||||
|
||||
# 尝试调用
|
||||
result = await instance.provider.health_check()
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"platform_id": platform_id,
|
||||
"success": result.is_available,
|
||||
"response_time": result.response_time,
|
||||
"message": "连接成功" if result.is_available else result.last_error,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return success_response(
|
||||
data={
|
||||
"platform_id": platform_id,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse[dict])
|
||||
async def reload_config():
|
||||
"""重新加载配置文件"""
|
||||
router = await get_model_router()
|
||||
reloaded = router.reload_config()
|
||||
|
||||
if reloaded:
|
||||
return success_response(data={"reloaded": True}, message="配置已重新加载")
|
||||
else:
|
||||
return success_response(data={"reloaded": False}, message="配置文件无变化")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt 模板 API
|
||||
# =============================================================================
|
||||
|
||||
from app.ai.prompts import (
|
||||
SCRIPT_TYPES,
|
||||
VIDEO_STYLES,
|
||||
PolishPromptBuilder,
|
||||
ScriptPromptBuilder,
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplatesResponse(BaseModel):
|
||||
"""Prompt 模板配置响应"""
|
||||
|
||||
script_types: list[dict]
|
||||
video_styles: list[dict]
|
||||
tones: list[str]
|
||||
|
||||
|
||||
class ScriptGenerateRequest(BaseModel):
|
||||
"""脚本生成请求"""
|
||||
|
||||
topic: str = Field(..., description="脚本主题", example="水电改造的3个致命错误")
|
||||
duration: int = Field(30, ge=15, le=120, description="视频时长(秒)")
|
||||
script_type: str = Field("干货型", description="脚本类型")
|
||||
video_style: str = Field("口播", description="视频风格")
|
||||
tone: str | None = Field(None, description="语气风格")
|
||||
requirements: str | None = Field(None, description="额外要求")
|
||||
model_id: str | None = Field(None, description="指定模型ID,默认使用系统默认模型")
|
||||
|
||||
|
||||
class ScriptGenerateResponse(BaseModel):
|
||||
"""脚本生成响应 - 针对前端展示优化"""
|
||||
|
||||
success: bool
|
||||
script: list[
|
||||
dict | None
|
||||
] # 镜头列表,包含 shot_number, type, scene/prompt, voiceover, duration, word_count
|
||||
total_duration: int | None # 预计总时长(秒)
|
||||
target_duration: int # 目标时长(秒)
|
||||
total_word_count: int | None # 总字数(供前端展示)
|
||||
segment_count: int | None # 分镜数量(供前端展示)
|
||||
empty_shot_count: int | None # 空镜数量(供前端展示)
|
||||
script_type: str
|
||||
model: str
|
||||
usage: dict | None
|
||||
error: str | None
|
||||
raw_content: str | None
|
||||
|
||||
|
||||
class PolishRequest(BaseModel):
|
||||
"""润色请求"""
|
||||
|
||||
content: str = Field(..., description="需要润色的内容")
|
||||
polish_type: str = Field("voiceover", description="润色类型:scene/voiceover")
|
||||
model_id: str | None = Field(None, description="指定模型ID")
|
||||
|
||||
|
||||
class PolishResponse(BaseModel):
|
||||
"""润色响应"""
|
||||
|
||||
success: bool
|
||||
original: str
|
||||
polished: str | None
|
||||
polish_type: str
|
||||
model: str
|
||||
usage: dict | None
|
||||
|
||||
|
||||
@router.get("/prompts/templates", response_model=ApiResponse[PromptTemplatesResponse])
|
||||
async def get_prompt_templates():
|
||||
"""
|
||||
获取所有可用的 Prompt 模板配置
|
||||
|
||||
包括脚本类型、视频风格、语气风格等选项。
|
||||
"""
|
||||
return success_response(
|
||||
data={
|
||||
"script_types": [
|
||||
{
|
||||
"id": key,
|
||||
"name": value["name"],
|
||||
"description": value["description"],
|
||||
"key_points": value["key_points"],
|
||||
}
|
||||
for key, value in SCRIPT_TYPES.items()
|
||||
if key != "default"
|
||||
],
|
||||
"video_styles": [
|
||||
{
|
||||
"id": key,
|
||||
"name": value["name"],
|
||||
"description": value["description"],
|
||||
}
|
||||
for key, value in VIDEO_STYLES.items()
|
||||
],
|
||||
"tones": ["专业", "亲和", "幽默", "严肃", "激情"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompts/build", response_model=ApiResponse[dict])
|
||||
async def build_system_prompt(
|
||||
duration: int = 30,
|
||||
script_type: str = "干货型",
|
||||
video_style: str = "口播",
|
||||
tone: str | None = None,
|
||||
):
|
||||
"""
|
||||
构建系统 Prompt(用于调试和预览)
|
||||
|
||||
返回构建好的系统 Prompt,可用于前端预览或调试。
|
||||
"""
|
||||
builder = ScriptPromptBuilder()
|
||||
prompt = builder.build(
|
||||
duration=duration,
|
||||
script_type=script_type,
|
||||
video_style=video_style,
|
||||
industry="家装",
|
||||
tone=tone,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"system_prompt": prompt,
|
||||
"length": len(prompt),
|
||||
"parameters": {
|
||||
"duration": duration,
|
||||
"script_type": script_type,
|
||||
"video_style": video_style,
|
||||
"tone": tone,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/scripts/generate", response_model=ApiResponse[ScriptGenerateResponse])
|
||||
async def generate_script(data: ScriptGenerateRequest):
|
||||
"""
|
||||
生成家装行业短视频脚本
|
||||
|
||||
使用专业的 Prompt 模板生成包含分镜+空镜的混合脚本。
|
||||
针对前端展示优化,返回分镜数、空镜数、总字数等统计信息。
|
||||
"""
|
||||
router = await get_model_router()
|
||||
|
||||
# 构建系统 Prompt
|
||||
builder = ScriptPromptBuilder()
|
||||
system_prompt = builder.build(
|
||||
duration=data.duration,
|
||||
script_type=data.script_type,
|
||||
video_style=data.video_style,
|
||||
industry="家装",
|
||||
tone=data.requirements,
|
||||
custom_requirements=data.requirements,
|
||||
)
|
||||
|
||||
# 构建用户输入
|
||||
user_prompt = f"""主题是"{data.topic}"
|
||||
|
||||
要求:
|
||||
1. 严格按照时长要求控制
|
||||
2. 每个镜头的配音字数必须匹配时长(4-5字/秒)
|
||||
3. 空镜必须有画外音,不能为空
|
||||
4. 只返回JSON数组,不要有其他文字"""
|
||||
|
||||
full_prompt = f"{system_prompt}\n\n【用户输入】\n{user_prompt}\n\n请生成脚本,只返回JSON数组:"
|
||||
|
||||
# 调用模型
|
||||
try:
|
||||
result = await router.generate(
|
||||
prompt=full_prompt,
|
||||
model_id=data.model_id,
|
||||
task_type="script",
|
||||
temperature=0.7,
|
||||
max_tokens=2500,
|
||||
)
|
||||
|
||||
# 安全地解析 JSON 响应
|
||||
success_parsed, parsed_data, error_msg = safe_parse_ai_json_response(
|
||||
result.content
|
||||
)
|
||||
|
||||
if not success_parsed:
|
||||
logger.error(f"AI 响应解析失败: {error_msg}")
|
||||
return success_response(
|
||||
data={
|
||||
"success": False,
|
||||
"script": None,
|
||||
"total_duration": None,
|
||||
"target_duration": data.duration,
|
||||
"total_word_count": None,
|
||||
"segment_count": None,
|
||||
"empty_shot_count": None,
|
||||
"script_type": data.script_type,
|
||||
"model": result.model,
|
||||
"usage": result.usage,
|
||||
"error": error_msg or "JSON解析失败",
|
||||
"raw_content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
# 验证并标准化分镜数据
|
||||
try:
|
||||
script = validate_and_normalize_shots(parsed_data)
|
||||
except Exception as e:
|
||||
logger.error(f"分镜数据标准化失败: {e}")
|
||||
return success_response(
|
||||
data={
|
||||
"success": False,
|
||||
"script": None,
|
||||
"total_duration": None,
|
||||
"target_duration": data.duration,
|
||||
"total_word_count": None,
|
||||
"segment_count": None,
|
||||
"empty_shot_count": None,
|
||||
"script_type": data.script_type,
|
||||
"model": result.model,
|
||||
"usage": result.usage,
|
||||
"error": f"分镜数据格式错误: {e}",
|
||||
"raw_content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
# 验证分镜结构
|
||||
is_valid, validation_errors = validate_shots_structure(script)
|
||||
if not is_valid:
|
||||
logger.warning(f"分镜结构验证失败: {validation_errors}")
|
||||
# 继续处理,但记录警告
|
||||
|
||||
# 计算统计信息(供前端展示)
|
||||
total_duration = sum(
|
||||
int(shot.get("duration", "5s").rstrip("s秒"))
|
||||
for shot in script
|
||||
if isinstance(shot, dict)
|
||||
)
|
||||
total_word_count = sum(
|
||||
len(shot.get("voiceover", "")) for shot in script if isinstance(shot, dict)
|
||||
)
|
||||
segment_count = sum(
|
||||
1
|
||||
for shot in script
|
||||
if isinstance(shot, dict) and shot.get("type") == "segment"
|
||||
)
|
||||
empty_shot_count = sum(
|
||||
1
|
||||
for shot in script
|
||||
if isinstance(shot, dict) and shot.get("type") == "empty_shot"
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"success": True,
|
||||
"script": script,
|
||||
"total_duration": total_duration,
|
||||
"target_duration": data.duration,
|
||||
"total_word_count": total_word_count,
|
||||
"segment_count": segment_count,
|
||||
"empty_shot_count": empty_shot_count,
|
||||
"script_type": data.script_type,
|
||||
"model": result.model,
|
||||
"usage": result.usage,
|
||||
"error": None,
|
||||
"raw_content": None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/scripts/polish", response_model=ApiResponse[PolishResponse])
|
||||
async def polish_script_content(data: PolishRequest):
|
||||
"""
|
||||
润色脚本内容
|
||||
|
||||
对场景描述或口播文案进行专业润色。
|
||||
"""
|
||||
router = await get_model_router()
|
||||
|
||||
# 构建润色 Prompt
|
||||
builder = PolishPromptBuilder()
|
||||
system_prompt = builder.build(data.polish_type)
|
||||
|
||||
full_prompt = f"{system_prompt}\n\n【待润色内容】\n{data.content}\n\n请润色:"
|
||||
|
||||
# 调用模型
|
||||
try:
|
||||
result = await router.generate(
|
||||
prompt=full_prompt,
|
||||
model_id=data.model_id,
|
||||
task_type="polish",
|
||||
temperature=0.6,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"success": True,
|
||||
"original": data.content,
|
||||
"polished": result.content,
|
||||
"polish_type": data.polish_type,
|
||||
"model": result.model,
|
||||
"usage": result.usage,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"润色失败: {str(e)}")
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
认证模块 API
|
||||
============
|
||||
|
||||
采用"手机号 + JWT"的认证方案。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.security import create_access_token
|
||||
from app.crud.user import user as user_crud
|
||||
from app.db.session import AsyncSession, get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import LoginResponse, MobileLoginRequest
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/login", response_model=ApiResponse[LoginResponse])
|
||||
async def login(
|
||||
request: MobileLoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
手机号登录/注册
|
||||
|
||||
- 如果手机号已存在,返回对应用户
|
||||
- 如果不存在,自动创建新用户
|
||||
- 返回 JWT Token 用于后续认证
|
||||
"""
|
||||
# 获取或创建用户
|
||||
user_obj = await user_crud.get_or_create_by_mobile(
|
||||
db,
|
||||
mobile=request.mobile,
|
||||
nickname=request.nickname,
|
||||
)
|
||||
|
||||
# 生成 JWT Token
|
||||
token = create_access_token(data={"sub": user_obj.id, "mobile": user_obj.mobile})
|
||||
|
||||
return success_response(
|
||||
data=LoginResponse(
|
||||
token=token,
|
||||
user={
|
||||
"id": user_obj.id,
|
||||
"nickname": user_obj.nickname or "",
|
||||
"avatar": user_obj.avatar_url or "",
|
||||
},
|
||||
),
|
||||
message="登录成功",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=ApiResponse[dict])
|
||||
async def get_me(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前登录用户信息"""
|
||||
return success_response(
|
||||
data={
|
||||
"id": current_user.id,
|
||||
"mobile": current_user.mobile,
|
||||
"nickname": current_user.nickname,
|
||||
"avatar": current_user.avatar_url,
|
||||
"createdAt": current_user.created_at.isoformat() if current_user.created_at else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse[dict])
|
||||
async def refresh_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""刷新 JWT Token"""
|
||||
new_token = create_access_token(data={"sub": current_user.id, "mobile": current_user.mobile})
|
||||
return success_response(
|
||||
data={"token": new_token},
|
||||
message="Token 刷新成功",
|
||||
)
|
||||
@@ -0,0 +1,560 @@
|
||||
"""
|
||||
Avatar 形象克隆模块
|
||||
==================
|
||||
|
||||
串行流程:
|
||||
1. 使用上传的视频创建 KlingAI 自定义音色 (custom-voices)
|
||||
2. 轮询等待音色生成完成,获取 voice_id
|
||||
3. 使用同一视频 + voice_id 创建 KlingAI 主体 (advanced-custom-elements)
|
||||
4. 轮询等待主体生成完成,获取 provider_element_id
|
||||
5. 返回统一的 AvatarItem
|
||||
|
||||
异步架构:
|
||||
- POST /avatar/clone 只负责注册到 Async Engine(纯 Redis,无 DB),立即返回 task_id
|
||||
- 真正的轮询由 Async Engine Scheduler 在后台执行
|
||||
- 前端通过 SSE 或轮询 GET /avatar/tasks/{task_id} 查询进度
|
||||
|
||||
数据策略:
|
||||
- 形象克隆数据只保存在前端本地,后端不持久化到数据库
|
||||
- 任务运行时的中间状态全部存储在 Redis 中(TTL 24h)
|
||||
|
||||
错误提示策略:
|
||||
- custom-voice 失败:提示"有声的人物视频"相关原因
|
||||
- element 失败:提示视频内容/质量不符合主体创建要求
|
||||
- 超时:标记为 timeout,支持重试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.api.deps import get_current_user
|
||||
from app.config import get_settings
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.schemas.enums import AvatarCloneStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_kling_provider() -> KlingAIProvider:
|
||||
settings = get_settings()
|
||||
return KlingAIProvider(
|
||||
config={
|
||||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _get_avatar_state(redis, job_id: str) -> dict | None:
|
||||
"""从 Redis 读取 avatar 任务完整状态"""
|
||||
data = await redis.hgetall(f"job:{job_id}")
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# 解析 JSON 字段
|
||||
for key in ("result", "params"):
|
||||
if key in data and data[key]:
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
data[key] = json.loads(data[key])
|
||||
return data
|
||||
|
||||
|
||||
class CloneAvatarRequest(BaseModel):
|
||||
"""创建形象克隆请求"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=20, description="形象名称")
|
||||
video_url: str = Field(description="人物视频 URL")
|
||||
|
||||
|
||||
class CloneAvatarResponse(BaseModel):
|
||||
"""创建形象克隆响应"""
|
||||
|
||||
task_id: str = Field(..., description="任务 ID(用于 SSE/轮询跟踪进度)")
|
||||
status: str = Field("pending", description="初始状态")
|
||||
|
||||
|
||||
class AvatarTaskStatusResponse(BaseModel):
|
||||
"""任务状态查询响应"""
|
||||
|
||||
task_id: str
|
||||
status: str = Field(..., description="当前状态")
|
||||
fail_reason: str | None = Field(None, description="失败原因")
|
||||
voice_id: str | None = Field(None, description="已生成的音色 ID")
|
||||
human_id: int | None = Field(None, description="已生成的主体 ID")
|
||||
trial_url: str | None = Field(None, description="试听 URL")
|
||||
video_url: str = Field(..., description="原始视频 URL")
|
||||
name: str = Field(..., description="形象名称")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
|
||||
class AvatarItem(BaseModel):
|
||||
"""形象库列表项"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(..., description="形象唯一标识")
|
||||
name: str = Field(..., description="展示名称")
|
||||
voice_id: str = Field(..., description="Kling 自定义音色 ID")
|
||||
human_id: int = Field(..., description="数字人主体 ID")
|
||||
video_url: str = Field(description="原始人物视频 URL")
|
||||
trial_url: str | None = Field(None, description="音色试听 URL")
|
||||
record_time: str = Field(description="创建时间 ISO 字符串")
|
||||
|
||||
|
||||
class UpdateAvatarNameRequest(BaseModel):
|
||||
"""更新形象名称请求"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=20, description="新形象名称")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# API 路由
|
||||
# ============================================================
|
||||
|
||||
|
||||
@router.post("/avatar/clone", response_model=ApiResponse[CloneAvatarResponse])
|
||||
async def clone_avatar(
|
||||
data: CloneAvatarRequest,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
提交形象克隆任务
|
||||
|
||||
立即返回 task_id,前端通过 SSE 或轮询跟踪进度。
|
||||
实际串行流程由 Async Engine Scheduler 异步执行。
|
||||
任务状态纯 Redis 存储,不写入数据库。
|
||||
"""
|
||||
user_id = str(current_user.id)
|
||||
name = data.name.strip()
|
||||
video_url = data.video_url.strip()
|
||||
|
||||
# 生成 task_id
|
||||
task_id = f"avt_{uuid.uuid4().hex[:16]}"
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# 写入 Redis,供 Async Engine 调度(同时存储 avatar 初始状态)
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
await registry.create(task_id, "avatar_clone", user_id)
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
progress=5,
|
||||
message="开始形象克隆...",
|
||||
completed=0,
|
||||
total=1,
|
||||
params={
|
||||
"avatar_id": task_id,
|
||||
"name": name,
|
||||
"video_url": video_url,
|
||||
"user_id": user_id,
|
||||
},
|
||||
# 存储 avatar 状态字段(供 API 查询)
|
||||
avatar_status=AvatarCloneStatus.PENDING.value,
|
||||
avatar_name=name,
|
||||
avatar_video_url=video_url,
|
||||
voice_id="",
|
||||
provider_element_id="",
|
||||
provider_voice_job_id="",
|
||||
provider_element_job_id="",
|
||||
trial_url="",
|
||||
fail_reason="",
|
||||
created_at=now.isoformat(),
|
||||
updated_at=now.isoformat(),
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
return success_response(data=CloneAvatarResponse(task_id=task_id, status="pending"))
|
||||
|
||||
|
||||
@router.get("/avatar/tasks/{task_id}", response_model=ApiResponse[AvatarTaskStatusResponse])
|
||||
async def get_avatar_task_status(
|
||||
task_id: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""查询形象克隆任务状态(从 Redis 读取)"""
|
||||
redis = get_redis_client()
|
||||
state = await _get_avatar_state(redis, task_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
# 权限检查
|
||||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||||
if params.get("user_id") != str(current_user.id):
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
def _dt(key: str) -> datetime:
|
||||
raw = state.get(key, "")
|
||||
if raw:
|
||||
try:
|
||||
return datetime.fromisoformat(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.now(UTC)
|
||||
|
||||
def _int(key: str) -> int | None:
|
||||
raw = state.get(key, "")
|
||||
if raw:
|
||||
try:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
return success_response(
|
||||
data=AvatarTaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=state.get("avatar_status", state.get("status", "unknown")),
|
||||
fail_reason=state.get("fail_reason") or None,
|
||||
voice_id=state.get("voice_id") or None,
|
||||
human_id=_int("provider_element_id"),
|
||||
trial_url=state.get("trial_url") or None,
|
||||
video_url=params.get("video_url", ""),
|
||||
name=params.get("name", ""),
|
||||
created_at=_dt("created_at"),
|
||||
updated_at=_dt("updated_at"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/avatar/clone/stream")
|
||||
async def sse_avatar_clone(
|
||||
task_id: str = Query(..., alias="task_id", description="任务 ID"),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
SSE 流:实时推送形象克隆任务状态
|
||||
|
||||
前端连接后,每 3 秒推送一次状态,直到任务结束(succeed / failed / timeout)。
|
||||
"""
|
||||
user_id = str(current_user.id)
|
||||
|
||||
async def event_stream():
|
||||
for _ in range(400): # 最多 20 分钟(400 * 3s)
|
||||
redis = get_redis_client()
|
||||
state = await _get_avatar_state(redis, task_id)
|
||||
|
||||
if not state:
|
||||
payload = json.dumps(
|
||||
{"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False
|
||||
)
|
||||
yield f"event: error\ndata: {payload}\n\n"
|
||||
break
|
||||
|
||||
# 权限检查
|
||||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||||
if params.get("user_id") != user_id:
|
||||
payload = json.dumps(
|
||||
{"status": "error", "fail_reason": "任务不存在或无权限"}, ensure_ascii=False
|
||||
)
|
||||
yield f"event: error\ndata: {payload}\n\n"
|
||||
break
|
||||
|
||||
avatar_status = state.get("avatar_status", state.get("status", "unknown"))
|
||||
|
||||
payload = json.dumps(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": avatar_status,
|
||||
"fail_reason": state.get("fail_reason") or None,
|
||||
"voice_id": state.get("voice_id") or None,
|
||||
"provider_element_id": state.get("provider_element_id") or None,
|
||||
"trial_url": state.get("trial_url") or None,
|
||||
"video_url": params.get("video_url", ""),
|
||||
"name": params.get("name", ""),
|
||||
"created_at": state.get("created_at", ""),
|
||||
"updated_at": state.get("updated_at", ""),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
yield f"data: {payload}\n\n"
|
||||
|
||||
if avatar_status in (
|
||||
AvatarCloneStatus.SUCCEED,
|
||||
AvatarCloneStatus.VOICE_FAILED,
|
||||
AvatarCloneStatus.ELEMENT_FAILED,
|
||||
AvatarCloneStatus.TIMEOUT,
|
||||
):
|
||||
break
|
||||
|
||||
await asyncio.sleep(3)
|
||||
else:
|
||||
# 达到最大轮询次数,推送超时事件
|
||||
payload = json.dumps(
|
||||
{"status": "timeout", "fail_reason": "连接超时,请通过轮询接口继续跟踪"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
yield f"event: timeout\ndata: {payload}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/avatar/tasks/{task_id}/retry", response_model=ApiResponse[dict])
|
||||
async def retry_avatar_task(
|
||||
task_id: str,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
重试失败或超时的形象克隆任务
|
||||
|
||||
仅允许对 voice_failed / element_failed / timeout 状态的任务重试。
|
||||
重试时会重置状态为 pending 并重新注册到 Async Engine。
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
state = await _get_avatar_state(redis, task_id)
|
||||
if not state:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||||
if params.get("user_id") != str(current_user.id):
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
avatar_status = state.get("avatar_status", state.get("status", ""))
|
||||
if avatar_status not in (
|
||||
AvatarCloneStatus.VOICE_FAILED.value,
|
||||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
AvatarCloneStatus.TIMEOUT.value,
|
||||
):
|
||||
raise HTTPException(status_code=400, detail=f"当前状态 {avatar_status} 不支持重试")
|
||||
|
||||
# 重置状态
|
||||
registry = JobRegistry(redis)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
avatar_status=AvatarCloneStatus.PENDING,
|
||||
fail_reason="",
|
||||
voice_id="",
|
||||
provider_element_id="",
|
||||
provider_voice_job_id="",
|
||||
provider_element_job_id="",
|
||||
trial_url="",
|
||||
updated_at=now,
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
return success_response(data={"task_id": task_id, "status": "pending"})
|
||||
|
||||
|
||||
@router.delete("/avatar/{avatar_id}", response_model=ApiResponse[dict])
|
||||
async def delete_avatar(
|
||||
avatar_id: str,
|
||||
voice_id: str | None = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
删除形象:清理 Kling 资源 + 删除 Redis 任务记录
|
||||
|
||||
不操作数据库,形象数据由前端本地管理。
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
state = await _get_avatar_state(redis, avatar_id)
|
||||
|
||||
# 获取 Kling 资源 ID(优先用传入的,否则从 Redis 读)
|
||||
actual_voice_id = voice_id
|
||||
actual_element_id = None
|
||||
if state:
|
||||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||||
if params.get("user_id") == str(current_user.id):
|
||||
actual_element_id = state.get("provider_element_id")
|
||||
if not actual_voice_id:
|
||||
actual_voice_id = state.get("voice_id")
|
||||
|
||||
# 异步清理 Kling 资源(不阻塞前端)
|
||||
provider = _get_kling_provider()
|
||||
if actual_element_id:
|
||||
try:
|
||||
await provider.delete_element(str(actual_element_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"delete_element failed: {e}")
|
||||
|
||||
if actual_voice_id:
|
||||
try:
|
||||
await provider.delete_custom_voice(actual_voice_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"delete_custom_voice failed: {e}")
|
||||
|
||||
# 删除 Redis 任务记录
|
||||
registry = JobRegistry(redis)
|
||||
await registry.delete(avatar_id)
|
||||
|
||||
return success_response(data={"success": True, "message": "形象已删除"})
|
||||
|
||||
|
||||
@router.get("/avatar/library", response_model=ApiResponse[list[AvatarItem]])
|
||||
async def get_avatar_library(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前用户的克隆形象库
|
||||
|
||||
形象数据只保存在前端本地,后端不持久化。
|
||||
此接口始终返回空列表,由前端从 localStorage/文件系统读取真实数据。
|
||||
"""
|
||||
return success_response(data=[])
|
||||
|
||||
|
||||
@router.patch("/avatar/{avatar_id}", response_model=ApiResponse[dict])
|
||||
async def update_avatar_name(
|
||||
avatar_id: str,
|
||||
data: UpdateAvatarNameRequest,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
更新形象名称
|
||||
|
||||
形象数据由前端本地管理,后端仅返回成功。
|
||||
"""
|
||||
new_name = data.name.strip()
|
||||
if not new_name:
|
||||
raise HTTPException(status_code=400, detail="名称不能为空")
|
||||
|
||||
return success_response(data={"success": True, "name": new_name})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 管理和监控接口(用于排查问题和手动恢复)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AvatarHealthResponse(BaseModel):
|
||||
"""形象克隆服务健康状态"""
|
||||
|
||||
total_processing: int = Field(..., description="处理中的任务总数")
|
||||
pending: int = Field(..., description="待处理任务数")
|
||||
voice_processing: int = Field(..., description="音色生成中任务数")
|
||||
element_processing: int = Field(..., description="主体生成中任务数")
|
||||
stuck_tasks: int = Field(..., description="卡住任务数(超过30分钟)")
|
||||
recent_failures: int = Field(..., description="最近1小时失败数")
|
||||
|
||||
|
||||
@router.get("/avatar/health", response_model=ApiResponse[AvatarHealthResponse])
|
||||
async def get_avatar_health(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取形象克隆服务健康状态
|
||||
|
||||
基于 Redis 运行中任务统计,不查询数据库。
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
job_ids = await registry.get_running_job_ids()
|
||||
|
||||
total_processing = 0
|
||||
pending = 0
|
||||
voice_processing = 0
|
||||
element_processing = 0
|
||||
stuck_tasks = 0
|
||||
recent_failures = 0
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stuck_threshold = now.timestamp() - 30 * 60 # 30 分钟前
|
||||
recent_threshold = now.timestamp() - 60 * 60 # 1 小时前
|
||||
|
||||
for job_id in job_ids:
|
||||
state = await _get_avatar_state(redis, job_id)
|
||||
if not state:
|
||||
continue
|
||||
|
||||
# 只统计当前用户的任务(非管理员)
|
||||
params = state.get("params", {}) if isinstance(state.get("params"), dict) else {}
|
||||
if params.get("user_id") != str(current_user.id):
|
||||
continue
|
||||
|
||||
job_type = state.get("type", "")
|
||||
if job_type != "avatar_clone":
|
||||
continue
|
||||
|
||||
avatar_status = state.get("avatar_status", state.get("status", ""))
|
||||
total_processing += 1
|
||||
|
||||
if avatar_status == AvatarCloneStatus.PENDING.value:
|
||||
pending += 1
|
||||
elif avatar_status == AvatarCloneStatus.VOICE_PROCESSING.value:
|
||||
voice_processing += 1
|
||||
elif avatar_status == AvatarCloneStatus.ELEMENT_PROCESSING.value:
|
||||
element_processing += 1
|
||||
|
||||
# 检查是否卡住(updated_at 超过 30 分钟)
|
||||
updated_at_raw = state.get("updated_at", "")
|
||||
if updated_at_raw:
|
||||
try:
|
||||
updated_ts = datetime.fromisoformat(updated_at_raw).timestamp()
|
||||
if updated_ts < stuck_threshold and avatar_status in (
|
||||
AvatarCloneStatus.PENDING.value,
|
||||
AvatarCloneStatus.VOICE_PROCESSING.value,
|
||||
AvatarCloneStatus.ELEMENT_PROCESSING.value,
|
||||
):
|
||||
stuck_tasks += 1
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查最近失败
|
||||
if avatar_status in (
|
||||
AvatarCloneStatus.VOICE_FAILED.value,
|
||||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
AvatarCloneStatus.TIMEOUT.value,
|
||||
):
|
||||
updated_at_raw = state.get("updated_at", "")
|
||||
if updated_at_raw:
|
||||
try:
|
||||
updated_ts = datetime.fromisoformat(updated_at_raw).timestamp()
|
||||
if updated_ts >= recent_threshold:
|
||||
recent_failures += 1
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return success_response(
|
||||
data=AvatarHealthResponse(
|
||||
total_processing=total_processing,
|
||||
pending=pending,
|
||||
voice_processing=voice_processing,
|
||||
element_processing=element_processing,
|
||||
stuck_tasks=stuck_tasks,
|
||||
recent_failures=recent_failures,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/avatar/admin/trigger-recovery", response_model=ApiResponse[dict])
|
||||
async def admin_trigger_recovery(
|
||||
current_user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
手动触发卡住任务恢复(管理员接口)
|
||||
|
||||
Async Engine 会自动轮询,无需手动触发恢复。
|
||||
"""
|
||||
# 权限检查:基于特定手机号判断管理员
|
||||
is_admin = current_user.mobile in ["13800138000", "admin"]
|
||||
if not is_admin:
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"message": "Async Engine 会持续自动轮询,无需手动触发恢复",
|
||||
"task_id": None,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
火山引擎音视频字幕 API 路由
|
||||
============================
|
||||
|
||||
提供字幕生成、自动打轴等功能。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.caption import (
|
||||
AutoAlignResult,
|
||||
AutoAlignSubmitRequest,
|
||||
CaptionResult,
|
||||
CaptionSubmitRequest,
|
||||
CaptionTaskResponse,
|
||||
SrtSubtitleResponse,
|
||||
)
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.services.volcengine_caption_service import (
|
||||
VolcengineCaptionError,
|
||||
VolcengineCaptionService,
|
||||
get_caption_service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/caption", tags=["Caption"])
|
||||
|
||||
|
||||
@router.post("/submit", response_model=ApiResponse[CaptionTaskResponse])
|
||||
async def submit_caption_task(request: CaptionSubmitRequest):
|
||||
"""
|
||||
提交字幕生成任务
|
||||
|
||||
提交音频/视频文件URL,生成带时间轴的字幕。
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
task_id = await service.submit_caption_task(
|
||||
audio_url=request.audio_url,
|
||||
language=request.language,
|
||||
caption_type=request.caption_type,
|
||||
use_punc=request.use_punc,
|
||||
use_itn=request.use_itn,
|
||||
words_per_line=request.words_per_line,
|
||||
max_lines=request.max_lines,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=CaptionTaskResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
),
|
||||
message="字幕任务已提交",
|
||||
)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"提交字幕任务失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"提交字幕任务异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"提交失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/query/{task_id}", response_model=ApiResponse[CaptionResult])
|
||||
async def query_caption_task(task_id: str, blocking: bool = True):
|
||||
"""
|
||||
查询字幕任务结果
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
blocking: 是否阻塞等待结果 (默认True)
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
result = await service.query_caption_task(task_id, blocking=blocking)
|
||||
|
||||
return success_response(data=result)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"查询字幕任务失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"查询字幕任务异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse[CaptionResult])
|
||||
async def generate_caption(request: CaptionSubmitRequest, max_wait_time: int = 120):
|
||||
"""
|
||||
生成字幕(完整流程)
|
||||
|
||||
提交任务并轮询结果,直接返回最终字幕数据。
|
||||
适用于不需要异步处理的场景。
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
result = await service.generate_caption(
|
||||
audio_url=request.audio_url,
|
||||
language=request.language,
|
||||
caption_type=request.caption_type,
|
||||
use_punc=request.use_punc,
|
||||
use_itn=request.use_itn,
|
||||
words_per_line=request.words_per_line,
|
||||
max_lines=request.max_lines,
|
||||
max_wait_time=max_wait_time,
|
||||
)
|
||||
|
||||
return success_response(data=result)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"生成字幕失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"生成字幕异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate-ass", response_model=ApiResponse[dict])
|
||||
async def generate_ass(
|
||||
request: CaptionSubmitRequest,
|
||||
video_width: int = 1080,
|
||||
video_height: int = 1920,
|
||||
max_wait_time: int = 120,
|
||||
):
|
||||
"""
|
||||
生成 ASS 格式字幕(完整流程,使用抖音美好体)
|
||||
|
||||
Args:
|
||||
video_width: 视频宽度(默认 1080)
|
||||
video_height: 视频高度(默认 1920)
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
result = await service.generate_caption(
|
||||
audio_url=request.audio_url,
|
||||
language=request.language,
|
||||
caption_type=request.caption_type,
|
||||
use_punc=request.use_punc,
|
||||
use_itn=request.use_itn,
|
||||
words_per_line=request.words_per_line,
|
||||
max_lines=request.max_lines,
|
||||
max_wait_time=max_wait_time,
|
||||
)
|
||||
|
||||
ass_content = service.to_ass(
|
||||
result.utterances,
|
||||
video_width=video_width,
|
||||
video_height=video_height,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"ass_content": ass_content,
|
||||
"utterances": result.utterances,
|
||||
"duration": result.duration,
|
||||
"font": "DouyinSansBold",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成ASS字幕失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate-srt", response_model=ApiResponse[SrtSubtitleResponse])
|
||||
async def generate_srt(request: CaptionSubmitRequest, max_wait_time: int = 120):
|
||||
"""
|
||||
生成 SRT 格式字幕(完整流程)
|
||||
|
||||
直接返回 SRT 格式字幕文件内容。
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
result = await service.generate_caption(
|
||||
audio_url=request.audio_url,
|
||||
language=request.language,
|
||||
caption_type=request.caption_type,
|
||||
use_punc=request.use_punc,
|
||||
use_itn=request.use_itn,
|
||||
words_per_line=request.words_per_line,
|
||||
max_lines=request.max_lines,
|
||||
max_wait_time=max_wait_time,
|
||||
)
|
||||
|
||||
srt_content = service.to_srt(result.utterances)
|
||||
|
||||
return success_response(
|
||||
data=SrtSubtitleResponse(
|
||||
srt_content=srt_content,
|
||||
utterances=result.utterances,
|
||||
)
|
||||
)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"生成SRT字幕失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"生成SRT字幕异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ata/submit", response_model=ApiResponse[CaptionTaskResponse])
|
||||
async def submit_auto_align_task(request: AutoAlignSubmitRequest):
|
||||
"""
|
||||
提交自动字幕打轴任务
|
||||
|
||||
为已有字幕文本自动配上时间轴。
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
task_id = await service.submit_auto_align_task(
|
||||
audio_url=request.audio_url,
|
||||
audio_text=request.audio_text,
|
||||
caption_type=request.caption_type,
|
||||
sta_punc_mode=request.sta_punc_mode,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=CaptionTaskResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
),
|
||||
message="打轴任务已提交",
|
||||
)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"提交打轴任务失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"提交打轴任务异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"提交失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/ata/query/{task_id}", response_model=ApiResponse[AutoAlignResult])
|
||||
async def query_auto_align_task(task_id: str, blocking: bool = True):
|
||||
"""
|
||||
查询打轴任务结果
|
||||
"""
|
||||
try:
|
||||
service = await get_caption_service()
|
||||
result = await service.query_auto_align_task(task_id, blocking=blocking)
|
||||
|
||||
return success_response(data=result)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"查询打轴任务失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"查询打轴任务异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ata/align")
|
||||
async def auto_align_caption(request: AutoAlignSubmitRequest, max_wait_time: int = 120):
|
||||
"""
|
||||
自动字幕打轴(完整流程)
|
||||
|
||||
提交打轴任务并轮询结果,直接返回最终数据。
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[Caption API] Auto align request: audio_url={request.audio_url[:50]}...")
|
||||
service = await get_caption_service()
|
||||
result = await service.auto_align_caption(
|
||||
audio_url=request.audio_url,
|
||||
audio_text=request.audio_text,
|
||||
caption_type=request.caption_type,
|
||||
sta_punc_mode=request.sta_punc_mode,
|
||||
max_wait_time=max_wait_time,
|
||||
)
|
||||
logger.info(
|
||||
f"[Caption API] Auto align result: utterances_count={len(result.utterances) if result.utterances else 0}"
|
||||
)
|
||||
if result.utterances:
|
||||
logger.info(f"[Caption API] First utterance: {result.utterances[0]}")
|
||||
|
||||
# 手动序列化为字典,确保嵌套模型正确处理
|
||||
response_data = {
|
||||
"code": 0,
|
||||
"message": "Success",
|
||||
"duration": result.duration,
|
||||
"utterances": [
|
||||
{
|
||||
"text": u.text,
|
||||
"start_time": u.start_time,
|
||||
"end_time": u.end_time,
|
||||
}
|
||||
for u in (result.utterances or [])
|
||||
],
|
||||
}
|
||||
logger.info(f"[Caption API] Response data: {response_data}")
|
||||
return success_response(data=response_data)
|
||||
|
||||
except VolcengineCaptionError as e:
|
||||
logger.error(f"自动打轴失败: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"自动打轴异常: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"打轴失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/convert/ass", response_model=ApiResponse[dict])
|
||||
async def convert_to_ass(
|
||||
result: CaptionResult,
|
||||
video_width: int = 1080,
|
||||
video_height: int = 1920,
|
||||
):
|
||||
"""
|
||||
将字幕结果转换为 ASS 格式(使用抖音美好体)
|
||||
"""
|
||||
try:
|
||||
service = VolcengineCaptionService("", "") # 不需要认证
|
||||
ass_content = service.to_ass(
|
||||
result.utterances,
|
||||
video_width=video_width,
|
||||
video_height=video_height,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"ass_content": ass_content,
|
||||
"font": "DouyinSansBold",
|
||||
"utterances_count": len(result.utterances),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换ASS失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"转换失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/convert/srt", response_model=ApiResponse[dict])
|
||||
async def convert_to_srt(result: CaptionResult):
|
||||
"""
|
||||
将字幕结果转换为 SRT 格式
|
||||
|
||||
用于将 /generate 返回的原始数据转换为 SRT 格式。
|
||||
"""
|
||||
try:
|
||||
service = VolcengineCaptionService("", "") # 不需要认证
|
||||
srt_content = service.to_srt(result.utterances)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"srt_content": srt_content,
|
||||
"utterances_count": len(result.utterances),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换SRT失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"转换失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/convert/vtt", response_model=ApiResponse[dict])
|
||||
async def convert_to_vtt(result: CaptionResult):
|
||||
"""
|
||||
将字幕结果转换为 WebVTT 格式
|
||||
"""
|
||||
try:
|
||||
service = VolcengineCaptionService("", "") # 不需要认证
|
||||
vtt_content = service.to_vtt(result.utterances)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"vtt_content": vtt_content,
|
||||
"utterances_count": len(result.utterances),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换VTT失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"转换失败: {str(e)}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
七牛云对象存储 API 路由
|
||||
========================
|
||||
|
||||
提供音视频文件上传、管理和访问功能。
|
||||
|
||||
主要功能:
|
||||
1. 生成上传凭证(客户端直传)
|
||||
2. 服务端文件上传
|
||||
3. 声音克隆样本上传
|
||||
4. 文件删除和管理
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.services.qiniu_service import get_qiniu_service
|
||||
|
||||
router = APIRouter(prefix="/qiniu", tags=["Qiniu Storage"])
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
|
||||
class UploadTokenRequest(BaseModel):
|
||||
"""上传凭证请求"""
|
||||
|
||||
key: str = Field(..., description="文件存储 Key")
|
||||
expires: int = Field(3600, description="Token 有效期(秒)")
|
||||
|
||||
|
||||
class UploadTokenResponse(BaseModel):
|
||||
"""上传凭证响应"""
|
||||
|
||||
token: str
|
||||
key: str
|
||||
uploadUrl: str = "https://upload.qiniup.com"
|
||||
|
||||
|
||||
class FileUploadResponse(BaseModel):
|
||||
"""文件上传响应"""
|
||||
|
||||
key: str
|
||||
url: str
|
||||
hash: str
|
||||
mimeType: str
|
||||
fsize: int
|
||||
isDuplicate: bool = False
|
||||
message: str | None = None
|
||||
existingTaskId: str | None = None # 当检测到重复任务时返回
|
||||
|
||||
|
||||
class DeleteFileRequest(BaseModel):
|
||||
"""删除文件请求"""
|
||||
|
||||
key: str = Field(..., description="文件 Key")
|
||||
|
||||
|
||||
# ============ API 路由 ============
|
||||
|
||||
|
||||
@router.post("/upload-token", response_model=ApiResponse[UploadTokenResponse])
|
||||
async def get_upload_token(request: UploadTokenRequest):
|
||||
"""
|
||||
获取上传凭证(客户端直传)
|
||||
|
||||
前端获取 Token 后,可直接上传到七牛云,无需经过服务端。
|
||||
|
||||
上传地址: https://upload.qiniup.com
|
||||
请求方式: POST (multipart/form-data)
|
||||
请求参数:
|
||||
- token: 上传凭证(本接口返回)
|
||||
- key: 文件存储 Key(本接口返回)
|
||||
- file: 文件内容
|
||||
"""
|
||||
try:
|
||||
service = get_qiniu_service()
|
||||
token = service.get_upload_token(request.key, request.expires)
|
||||
|
||||
return success_response(
|
||||
data=UploadTokenResponse(
|
||||
token=token, key=request.key, uploadUrl="https://upload.qiniup.com"
|
||||
)
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"生成上传凭证失败: {e}")
|
||||
|
||||
|
||||
@router.post("/upload/audio", response_model=ApiResponse[FileUploadResponse])
|
||||
async def upload_audio(
|
||||
file: UploadFile = File(..., description="音频文件(MP3, WAV, M4A, AAC, OGG)"),
|
||||
userId: str | None = Form(None, description="用户ID(可选,用于目录隔离)"),
|
||||
):
|
||||
"""
|
||||
上传音频文件
|
||||
|
||||
支持格式: MP3, WAV, M4A, AAC, OGG
|
||||
文件会自动存储到: audios/{userId}/{date}/{uuid}.{ext}
|
||||
"""
|
||||
service = get_qiniu_service()
|
||||
|
||||
# 保存临时文件
|
||||
suffix = Path(file.filename).suffix if file.filename else ".mp3"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
shutil.copyfileobj(file.file, tmp)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = service.upload_audio(tmp_path, userId=userId)
|
||||
return success_response(data=FileUploadResponse(**result))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {e}")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@router.post("/upload/video", response_model=ApiResponse[FileUploadResponse])
|
||||
async def upload_video(
|
||||
file: UploadFile = File(..., description="视频文件(MP4, MOV, AVI, WebM)"),
|
||||
userId: str | None = Form(None, description="用户ID(可选,用于目录隔离)"),
|
||||
):
|
||||
"""
|
||||
上传视频文件
|
||||
|
||||
支持格式: MP4, MOV, AVI, WebM
|
||||
文件会自动存储到: videos/{userId}/{date}/{uuid}.{ext}
|
||||
"""
|
||||
service = get_qiniu_service()
|
||||
|
||||
suffix = Path(file.filename).suffix if file.filename else ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
shutil.copyfileobj(file.file, tmp)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = service.upload_video(tmp_path, userId=userId)
|
||||
return success_response(data=FileUploadResponse(**result))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {e}")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
async def _check_existing_avatar_task(
|
||||
video_url: str,
|
||||
user_id: str,
|
||||
) -> dict | None:
|
||||
"""
|
||||
检查是否有相同视频URL的正在进行的任务(从 Redis 读取)
|
||||
|
||||
Returns:
|
||||
如果找到进行中的任务,返回 {'task_id': str, 'status': str}
|
||||
否则返回 None
|
||||
"""
|
||||
import json
|
||||
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.scheduler.registry import JobRegistry
|
||||
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
job_ids = await registry.get_running_job_ids()
|
||||
|
||||
for job_id in job_ids:
|
||||
data = await redis.hgetall(f"job:{job_id}")
|
||||
if not data:
|
||||
continue
|
||||
if data.get("type") != "avatar_clone":
|
||||
continue
|
||||
|
||||
params = {}
|
||||
if "params" in data and data["params"]:
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
params = json.loads(data["params"])
|
||||
|
||||
if params.get("user_id") == user_id and params.get("video_url") == video_url:
|
||||
avatar_status = data.get("avatar_status", data.get("status", ""))
|
||||
return {
|
||||
"task_id": job_id,
|
||||
"status": avatar_status,
|
||||
"voice_id": data.get("voice_id"),
|
||||
"provider_element_id": data.get("provider_element_id"),
|
||||
"video_url": video_url,
|
||||
"file_size": 0,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/upload/avatar", response_model=ApiResponse[FileUploadResponse])
|
||||
async def upload_avatar(
|
||||
file: UploadFile = File(..., description="形象克隆视频(MP4, MOV)"),
|
||||
userId: str | None = Form(None, description="用户ID(可选,用于目录隔离)"),
|
||||
fileHash: str | None = Form(None, description="前端计算的文件SHA256哈希,用于重复检测"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
上传形象克隆视频
|
||||
|
||||
用于形象克隆功能,上传的视频将同时用于创建自定义音色和主体。
|
||||
|
||||
KlingAI 要求:
|
||||
- 格式: MP4, MOV
|
||||
- 时长: 5-30 秒 (建议 5-8 秒)
|
||||
- 大小: 不超过 200MB
|
||||
- 分辨率: 高度 720px~2160px
|
||||
- 内容: 写实风格人物正面特写,人脸清晰、无遮挡,视频中有清晰人声
|
||||
|
||||
文件存储路径: meijiaka/avatars/{userId}/{date}/{uuid}.{ext}
|
||||
|
||||
重复检测:
|
||||
- 如果提供了 fileHash,会检查是否已有相同文件的任务在进行中
|
||||
- 返回的 isDuplicate 表示是否复用了已有资源
|
||||
- existingTaskId 表示已存在任务的ID(如果有)
|
||||
"""
|
||||
service = get_qiniu_service()
|
||||
|
||||
# 使用当前登录用户的ID
|
||||
effective_user_id = userId or str(current_user.id)
|
||||
|
||||
suffix = Path(file.filename).suffix if file.filename else ".mp4"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
shutil.copyfileobj(file.file, tmp)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = service.upload_avatar_video(
|
||||
tmp_path,
|
||||
user_id=effective_user_id,
|
||||
file_hash=fileHash,
|
||||
)
|
||||
|
||||
# 如果七牛云返回了现有文件,检查数据库中是否有进行中的任务
|
||||
if result.get("isDuplicate") and result.get("url"):
|
||||
existing_task = await _check_existing_avatar_task(result["url"], effective_user_id)
|
||||
if existing_task:
|
||||
logger.info(
|
||||
f"Found existing avatar task for uploaded file: {existing_task['task_id']}"
|
||||
)
|
||||
result["existingTaskId"] = existing_task["task_id"]
|
||||
result["message"] = "检测到相同视频的任务正在进行中,已复用现有任务"
|
||||
|
||||
return success_response(data=FileUploadResponse(**result))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Upload avatar failed")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {e}")
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
@router.get("/files/{key:path}", response_model=ApiResponse[dict])
|
||||
async def get_file_info(key: str):
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
key: 文件存储 Key(路径格式)
|
||||
"""
|
||||
try:
|
||||
service = get_qiniu_service()
|
||||
# 根据 key 推断 bucket
|
||||
bucket = service.image_bucket if "/images/" in key else service.video_bucket
|
||||
info = service.get_file_info(bucket, key)
|
||||
|
||||
if info is None:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
return success_response(data=info)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取文件信息失败: {e}")
|
||||
|
||||
|
||||
@router.delete("/files/{key:path}", response_model=ApiResponse[dict])
|
||||
async def delete_file(key: str):
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
key: 文件存储 Key
|
||||
"""
|
||||
try:
|
||||
service = get_qiniu_service()
|
||||
# 根据 key 推断 bucket
|
||||
bucket = service.image_bucket if "/images/" in key else service.video_bucket
|
||||
success = service.delete_file(bucket, key)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"success": success,
|
||||
"key": key,
|
||||
"message": "删除成功" if success else "删除失败或文件不存在",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {e}")
|
||||
|
||||
|
||||
@router.post("/refresh-cdn", response_model=ApiResponse[dict])
|
||||
async def refresh_cdn(keys: list[str]):
|
||||
"""
|
||||
刷新 CDN 缓存
|
||||
|
||||
文件更新后,调用此接口刷新 CDN 缓存,确保用户访问到最新内容。
|
||||
"""
|
||||
try:
|
||||
service = get_qiniu_service()
|
||||
result = service.refresh_cdn(keys)
|
||||
|
||||
return success_response(data=result)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"刷新 CDN 失败: {e}")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
API v1 路由聚合
|
||||
==============
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import (
|
||||
ai_models,
|
||||
auth,
|
||||
avatar,
|
||||
caption,
|
||||
klingai,
|
||||
qiniu,
|
||||
script,
|
||||
system,
|
||||
tasks,
|
||||
video,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 认证模块
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
|
||||
# 脚本模块
|
||||
api_router.include_router(script.router, prefix="/script", tags=["Script"])
|
||||
|
||||
# AI 平台管理模块
|
||||
api_router.include_router(ai_models.router, prefix="/ai", tags=["AI Models"])
|
||||
|
||||
# KlingAI 模块(视频/图像生成)
|
||||
api_router.include_router(klingai.router, tags=["KlingAI"])
|
||||
|
||||
# 七牛云对象存储模块
|
||||
api_router.include_router(qiniu.router, tags=["Qiniu Storage"])
|
||||
|
||||
# 视频生成模块
|
||||
api_router.include_router(video.router, tags=["Video"])
|
||||
|
||||
# 形象克隆模块
|
||||
api_router.include_router(avatar.router, tags=["Avatar"])
|
||||
|
||||
# 系统模块
|
||||
api_router.include_router(system.router, prefix="/system", tags=["System"])
|
||||
|
||||
# 字幕生成模块(火山引擎-豆包语音)
|
||||
api_router.include_router(caption.router, tags=["Caption"])
|
||||
|
||||
# 统一任务管理模块
|
||||
api_router.include_router(tasks.router, tags=["Tasks"])
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
脚本生成 API
|
||||
============
|
||||
|
||||
提供脚本生成、润色、模型健康检查等功能。
|
||||
支持 SSE 流式响应。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.schemas.script import (
|
||||
GenerateScriptRequest,
|
||||
ModelHealthResponse,
|
||||
PolishRequest,
|
||||
ScriptGenerationEvent,
|
||||
ScriptShot,
|
||||
TestModelRequest,
|
||||
TestModelResponse,
|
||||
)
|
||||
from app.services.script_service import get_script_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse[list[ScriptShot]])
|
||||
async def generate_script(request: GenerateScriptRequest):
|
||||
"""
|
||||
同步生成脚本
|
||||
|
||||
直接返回生成的分镜列表,适合快速预览。
|
||||
"""
|
||||
service = get_script_service()
|
||||
|
||||
shots = await service.generate_script(
|
||||
topic=request.topic,
|
||||
duration=request.duration,
|
||||
script_type=request.script_type,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=shots,
|
||||
message=f"成功生成 {len(shots)} 个分镜",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate/stream")
|
||||
async def generate_script_stream(request: Request, data: GenerateScriptRequest):
|
||||
"""
|
||||
流式生成脚本(SSE)
|
||||
|
||||
返回 Server-Sent Events,包含进度更新和最终结果。
|
||||
前端通过 EventSource 接收实时进度。
|
||||
|
||||
**SSE 事件类型:**
|
||||
- `start`: 开始生成
|
||||
- `analyzing`: 分析主题
|
||||
- `planning`: 规划结构
|
||||
- `generating`: AI 生成中
|
||||
- `parsing`: 解析结果
|
||||
- `complete`: 完成,包含 result 字段
|
||||
- `error`: 错误
|
||||
|
||||
**示例事件流:**
|
||||
```
|
||||
data: {"type": "start", "progress": 0, "message": "开始生成脚本"}
|
||||
|
||||
data: {"type": "analyzing", "progress": 15, "message": "分析目标受众..."}
|
||||
|
||||
data: {"type": "complete", "progress": 100, "message": "成功生成 5 个分镜", "result": [...]}
|
||||
```
|
||||
"""
|
||||
service = get_script_service()
|
||||
|
||||
async def event_generator():
|
||||
"""SSE 事件生成器,带客户端断开检测"""
|
||||
try:
|
||||
async for event in service.generate_script_stream(
|
||||
topic=data.topic,
|
||||
duration=data.duration,
|
||||
script_type=data.script_type,
|
||||
model=data.model,
|
||||
):
|
||||
# 检查客户端是否已断开
|
||||
if await request.is_disconnected():
|
||||
logger.info("[SSE] 客户端已断开连接,停止生成")
|
||||
break
|
||||
|
||||
# SSE 格式:data: {...}\n\n
|
||||
try:
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"[SSE] 序列化事件失败: {e}")
|
||||
continue
|
||||
|
||||
# 发送结束标记(如果客户端还连接着)
|
||||
if not await request.is_disconnected():
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("[SSE] 事件生成器异常")
|
||||
# 尝试发送错误信息给客户端
|
||||
try:
|
||||
error_event = ScriptGenerationEvent(
|
||||
type="error",
|
||||
progress=0,
|
||||
message=f"服务器错误: {str(e)}",
|
||||
)
|
||||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/polish", response_model=ApiResponse[str])
|
||||
async def polish_content(request: PolishRequest):
|
||||
"""
|
||||
AI 润色文案/画面描述
|
||||
|
||||
- `polishType=scene`: 润色画面描述(根据 shot_type 自动区分分镜/空镜)
|
||||
- `polishType=voiceover`: 润色配音文案
|
||||
|
||||
参数:
|
||||
- `shot_type`: "segment"(分镜)或 "empty_shot"(空镜),画面润色时必填
|
||||
"""
|
||||
service = get_script_service()
|
||||
|
||||
polished = await service.polish_content(
|
||||
content=request.content,
|
||||
polish_type=request.polish_type,
|
||||
shot_type=request.shot_type or "segment",
|
||||
)
|
||||
|
||||
type_name = "画面" if request.polish_type == "scene" else "文案"
|
||||
return success_response(
|
||||
data=polished,
|
||||
message=f"{type_name}润色完成",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/model-health", response_model=ApiResponse[ModelHealthResponse])
|
||||
async def check_model_health():
|
||||
"""
|
||||
检查 AI 模型健康状态
|
||||
|
||||
返回所有配置的模型及其可用性状态。
|
||||
"""
|
||||
service = get_script_service()
|
||||
health_data = await service.check_model_health()
|
||||
|
||||
return success_response(
|
||||
data=ModelHealthResponse(**health_data),
|
||||
message="模型健康检查完成",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test-model", response_model=ApiResponse[TestModelResponse])
|
||||
async def test_model(request: TestModelRequest):
|
||||
"""
|
||||
测试指定模型连接
|
||||
|
||||
发送一个简单的测试请求,验证模型是否可用。
|
||||
"""
|
||||
service = get_script_service()
|
||||
result = await service.test_model(request.model_id)
|
||||
|
||||
return success_response(
|
||||
data=TestModelResponse(**result),
|
||||
message="模型测试完成" if result["success"] else f"模型测试失败: {result.get('error')}",
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
系统模块 API
|
||||
============
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health", response_model=ApiResponse[dict])
|
||||
async def system_health():
|
||||
"""系统健康检查(详细版)"""
|
||||
return success_response(
|
||||
data={
|
||||
"status": "healthy",
|
||||
"services": {
|
||||
"api": "up",
|
||||
"database": "unknown", # TODO: 检查数据库连接
|
||||
"redis": "unknown", # TODO: 检查 Redis 连接
|
||||
},
|
||||
},
|
||||
message="系统运行正常",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/version", response_model=ApiResponse[dict])
|
||||
async def system_version():
|
||||
"""获取系统版本信息"""
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"name": settings.APP_NAME,
|
||||
"version": settings.APP_VERSION,
|
||||
"environment": settings.ENV,
|
||||
},
|
||||
message="获取版本成功",
|
||||
)
|
||||
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
统一任务管理 API
|
||||
===============
|
||||
|
||||
提供任务创建和状态查询接口,支持:
|
||||
- video: 视频生成
|
||||
- image: 图片生成
|
||||
- script: 脚本生成
|
||||
- subtitle: 字幕对齐
|
||||
- copy: 文案提取
|
||||
- avatar_clone: 形象克隆
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.models.user import User
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.schemas.enums import AvatarCloneStatus
|
||||
from app.schemas.segment import Segment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["Tasks"])
|
||||
|
||||
|
||||
# ========== 请求/响应模型 ==========
|
||||
|
||||
|
||||
class VideoParams(BaseModel):
|
||||
"""视频生成参数"""
|
||||
|
||||
segments: list[Segment] = Field(..., description="分镜列表")
|
||||
human_id: int | None = Field(None, description="数字人主体ID")
|
||||
|
||||
@field_validator("segments")
|
||||
@classmethod
|
||||
def validate_segments(cls, v: list[Segment]) -> list[Segment]:
|
||||
if not v:
|
||||
raise ValueError("segments 不能为空列表")
|
||||
return v
|
||||
|
||||
|
||||
class ImageParams(BaseModel):
|
||||
"""图片生成参数"""
|
||||
|
||||
prompt: str = Field(..., min_length=1, description="图片描述")
|
||||
image_type: str = Field(default="cover", description="图片类型: empty_shot/cover")
|
||||
reference_image: str | None = Field(None, description="参考图片URL(图生图)")
|
||||
human_id: int | None = Field(None, description="数字人主体ID(omni-image使用)")
|
||||
|
||||
@field_validator("prompt")
|
||||
@classmethod
|
||||
def validate_prompt(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("prompt 不能为空")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ScriptParams(BaseModel):
|
||||
"""脚本生成参数"""
|
||||
|
||||
topic: str = Field(..., min_length=1, description="创作主题")
|
||||
style: str = Field(default="default", description="脚本风格")
|
||||
duration: int = Field(default=60, ge=10, le=300, description="视频时长(秒)")
|
||||
|
||||
@field_validator("topic")
|
||||
@classmethod
|
||||
def validate_topic(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("topic 不能为空")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class SubtitleParams(BaseModel):
|
||||
"""字幕生成参数"""
|
||||
|
||||
video_path: str = Field(..., min_length=1, description="视频文件路径")
|
||||
language: str = Field(default="zh", description="语言代码")
|
||||
mode: str = Field(default="caption", description="模式: caption/auto_align")
|
||||
audio_text: str | None = Field(default=None, description="打轴文本(auto_align 模式必填)")
|
||||
|
||||
@field_validator("video_path")
|
||||
@classmethod
|
||||
def validate_video_path(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("video_path 不能为空")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class CopyParams(BaseModel):
|
||||
"""文案提取参数"""
|
||||
|
||||
video_url: str = Field(..., min_length=1, description="视频链接URL")
|
||||
|
||||
@field_validator("video_url")
|
||||
@classmethod
|
||||
def validate_video_url(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("video_url 不能为空")
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("video_url 必须是有效的URL")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class TaskCreateRequest(BaseModel):
|
||||
"""创建任务请求"""
|
||||
|
||||
project_id: str | None = Field(None, description="项目ID(可选)")
|
||||
params: dict = Field(default_factory=dict, description="任务参数")
|
||||
|
||||
|
||||
class TaskCreateResponse(BaseModel):
|
||||
"""创建任务响应"""
|
||||
|
||||
task_id: str = Field(..., description="任务ID")
|
||||
status: str = Field("pending", description="任务状态")
|
||||
message: str = Field("任务已创建", description="状态消息")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""任务状态响应"""
|
||||
|
||||
task_id: str = Field(..., description="任务ID")
|
||||
type: str | None = Field(None, description="任务类型")
|
||||
status: str = Field(..., description="任务状态: pending/running/waiting/completed/failed")
|
||||
progress: int = Field(0, description="进度百分比 (0-100)")
|
||||
message: str = Field("", description="状态描述")
|
||||
completed: int = Field(0, description="已完成子任务数")
|
||||
total: int = Field(0, description="总子任务数")
|
||||
result: dict | None = Field(None, description="任务结果(完成时)")
|
||||
error: str | None = Field(None, description="错误信息(失败时)")
|
||||
created_at: str = Field("", description="任务创建时间(ISO格式)")
|
||||
|
||||
|
||||
# ========== 辅助函数 ==========
|
||||
|
||||
|
||||
def _generate_task_id() -> str:
|
||||
"""生成任务ID"""
|
||||
return f"task_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
|
||||
# ========== API 路由 ==========
|
||||
|
||||
|
||||
@router.post("/{task_type}", response_model=TaskCreateResponse)
|
||||
async def create_task(
|
||||
task_type: Literal["video", "image", "script", "subtitle", "copy", "avatar_clone"],
|
||||
request: TaskCreateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TaskCreateResponse:
|
||||
"""
|
||||
创建新任务
|
||||
|
||||
根据任务类型写入 Redis,由 Async Engine Scheduler 统一调度。
|
||||
"""
|
||||
task_id = _generate_task_id()
|
||||
user_id = str(current_user.id)
|
||||
project_id = request.project_id or request.params.get("project_id", "")
|
||||
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
|
||||
try:
|
||||
await registry.create(task_id, task_type, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Failed to create registry entry: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建任务失败:Redis连接错误")
|
||||
|
||||
try:
|
||||
if task_type == "video":
|
||||
# 字段适配:前端 shots/element_id → 后端 segments/human_id
|
||||
import re
|
||||
|
||||
video_params = dict(request.params)
|
||||
if "shots" in video_params:
|
||||
shots = video_params.pop("shots")
|
||||
for s in shots:
|
||||
# 清洗 id:前端可能发送数字,Segment 模型要求 str
|
||||
if "id" in s and not isinstance(s["id"], str):
|
||||
s["id"] = str(s["id"])
|
||||
# 清洗 duration:前端可能发送 "5s",Segment 模型要求 int
|
||||
duration = s.get("duration")
|
||||
if isinstance(duration, str):
|
||||
m = re.search(r"\d+", duration)
|
||||
s["duration"] = int(m.group()) if m else None
|
||||
video_params["segments"] = shots
|
||||
if "element_id" in video_params:
|
||||
video_params["human_id"] = video_params.pop("element_id")
|
||||
validated = VideoParams(**video_params)
|
||||
segments = validated.segments
|
||||
human_id = validated.human_id
|
||||
|
||||
normalized_segments = []
|
||||
for s in segments:
|
||||
normalized_segments.append(
|
||||
{
|
||||
"id": str(s.id),
|
||||
"type": s.type,
|
||||
"scene": s.scene,
|
||||
"voiceover": s.voiceover,
|
||||
"duration": s.duration,
|
||||
"human_id": (human_id if s.type == "segment" else None),
|
||||
"voice_id": s.voice_id,
|
||||
"provider_task_id": None,
|
||||
"status": "pending",
|
||||
"video_url": None,
|
||||
"local_path": None,
|
||||
"qiniu_url": None,
|
||||
"error_message": None,
|
||||
}
|
||||
)
|
||||
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
message=f"开始生成视频,共 {len(normalized_segments)} 个镜头...",
|
||||
completed=0,
|
||||
total=len(normalized_segments),
|
||||
params={
|
||||
"project_id": project_id,
|
||||
"user_id": user_id,
|
||||
"human_id": human_id,
|
||||
"shots": json.dumps(normalized_segments, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
elif task_type == "image":
|
||||
image_validated = ImageParams(**request.params)
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
message="准备生成图片...",
|
||||
completed=0,
|
||||
total=1,
|
||||
params={
|
||||
"project_id": project_id,
|
||||
"user_id": user_id,
|
||||
"prompt": image_validated.prompt,
|
||||
"image_type": image_validated.image_type,
|
||||
"reference_image": image_validated.reference_image,
|
||||
"human_id": image_validated.human_id,
|
||||
},
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
elif task_type == "script":
|
||||
script_validated = ScriptParams(**request.params)
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
message="等待执行...",
|
||||
params={
|
||||
"topic": script_validated.topic,
|
||||
"style": script_validated.style,
|
||||
"duration": script_validated.duration,
|
||||
},
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
elif task_type == "subtitle":
|
||||
subtitle_validated = SubtitleParams(**request.params)
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
message="准备字幕生成...",
|
||||
completed=0,
|
||||
total=1,
|
||||
params={
|
||||
"project_id": project_id,
|
||||
"video_path": subtitle_validated.video_path,
|
||||
"language": subtitle_validated.language,
|
||||
"mode": subtitle_validated.mode,
|
||||
"audio_text": subtitle_validated.audio_text,
|
||||
},
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
elif task_type == "copy":
|
||||
copy_validated = CopyParams(**request.params)
|
||||
await registry.update(
|
||||
task_id,
|
||||
status="running",
|
||||
message="准备提取文案...",
|
||||
completed=0,
|
||||
total=1,
|
||||
params={"video_url": copy_validated.video_url},
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
|
||||
elif task_type == "avatar_clone":
|
||||
name = request.params.get("name", "").strip()
|
||||
video_url = request.params.get("video_url", "").strip()
|
||||
if not name:
|
||||
raise ValueError("name 不能为空")
|
||||
if not video_url:
|
||||
raise ValueError("video_url 不能为空")
|
||||
if not video_url.startswith(("http://", "https://")):
|
||||
raise ValueError("video_url 必须是有效的URL")
|
||||
|
||||
avatar_id = f"avt_{uuid.uuid4().hex[:16]}"
|
||||
now = datetime.now(UTC).isoformat()
|
||||
|
||||
# avatar_clone 使用自己的 task_id(avt_xxx),不走通用的 task_xxx
|
||||
await registry.create(avatar_id, "avatar_clone", user_id)
|
||||
await registry.update(
|
||||
avatar_id,
|
||||
status="running",
|
||||
progress=5,
|
||||
message="开始形象克隆...",
|
||||
completed=0,
|
||||
total=1,
|
||||
params={
|
||||
"avatar_id": avatar_id,
|
||||
"name": name,
|
||||
"video_url": video_url,
|
||||
"user_id": user_id,
|
||||
},
|
||||
avatar_status=AvatarCloneStatus.PENDING.value,
|
||||
avatar_name=name,
|
||||
avatar_video_url=video_url,
|
||||
voice_id="",
|
||||
provider_element_id="",
|
||||
provider_voice_job_id="",
|
||||
provider_element_job_id="",
|
||||
trial_url="",
|
||||
fail_reason="",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
await registry.add_running(avatar_id)
|
||||
# 返回的任务 ID 用 avatar_id,保持前端兼容
|
||||
task_id = avatar_id
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的任务类型: {task_type}")
|
||||
|
||||
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
||||
return TaskCreateResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message=f"{task_type} 任务已创建",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"[API] Invalid params for {task_type}: {e}")
|
||||
try:
|
||||
await registry.update(task_id, status="failed", message=f"参数错误: {e}", error=str(e))
|
||||
except Exception as registry_err:
|
||||
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
||||
raise HTTPException(status_code=422, detail=f"参数错误: {e}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[API] Failed to create task: {e}")
|
||||
try:
|
||||
await registry.update(task_id, status="failed", message=str(e), error=str(e))
|
||||
except Exception as registry_err:
|
||||
logger.warning(f"[API] Failed to update registry for {task_id}: {registry_err}")
|
||||
raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
|
||||
|
||||
|
||||
def _map_avatar_status(status: str) -> str:
|
||||
"""将 AvatarCloneStatus 映射为统一任务状态"""
|
||||
mapping = {
|
||||
"succeed": "completed",
|
||||
"voice_failed": "failed",
|
||||
"element_failed": "failed",
|
||||
"timeout": "failed",
|
||||
"pending": "running",
|
||||
"voice_processing": "running",
|
||||
"element_pending": "running",
|
||||
"element_processing": "running",
|
||||
}
|
||||
return mapping.get(status, "running")
|
||||
|
||||
|
||||
@router.get("", response_model=list[TaskStatusResponse])
|
||||
async def list_tasks(
|
||||
project_id: str | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[TaskStatusResponse]:
|
||||
"""
|
||||
查询当前用户所有进行中的任务
|
||||
|
||||
从 Redis running 集合读取真实状态,支持按 project_id 过滤。
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
|
||||
try:
|
||||
jobs = await registry.list_running_by_user(str(current_user.id))
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Redis error when listing tasks: {e}")
|
||||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||||
|
||||
results: list[TaskStatusResponse] = []
|
||||
for job in jobs:
|
||||
# 按 project_id 过滤
|
||||
if project_id and job.project_id != project_id:
|
||||
continue
|
||||
results.append(
|
||||
TaskStatusResponse(
|
||||
task_id=job.job_id,
|
||||
type=job.job_type,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
message=job.message,
|
||||
completed=job.completed,
|
||||
total=job.total,
|
||||
result=None, # 列表查询不返回 result,避免数据过大
|
||||
error=job.error,
|
||||
created_at=job.created_at,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> TaskStatusResponse:
|
||||
"""
|
||||
查询任务状态
|
||||
|
||||
前端通过轮询此接口获取任务进度。
|
||||
任务状态仅从 Redis 查询,记录过期后返回 404。
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
|
||||
try:
|
||||
job = await registry.get(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Redis error when getting task {task_id}: {e}")
|
||||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
||||
|
||||
# 权限检查
|
||||
if job.user_id != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
type=job.job_type,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
message=job.message,
|
||||
completed=job.completed,
|
||||
total=job.total,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
created_at=job.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取任务结果(简化接口,直接返回 result 字段)
|
||||
"""
|
||||
redis = get_redis_client()
|
||||
registry = JobRegistry(redis)
|
||||
|
||||
try:
|
||||
job = await registry.get(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Redis error when getting result {task_id}: {e}")
|
||||
raise HTTPException(status_code=503, detail="服务暂时不可用,请稍后重试")
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="任务不存在或已过期")
|
||||
|
||||
if job.user_id != str(current_user.id):
|
||||
raise HTTPException(status_code=403, detail="无权访问此任务")
|
||||
|
||||
if job.status != "completed":
|
||||
raise HTTPException(status_code=400, detail=f"任务未完成,当前状态: {job.status}")
|
||||
|
||||
return job.result or {}
|
||||
@@ -0,0 +1,511 @@
|
||||
"""
|
||||
视频生成 API 路由
|
||||
================
|
||||
|
||||
提供数字人视频、文生视频、图生视频功能。
|
||||
基于 KlingAI API 实现。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.config import get_settings
|
||||
from app.core.config_loader import get_config_loader
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.schemas.segment import Segment
|
||||
from app.services.kling_video_service import get_kling_video_service
|
||||
|
||||
router = APIRouter(prefix="/video", tags=["Video"])
|
||||
|
||||
# 视频文件存储目录
|
||||
VIDEO_STORAGE_DIR = Path("data/video")
|
||||
VIDEO_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 上传文件临时目录
|
||||
UPLOAD_DIR = Path("data/uploads")
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ 数据模型 ============
|
||||
|
||||
|
||||
class DigitalHuman(BaseModel):
|
||||
"""数字人信息"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
desc: str
|
||||
avatar_url: str | None = None
|
||||
type: str = "preset" # preset, custom, upload
|
||||
|
||||
|
||||
class VideoGenerateRequest(BaseModel):
|
||||
"""视频生成请求"""
|
||||
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
human_id: int | None = Field(None, description="数字人主体ID(分镜类型使用)")
|
||||
segments: list[Segment] = Field(..., description="分镜列表")
|
||||
|
||||
|
||||
class VideoGenerateResponse(BaseModel):
|
||||
"""视频生成响应"""
|
||||
|
||||
job_id: str = Field(..., description="作业ID")
|
||||
task_id: str = Field(..., description="任务ID(与job_id相同)")
|
||||
status: str = Field(..., description="作业状态")
|
||||
message: str = Field(..., description="状态消息")
|
||||
sse_url: str = Field(..., description="SSE进度流URL")
|
||||
|
||||
|
||||
class VideoJobStatus(BaseModel):
|
||||
"""视频作业状态"""
|
||||
|
||||
job_id: str
|
||||
project_id: str
|
||||
status: str # pending, processing, completed, partial, failed
|
||||
progress: int
|
||||
total_segments: int
|
||||
completed_segments: int
|
||||
failed_segments: int
|
||||
created_at: float
|
||||
updated_at: float
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ShotResult(BaseModel):
|
||||
"""单个分镜结果"""
|
||||
|
||||
segment_id: str
|
||||
type: str
|
||||
status: str
|
||||
task_id: str | None = None
|
||||
video_url: str | None = None
|
||||
local_path: str | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class VideoJobDetail(BaseModel):
|
||||
"""视频作业详情"""
|
||||
|
||||
job_id: str
|
||||
project_id: str
|
||||
status: str
|
||||
progress: int
|
||||
total_segments: int
|
||||
completed_segments: int
|
||||
failed_segments: int
|
||||
segments: list[ShotResult]
|
||||
created_at: float
|
||||
updated_at: float
|
||||
|
||||
|
||||
# ============ 内存存储 ============
|
||||
|
||||
# 数字人库
|
||||
digital_humans_db: dict[str, DigitalHuman] = {
|
||||
"dh_001": DigitalHuman(
|
||||
id="dh_001",
|
||||
name="商务男士",
|
||||
desc="专业稳重的商务形象,适合正式场合",
|
||||
type="preset",
|
||||
),
|
||||
"dh_002": DigitalHuman(
|
||||
id="dh_002",
|
||||
name="亲和女士",
|
||||
desc="温和亲切的女性形象,适合讲解分享",
|
||||
type="preset",
|
||||
),
|
||||
"dh_003": DigitalHuman(
|
||||
id="dh_003",
|
||||
name="活力青年",
|
||||
desc="年轻有活力的形象,适合轻松内容",
|
||||
type="preset",
|
||||
),
|
||||
"dh_004": DigitalHuman(
|
||||
id="dh_004", name="知性女性", desc="知性优雅的形象,适合知识分享", type="preset"
|
||||
),
|
||||
}
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
|
||||
async def get_klingai_provider() -> KlingAIProvider:
|
||||
"""获取 KlingAI Provider 实例
|
||||
|
||||
API Key 从 Settings 读取(符合配置规范)
|
||||
"""
|
||||
settings = get_settings()
|
||||
config_loader = get_config_loader()
|
||||
platform = config_loader.get_platform("klingai")
|
||||
|
||||
# 从 Settings 读取 AK/SK(符合配置规范:.env → Settings → 服务层)
|
||||
access_key = settings.KLINGAI_ACCESS_KEY
|
||||
secret_key = settings.KLINGAI_SECRET_KEY
|
||||
|
||||
if not access_key or not secret_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="KlingAI 未配置,请设置 KLINGAI_ACCESS_KEY 和 KLINGAI_SECRET_KEY",
|
||||
)
|
||||
|
||||
# 从 YAML 读取 base_url(模型配置)
|
||||
base_url = platform.base_url if platform else None
|
||||
|
||||
return KlingAIProvider(
|
||||
{
|
||||
"access_key": access_key,
|
||||
"secret_key": secret_key,
|
||||
"base_url": base_url or "https://api-beijing.klingai.com",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============ 新版 API 路由(推荐) ============
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ApiResponse[VideoGenerateResponse])
|
||||
async def create_video_generation(data: VideoGenerateRequest):
|
||||
"""
|
||||
创建视频生成任务
|
||||
|
||||
接收项目ID、数字人ID和分镜列表,创建视频生成作业。
|
||||
支持 SSE 流式查询进度。
|
||||
|
||||
**分镜类型说明:**
|
||||
- `segment`: 分镜(带数字人),使用 omni-video 接口,需要 human_id
|
||||
- `empty_shot`: 空镜,使用文生图 + 图生视频流程
|
||||
|
||||
**调用流程:**
|
||||
1. 调用此接口创建任务,获取 job_id
|
||||
2. 使用 SSE 接口 `/video/jobs/{job_id}/stream` 监听进度
|
||||
3. 或使用 `/video/jobs/{job_id}` 查询状态
|
||||
"""
|
||||
try:
|
||||
service = get_kling_video_service()
|
||||
|
||||
# 转换分镜数据
|
||||
segments_data = []
|
||||
for segment in data.segments:
|
||||
segments_data.append(
|
||||
{
|
||||
"id": segment.id,
|
||||
"type": segment.type,
|
||||
"scene": segment.scene,
|
||||
"voiceover": segment.voiceover,
|
||||
"voice_id": segment.voice_id,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建作业
|
||||
job = await service.create_job(
|
||||
project_id=data.project_id,
|
||||
human_id=data.human_id,
|
||||
segments_data=segments_data,
|
||||
)
|
||||
|
||||
# 构建SSE URL
|
||||
sse_url = f"/video/jobs/{job.job_id}/stream"
|
||||
|
||||
return success_response(
|
||||
data=VideoGenerateResponse(
|
||||
job_id=job.job_id,
|
||||
task_id=job.job_id,
|
||||
status=job.status,
|
||||
message="视频生成任务已创建",
|
||||
sse_url=sse_url,
|
||||
)
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"创建视频生成任务失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}", response_model=ApiResponse[VideoJobDetail])
|
||||
async def get_video_job(job_id: str):
|
||||
"""
|
||||
查询视频生成作业详情
|
||||
|
||||
获取指定作业的详细信息和所有分镜的处理结果。
|
||||
"""
|
||||
try:
|
||||
service = get_kling_video_service()
|
||||
job = service.get_job(job_id)
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="作业不存在")
|
||||
|
||||
# 构建分镜结果
|
||||
segments = []
|
||||
for segment in job.segments:
|
||||
segments.append(
|
||||
ShotResult(
|
||||
segment_id=segment.id,
|
||||
type=segment.type,
|
||||
status=segment.status,
|
||||
task_id=segment.provider_task_id,
|
||||
video_url=segment.video_url,
|
||||
local_path=segment.local_path,
|
||||
error_message=segment.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data=VideoJobDetail(
|
||||
job_id=job.job_id,
|
||||
project_id=job.project_id,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
total_segments=len(job.segments),
|
||||
completed_segments=sum(1 for s in job.segments if s.status.value == "completed"),
|
||||
failed_segments=sum(1 for s in job.segments if s.status.value == "failed"),
|
||||
segments=segments,
|
||||
created_at=job.created_at,
|
||||
updated_at=job.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"查询作业详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/stream")
|
||||
async def stream_video_job(job_id: str):
|
||||
"""
|
||||
SSE 流式获取视频生成进度
|
||||
|
||||
使用 Server-Sent Events 实时推送视频生成进度。
|
||||
|
||||
**事件类型:**
|
||||
- `start`: 开始生成
|
||||
- `processing`: 处理中(包含进度信息)
|
||||
- `finalizing`: 完成整理
|
||||
- `complete`: 全部完成
|
||||
- `error`: 发生错误
|
||||
|
||||
**示例:**
|
||||
```
|
||||
const eventSource = new EventSource('/api/v1/video/jobs/{job_id}/stream');
|
||||
eventSource.onmessage = (e) => {
|
||||
const data = JSON.parse(e.data);
|
||||
console.log(data.progress + '%: ' + data.message);
|
||||
};
|
||||
```
|
||||
"""
|
||||
try:
|
||||
service = get_kling_video_service()
|
||||
|
||||
# 验证作业存在
|
||||
job = service.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="作业不存在")
|
||||
|
||||
async def event_generator():
|
||||
"""SSE 事件生成器"""
|
||||
async for event in service.process_job_stream(job_id):
|
||||
yield f"data: {__import__('json').dumps(event, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"流式获取作业进度失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/status", response_model=ApiResponse[VideoJobStatus])
|
||||
async def get_video_job_status(job_id: str):
|
||||
"""
|
||||
获取视频生成作业状态(简化版)
|
||||
"""
|
||||
try:
|
||||
service = get_kling_video_service()
|
||||
status = service.get_job_status(job_id)
|
||||
|
||||
if not status:
|
||||
raise HTTPException(status_code=404, detail="作业不存在")
|
||||
|
||||
return success_response(
|
||||
data=VideoJobStatus(
|
||||
job_id=str(status["job_id"]),
|
||||
project_id=str(status["project_id"]),
|
||||
status=str(status["status"]),
|
||||
progress=int(status["progress"]), # type: ignore[arg-type]
|
||||
total_segments=int(status["total_segments"]), # type: ignore[arg-type]
|
||||
completed_segments=int(status["completed_segments"]), # type: ignore[arg-type]
|
||||
failed_segments=int(status["failed_segments"]), # type: ignore[arg-type]
|
||||
created_at=float(status["created_at"]), # type: ignore[arg-type]
|
||||
updated_at=float(status["updated_at"]), # type: ignore[arg-type]
|
||||
error_message=status.get("error_message"),
|
||||
)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取作业状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============ 数字人管理 ============
|
||||
|
||||
|
||||
@router.get("/library", response_model=ApiResponse[list[DigitalHuman]])
|
||||
async def get_digital_humans():
|
||||
"""
|
||||
获取数字人素材库
|
||||
|
||||
返回系统预设的数字人列表。
|
||||
"""
|
||||
try:
|
||||
humans = list(digital_humans_db.values())
|
||||
return success_response(data=humans)
|
||||
except Exception as e:
|
||||
logger.error(f"获取数字人库失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/upload", response_model=ApiResponse[DigitalHuman])
|
||||
async def upload_video(
|
||||
file: UploadFile = File(..., description="视频文件"),
|
||||
name: str | None = Form(None, description="数字人名称"),
|
||||
):
|
||||
"""
|
||||
上传人物视频作为数字人素材
|
||||
|
||||
文件要求:
|
||||
- 格式:mp4, mov
|
||||
- 时长:2-60秒
|
||||
- 分辨率:720p 或 1080p
|
||||
"""
|
||||
try:
|
||||
# 验证文件格式
|
||||
allowed_types = ["video/mp4", "video/quicktime", "video/x-msvideo"]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件格式: {file.content_type},请上传 mp4/mov 视频",
|
||||
)
|
||||
|
||||
# 保存文件
|
||||
file_ext = Path(file.filename or "").suffix or ".mp4"
|
||||
video_id = f"upload_{uuid.uuid4().hex[:16]}"
|
||||
video_filename = f"{video_id}{file_ext}"
|
||||
video_path = UPLOAD_DIR / video_filename
|
||||
|
||||
content = await file.read()
|
||||
video_path.write_bytes(content)
|
||||
|
||||
logger.info(f"视频上传成功: {video_path}, 大小: {len(content)} bytes")
|
||||
|
||||
# 创建数字人记录
|
||||
human = DigitalHuman(
|
||||
id=video_id,
|
||||
name=name or f"上传视频_{datetime.now().strftime('%m%d_%H%M')}",
|
||||
desc="用户上传的自定义数字人",
|
||||
type="upload",
|
||||
avatar_url=f"/api/v1/video/{video_id}/thumbnail",
|
||||
)
|
||||
|
||||
# 添加到数据库
|
||||
digital_humans_db[video_id] = human
|
||||
|
||||
return success_response(data=human)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"上传视频失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{video_id}/download")
|
||||
async def download_video(video_id: str):
|
||||
"""
|
||||
下载视频文件
|
||||
|
||||
支持三种查找位置:
|
||||
1. data/video/{video_id}.mp4 - 传统存储
|
||||
2. data/uploads/{video_id}.ext - 上传文件
|
||||
3. ~/Documents/Meijiaka/projects/*/videos/{video_id}.mp4 - 项目生成的视频
|
||||
文件名格式: scene_{shot_id}.mp4
|
||||
"""
|
||||
try:
|
||||
# 1. 首先查找传统存储位置
|
||||
video_path = VIDEO_STORAGE_DIR / f"{video_id}.mp4"
|
||||
found = False
|
||||
|
||||
if not video_path.exists():
|
||||
# 2. 尝试从上传目录查找
|
||||
for ext in [".mp4", ".mov", ".avi"]:
|
||||
candidate = UPLOAD_DIR / f"{video_id}{ext}"
|
||||
if candidate.exists():
|
||||
video_path = candidate
|
||||
found = True
|
||||
break
|
||||
else:
|
||||
found = True
|
||||
|
||||
# 3. 如果还没找到,尝试在项目视频目录中查找
|
||||
# video_id 可能是 scene_{id} 格式
|
||||
if not found:
|
||||
from app.services.kling_video_service import KlingVideoService
|
||||
|
||||
# 遍历项目目录查找(递归查找)
|
||||
base_dir = KlingVideoService.BASE_STORAGE_DIR
|
||||
if base_dir.exists():
|
||||
for project_dir in base_dir.iterdir():
|
||||
if project_dir.is_dir():
|
||||
candidate = project_dir / "videos" / f"{video_id}.mp4"
|
||||
if candidate.exists():
|
||||
video_path = candidate
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found or not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="视频文件不存在")
|
||||
|
||||
return FileResponse(path=video_path, media_type="video/mp4", filename=f"{video_id}.mp4")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"下载视频失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{video_id}/thumbnail")
|
||||
async def get_video_thumbnail(video_id: str):
|
||||
"""
|
||||
获取视频缩略图
|
||||
"""
|
||||
try:
|
||||
# 简化实现:返回占位图
|
||||
# 实际应该使用 FFmpeg 提取视频第一帧
|
||||
raise HTTPException(status_code=404, detail="缩略图功能暂未实现")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取缩略图失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
配置管理 - Pydantic Settings
|
||||
==========================
|
||||
|
||||
所有配置项通过环境变量或 .env 文件注入。
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
# 应用基础配置
|
||||
APP_NAME: str = Field(default="美家卡智影 API", description="应用名称")
|
||||
APP_VERSION: str = Field(default="0.1.0", description="应用版本")
|
||||
DEBUG: bool = Field(default=True, description="调试模式")
|
||||
ENV: Literal["development", "staging", "production"] = Field(
|
||||
default="development", description="运行环境"
|
||||
)
|
||||
|
||||
# 服务器配置
|
||||
HOST: str = Field(default="0.0.0.0", description="监听地址")
|
||||
PORT: int = Field(default=8000, description="监听端口")
|
||||
WORKERS: int = Field(default=1, description="工作进程数(生产环境建议 > 1)")
|
||||
|
||||
# 数据库配置(统一使用 PostgreSQL)
|
||||
DATABASE_URL: str = Field(
|
||||
default="postgresql+asyncpg://postgres:postgres@localhost:5432/meijiaka",
|
||||
description="数据库连接字符串(PostgreSQL)",
|
||||
)
|
||||
DATABASE_POOL_SIZE: int = Field(default=10, description="数据库连接池大小")
|
||||
DATABASE_MAX_OVERFLOW: int = Field(default=20, description="连接池溢出上限")
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST: str = Field(
|
||||
default="localhost",
|
||||
description="Redis 主机地址",
|
||||
)
|
||||
REDIS_PORT: int = Field(
|
||||
default=6379,
|
||||
description="Redis 端口",
|
||||
)
|
||||
REDIS_DB: int = Field(
|
||||
default=0,
|
||||
description="Redis 数据库编号",
|
||||
)
|
||||
REDIS_PASSWORD: str | None = Field(
|
||||
default=None,
|
||||
description="Redis 密码(无密码请留空)",
|
||||
)
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY: str = Field(
|
||||
default="your-secret-key-here-change-in-production",
|
||||
description="JWT 签名密钥(生产环境必须修改)",
|
||||
)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(
|
||||
default=60 * 24 * 7, # 7 天
|
||||
description="访问令牌过期时间(分钟)",
|
||||
)
|
||||
ALGORITHM: str = Field(default="HS256", description="JWT 算法")
|
||||
|
||||
# CORS 配置
|
||||
CORS_ORIGINS: str = Field(
|
||||
default="http://localhost:1420,http://127.0.0.1:1420,http://localhost:8080,http://127.0.0.1:8080",
|
||||
description="允许的跨域来源(逗号分隔)",
|
||||
)
|
||||
|
||||
# AI 模型配置
|
||||
# 字节跳动 - 火山方舟
|
||||
# 文档:https://www.volcengine.com/docs/82379/1399009
|
||||
VOLCENGINE_API_KEY: str | None = Field(default=None, description="火山方舟 API Key")
|
||||
VOLCENGINE_BASE_URL: str = Field(
|
||||
default="https://ark.cn-beijing.volces.com/api/v3",
|
||||
description="火山方舟 Base URL",
|
||||
)
|
||||
VOLCENGINE_MODEL: str = Field(
|
||||
default="doubao-seed-2-0-lite-260215",
|
||||
description="火山方舟默认模型(Model ID)",
|
||||
)
|
||||
|
||||
# 火山引擎音视频字幕服务
|
||||
VOLCENGINE_CAPTION_APPID: str | None = Field(default=None, description="火山字幕 AppID")
|
||||
VOLCENGINE_CAPTION_TOKEN: str | None = Field(default=None, description="火山字幕 Token")
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: str | None = Field(default=None, description="OpenAI API Key")
|
||||
OPENAI_BASE_URL: str = Field(default="https://api.openai.com/v1", description="OpenAI Base URL")
|
||||
OPENAI_DEFAULT_MODEL: str = Field(default="gpt-3.5-turbo", description="默认 OpenAI 模型")
|
||||
|
||||
# 文心一言 (百度)
|
||||
WENXIN_API_KEY: str | None = Field(default=None, description="文心一言 API Key")
|
||||
WENXIN_SECRET_KEY: str | None = Field(default=None, description="文心一言 Secret Key")
|
||||
|
||||
# 通义千问 (阿里云)
|
||||
QIANWEN_API_KEY: str | None = Field(default=None, description="通义千问 API Key")
|
||||
|
||||
# 数字人服务配置
|
||||
DIGITAL_HUMAN_PROVIDER: Literal["heygen", "did", "mock"] = Field(
|
||||
default="mock",
|
||||
description="数字人服务提供商",
|
||||
)
|
||||
HEYGEN_API_KEY: str | None = Field(default=None, description="HeyGen API Key")
|
||||
DID_API_KEY: str | None = Field(default=None, description="D-ID API Key")
|
||||
|
||||
# KlingAI 配置
|
||||
KLINGAI_ACCESS_KEY: str | None = Field(default=None, description="KlingAI Access Key")
|
||||
KLINGAI_SECRET_KEY: str | None = Field(default=None, description="KlingAI Secret Key")
|
||||
|
||||
# 七牛云存储配置
|
||||
QINIU_ACCESS_KEY: str | None = Field(default=None, description="七牛云 Access Key")
|
||||
QINIU_SECRET_KEY: str | None = Field(default=None, description="七牛云 Secret Key")
|
||||
QINIU_VIDEO_BUCKET: str = Field(default="media-liche", description="视频存储 Bucket")
|
||||
QINIU_VIDEO_DOMAIN: str = Field(default="media.liche.cn", description="视频存储域名")
|
||||
QINIU_IMAGE_BUCKET: str = Field(default="img-liche", description="图片存储 Bucket")
|
||||
QINIU_IMAGE_DOMAIN: str = Field(default="img.liche.cn", description="图片存储域名")
|
||||
|
||||
# AnyToCopy 文案提取服务
|
||||
ANYTOCOPY_API_KEY: str | None = Field(default=None, description="AnyToCopy API Key")
|
||||
ANYTOCOPY_API_SECRET: str | None = Field(default=None, description="AnyToCopy API Secret")
|
||||
ANYTOCOPY_BASE_URL: str = Field(
|
||||
default="https://api.anytocopy.com/vip/open-api/v1",
|
||||
description="AnyToCopy Base URL",
|
||||
)
|
||||
|
||||
# 视频生成配置
|
||||
DEFAULT_EMPTY_SHOT_VOICE_ID: str = Field(
|
||||
default="829826792415842333",
|
||||
description="空镜视频默认音色ID(Kling官方音色,默认:播报男声)",
|
||||
)
|
||||
|
||||
# Async Engine 槽位配置
|
||||
KLING_VIDEO_MAX_CONCURRENT: int = Field(default=18, description="Kling视频生成最大并发数")
|
||||
KLING_IMAGE_MAX_CONCURRENT: int = Field(default=9, description="Kling图片生成最大并发数")
|
||||
KLING_AVATAR_MAX_CONCURRENT: int = Field(default=2, description="Kling形象克隆最大并发数")
|
||||
ANYTOCOPY_MAX_CONCURRENT: int = Field(default=5, description="AnyToCopy文案提取最大并发数")
|
||||
VOLC_SUBTITLE_MAX_CONCURRENT: int = Field(default=5, description="火山字幕生成最大并发数")
|
||||
|
||||
# 任务超时配置(秒)
|
||||
KLING_VIDEO_TIMEOUT_PER_SHOT: int = Field(
|
||||
default=600, description="Kling视频单镜头超时时间(秒)"
|
||||
)
|
||||
KLING_IMAGE_TIMEOUT: int = Field(default=120, description="Kling图片生成超时时间(秒)")
|
||||
VOLC_SUBTITLE_TIMEOUT: int = Field(default=600, description="火山字幕生成超时时间(秒)")
|
||||
|
||||
# AnyToCopy 轮询配置
|
||||
ANYTOCOPY_POLL_INTERVAL: float = Field(default=3.0, description="AnyToCopy轮询间隔(秒)")
|
||||
ANYTOCOPY_MAX_POLL: int = Field(default=60, description="AnyToCopy最大轮询次数")
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||
default="DEBUG",
|
||||
description="日志级别",
|
||||
)
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> list[str]:
|
||||
"""将 CORS_ORIGINS 字符串解析为列表"""
|
||||
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
||||
|
||||
@property
|
||||
def use_redis(self) -> bool:
|
||||
"""是否使用 Redis"""
|
||||
return bool(self.REDIS_HOST)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例(带缓存)"""
|
||||
settings = Settings()
|
||||
|
||||
# 生产环境安全检查
|
||||
if settings.ENV == "production":
|
||||
default_keys = [
|
||||
"your-secret-key-here-change-in-production",
|
||||
"change-me-in-production",
|
||||
"secret-key",
|
||||
"",
|
||||
]
|
||||
if not settings.SECRET_KEY or settings.SECRET_KEY in default_keys:
|
||||
raise ValueError(
|
||||
"生产环境必须设置强随机 SECRET_KEY!"
|
||||
"请在 .env 文件中设置一个随机字符串(至少 32 位)。"
|
||||
)
|
||||
|
||||
# 检查 CORS 配置
|
||||
if settings.CORS_ORIGINS and "localhost" in settings.CORS_ORIGINS.lower():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"生产环境 CORS 配置中包含 localhost,建议限制为实际域名",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return settings
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
AI 模型配置加载器
|
||||
================
|
||||
|
||||
从 YAML 文件加载模型配置,支持热重载。
|
||||
API Key 从 Settings 读取(符合配置规范)。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入 YAML 库
|
||||
try:
|
||||
import yaml
|
||||
|
||||
YAML_AVAILABLE = True
|
||||
except ImportError:
|
||||
YAML_AVAILABLE = False
|
||||
logger.warning("PyYAML 未安装,使用 JSON 备选方案。安装: pip install pyyaml")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""平台配置"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
provider: str
|
||||
priority: int = 100
|
||||
base_url: str = "" # 从 YAML 读取,可选
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置"""
|
||||
|
||||
id: str
|
||||
platform_id: str
|
||||
model_name: str
|
||||
display_name: str
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
default_params: dict[str, Any] = field(default_factory=dict)
|
||||
is_enabled: bool = True
|
||||
cost_per_1k_input: float = 0.0
|
||||
cost_per_1k_output: float = 0.0
|
||||
max_tokens_limit: int = 4096
|
||||
|
||||
|
||||
class AIModelConfigLoader:
|
||||
"""AI 模型配置加载器
|
||||
|
||||
从 YAML 加载模型配置(支持热重载)。
|
||||
API Key 从 Settings 读取(通过 get_settings()),符合配置规范。
|
||||
"""
|
||||
|
||||
DEFAULT_CONFIG_PATH = (
|
||||
Path(__file__).parent.parent.parent / "config" / "ai_models.yaml"
|
||||
)
|
||||
|
||||
def __init__(self, config_path: str | None = None):
|
||||
self.config_path = (
|
||||
Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
|
||||
)
|
||||
self._platforms: dict[str, PlatformConfig] = {}
|
||||
self._models: dict[str, ModelConfig] = {}
|
||||
self._task_defaults: dict[str, str] = {}
|
||||
self._last_modified = 0
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""加载配置文件"""
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
|
||||
self._load_defaults()
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
if YAML_AVAILABLE:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
# 备选:使用 JSON
|
||||
import json
|
||||
|
||||
config = json.load(f)
|
||||
|
||||
self._parse_config(config)
|
||||
self._last_modified = self.config_path.stat().st_mtime
|
||||
logger.info(
|
||||
f"已加载模型配置: {len(self._platforms)} 平台, {len(self._models)} 模型"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e},使用默认配置")
|
||||
self._load_defaults()
|
||||
|
||||
def _parse_config(self, config: dict):
|
||||
"""解析配置(仅解析模型配置,API Key 从 Settings 读取)"""
|
||||
# 解析平台
|
||||
platforms_data = config.get("platforms", {})
|
||||
for pid, pdata in platforms_data.items():
|
||||
self._platforms[pid] = PlatformConfig(
|
||||
id=pid,
|
||||
name=pdata.get("name", pid),
|
||||
provider=pdata.get("provider", pid),
|
||||
priority=pdata.get("priority", 100),
|
||||
base_url=pdata.get("base_url", ""),
|
||||
)
|
||||
|
||||
# 解析模型
|
||||
models_data = config.get("models", {})
|
||||
for mid, mdata in models_data.items():
|
||||
self._models[mid] = ModelConfig(
|
||||
id=mid,
|
||||
platform_id=mdata.get("platform_id", ""),
|
||||
model_name=mdata.get("model_name", mid),
|
||||
display_name=mdata.get("display_name", mid),
|
||||
capabilities=mdata.get("capabilities", []),
|
||||
default_params=mdata.get("default_params", {}),
|
||||
is_enabled=mdata.get("is_enabled", True),
|
||||
cost_per_1k_input=mdata.get("cost_per_1k_input", 0.0),
|
||||
cost_per_1k_output=mdata.get("cost_per_1k_output", 0.0),
|
||||
max_tokens_limit=mdata.get("max_tokens_limit", 4096),
|
||||
)
|
||||
|
||||
# 解析任务默认映射
|
||||
self._task_defaults = config.get("task_defaults", {})
|
||||
|
||||
def _load_defaults(self):
|
||||
"""加载默认配置"""
|
||||
self._platforms = {
|
||||
"mock": PlatformConfig(
|
||||
id="mock",
|
||||
name="Mock 测试平台",
|
||||
provider="mock",
|
||||
priority=999,
|
||||
)
|
||||
}
|
||||
self._models = {
|
||||
"mock-model": ModelConfig(
|
||||
id="mock-model",
|
||||
platform_id="mock",
|
||||
model_name="mock-model",
|
||||
display_name="Mock 测试模型",
|
||||
capabilities=["script", "polish", "chat"],
|
||||
)
|
||||
}
|
||||
self._task_defaults = {
|
||||
"script": "mock-model",
|
||||
"polish": "mock-model",
|
||||
"chat": "mock-model",
|
||||
}
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置(如果文件有更新)"""
|
||||
if self.config_path.exists():
|
||||
current_mtime = self.config_path.stat().st_mtime
|
||||
if current_mtime > self._last_modified:
|
||||
logger.info("配置文件已更新,重新加载")
|
||||
self._load()
|
||||
return True
|
||||
return False
|
||||
|
||||
# ============== 查询方法 ==============
|
||||
|
||||
def get_platform(self, platform_id: str) -> PlatformConfig | None:
|
||||
"""获取平台配置"""
|
||||
return self._platforms.get(platform_id)
|
||||
|
||||
def get_all_platforms(self) -> list[PlatformConfig]:
|
||||
"""获取所有平台(按优先级排序)"""
|
||||
return sorted(self._platforms.values(), key=lambda p: p.priority)
|
||||
|
||||
def get_model(self, model_id: str) -> ModelConfig | None:
|
||||
"""获取模型配置"""
|
||||
return self._models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> list[ModelConfig]:
|
||||
"""获取所有模型"""
|
||||
return list(self._models.values())
|
||||
|
||||
def get_enabled_models(self) -> list[ModelConfig]:
|
||||
"""获取启用的模型"""
|
||||
return [m for m in self._models.values() if m.is_enabled]
|
||||
|
||||
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
|
||||
"""根据能力获取模型"""
|
||||
return [
|
||||
m
|
||||
for m in self._models.values()
|
||||
if m.is_enabled and capability in m.capabilities
|
||||
]
|
||||
|
||||
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
|
||||
"""根据平台获取模型"""
|
||||
return [
|
||||
m
|
||||
for m in self._models.values()
|
||||
if m.platform_id == platform_id and m.is_enabled
|
||||
]
|
||||
|
||||
def get_default_model_for_task(self, task_type: str) -> str | None:
|
||||
"""获取任务类型的默认模型 ID"""
|
||||
return self._task_defaults.get(task_type)
|
||||
|
||||
def set_default_model_for_task(self, task_type: str, model_id: str):
|
||||
"""设置任务类型的默认模型(内存中,不保存到文件)"""
|
||||
if model_id in self._models:
|
||||
self._task_defaults[task_type] = model_id
|
||||
|
||||
|
||||
# 全局配置加载器实例
|
||||
_config_loader: AIModelConfigLoader | None = None
|
||||
|
||||
|
||||
def get_config_loader() -> AIModelConfigLoader:
|
||||
"""获取全局配置加载器"""
|
||||
global _config_loader
|
||||
if _config_loader is None:
|
||||
_config_loader = AIModelConfigLoader()
|
||||
return _config_loader
|
||||
|
||||
|
||||
def reload_config() -> bool:
|
||||
"""重新加载配置"""
|
||||
loader = get_config_loader()
|
||||
return loader.reload()
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
自定义异常类
|
||||
============
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class AppException(HTTPException):
|
||||
"""应用基础异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str = "操作失败",
|
||||
detail: dict | None = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail=detail or {})
|
||||
self.message = message
|
||||
|
||||
|
||||
class NotFoundException(AppException):
|
||||
"""资源不存在"""
|
||||
|
||||
def __init__(self, message: str = "资源不存在"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class ValidationException(AppException):
|
||||
"""参数验证失败"""
|
||||
|
||||
def __init__(self, message: str = "参数验证失败"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class UnauthorizedException(AppException):
|
||||
"""未授权"""
|
||||
|
||||
def __init__(self, message: str = "未授权,请先登录"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenException(AppException):
|
||||
"""禁止访问"""
|
||||
|
||||
def __init__(self, message: str = "无权访问该资源"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class BusinessException(AppException):
|
||||
"""业务逻辑错误"""
|
||||
|
||||
def __init__(self, message: str = "业务操作失败"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class ModelUnavailableException(AppException):
|
||||
"""AI 模型不可用"""
|
||||
|
||||
def __init__(self, message: str = "AI 模型服务暂时不可用"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class TaskFailedException(AppException):
|
||||
"""异步任务执行失败"""
|
||||
|
||||
def __init__(self, message: str = "任务执行失败"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=message,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Redis 客户端
|
||||
============
|
||||
全局 Redis 连接,供 Scheduler 和 RateLimiter 使用
|
||||
"""
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
# 全局客户端(懒加载)
|
||||
_redis_client: Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> Redis:
|
||||
"""获取或创建 Redis 客户端"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
settings = get_settings()
|
||||
|
||||
# 构建连接参数
|
||||
client_kwargs = {
|
||||
"host": settings.REDIS_HOST,
|
||||
"port": settings.REDIS_PORT,
|
||||
"db": settings.REDIS_DB,
|
||||
"decode_responses": True,
|
||||
}
|
||||
|
||||
# 有密码时添加
|
||||
if settings.REDIS_PASSWORD:
|
||||
client_kwargs["password"] = settings.REDIS_PASSWORD
|
||||
|
||||
_redis_client = Redis(**client_kwargs)
|
||||
|
||||
return _redis_client
|
||||
|
||||
|
||||
def init_redis_client(redis: Redis) -> None:
|
||||
"""初始化全局客户端(用于测试)"""
|
||||
global _redis_client
|
||||
_redis_client = redis
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
安全工具 - JWT Token 生成与验证
|
||||
===============================
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def create_access_token(data: dict[str, Any], expires_delta: timedelta | None = None) -> str:
|
||||
"""
|
||||
创建 JWT 访问令牌
|
||||
|
||||
Args:
|
||||
data: 要编码到 Token 中的数据(通常包含 user_id)
|
||||
expires_delta: 过期时间偏移量,默认使用配置中的设置
|
||||
|
||||
Returns:
|
||||
JWT Token 字符串
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM,
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_token(token: str) -> dict[str, Any | None]:
|
||||
"""
|
||||
验证 JWT Token
|
||||
|
||||
Args:
|
||||
token: JWT Token 字符串
|
||||
|
||||
Returns:
|
||||
解码后的 payload,如果验证失败返回 None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Token 管理器 - 通用 API 认证 Token 缓存与自动刷新
|
||||
|
||||
支持:
|
||||
- JWT Token(如 KlingAI)
|
||||
- OAuth2 Access Token
|
||||
- 自定义 Token 类型
|
||||
|
||||
特性:
|
||||
- 线程/协程安全的 token 缓存
|
||||
- 自动刷新(带安全边界)
|
||||
- 后台预热机制
|
||||
- 支持多 Provider 实例隔离
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token 信息容器"""
|
||||
|
||||
token: str
|
||||
expires_at: float # 过期时间戳(秒)
|
||||
token_type: str = "Bearer"
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""是否已过期"""
|
||||
return time.time() >= self.expires_at
|
||||
|
||||
@property
|
||||
def expires_in(self) -> float:
|
||||
"""剩余有效时间(秒)"""
|
||||
return max(0, self.expires_at - time.time())
|
||||
|
||||
def is_near_expiry(self, safety_margin: float = 300) -> bool:
|
||||
"""
|
||||
是否接近过期(需要刷新)
|
||||
|
||||
Args:
|
||||
safety_margin: 安全边界(秒),默认5分钟
|
||||
"""
|
||||
return time.time() >= (self.expires_at - safety_margin)
|
||||
|
||||
|
||||
class TokenGenerator(Protocol):
|
||||
"""Token 生成函数协议"""
|
||||
|
||||
async def __call__(self) -> TokenInfo:
|
||||
"""生成/获取新的 token"""
|
||||
...
|
||||
|
||||
|
||||
class BaseTokenStrategy(ABC):
|
||||
"""Token 生成策略基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self) -> TokenInfo:
|
||||
"""生成新的 token"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_key(self) -> str:
|
||||
"""获取缓存标识(用于多实例隔离)"""
|
||||
pass
|
||||
|
||||
|
||||
class JWTTokenStrategy(BaseTokenStrategy):
|
||||
"""JWT Token 生成策略(用于 KlingAI 等)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
expires_in: int = 1800,
|
||||
algorithm: str = "HS256",
|
||||
token_type: str = "JWT",
|
||||
):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self.expires_in = expires_in # 默认30分钟
|
||||
self.algorithm = algorithm
|
||||
self.token_type = token_type
|
||||
|
||||
async def generate(self) -> TokenInfo:
|
||||
"""生成 JWT Token"""
|
||||
from jose import jwt
|
||||
|
||||
headers = {"alg": self.algorithm, "typ": self.token_type}
|
||||
current_time = int(time.time())
|
||||
payload = {
|
||||
"iss": self.access_key,
|
||||
"exp": current_time + self.expires_in,
|
||||
"nbf": current_time - 5,
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm, headers=headers)
|
||||
|
||||
return TokenInfo(
|
||||
token=token,
|
||||
expires_at=current_time + self.expires_in,
|
||||
token_type="Bearer",
|
||||
)
|
||||
|
||||
def get_cache_key(self) -> str:
|
||||
"""缓存标识:access_key 的 hash"""
|
||||
return f"jwt:{self.access_key[:8]}"
|
||||
|
||||
|
||||
class OAuth2TokenStrategy(BaseTokenStrategy):
|
||||
"""OAuth2 Token 生成策略"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scope: str | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_url = token_url
|
||||
self.scope = scope
|
||||
self.extra_params = extra_params or {}
|
||||
|
||||
async def generate(self) -> TokenInfo:
|
||||
"""从 OAuth2 服务器获取 token"""
|
||||
import httpx
|
||||
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**self.extra_params,
|
||||
}
|
||||
if self.scope:
|
||||
data["scope"] = self.scope
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
access_token = result["access_token"]
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
token_type = result.get("token_type", "Bearer")
|
||||
|
||||
return TokenInfo(
|
||||
token=access_token,
|
||||
expires_at=time.time() + expires_in,
|
||||
token_type=token_type,
|
||||
extra_data={
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in ["access_token", "expires_in", "token_type"]
|
||||
},
|
||||
)
|
||||
|
||||
def get_cache_key(self) -> str:
|
||||
"""缓存标识:client_id + token_url 的 hash"""
|
||||
return f"oauth2:{self.client_id[:8]}:{hash(self.token_url) % 10000}"
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
Token 管理器 - 单例模式,全局统一管理所有 token
|
||||
|
||||
使用示例:
|
||||
# JWT 方式(KlingAI)
|
||||
strategy = JWTTokenStrategy(access_key="xxx", secret_key="yyy")
|
||||
token = await TokenManager.get_instance().get_token(strategy)
|
||||
|
||||
# OAuth2 方式
|
||||
strategy = OAuth2TokenStrategy(
|
||||
client_id="xxx",
|
||||
client_secret="yyy",
|
||||
token_url="https://api.example.com/oauth2/token"
|
||||
)
|
||||
token = await TokenManager.get_instance().get_token(strategy)
|
||||
"""
|
||||
|
||||
_instance: TokenManager | None = None
|
||||
_lock: asyncio.Lock | None = None
|
||||
|
||||
def __new__(cls) -> TokenManager:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> TokenManager:
|
||||
"""获取单例实例"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# token 缓存: {cache_key: TokenInfo}
|
||||
self._tokens: dict[str, TokenInfo] = {}
|
||||
|
||||
# 刷新锁: {cache_key: asyncio.Lock}
|
||||
self._refresh_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# 全局锁,用于创建新的 refresh_lock
|
||||
self._global_lock = asyncio.Lock()
|
||||
|
||||
# 后台刷新任务
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# 预热配置
|
||||
self._safety_margin = 300 # 提前5分钟刷新
|
||||
self._preemptive_refresh = True # 启用预热机制
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def get_token(
|
||||
self,
|
||||
strategy: BaseTokenStrategy,
|
||||
force_refresh: bool = False,
|
||||
) -> TokenInfo:
|
||||
"""
|
||||
获取有效的 token
|
||||
|
||||
Args:
|
||||
strategy: Token 生成策略
|
||||
force_refresh: 强制刷新(忽略缓存)
|
||||
|
||||
Returns:
|
||||
TokenInfo: 有效的 token 信息
|
||||
"""
|
||||
cache_key = strategy.get_cache_key()
|
||||
|
||||
# 检查缓存
|
||||
if not force_refresh and cache_key in self._tokens:
|
||||
token_info = self._tokens[cache_key]
|
||||
if not token_info.is_near_expiry(self._safety_margin):
|
||||
logger.debug(f"Token cache hit for {cache_key}")
|
||||
return token_info
|
||||
|
||||
# 需要刷新 token
|
||||
return await self._refresh_token(strategy)
|
||||
|
||||
async def get_token_string(
|
||||
self,
|
||||
strategy: BaseTokenStrategy,
|
||||
force_refresh: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
获取 token 字符串(快捷方法)
|
||||
|
||||
Returns:
|
||||
str: token 字符串(带 Bearer 前缀)
|
||||
"""
|
||||
token_info = await self.get_token(strategy, force_refresh)
|
||||
return f"{token_info.token_type} {token_info.token}"
|
||||
|
||||
async def _refresh_token(self, strategy: BaseTokenStrategy) -> TokenInfo:
|
||||
"""
|
||||
刷新 token(带并发控制)
|
||||
|
||||
使用双重检查锁定模式,确保并发请求只触发一次刷新
|
||||
"""
|
||||
cache_key = strategy.get_cache_key()
|
||||
|
||||
# 获取或创建该 cache_key 专用的刷新锁
|
||||
async with self._global_lock:
|
||||
if cache_key not in self._refresh_locks:
|
||||
self._refresh_locks[cache_key] = asyncio.Lock()
|
||||
|
||||
refresh_lock = self._refresh_locks[cache_key]
|
||||
|
||||
async with refresh_lock:
|
||||
# 双重检查:等待锁之后,可能其他协程已经刷新过了
|
||||
if cache_key in self._tokens:
|
||||
token_info = self._tokens[cache_key]
|
||||
if not token_info.is_near_expiry(self._safety_margin):
|
||||
logger.debug(f"Token refreshed by another task for {cache_key}")
|
||||
return token_info
|
||||
|
||||
# 执行刷新
|
||||
logger.info(f"Refreshing token for {cache_key}")
|
||||
try:
|
||||
new_token = await strategy.generate()
|
||||
self._tokens[cache_key] = new_token
|
||||
|
||||
# 启动后台预热任务
|
||||
if self._preemptive_refresh:
|
||||
self._schedule_preemptive_refresh(strategy, new_token)
|
||||
|
||||
logger.info(
|
||||
f"Token refreshed successfully for {cache_key}, expires in {new_token.expires_in:.0f}s"
|
||||
)
|
||||
return new_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh token for {cache_key}: {e}")
|
||||
# 如果刷新失败但缓存的 token 还能用,返回缓存的
|
||||
if cache_key in self._tokens:
|
||||
cached = self._tokens[cache_key]
|
||||
if not cached.is_expired:
|
||||
logger.warning(
|
||||
f"Using expired cache for {cache_key} due to refresh failure"
|
||||
)
|
||||
return cached
|
||||
raise
|
||||
|
||||
def _schedule_preemptive_refresh(self, strategy: BaseTokenStrategy, token_info: TokenInfo):
|
||||
"""
|
||||
调度后台预热刷新任务
|
||||
|
||||
在 token 即将过期前自动刷新,避免请求时等待
|
||||
"""
|
||||
cache_key = strategy.get_cache_key()
|
||||
|
||||
# 计算预热时间(token 过期前 safety_margin * 2)
|
||||
refresh_at = token_info.expires_at - self._safety_margin * 2
|
||||
delay = max(0, refresh_at - time.time())
|
||||
|
||||
async def _refresh_task():
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
logger.info(f"Preemptive token refresh for {cache_key}")
|
||||
await self._refresh_token(strategy)
|
||||
except Exception as e:
|
||||
logger.error(f"Preemptive refresh failed for {cache_key}: {e}")
|
||||
|
||||
# 创建后台任务
|
||||
task = asyncio.create_task(_refresh_task())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
logger.debug(f"Scheduled preemptive refresh for {cache_key} in {delay:.0f}s")
|
||||
|
||||
async def invalidate(self, strategy: BaseTokenStrategy) -> bool:
|
||||
"""
|
||||
使缓存失效
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
cache_key = strategy.get_cache_key()
|
||||
if cache_key in self._tokens:
|
||||
del self._tokens[cache_key]
|
||||
logger.info(f"Token cache invalidated for {cache_key}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""清除所有 token 缓存"""
|
||||
self._tokens.clear()
|
||||
logger.info("All token caches cleared")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
now = time.time()
|
||||
stats = {
|
||||
"total_cached": len(self._tokens),
|
||||
"active_tasks": len(self._background_tasks),
|
||||
"tokens": {},
|
||||
}
|
||||
|
||||
for key, token_info in self._tokens.items():
|
||||
stats["tokens"][key] = {
|
||||
"expires_in": token_info.expires_in,
|
||||
"is_expired": token_info.is_expired,
|
||||
"is_near_expiry": token_info.is_near_expiry(self._safety_margin),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 便捷函数
|
||||
|
||||
|
||||
async def get_jwt_token(
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
expires_in: int = 1800,
|
||||
algorithm: str = "HS256",
|
||||
) -> TokenInfo:
|
||||
"""
|
||||
获取 JWT Token(使用全局 TokenManager)
|
||||
|
||||
示例:
|
||||
token_info = await get_jwt_token("access_key", "secret_key")
|
||||
headers = {"Authorization": f"Bearer {token_info.token}"}
|
||||
"""
|
||||
strategy = JWTTokenStrategy(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
expires_in=expires_in,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
return await TokenManager.get_instance().get_token(strategy)
|
||||
|
||||
|
||||
async def get_oauth2_token(
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scope: str | None = None,
|
||||
) -> TokenInfo:
|
||||
"""
|
||||
获取 OAuth2 Token(使用全局 TokenManager)
|
||||
|
||||
示例:
|
||||
token_info = await get_oauth2_token(
|
||||
client_id="xxx",
|
||||
client_secret="yyy",
|
||||
token_url="https://api.example.com/oauth2/token"
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {token_info.token}"}
|
||||
"""
|
||||
strategy = OAuth2TokenStrategy(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
token_url=token_url,
|
||||
scope=scope,
|
||||
)
|
||||
return await TokenManager.get_instance().get_token(strategy)
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
TokenManager 使用示例
|
||||
|
||||
展示如何在 Provider 中使用 TokenManager 来管理认证 Token。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from app.core.token_manager import (
|
||||
JWTTokenStrategy,
|
||||
OAuth2TokenStrategy,
|
||||
TokenManager,
|
||||
get_jwt_token,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_jwt():
|
||||
"""JWT Token 示例(KlingAI 模式)"""
|
||||
print("=" * 60)
|
||||
print("JWT Token 示例 (KlingAI)")
|
||||
print("=" * 60)
|
||||
|
||||
# 方法1: 使用便捷函数(推荐简单场景)
|
||||
try:
|
||||
token_info = await get_jwt_token(
|
||||
access_key="test_access_key",
|
||||
secret_key="test_secret_key",
|
||||
)
|
||||
print(f"Token: {token_info.token[:50]}...")
|
||||
print(f"Expires in: {token_info.expires_in:.0f} seconds")
|
||||
print(f"Is expired: {token_info.is_expired}")
|
||||
except Exception as e:
|
||||
print(f"JWT generation failed (expected in demo): {e}")
|
||||
|
||||
# 方法2: 使用 TokenManager + Strategy(推荐 Provider 集成)
|
||||
strategy = JWTTokenStrategy(
|
||||
access_key="your_access_key",
|
||||
secret_key="your_secret_key",
|
||||
expires_in=1800, # 30分钟
|
||||
)
|
||||
|
||||
# 第一次获取会生成新 token
|
||||
token1 = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"\nFirst token: {token1.token[:30]}...")
|
||||
|
||||
# 第二次获取会命中缓存(如果未过期)
|
||||
token2 = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"Second token: {token2.token[:30]}...")
|
||||
print(f"Same token: {token1.token == token2.token}")
|
||||
|
||||
# 查看缓存统计
|
||||
stats = TokenManager.get_instance().get_stats()
|
||||
print(f"\nCache stats: {stats}")
|
||||
|
||||
|
||||
async def example_oauth2():
|
||||
"""OAuth2 Token 示例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("OAuth2 Token 示例")
|
||||
print("=" * 60)
|
||||
|
||||
strategy = OAuth2TokenStrategy(
|
||||
client_id="your_client_id",
|
||||
client_secret="your_client_secret",
|
||||
token_url="https://api.example.com/oauth2/token",
|
||||
scope="read write",
|
||||
)
|
||||
|
||||
print("OAuth2 strategy created")
|
||||
print(f"Cache key: {strategy.get_cache_key()}")
|
||||
|
||||
|
||||
async def example_provider_integration():
|
||||
"""Provider 集成示例"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Provider 集成示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 这是一个模拟的 Provider 类
|
||||
class ExampleProvider:
|
||||
def __init__(self, access_key: str, secret_key: str):
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
self._token_strategy = JWTTokenStrategy(
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
expires_in=1800,
|
||||
)
|
||||
|
||||
async def _get_headers(self) -> dict[str, str]:
|
||||
"""获取带认证的请求头"""
|
||||
token_info = await TokenManager.get_instance().get_token(self._token_strategy)
|
||||
return {
|
||||
"Authorization": f"Bearer {token_info.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def make_request(self):
|
||||
"""模拟 API 请求"""
|
||||
headers = await self._get_headers()
|
||||
print(f"Request headers: {headers}")
|
||||
# 实际使用时: await session.post(url, headers=headers, ...)
|
||||
|
||||
provider = ExampleProvider("access_key_123", "secret_key_456")
|
||||
await provider.make_request()
|
||||
|
||||
|
||||
async def example_concurrent_requests():
|
||||
"""并发请求示例 - 测试 token 刷新时的并发安全"""
|
||||
print("\n" + "=" * 60)
|
||||
print("并发请求示例")
|
||||
print("=" * 60)
|
||||
|
||||
strategy = JWTTokenStrategy(
|
||||
access_key="concurrent_test_key",
|
||||
secret_key="concurrent_test_secret",
|
||||
expires_in=1800,
|
||||
)
|
||||
|
||||
async def request_task(task_id: int):
|
||||
"""模拟单个请求"""
|
||||
token_info = await TokenManager.get_instance().get_token(strategy)
|
||||
print(f"Task {task_id}: got token (expires in {token_info.expires_in:.0f}s)")
|
||||
return token_info
|
||||
|
||||
# 并发10个请求,应该只触发一次 token 生成
|
||||
print("Launching 10 concurrent requests...")
|
||||
results = await asyncio.gather(*[request_task(i) for i in range(10)])
|
||||
|
||||
# 验证所有请求拿到的是同一个 token
|
||||
tokens = [r.token for r in results]
|
||||
unique_tokens = set(tokens)
|
||||
print(f"\nTotal requests: {len(tokens)}")
|
||||
print(f"Unique tokens generated: {len(unique_tokens)}")
|
||||
print(f"Concurrent safety: {'✓ PASS' if len(unique_tokens) == 1 else '✗ FAIL'}")
|
||||
|
||||
|
||||
async def example_stats():
|
||||
"""查看 TokenManager 统计信息"""
|
||||
print("\n" + "=" * 60)
|
||||
print("TokenManager 统计")
|
||||
print("=" * 60)
|
||||
|
||||
manager = TokenManager.get_instance()
|
||||
stats = manager.get_stats()
|
||||
|
||||
print(f"Total cached tokens: {stats['total_cached']}")
|
||||
print(f"Active background tasks: {stats['active_tasks']}")
|
||||
print(f"Token details: {stats['tokens']}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行所有示例"""
|
||||
await example_jwt()
|
||||
await example_oauth2()
|
||||
await example_provider_integration()
|
||||
await example_concurrent_requests()
|
||||
await example_stats()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有示例完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
CRUD 模块
|
||||
========
|
||||
|
||||
统一导出所有 CRUD 实例,方便导入使用。
|
||||
|
||||
使用示例:
|
||||
from app.crud import user
|
||||
|
||||
user_obj = await user.get(db, id="xxx")
|
||||
"""
|
||||
|
||||
from app.crud.model_usage import model_usage_log
|
||||
from app.crud.user import user
|
||||
|
||||
__all__ = [
|
||||
"user",
|
||||
"model_usage_log",
|
||||
]
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Avatar CRUD 操作
|
||||
================
|
||||
|
||||
形象克隆记录的数据访问层。
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.avatar import Avatar
|
||||
from app.schemas.avatar import AvatarCreate, AvatarUpdate
|
||||
|
||||
|
||||
class CRUDAvatar(CRUDBase[Avatar, AvatarCreate, AvatarUpdate]):
|
||||
"""Avatar 数据访问对象"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(Avatar)
|
||||
|
||||
async def get_multi_by_user(
|
||||
self, db: AsyncSession, *, user_id: str, skip: int = 0, limit: int = 100
|
||||
) -> list[Avatar]:
|
||||
"""获取用户的形象列表(排除已软删除)"""
|
||||
result = await db.execute(
|
||||
select(Avatar)
|
||||
.where(Avatar.user_id == user_id)
|
||||
.where(Avatar.deleted_at.is_(None))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.order_by(Avatar.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str, commit: bool = True) -> Avatar | None:
|
||||
"""软删除形象记录"""
|
||||
obj = await self.get(db, id)
|
||||
if obj:
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
if commit:
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
else:
|
||||
await db.flush()
|
||||
return obj
|
||||
|
||||
async def get_stuck_tasks(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
processing_statuses: list[str],
|
||||
timeout_minutes: int = 30,
|
||||
limit: int = 100,
|
||||
) -> list[Avatar]:
|
||||
"""获取卡住的任务(超过指定时间未更新的处理中任务)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
processing_statuses: 需要检查的处理中状态列表
|
||||
timeout_minutes: 超时时间(分钟)
|
||||
limit: 最大返回数量
|
||||
"""
|
||||
timeout_threshold = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||
|
||||
result = await db.execute(
|
||||
select(Avatar)
|
||||
.where(Avatar.status.in_(processing_statuses))
|
||||
.where(Avatar.deleted_at.is_(None))
|
||||
.where(Avatar.updated_at < timeout_threshold)
|
||||
.limit(limit)
|
||||
.order_by(Avatar.updated_at.asc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_status_in(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
statuses: list[str],
|
||||
updated_before: datetime | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Avatar]:
|
||||
"""根据状态列表查询任务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
statuses: 状态列表
|
||||
updated_before: 更新时间早于该时间的记录
|
||||
limit: 最大返回数量
|
||||
"""
|
||||
query = select(Avatar).where(Avatar.status.in_(statuses)).where(Avatar.deleted_at.is_(None))
|
||||
|
||||
if updated_before:
|
||||
query = query.where(Avatar.updated_at < updated_before)
|
||||
|
||||
query = query.limit(limit).order_by(Avatar.updated_at.asc())
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# 全局单例
|
||||
avatar = CRUDAvatar()
|
||||
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
CRUD 基础类
|
||||
==========
|
||||
|
||||
提供通用的数据访问方法,所有业务 CRUD 必须继承此类。
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.base import BaseModel as AppBaseModel
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=AppBaseModel)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel, default=Any)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel, default=Any)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""
|
||||
通用 CRUD 基类
|
||||
|
||||
所有业务 CRUD 必须继承此类,确保接口统一。
|
||||
|
||||
使用示例:
|
||||
class UserCRUD(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def __init__(self):
|
||||
super().__init__(User)
|
||||
|
||||
# 添加业务特定方法...
|
||||
|
||||
user = UserCRUD()
|
||||
"""
|
||||
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
Args:
|
||||
model: SQLAlchemy 模型类
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> ModelType | None:
|
||||
"""根据 ID 获取单个对象"""
|
||||
result = await db.execute(select(self.model).where(self.model.id == id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> list[ModelType]:
|
||||
"""获取多个对象(分页)"""
|
||||
result = await db.execute(select(self.model).offset(skip).limit(limit))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType | dict[str, Any], commit: bool = True
|
||||
) -> ModelType:
|
||||
"""创建对象
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
obj_in: 对象数据(Pydantic 模型或字典)
|
||||
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
||||
"""
|
||||
if isinstance(obj_in, BaseModel):
|
||||
obj_in = obj_in.model_dump(exclude_unset=True)
|
||||
db_obj = self.model(**obj_in)
|
||||
db.add(db_obj)
|
||||
if commit:
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
else:
|
||||
# 不提交时刷新以获取默认值(如自增ID),但需在事务中
|
||||
await db.flush()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
commit: bool = True,
|
||||
) -> ModelType:
|
||||
"""更新对象
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
db_obj: 数据库对象
|
||||
obj_in: 更新数据(Pydantic 模型或字典)
|
||||
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
||||
"""
|
||||
if isinstance(obj_in, BaseModel):
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
else:
|
||||
update_data = obj_in
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_obj, field) and value is not None:
|
||||
setattr(db_obj, field, value)
|
||||
if commit:
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
else:
|
||||
await db.flush()
|
||||
return db_obj
|
||||
|
||||
async def delete(self, db: AsyncSession, *, id: str, commit: bool = True) -> ModelType | None:
|
||||
"""删除对象
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
id: 对象ID
|
||||
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
||||
"""
|
||||
obj = await self.get(db, id)
|
||||
if obj:
|
||||
await db.delete(obj)
|
||||
if commit:
|
||||
await db.commit()
|
||||
return obj
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""统计总数"""
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar() or 0
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
模型使用日志 CRUD 操作
|
||||
======================
|
||||
|
||||
仅保留使用日志功能,模型配置已迁移到 YAML 文件。
|
||||
"""
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.model_usage import ModelUsageLog
|
||||
|
||||
|
||||
class ModelUsageLogCRUD(CRUDBase[ModelUsageLog]):
|
||||
"""模型使用日志 CRUD"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ModelUsageLog)
|
||||
|
||||
async def get_daily_cost(self, db: AsyncSession, *, date: str) -> float:
|
||||
"""获取某日总成本"""
|
||||
result = await db.execute(
|
||||
select(func.sum(ModelUsageLog.cost_cny)).where(
|
||||
func.date(ModelUsageLog.created_at) == date
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0.0
|
||||
|
||||
async def get_by_user(
|
||||
self, db: AsyncSession, *, user_id: str, skip: int = 0, limit: int = 100
|
||||
) -> list[ModelUsageLog]:
|
||||
"""获取用户的使用日志"""
|
||||
result = await db.execute(
|
||||
select(ModelUsageLog)
|
||||
.where(ModelUsageLog.user_id == user_id)
|
||||
.order_by(ModelUsageLog.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# 导出实例
|
||||
model_usage_log = ModelUsageLogCRUD()
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
用户 CRUD 操作
|
||||
==============
|
||||
|
||||
用户认证相关的数据访问。
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class UserCRUD(CRUDBase[User]):
|
||||
"""用户数据访问对象"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(User)
|
||||
|
||||
async def get_by_mobile(self, db: AsyncSession, *, mobile: str) -> User | None:
|
||||
"""根据手机号获取用户"""
|
||||
result = await db.execute(select(User).where(User.mobile == mobile))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create_by_mobile(
|
||||
self, db: AsyncSession, *, mobile: str, nickname: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
根据手机号获取或创建用户
|
||||
|
||||
Returns:
|
||||
已存在或新创建的用户
|
||||
"""
|
||||
user = await self.get_by_mobile(db, mobile=mobile)
|
||||
|
||||
if user is None:
|
||||
# 创建新用户
|
||||
user = await self.create(
|
||||
db,
|
||||
obj_in={
|
||||
"mobile": mobile,
|
||||
"nickname": nickname or f"用户_{mobile[-4:]}",
|
||||
},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
# 导出实例
|
||||
user = UserCRUD()
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
SQLAlchemy 数据库配置
|
||||
====================
|
||||
|
||||
统一使用 PostgreSQL + 异步模式。
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
async_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.DEBUG,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""获取异步数据库 Session
|
||||
|
||||
注意:commit 由调用方(API层或Service层)控制,不在此自动提交
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""初始化数据库"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""关闭数据库连接"""
|
||||
await async_engine.dispose()
|
||||
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
FastAPI 应用入口
|
||||
================
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.v1.router import api_router
|
||||
from app.config import get_settings
|
||||
from app.db.session import close_db, init_db
|
||||
from app.schemas.common import ApiResponse
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# 配置日志 - 同时输出到控制台和文件
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
log_level = getattr(logging, settings.LOG_LEVEL)
|
||||
|
||||
# 创建日志目录(在用户文档目录下)
|
||||
log_dir = Path.home() / "Documents" / "Meijiaka" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 日志文件名按日期
|
||||
log_file = log_dir / f"api_{datetime.now().strftime('%Y%m%d')}.log"
|
||||
|
||||
# 配置根日志记录器
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format=log_format,
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout), # 控制台输出
|
||||
logging.FileHandler(log_file, encoding="utf-8", mode="a"), # 文件输出
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"日志文件位置: {log_file}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用生命周期管理
|
||||
|
||||
- 启动时:初始化数据库、加载模型配置
|
||||
- 关闭时:清理资源
|
||||
"""
|
||||
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
|
||||
|
||||
# 开发环境自动创建表
|
||||
if settings.DEBUG and settings.ENV == "development":
|
||||
logger.info("Initializing database tables...")
|
||||
try:
|
||||
# 确保所有模型已注册到 metadata
|
||||
from app.models import Avatar, ModelUsageLog, User # noqa: F401
|
||||
|
||||
await init_db()
|
||||
logger.info("Database tables initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Database initialization skipped: {e}")
|
||||
|
||||
# 加载 AI 模型配置(从 YAML 文件)
|
||||
try:
|
||||
from app.core.config_loader import get_config_loader
|
||||
|
||||
config_loader = get_config_loader()
|
||||
platforms_count = len(config_loader.get_all_platforms())
|
||||
models_count = len(config_loader.get_enabled_models())
|
||||
|
||||
logger.info(f"Loaded {platforms_count} platforms, {models_count} models from config file")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load models from config: {e}")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭时清理
|
||||
logger.info("Shutting down...")
|
||||
await close_db()
|
||||
logger.info("Cleanup complete")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""创建 FastAPI 应用实例"""
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
description="美家卡智影 - AI 视频创作后端 API",
|
||||
docs_url="/docs" if settings.DEBUG else None,
|
||||
redoc_url="/redoc" if settings.DEBUG else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS 配置
|
||||
# 开发环境下允许所有来源,避免跨域问题
|
||||
allow_origins = ["*"] if settings.DEBUG else settings.cors_origins_list
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
# 全局异常处理(统一返回 ApiResponse 格式)
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request, exc):
|
||||
"""全局异常捕获"""
|
||||
logger.exception("Unhandled exception")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"code": 500,
|
||||
"message": "服务器内部错误",
|
||||
"data": None,
|
||||
"detail": {"error": str(exc)} if settings.DEBUG else None,
|
||||
},
|
||||
)
|
||||
|
||||
# 健康检查
|
||||
@app.get("/health", tags=["System"])
|
||||
async def health_check():
|
||||
"""服务健康检查"""
|
||||
return ApiResponse(
|
||||
code=200,
|
||||
data={
|
||||
"status": "healthy",
|
||||
"version": settings.APP_VERSION,
|
||||
"environment": settings.ENV,
|
||||
},
|
||||
message="服务运行正常",
|
||||
)
|
||||
|
||||
# 根路由
|
||||
@app.get("/", tags=["System"])
|
||||
async def root():
|
||||
"""API 根路径"""
|
||||
return ApiResponse(
|
||||
code=200,
|
||||
data={
|
||||
"name": settings.APP_NAME,
|
||||
"version": settings.APP_VERSION,
|
||||
"docs": "/docs" if settings.DEBUG else None,
|
||||
},
|
||||
message="美家卡智影 API 服务",
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app()
|
||||
|
||||
|
||||
def main():
|
||||
"""入口函数(用于命令行启动)"""
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
workers=settings.WORKERS if not settings.DEBUG else 1,
|
||||
reload=settings.DEBUG,
|
||||
log_level=settings.LOG_LEVEL.lower(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# test
|
||||
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
模型模块
|
||||
|
||||
所有 SQLAlchemy 模型定义。
|
||||
|
||||
注意:AIModel/AIPlatform 已迁移到 YAML 配置 (config/ai_models.yaml)
|
||||
"""
|
||||
|
||||
from app.models.avatar import Avatar
|
||||
from app.models.base import BaseModel
|
||||
from app.models.model_usage import ModelUsageLog
|
||||
from app.models.user import User
|
||||
|
||||
# 当前可用的模型
|
||||
__all__ = [
|
||||
"Avatar",
|
||||
"BaseModel",
|
||||
"ModelUsageLog",
|
||||
"User",
|
||||
]
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Avatar 形象克隆模型
|
||||
==================
|
||||
|
||||
存储用户克隆形象的信息,作为本地 localStorage 的云端备份。
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.session import Base
|
||||
from app.schemas.enums import AvatarCloneStatus
|
||||
|
||||
|
||||
class Avatar(Base):
|
||||
"""
|
||||
形象克隆记录表
|
||||
|
||||
用于备份用户在本地创建的克隆形象,支持换机恢复和客服排查。
|
||||
"""
|
||||
|
||||
__tablename__ = "avatars"
|
||||
|
||||
# 主键:本地生成的唯一标识(与 Kling element_id 无关)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
primary_key=True,
|
||||
comment="本地形象唯一标识(如 avt_xxx)",
|
||||
)
|
||||
|
||||
# 关联用户(外键,对应 users.id)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="关联用户 ID",
|
||||
)
|
||||
|
||||
# 形象展示名称
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
nullable=False,
|
||||
comment="形象展示名称",
|
||||
)
|
||||
|
||||
# 供应商标识
|
||||
provider: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
nullable=False,
|
||||
default="kling",
|
||||
comment="供应商标识: kling",
|
||||
)
|
||||
|
||||
# Kling 自定义音色 ID(创建成功后回填)
|
||||
voice_id: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
comment="Kling 自定义音色 ID",
|
||||
)
|
||||
|
||||
# 供应商主体 ID(创建成功后回填,用于调用 omni-video API)
|
||||
provider_element_id: Mapped[int | None] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
comment="供应商主体 ID(数字类型,调用 API 时使用)",
|
||||
)
|
||||
|
||||
# 供应商任务 ID(用于客服追溯)
|
||||
provider_voice_job_id: Mapped[str | None] = mapped_column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="供应商自定义音色任务 ID",
|
||||
)
|
||||
|
||||
provider_element_job_id: Mapped[str | None] = mapped_column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="供应商主体创建任务 ID",
|
||||
)
|
||||
|
||||
# 资源地址
|
||||
video_url: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
nullable=False,
|
||||
comment="原始人物视频 URL",
|
||||
)
|
||||
|
||||
trial_url: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="音色试听音频 URL",
|
||||
)
|
||||
|
||||
# 状态机
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
nullable=False,
|
||||
default=AvatarCloneStatus.PENDING.value,
|
||||
comment="状态: pending/voice_processing/voice_failed/element_processing/element_failed/succeed/timeout",
|
||||
)
|
||||
|
||||
# 失败原因(用户可读)
|
||||
fail_reason: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="失败原因(中文可读)",
|
||||
)
|
||||
|
||||
# 软删除标记
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="软删除时间,NULL 表示未删除",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
comment="记录创建时间",
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
comment="记录更新时间",
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {column.name: getattr(self, column.name) for column in self.__table__.columns}
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
基础模型定义
|
||||
============
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
"""
|
||||
基础模型 - 所有模型继承此类
|
||||
|
||||
提供:
|
||||
- UUID 主键(自动生成)
|
||||
- 创建时间
|
||||
- 更新时间
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {column.name: getattr(self, column.name) for column in self.__table__.columns}
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
AI 模型使用日志模型
|
||||
==================
|
||||
|
||||
存储模型调用的使用日志,用于成本统计和监控。
|
||||
|
||||
模型配置已迁移到 YAML 文件:config/ai_models.yaml
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
class ModelUsageLog(Base):
|
||||
"""模型使用日志 - 用于成本统计和监控"""
|
||||
|
||||
__tablename__ = "model_usage_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 调用信息
|
||||
model_id = Column(String(100), nullable=False)
|
||||
platform_id = Column(String(50), nullable=False)
|
||||
|
||||
# 调用类型
|
||||
task_type = Column(String(50), nullable=False) # script、polish、chat
|
||||
|
||||
# Token 用量
|
||||
prompt_tokens = Column(Integer, default=0)
|
||||
completion_tokens = Column(Integer, default=0)
|
||||
total_tokens = Column(Integer, default=0)
|
||||
|
||||
# 成本(计算后的人民币)
|
||||
cost_cny = Column(Float, default=0.0)
|
||||
|
||||
# 性能
|
||||
response_time_ms = Column(Integer, nullable=True) # 响应时间
|
||||
|
||||
# 结果
|
||||
success = Column(Boolean, default=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# 用户/项目
|
||||
user_id = Column(String(50), nullable=True)
|
||||
project_id = Column(String(50), nullable=True)
|
||||
|
||||
# 时间
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# 索引定义
|
||||
__table_args__ = (
|
||||
# 索引:按用户查询使用记录
|
||||
Index("ix_model_usage_logs_user_id", "user_id"),
|
||||
# 索引:按时间查询(用于统计)
|
||||
Index("ix_model_usage_logs_created_at", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UsageLog {self.model_id}: {self.total_tokens} tokens>"
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
用户模型
|
||||
========
|
||||
|
||||
采用"手机号 + JWT"的传统认证方案。
|
||||
"""
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""用户表"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
# 手机号,作为登录账号
|
||||
mobile: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
comment="手机号",
|
||||
)
|
||||
|
||||
nickname: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
comment="用户昵称",
|
||||
)
|
||||
|
||||
avatar_url: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="头像 URL",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User(id={self.id}, mobile={self.mobile}, nickname={self.nickname})>"
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
统一异步任务调度器
|
||||
==================
|
||||
|
||||
统一异步任务调度器(替代原 Celery 架构)。
|
||||
"""
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Async Engine 核心调度器
|
||||
=======================
|
||||
|
||||
驱动所有 Handler 的 Tick 循环,批量查询、批量更新。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncEngine:
|
||||
"""统一异步作业调度引擎"""
|
||||
|
||||
def __init__(self, handlers: list[AsyncHandler] | None = None):
|
||||
self.redis = get_redis_client()
|
||||
self.registry = JobRegistry(self.redis)
|
||||
self.slots = SlotManager(self.redis)
|
||||
self.handlers: dict[str, AsyncHandler] = {}
|
||||
if handlers:
|
||||
for h in handlers:
|
||||
self.handlers[h.name] = h
|
||||
|
||||
def register(self, handler: AsyncHandler) -> None:
|
||||
"""注册一个 Handler"""
|
||||
self.handlers[handler.name] = handler
|
||||
logger.info(f"Registered handler: {handler.name}")
|
||||
|
||||
async def tick(self) -> None:
|
||||
"""执行一次完整的调度 Tick"""
|
||||
tick_start = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# 1. 加载所有 running 的作业 ID
|
||||
running_ids = await self.registry.get_running_job_ids()
|
||||
if not running_ids:
|
||||
logger.debug("Tick: no running jobs")
|
||||
return
|
||||
|
||||
# 2. 按 job_type 分组
|
||||
jobs_by_type: dict[str, list[Any]] = {}
|
||||
for job_id in running_ids:
|
||||
record = await self.registry.get(job_id)
|
||||
if not record:
|
||||
await self.registry.remove_running(job_id)
|
||||
continue
|
||||
jobs_by_type.setdefault(record.job_type, []).append(record)
|
||||
|
||||
# 3. 并行执行各 Handler 的 tick
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._safe_tick(handler_name, handler, jobs_by_type.get(handler_name, []))
|
||||
for handler_name, handler in self.handlers.items()
|
||||
]
|
||||
)
|
||||
|
||||
# 4. 收集并应用状态变更
|
||||
for changes in results:
|
||||
if changes:
|
||||
await self._apply_changes(changes)
|
||||
|
||||
# 5. 清理已结束的作业
|
||||
await self._cleanup_finished()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Scheduler tick failed")
|
||||
finally:
|
||||
elapsed = asyncio.get_event_loop().time() - tick_start
|
||||
logger.debug(f"Tick completed in {elapsed:.2f}s")
|
||||
|
||||
async def _safe_tick(
|
||||
self, name: str, handler: AsyncHandler, jobs: list[Any]
|
||||
) -> list[StateChange]:
|
||||
"""安全执行 Handler tick,捕获异常"""
|
||||
try:
|
||||
return await handler.tick(jobs, self.registry, self.slots)
|
||||
except Exception:
|
||||
logger.exception(f"Handler tick failed: {name}")
|
||||
return []
|
||||
|
||||
async def _apply_changes(self, changes: list[StateChange]) -> None:
|
||||
"""批量应用状态变更到 Redis"""
|
||||
pipe = self.redis.pipeline()
|
||||
executed = False
|
||||
for change in changes:
|
||||
key, field, value = change.to_redis_command()
|
||||
pipe.hset(key, field, value)
|
||||
executed = True
|
||||
if executed:
|
||||
await pipe.execute()
|
||||
|
||||
async def _cleanup_finished(self) -> None:
|
||||
"""清理已完成的作业"""
|
||||
running_ids = await self.registry.get_running_job_ids()
|
||||
for job_id in running_ids:
|
||||
record = await self.registry.get(job_id)
|
||||
if not record:
|
||||
await self.registry.remove_running(job_id)
|
||||
continue
|
||||
if record.status in ("completed", "failed"):
|
||||
await self.registry.remove_running(job_id)
|
||||
logger.info(f"Job moved to finished: {job_id} ({record.status})")
|
||||
|
||||
async def run_forever(self, interval: float = 10.0, min_interval: float = 2.0) -> None:
|
||||
"""启动无限 Tick 循环"""
|
||||
logger.info("Async Engine started")
|
||||
while True:
|
||||
tick_start = asyncio.get_event_loop().time()
|
||||
await self.tick()
|
||||
elapsed = asyncio.get_event_loop().time() - tick_start
|
||||
sleep_time = max(interval - elapsed, min_interval)
|
||||
await asyncio.sleep(sleep_time)
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Scheduler Handlers
|
||||
"""
|
||||
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
Avatar 形象克隆处理器
|
||||
====================
|
||||
|
||||
管理 Kling 形象克隆的提交与轮询。
|
||||
占用全局槽位:2
|
||||
|
||||
数据策略:不操作数据库,所有中间状态存储在 Redis 中。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.config import get_settings
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.schemas.enums import AvatarCloneStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "kling:avatar_slots"
|
||||
MAX_SLOTS = 2
|
||||
|
||||
SYSTEM_BUSY_MESSAGE = "系统繁忙,请稍后重试"
|
||||
SYSTEM_ERROR_MESSAGE = "系统处理异常,请稍后重试或联系客服"
|
||||
|
||||
|
||||
def _get_kling_provider() -> KlingAIProvider:
|
||||
settings = get_settings()
|
||||
return KlingAIProvider(
|
||||
config={
|
||||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _translate_voice_error(message: str) -> str:
|
||||
msg = (message or "").lower()
|
||||
if "no valid audio" in msg or "audio" in msg or "voice" in msg or "人声" in msg:
|
||||
return "自定义音色创建失败:视频中没有检测到清晰的人声。请确保上传「有声的人物视频」,且人声干净、无杂音、背景噪音小。"
|
||||
if "duration" in msg or "时长" in msg:
|
||||
return "自定义音色创建失败:视频时长不符合要求。请使用 5-30 秒的视频。"
|
||||
if "format" in msg or "格式" in msg:
|
||||
return "自定义音色创建失败:视频格式不支持。请使用 MP4 或 MOV 格式。"
|
||||
if "size" in msg or "大小" in msg or "mb" in msg:
|
||||
return "自定义音色创建失败:视频文件过大。请压缩至 200MB 以内。"
|
||||
if "quality" in msg or "质量" in msg:
|
||||
return "自定义音色创建失败:视频/音频质量不符合要求。请确保画面清晰、人声干净、无强烈背景噪音。"
|
||||
return f"自定义音色创建失败:{message}。请检查是否上传了符合要求的「有声的人物视频」。"
|
||||
|
||||
|
||||
def _translate_element_error(message: str) -> str:
|
||||
msg = (message or "").lower()
|
||||
if "duration" in msg or "时长" in msg:
|
||||
return "主体创建失败:视频时长不符合要求。请使用 3-8 秒的人物特写视频。"
|
||||
if "resolution" in msg or "height" in msg or "像素" in msg or "720" in msg or "2160" in msg:
|
||||
return "主体创建失败:视频分辨率不符合要求。请确保视频高度在 720px~2160px 之间。"
|
||||
if "size" in msg or "大小" in msg or "mb" in msg or "200" in msg:
|
||||
return "主体创建失败:视频文件过大。请压缩至 200MB 以内。"
|
||||
if "format" in msg or "格式" in msg or "mp4" in msg or "mov" in msg:
|
||||
return "主体创建失败:视频格式不支持。请使用 MP4 或 MOV 格式。"
|
||||
if "face" in msg or "人脸" in msg or "detect" in msg or "主体" in msg:
|
||||
return "主体创建失败:未能从视频中检测到稳定的人脸。请确保视频为「写实风格的人物正面特写」,人脸清晰、无遮挡、光线充足。"
|
||||
if "human" in msg or "人形" in msg or "character" in msg or "写实" in msg:
|
||||
return "主体创建失败:视频内容不符合要求。请确保视频中是「写实风格的真实人物」,非卡通、非动物、非虚拟形象。"
|
||||
return f"主体创建失败:{message}。请检查视频是否为 3-8 秒、人脸清晰、写实风格的正面人物视频。"
|
||||
|
||||
|
||||
def _translate_system_error(error: Exception, step: str) -> tuple[str, str]:
|
||||
error_str = str(error)
|
||||
error_type = type(error).__name__
|
||||
if isinstance(error, aiohttp.ClientError | asyncio.TimeoutError):
|
||||
return SYSTEM_BUSY_MESSAGE, f"[{step}] 网络错误: {error_type}: {error_str}"
|
||||
if "500" in error_str or "503" in error_str or "502" in error_str:
|
||||
return SYSTEM_BUSY_MESSAGE, f"[{step}] KlingAI 服务错误: {error_type}: {error_str}"
|
||||
if (
|
||||
"rate limit" in error_str.lower()
|
||||
or "too many requests" in error_str.lower()
|
||||
or "429" in error_str
|
||||
):
|
||||
return SYSTEM_BUSY_MESSAGE, f"[{step}] API 限流: {error_type}: {error_str}"
|
||||
return SYSTEM_ERROR_MESSAGE, f"[{step}] 系统错误: {error_type}: {error_str}"
|
||||
|
||||
|
||||
async def _update_avatar_state(registry: JobRegistry, avatar_id: str, **fields: Any) -> None:
|
||||
"""更新 Redis 中的 avatar 状态(同时更新 updated_at)"""
|
||||
fields["updated_at"] = datetime.now(UTC).isoformat()
|
||||
await registry.update(avatar_id, **fields)
|
||||
|
||||
|
||||
class AvatarHandler(AsyncHandler):
|
||||
name = "avatar_clone"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
for job in jobs:
|
||||
job_changes = await self._process_job(job, registry, slots)
|
||||
changes.extend(job_changes)
|
||||
return changes
|
||||
|
||||
async def _process_job(
|
||||
self, job: Any, registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
avatar_id = job.job_id
|
||||
|
||||
# 从 Redis 读取 avatar 状态
|
||||
redis = get_redis_client()
|
||||
state_raw = await redis.hgetall(f"job:{avatar_id}")
|
||||
if not state_raw:
|
||||
logger.error(f"Avatar job not found in Redis: {avatar_id}")
|
||||
_msg = "任务记录丢失,请重新提交"
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="status", value="failed"))
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="message", value=_msg))
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_msg))
|
||||
return changes
|
||||
|
||||
# 解析 params
|
||||
params = {}
|
||||
if "params" in state_raw and state_raw["params"]:
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
params = json.loads(state_raw["params"])
|
||||
|
||||
status = state_raw.get("avatar_status", state_raw.get("status", ""))
|
||||
provider = _get_kling_provider()
|
||||
|
||||
# 辅助函数:读取字段
|
||||
def _f(key: str) -> str:
|
||||
return state_raw.get(key, "") or ""
|
||||
|
||||
# ---------- pending: 创建音色 ----------
|
||||
if status == AvatarCloneStatus.PENDING.value:
|
||||
slot_id = f"avatar:{avatar_id}"
|
||||
acquired = await slots.acquire(SLOT_KEY, slot_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
return changes # 槽位已满,等下一轮
|
||||
|
||||
try:
|
||||
await _update_avatar_state(
|
||||
registry, avatar_id, avatar_status=AvatarCloneStatus.VOICE_PROCESSING.value
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=avatar_id, field_path="message", value="正在创建自定义音色..."
|
||||
)
|
||||
)
|
||||
voice_result = await provider.create_custom_voice(
|
||||
voice_name=params.get("name", ""),
|
||||
video_url=params.get("video_url", ""),
|
||||
)
|
||||
voice_task_id = voice_result.get("task_id")
|
||||
if not voice_task_id:
|
||||
raise Exception("未返回音色任务 ID")
|
||||
await _update_avatar_state(registry, avatar_id, provider_voice_job_id=voice_task_id)
|
||||
logger.info(f"Avatar {avatar_id}: created voice task {voice_task_id}")
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, slot_id)
|
||||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||||
code in str(e) for code in ["500", "503", "502", "429"]
|
||||
):
|
||||
user_msg, cloud_detail = _translate_system_error(e, "voice_create")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||||
fail_reason=user_msg,
|
||||
)
|
||||
logger.error(f"Avatar {avatar_id} voice_create system error: {cloud_detail}")
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||||
)
|
||||
else:
|
||||
_reason = _translate_voice_error(str(e))
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
|
||||
# ---------- voice_processing: 轮询音色 ----------
|
||||
elif status == AvatarCloneStatus.VOICE_PROCESSING.value:
|
||||
provider_voice_job_id = _f("provider_voice_job_id")
|
||||
if not provider_voice_job_id:
|
||||
return changes
|
||||
try:
|
||||
result = await provider.get_custom_voice_task(provider_voice_job_id)
|
||||
kling_status = result.get("task_status", "processing")
|
||||
logger.info(
|
||||
f"Avatar {avatar_id}: voice task {provider_voice_job_id} status={kling_status}"
|
||||
)
|
||||
if kling_status == "processing":
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value="音色处理中...")
|
||||
)
|
||||
elif kling_status == "succeed":
|
||||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||||
task_result = result.get("task_result", {})
|
||||
voices = task_result.get("voices", [])
|
||||
voice_id = None
|
||||
trial_url = None
|
||||
if voices:
|
||||
voice_info = voices[0]
|
||||
voice_id = voice_info.get("voice_id") or voice_info.get("id")
|
||||
trial_url = (
|
||||
voice_info.get("trial_url")
|
||||
or voice_info.get("preview_url")
|
||||
or voice_info.get("voice_url")
|
||||
)
|
||||
if not voice_id:
|
||||
raise Exception("音色任务成功但未返回 voice_id")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_PENDING.value,
|
||||
voice_id=voice_id,
|
||||
trial_url=trial_url or "",
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=avatar_id,
|
||||
field_path="message",
|
||||
value="音色创建成功,准备创建形象主体...",
|
||||
)
|
||||
)
|
||||
logger.info(f"Avatar {avatar_id}: voice succeed, voice_id={voice_id}")
|
||||
|
||||
elif kling_status == "failed":
|
||||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||||
error_msg = result.get("task_msg", "任务执行失败")
|
||||
_reason = _translate_voice_error(error_msg)
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
except Exception as e:
|
||||
logger.exception(f"Avatar {avatar_id}: voice poll error")
|
||||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||||
code in str(e) for code in ["500", "503", "502", "429"]
|
||||
):
|
||||
user_msg, cloud_detail = _translate_system_error(e, "voice_poll")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||||
fail_reason=user_msg,
|
||||
)
|
||||
logger.error(f"Avatar {avatar_id} voice_poll system error: {cloud_detail}")
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||||
)
|
||||
else:
|
||||
_reason = _translate_voice_error(str(e))
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.VOICE_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
|
||||
# ---------- element_pending: 创建主体 ----------
|
||||
elif status == AvatarCloneStatus.ELEMENT_PENDING.value:
|
||||
slot_id = f"avatar:{avatar_id}"
|
||||
acquired = await slots.acquire(SLOT_KEY, slot_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
return changes
|
||||
|
||||
try:
|
||||
await _update_avatar_state(
|
||||
registry, avatar_id, avatar_status=AvatarCloneStatus.ELEMENT_PROCESSING.value
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value="正在创建形象主体...")
|
||||
)
|
||||
element_result = await provider.create_element(
|
||||
element_name=params.get("name", ""),
|
||||
element_description=f"{params.get('name', '')} 的克隆形象",
|
||||
reference_type="video_refer",
|
||||
element_video_list={
|
||||
"refer_videos": [{"video_url": params.get("video_url", "")}]
|
||||
},
|
||||
element_voice_id=_f("voice_id"),
|
||||
)
|
||||
element_task_id = element_result.get("task_id")
|
||||
if not element_task_id:
|
||||
raise Exception("未返回主体任务 ID")
|
||||
await _update_avatar_state(
|
||||
registry, avatar_id, provider_element_job_id=element_task_id
|
||||
)
|
||||
logger.info(f"Avatar {avatar_id}: created element task {element_task_id}")
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, slot_id)
|
||||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||||
code in str(e) for code in ["500", "503", "502", "429"]
|
||||
):
|
||||
user_msg, cloud_detail = _translate_system_error(e, "element_create")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
fail_reason=user_msg,
|
||||
)
|
||||
logger.error(f"Avatar {avatar_id} element_create system error: {cloud_detail}")
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||||
)
|
||||
else:
|
||||
_reason = _translate_element_error(str(e))
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
|
||||
# ---------- element_processing: 轮询主体 ----------
|
||||
elif status == AvatarCloneStatus.ELEMENT_PROCESSING.value:
|
||||
provider_element_job_id = _f("provider_element_job_id")
|
||||
if not provider_element_job_id:
|
||||
return changes
|
||||
try:
|
||||
result = await provider.get_element_task(provider_element_job_id)
|
||||
kling_status = result.get("task_status", "processing")
|
||||
logger.info(
|
||||
f"Avatar {avatar_id}: element task {provider_element_job_id} status={kling_status}"
|
||||
)
|
||||
if kling_status == "processing":
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=avatar_id, field_path="message", value="形象主体处理中..."
|
||||
)
|
||||
)
|
||||
elif kling_status == "succeed":
|
||||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||||
task_result = result.get("task_result", {})
|
||||
elements = task_result.get("elements", [])
|
||||
element_id = None
|
||||
if elements:
|
||||
element_id = elements[0].get("element_id")
|
||||
if not element_id:
|
||||
element_id = task_result.get("element_id")
|
||||
if not element_id:
|
||||
raise Exception("主体任务成功但未返回 element_id")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.SUCCEED.value,
|
||||
provider_element_id=str(element_id),
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=avatar_id,
|
||||
field_path="result",
|
||||
value={
|
||||
"avatar_id": avatar_id,
|
||||
"name": params.get("name", ""),
|
||||
"video_url": params.get("video_url", ""),
|
||||
"voice_id": _f("voice_id"),
|
||||
"element_id": int(element_id),
|
||||
"trial_url": _f("trial_url"),
|
||||
},
|
||||
)
|
||||
)
|
||||
logger.info(f"Avatar {avatar_id}: element succeed, element_id={element_id}")
|
||||
|
||||
elif kling_status == "failed":
|
||||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||||
error_msg = result.get("task_msg", "任务执行失败")
|
||||
_reason = _translate_element_error(error_msg)
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
except Exception as e:
|
||||
logger.exception(f"Avatar {avatar_id}: element poll error")
|
||||
if isinstance(e, aiohttp.ClientError | asyncio.TimeoutError) or any(
|
||||
code in str(e) for code in ["500", "503", "502", "429"]
|
||||
):
|
||||
user_msg, cloud_detail = _translate_system_error(e, "element_poll")
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
fail_reason=user_msg,
|
||||
)
|
||||
logger.error(f"Avatar {avatar_id} element_poll system error: {cloud_detail}")
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=user_msg)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="error", value=user_msg)
|
||||
)
|
||||
else:
|
||||
_reason = _translate_element_error(str(e))
|
||||
await _update_avatar_state(
|
||||
registry,
|
||||
avatar_id,
|
||||
avatar_status=AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
fail_reason=_reason,
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="message", value=_reason)
|
||||
)
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_reason))
|
||||
|
||||
# ---------- 已结束状态:移出 running ----------
|
||||
elif status in (
|
||||
AvatarCloneStatus.SUCCEED.value,
|
||||
AvatarCloneStatus.VOICE_FAILED.value,
|
||||
AvatarCloneStatus.ELEMENT_FAILED.value,
|
||||
):
|
||||
await slots.release(SLOT_KEY, f"avatar:{avatar_id}")
|
||||
if status == AvatarCloneStatus.SUCCEED.value:
|
||||
changes.append(
|
||||
StateChange(job_id=avatar_id, field_path="status", value="completed")
|
||||
)
|
||||
else:
|
||||
_msg = "任务状态异常"
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="status", value="failed"))
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="message", value=_msg))
|
||||
changes.append(StateChange(job_id=avatar_id, field_path="error", value=_msg))
|
||||
|
||||
return changes
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
AsyncHandler 抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
|
||||
|
||||
class AsyncHandler(ABC):
|
||||
"""第三方异步任务处理器基类"""
|
||||
|
||||
name: str
|
||||
slot_key: str
|
||||
max_slots: int
|
||||
|
||||
@abstractmethod
|
||||
async def tick(
|
||||
self,
|
||||
jobs: list[Any],
|
||||
registry: JobRegistry,
|
||||
slots: SlotManager,
|
||||
) -> list[StateChange]:
|
||||
"""
|
||||
每个 Tick 执行一次。
|
||||
|
||||
Args:
|
||||
jobs: 当前 running 状态的作业记录列表
|
||||
registry: 作业注册表(用于读写 Redis 状态)
|
||||
slots: 全局槽位管理器
|
||||
|
||||
Returns:
|
||||
状态变更列表
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Copy 任务处理器
|
||||
==============
|
||||
|
||||
管理 AnyToCopy 文案提取的提交与轮询。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.services.anytocopy_service import get_anytocopy_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "anytocopy:slots"
|
||||
MAX_SLOTS = 5
|
||||
|
||||
|
||||
class CopyHandler(AsyncHandler):
|
||||
name = "copy"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
|
||||
for job in jobs:
|
||||
params = job.params or {}
|
||||
anytocopy_task_id = params.get("anytocopy_task_id")
|
||||
video_url = params.get("url", params.get("video_url", ""))
|
||||
|
||||
if anytocopy_task_id:
|
||||
try:
|
||||
service = get_anytocopy_service()
|
||||
result = await service.query_task(anytocopy_task_id)
|
||||
if result.get("code") != 200:
|
||||
continue
|
||||
data = result.get("data", {})
|
||||
status = data.get("status")
|
||||
|
||||
if status == "SUCCESS":
|
||||
result_data = {
|
||||
"video_url": video_url,
|
||||
"title": data.get("title", ""),
|
||||
"content": data.get("content", ""),
|
||||
"text_content": data.get("textContent", ""),
|
||||
"platform": data.get("platform", ""),
|
||||
"duration": data.get("duration", 0),
|
||||
}
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id, field_path="message", value="文案提取完成"
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="completed", value=1)
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=1))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="result", value=result_data)
|
||||
)
|
||||
elif status in ("FAILED", "FAILURE"):
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value=f"提取失败: {data.get('errorMessage', '未知错误')}",
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="error",
|
||||
value=data.get("errorMessage", ""),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Copy {job.job_id}] poll error: {e}")
|
||||
continue
|
||||
|
||||
acquired = await slots.acquire(SLOT_KEY, job.job_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
continue
|
||||
|
||||
try:
|
||||
service = get_anytocopy_service()
|
||||
submit_result = await service.submit_task(video_url)
|
||||
if submit_result.get("code") != 200:
|
||||
raise Exception(f"提交失败: {submit_result.get('msg')}")
|
||||
anytocopy_task_id = submit_result["data"]
|
||||
params["anytocopy_task_id"] = anytocopy_task_id
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value="文案提取任务已提交")
|
||||
)
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value=str(e)[:200])
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=str(e)[:500])
|
||||
)
|
||||
|
||||
return changes
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Image 任务处理器
|
||||
===============
|
||||
|
||||
管理 Kling 图片生成的提交、轮询、下载。
|
||||
不占用 Kling Video/Avatar 槽位,使用独立的图片槽位池。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider
|
||||
from app.config import get_settings
|
||||
from app.core.config_loader import get_config_loader
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "kling:image_slots"
|
||||
MAX_SLOTS = 9
|
||||
|
||||
|
||||
class ImageHandler(AsyncHandler):
|
||||
name = "image"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
async def _get_provider(self) -> KlingAIProvider:
|
||||
settings = get_settings()
|
||||
config_loader = get_config_loader()
|
||||
platform = config_loader.get_platform("klingai")
|
||||
return KlingAIProvider(
|
||||
{
|
||||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||||
"base_url": platform.base_url if platform else "https://api-beijing.klingai.com",
|
||||
}
|
||||
)
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
provider = await self._get_provider()
|
||||
|
||||
for job in jobs:
|
||||
params = job.params or {}
|
||||
provider_task_id = params.get("provider_task_id")
|
||||
project_id = params.get("project_id", "")
|
||||
prompt = params.get("prompt", "")
|
||||
image_type = params.get("image_type", "cover")
|
||||
|
||||
if provider_task_id:
|
||||
# 轮询状态
|
||||
try:
|
||||
result = await provider.get_image_task(provider_task_id)
|
||||
status = result.get("task_status", "unknown")
|
||||
except Exception as e:
|
||||
logger.error(f"[Image {job.job_id}] poll error: {e}")
|
||||
continue
|
||||
|
||||
if status in ("processing", "submitted"):
|
||||
continue
|
||||
|
||||
if status == "failed":
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
error_msg = result.get("task_status_msg", "图片生成失败")
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value=error_msg)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=error_msg)
|
||||
)
|
||||
continue
|
||||
|
||||
# succeed
|
||||
images = result.get("task_result", {}).get("images", [])
|
||||
if not images:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value="图片生成成功但未返回图片",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
image_url = images[0].get("url")
|
||||
image_dir = (
|
||||
Path.home() / "Documents" / "Meijiaka" / "projects" / project_id / "images"
|
||||
)
|
||||
image_dir.mkdir(parents=True, exist_ok=True)
|
||||
ext = ".jpg" if ".jpg" in image_url else ".png"
|
||||
local_path = image_dir / f"{image_type}_{job.job_id[:8]}{ext}"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session, session.get(image_url) as resp:
|
||||
resp.raise_for_status()
|
||||
local_path.write_bytes(await resp.read())
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id, field_path="message", value=f"图片下载失败: {e}"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
result_data = {
|
||||
"project_id": project_id,
|
||||
"image_type": image_type,
|
||||
"local_path": str(local_path),
|
||||
"prompt": prompt,
|
||||
}
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value="图片生成完成")
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="completed", value=1))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=1))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="result", value=result_data)
|
||||
)
|
||||
continue
|
||||
|
||||
# 提交新任务
|
||||
acquired = await slots.acquire(SLOT_KEY, job.job_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
continue
|
||||
|
||||
try:
|
||||
reference_image = params.get("reference_image")
|
||||
human_id = params.get("human_id")
|
||||
if reference_image:
|
||||
result = await provider.generate_image(
|
||||
prompt=prompt,
|
||||
image_url=reference_image,
|
||||
model="kling-v3",
|
||||
)
|
||||
elif human_id:
|
||||
result = await provider.generate_image(
|
||||
prompt=prompt,
|
||||
model="kling-v3",
|
||||
aspect_ratio="9:16",
|
||||
)
|
||||
else:
|
||||
result = await provider.generate_image(
|
||||
prompt=prompt,
|
||||
model="kling-v3",
|
||||
aspect_ratio="9:16",
|
||||
)
|
||||
|
||||
provider_task_id = result.get("task_id")
|
||||
if not provider_task_id:
|
||||
raise ValueError("未返回任务ID")
|
||||
params["provider_task_id"] = provider_task_id
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value="图片任务已提交")
|
||||
)
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value=str(e)[:200])
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=str(e)[:500])
|
||||
)
|
||||
|
||||
return changes
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Script 任务处理器
|
||||
================
|
||||
|
||||
管理脚本生成的执行。
|
||||
不占用 Kling/Volc 槽位,使用独立的 script 槽位池。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.services.anytocopy_service import get_anytocopy_service
|
||||
from app.services.script_service import ScriptService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "script:slots"
|
||||
MAX_SLOTS = 10
|
||||
|
||||
|
||||
class ScriptHandler(AsyncHandler):
|
||||
name = "script"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
|
||||
for job in jobs:
|
||||
acquired = await slots.acquire(SLOT_KEY, job.job_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
continue
|
||||
|
||||
try:
|
||||
changes.extend(await self._process_job(job, registry, slots))
|
||||
except Exception as e:
|
||||
logger.exception(f"[Script {job.job_id}] failed")
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=str(e)[:500])
|
||||
)
|
||||
finally:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
|
||||
return changes
|
||||
|
||||
async def _process_job(
|
||||
self, job: Any, registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
params = job.params or {}
|
||||
topic = params.get("topic", "")
|
||||
style = params.get("style", "default")
|
||||
duration = params.get("duration", 60)
|
||||
|
||||
await registry.update(
|
||||
job.job_id,
|
||||
status="running",
|
||||
progress=10,
|
||||
message="分析需求中...",
|
||||
completed=0,
|
||||
total=1,
|
||||
)
|
||||
|
||||
try:
|
||||
await __import__("asyncio").sleep(2)
|
||||
anytocopy = get_anytocopy_service()
|
||||
extract_result = await anytocopy.extract_text_from_input(topic)
|
||||
extracted_info = None
|
||||
actual_topic = topic
|
||||
is_video_url = extract_result.get("is_video_url", False)
|
||||
|
||||
if is_video_url:
|
||||
await registry.update(
|
||||
job.job_id,
|
||||
progress=30,
|
||||
message="提取视频素材中...",
|
||||
)
|
||||
video_info = extract_result.get("video_info")
|
||||
if video_info:
|
||||
extracted_info = {
|
||||
"title": video_info.title,
|
||||
"content": video_info.content,
|
||||
"text_content": video_info.text_content,
|
||||
"platform": video_info.platform,
|
||||
"duration": video_info.duration,
|
||||
"original_url": topic,
|
||||
}
|
||||
actual_topic = extract_result.get("extracted_text") or topic
|
||||
await registry.update(
|
||||
job.job_id,
|
||||
progress=60,
|
||||
message="生成脚本中...",
|
||||
)
|
||||
else:
|
||||
await registry.update(
|
||||
job.job_id,
|
||||
progress=40,
|
||||
message="构思脚本中...",
|
||||
)
|
||||
|
||||
service = ScriptService()
|
||||
shots = await service.generate_script(
|
||||
topic=actual_topic, script_type=style, duration=duration
|
||||
)
|
||||
|
||||
# 计算分镜真实总时长
|
||||
total_duration = sum(s.duration for s in shots if s.duration)
|
||||
result_data = {
|
||||
"title": actual_topic[:50],
|
||||
"scenes": [s.model_dump() for s in shots],
|
||||
"total_duration": total_duration,
|
||||
"style": style,
|
||||
"shot_count": len(shots),
|
||||
"extracted_info": extracted_info,
|
||||
}
|
||||
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="completed"))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="progress", value=100))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value="脚本生成完成")
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="completed", value=1))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=1))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="result", value=result_data))
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"[ScriptTask {job.job_id}] Failed")
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value=str(exc)[:200])
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="error", value=str(exc)[:500]))
|
||||
|
||||
return changes
|
||||
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Subtitle 任务处理器
|
||||
==================
|
||||
|
||||
管理火山引擎字幕生成与自动打轴的提交与轮询。
|
||||
支持两种模式:
|
||||
- caption: 字幕识别(从音频/视频提取带时间轴的字幕)
|
||||
- auto_align: 自动打轴(为已有字幕文本配上时间轴)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.services.volcengine_caption_service import VolcengineCaptionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "volc:subtitle_slots"
|
||||
MAX_SLOTS = 5
|
||||
|
||||
|
||||
class SubtitleHandler(AsyncHandler):
|
||||
name = "subtitle"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
|
||||
for job in jobs:
|
||||
params = job.params or {}
|
||||
mode = params.get("mode", "caption")
|
||||
volc_task_id = params.get("volc_task_id")
|
||||
project_id = params.get("project_id", "")
|
||||
video_path = params.get("video", params.get("video_path", ""))
|
||||
language = params.get("language", "zh")
|
||||
audio_text = params.get("audio_text", "")
|
||||
|
||||
if volc_task_id:
|
||||
# 轮询
|
||||
try:
|
||||
service = VolcengineCaptionService()
|
||||
if mode == "auto_align":
|
||||
result = await service.query_auto_align_task(volc_task_id, blocking=False)
|
||||
else:
|
||||
result = await service.query_caption_task(volc_task_id, blocking=False)
|
||||
|
||||
if result.code == 0:
|
||||
utterances = result.utterances or []
|
||||
result_data = {
|
||||
"project_id": project_id,
|
||||
"video_path": video_path,
|
||||
"language": language,
|
||||
"mode": mode,
|
||||
"duration": result.duration,
|
||||
"utterances": [
|
||||
{
|
||||
"text": u.text,
|
||||
"start_time": u.start_time,
|
||||
"end_time": u.end_time,
|
||||
}
|
||||
for u in utterances
|
||||
],
|
||||
}
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id, field_path="message", value="字幕生成完成"
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="completed", value=1)
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=1))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="result", value=result_data)
|
||||
)
|
||||
elif result.code != 2000:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value=f"字幕识别失败: {result.message}",
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=result.message)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Subtitle {job.job_id}] poll error: {e}")
|
||||
continue
|
||||
|
||||
# 提交
|
||||
acquired = await slots.acquire(SLOT_KEY, job.job_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
continue
|
||||
|
||||
try:
|
||||
service = VolcengineCaptionService()
|
||||
if mode == "auto_align":
|
||||
if not audio_text:
|
||||
raise ValueError("auto_align 模式需要提供 audio_text")
|
||||
volc_task_id = await service.submit_auto_align_task(
|
||||
audio_url=video_path,
|
||||
audio_text=audio_text,
|
||||
)
|
||||
else:
|
||||
volc_task_id = await service.submit_caption_task(
|
||||
audio_url=video_path, language=language
|
||||
)
|
||||
if not volc_task_id:
|
||||
raise ValueError("未返回任务ID")
|
||||
params["volc_task_id"] = volc_task_id
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value="字幕任务已提交")
|
||||
)
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, job.job_id)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="message", value=str(e)[:200])
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="error", value=str(e)[:500])
|
||||
)
|
||||
|
||||
return changes
|
||||
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Video 任务处理器
|
||||
===============
|
||||
|
||||
管理 Kling 视频生成(含 segment 和 empty_shot)的提交、轮询、下载。
|
||||
占用全局槽位:18
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.ai.providers.klingai_provider import KlingAIProvider, KlingPromptBuilder
|
||||
from app.config import get_settings
|
||||
from app.core.config_loader import get_config_loader
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import JobRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.services.qiniu_service import get_qiniu_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "kling:video_slots"
|
||||
MAX_SLOTS = 18
|
||||
|
||||
|
||||
class VideoHandler(AsyncHandler):
|
||||
name = "video"
|
||||
slot_key = SLOT_KEY
|
||||
max_slots = MAX_SLOTS
|
||||
|
||||
def __init__(self):
|
||||
self._provider: KlingAIProvider | None = None
|
||||
|
||||
async def _get_provider(self) -> KlingAIProvider:
|
||||
if self._provider is None:
|
||||
settings = get_settings()
|
||||
config_loader = get_config_loader()
|
||||
platform = config_loader.get_platform("klingai")
|
||||
self._provider = KlingAIProvider(
|
||||
{
|
||||
"access_key": settings.KLINGAI_ACCESS_KEY or "",
|
||||
"secret_key": settings.KLINGAI_SECRET_KEY or "",
|
||||
"base_url": (
|
||||
platform.base_url if platform else "https://api-beijing.klingai.com"
|
||||
),
|
||||
}
|
||||
)
|
||||
return self._provider
|
||||
|
||||
def _get_project_video_dir(self, project_id: str) -> Path:
|
||||
video_dir = Path.home() / "Documents" / "Meijiaka" / "projects" / project_id / "videos"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
return video_dir
|
||||
|
||||
async def _download_video(self, video_url: str, local_path: Path) -> None:
|
||||
async with aiohttp.ClientSession() as session, session.get(video_url) as resp:
|
||||
resp.raise_for_status()
|
||||
local_path.write_bytes(await resp.read())
|
||||
|
||||
async def _download_image(self, image_url: str, local_path: Path) -> None:
|
||||
async with aiohttp.ClientSession() as session, session.get(image_url) as resp:
|
||||
resp.raise_for_status()
|
||||
local_path.write_bytes(await resp.read())
|
||||
|
||||
async def _poll_image_task(self, provider: KlingAIProvider, image_task_id: str) -> str:
|
||||
"""轮询文生图任务,返回图片 URL"""
|
||||
timeout = 600
|
||||
start = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
if asyncio.get_event_loop().time() - start > timeout:
|
||||
raise TimeoutError("文生图轮询超时")
|
||||
result = await provider.get_image_task(image_task_id)
|
||||
status = result.get("task_status", "unknown")
|
||||
if status == "succeed":
|
||||
images = result.get("task_result", {}).get("images", [])
|
||||
if images and images[0].get("url"):
|
||||
return images[0]["url"]
|
||||
raise Exception("文生图成功但未返回图片 URL")
|
||||
if status == "failed":
|
||||
raise Exception(result.get("task_status_msg", "文生图失败"))
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def tick(
|
||||
self, jobs: list[Any], registry: JobRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
provider = await self._get_provider()
|
||||
|
||||
for job in jobs:
|
||||
job_changes = await self._process_job(job, registry, slots, provider)
|
||||
changes.extend(job_changes)
|
||||
|
||||
return changes
|
||||
|
||||
async def _process_job(
|
||||
self, job: Any, registry: JobRegistry, slots: SlotManager, provider: KlingAIProvider
|
||||
) -> list[StateChange]:
|
||||
changes: list[StateChange] = []
|
||||
params = job.params or {}
|
||||
shots = params.get("shots", [])
|
||||
if isinstance(shots, str):
|
||||
shots = json.loads(shots)
|
||||
params["shots"] = shots
|
||||
if not shots:
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="error", value="没有镜头数据"))
|
||||
return changes
|
||||
|
||||
project_id = params.get("project_id", job.job_id)
|
||||
|
||||
# 1. 查询 submitted 状态的 shots
|
||||
for i, shot in enumerate(shots):
|
||||
if shot.get("status") != "submitted":
|
||||
continue
|
||||
provider_task_id = shot.get("provider_task_id")
|
||||
if not provider_task_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
if shot.get("type") == "segment":
|
||||
result = await provider.get_omni_video_task(provider_task_id)
|
||||
else:
|
||||
result = await provider.get_video_task(
|
||||
provider_task_id, task_type="image2video"
|
||||
)
|
||||
status = result.get("task_status", "unknown")
|
||||
except Exception as e:
|
||||
logger.error(f"[Video {job.job_id}] Query shot {shot['id']} error: {e}")
|
||||
# 累计查询失败计数
|
||||
fail_count = shot.get("query_fail_count", 0) + 1
|
||||
if fail_count >= 5:
|
||||
await slots.release(SLOT_KEY, f"{job.job_id}:{shot['id']}")
|
||||
shots[i]["status"] = "failed"
|
||||
shots[i]["error_message"] = f"查询状态连续失败: {e}"[:500]
|
||||
shots[i]["query_fail_count"] = fail_count
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="params", value=params)
|
||||
)
|
||||
else:
|
||||
shots[i]["query_fail_count"] = fail_count
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="params", value=params)
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查超时:超过 2 小时还在 processing 就标记失败
|
||||
created_at = result.get("created_at", 0) # KlingAI 返回的是 Unix 毫秒时间戳
|
||||
import time
|
||||
now_ms = int(time.time() * 1000)
|
||||
if status == "processing" and created_at > 0 and (now_ms - created_at) > 2 * 60 * 60 * 1000:
|
||||
# 超时 2 小时,标记失败释放槽位
|
||||
await slots.release(SLOT_KEY, f"{job.job_id}:{shot['id']}")
|
||||
shots[i]["status"] = "failed"
|
||||
shots[i]["error_message"] = "生成超时(超过 2 小时仍在处理中)"
|
||||
logger.warning(f"[Video {job.job_id}] Shot {shot['id']} timeout, marked as failed")
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
|
||||
elif status == "succeed":
|
||||
await slots.release(SLOT_KEY, f"{job.job_id}:{shot['id']}")
|
||||
videos = result.get("task_result", {}).get("videos", [])
|
||||
video_url = videos[0].get("url") if videos else None
|
||||
if video_url:
|
||||
shots[i]["video_url"] = video_url
|
||||
shots[i]["status"] = "completed"
|
||||
# 完成就立即下载,不用等全部完成
|
||||
logger.info(f"[Video {job.job_id}] Shot {shot['id']} completed, downloading...")
|
||||
await self._download_and_upload(project_id, shots[i])
|
||||
else:
|
||||
shots[i]["status"] = "failed"
|
||||
shots[i]["error_message"] = "任务成功但未返回视频"
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
|
||||
elif status == "failed":
|
||||
await slots.release(SLOT_KEY, f"{job.job_id}:{shot['id']}")
|
||||
shots[i]["status"] = "failed"
|
||||
shots[i]["error_message"] = result.get("task_status_msg", "生成失败")[:500]
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
|
||||
# 2. 提交 pending 状态的 shots(填槽),segment 优先于 empty_shot
|
||||
pending_shots = sorted(
|
||||
[s for s in shots if s.get("status") == "pending"],
|
||||
key=lambda s: 0 if s.get("type") == "segment" else 1,
|
||||
)
|
||||
for shot in pending_shots:
|
||||
slot_id = f"{job.job_id}:{shot['id']}"
|
||||
acquired = await slots.acquire(SLOT_KEY, slot_id, MAX_SLOTS)
|
||||
if not acquired:
|
||||
continue # 当前这个获取失败(槽位满或网络问题),跳过尝试下一个,下次 tick 再重试
|
||||
|
||||
try:
|
||||
if shot.get("type") == "segment":
|
||||
human_id = shot.get("human_id") or params.get("human_id")
|
||||
if not human_id:
|
||||
raise ValueError(f"分镜 {shot['id']} 缺少 human_id")
|
||||
prompt = KlingPromptBuilder.omni_segment(
|
||||
shot.get("scene", ""), shot.get("voiceover", "")
|
||||
)
|
||||
result = await provider.generate_video_omni(
|
||||
prompt=prompt,
|
||||
model="kling-v3-omni",
|
||||
mode="pro",
|
||||
aspect_ratio="9:16",
|
||||
duration=shot.get("duration"),
|
||||
sound="on",
|
||||
multi_shot=False,
|
||||
element_list=[{"element_id": str(human_id)}],
|
||||
)
|
||||
else:
|
||||
# empty_shot: 文生图 -> 上传七牛 -> 图生视频
|
||||
result = await self._submit_empty_shot(shot, provider)
|
||||
|
||||
provider_task_id = result.get("task_id")
|
||||
if not provider_task_id:
|
||||
raise ValueError(f"创建任务失败,未返回 provider_task_id: {result}")
|
||||
|
||||
shot["provider_task_id"] = provider_task_id
|
||||
shot["status"] = "submitted"
|
||||
logger.info(f"[Video {job.job_id}] Shot {shot['id']} submitted: {provider_task_id}")
|
||||
except Exception as e:
|
||||
await slots.release(SLOT_KEY, slot_id)
|
||||
shot["status"] = "failed"
|
||||
shot["error_message"] = str(e)[:500]
|
||||
logger.error(f"[Video {job.job_id}] Submit shot {shot['id']} failed: {e}")
|
||||
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
|
||||
# 3. 检查是否所有 shots 都完成,做最终汇总
|
||||
all_done = all(s.get("status") in ("completed", "failed") for s in shots)
|
||||
completed = sum(1 for s in shots if s.get("status") == "completed")
|
||||
failed = sum(1 for s in shots if s.get("status") == "failed")
|
||||
|
||||
if all_done:
|
||||
# 下载已经在每个分镜完成时处理过了,这里只重试下载失败的
|
||||
retry_download_tasks = [
|
||||
self._download_and_upload(project_id, shot)
|
||||
for shot in shots
|
||||
if shot.get("status") == "completed"
|
||||
and shot.get("video_url")
|
||||
and not shot.get("local_path")
|
||||
]
|
||||
if retry_download_tasks:
|
||||
logger.info(f"[Video {job.job_id}] Final retry downloading {len(retry_download_tasks)} videos...")
|
||||
await asyncio.gather(*retry_download_tasks, return_exceptions=True)
|
||||
logger.info(f"[Video {job.job_id}] Retry downloads finished")
|
||||
# shots 字典已被 _download_and_upload 更新,写回 params
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="params", value=params))
|
||||
|
||||
# 下载后重新统计,以反映可能的下载失败
|
||||
completed = sum(1 for s in shots if s.get("status") == "completed")
|
||||
failed = sum(1 for s in shots if s.get("status") == "failed")
|
||||
|
||||
if completed == 0 and failed > 0:
|
||||
errors = "; ".join(
|
||||
f"{s.get('id')}: {s.get('error_message')}"
|
||||
for s in shots
|
||||
if s.get("error_message")
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value=f"全部失败 ({failed}/{len(shots)})",
|
||||
)
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="error", value=errors))
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="completed", value=len(shots))
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=len(shots)))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="progress", value=100))
|
||||
else:
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value=f"完成!成功 {completed},失败 {failed}",
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="completed", value=len(shots))
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=len(shots)))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="progress", value=100))
|
||||
# result 字段包含 shots 汇总(含下载后的 local_path / qiniu_url)
|
||||
result_data = {
|
||||
"project_id": project_id,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"total": len(shots),
|
||||
"shots": [
|
||||
{
|
||||
"shot_id": s.get("id"),
|
||||
"type": s.get("type"),
|
||||
"status": s.get("status"),
|
||||
"task_id": s.get("provider_task_id"),
|
||||
"video_url": s.get("video_url"),
|
||||
"local_path": s.get("local_path"),
|
||||
"qiniu_url": s.get("qiniu_url"),
|
||||
"error_message": s.get("error_message"),
|
||||
}
|
||||
for s in shots
|
||||
],
|
||||
}
|
||||
changes.append(
|
||||
StateChange(job_id=job.job_id, field_path="result", value=result_data)
|
||||
)
|
||||
else:
|
||||
done_count = completed + failed
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="status", value="running"))
|
||||
changes.append(
|
||||
StateChange(
|
||||
job_id=job.job_id,
|
||||
field_path="message",
|
||||
value=f"{done_count}/{len(shots)} 个镜头处理中",
|
||||
)
|
||||
)
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="completed", value=done_count))
|
||||
changes.append(StateChange(job_id=job.job_id, field_path="total", value=len(shots)))
|
||||
|
||||
return changes
|
||||
|
||||
async def _submit_empty_shot(
|
||||
self, shot: dict[str, Any], provider: KlingAIProvider
|
||||
) -> dict[str, Any]:
|
||||
"""空镜 shot 的完整提交流程:文生图 -> 上传七牛 -> 图生视频"""
|
||||
qiniu = get_qiniu_service()
|
||||
|
||||
# 1. 文生图
|
||||
image_result = await provider.generate_image(
|
||||
prompt=shot.get("scene", ""),
|
||||
model="kling-v3",
|
||||
aspect_ratio="9:16",
|
||||
)
|
||||
image_task_id = image_result.get("task_id")
|
||||
if not image_task_id:
|
||||
raise ValueError(f"文生图创建失败: {image_result}")
|
||||
|
||||
# 2. 轮询图片完成
|
||||
image_url = await self._poll_image_task(provider, image_task_id)
|
||||
|
||||
# 3. 下载图片
|
||||
temp_dir = Path(tempfile.gettempdir()) / "meijiaka_empty_shot"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_image_path = temp_dir / f"{image_task_id}.jpg"
|
||||
await self._download_image(image_url, temp_image_path)
|
||||
|
||||
# 4. 上传七牛
|
||||
qiniu_result = qiniu.upload_file(
|
||||
local_path=str(temp_image_path),
|
||||
file_type="image",
|
||||
check_duplicate=True,
|
||||
)
|
||||
qiniu_image_url = qiniu_result["url"]
|
||||
with contextlib.suppress(Exception):
|
||||
temp_image_path.unlink()
|
||||
|
||||
# 5. 图生视频
|
||||
voice_id = shot.get("voice_id") or get_settings().DEFAULT_EMPTY_SHOT_VOICE_ID
|
||||
prompt = KlingPromptBuilder.empty_shot(shot.get("scene", ""), shot.get("voiceover", ""))
|
||||
result = await provider.generate_video_image2video(
|
||||
prompt=prompt,
|
||||
image_url=qiniu_image_url,
|
||||
model="kling-v2-6",
|
||||
mode="pro",
|
||||
duration=shot.get("duration"),
|
||||
voice_list=[{"voice_id": voice_id}],
|
||||
sound="on",
|
||||
negative_prompt="画外音没有标点的时候不要轻易断句",
|
||||
)
|
||||
return result
|
||||
|
||||
async def _download_and_upload(self, project_id: str, shot: dict[str, Any]) -> None:
|
||||
"""下载视频到本地并上传七牛。直接更新传入的 shot 字典,不操作 Redis。"""
|
||||
video_url = shot.get("video_url")
|
||||
if not video_url:
|
||||
shot["status"] = "failed"
|
||||
shot["error_message"] = "没有视频URL"
|
||||
return
|
||||
|
||||
video_dir = self._get_project_video_dir(project_id)
|
||||
|
||||
# 清理同 shot_id 的旧视频文件(避免重新生成后前端缓存不刷新)
|
||||
import glob as stdlib_glob
|
||||
|
||||
pattern = f"scene_{stdlib_glob.escape(str(shot['id']))}_*.mp4"
|
||||
for old_file in video_dir.glob(pattern):
|
||||
try:
|
||||
old_file.unlink()
|
||||
logger.info(f"[Video] Removed old file: {old_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Video] Failed to remove old file {old_file}: {e}")
|
||||
|
||||
# 使用随机后缀命名,确保前端检测到 filePath 变化并重新加载
|
||||
local_path = video_dir / f"scene_{shot['id']}_{uuid.uuid4().hex[:6]}.mp4"
|
||||
|
||||
try:
|
||||
await self._download_video(video_url, local_path)
|
||||
shot["local_path"] = str(local_path)
|
||||
|
||||
try:
|
||||
qiniu = get_qiniu_service()
|
||||
qiniu_result = qiniu.upload_video(local_path=str(local_path))
|
||||
shot["qiniu_url"] = qiniu_result["url"]
|
||||
except Exception as e:
|
||||
logger.warning(f"[Video] Shot {shot['id']} upload qiniu failed: {e}")
|
||||
shot["qiniu_url"] = None
|
||||
|
||||
logger.info(f"[Video] Shot {shot['id']} download/upload done: {local_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Video] Shot {shot['id']} download failed: {e}")
|
||||
shot["status"] = "failed"
|
||||
shot["error_message"] = f"下载失败: {e}"[:500]
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Async Engine 独立进程入口
|
||||
=========================
|
||||
|
||||
usage: python -m app.scheduler.main
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from app.scheduler.engine import AsyncEngine
|
||||
from app.scheduler.handlers.avatar_handler import AvatarHandler
|
||||
from app.scheduler.handlers.copy_handler import CopyHandler
|
||||
from app.scheduler.handlers.image_handler import ImageHandler
|
||||
from app.scheduler.handlers.script_handler import ScriptHandler
|
||||
from app.scheduler.handlers.subtitle_handler import SubtitleHandler
|
||||
from app.scheduler.handlers.video_handler import VideoHandler
|
||||
|
||||
logger = logging.getLogger("scheduler")
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=log_format,
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
setup_logging()
|
||||
engine = AsyncEngine()
|
||||
engine.register(VideoHandler())
|
||||
engine.register(AvatarHandler())
|
||||
engine.register(ImageHandler())
|
||||
engine.register(SubtitleHandler())
|
||||
engine.register(CopyHandler())
|
||||
engine.register(ScriptHandler())
|
||||
await engine.run_forever(interval=10.0, min_interval=2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Scheduler stopped by user")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user