""" 空镜素材服务 ============ 从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。 """ import logging import math import random import re from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import ValidationException from app.core.redis_client import get_redis_client from app.crud import broll_category, broll_material from app.models.broll_category import BrollCategory logger = logging.getLogger(__name__) # Redis 已使用素材 Set 的 TTL(7 天) _USED_MATERIALS_TTL = 7 * 24 * 3600 def _normalize_scene(scene: str) -> str: """标准化场景描述,用于匹配三级分类 name""" # 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等) cleaned = re.sub(r"\s+", "", scene) # 去除常见中文标点符号(逗号、句号、感叹号、问号、顿号、分号、冒号、引号、括号等) cleaned = re.sub(r"[,。!?、;:""''()【】《》]+", "", cleaned) # 去除零宽字符(零宽空格、零宽非连接符、零宽连接符、零宽非断空格等) cleaned = re.sub(r"[\u200b-\u200f\ufeff]+", "", cleaned) return cleaned def _weighted_choice(materials: list) -> object: # noqa: ANN001 """ 加权随机选择素材 weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。 """ if not materials: raise ValueError("候选池为空") if len(materials) == 1: return materials[0] weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials] total_weight = sum(weights) if total_weight == 0: return random.choice(materials) r = random.uniform(0, total_weight) cumulative = 0.0 for m, w in zip(materials, weights, strict=True): cumulative += w if r <= cumulative: return m # 兜底返回最后一个 return materials[-1] async def _try_fallback_to_parent( db: AsyncSession, normalized_scene: str, ) -> BrollCategory | None: """ 三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。 解析逻辑: - 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工') - 若不含 '-',直接以整个 scene 作为 parent_name 匹配策略(逐级降级): 1. 精确匹配 level=2 分类 name 2. 模糊匹配(LIKE %parent_name%),兼容 "电路施工" → "电路施工镜" 3. 去掉常见后缀(镜、阶段等)再精确匹配 返回: 随机选中的一个 level=3 子分类,或 None """ if "-" in normalized_scene: parent_name = normalized_scene.rsplit("-", 1)[-1] else: parent_name = normalized_scene # 1. 精确匹配 parent = await broll_category.get_by_name_and_level( db, name=parent_name, level=2 ) # 2. 模糊匹配(兼容 "电路施工" → "电路施工镜") if parent is None: parent = await broll_category.get_by_name_like_and_level( db, name=parent_name, level=2 ) # 3. 去掉常见后缀再试 if parent is None: for suffix in ("镜", "阶段"): if not parent_name.endswith(suffix): candidate = parent_name + suffix parent = await broll_category.get_by_name_and_level( db, name=candidate, level=2 ) if parent: break if parent is None: return None children = await broll_category.get_children_by_parent_id( db, parent_id=parent.id, level=3 ) if not children: return None return random.choice(children) async def match_material( db: AsyncSession, scene: str, required_duration: float, project_id: str | None = None, ) -> dict | None: """ 根据场景描述和所需时长匹配空镜素材 匹配策略: 1. 标准化 scene,精确匹配三级分类(level=3)的 name。 2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。 3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。 4. 查询该分类下状态为 active、时长 >= required_duration 的素材。 5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。 6. 优先从未使用候选中加权随机选择;若未用候选为空, fallback 到全部候选(允许复用,保证合成连续性)。 7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。 Args: db: 数据库 Session scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工") required_duration: 所需时长(秒),必须大于 0 project_id: 项目ID,用于去重。为 None 时不做去重。 Returns: {"url": str, "duration": float} 或 None Raises: ValidationException: scene 为空或 duration <= 0 """ # 参数校验 if not scene or not scene.strip(): raise ValidationException("场景描述不能为空") if required_duration <= 0: raise ValidationException("所需时长必须大于 0") normalized = _normalize_scene(scene) # 1. 查找三级分类(精确匹配 -> 全量内存匹配兜底 -> 顺序颠倒 -> 上级回退) category = await broll_category.get_by_name_and_level( db, name=normalized, level=3 ) # 精确匹配失败时,全量查询后在内存标准化匹配(兼容数据库 name 含不可见字符) if category is None: all_categories = await broll_category.get_by_level(db, level=3) for c in all_categories: if _normalize_scene(c.name) == normalized: category = c logger.info( f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'" ) break # 若仍失败,尝试将 "A-B" 倒序为 "B-A" 再匹配 if category is None: parts = normalized.rsplit("-", 1) if len(parts) == 2: reversed_name = f"{parts[1]}-{parts[0]}" category = await broll_category.get_by_name_and_level( db, name=reversed_name, level=3 ) if category: logger.info( f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'" ) # 若仍失败,回退到上级分类随机选取 if category is None: category = await _try_fallback_to_parent(db, normalized) if category: logger.info( f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'" ) if category is None: logger.warning(f"素材匹配失败: 未找到分类 '{normalized}' (原始 scene: '{scene}')") return None # 2. 查询该分类下所有 active 素材(先不过滤时长,用于日志诊断) all_materials = await broll_material.get_active_by_categories( db, category_ids=[category.id] ) if not all_materials: logger.warning(f"素材匹配失败: 分类 '{normalized}' 下无任何可用素材") return None # 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底) materials = [m for m in all_materials if m.duration >= required_duration] if not materials: materials = [m for m in all_materials if m.duration >= required_duration * 0.7] if not materials: materials = all_materials if not materials: max_duration = max(m.duration for m in all_materials) logger.warning( f"素材匹配失败: 分类 '{normalized}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)" ) return None # 3. Redis 去重:获取项目已使用素材 used_urls: set[str] = set() if project_id: try: redis = get_redis_client() key = f"proj:{project_id}:used_materials" used_urls = set(await redis.smembers(key)) except Exception as e: logger.warning(f"Redis 去重查询失败,降级为不去重: {e}") # 4. 区分未用候选和全部候选 unused = [m for m in materials if m.url not in used_urls] target_pool = unused if unused else materials # 5. 加权随机选择 chosen = _weighted_choice(target_pool) # 6. 原子递增 usage_count(避免并发覆盖) await broll_material.increment_usage_count(db, material_id=chosen.id) # 7. 记录到 Redis(异常不影响主流程) if project_id: try: redis = get_redis_client() key = f"proj:{project_id}:used_materials" await redis.sadd(key, chosen.url) await redis.expire(key, _USED_MATERIALS_TTL) except Exception as e: logger.warning(f"Redis 去重记录失败: {e}") return {"url": chosen.url, "duration": float(chosen.duration)} async def batch_match( db: AsyncSession, scenes: list[dict], project_id: str | None = None, ) -> list[dict | None]: """ 批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis) 优化策略: 1. 一次性批量查询所有三级分类(1 次 DB)。 2. 一次性批量查询所有相关素材(1 次 DB)。 3. 内存中按 scene + duration 过滤并加权随机选择。 4. 批量 UPDATE usage_count(1 次 DB)。 5. 批量 Redis sadd(1 次 Redis pipeline)。 Args: db: 数据库 Session scenes: 每个元素为 {"scene": str, "duration": float} project_id: 项目ID,用于去重 Returns: 与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None """ if not scenes: return [] # 1. 标准化所有 scene 并去重 normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes] unique_names = list(set(normalized_scenes)) # 2. 批量查询分类:优先精确查询,失败时全量内存匹配兜底 categories = await broll_category.get_by_names_and_level( db, names=unique_names, level=3 ) category_map: dict[str, object] = {} for c in categories: category_map[_normalize_scene(c.name)] = c # 收集未命中的 name,准备全量兜底 unmatched_by_exact = [name for name in unique_names if name not in category_map] if unmatched_by_exact: all_categories = await broll_category.get_by_level(db, level=3) for c in all_categories: normalized_db_name = _normalize_scene(c.name) if normalized_db_name not in category_map: category_map[normalized_db_name] = c # 构建原始 scene -> category 的映射 reversed_map: dict[str, str] = {} for name in unique_names: parts = name.rsplit("-", 1) if len(parts) == 2: reversed_map[name] = f"{parts[1]}-{parts[0]}" scene_to_category: dict[str, object] = {} for name in unique_names: if name in category_map: scene_to_category[name] = category_map[name] elif name in reversed_map and reversed_map[name] in category_map: rev = reversed_map[name] scene_to_category[name] = category_map[rev] logger.info( f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'" ) # 3. 未匹配的 scene 回退到上级分类随机选取 unmatched = [name for name in unique_names if name not in scene_to_category] for name in unmatched: fallback_cat = await _try_fallback_to_parent(db, name) if fallback_cat: scene_to_category[name] = fallback_cat logger.info( f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'" ) # 4. 收集所有需要的 category_id needed_category_ids = [ scene_to_category[name].id for name in unique_names if name in scene_to_category ] # 4. 批量查询素材(1 次 DB) all_materials = await broll_material.get_active_by_categories( db, category_ids=list(set(needed_category_ids)) ) # 按 category_id 分组,方便内存过滤 materials_by_category: dict[int, list] = {} for m in all_materials: materials_by_category.setdefault(m.category_id, []).append(m) # 5. Redis 获取已使用素材(1 次 Redis) used_urls: set[str] = set() if project_id: try: redis = get_redis_client() key = f"proj:{project_id}:used_materials" used_urls = set(await redis.smembers(key)) except Exception as e: logger.warning(f"Redis 去重查询失败,降级为不去重: {e}") # 6. 内存中逐个匹配 results: list[dict | None] = [] chosen_materials: list = [] # 记录选中的素材,用于批量更新 for idx, scene_name in enumerate(normalized_scenes): required_duration = scenes[idx]["duration"] category = scene_to_category.get(scene_name) if category is None: original_scene = scenes[idx]["scene"] logger.warning( f"批量素材匹配失败: 未找到分类 '{scene_name}' (原始 scene: '{original_scene}')" ) results.append(None) continue materials = materials_by_category.get(category.id, []) # 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底) candidates = [m for m in materials if m.duration >= required_duration] if not candidates: candidates = [m for m in materials if m.duration >= required_duration * 0.7] if not candidates: candidates = materials if not candidates: max_duration = max((m.duration for m in materials), default=0) logger.warning( f"批量素材匹配失败: 分类 '{scene_name}' -> '{category.name}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)" ) results.append(None) continue # Redis 去重 unused = [m for m in candidates if m.url not in used_urls] target_pool = unused if unused else candidates # 加权随机 chosen = _weighted_choice(target_pool) chosen_materials.append(chosen) used_urls.add(chosen.url) results.append({"url": chosen.url, "duration": float(chosen.duration)}) # 7. 批量更新 usage_count(1 次 DB) if chosen_materials: material_ids = [m.id for m in chosen_materials] try: await broll_material.increment_usage_count_batch(db, material_ids=material_ids) except Exception as e: logger.warning(f"批量更新 usage_count 失败: {e}") # 8. 批量记录到 Redis(pipeline) if project_id and chosen_materials: try: redis = get_redis_client() key = f"proj:{project_id}:used_materials" pipe = redis.pipeline() for m in chosen_materials: pipe.sadd(key, m.url) pipe.expire(key, _USED_MATERIALS_TTL) await pipe.execute() except Exception as e: logger.warning(f"Redis 去重记录失败: {e}") return results