3587559a87
- 新增 broll_category.get_by_name_like_and_level() 模糊匹配方法 - _try_fallback_to_parent 增加三级降级策略: 1. 精确匹配 2. 模糊匹配 LIKE %parent_name%(兼容'电路施工'→'电路施工镜') 3. 自动补后缀'镜'/'阶段'再精确匹配 - 解决 scene 中 parent_name 与数据库二级分类 name 不一致导致回退失败的问题
376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""
|
||
空镜素材服务
|
||
============
|
||
|
||
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
import random
|
||
import re
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.exceptions import ValidationException
|
||
from app.core.redis_client import get_redis_client
|
||
from app.crud import broll_category, broll_material
|
||
from app.models.broll_category import BrollCategory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis 已使用素材 Set 的 TTL(7 天)
|
||
_USED_MATERIALS_TTL = 7 * 24 * 3600
|
||
|
||
|
||
def _normalize_scene(scene: str) -> str:
|
||
"""标准化场景描述,用于匹配三级分类 name"""
|
||
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
|
||
return re.sub(r"\s+", "", scene)
|
||
|
||
|
||
def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
||
"""
|
||
加权随机选择素材
|
||
|
||
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
|
||
"""
|
||
if not materials:
|
||
raise ValueError("候选池为空")
|
||
|
||
if len(materials) == 1:
|
||
return materials[0]
|
||
|
||
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
|
||
total_weight = sum(weights)
|
||
|
||
if total_weight == 0:
|
||
return random.choice(materials)
|
||
|
||
r = random.uniform(0, total_weight)
|
||
cumulative = 0.0
|
||
for m, w in zip(materials, weights, strict=True):
|
||
cumulative += w
|
||
if r <= cumulative:
|
||
return m
|
||
|
||
# 兜底返回最后一个
|
||
return materials[-1]
|
||
|
||
|
||
async def _try_fallback_to_parent(
|
||
db: AsyncSession,
|
||
normalized_scene: str,
|
||
) -> BrollCategory | None:
|
||
"""
|
||
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
|
||
|
||
解析逻辑:
|
||
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工')
|
||
- 若不含 '-',直接以整个 scene 作为 parent_name
|
||
|
||
匹配策略(逐级降级):
|
||
1. 精确匹配 level=2 分类 name
|
||
2. 模糊匹配(LIKE %parent_name%),兼容 "电路施工" → "电路施工镜"
|
||
3. 去掉常见后缀(镜、阶段等)再精确匹配
|
||
|
||
返回:
|
||
随机选中的一个 level=3 子分类,或 None
|
||
"""
|
||
if "-" in normalized_scene:
|
||
parent_name = normalized_scene.rsplit("-", 1)[-1]
|
||
else:
|
||
parent_name = normalized_scene
|
||
|
||
# 1. 精确匹配
|
||
parent = await broll_category.get_by_name_and_level(
|
||
db, name=parent_name, level=2
|
||
)
|
||
|
||
# 2. 模糊匹配(兼容 "电路施工" → "电路施工镜")
|
||
if parent is None:
|
||
parent = await broll_category.get_by_name_like_and_level(
|
||
db, name=parent_name, level=2
|
||
)
|
||
|
||
# 3. 去掉常见后缀再试
|
||
if parent is None:
|
||
for suffix in ("镜", "阶段"):
|
||
if not parent_name.endswith(suffix):
|
||
candidate = parent_name + suffix
|
||
parent = await broll_category.get_by_name_and_level(
|
||
db, name=candidate, level=2
|
||
)
|
||
if parent:
|
||
break
|
||
|
||
if parent is None:
|
||
return None
|
||
|
||
children = await broll_category.get_children_by_parent_id(
|
||
db, parent_id=parent.id, level=3
|
||
)
|
||
if not children:
|
||
return None
|
||
|
||
return random.choice(children)
|
||
|
||
|
||
async def match_material(
|
||
db: AsyncSession,
|
||
scene: str,
|
||
required_duration: float,
|
||
project_id: str | None = None,
|
||
) -> dict | None:
|
||
"""
|
||
根据场景描述和所需时长匹配空镜素材
|
||
|
||
匹配策略:
|
||
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
||
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。
|
||
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。
|
||
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
||
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
||
6. 优先从未使用候选中加权随机选择;若未用候选为空,
|
||
fallback 到全部候选(允许复用,保证合成连续性)。
|
||
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工")
|
||
required_duration: 所需时长(秒),必须大于 0
|
||
project_id: 项目ID,用于去重。为 None 时不做去重。
|
||
|
||
Returns:
|
||
{"url": str, "duration": float} 或 None
|
||
|
||
Raises:
|
||
ValidationException: scene 为空或 duration <= 0
|
||
"""
|
||
# 参数校验
|
||
if not scene or not scene.strip():
|
||
raise ValidationException("场景描述不能为空")
|
||
if required_duration <= 0:
|
||
raise ValidationException("所需时长必须大于 0")
|
||
|
||
normalized = _normalize_scene(scene)
|
||
|
||
# 1. 查找三级分类(精确匹配 + 顺序颠倒兜底)
|
||
category = await broll_category.get_by_name_and_level(
|
||
db, name=normalized, level=3
|
||
)
|
||
# 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
|
||
if category is None:
|
||
parts = normalized.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
reversed_name = f"{parts[1]}-{parts[0]}"
|
||
category = await broll_category.get_by_name_and_level(
|
||
db, name=reversed_name, level=3
|
||
)
|
||
if category:
|
||
logger.info(
|
||
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
|
||
)
|
||
# 若仍失败,回退到上级分类随机选取
|
||
if category is None:
|
||
category = await _try_fallback_to_parent(db, normalized)
|
||
if category:
|
||
logger.info(
|
||
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
|
||
)
|
||
if category is None:
|
||
logger.debug(f"未找到分类: {normalized}")
|
||
return None
|
||
|
||
# 2. 查询候选素材
|
||
materials = await broll_material.get_active_by_category_and_duration(
|
||
db, category_id=category.id, min_duration=required_duration
|
||
)
|
||
if not materials:
|
||
logger.debug(
|
||
f"分类 {normalized} 无足够时长的素材 (需 >= {required_duration}s)"
|
||
)
|
||
return None
|
||
|
||
# 3. Redis 去重:获取项目已使用素材
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await redis.smembers(key))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 4. 区分未用候选和全部候选
|
||
unused = [m for m in materials if m.url not in used_urls]
|
||
target_pool = unused if unused else materials
|
||
|
||
# 5. 加权随机选择
|
||
chosen = _weighted_choice(target_pool)
|
||
|
||
# 6. 原子递增 usage_count(避免并发覆盖)
|
||
await broll_material.increment_usage_count(db, material_id=chosen.id)
|
||
|
||
# 7. 记录到 Redis(异常不影响主流程)
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
await redis.sadd(key, chosen.url)
|
||
await redis.expire(key, _USED_MATERIALS_TTL)
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return {"url": chosen.url, "duration": float(chosen.duration)}
|
||
|
||
|
||
async def batch_match(
|
||
db: AsyncSession,
|
||
scenes: list[dict],
|
||
project_id: str | None = None,
|
||
) -> list[dict | None]:
|
||
"""
|
||
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis)
|
||
|
||
优化策略:
|
||
1. 一次性批量查询所有三级分类(1 次 DB)。
|
||
2. 一次性批量查询所有相关素材(1 次 DB)。
|
||
3. 内存中按 scene + duration 过滤并加权随机选择。
|
||
4. 批量 UPDATE usage_count(1 次 DB)。
|
||
5. 批量 Redis sadd(1 次 Redis pipeline)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scenes: 每个元素为 {"scene": str, "duration": float}
|
||
project_id: 项目ID,用于去重
|
||
|
||
Returns:
|
||
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
|
||
"""
|
||
if not scenes:
|
||
return []
|
||
|
||
# 1. 标准化所有 scene 并去重
|
||
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
|
||
unique_names = list(set(normalized_scenes))
|
||
|
||
# 2. 批量查询分类(1 次 DB)—— 同时查询原始名和倒序名
|
||
reversed_names: list[str] = []
|
||
name_to_reversed: dict[str, str] = {}
|
||
for name in unique_names:
|
||
parts = name.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
rev = f"{parts[1]}-{parts[0]}"
|
||
reversed_names.append(rev)
|
||
name_to_reversed[name] = rev
|
||
|
||
all_query_names = unique_names + reversed_names
|
||
categories = await broll_category.get_by_names_and_level(
|
||
db, names=all_query_names, level=3
|
||
)
|
||
category_map: dict[str, object] = {}
|
||
for c in categories:
|
||
category_map[c.name] = c
|
||
|
||
# 构建原始 scene -> category 的映射(优先精确匹配,fallback 倒序匹配)
|
||
scene_to_category: dict[str, object] = {}
|
||
for name in unique_names:
|
||
if name in category_map:
|
||
scene_to_category[name] = category_map[name]
|
||
elif name in name_to_reversed and name_to_reversed[name] in category_map:
|
||
rev = name_to_reversed[name]
|
||
scene_to_category[name] = category_map[rev]
|
||
logger.info(
|
||
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
|
||
)
|
||
|
||
# 3. 未匹配的 scene 回退到上级分类随机选取
|
||
unmatched = [name for name in unique_names if name not in scene_to_category]
|
||
for name in unmatched:
|
||
fallback_cat = await _try_fallback_to_parent(db, name)
|
||
if fallback_cat:
|
||
scene_to_category[name] = fallback_cat
|
||
logger.info(
|
||
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
|
||
)
|
||
|
||
# 4. 收集所有需要的 category_id
|
||
needed_category_ids = [
|
||
scene_to_category[name].id
|
||
for name in unique_names
|
||
if name in scene_to_category
|
||
]
|
||
|
||
# 4. 批量查询素材(1 次 DB)
|
||
all_materials = await broll_material.get_active_by_categories(
|
||
db, category_ids=list(set(needed_category_ids))
|
||
)
|
||
|
||
# 按 category_id 分组,方便内存过滤
|
||
materials_by_category: dict[int, list] = {}
|
||
for m in all_materials:
|
||
materials_by_category.setdefault(m.category_id, []).append(m)
|
||
|
||
# 5. Redis 获取已使用素材(1 次 Redis)
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await redis.smembers(key))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 6. 内存中逐个匹配
|
||
results: list[dict | None] = []
|
||
chosen_materials: list = [] # 记录选中的素材,用于批量更新
|
||
|
||
for idx, scene_name in enumerate(normalized_scenes):
|
||
required_duration = scenes[idx]["duration"]
|
||
|
||
category = scene_to_category.get(scene_name)
|
||
if category is None:
|
||
results.append(None)
|
||
continue
|
||
|
||
materials = materials_by_category.get(category.id, [])
|
||
# 按时长过滤
|
||
candidates = [m for m in materials if m.duration >= required_duration]
|
||
if not candidates:
|
||
results.append(None)
|
||
continue
|
||
|
||
# Redis 去重
|
||
unused = [m for m in candidates if m.url not in used_urls]
|
||
target_pool = unused if unused else candidates
|
||
|
||
# 加权随机
|
||
chosen = _weighted_choice(target_pool)
|
||
chosen_materials.append(chosen)
|
||
used_urls.add(chosen.url)
|
||
results.append({"url": chosen.url, "duration": float(chosen.duration)})
|
||
|
||
# 7. 批量更新 usage_count(1 次 DB)
|
||
if chosen_materials:
|
||
material_ids = [m.id for m in chosen_materials]
|
||
try:
|
||
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
|
||
except Exception as e:
|
||
logger.warning(f"批量更新 usage_count 失败: {e}")
|
||
|
||
# 8. 批量记录到 Redis(pipeline)
|
||
if project_id and chosen_materials:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
pipe = redis.pipeline()
|
||
for m in chosen_materials:
|
||
pipe.sadd(key, m.url)
|
||
pipe.expire(key, _USED_MATERIALS_TTL)
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return results
|