feat: 素材匹配 fallback 到上级分类随机选取

当三级分类(level=3)精确匹配失败时,回退到上级(level=2)
分类随机选取一个子分类,避免 AI 生成无效 scene(如
'电路施工-电路施工')导致素材匹配完全失败。

- CRUD: 新增 get_children_by_parent_id 方法
- match_material: 新增 _try_fallback_to_parent 辅助函数
- batch_match: 同步增加 fallback 逻辑
- 顺手修复 zip() 缺少 strict 参数的 lint 问题
This commit is contained in:
小鱼开发
2026-06-01 19:05:41 +08:00
parent f109a115d4
commit af8c483910
2 changed files with 73 additions and 6 deletions
+13
View File
@@ -44,6 +44,19 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
)
return list(result.scalars().all())
async def get_children_by_parent_id(
self, db: AsyncSession, *, parent_id: int, level: int
) -> list[BrollCategory]:
"""根据父分类 ID 和层级获取启用的子分类"""
result = await db.execute(
select(BrollCategory).where(
BrollCategory.parent_id == parent_id,
BrollCategory.level == level,
BrollCategory.status == "active",
)
)
return list(result.scalars().all())
# 导出实例
broll_category = BrollCategoryCRUD()
+60 -6
View File
@@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ValidationException
from app.core.redis_client import get_redis_client
from app.crud import broll_category, broll_material
from app.models.broll_category import BrollCategory
logger = logging.getLogger(__name__)
@@ -48,7 +49,7 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
r = random.uniform(0, total_weight)
cumulative = 0.0
for m, w in zip(materials, weights):
for m, w in zip(materials, weights, strict=True):
cumulative += w
if r <= cumulative:
return m
@@ -57,6 +58,40 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
return materials[-1]
async def _try_fallback_to_parent(
db: AsyncSession,
normalized_scene: str,
) -> BrollCategory | None:
"""
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
解析逻辑:
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工'
- 若不含 '-',直接以整个 scene 作为 parent_name
返回:
随机选中的一个 level=3 子分类,或 None
"""
if "-" in normalized_scene:
parent_name = normalized_scene.rsplit("-", 1)[-1]
else:
parent_name = normalized_scene
parent = await broll_category.get_by_name_and_level(
db, name=parent_name, level=2
)
if parent is None:
return None
children = await broll_category.get_children_by_parent_id(
db, parent_id=parent.id, level=3
)
if not children:
return None
return random.choice(children)
async def match_material(
db: AsyncSession,
scene: str,
@@ -68,11 +103,13 @@ async def match_material(
匹配策略:
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
2. 查询该分类下状态为 active、时长 >= required_duration 的素材
3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除
4. 优先从未使用候选中加权随机选择;若未用候选为空,
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
6. 优先从未使用候选中加权随机选择;若未用候选为空,
fallback 到全部候选(允许复用,保证合成连续性)。
5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set7 天 TTL)。
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set7 天 TTL)。
Args:
db: 数据库 Session
@@ -110,6 +147,13 @@ async def match_material(
logger.info(
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
)
# 若仍失败,回退到上级分类随机选取
if category is None:
category = await _try_fallback_to_parent(db, normalized)
if category:
logger.info(
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
)
if category is None:
logger.debug(f"未找到分类: {normalized}")
return None
@@ -217,7 +261,17 @@ async def batch_match(
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
)
# 3. 收集所有需要的 category_id
# 3. 未匹配的 scene 回退到上级分类随机选取
unmatched = [name for name in unique_names if name not in scene_to_category]
for name in unmatched:
fallback_cat = await _try_fallback_to_parent(db, name)
if fallback_cat:
scene_to_category[name] = fallback_cat
logger.info(
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
)
# 4. 收集所有需要的 category_id
needed_category_ids = [
scene_to_category[name].id
for name in unique_names