feat: init meijiaka-zj project from ai-meijiaka template
This commit is contained in:
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
AI 响应处理工具
|
||||
===============
|
||||
|
||||
提供安全的 AI 响应解析、验证和清洗功能。
|
||||
这是 AI 输出和后端/前端之间的防火墙。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_json_from_markdown(content: str) -> str | None:
|
||||
"""
|
||||
从 Markdown 代码块中提取 JSON 字符串
|
||||
|
||||
支持格式:
|
||||
- ```json {...} ```
|
||||
- ``` {...} ```
|
||||
- 纯 JSON 文本
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
|
||||
Returns:
|
||||
提取的 JSON 字符串,如果无法提取则返回 None
|
||||
"""
|
||||
if not content:
|
||||
return None
|
||||
|
||||
content = content.strip()
|
||||
|
||||
# 匹配 ```json ... ``` 或 ``` ... ```
|
||||
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||||
matches = re.findall(pattern, content)
|
||||
|
||||
if matches:
|
||||
# 取最后一个匹配(避免前面有示例代码)
|
||||
result = matches[-1].strip()
|
||||
return result if result else None
|
||||
|
||||
# 如果没有代码块,返回原始内容
|
||||
return content
|
||||
|
||||
|
||||
def sanitize_string(value: Any, max_length: int = 5000) -> str | None:
|
||||
"""
|
||||
清洗字符串值
|
||||
|
||||
- 去除 HTML 标签
|
||||
- 去除控制字符
|
||||
- 标准化空白字符
|
||||
- 截断超长内容
|
||||
|
||||
Args:
|
||||
value: 原始值
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
清洗后的字符串
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# 转换为字符串
|
||||
text = str(value)
|
||||
|
||||
# 去除 HTML 标签
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
|
||||
# 去除控制字符(保留换行和制表符)
|
||||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
|
||||
|
||||
# 标准化空白字符
|
||||
text = re.sub(r"[\t ]+", " ", text)
|
||||
text = re.sub(r"\n+", "\n", text)
|
||||
text = text.strip()
|
||||
|
||||
# 截断超长内容
|
||||
if len(text) > max_length:
|
||||
logger.warning(f"内容被截断: {len(text)} -> {max_length} 字符")
|
||||
text = text[:max_length] + "..."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def parse_duration(duration_value: Any) -> str:
|
||||
"""
|
||||
解析时长字段
|
||||
|
||||
支持格式:
|
||||
- 数字 (5) -> "5s"
|
||||
- 字符串带单位 ("5s", "5秒") -> "5s"
|
||||
- 其他 -> "5s" (默认)
|
||||
|
||||
Args:
|
||||
duration_value: 原始时长值
|
||||
|
||||
Returns:
|
||||
标准化的时长字符串
|
||||
"""
|
||||
if duration_value is None:
|
||||
return "5s"
|
||||
|
||||
# 如果是数字,直接加 s
|
||||
if isinstance(duration_value, int | float):
|
||||
seconds = max(1, min(int(duration_value), 300)) # 限制 1-300 秒
|
||||
return f"{seconds}s"
|
||||
|
||||
# 如果是字符串
|
||||
text = str(duration_value).strip().lower()
|
||||
|
||||
# 提取数字
|
||||
match = re.search(r"(\d+)", text)
|
||||
if match:
|
||||
seconds = int(match.group(1))
|
||||
seconds = max(1, min(seconds, 300))
|
||||
return f"{seconds}s"
|
||||
|
||||
return "5s"
|
||||
|
||||
|
||||
def validate_and_normalize_shots(raw_data: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
验证并标准化分镜数据
|
||||
|
||||
这是一个防御性函数,处理各种可能的 AI 返回格式:
|
||||
- 列表格式: [{...}, {...}]
|
||||
- 包装格式: {"shots": [...]} -> 提取 shots
|
||||
- 单对象格式: {...} -> 包装为列表
|
||||
- 无效格式: 返回空列表
|
||||
|
||||
Args:
|
||||
raw_data: AI 返回的原始数据
|
||||
|
||||
Returns:
|
||||
标准化的分镜列表
|
||||
"""
|
||||
if raw_data is None:
|
||||
logger.warning("AI 返回数据为空")
|
||||
return []
|
||||
|
||||
shots = []
|
||||
|
||||
# 处理字典格式(可能是包装对象)
|
||||
if isinstance(raw_data, dict):
|
||||
# 尝试提取常见的包装字段
|
||||
for key in ["shots", "data", "segments", "script", "result", "list"]:
|
||||
if key in raw_data and isinstance(raw_data[key], list):
|
||||
shots = raw_data[key]
|
||||
logger.info(f"从字典字段 '{key}' 提取到 {len(shots)} 个分镜")
|
||||
break
|
||||
else:
|
||||
# 没有列表字段,将整个字典作为一个分镜
|
||||
logger.info("将字典作为单个分镜处理")
|
||||
shots = [raw_data]
|
||||
|
||||
# 处理列表格式
|
||||
elif isinstance(raw_data, list):
|
||||
shots = raw_data
|
||||
|
||||
# 其他格式无法处理
|
||||
else:
|
||||
logger.error(f"无法处理的 AI 返回格式: {type(raw_data)}")
|
||||
return []
|
||||
|
||||
# 验证并标准化每个分镜
|
||||
normalized_shots = []
|
||||
for idx, item in enumerate(shots):
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"跳过非字典分镜 (索引 {idx}): {type(item)}")
|
||||
continue
|
||||
|
||||
# 字段映射和清洗
|
||||
normalized: dict[str, Any] = {
|
||||
"id": str(idx + 1), # 强制按索引递增,Segment 模型要求 str
|
||||
"type": "segment", # 默认类型
|
||||
"scene": None,
|
||||
"voiceover": "",
|
||||
"duration": 5, # Segment 模型要求 int(秒)
|
||||
}
|
||||
|
||||
# 提取 ID(支持字符串和数字,最终转为 str)
|
||||
raw_id = item.get("id", idx + 1)
|
||||
try:
|
||||
normalized["id"] = str(int(raw_id))
|
||||
except (ValueError, TypeError):
|
||||
normalized["id"] = str(idx + 1)
|
||||
|
||||
# 提取类型
|
||||
raw_type = item.get("type", "segment")
|
||||
if isinstance(raw_type, str):
|
||||
normalized["type"] = raw_type.strip().lower()
|
||||
|
||||
# 提取场景描述(支持多种字段名)
|
||||
scene = (
|
||||
item.get("scene")
|
||||
or item.get("prompt")
|
||||
or item.get("description")
|
||||
or item.get("image_prompt")
|
||||
or item.get("visual")
|
||||
)
|
||||
normalized["scene"] = sanitize_string(scene, max_length=2000)
|
||||
|
||||
# 提取配音文案(支持多种字段名)
|
||||
voiceover = (
|
||||
item.get("voiceover")
|
||||
or item.get("text")
|
||||
or item.get("content")
|
||||
or item.get("narration")
|
||||
or item.get("script")
|
||||
)
|
||||
normalized["voiceover"] = sanitize_string(voiceover, max_length=2000) or ""
|
||||
|
||||
# 提取时长(Segment 模型要求 int 秒数)
|
||||
duration = item.get("duration")
|
||||
duration_str = parse_duration(duration) # 返回如 "5s"
|
||||
try:
|
||||
normalized["duration"] = int(re.search(r"\d+", duration_str).group())
|
||||
except (AttributeError, ValueError):
|
||||
normalized["duration"] = 5
|
||||
|
||||
# 计算字数
|
||||
normalized["word_count"] = len(normalized["voiceover"])
|
||||
|
||||
normalized_shots.append(normalized)
|
||||
|
||||
return normalized_shots
|
||||
|
||||
|
||||
def _normalize_json_quotes(json_str: str) -> str:
|
||||
"""
|
||||
将中文引号(弯引号)替换为英文引号
|
||||
|
||||
某些 AI 模型会在长文本生成中混用中英文标点,导致 JSON 解析失败。
|
||||
此函数将中文引号 " 和 " 替换为标准 JSON 使用的英文引号 "。
|
||||
|
||||
Args:
|
||||
json_str: 原始 JSON 字符串
|
||||
|
||||
Returns:
|
||||
规范化后的 JSON 字符串
|
||||
"""
|
||||
# 中文左双引号 " 和右双引号 " 都替换为英文双引号 "
|
||||
return json_str.replace('"', '"').replace('"', '"')
|
||||
|
||||
|
||||
def safe_parse_ai_json_response(content: str) -> tuple[bool, Any, str | None]:
|
||||
"""
|
||||
安全地解析 AI JSON 响应
|
||||
|
||||
Args:
|
||||
content: AI 返回的原始内容
|
||||
|
||||
Returns:
|
||||
tuple: (是否成功, 解析后的数据, 错误信息)
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return False, None, "AI 返回内容为空"
|
||||
|
||||
# 提取 JSON 字符串
|
||||
json_str = extract_json_from_markdown(content)
|
||||
|
||||
if not json_str:
|
||||
logger.error(f"无法从内容中提取 JSON: {content[:200]}...")
|
||||
return False, None, "无法从 AI 输出中提取 JSON"
|
||||
|
||||
# 尝试直接解析 JSON
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return True, data, None
|
||||
except json.JSONDecodeError:
|
||||
pass # 解析失败,尝试修复
|
||||
|
||||
# 尝试修复中文引号问题
|
||||
normalized = _normalize_json_quotes(json_str)
|
||||
|
||||
try:
|
||||
data = json.loads(normalized)
|
||||
logger.info("JSON 引号规范化成功")
|
||||
return True, data, None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON 解析失败: {e}")
|
||||
logger.error(f"原始内容前 500 字符: {json_str[:500]!r}")
|
||||
return False, None, f"JSON 解析失败: {e}"
|
||||
except Exception as e:
|
||||
logger.error(f"解析 AI 响应时发生未知错误: {e}")
|
||||
return False, None, f"解析错误: {e}"
|
||||
|
||||
|
||||
def validate_shots_structure(shots: list[dict]) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
验证分镜列表的结构完整性
|
||||
|
||||
Args:
|
||||
shots: 分镜列表
|
||||
|
||||
Returns:
|
||||
tuple: (是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not shots:
|
||||
errors.append("分镜列表为空")
|
||||
return False, errors
|
||||
|
||||
for idx, shot in enumerate(shots):
|
||||
# 检查必需字段
|
||||
if not isinstance(shot, dict):
|
||||
errors.append(f"分镜 {idx + 1} 不是字典类型")
|
||||
continue
|
||||
|
||||
# 检查 voiceover(允许为空字符串但不允许缺失)
|
||||
if "voiceover" not in shot:
|
||||
errors.append(f"分镜 {idx + 1} 缺少 voiceover 字段")
|
||||
|
||||
# 检查 id
|
||||
if "id" not in shot:
|
||||
errors.append(f"分镜 {idx + 1} 缺少 id 字段")
|
||||
elif not isinstance(shot.get("id"), int):
|
||||
errors.append(f"分镜 {idx + 1} 的 id 不是整数")
|
||||
|
||||
# 检查 type
|
||||
if "type" not in shot:
|
||||
errors.append(f"分镜 {idx + 1} 缺少 type 字段")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
Reference in New Issue
Block a user