perf(material): batch_match 批量查询优化,减少 DB 往返

- CRUD 新增 get_by_names_and_level() 批量查分类
- CRUD 新增 get_active_by_categories() 批量查素材
- CRUD 新增 increment_usage_count_batch() 批量更新 usage_count
- 重写 batch_match:从 N 次 DB 往返降到 3 次(查分类 + 查素材 + UPDATE)
- Redis 改用 pipeline 批量 sadd + expire
- 解决并发/连接池不足导致的间歇性 500 错误
This commit is contained in:
小鱼开发
2026-05-16 14:48:28 +08:00
parent b8aad2ea62
commit d3069d423b
3 changed files with 137 additions and 10 deletions
+15
View File
@@ -29,6 +29,21 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
)
return result.scalar_one_or_none()
async def get_by_names_and_level(
self, db: AsyncSession, *, names: list[str], level: int
) -> list[BrollCategory]:
"""批量根据名称和层级获取启用的分类"""
if not names:
return []
result = await db.execute(
select(BrollCategory).where(
BrollCategory.name.in_(names),
BrollCategory.level == level,
BrollCategory.status == "active",
)
)
return list(result.scalars().all())
# 导出实例
broll_category = BrollCategoryCRUD()
+26
View File
@@ -35,6 +35,20 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
)
return list(result.scalars().all())
async def get_active_by_categories(
self, db: AsyncSession, *, category_ids: list[int]
) -> list[BrollMaterial]:
"""批量获取指定分类下状态为 active 的素材列表(不过滤时长)"""
if not category_ids:
return []
result = await db.execute(
select(BrollMaterial).where(
BrollMaterial.category_id.in_(category_ids),
BrollMaterial.status == "active",
)
)
return list(result.scalars().all())
async def increment_usage_count(
self, db: AsyncSession, *, material_id: int
) -> None:
@@ -49,6 +63,18 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
.values(usage_count=BrollMaterial.usage_count + 1)
)
async def increment_usage_count_batch(
self, db: AsyncSession, *, material_ids: list[int]
) -> None:
"""批量原子递增素材使用次数"""
if not material_ids:
return
await db.execute(
update(BrollMaterial)
.where(BrollMaterial.id.in_(material_ids))
.values(usage_count=BrollMaterial.usage_count + 1)
)
# 导出实例
broll_material = BrollMaterialCRUD()
+96 -10
View File
@@ -151,9 +151,14 @@ async def batch_match(
project_id: str | None = None,
) -> list[dict | None]:
"""
批量匹配素材
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis
按 scenes 顺序逐个调用 match_material,保证同项目下的去重连续性。
优化策略:
1. 一次性批量查询所有三级分类(1 次 DB)。
2. 一次性批量查询所有相关素材(1 次 DB)。
3. 内存中按 scene + duration 过滤并加权随机选择。
4. 批量 UPDATE usage_count1 次 DB)。
5. 批量 Redis sadd1 次 Redis pipeline)。
Args:
db: 数据库 Session
@@ -163,13 +168,94 @@ async def batch_match(
Returns:
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
"""
if not scenes:
return []
# 1. 标准化所有 scene 并去重
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
unique_names = list(set(normalized_scenes))
# 2. 批量查询分类(1 次 DB
categories = await broll_category.get_by_names_and_level(
db, names=unique_names, level=3
)
category_map = {c.name: c for c in categories}
# 3. 收集所有需要的 category_id
needed_category_ids = [
category_map[name].id
for name in unique_names
if name in category_map
]
# 4. 批量查询素材(1 次 DB
all_materials = await broll_material.get_active_by_categories(
db, category_ids=list(set(needed_category_ids))
)
# 按 category_id 分组,方便内存过滤
materials_by_category: dict[int, list] = {}
for m in all_materials:
materials_by_category.setdefault(m.category_id, []).append(m)
# 5. Redis 获取已使用素材(1 次 Redis)
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 6. 内存中逐个匹配
results: list[dict | None] = []
for item in scenes:
result = await match_material(
db,
scene=item["scene"],
required_duration=item["duration"],
project_id=project_id,
)
results.append(result)
chosen_materials: list = [] # 记录选中的素材,用于批量更新
for idx, scene_name in enumerate(normalized_scenes):
required_duration = scenes[idx]["duration"]
category = category_map.get(scene_name)
if category is None:
results.append(None)
continue
materials = materials_by_category.get(category.id, [])
# 按时长过滤
candidates = [m for m in materials if m.duration >= required_duration]
if not candidates:
results.append(None)
continue
# Redis 去重
unused = [m for m in candidates if m.url not in used_urls]
target_pool = unused if unused else candidates
# 加权随机
chosen = _weighted_choice(target_pool)
chosen_materials.append(chosen)
used_urls.add(chosen.url)
results.append({"url": chosen.url, "duration": float(chosen.duration)})
# 7. 批量更新 usage_count1 次 DB
if chosen_materials:
material_ids = [m.id for m in chosen_materials]
try:
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
except Exception as e:
logger.warning(f"批量更新 usage_count 失败: {e}")
# 8. 批量记录到 Redispipeline
if project_id and chosen_materials:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
pipe = redis.pipeline()
for m in chosen_materials:
pipe.sadd(key, m.url)
pipe.expire(key, _USED_MATERIALS_TTL)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return results