From af8c483910e0d8b610f0227e3bc38cda99de2598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E9=B1=BC=E5=BC=80=E5=8F=91?= Date: Mon, 1 Jun 2026 19:05:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=B4=A0=E6=9D=90=E5=8C=B9=E9=85=8D=20?= =?UTF-8?q?fallback=20=E5=88=B0=E4=B8=8A=E7=BA=A7=E5=88=86=E7=B1=BB?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E9=80=89=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当三级分类(level=3)精确匹配失败时,回退到上级(level=2) 分类随机选取一个子分类,避免 AI 生成无效 scene(如 '电路施工-电路施工')导致素材匹配完全失败。 - CRUD: 新增 get_children_by_parent_id 方法 - match_material: 新增 _try_fallback_to_parent 辅助函数 - batch_match: 同步增加 fallback 逻辑 - 顺手修复 zip() 缺少 strict 参数的 lint 问题 --- python-api/app/crud/broll_category.py | 13 ++++ python-api/app/services/material_service.py | 66 +++++++++++++++++++-- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/python-api/app/crud/broll_category.py b/python-api/app/crud/broll_category.py index 9a96966..8fa1ad3 100644 --- a/python-api/app/crud/broll_category.py +++ b/python-api/app/crud/broll_category.py @@ -44,6 +44,19 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]): ) return list(result.scalars().all()) + async def get_children_by_parent_id( + self, db: AsyncSession, *, parent_id: int, level: int + ) -> list[BrollCategory]: + """根据父分类 ID 和层级获取启用的子分类""" + result = await db.execute( + select(BrollCategory).where( + BrollCategory.parent_id == parent_id, + BrollCategory.level == level, + BrollCategory.status == "active", + ) + ) + return list(result.scalars().all()) + # 导出实例 broll_category = BrollCategoryCRUD() diff --git a/python-api/app/services/material_service.py b/python-api/app/services/material_service.py index 441d196..d6fe5a4 100644 --- a/python-api/app/services/material_service.py +++ b/python-api/app/services/material_service.py @@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import ValidationException from app.core.redis_client import get_redis_client from app.crud import broll_category, broll_material +from app.models.broll_category import BrollCategory logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001 r = random.uniform(0, total_weight) cumulative = 0.0 - for m, w in zip(materials, weights): + for m, w in zip(materials, weights, strict=True): cumulative += w if r <= cumulative: return m @@ -57,6 +58,40 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001 return materials[-1] +async def _try_fallback_to_parent( + db: AsyncSession, + normalized_scene: str, +) -> BrollCategory | None: + """ + 三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。 + + 解析逻辑: + - 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工') + - 若不含 '-',直接以整个 scene 作为 parent_name + + 返回: + 随机选中的一个 level=3 子分类,或 None + """ + if "-" in normalized_scene: + parent_name = normalized_scene.rsplit("-", 1)[-1] + else: + parent_name = normalized_scene + + parent = await broll_category.get_by_name_and_level( + db, name=parent_name, level=2 + ) + if parent is None: + return None + + children = await broll_category.get_children_by_parent_id( + db, parent_id=parent.id, level=3 + ) + if not children: + return None + + return random.choice(children) + + async def match_material( db: AsyncSession, scene: str, @@ -68,11 +103,13 @@ async def match_material( 匹配策略: 1. 标准化 scene,精确匹配三级分类(level=3)的 name。 - 2. 查询该分类下状态为 active、时长 >= required_duration 的素材。 - 3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。 - 4. 优先从未使用候选中加权随机选择;若未用候选为空, + 2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。 + 3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。 + 4. 查询该分类下状态为 active、时长 >= required_duration 的素材。 + 5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。 + 6. 优先从未使用候选中加权随机选择;若未用候选为空, fallback 到全部候选(允许复用,保证合成连续性)。 - 5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。 + 7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。 Args: db: 数据库 Session @@ -110,6 +147,13 @@ async def match_material( logger.info( f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'" ) + # 若仍失败,回退到上级分类随机选取 + if category is None: + category = await _try_fallback_to_parent(db, normalized) + if category: + logger.info( + f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'" + ) if category is None: logger.debug(f"未找到分类: {normalized}") return None @@ -217,7 +261,17 @@ async def batch_match( f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'" ) - # 3. 收集所有需要的 category_id + # 3. 未匹配的 scene 回退到上级分类随机选取 + unmatched = [name for name in unique_names if name not in scene_to_category] + for name in unmatched: + fallback_cat = await _try_fallback_to_parent(db, name) + if fallback_cat: + scene_to_category[name] = fallback_cat + logger.info( + f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'" + ) + + # 4. 收集所有需要的 category_id needed_category_ids = [ scene_to_category[name].id for name in unique_names