Files

384 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
空镜素材服务
============
从 PostgreSQL 查询素材,支持加权随机选择和 Redis 项目级去重。
"""
import logging
import math
import random
import re
from collections.abc import Awaitable
from typing import Any, cast
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ValidationException
from app.core.redis_client import get_redis_client
from app.crud import broll_category, broll_material
from app.models.broll_category import BrollCategory
from app.models.broll_material import BrollMaterial
logger = logging.getLogger(__name__)
# Redis 已使用素材 Set 的 TTL7 天)
_USED_MATERIALS_TTL = 7 * 24 * 3600
def _normalize_scene(scene: str) -> str:
"""标准化场景描述,用于匹配三级分类 name"""
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
cleaned = re.sub(r"\s+", "", scene)
# 去除常见中文标点符号(逗号、句号、感叹号、问号、顿号、分号、冒号、引号、括号等)
cleaned = re.sub(r"[,。!?、;:" "''()【】《》]+", "", cleaned)
# 去除零宽字符(零宽空格、零宽非连接符、零宽连接符、零宽非断空格等)
cleaned = re.sub(r"[\u200b-\u200f\ufeff]+", "", cleaned)
return cleaned
def _weighted_choice(materials: list[BrollMaterial]) -> BrollMaterial:
"""
加权随机选择素材
weight = 1 / sqrt(usage_count + 1),避免过度集中热门素材。
"""
if not materials:
raise ValueError("候选池为空")
if len(materials) == 1:
return materials[0]
weights = [1.0 / math.sqrt(m.usage_count + 1) for m in materials]
total_weight = sum(weights)
if total_weight == 0:
return random.choice(materials) # nosec B311 素材抽样,非加密场景
r = random.uniform(0, total_weight) # nosec B311 素材抽样,非加密场景
cumulative = 0.0
for m, w in zip(materials, weights, strict=True):
cumulative += w
if r <= cumulative:
return m
# 兜底返回最后一个
return materials[-1]
async def _try_fallback_to_parent(
db: AsyncSession,
normalized_scene: str,
) -> BrollCategory | None:
"""
三级分类匹配失败时,回退到上级(level=2)分类随机选取子分类。
解析逻辑:
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工'
- 若不含 '-',直接以整个 scene 作为 parent_name
匹配策略:
1. 精确匹配 level=2 分类 name
2. 模糊匹配(LIKE %parent_name%
返回:
随机选中的一个 level=3 子分类,或 None
"""
if "-" in normalized_scene:
parent_name = normalized_scene.rsplit("-", 1)[-1]
else:
parent_name = normalized_scene
# 1. 精确匹配
parent = await broll_category.get_by_name_and_level(db, name=parent_name, level=2)
# 2. 模糊匹配
if parent is None:
parent = await broll_category.get_by_name_like_and_level(db, name=parent_name, level=2)
if parent is None:
return None
children = await broll_category.get_children_by_parent_id(db, parent_id=parent.id, level=3)
if not children:
return None
return random.choice(children) # nosec B311 素材抽样,非加密场景
async def match_material(
db: AsyncSession,
scene: str,
required_duration: float,
project_id: str | None = None,
) -> dict | None:
"""
根据场景描述和所需时长匹配空镜素材
匹配策略:
1. 标准化 scene,精确匹配三级分类(level=3)的 name。
2. 若精确匹配失败,尝试将 "A-B" 倒序为 "B-A" 再匹配。
3. 若仍失败,回退到上级(level=2)分类,随机选取一个子分类。
4. 查询该分类下状态为 active、时长 >= required_duration 的素材。
5. 若提供 project_id,从 Redis 获取该项目已使用的 URL 并排除。
6. 优先从未使用候选中加权随机选择;若未用候选为空,
fallback 到全部候选(允许复用,保证合成连续性)。
7. 原子递增 usage_count,并将选中的 URL 写入 Redis Set7 天 TTL)。
Args:
db: 数据库 Session
scene: 分镜场景描述(如 "卫生间基层清理 - 防水施工"
required_duration: 所需时长(秒),必须大于 0
project_id: 项目ID,用于去重。为 None 时不做去重。
Returns:
{"url": str, "duration": float} 或 None
Raises:
ValidationException: scene 为空或 duration <= 0
"""
# 参数校验
if not scene or not scene.strip():
raise ValidationException("场景描述不能为空")
if required_duration <= 0:
raise ValidationException("所需时长必须大于 0")
normalized = _normalize_scene(scene)
# 1. 查找三级分类(精确匹配 -> 全量内存匹配兜底 -> 顺序颠倒 -> 上级回退)
category = await broll_category.get_by_name_and_level(db, name=normalized, level=3)
# 精确匹配失败时,全量查询后在内存标准化匹配(兼容数据库 name 含不可见字符)
if category is None:
all_categories = await broll_category.get_by_level(db, level=3)
for c in all_categories:
if _normalize_scene(c.name) == normalized:
category = c
logger.info(f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'")
break
# 若仍失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
if category is None:
parts = normalized.rsplit("-", 1)
if len(parts) == 2:
reversed_name = f"{parts[1]}-{parts[0]}"
category = await broll_category.get_by_name_and_level(db, name=reversed_name, level=3)
if category:
logger.info(f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'")
# 若仍失败,回退到上级分类随机选取
if category is None:
category = await _try_fallback_to_parent(db, normalized)
if category:
logger.info(f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'")
if category is None:
logger.warning(f"素材匹配失败: 未找到分类 '{normalized}' (原始 scene: '{scene}')")
return None
# 2. 查询该分类下所有 active 素材(先不过滤时长,用于日志诊断)
all_materials = await broll_material.get_active_by_categories(db, category_ids=[category.id])
if not all_materials:
logger.warning(f"素材匹配失败: 分类 '{normalized}' 下无任何可用素材")
return None
# 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底)
materials = [m for m in all_materials if m.duration >= required_duration]
if not materials:
materials = [m for m in all_materials if m.duration >= required_duration * 0.7]
if not materials:
materials = all_materials
if not materials:
max_duration = max(m.duration for m in all_materials)
logger.warning(
f"素材匹配失败: 分类 '{normalized}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)"
)
return None
# 3. Redis 去重:获取项目已使用素材
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 4. 区分未用候选和全部候选
unused = [m for m in materials if m.url not in used_urls]
target_pool = unused if unused else materials
# 5. 加权随机选择
chosen = _weighted_choice(target_pool)
# 6. 原子递增 usage_count(避免并发覆盖)
await broll_material.increment_usage_count(db, material_id=chosen.id)
# 7. 记录到 Redis(异常不影响主流程)
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
await cast(Awaitable[int], redis.sadd(key, chosen.url))
await cast(Awaitable[int], redis.expire(key, _USED_MATERIALS_TTL))
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return {"url": chosen.url, "duration": float(chosen.duration)}
async def batch_match(
db: AsyncSession,
scenes: list[dict],
project_id: str | None = None,
) -> list[dict | None]:
"""
批量匹配素材(优化版:3 次 DB 往返 + 1 次 Redis
优化策略:
1. 一次性批量查询所有三级分类(1 次 DB)。
2. 一次性批量查询所有相关素材(1 次 DB)。
3. 内存中按 scene + duration 过滤并加权随机选择。
4. 批量 UPDATE usage_count1 次 DB)。
5. 批量 Redis sadd1 次 Redis pipeline)。
Args:
db: 数据库 Session
scenes: 每个元素为 {"scene": str, "duration": float}
project_id: 项目ID,用于去重
Returns:
与 scenes 长度一致的列表,元素为 {"url": str, "duration": float} 或 None
"""
if not scenes:
return []
# 1. 标准化所有 scene 并去重
normalized_scenes = [_normalize_scene(s["scene"]) for s in scenes]
unique_names = list(set(normalized_scenes))
# 2. 批量查询分类:优先精确查询,失败时全量内存匹配兜底
categories = await broll_category.get_by_names_and_level(db, names=unique_names, level=3)
category_map: dict[str, BrollCategory] = {}
for c in categories:
category_map[_normalize_scene(c.name)] = c
# 收集未命中的 name,准备全量兜底
unmatched_by_exact = [name for name in unique_names if name not in category_map]
if unmatched_by_exact:
all_categories = await broll_category.get_by_level(db, level=3)
for c in all_categories:
normalized_db_name = _normalize_scene(c.name)
if normalized_db_name not in category_map:
category_map[normalized_db_name] = c
# 构建原始 scene -> category 的映射
reversed_map: dict[str, str] = {}
for name in unique_names:
parts = name.rsplit("-", 1)
if len(parts) == 2:
reversed_map[name] = f"{parts[1]}-{parts[0]}"
scene_to_category: dict[str, BrollCategory] = {}
for name in unique_names:
if name in category_map:
scene_to_category[name] = category_map[name]
elif name in reversed_map and reversed_map[name] in category_map:
rev = reversed_map[name]
scene_to_category[name] = category_map[rev]
logger.info(f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'")
# 3. 未匹配的 scene 回退到上级分类随机选取
unmatched = [name for name in unique_names if name not in scene_to_category]
for name in unmatched:
fallback_cat = await _try_fallback_to_parent(db, name)
if fallback_cat:
scene_to_category[name] = fallback_cat
logger.info(f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'")
# 4. 收集所有需要的 category_id
needed_category_ids = [
scene_to_category[name].id for name in unique_names if name in scene_to_category
]
# 4. 批量查询素材(1 次 DB
all_materials = await broll_material.get_active_by_categories(
db, category_ids=list(set(needed_category_ids))
)
# 按 category_id 分组,方便内存过滤
materials_by_category: dict[int, list[BrollMaterial]] = {}
for m in all_materials:
materials_by_category.setdefault(m.category_id, []).append(m)
# 5. Redis 获取已使用素材(1 次 Redis)
used_urls: set[str] = set()
if project_id:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 6. 内存中逐个匹配
results: list[dict | None] = []
chosen_materials: list[BrollMaterial] = [] # 记录选中的素材,用于批量更新
for idx, scene_name in enumerate(normalized_scenes):
required_duration = scenes[idx]["duration"]
category = scene_to_category.get(scene_name)
if category is None:
original_scene = scenes[idx]["scene"]
logger.warning(
f"批量素材匹配失败: 未找到分类 '{scene_name}' (原始 scene: '{original_scene}')"
)
results.append(None)
continue
materials = materials_by_category.get(category.id, [])
# 按时长过滤(优先严格匹配,失败时逐步放宽到 70% 兜底)
candidates = [m for m in materials if m.duration >= required_duration]
if not candidates:
candidates = [m for m in materials if m.duration >= required_duration * 0.7]
if not candidates:
candidates = materials
if not candidates:
max_duration = max((m.duration for m in materials), default=0)
logger.warning(
f"批量素材匹配失败: 分类 '{scene_name}' -> '{category.name}' 无足够时长的素材 (需 >= {required_duration}s, 最大可用: {max_duration}s)"
)
results.append(None)
continue
# Redis 去重
unused = [m for m in candidates if m.url not in used_urls]
target_pool = unused if unused else candidates
# 加权随机
chosen = _weighted_choice(target_pool)
chosen_materials.append(chosen)
used_urls.add(chosen.url)
results.append({"url": chosen.url, "duration": float(chosen.duration)})
# 7. 批量更新 usage_count1 次 DB
if chosen_materials:
material_ids = [m.id for m in chosen_materials]
try:
await broll_material.increment_usage_count_batch(db, material_ids=material_ids)
except Exception as e:
logger.warning(f"批量更新 usage_count 失败: {e}")
# 8. 批量记录到 Redispipeline
if project_id and chosen_materials:
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
pipe = redis.pipeline()
for m in chosen_materials:
pipe.sadd(key, m.url)
pipe.expire(key, _USED_MATERIALS_TTL)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
return results