Files
meijiaka-zy/python-api/app/services/material_service.py
T
小鱼开发 af8c483910 feat: 素材匹配 fallback 到上级分类随机选取
当三级分类(level=3)精确匹配失败时,回退到上级(level=2)
分类随机选取一个子分类,避免 AI 生成无效 scene(如
'电路施工-电路施工')导致素材匹配完全失败。

- CRUD: 新增 get_children_by_parent_id 方法
- match_material: 新增 _try_fallback_to_parent 辅助函数
- batch_match: 同步增加 fallback 逻辑
- 顺手修复 zip() 缺少 strict 参数的 lint 问题
2026-06-01 19:05:41 +08:00

352 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
空镜素材服务
============
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
"""
import logging
import math
import random
import re
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ValidationException
from app.core.redis_client import get_redis_client
from app.crud import broll_category, broll_material
from app.models.broll_category import BrollCategory
logger = logging.getLogger(__name__)
# Redis 已使用素材 Set 的 TTL7 天)
_USED_MATERIALS_TTL = 7 * 24 * 3600
def _normalize_scene(scene: str) -> str:
"""标准化场景描述,用于匹配三级分类 name"""
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
return re.sub(r"\s+", "", scene)
def _weighted_choice(materials: list) -> object: # noqa: ANN001
"""
加权随机选择素材
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
"""
if not materials:
raise ValueError("候选池为空")
if len(materials) == 1:
return materials[0]
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
total_weight = sum(weights)
if total_weight == 0:
return random.choice(materials)
r = random.uniform(0, total_weight)
cumulative = 0.0
for m, w in zip(materials, weights, strict=True):
cumulative += w
if r <= cumulative:
return m
# 兜底返回最后一个
return materials[-1]
async def _try_fallback_to_parent(
db: AsyncSession,
normalized_scene: str,
) -> BrollCategory | None:
"""
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
解析逻辑:
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工'
- 若不含 '-',直接以整个 scene 作为 parent_name
返回:
随机选中的一个 level=3 子分类,或 None
"""
if "-" in normalized_scene:
parent_name = normalized_scene.rsplit("-", 1)[-1]
else:
parent_name = normalized_scene
parent = await broll_category.get_by_name_and_level(
db, name=parent_name, level=2
)
if parent is None:
return None
children = await broll_category.get_children_by_parent_id(
db, parent_id=parent.id, level=3
)
if not children:
return None
return random.choice(children)
async def match_material(
db: AsyncSession,
scene: str,
required_duration: float,
project_id: str | None = None,
) -> dict | None:
"""
根据场景描述和所需时长匹配空镜素材
匹配策略:
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
6. 优先从未使用候选中加权随机选择;若未用候选为空,
fallback 到全部候选(允许复用,保证合成连续性)。
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set7 天 TTL)。
Args:
db: 数据库 Session
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工"
required_duration: 所需时长(秒),必须大于 0
project_id: 项目ID,用于去重。为 None 时不做去重。
Returns:
{"url": str, "duration": float} 或 None
Raises:
ValidationException: scene 为空或 duration <= 0
"""
# 参数校验
if not scene or not scene.strip():
raise ValidationException("场景描述不能为空")
if required_duration <= 0:
raise ValidationException("所需时长必须大于 0")
normalized = _normalize_scene(scene)
# 1. 查找三级分类(精确匹配 + 顺序颠倒兜底)
category = await broll_category.get_by_name_and_level(
db, name=normalized, level=3
)
# 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
if category is None:
parts = normalized.rsplit("-", 1)
if len(parts) == 2:
reversed_name = f"{parts[1]}-{parts[0]}"
category = await broll_category.get_by_name_and_level(
db, name=reversed_name, level=3
)
if category:
logger.info(
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
)
# 若仍失败,回退到上级分类随机选取
if category is None:
category = await _try_fallback_to_parent(db, normalized)
if category:
logger.info(
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
)
if category is None:
logger.debug(f"未找到分类: {normalized}")
return None
# 2. 查询候选素材
materials = await broll_material.get_active_by_category_and_duration(
db, category_id=category.id, min_duration=required_duration
)
if not materials:
logger.debug(
f"分类 {normalized} 无足够时长的素材 (需 >= {required_duration}s)"
)
return None
# 3. Redis 去重:获取项目已使用素材
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 4. 区分未用候选和全部候选
unused = [m for m in materials if m.url not in used_urls]
target_pool = unused if unused else materials
# 5. 加权随机选择
chosen = _weighted_choice(target_pool)
# 6. 原子递增 usage_count(避免并发覆盖)
await broll_material.increment_usage_count(db, material_id=chosen.id)
# 7. 记录到 Redis(异常不影响主流程)
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
await redis.sadd(key, chosen.url)
await redis.expire(key, _USED_MATERIALS_TTL)
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return {"url": chosen.url, "duration": float(chosen.duration)}
async def batch_match(
db: AsyncSession,
scenes: list[dict],
project_id: str | None = None,
) -> list[dict | None]:
"""
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis
优化策略:
1. 一次性批量查询所有三级分类(1 次 DB)。
2. 一次性批量查询所有相关素材(1 次 DB)。
3. 内存中按 scene + duration 过滤并加权随机选择。
4. 批量 UPDATE usage_count1 次 DB)。
5. 批量 Redis sadd1 次 Redis pipeline)。
Args:
db: 数据库 Session
scenes: 每个元素为 {"scene": str, "duration": float}
project_id: 项目ID,用于去重
Returns:
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
"""
if not scenes:
return []
# 1. 标准化所有 scene 并去重
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
unique_names = list(set(normalized_scenes))
# 2. 批量查询分类(1 次 DB)—— 同时查询原始名和倒序名
reversed_names: list[str] = []
name_to_reversed: dict[str, str] = {}
for name in unique_names:
parts = name.rsplit("-", 1)
if len(parts) == 2:
rev = f"{parts[1]}-{parts[0]}"
reversed_names.append(rev)
name_to_reversed[name] = rev
all_query_names = unique_names + reversed_names
categories = await broll_category.get_by_names_and_level(
db, names=all_query_names, level=3
)
category_map: dict[str, object] = {}
for c in categories:
category_map[c.name] = c
# 构建原始 scene -> category 的映射(优先精确匹配,fallback 倒序匹配)
scene_to_category: dict[str, object] = {}
for name in unique_names:
if name in category_map:
scene_to_category[name] = category_map[name]
elif name in name_to_reversed and name_to_reversed[name] in category_map:
rev = name_to_reversed[name]
scene_to_category[name] = category_map[rev]
logger.info(
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
)
# 3. 未匹配的 scene 回退到上级分类随机选取
unmatched = [name for name in unique_names if name not in scene_to_category]
for name in unmatched:
fallback_cat = await _try_fallback_to_parent(db, name)
if fallback_cat:
scene_to_category[name] = fallback_cat
logger.info(
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
)
# 4. 收集所有需要的 category_id
needed_category_ids = [
scene_to_category[name].id
for name in unique_names
if name in scene_to_category
]
# 4. 批量查询素材(1 次 DB
all_materials = await broll_material.get_active_by_categories(
db, category_ids=list(set(needed_category_ids))
)
# 按 category_id 分组,方便内存过滤
materials_by_category: dict[int, list] = {}
for m in all_materials:
materials_by_category.setdefault(m.category_id, []).append(m)
# 5. Redis 获取已使用素材(1 次 Redis)
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 6. 内存中逐个匹配
results: list[dict | None] = []
chosen_materials: list = [] # 记录选中的素材,用于批量更新
for idx, scene_name in enumerate(normalized_scenes):
required_duration = scenes[idx]["duration"]
category = scene_to_category.get(scene_name)
if category is None:
results.append(None)
continue
materials = materials_by_category.get(category.id, [])
# 按时长过滤
candidates = [m for m in materials if m.duration >= required_duration]
if not candidates:
results.append(None)
continue
# Redis 去重
unused = [m for m in candidates if m.url not in used_urls]
target_pool = unused if unused else candidates
# 加权随机
chosen = _weighted_choice(target_pool)
chosen_materials.append(chosen)
used_urls.add(chosen.url)
results.append({"url": chosen.url, "duration": float(chosen.duration)})
# 7. 批量更新 usage_count1 次 DB
if chosen_materials:
material_ids = [m.id for m in chosen_materials]
try:
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
except Exception as e:
logger.warning(f"批量更新 usage_count 失败: {e}")
# 8. 批量记录到 Redispipeline
if project_id and chosen_materials:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
pipe = redis.pipeline()
for m in chosen_materials:
pipe.sadd(key, m.url)
pipe.expire(key, _USED_MATERIALS_TTL)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return results