384 lines
15 KiB
Python
384 lines
15 KiB
Python
"""
|
||
空镜素材服务
|
||
============
|
||
|
||
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
import random
|
||
import re
|
||
from collections.abc import Awaitable
|
||
from typing import Any, cast
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.exceptions import ValidationException
|
||
from app.core.redis_client import get_redis_client
|
||
from app.crud import broll_category, broll_material
|
||
from app.models.broll_category import BrollCategory
|
||
from app.models.broll_material import BrollMaterial
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Redis 已使用素材 Set 的 TTL(7 天)
|
||
_USED_MATERIALS_TTL = 7 * 24 * 3600
|
||
|
||
|
||
def _normalize_scene(scene: str) -> str:
|
||
"""标准化场景描述,用于匹配三级分类 name"""
|
||
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
|
||
cleaned = re.sub(r"\s+", "", scene)
|
||
# 去除常见中文标点符号(逗号、句号、感叹号、问号、顿号、分号、冒号、引号、括号等)
|
||
cleaned = re.sub(r"[,。!?、;:" "''()【】《》]+", "", cleaned)
|
||
# 去除零宽字符(零宽空格、零宽非连接符、零宽连接符、零宽非断空格等)
|
||
cleaned = re.sub(r"[\u200b-\u200f\ufeff]+", "", cleaned)
|
||
return cleaned
|
||
|
||
|
||
def _weighted_choice(materials: list[BrollMaterial]) -> BrollMaterial:
|
||
"""
|
||
加权随机选择素材
|
||
|
||
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
|
||
"""
|
||
if not materials:
|
||
raise ValueError("候选池为空")
|
||
|
||
if len(materials) == 1:
|
||
return materials[0]
|
||
|
||
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
|
||
total_weight = sum(weights)
|
||
|
||
if total_weight == 0:
|
||
return random.choice(materials) # nosec B311 素材抽样,非加密场景
|
||
|
||
r = random.uniform(0, total_weight) # nosec B311 素材抽样,非加密场景
|
||
cumulative = 0.0
|
||
for m, w in zip(materials, weights, strict=True):
|
||
cumulative += w
|
||
if r <= cumulative:
|
||
return m
|
||
|
||
# 兜底返回最后一个
|
||
return materials[-1]
|
||
|
||
|
||
async def _try_fallback_to_parent(
|
||
db: AsyncSession,
|
||
normalized_scene: str,
|
||
) -> BrollCategory | None:
|
||
"""
|
||
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
|
||
|
||
解析逻辑:
|
||
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工')
|
||
- 若不含 '-',直接以整个 scene 作为 parent_name
|
||
|
||
匹配策略:
|
||
1. 精确匹配 level=2 分类 name
|
||
2. 模糊匹配(LIKE %parent_name%)
|
||
|
||
返回:
|
||
随机选中的一个 level=3 子分类,或 None
|
||
"""
|
||
if "-" in normalized_scene:
|
||
parent_name = normalized_scene.rsplit("-", 1)[-1]
|
||
else:
|
||
parent_name = normalized_scene
|
||
|
||
# 1. 精确匹配
|
||
parent = await broll_category.get_by_name_and_level(db, name=parent_name, level=2)
|
||
|
||
# 2. 模糊匹配
|
||
if parent is None:
|
||
parent = await broll_category.get_by_name_like_and_level(db, name=parent_name, level=2)
|
||
|
||
if parent is None:
|
||
return None
|
||
|
||
children = await broll_category.get_children_by_parent_id(db, parent_id=parent.id, level=3)
|
||
if not children:
|
||
return None
|
||
|
||
return random.choice(children) # nosec B311 素材抽样,非加密场景
|
||
|
||
|
||
async def match_material(
|
||
db: AsyncSession,
|
||
scene: str,
|
||
required_duration: float,
|
||
project_id: str | None = None,
|
||
) -> dict | None:
|
||
"""
|
||
根据场景描述和所需时长匹配空镜素材
|
||
|
||
匹配策略:
|
||
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
||
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。
|
||
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。
|
||
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
||
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
||
6. 优先从未使用候选中加权随机选择;若未用候选为空,
|
||
fallback 到全部候选(允许复用,保证合成连续性)。
|
||
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工")
|
||
required_duration: 所需时长(秒),必须大于 0
|
||
project_id: 项目ID,用于去重。为 None 时不做去重。
|
||
|
||
Returns:
|
||
{"url": str, "duration": float} 或 None
|
||
|
||
Raises:
|
||
ValidationException: scene 为空或 duration <= 0
|
||
"""
|
||
# 参数校验
|
||
if not scene or not scene.strip():
|
||
raise ValidationException("场景描述不能为空")
|
||
if required_duration <= 0:
|
||
raise ValidationException("所需时长必须大于 0")
|
||
|
||
normalized = _normalize_scene(scene)
|
||
|
||
# 1. 查找三级分类(精确匹配 -> 全量内存匹配兜底 -> 顺序颠倒 -> 上级回退)
|
||
category = await broll_category.get_by_name_and_level(db, name=normalized, level=3)
|
||
# 精确匹配失败时,全量查询后在内存标准化匹配(兼容数据库 name 含不可见字符)
|
||
if category is None:
|
||
all_categories = await broll_category.get_by_level(db, level=3)
|
||
for c in all_categories:
|
||
if _normalize_scene(c.name) == normalized:
|
||
category = c
|
||
logger.info(f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'")
|
||
break
|
||
# 若仍失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
|
||
if category is None:
|
||
parts = normalized.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
reversed_name = f"{parts[1]}-{parts[0]}"
|
||
category = await broll_category.get_by_name_and_level(db, name=reversed_name, level=3)
|
||
if category:
|
||
logger.info(f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'")
|
||
# 若仍失败,回退到上级分类随机选取
|
||
if category is None:
|
||
category = await _try_fallback_to_parent(db, normalized)
|
||
if category:
|
||
logger.info(f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'")
|
||
if category is None:
|
||
logger.warning(f"素材匹配失败: 未找到分类 '{normalized}' (原始 scene: '{scene}')")
|
||
return None
|
||
|
||
# 2. 查询该分类下所有 active 素材(先不过滤时长,用于日志诊断)
|
||
all_materials = await broll_material.get_active_by_categories(db, category_ids=[category.id])
|
||
if not all_materials:
|
||
logger.warning(f"素材匹配失败: 分类 '{normalized}' 下无任何可用素材")
|
||
return None
|
||
|
||
# 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底)
|
||
materials = [m for m in all_materials if m.duration >= required_duration]
|
||
if not materials:
|
||
materials = [m for m in all_materials if m.duration >= required_duration * 0.7]
|
||
if not materials:
|
||
materials = all_materials
|
||
if not materials:
|
||
max_duration = max(m.duration for m in all_materials)
|
||
logger.warning(
|
||
f"素材匹配失败: 分类 '{normalized}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)"
|
||
)
|
||
return None
|
||
|
||
# 3. Redis 去重:获取项目已使用素材
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 4. 区分未用候选和全部候选
|
||
unused = [m for m in materials if m.url not in used_urls]
|
||
target_pool = unused if unused else materials
|
||
|
||
# 5. 加权随机选择
|
||
chosen = _weighted_choice(target_pool)
|
||
|
||
# 6. 原子递增 usage_count(避免并发覆盖)
|
||
await broll_material.increment_usage_count(db, material_id=chosen.id)
|
||
|
||
# 7. 记录到 Redis(异常不影响主流程)
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
await cast(Awaitable[int], redis.sadd(key, chosen.url))
|
||
await cast(Awaitable[int], redis.expire(key, _USED_MATERIALS_TTL))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return {"url": chosen.url, "duration": float(chosen.duration)}
|
||
|
||
|
||
async def batch_match(
|
||
db: AsyncSession,
|
||
scenes: list[dict],
|
||
project_id: str | None = None,
|
||
) -> list[dict | None]:
|
||
"""
|
||
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis)
|
||
|
||
优化策略:
|
||
1. 一次性批量查询所有三级分类(1 次 DB)。
|
||
2. 一次性批量查询所有相关素材(1 次 DB)。
|
||
3. 内存中按 scene + duration 过滤并加权随机选择。
|
||
4. 批量 UPDATE usage_count(1 次 DB)。
|
||
5. 批量 Redis sadd(1 次 Redis pipeline)。
|
||
|
||
Args:
|
||
db: 数据库 Session
|
||
scenes: 每个元素为 {"scene": str, "duration": float}
|
||
project_id: 项目ID,用于去重
|
||
|
||
Returns:
|
||
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
|
||
"""
|
||
if not scenes:
|
||
return []
|
||
|
||
# 1. 标准化所有 scene 并去重
|
||
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
|
||
unique_names = list(set(normalized_scenes))
|
||
|
||
# 2. 批量查询分类:优先精确查询,失败时全量内存匹配兜底
|
||
categories = await broll_category.get_by_names_and_level(db, names=unique_names, level=3)
|
||
category_map: dict[str, BrollCategory] = {}
|
||
for c in categories:
|
||
category_map[_normalize_scene(c.name)] = c
|
||
|
||
# 收集未命中的 name,准备全量兜底
|
||
unmatched_by_exact = [name for name in unique_names if name not in category_map]
|
||
if unmatched_by_exact:
|
||
all_categories = await broll_category.get_by_level(db, level=3)
|
||
for c in all_categories:
|
||
normalized_db_name = _normalize_scene(c.name)
|
||
if normalized_db_name not in category_map:
|
||
category_map[normalized_db_name] = c
|
||
|
||
# 构建原始 scene -> category 的映射
|
||
reversed_map: dict[str, str] = {}
|
||
for name in unique_names:
|
||
parts = name.rsplit("-", 1)
|
||
if len(parts) == 2:
|
||
reversed_map[name] = f"{parts[1]}-{parts[0]}"
|
||
|
||
scene_to_category: dict[str, BrollCategory] = {}
|
||
for name in unique_names:
|
||
if name in category_map:
|
||
scene_to_category[name] = category_map[name]
|
||
elif name in reversed_map and reversed_map[name] in category_map:
|
||
rev = reversed_map[name]
|
||
scene_to_category[name] = category_map[rev]
|
||
logger.info(f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'")
|
||
|
||
# 3. 未匹配的 scene 回退到上级分类随机选取
|
||
unmatched = [name for name in unique_names if name not in scene_to_category]
|
||
for name in unmatched:
|
||
fallback_cat = await _try_fallback_to_parent(db, name)
|
||
if fallback_cat:
|
||
scene_to_category[name] = fallback_cat
|
||
logger.info(f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'")
|
||
|
||
# 4. 收集所有需要的 category_id
|
||
needed_category_ids = [
|
||
scene_to_category[name].id for name in unique_names if name in scene_to_category
|
||
]
|
||
|
||
# 4. 批量查询素材(1 次 DB)
|
||
all_materials = await broll_material.get_active_by_categories(
|
||
db, category_ids=list(set(needed_category_ids))
|
||
)
|
||
|
||
# 按 category_id 分组,方便内存过滤
|
||
materials_by_category: dict[int, list[BrollMaterial]] = {}
|
||
for m in all_materials:
|
||
materials_by_category.setdefault(m.category_id, []).append(m)
|
||
|
||
# 5. Redis 获取已使用素材(1 次 Redis)
|
||
used_urls: set[str] = set()
|
||
if project_id:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||
|
||
# 6. 内存中逐个匹配
|
||
results: list[dict | None] = []
|
||
chosen_materials: list[BrollMaterial] = [] # 记录选中的素材,用于批量更新
|
||
|
||
for idx, scene_name in enumerate(normalized_scenes):
|
||
required_duration = scenes[idx]["duration"]
|
||
|
||
category = scene_to_category.get(scene_name)
|
||
if category is None:
|
||
original_scene = scenes[idx]["scene"]
|
||
logger.warning(
|
||
f"批量素材匹配失败: 未找到分类 '{scene_name}' (原始 scene: '{original_scene}')"
|
||
)
|
||
results.append(None)
|
||
continue
|
||
|
||
materials = materials_by_category.get(category.id, [])
|
||
# 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底)
|
||
candidates = [m for m in materials if m.duration >= required_duration]
|
||
if not candidates:
|
||
candidates = [m for m in materials if m.duration >= required_duration * 0.7]
|
||
if not candidates:
|
||
candidates = materials
|
||
if not candidates:
|
||
max_duration = max((m.duration for m in materials), default=0)
|
||
logger.warning(
|
||
f"批量素材匹配失败: 分类 '{scene_name}' -> '{category.name}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)"
|
||
)
|
||
results.append(None)
|
||
continue
|
||
|
||
# Redis 去重
|
||
unused = [m for m in candidates if m.url not in used_urls]
|
||
target_pool = unused if unused else candidates
|
||
|
||
# 加权随机
|
||
chosen = _weighted_choice(target_pool)
|
||
chosen_materials.append(chosen)
|
||
used_urls.add(chosen.url)
|
||
results.append({"url": chosen.url, "duration": float(chosen.duration)})
|
||
|
||
# 7. 批量更新 usage_count(1 次 DB)
|
||
if chosen_materials:
|
||
material_ids = [m.id for m in chosen_materials]
|
||
try:
|
||
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
|
||
except Exception as e:
|
||
logger.warning(f"批量更新 usage_count 失败: {e}")
|
||
|
||
# 8. 批量记录到 Redis(pipeline)
|
||
if project_id and chosen_materials:
|
||
try:
|
||
redis = get_redis_client()
|
||
key = f"proj:{project_id}:used_materials"
|
||
pipe = redis.pipeline()
|
||
for m in chosen_materials:
|
||
pipe.sadd(key, m.url)
|
||
pipe.expire(key, _USED_MATERIALS_TTL)
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(f"Redis 去重记录失败: {e}")
|
||
|
||
return results
|