166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
检查提示词素材库标题与数据库三级分类名的一致性
|
|
"""
|
|
|
|
import re
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
|
|
def extract_db_categories(sql_path: str) -> set[str]:
|
|
"""从 seed_categories.sql 提取所有三级分类 name"""
|
|
with open(sql_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# 匹配 level=3 的 name
|
|
matches = re.findall(r"VALUES \('[^']+', '([^']+)', \d+, 3,", content)
|
|
return set(matches)
|
|
|
|
|
|
def extract_prompt_titles(prompts_dir: str) -> dict[str, list[str]]:
|
|
"""从提示词文件提取素材库标题"""
|
|
results = {}
|
|
dir_path = Path(prompts_dir)
|
|
|
|
for txt_file in dir_path.rglob("*.txt"):
|
|
with open(txt_file, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# 查找【内置完整素材库标题】或【内置素材库标题】部分
|
|
marker1 = "【内置完整素材库标题】"
|
|
marker2 = "【内置素材库标题】"
|
|
|
|
start_idx = -1
|
|
for marker in [marker1, marker2]:
|
|
idx = content.find(marker)
|
|
if idx != -1:
|
|
start_idx = idx + len(marker)
|
|
break
|
|
|
|
if start_idx == -1:
|
|
continue
|
|
|
|
# 提取标题列表(到下一个【标记或文件末尾)
|
|
section = content[start_idx:]
|
|
# 找下一个【标记
|
|
next_marker = section.find("【")
|
|
if next_marker != -1:
|
|
section = section[:next_marker]
|
|
|
|
# 逐行解析,去除空行和注释
|
|
titles = []
|
|
for line in section.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line or line.startswith("(") or line.startswith("备注"):
|
|
continue
|
|
# 去除可能的编号前缀如 "1、" 或 "- "
|
|
line = re.sub(r"^\d+[、.]\s*", "", line)
|
|
line = re.sub(r"^[-•]\s*", "", line)
|
|
if line:
|
|
titles.append(line)
|
|
|
|
if titles:
|
|
results[str(txt_file.relative_to(dir_path))] = titles
|
|
|
|
return results
|
|
|
|
|
|
def normalize_for_compare(text: str) -> str:
|
|
"""标准化用于比较(去除空格和特殊字符)"""
|
|
return re.sub(r"\s+", "", text)
|
|
|
|
|
|
def main():
|
|
db_path = Path(__file__).parent.parent / "python-api" / "scripts" / "seed_categories.sql"
|
|
prompts_dir = Path(__file__).parent.parent / "python-api" / "app" / "ai" / "prompts" / "system"
|
|
|
|
db_categories = extract_db_categories(str(db_path))
|
|
prompt_titles = extract_prompt_titles(str(prompts_dir))
|
|
|
|
print(f"数据库三级分类总数: {len(db_categories)}")
|
|
print(f"包含素材库标题的提示词文件数: {len(prompt_titles)}")
|
|
print()
|
|
|
|
# 汇总所有提示词中的标题
|
|
all_prompt_titles = set()
|
|
for file, titles in prompt_titles.items():
|
|
for t in titles:
|
|
all_prompt_titles.add(t)
|
|
|
|
print(f"提示词中素材库标题总数(去重): {len(all_prompt_titles)}")
|
|
print()
|
|
|
|
# 对比:提示词中有但数据库中没有的
|
|
in_prompt_not_db = []
|
|
for title in all_prompt_titles:
|
|
if title not in db_categories and normalize_for_compare(title) not in {normalize_for_compare(c) for c in db_categories}:
|
|
in_prompt_not_db.append(title)
|
|
|
|
# 对比:数据库中有但提示词中没有的
|
|
in_db_not_prompt = []
|
|
prompt_normalized = {normalize_for_compare(t) for t in all_prompt_titles}
|
|
for cat in db_categories:
|
|
if cat not in all_prompt_titles and normalize_for_compare(cat) not in prompt_normalized:
|
|
in_db_not_prompt.append(cat)
|
|
|
|
# 统计按二级分类分组
|
|
db_by_parent = defaultdict(list)
|
|
for cat in db_categories:
|
|
# 从分类名推断父分类,如 "卧室原始结构-毛坯基础" → "毛坯基础"
|
|
parts = cat.split("-")
|
|
if len(parts) >= 2:
|
|
parent = parts[-1]
|
|
else:
|
|
parent = "其他"
|
|
db_by_parent[parent].append(cat)
|
|
|
|
print("=" * 60)
|
|
print("【不一致统计】")
|
|
print("=" * 60)
|
|
|
|
print(f"\n1. 提示词中有但数据库中无(可能为错误或过时标题): {len(in_prompt_not_db)} 个")
|
|
if in_prompt_not_db:
|
|
for t in sorted(in_prompt_not_db):
|
|
print(f" - {t}")
|
|
else:
|
|
print(" (无)")
|
|
|
|
print(f"\n2. 数据库中有但提示词中无(缺少素材引用): {len(in_db_not_prompt)} 个")
|
|
if in_db_not_prompt:
|
|
# 按父分类分组
|
|
by_parent = defaultdict(list)
|
|
for cat in in_db_not_prompt:
|
|
parts = cat.split("-")
|
|
parent = parts[-1] if len(parts) >= 2 else "其他"
|
|
by_parent[parent].append(cat)
|
|
|
|
for parent in sorted(by_parent.keys()):
|
|
cats = by_parent[parent]
|
|
print(f"\n 【{parent}】({len(cats)}个)")
|
|
for cat in sorted(cats):
|
|
print(f" - {cat}")
|
|
else:
|
|
print(" (无)")
|
|
|
|
# 统计各提示词文件中的标题数量
|
|
print(f"\n3. 各提示词文件素材库标题数量:")
|
|
for file in sorted(prompt_titles.keys()):
|
|
titles = prompt_titles[file]
|
|
# 计算该文件中与数据库不一致的数量
|
|
mismatched = [t for t in titles if t not in db_categories and normalize_for_compare(t) not in {normalize_for_compare(c) for c in db_categories}]
|
|
status = f" ⚠️ 有{mismatched}个不一致" if mismatched else " ✅ 一致"
|
|
print(f" {file}: {len(titles)}个标题{status}")
|
|
if mismatched:
|
|
for m in mismatched:
|
|
print(f" ❌ {m}")
|
|
|
|
# 总体匹配率
|
|
matched = len(all_prompt_titles) - len(in_prompt_not_db)
|
|
match_rate = (matched / len(all_prompt_titles) * 100) if all_prompt_titles else 0
|
|
print(f"\n总体匹配率: {match_rate:.1f}% ({matched}/{len(all_prompt_titles)})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|