diff --git a/python-api/app/services/material_service.py b/python-api/app/services/material_service.py index c2cccea..441d196 100644 --- a/python-api/app/services/material_service.py +++ b/python-api/app/services/material_service.py @@ -94,10 +94,22 @@ async def match_material( normalized = _normalize_scene(scene) - # 1. 查找三级分类 + # 1. 查找三级分类(精确匹配 + 顺序颠倒兜底) category = await broll_category.get_by_name_and_level( db, name=normalized, level=3 ) + # 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配 + if category is None: + parts = normalized.rsplit("-", 1) + if len(parts) == 2: + reversed_name = f"{parts[1]}-{parts[0]}" + category = await broll_category.get_by_name_and_level( + db, name=reversed_name, level=3 + ) + if category: + logger.info( + f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'" + ) if category is None: logger.debug(f"未找到分类: {normalized}") return None @@ -175,17 +187,41 @@ async def batch_match( normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes] unique_names = list(set(normalized_scenes)) - # 2. 批量查询分类(1 次 DB) + # 2. 批量查询分类(1 次 DB)—— 同时查询原始名和倒序名 + reversed_names: list[str] = [] + name_to_reversed: dict[str, str] = {} + for name in unique_names: + parts = name.rsplit("-", 1) + if len(parts) == 2: + rev = f"{parts[1]}-{parts[0]}" + reversed_names.append(rev) + name_to_reversed[name] = rev + + all_query_names = unique_names + reversed_names categories = await broll_category.get_by_names_and_level( - db, names=unique_names, level=3 + db, names=all_query_names, level=3 ) - category_map = {c.name: c for c in categories} + category_map: dict[str, object] = {} + for c in categories: + category_map[c.name] = c + + # 构建原始 scene -> category 的映射(优先精确匹配,fallback 倒序匹配) + scene_to_category: dict[str, object] = {} + for name in unique_names: + if name in category_map: + scene_to_category[name] = category_map[name] + elif name in name_to_reversed and name_to_reversed[name] in category_map: + rev = name_to_reversed[name] + scene_to_category[name] = category_map[rev] + logger.info( + f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'" + ) # 3. 收集所有需要的 category_id needed_category_ids = [ - category_map[name].id + scene_to_category[name].id for name in unique_names - if name in category_map + if name in scene_to_category ] # 4. 批量查询素材(1 次 DB) @@ -215,7 +251,7 @@ async def batch_match( for idx, scene_name in enumerate(normalized_scenes): required_duration = scenes[idx]["duration"] - category = category_map.get(scene_name) + category = scene_to_category.get(scene_name) if category is None: results.append(None) continue