Files
meijiaka-zy/scripts/check_prompt_category_consistency.py
T

166 lines
5.8 KiB
Python

#!/usr/bin/env python3
"""
检查提示词素材库标题与数据库三级分类名的一致性
"""
import re
from pathlib import Path
from collections import defaultdict
def extract_db_categories(sql_path: str) -> set[str]:
"""从 seed_categories.sql 提取所有三级分类 name"""
with open(sql_path, "r", encoding="utf-8") as f:
content = f.read()
# 匹配 level=3 的 name
matches = re.findall(r"VALUES \('[^']+', '([^']+)', \d+, 3,", content)
return set(matches)
def extract_prompt_titles(prompts_dir: str) -> dict[str, list[str]]:
"""从提示词文件提取素材库标题"""
results = {}
dir_path = Path(prompts_dir)
for txt_file in dir_path.rglob("*.txt"):
with open(txt_file, "r", encoding="utf-8") as f:
content = f.read()
# 查找【内置完整素材库标题】或【内置素材库标题】部分
marker1 = "【内置完整素材库标题】"
marker2 = "【内置素材库标题】"
start_idx = -1
for marker in [marker1, marker2]:
idx = content.find(marker)
if idx != -1:
start_idx = idx + len(marker)
break
if start_idx == -1:
continue
# 提取标题列表(到下一个【标记或文件末尾)
section = content[start_idx:]
# 找下一个【标记
next_marker = section.find("")
if next_marker != -1:
section = section[:next_marker]
# 逐行解析,去除空行和注释
titles = []
for line in section.strip().split("\n"):
line = line.strip()
if not line or line.startswith("(") or line.startswith("备注"):
continue
# 去除可能的编号前缀如 "1、" 或 "- "
line = re.sub(r"^\d+[、.]\s*", "", line)
line = re.sub(r"^[-•]\s*", "", line)
if line:
titles.append(line)
if titles:
results[str(txt_file.relative_to(dir_path))] = titles
return results
def normalize_for_compare(text: str) -> str:
"""标准化用于比较(去除空格和特殊字符)"""
return re.sub(r"\s+", "", text)
def main():
db_path = Path(__file__).parent.parent / "python-api" / "scripts" / "seed_categories.sql"
prompts_dir = Path(__file__).parent.parent / "python-api" / "app" / "ai" / "prompts" / "system"
db_categories = extract_db_categories(str(db_path))
prompt_titles = extract_prompt_titles(str(prompts_dir))
print(f"数据库三级分类总数: {len(db_categories)}")
print(f"包含素材库标题的提示词文件数: {len(prompt_titles)}")
print()
# 汇总所有提示词中的标题
all_prompt_titles = set()
for file, titles in prompt_titles.items():
for t in titles:
all_prompt_titles.add(t)
print(f"提示词中素材库标题总数(去重): {len(all_prompt_titles)}")
print()
# 对比:提示词中有但数据库中没有的
in_prompt_not_db = []
for title in all_prompt_titles:
if title not in db_categories and normalize_for_compare(title) not in {normalize_for_compare(c) for c in db_categories}:
in_prompt_not_db.append(title)
# 对比:数据库中有但提示词中没有的
in_db_not_prompt = []
prompt_normalized = {normalize_for_compare(t) for t in all_prompt_titles}
for cat in db_categories:
if cat not in all_prompt_titles and normalize_for_compare(cat) not in prompt_normalized:
in_db_not_prompt.append(cat)
# 统计按二级分类分组
db_by_parent = defaultdict(list)
for cat in db_categories:
# 从分类名推断父分类,如 "卧室原始结构-毛坯基础" → "毛坯基础"
parts = cat.split("-")
if len(parts) >= 2:
parent = parts[-1]
else:
parent = "其他"
db_by_parent[parent].append(cat)
print("=" * 60)
print("【不一致统计】")
print("=" * 60)
print(f"\n1. 提示词中有但数据库中无(可能为错误或过时标题): {len(in_prompt_not_db)}")
if in_prompt_not_db:
for t in sorted(in_prompt_not_db):
print(f" - {t}")
else:
print(" (无)")
print(f"\n2. 数据库中有但提示词中无(缺少素材引用): {len(in_db_not_prompt)}")
if in_db_not_prompt:
# 按父分类分组
by_parent = defaultdict(list)
for cat in in_db_not_prompt:
parts = cat.split("-")
parent = parts[-1] if len(parts) >= 2 else "其他"
by_parent[parent].append(cat)
for parent in sorted(by_parent.keys()):
cats = by_parent[parent]
print(f"\n{parent}】({len(cats)}个)")
for cat in sorted(cats):
print(f" - {cat}")
else:
print(" (无)")
# 统计各提示词文件中的标题数量
print(f"\n3. 各提示词文件素材库标题数量:")
for file in sorted(prompt_titles.keys()):
titles = prompt_titles[file]
# 计算该文件中与数据库不一致的数量
mismatched = [t for t in titles if t not in db_categories and normalize_for_compare(t) not in {normalize_for_compare(c) for c in db_categories}]
status = f" ⚠️ 有{mismatched}个不一致" if mismatched else " ✅ 一致"
print(f" {file}: {len(titles)}个标题{status}")
if mismatched:
for m in mismatched:
print(f"{m}")
# 总体匹配率
matched = len(all_prompt_titles) - len(in_prompt_not_db)
match_rate = (matched / len(all_prompt_titles) * 100) if all_prompt_titles else 0
print(f"\n总体匹配率: {match_rate:.1f}% ({matched}/{len(all_prompt_titles)})")
if __name__ == "__main__":
main()