447f3c2ffe
后端: - 新增 BrollCategory/BrollMaterial/BrollTag 模型及表(mjk_categories/materials/tags) - 新增 Alembic 迁移 69274ce979a5 - 新增 broll_category/broll_material CRUD 层 - 重构 material_service:删除 JSON 配置,改用 PostgreSQL + Redis 去重 - 新增 /materials/batch-match 接口,删除 /materials/reload - usage_count 原子递增,Redis 失败自动降级 前端: - materials API 改为 projectId 去重,新增 batchMatch - VideoGeneration 批量匹配改用 batchMatch,删除 usedUrls 手动维护 - 修复积分不足时进度弹窗与充值弹窗叠加的 bug - 操作前预检积分,不足时显示提示条+立即充值按钮
175 lines
5.2 KiB
Python
175 lines
5.2 KiB
Python
"""
|
||
空镜素材服务
|
||
============
|
||
|
||
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
import random
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.exceptions import ValidationException
|
||
from app.core.redis_client import get_redis_client
|
||
from app.crud import broll_category, broll_material
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis 已使用素材 Set 的 TTL(7 天)
|
||
_USED_MATERIALS_TTL = 7 * 24 * 3600
|
||
|
||
|
||
def _normalize_scene(scene: str) -> str:
|
||
"""标准化场景描述,用于匹配三级分类 name"""
|
||
# 去除空格和全角空格
|
||
return scene.replace(" ", "").replace("\u3000", "").strip()
|
||
|
||
|
||
def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
||
"""
|
||
加权随机选择素材
|
||
|
||
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
|
||
"""
|
||
if not materials:
|
||
raise ValueError("候选池为空")
|
||
|
||
if len(materials) == 1:
|
||
return materials[0]
|
||
|
||
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
|
||
total_weight = sum(weights)
|
||
|
||
if total_weight == 0:
|
||
return random.choice(materials)
|
||
|
||
r = random.uniform(0, total_weight)
|
||
cumulative = 0.0
|
||
for m, w in zip(materials, weights):
|
||
cumulative += w
|
||
if r <= cumulative:
|
||
return m
|
||
|
||
# 兜底返回最后一个
|
||
return materials[-1]
|
||
|
||
|
||
async def match_material(
|
||
db: AsyncSession,
|
||
scene: str,
|
||
required_duration: float,
|
||
project_id: str | None = None,
|
||
) -> dict | None:
|
||
"""
|
||
根据场景描述和所需时长匹配空镜素材
|
||
|
||
匹配策略:
|
||
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
||
2. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
||
3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
||
4. 优先从未使用候选中加权随机选择;若未用候选为空,
|
||
fallback 到全部候选(允许复用,保证合成连续性)。
|
||
5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工")
|
||
required_duration: 所需时长(秒),必须大于 0
|
||
project_id: 项目ID,用于去重。为 None 时不做去重。
|
||
|
||
Returns:
|
||
{"url": str, "duration": float} 或 None
|
||
|
||
Raises:
|
||
ValidationException: scene 为空或 duration <= 0
|
||
"""
|
||
# 参数校验
|
||
if not scene or not scene.strip():
|
||
raise ValidationException("场景描述不能为空")
|
||
if required_duration <= 0:
|
||
raise ValidationException("所需时长必须大于 0")
|
||
|
||
normalized = _normalize_scene(scene)
|
||
|
||
# 1. 查找三级分类
|
||
category = await broll_category.get_by_name_and_level(
|
||
db, name=normalized, level=3
|
||
)
|
||
if category is None:
|
||
logger.debug(f"未找到分类: {normalized}")
|
||
return None
|
||
|
||
# 2. 查询候选素材
|
||
materials = await broll_material.get_active_by_category_and_duration(
|
||
db, category_id=category.id, min_duration=required_duration
|
||
)
|
||
if not materials:
|
||
logger.debug(
|
||
f"分类 {normalized} 无足够时长的素材 (需 >= {required_duration}s)"
|
||
)
|
||
return None
|
||
|
||
# 3. Redis 去重:获取项目已使用素材
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await redis.smembers(key))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 4. 区分未用候选和全部候选
|
||
unused = [m for m in materials if m.url not in used_urls]
|
||
target_pool = unused if unused else materials
|
||
|
||
# 5. 加权随机选择
|
||
chosen = _weighted_choice(target_pool)
|
||
|
||
# 6. 原子递增 usage_count(避免并发覆盖)
|
||
await broll_material.increment_usage_count(db, material_id=chosen.id)
|
||
|
||
# 7. 记录到 Redis(异常不影响主流程)
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
await redis.sadd(key, chosen.url)
|
||
await redis.expire(key, _USED_MATERIALS_TTL)
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return {"url": chosen.url, "duration": float(chosen.duration)}
|
||
|
||
|
||
async def batch_match(
|
||
db: AsyncSession,
|
||
scenes: list[dict],
|
||
project_id: str | None = None,
|
||
) -> list[dict | None]:
|
||
"""
|
||
批量匹配素材
|
||
|
||
按 scenes 顺序逐个调用 match_material,保证同项目下的去重连续性。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scenes: 每个元素为 {"scene": str, "duration": float}
|
||
project_id: 项目ID,用于去重
|
||
|
||
Returns:
|
||
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
|
||
"""
|
||
results: list[dict | None] = []
|
||
for item in scenes:
|
||
result = await match_material(
|
||
db,
|
||
scene=item["scene"],
|
||
required_duration=item["duration"],
|
||
project_id=project_id,
|
||
)
|
||
results.append(result)
|
||
return results
|