Files
meijiaka-zy/python-api/app/services/ai_response_utils.py
T

332 lines
9.2 KiB
Python

"""
AI 响应处理工具
===============
提供安全的 AI 响应解析、验证和清洗功能。
这是 AI 输出和后端/前端之间的防火墙。
"""
import json
import logging
import re
from typing import Any
logger = logging.getLogger(__name__)
def extract_json_from_markdown(content: str) -> str | None:
"""
从 Markdown 代码块中提取 JSON 字符串
支持格式:
- ```json {...} ```
- ``` {...} ```
- 纯 JSON 文本
Args:
content: 原始内容
Returns:
提取的 JSON 字符串,如果无法提取则返回 None
"""
if not content:
return None
content = content.strip()
# 匹配 ```json ... ``` 或 ``` ... ```
pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
matches = re.findall(pattern, content)
if matches:
# 取最后一个匹配(避免前面有示例代码)
result = matches[-1].strip()
return result if result else None
# 如果没有代码块,返回原始内容
return content
def sanitize_string(value: Any, max_length: int = 5000) -> str | None:
"""
清洗字符串值
- 去除 HTML 标签
- 去除控制字符
- 标准化空白字符
- 截断超长内容
Args:
value: 原始值
max_length: 最大长度
Returns:
清洗后的字符串
"""
if value is None:
return None
# 转换为字符串
text = str(value)
# 去除 HTML 标签
text = re.sub(r"<[^>]+>", "", text)
# 去除控制字符(保留换行和制表符)
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
# 标准化空白字符
text = re.sub(r"[\t ]+", " ", text)
text = re.sub(r"\n+", "\n", text)
text = text.strip()
# 截断超长内容
if len(text) > max_length:
logger.warning(f"内容被截断: {len(text)} -> {max_length} 字符")
text = text[:max_length] + "..."
return text
def parse_duration(duration_value: Any) -> str:
"""
解析时长字段
支持格式:
- 数字 (5) -> "5s"
- 字符串带单位 ("5s", "5秒") -> "5s"
- 其他 -> "5s" (默认)
Args:
duration_value: 原始时长值
Returns:
标准化的时长字符串
"""
if duration_value is None:
return "5s"
# 如果是数字,直接加 s
if isinstance(duration_value, int | float):
seconds = max(1, min(int(duration_value), 300)) # 限制 1-300 秒
return f"{seconds}s"
# 如果是字符串
text = str(duration_value).strip().lower()
# 提取数字
match = re.search(r"(\d+)", text)
if match:
seconds = int(match.group(1))
seconds = max(1, min(seconds, 300))
return f"{seconds}s"
return "5s"
def validate_and_normalize_shots(raw_data: Any) -> list[dict[str, Any]]:
"""
验证并标准化分镜数据
这是一个防御性函数,处理各种可能的 AI 返回格式:
- 列表格式: [{...}, {...}]
- 包装格式: {"shots": [...]} -> 提取 shots
- 单对象格式: {...} -> 包装为列表
- 无效格式: 返回空列表
Args:
raw_data: AI 返回的原始数据
Returns:
标准化的分镜列表
"""
if raw_data is None:
logger.warning("AI 返回数据为空")
return []
shots = []
# 处理字典格式(可能是包装对象)
if isinstance(raw_data, dict):
# 尝试提取常见的包装字段
for key in ["shots", "data", "segments", "script", "result", "list"]:
if key in raw_data and isinstance(raw_data[key], list):
shots = raw_data[key]
logger.info(f"从字典字段 '{key}' 提取到 {len(shots)} 个分镜")
break
else:
# 没有列表字段,将整个字典作为一个分镜
logger.info("将字典作为单个分镜处理")
shots = [raw_data]
# 处理列表格式
elif isinstance(raw_data, list):
shots = raw_data
# 其他格式无法处理
else:
logger.error(f"无法处理的 AI 返回格式: {type(raw_data)}")
return []
# 验证并标准化每个分镜
normalized_shots = []
for idx, item in enumerate(shots):
if not isinstance(item, dict):
logger.warning(f"跳过非字典分镜 (索引 {idx}): {type(item)}")
continue
# 字段映射和清洗
normalized: dict[str, Any] = {
"id": str(idx + 1), # 强制按索引递增,Segment 模型要求 str
"type": "segment", # 默认类型
"scene": None,
"voiceover": "",
"duration": 5, # Segment 模型要求 int(秒)
}
# 提取 ID(支持字符串和数字,最终转为 str)
raw_id = item.get("id", idx + 1)
try:
normalized["id"] = str(int(raw_id))
except (ValueError, TypeError):
normalized["id"] = str(idx + 1)
# 提取类型
raw_type = item.get("type", "segment")
if isinstance(raw_type, str):
normalized["type"] = raw_type.strip().lower()
# 提取场景描述(支持多种字段名)
scene = (
item.get("scene")
or item.get("prompt")
or item.get("description")
or item.get("image_prompt")
or item.get("visual")
)
normalized["scene"] = sanitize_string(scene, max_length=2000)
# 提取配音文案(支持多种字段名)
voiceover = (
item.get("voiceover")
or item.get("text")
or item.get("content")
or item.get("narration")
or item.get("script")
)
normalized["voiceover"] = sanitize_string(voiceover, max_length=2000) or ""
# 提取时长(Segment 模型要求 int 秒数)
duration = item.get("duration")
duration_str = parse_duration(duration) # 返回如 "5s"
try:
normalized["duration"] = int(re.search(r"\d+", duration_str).group())
except (AttributeError, ValueError):
normalized["duration"] = 5
# 计算字数
normalized["word_count"] = len(normalized["voiceover"])
normalized_shots.append(normalized)
return normalized_shots
def _normalize_json_quotes(json_str: str) -> str:
"""
将中文引号(弯引号)替换为英文引号
某些 AI 模型会在长文本生成中混用中英文标点,导致 JSON 解析失败。
此函数将中文引号 "" 替换为标准 JSON 使用的英文引号 "
Args:
json_str: 原始 JSON 字符串
Returns:
规范化后的 JSON 字符串
"""
# 中文左双引号 " 和右双引号 " 都替换为英文双引号 "
return json_str.replace('"', '"').replace('"', '"')
def safe_parse_ai_json_response(content: str) -> tuple[bool, Any, str | None]:
"""
安全地解析 AI JSON 响应
Args:
content: AI 返回的原始内容
Returns:
tuple: (是否成功, 解析后的数据, 错误信息)
"""
if not content or not content.strip():
return False, None, "AI 返回内容为空"
# 提取 JSON 字符串
json_str = extract_json_from_markdown(content)
if not json_str:
logger.error(f"无法从内容中提取 JSON: {content[:200]}...")
return False, None, "无法从 AI 输出中提取 JSON"
# 尝试直接解析 JSON
try:
data = json.loads(json_str)
return True, data, None
except json.JSONDecodeError:
pass # 解析失败,尝试修复
# 尝试修复中文引号问题
normalized = _normalize_json_quotes(json_str)
try:
data = json.loads(normalized)
logger.info("JSON 引号规范化成功")
return True, data, None
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败: {e}")
logger.error(f"原始内容前 500 字符: {json_str[:500]!r}")
return False, None, f"JSON 解析失败: {e}"
except Exception as e:
logger.error(f"解析 AI 响应时发生未知错误: {e}")
return False, None, f"解析错误: {e}"
def validate_shots_structure(shots: list[dict]) -> tuple[bool, list[str]]:
"""
验证分镜列表的结构完整性
Args:
shots: 分镜列表
Returns:
tuple: (是否有效, 错误信息列表)
"""
errors = []
if not shots:
errors.append("分镜列表为空")
return False, errors
for idx, shot in enumerate(shots):
# 检查必需字段
if not isinstance(shot, dict):
errors.append(f"分镜 {idx + 1} 不是字典类型")
continue
# 检查 voiceover(允许为空字符串但不允许缺失)
if "voiceover" not in shot:
errors.append(f"分镜 {idx + 1} 缺少 voiceover 字段")
# 检查 id
if "id" not in shot:
errors.append(f"分镜 {idx + 1} 缺少 id 字段")
elif not isinstance(shot.get("id"), int):
errors.append(f"分镜 {idx + 1} 的 id 不是整数")
# 检查 type
if "type" not in shot:
errors.append(f"分镜 {idx + 1} 缺少 type 字段")
return len(errors) == 0, errors