feat: 素材匹配 fallback 到上级分类随机选取
当三级分类(level=3)精确匹配失败时,回退到上级(level=2) 分类随机选取一个子分类,避免 AI 生成无效 scene(如 '电路施工-电路施工')导致素材匹配完全失败。 - CRUD: 新增 get_children_by_parent_id 方法 - match_material: 新增 _try_fallback_to_parent 辅助函数 - batch_match: 同步增加 fallback 逻辑 - 顺手修复 zip() 缺少 strict 参数的 lint 问题
This commit is contained in:
@@ -44,6 +44,19 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
|
|||||||
)
|
)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_children_by_parent_id(
|
||||||
|
self, db: AsyncSession, *, parent_id: int, level: int
|
||||||
|
) -> list[BrollCategory]:
|
||||||
|
"""根据父分类 ID 和层级获取启用的子分类"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BrollCategory).where(
|
||||||
|
BrollCategory.parent_id == parent_id,
|
||||||
|
BrollCategory.level == level,
|
||||||
|
BrollCategory.status == "active",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
# 导出实例
|
# 导出实例
|
||||||
broll_category = BrollCategoryCRUD()
|
broll_category = BrollCategoryCRUD()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.core.exceptions import ValidationException
|
from app.core.exceptions import ValidationException
|
||||||
from app.core.redis_client import get_redis_client
|
from app.core.redis_client import get_redis_client
|
||||||
from app.crud import broll_category, broll_material
|
from app.crud import broll_category, broll_material
|
||||||
|
from app.models.broll_category import BrollCategory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
|||||||
|
|
||||||
r = random.uniform(0, total_weight)
|
r = random.uniform(0, total_weight)
|
||||||
cumulative = 0.0
|
cumulative = 0.0
|
||||||
for m, w in zip(materials, weights):
|
for m, w in zip(materials, weights, strict=True):
|
||||||
cumulative += w
|
cumulative += w
|
||||||
if r <= cumulative:
|
if r <= cumulative:
|
||||||
return m
|
return m
|
||||||
@@ -57,6 +58,40 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
|||||||
return materials[-1]
|
return materials[-1]
|
||||||
|
|
||||||
|
|
||||||
|
async def _try_fallback_to_parent(
|
||||||
|
db: AsyncSession,
|
||||||
|
normalized_scene: str,
|
||||||
|
) -> BrollCategory | None:
|
||||||
|
"""
|
||||||
|
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
|
||||||
|
|
||||||
|
解析逻辑:
|
||||||
|
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工')
|
||||||
|
- 若不含 '-',直接以整个 scene 作为 parent_name
|
||||||
|
|
||||||
|
返回:
|
||||||
|
随机选中的一个 level=3 子分类,或 None
|
||||||
|
"""
|
||||||
|
if "-" in normalized_scene:
|
||||||
|
parent_name = normalized_scene.rsplit("-", 1)[-1]
|
||||||
|
else:
|
||||||
|
parent_name = normalized_scene
|
||||||
|
|
||||||
|
parent = await broll_category.get_by_name_and_level(
|
||||||
|
db, name=parent_name, level=2
|
||||||
|
)
|
||||||
|
if parent is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
children = await broll_category.get_children_by_parent_id(
|
||||||
|
db, parent_id=parent.id, level=3
|
||||||
|
)
|
||||||
|
if not children:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return random.choice(children)
|
||||||
|
|
||||||
|
|
||||||
async def match_material(
|
async def match_material(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
scene: str,
|
scene: str,
|
||||||
@@ -68,11 +103,13 @@ async def match_material(
|
|||||||
|
|
||||||
匹配策略:
|
匹配策略:
|
||||||
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
|
||||||
2. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。
|
||||||
3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。
|
||||||
4. 优先从未使用候选中加权随机选择;若未用候选为空,
|
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
|
||||||
|
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
|
||||||
|
6. 优先从未使用候选中加权随机选择;若未用候选为空,
|
||||||
fallback 到全部候选(允许复用,保证合成连续性)。
|
fallback 到全部候选(允许复用,保证合成连续性)。
|
||||||
5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set(7 天 TTL)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库 Session
|
db: 数据库 Session
|
||||||
@@ -110,6 +147,13 @@ async def match_material(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
|
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
|
||||||
)
|
)
|
||||||
|
# 若仍失败,回退到上级分类随机选取
|
||||||
|
if category is None:
|
||||||
|
category = await _try_fallback_to_parent(db, normalized)
|
||||||
|
if category:
|
||||||
|
logger.info(
|
||||||
|
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
|
||||||
|
)
|
||||||
if category is None:
|
if category is None:
|
||||||
logger.debug(f"未找到分类: {normalized}")
|
logger.debug(f"未找到分类: {normalized}")
|
||||||
return None
|
return None
|
||||||
@@ -217,7 +261,17 @@ async def batch_match(
|
|||||||
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
|
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 收集所有需要的 category_id
|
# 3. 未匹配的 scene 回退到上级分类随机选取
|
||||||
|
unmatched = [name for name in unique_names if name not in scene_to_category]
|
||||||
|
for name in unmatched:
|
||||||
|
fallback_cat = await _try_fallback_to_parent(db, name)
|
||||||
|
if fallback_cat:
|
||||||
|
scene_to_category[name] = fallback_cat
|
||||||
|
logger.info(
|
||||||
|
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 收集所有需要的 category_id
|
||||||
needed_category_ids = [
|
needed_category_ids = [
|
||||||
scene_to_category[name].id
|
scene_to_category[name].id
|
||||||
for name in unique_names
|
for name in unique_names
|
||||||
|
|||||||
Reference in New Issue
Block a user