Files
meijiaka-zy/python-api/app/services/material_service.py
T
小鱼开发 2a36e4ec3d fix(material): 支持 scene 名称顺序颠倒兜底匹配
AI 生成 scene 时常将三级分类名中的 '-' 前后顺序写反
(如 瓷砖铺贴-瓷砖完工展示 vs 瓷砖完工展示-瓷砖铺贴),
导致精确匹配失败、素材匹配为空。

- match_material: 精确匹配失败后,尝试倒序匹配
- batch_match: 批量查询时同时查询原始名和倒序名,
  内存中构建 scene->category 映射,优先精确匹配、fallback 倒序
2026-05-17 21:35:44 +08:00

298 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
空镜素材服务
============
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
"""
import logging
import math
import random
import re
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ValidationException
from app.core.redis_client import get_redis_client
from app.crud import broll_category, broll_material
logger = logging.getLogger(__name__)
# Redis 已使用素材 Set 的 TTL7 天)
_USED_MATERIALS_TTL = 7 * 24 * 3600
def _normalize_scene(scene: str) -> str:
"""标准化场景描述,用于匹配三级分类 name"""
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
return re.sub(r"\s+", "", scene)
def _weighted_choice(materials: list) -> object: # noqa: ANN001
"""
加权随机选择素材
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
"""
if not materials:
raise ValueError("候选池为空")
if len(materials) == 1:
return materials[0]
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
total_weight = sum(weights)
if total_weight == 0:
return random.choice(materials)
r = random.uniform(0, total_weight)
cumulative = 0.0
for m, w in zip(materials, weights):
cumulative += w
if r <= cumulative:
return m
# 兜底返回最后一个
return materials[-1]
async def match_material(
db: AsyncSession,
scene: str,
required_duration: float,
project_id: str | None = None,
) -> dict | None:
"""
根据场景描述和所需时长匹配空镜素材
匹配策略:
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
2. 查询该分类下状态为 active、时长 >= required_duration 的素材。
3. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
4. 优先从未使用候选中加权随机选择;若未用候选为空,
fallback 到全部候选(允许复用,保证合成连续性)。
5. 原子递增 usage_count,并将选中的 URL 写入 Redis Set7 天 TTL)。
Args:
db: 数据库 Session
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工"
required_duration: 所需时长(秒),必须大于 0
project_id: 项目ID,用于去重。为 None 时不做去重。
Returns:
{"url": str, "duration": float} 或 None
Raises:
ValidationException: scene 为空或 duration <= 0
"""
# 参数校验
if not scene or not scene.strip():
raise ValidationException("场景描述不能为空")
if required_duration <= 0:
raise ValidationException("所需时长必须大于 0")
normalized = _normalize_scene(scene)
# 1. 查找三级分类(精确匹配 + 顺序颠倒兜底)
category = await broll_category.get_by_name_and_level(
db, name=normalized, level=3
)
# 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
if category is None:
parts = normalized.rsplit("-", 1)
if len(parts) == 2:
reversed_name = f"{parts[1]}-{parts[0]}"
category = await broll_category.get_by_name_and_level(
db, name=reversed_name, level=3
)
if category:
logger.info(
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
)
if category is None:
logger.debug(f"未找到分类: {normalized}")
return None
# 2. 查询候选素材
materials = await broll_material.get_active_by_category_and_duration(
db, category_id=category.id, min_duration=required_duration
)
if not materials:
logger.debug(
f"分类 {normalized} 无足够时长的素材 (需 >= {required_duration}s)"
)
return None
# 3. Redis 去重:获取项目已使用素材
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 4. 区分未用候选和全部候选
unused = [m for m in materials if m.url not in used_urls]
target_pool = unused if unused else materials
# 5. 加权随机选择
chosen = _weighted_choice(target_pool)
# 6. 原子递增 usage_count(避免并发覆盖)
await broll_material.increment_usage_count(db, material_id=chosen.id)
# 7. 记录到 Redis(异常不影响主流程)
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
await redis.sadd(key, chosen.url)
await redis.expire(key, _USED_MATERIALS_TTL)
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return {"url": chosen.url, "duration": float(chosen.duration)}
async def batch_match(
db: AsyncSession,
scenes: list[dict],
project_id: str | None = None,
) -> list[dict | None]:
"""
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis
优化策略:
1. 一次性批量查询所有三级分类(1 次 DB)。
2. 一次性批量查询所有相关素材(1 次 DB)。
3. 内存中按 scene + duration 过滤并加权随机选择。
4. 批量 UPDATE usage_count1 次 DB)。
5. 批量 Redis sadd1 次 Redis pipeline)。
Args:
db: 数据库 Session
scenes: 每个元素为 {"scene": str, "duration": float}
project_id: 项目ID,用于去重
Returns:
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
"""
if not scenes:
return []
# 1. 标准化所有 scene 并去重
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
unique_names = list(set(normalized_scenes))
# 2. 批量查询分类(1 次 DB)—— 同时查询原始名和倒序名
reversed_names: list[str] = []
name_to_reversed: dict[str, str] = {}
for name in unique_names:
parts = name.rsplit("-", 1)
if len(parts) == 2:
rev = f"{parts[1]}-{parts[0]}"
reversed_names.append(rev)
name_to_reversed[name] = rev
all_query_names = unique_names + reversed_names
categories = await broll_category.get_by_names_and_level(
db, names=all_query_names, level=3
)
category_map: dict[str, object] = {}
for c in categories:
category_map[c.name] = c
# 构建原始 scene -> category 的映射(优先精确匹配,fallback 倒序匹配)
scene_to_category: dict[str, object] = {}
for name in unique_names:
if name in category_map:
scene_to_category[name] = category_map[name]
elif name in name_to_reversed and name_to_reversed[name] in category_map:
rev = name_to_reversed[name]
scene_to_category[name] = category_map[rev]
logger.info(
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
)
# 3. 收集所有需要的 category_id
needed_category_ids = [
scene_to_category[name].id
for name in unique_names
if name in scene_to_category
]
# 4. 批量查询素材(1 次 DB
all_materials = await broll_material.get_active_by_categories(
db, category_ids=list(set(needed_category_ids))
)
# 按 category_id 分组,方便内存过滤
materials_by_category: dict[int, list] = {}
for m in all_materials:
materials_by_category.setdefault(m.category_id, []).append(m)
# 5. Redis 获取已使用素材(1 次 Redis)
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 6. 内存中逐个匹配
results: list[dict | None] = []
chosen_materials: list = [] # 记录选中的素材,用于批量更新
for idx, scene_name in enumerate(normalized_scenes):
required_duration = scenes[idx]["duration"]
category = scene_to_category.get(scene_name)
if category is None:
results.append(None)
continue
materials = materials_by_category.get(category.id, [])
# 按时长过滤
candidates = [m for m in materials if m.duration >= required_duration]
if not candidates:
results.append(None)
continue
# Redis 去重
unused = [m for m in candidates if m.url not in used_urls]
target_pool = unused if unused else candidates
# 加权随机
chosen = _weighted_choice(target_pool)
chosen_materials.append(chosen)
used_urls.add(chosen.url)
results.append({"url": chosen.url, "duration": float(chosen.duration)})
# 7. 批量更新 usage_count1 次 DB
if chosen_materials:
material_ids = [m.id for m in chosen_materials]
try:
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
except Exception as e:
logger.warning(f"批量更新 usage_count 失败: {e}")
# 8. 批量记录到 Redispipeline
if project_id and chosen_materials:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
pipe = redis.pipeline()
for m in chosen_materials:
pipe.sadd(key, m.url)
pipe.expire(key, _USED_MATERIALS_TTL)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return results