perf(material): batch_match 批量查询优化,减少 DB 往返
- CRUD 新增 get_by_names_and_level() 批量查分类 - CRUD 新增 get_active_by_categories() 批量查素材 - CRUD 新增 increment_usage_count_batch() 批量更新 usage_count - 重写 batch_match:从 N 次 DB 往返降到 3 次(查分类 + 查素材 + UPDATE) - Redis 改用 pipeline 批量 sadd + expire - 解决并发/连接池不足导致的间歇性 500 错误
This commit is contained in:
@@ -29,6 +29,21 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_names_and_level(
|
||||
self, db: AsyncSession, *, names: list[str], level: int
|
||||
) -> list[BrollCategory]:
|
||||
"""批量根据名称和层级获取启用的分类"""
|
||||
if not names:
|
||||
return []
|
||||
result = await db.execute(
|
||||
select(BrollCategory).where(
|
||||
BrollCategory.name.in_(names),
|
||||
BrollCategory.level == level,
|
||||
BrollCategory.status == "active",
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
# 导出实例
|
||||
broll_category = BrollCategoryCRUD()
|
||||
|
||||
@@ -35,6 +35,20 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_active_by_categories(
|
||||
self, db: AsyncSession, *, category_ids: list[int]
|
||||
) -> list[BrollMaterial]:
|
||||
"""批量获取指定分类下状态为 active 的素材列表(不过滤时长)"""
|
||||
if not category_ids:
|
||||
return []
|
||||
result = await db.execute(
|
||||
select(BrollMaterial).where(
|
||||
BrollMaterial.category_id.in_(category_ids),
|
||||
BrollMaterial.status == "active",
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def increment_usage_count(
|
||||
self, db: AsyncSession, *, material_id: int
|
||||
) -> None:
|
||||
@@ -49,6 +63,18 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
|
||||
.values(usage_count=BrollMaterial.usage_count + 1)
|
||||
)
|
||||
|
||||
async def increment_usage_count_batch(
|
||||
self, db: AsyncSession, *, material_ids: list[int]
|
||||
) -> None:
|
||||
"""批量原子递增素材使用次数"""
|
||||
if not material_ids:
|
||||
return
|
||||
await db.execute(
|
||||
update(BrollMaterial)
|
||||
.where(BrollMaterial.id.in_(material_ids))
|
||||
.values(usage_count=BrollMaterial.usage_count + 1)
|
||||
)
|
||||
|
||||
|
||||
# 导出实例
|
||||
broll_material = BrollMaterialCRUD()
|
||||
|
||||
@@ -151,9 +151,14 @@ async def batch_match(
|
||||
project_id: str | None = None,
|
||||
) -> list[dict | None]:
|
||||
"""
|
||||
批量匹配素材
|
||||
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis)
|
||||
|
||||
按 scenes 顺序逐个调用 match_material,保证同项目下的去重连续性。
|
||||
优化策略:
|
||||
1. 一次性批量查询所有三级分类(1 次 DB)。
|
||||
2. 一次性批量查询所有相关素材(1 次 DB)。
|
||||
3. 内存中按 scene + duration 过滤并加权随机选择。
|
||||
4. 批量 UPDATE usage_count(1 次 DB)。
|
||||
5. 批量 Redis sadd(1 次 Redis pipeline)。
|
||||
|
||||
Args:
|
||||
db: 数据库 Session
|
||||
@@ -163,13 +168,94 @@ async def batch_match(
|
||||
Returns:
|
||||
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
|
||||
"""
|
||||
if not scenes:
|
||||
return []
|
||||
|
||||
# 1. 标准化所有 scene 并去重
|
||||
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
|
||||
unique_names = list(set(normalized_scenes))
|
||||
|
||||
# 2. 批量查询分类(1 次 DB)
|
||||
categories = await broll_category.get_by_names_and_level(
|
||||
db, names=unique_names, level=3
|
||||
)
|
||||
category_map = {c.name: c for c in categories}
|
||||
|
||||
# 3. 收集所有需要的 category_id
|
||||
needed_category_ids = [
|
||||
category_map[name].id
|
||||
for name in unique_names
|
||||
if name in category_map
|
||||
]
|
||||
|
||||
# 4. 批量查询素材(1 次 DB)
|
||||
all_materials = await broll_material.get_active_by_categories(
|
||||
db, category_ids=list(set(needed_category_ids))
|
||||
)
|
||||
|
||||
# 按 category_id 分组,方便内存过滤
|
||||
materials_by_category: dict[int, list] = {}
|
||||
for m in all_materials:
|
||||
materials_by_category.setdefault(m.category_id, []).append(m)
|
||||
|
||||
# 5. Redis 获取已使用素材(1 次 Redis)
|
||||
used_urls: set[str] = set()
|
||||
if project_id:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"proj:{project_id}:used_materials"
|
||||
used_urls = set(await redis.smembers(key))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||||
|
||||
# 6. 内存中逐个匹配
|
||||
results: list[dict | None] = []
|
||||
for item in scenes:
|
||||
result = await match_material(
|
||||
db,
|
||||
scene=item["scene"],
|
||||
required_duration=item["duration"],
|
||||
project_id=project_id,
|
||||
)
|
||||
results.append(result)
|
||||
chosen_materials: list = [] # 记录选中的素材,用于批量更新
|
||||
|
||||
for idx, scene_name in enumerate(normalized_scenes):
|
||||
required_duration = scenes[idx]["duration"]
|
||||
|
||||
category = category_map.get(scene_name)
|
||||
if category is None:
|
||||
results.append(None)
|
||||
continue
|
||||
|
||||
materials = materials_by_category.get(category.id, [])
|
||||
# 按时长过滤
|
||||
candidates = [m for m in materials if m.duration >= required_duration]
|
||||
if not candidates:
|
||||
results.append(None)
|
||||
continue
|
||||
|
||||
# Redis 去重
|
||||
unused = [m for m in candidates if m.url not in used_urls]
|
||||
target_pool = unused if unused else candidates
|
||||
|
||||
# 加权随机
|
||||
chosen = _weighted_choice(target_pool)
|
||||
chosen_materials.append(chosen)
|
||||
used_urls.add(chosen.url)
|
||||
results.append({"url": chosen.url, "duration": float(chosen.duration)})
|
||||
|
||||
# 7. 批量更新 usage_count(1 次 DB)
|
||||
if chosen_materials:
|
||||
material_ids = [m.id for m in chosen_materials]
|
||||
try:
|
||||
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量更新 usage_count 失败: {e}")
|
||||
|
||||
# 8. 批量记录到 Redis(pipeline)
|
||||
if project_id and chosen_materials:
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"proj:{project_id}:used_materials"
|
||||
pipe = redis.pipeline()
|
||||
for m in chosen_materials:
|
||||
pipe.sadd(key, m.url)
|
||||
pipe.expire(key, _USED_MATERIALS_TTL)
|
||||
await pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 去重记录失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user