2a36e4ec3d
AI 生成 scene 时常将三级分类名中的 '-' 前后顺序写反 (如 瓷砖铺贴-瓷砖完工展示 vs 瓷砖完工展示-瓷砖铺贴), 导致精确匹配失败、素材匹配为空。 - match_material: 精确匹配失败后,尝试倒序匹配 - batch_match: 批量查询时同时查询原始名和倒序名, 内存中构建 scene->category 映射,优先精确匹配、fallback 倒序
298 lines
9.9 KiB
Python
298 lines
9.9 KiB
Python
"""
|
||
空镜素材服务
|
||
============
|
||
|
||
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
import random
|
||
import re
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.exceptions import ValidationException
|
||
from app.core.redis_client import get_redis_client
|
||
from app.crud import broll_category, broll_material
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis 已使用素材 Set 的 TTL(7 天)
|
||
_USED_MATERIALS_TTL = 7 * 24 * 3600
|
||
|
||
|
||
def _normalize_scene(scene: str) -> str:
|
||
"""标准化场景描述,用于匹配三级分类 name"""
|
||
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
|
||
return re.sub(r"\s+", "", scene)
|
||
|
||
|
||
def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
||
"""
|
||
加权随机选择素材
|
||
|
||
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
|
||
"""
|
||
if not materials:
|
||
raise ValueError("候选池为空")
|
||
|
||
if len(materials) == 1:
|
||
return materials[0]
|
||
|
||
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
|
||
total_weight = sum(weights)
|
||
|
||
if total_weight == 0:
|
||
return random.choice(materials)
|
||
|
||
r = random.uniform(0, total_weight)
|
||
cumulative = 0.0
|
||
for m, w in zip(materials, weights):
|
||
cumulative += w
|
||
if r <= cumulative:
|
||
return m
|
||
|
||
# 兜底返回最后一个
|
||
return materials[-1]
|
||
|
||
|
||
async def match_material(
|
||
db: AsyncSession,
|
||
scene: str,
|
||
required_duration: float,
|
||
project_id: str | None = None,
|
||
) -> dict | None:
|
||
"""
|
||
根据场景描述和所需时长匹配空镜素材
|
||
|
||
匹配策略:
|
||
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
||
2. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
||
3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
||
4. 优先从未使用候选中加权随机选择;若未用候选为空,
|
||
fallback 到全部候选(允许复用,保证合成连续性)。
|
||
5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工")
|
||
required_duration: 所需时长(秒),必须大于 0
|
||
project_id: 项目ID,用于去重。为 None 时不做去重。
|
||
|
||
Returns:
|
||
{"url": str, "duration": float} 或 None
|
||
|
||
Raises:
|
||
ValidationException: scene 为空或 duration <= 0
|
||
"""
|
||
# 参数校验
|
||
if not scene or not scene.strip():
|
||
raise ValidationException("场景描述不能为空")
|
||
if required_duration <= 0:
|
||
raise ValidationException("所需时长必须大于 0")
|
||
|
||
normalized = _normalize_scene(scene)
|
||
|
||
# 1. 查找三级分类(精确匹配 + 顺序颠倒兜底)
|
||
category = await broll_category.get_by_name_and_level(
|
||
db, name=normalized, level=3
|
||
)
|
||
# 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
|
||
if category is None:
|
||
parts = normalized.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
reversed_name = f"{parts[1]}-{parts[0]}"
|
||
category = await broll_category.get_by_name_and_level(
|
||
db, name=reversed_name, level=3
|
||
)
|
||
if category:
|
||
logger.info(
|
||
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
|
||
)
|
||
if category is None:
|
||
logger.debug(f"未找到分类: {normalized}")
|
||
return None
|
||
|
||
# 2. 查询候选素材
|
||
materials = await broll_material.get_active_by_category_and_duration(
|
||
db, category_id=category.id, min_duration=required_duration
|
||
)
|
||
if not materials:
|
||
logger.debug(
|
||
f"分类 {normalized} 无足够时长的素材 (需 >= {required_duration}s)"
|
||
)
|
||
return None
|
||
|
||
# 3. Redis 去重:获取项目已使用素材
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await redis.smembers(key))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 4. 区分未用候选和全部候选
|
||
unused = [m for m in materials if m.url not in used_urls]
|
||
target_pool = unused if unused else materials
|
||
|
||
# 5. 加权随机选择
|
||
chosen = _weighted_choice(target_pool)
|
||
|
||
# 6. 原子递增 usage_count(避免并发覆盖)
|
||
await broll_material.increment_usage_count(db, material_id=chosen.id)
|
||
|
||
# 7. 记录到 Redis(异常不影响主流程)
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
await redis.sadd(key, chosen.url)
|
||
await redis.expire(key, _USED_MATERIALS_TTL)
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return {"url": chosen.url, "duration": float(chosen.duration)}
|
||
|
||
|
||
async def batch_match(
|
||
db: AsyncSession,
|
||
scenes: list[dict],
|
||
project_id: str | None = None,
|
||
) -> list[dict | None]:
|
||
"""
|
||
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis)
|
||
|
||
优化策略:
|
||
1. 一次性批量查询所有三级分类(1 次 DB)。
|
||
2. 一次性批量查询所有相关素材(1 次 DB)。
|
||
3. 内存中按 scene + duration 过滤并加权随机选择。
|
||
4. 批量 UPDATE usage_count(1 次 DB)。
|
||
5. 批量 Redis sadd(1 次 Redis pipeline)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scenes: 每个元素为 {"scene": str, "duration": float}
|
||
project_id: 项目ID,用于去重
|
||
|
||
Returns:
|
||
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
|
||
"""
|
||
if not scenes:
|
||
return []
|
||
|
||
# 1. 标准化所有 scene 并去重
|
||
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
|
||
unique_names = list(set(normalized_scenes))
|
||
|
||
# 2. 批量查询分类(1 次 DB)—— 同时查询原始名和倒序名
|
||
reversed_names: list[str] = []
|
||
name_to_reversed: dict[str, str] = {}
|
||
for name in unique_names:
|
||
parts = name.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
rev = f"{parts[1]}-{parts[0]}"
|
||
reversed_names.append(rev)
|
||
name_to_reversed[name] = rev
|
||
|
||
all_query_names = unique_names + reversed_names
|
||
categories = await broll_category.get_by_names_and_level(
|
||
db, names=all_query_names, level=3
|
||
)
|
||
category_map: dict[str, object] = {}
|
||
for c in categories:
|
||
category_map[c.name] = c
|
||
|
||
# 构建原始 scene -> category 的映射(优先精确匹配,fallback 倒序匹配)
|
||
scene_to_category: dict[str, object] = {}
|
||
for name in unique_names:
|
||
if name in category_map:
|
||
scene_to_category[name] = category_map[name]
|
||
elif name in name_to_reversed and name_to_reversed[name] in category_map:
|
||
rev = name_to_reversed[name]
|
||
scene_to_category[name] = category_map[rev]
|
||
logger.info(
|
||
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
|
||
)
|
||
|
||
# 3. 收集所有需要的 category_id
|
||
needed_category_ids = [
|
||
scene_to_category[name].id
|
||
for name in unique_names
|
||
if name in scene_to_category
|
||
]
|
||
|
||
# 4. 批量查询素材(1 次 DB)
|
||
all_materials = await broll_material.get_active_by_categories(
|
||
db, category_ids=list(set(needed_category_ids))
|
||
)
|
||
|
||
# 按 category_id 分组,方便内存过滤
|
||
materials_by_category: dict[int, list] = {}
|
||
for m in all_materials:
|
||
materials_by_category.setdefault(m.category_id, []).append(m)
|
||
|
||
# 5. Redis 获取已使用素材(1 次 Redis)
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await redis.smembers(key))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 6. 内存中逐个匹配
|
||
results: list[dict | None] = []
|
||
chosen_materials: list = [] # 记录选中的素材,用于批量更新
|
||
|
||
for idx, scene_name in enumerate(normalized_scenes):
|
||
required_duration = scenes[idx]["duration"]
|
||
|
||
category = scene_to_category.get(scene_name)
|
||
if category is None:
|
||
results.append(None)
|
||
continue
|
||
|
||
materials = materials_by_category.get(category.id, [])
|
||
# 按时长过滤
|
||
candidates = [m for m in materials if m.duration >= required_duration]
|
||
if not candidates:
|
||
results.append(None)
|
||
continue
|
||
|
||
# Redis 去重
|
||
unused = [m for m in candidates if m.url not in used_urls]
|
||
target_pool = unused if unused else candidates
|
||
|
||
# 加权随机
|
||
chosen = _weighted_choice(target_pool)
|
||
chosen_materials.append(chosen)
|
||
used_urls.add(chosen.url)
|
||
results.append({"url": chosen.url, "duration": float(chosen.duration)})
|
||
|
||
# 7. 批量更新 usage_count(1 次 DB)
|
||
if chosen_materials:
|
||
material_ids = [m.id for m in chosen_materials]
|
||
try:
|
||
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
|
||
except Exception as e:
|
||
logger.warning(f"批量更新 usage_count 失败: {e}")
|
||
|
||
# 8. 批量记录到 Redis(pipeline)
|
||
if project_id and chosen_materials:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
pipe = redis.pipeline()
|
||
for m in chosen_materials:
|
||
pipe.sadd(key, m.url)
|
||
pipe.expire(key, _USED_MATERIALS_TTL)
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return results
|