From d3069d423ba9e232d8154e373b413da387abd157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E9=B1=BC=E5=BC=80=E5=8F=91?= Date: Sat, 16 May 2026 14:48:28 +0800 Subject: [PATCH] =?UTF-8?q?perf(material):=20batch=5Fmatch=20=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E6=9F=A5=E8=AF=A2=E4=BC=98=E5=8C=96=EF=BC=8C=E5=87=8F?= =?UTF-8?q?=E5=B0=91=20DB=20=E5=BE=80=E8=BF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CRUD 新增 get_by_names_and_level() 批量查分类 - CRUD 新增 get_active_by_categories() 批量查素材 - CRUD 新增 increment_usage_count_batch() 批量更新 usage_count - 重写 batch_match:从 N 次 DB 往返降到 3 次(查分类 + 查素材 + UPDATE) - Redis 改用 pipeline 批量 sadd + expire - 解决并发/连接池不足导致的间歇性 500 错误 --- python-api/app/crud/broll_category.py | 15 +++ python-api/app/crud/broll_material.py | 26 +++++ python-api/app/services/material_service.py | 106 ++++++++++++++++++-- 3 files changed, 137 insertions(+), 10 deletions(-) diff --git a/python-api/app/crud/broll_category.py b/python-api/app/crud/broll_category.py index fe2c043..9a96966 100644 --- a/python-api/app/crud/broll_category.py +++ b/python-api/app/crud/broll_category.py @@ -29,6 +29,21 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]): ) return result.scalar_one_or_none() + async def get_by_names_and_level( + self, db: AsyncSession, *, names: list[str], level: int + ) -> list[BrollCategory]: + """批量根据名称和层级获取启用的分类""" + if not names: + return [] + result = await db.execute( + select(BrollCategory).where( + BrollCategory.name.in_(names), + BrollCategory.level == level, + BrollCategory.status == "active", + ) + ) + return list(result.scalars().all()) + # 导出实例 broll_category = BrollCategoryCRUD() diff --git a/python-api/app/crud/broll_material.py b/python-api/app/crud/broll_material.py index 9acf827..6a460fb 100644 --- a/python-api/app/crud/broll_material.py +++ b/python-api/app/crud/broll_material.py @@ -35,6 +35,20 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]): ) return list(result.scalars().all()) + async def get_active_by_categories( + self, db: AsyncSession, *, category_ids: list[int] + ) -> list[BrollMaterial]: + """批量获取指定分类下状态为 active 的素材列表(不过滤时长)""" + if not category_ids: + return [] + result = await db.execute( + select(BrollMaterial).where( + BrollMaterial.category_id.in_(category_ids), + BrollMaterial.status == "active", + ) + ) + return list(result.scalars().all()) + async def increment_usage_count( self, db: AsyncSession, *, material_id: int ) -> None: @@ -49,6 +63,18 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]): .values(usage_count=BrollMaterial.usage_count + 1) ) + async def increment_usage_count_batch( + self, db: AsyncSession, *, material_ids: list[int] + ) -> None: + """批量原子递增素材使用次数""" + if not material_ids: + return + await db.execute( + update(BrollMaterial) + .where(BrollMaterial.id.in_(material_ids)) + .values(usage_count=BrollMaterial.usage_count + 1) + ) + # 导出实例 broll_material = BrollMaterialCRUD() diff --git a/python-api/app/services/material_service.py b/python-api/app/services/material_service.py index ecca521..c2cccea 100644 --- a/python-api/app/services/material_service.py +++ b/python-api/app/services/material_service.py @@ -151,9 +151,14 @@ async def batch_match( project_id: str | None = None, ) -> list[dict | None]: """ - 批量匹配素材 + 批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis) - 按 scenes 顺序逐个调用 match_material,保证同项目下的去重连续性。 + 优化策略: + 1. 一次性批量查询所有三级分类(1 次 DB)。 + 2. 一次性批量查询所有相关素材(1 次 DB)。 + 3. 内存中按 scene + duration 过滤并加权随机选择。 + 4. 批量 UPDATE usage_count(1 次 DB)。 + 5. 批量 Redis sadd(1 次 Redis pipeline)。 Args: db: 数据库 Session @@ -163,13 +168,94 @@ async def batch_match( Returns: 与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None """ + if not scenes: + return [] + + # 1. 标准化所有 scene 并去重 + normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes] + unique_names = list(set(normalized_scenes)) + + # 2. 批量查询分类(1 次 DB) + categories = await broll_category.get_by_names_and_level( + db, names=unique_names, level=3 + ) + category_map = {c.name: c for c in categories} + + # 3. 收集所有需要的 category_id + needed_category_ids = [ + category_map[name].id + for name in unique_names + if name in category_map + ] + + # 4. 批量查询素材(1 次 DB) + all_materials = await broll_material.get_active_by_categories( + db, category_ids=list(set(needed_category_ids)) + ) + + # 按 category_id 分组,方便内存过滤 + materials_by_category: dict[int, list] = {} + for m in all_materials: + materials_by_category.setdefault(m.category_id, []).append(m) + + # 5. Redis 获取已使用素材(1 次 Redis) + used_urls: set[str] = set() + if project_id: + try: + redis = get_redis_client() + key = f"proj:{project_id}:used_materials" + used_urls = set(await redis.smembers(key)) + except Exception as e: + logger.warning(f"Redis 去重查询失败,降级为不去重: {e}") + + # 6. 内存中逐个匹配 results: list[dict | None] = [] - for item in scenes: - result = await match_material( - db, - scene=item["scene"], - required_duration=item["duration"], - project_id=project_id, - ) - results.append(result) + chosen_materials: list = [] # 记录选中的素材,用于批量更新 + + for idx, scene_name in enumerate(normalized_scenes): + required_duration = scenes[idx]["duration"] + + category = category_map.get(scene_name) + if category is None: + results.append(None) + continue + + materials = materials_by_category.get(category.id, []) + # 按时长过滤 + candidates = [m for m in materials if m.duration >= required_duration] + if not candidates: + results.append(None) + continue + + # Redis 去重 + unused = [m for m in candidates if m.url not in used_urls] + target_pool = unused if unused else candidates + + # 加权随机 + chosen = _weighted_choice(target_pool) + chosen_materials.append(chosen) + used_urls.add(chosen.url) + results.append({"url": chosen.url, "duration": float(chosen.duration)}) + + # 7. 批量更新 usage_count(1 次 DB) + if chosen_materials: + material_ids = [m.id for m in chosen_materials] + try: + await broll_material.increment_usage_count_batch(db, material_ids=material_ids) + except Exception as e: + logger.warning(f"批量更新 usage_count 失败: {e}") + + # 8. 批量记录到 Redis(pipeline) + if project_id and chosen_materials: + try: + redis = get_redis_client() + key = f"proj:{project_id}:used_materials" + pipe = redis.pipeline() + for m in chosen_materials: + pipe.sadd(key, m.url) + pipe.expire(key, _USED_MATERIALS_TTL) + await pipe.execute() + except Exception as e: + logger.warning(f"Redis 去重记录失败: {e}") + return results