Files
meijiaka-zy/python-api/app/api/v1/ai_models.py
T

553 lines
16 KiB
Python

"""
AI 模型管理与生成 API
=====================
提供模型列表查询、文本生成、脚本生成、润色等功能。
模型配置存储在 config/ai_models.yaml,支持热重载。
"""
import logging
logger = logging.getLogger(__name__)
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.ai.model_router import get_model_router
from app.core.config_loader import get_config_loader
from app.schemas.common import ApiResponse, success_response
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
validate_and_normalize_shots,
validate_shots_structure,
)
router = APIRouter()
# ============ 请求/响应 Schema ============
class PlatformResponse(BaseModel):
"""平台响应"""
id: str
name: str
provider: str
class ModelResponse(BaseModel):
"""模型响应"""
id: str
platform_id: str
model_name: str
display_name: str
capabilities: list[str]
default_params: dict
is_enabled: bool
full_model_id: str
class GenerateRequest(BaseModel):
"""生成请求"""
prompt: str = Field(..., description="提示词")
model_id: str | None = Field(None, description="指定模型 ID")
task_type: str | None = Field(
None, description="任务类型,用于自动选模型: script/polish"
)
temperature: float | None = Field(None, description="随机性 (0-2)")
max_tokens: int | None = Field(None, description="最大生成长度")
class GenerateResponse(BaseModel):
"""生成响应"""
content: str
model: str
usage: dict | None
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str
total_models: int
available_models: int
models: list[dict]
# ============ API 路由 ============
@router.get("/platforms", response_model=ApiResponse[list[PlatformResponse]])
async def list_platforms():
"""获取所有平台列表"""
config_loader = get_config_loader()
platforms = config_loader.get_all_platforms()
return success_response(
data=[
PlatformResponse(
id=p.id,
name=p.name,
provider=p.provider,
)
for p in platforms
]
)
@router.get("/models", response_model=ApiResponse[list[ModelResponse]])
async def list_models(capability: str | None = None):
"""获取模型列表
Args:
capability: 按能力过滤,如 script、polish、chat
"""
router = await get_model_router()
models = router.list_models(capability=capability)
return success_response(
data=[
ModelResponse(
id=m["id"],
platform_id=m["platform_id"],
model_name=m["model_name"],
display_name=m["display_name"],
capabilities=m["capabilities"],
default_params=m["default_params"],
is_enabled=True, # 列表中的都是启用的
full_model_id=f"{m['platform_id']}/{m['id']}",
)
for m in models
]
)
@router.post("/generate", response_model=ApiResponse[GenerateResponse])
async def generate_text(data: GenerateRequest):
"""文本生成(自动路由到对应平台)"""
router = await get_model_router()
try:
result = await router.generate(
prompt=data.prompt,
model_id=data.model_id,
task_type=data.task_type,
temperature=data.temperature,
max_tokens=data.max_tokens,
)
return success_response(
data=GenerateResponse(
content=result.content,
model=result.model,
usage=result.usage,
)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/health", response_model=ApiResponse[HealthResponse])
async def health_check(model_id: str | None = None):
"""检查模型健康状态"""
router = await get_model_router()
health_results = await router.health_check(model_id)
models_status = []
available_count = 0
for mid, health in health_results.items():
models_status.append(
{
"id": mid,
"name": health.name,
"is_available": health.is_available,
"response_time": health.response_time,
"last_error": health.last_error,
}
)
if health.is_available:
available_count += 1
return success_response(
data={
"status": "healthy" if available_count > 0 else "unhealthy",
"total_models": len(models_status),
"available_models": available_count,
"models": models_status,
}
)
@router.get("/platforms/{platform_id}/test", response_model=ApiResponse[dict])
async def test_platform_connection(platform_id: str):
"""测试平台连接"""
from app.ai.model_router import PlatformInstance
config_loader = get_config_loader()
platform = config_loader.get_platform(platform_id)
if not platform:
raise HTTPException(status_code=404, detail="平台不存在")
try:
# PlatformInstance 自动从 Settings 读取 API Key
instance = PlatformInstance(
{
"id": platform.id,
"name": platform.name,
"provider": platform.provider,
"base_url": platform.base_url,
}
)
# 尝试调用
result = await instance.provider.health_check()
return success_response(
data={
"platform_id": platform_id,
"success": result.is_available,
"response_time": result.response_time,
"message": "连接成功" if result.is_available else result.last_error,
}
)
except Exception as e:
return success_response(
data={
"platform_id": platform_id,
"success": False,
"error": str(e),
}
)
@router.post("/reload", response_model=ApiResponse[dict])
async def reload_config():
"""重新加载配置文件"""
router = await get_model_router()
reloaded = router.reload_config()
if reloaded:
return success_response(data={"reloaded": True}, message="配置已重新加载")
else:
return success_response(data={"reloaded": False}, message="配置文件无变化")
# =============================================================================
# Prompt 模板 API
# =============================================================================
from app.ai.prompts import (
SCRIPT_TYPES,
VIDEO_STYLES,
PolishPromptBuilder,
ScriptPromptBuilder,
)
class PromptTemplatesResponse(BaseModel):
"""Prompt 模板配置响应"""
script_types: list[dict]
video_styles: list[dict]
tones: list[str]
class ScriptGenerateRequest(BaseModel):
"""脚本生成请求"""
topic: str = Field(..., description="脚本主题", example="水电改造的3个致命错误")
duration: int = Field(30, ge=15, le=120, description="视频时长(秒)")
script_type: str = Field("干货型", description="脚本类型")
video_style: str = Field("口播", description="视频风格")
tone: str | None = Field(None, description="语气风格")
requirements: str | None = Field(None, description="额外要求")
model_id: str | None = Field(None, description="指定模型ID,默认使用系统默认模型")
class ScriptGenerateResponse(BaseModel):
"""脚本生成响应 - 针对前端展示优化"""
success: bool
script: list[
dict | None
] # 镜头列表,包含 shot_number, type, scene/prompt, voiceover, duration, word_count
total_duration: int | None # 预计总时长(秒)
target_duration: int # 目标时长(秒)
total_word_count: int | None # 总字数(供前端展示)
segment_count: int | None # 分镜数量(供前端展示)
empty_shot_count: int | None # 空镜数量(供前端展示)
script_type: str
model: str
usage: dict | None
error: str | None
raw_content: str | None
class PolishRequest(BaseModel):
"""润色请求"""
content: str = Field(..., description="需要润色的内容")
polish_type: str = Field("voiceover", description="润色类型:scene/voiceover")
model_id: str | None = Field(None, description="指定模型ID")
class PolishResponse(BaseModel):
"""润色响应"""
success: bool
original: str
polished: str | None
polish_type: str
model: str
usage: dict | None
@router.get("/prompts/templates", response_model=ApiResponse[PromptTemplatesResponse])
async def get_prompt_templates():
"""
获取所有可用的 Prompt 模板配置
包括脚本类型、视频风格、语气风格等选项。
"""
return success_response(
data={
"script_types": [
{
"id": key,
"name": value["name"],
"description": value["description"],
"key_points": value["key_points"],
}
for key, value in SCRIPT_TYPES.items()
if key != "default"
],
"video_styles": [
{
"id": key,
"name": value["name"],
"description": value["description"],
}
for key, value in VIDEO_STYLES.items()
],
"tones": ["专业", "亲和", "幽默", "严肃", "激情"],
}
)
@router.post("/prompts/build", response_model=ApiResponse[dict])
async def build_system_prompt(
duration: int = 30,
script_type: str = "干货型",
video_style: str = "口播",
tone: str | None = None,
):
"""
构建系统 Prompt(用于调试和预览)
返回构建好的系统 Prompt,可用于前端预览或调试。
"""
builder = ScriptPromptBuilder()
prompt = builder.build(
duration=duration,
script_type=script_type,
video_style=video_style,
industry="家装",
tone=tone,
)
return success_response(
data={
"system_prompt": prompt,
"length": len(prompt),
"parameters": {
"duration": duration,
"script_type": script_type,
"video_style": video_style,
"tone": tone,
},
}
)
@router.post("/scripts/generate", response_model=ApiResponse[ScriptGenerateResponse])
async def generate_script(data: ScriptGenerateRequest):
"""
生成家装行业短视频脚本
使用专业的 Prompt 模板生成包含分镜+空镜的混合脚本。
针对前端展示优化,返回分镜数、空镜数、总字数等统计信息。
"""
router = await get_model_router()
# 构建系统 Prompt
builder = ScriptPromptBuilder()
system_prompt = builder.build(
duration=data.duration,
script_type=data.script_type,
video_style=data.video_style,
industry="家装",
tone=data.requirements,
custom_requirements=data.requirements,
)
# 构建用户输入
user_prompt = f"""主题是"{data.topic}"
要求:
1. 严格按照时长要求控制
2. 每个镜头的配音字数必须匹配时长(4-5字/秒)
3. 空镜必须有画外音,不能为空
4. 只返回JSON数组,不要有其他文字"""
full_prompt = f"{system_prompt}\n\n【用户输入】\n{user_prompt}\n\n请生成脚本,只返回JSON数组:"
# 调用模型
try:
result = await router.generate(
prompt=full_prompt,
model_id=data.model_id,
task_type="script",
temperature=0.7,
max_tokens=2500,
)
# 安全地解析 JSON 响应
success_parsed, parsed_data, error_msg = safe_parse_ai_json_response(
result.content
)
if not success_parsed:
logger.error(f"AI 响应解析失败: {error_msg}")
return success_response(
data={
"success": False,
"script": None,
"total_duration": None,
"target_duration": data.duration,
"total_word_count": None,
"segment_count": None,
"empty_shot_count": None,
"script_type": data.script_type,
"model": result.model,
"usage": result.usage,
"error": error_msg or "JSON解析失败",
"raw_content": result.content,
}
)
# 验证并标准化分镜数据
try:
script = validate_and_normalize_shots(parsed_data)
except Exception as e:
logger.error(f"分镜数据标准化失败: {e}")
return success_response(
data={
"success": False,
"script": None,
"total_duration": None,
"target_duration": data.duration,
"total_word_count": None,
"segment_count": None,
"empty_shot_count": None,
"script_type": data.script_type,
"model": result.model,
"usage": result.usage,
"error": f"分镜数据格式错误: {e}",
"raw_content": result.content,
}
)
# 验证分镜结构
is_valid, validation_errors = validate_shots_structure(script)
if not is_valid:
logger.warning(f"分镜结构验证失败: {validation_errors}")
# 继续处理,但记录警告
# 计算统计信息(供前端展示)
total_duration = sum(
int(shot.get("duration", "5s").rstrip("s秒"))
for shot in script
if isinstance(shot, dict)
)
total_word_count = sum(
len(shot.get("voiceover", "")) for shot in script if isinstance(shot, dict)
)
segment_count = sum(
1
for shot in script
if isinstance(shot, dict) and shot.get("type") == "segment"
)
empty_shot_count = sum(
1
for shot in script
if isinstance(shot, dict) and shot.get("type") == "empty_shot"
)
return success_response(
data={
"success": True,
"script": script,
"total_duration": total_duration,
"target_duration": data.duration,
"total_word_count": total_word_count,
"segment_count": segment_count,
"empty_shot_count": empty_shot_count,
"script_type": data.script_type,
"model": result.model,
"usage": result.usage,
"error": None,
"raw_content": None,
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
@router.post("/scripts/polish", response_model=ApiResponse[PolishResponse])
async def polish_script_content(data: PolishRequest):
"""
润色脚本内容
对场景描述或口播文案进行专业润色。
"""
router = await get_model_router()
# 构建润色 Prompt
builder = PolishPromptBuilder()
system_prompt = builder.build(data.polish_type)
full_prompt = f"{system_prompt}\n\n【待润色内容】\n{data.content}\n\n请润色:"
# 调用模型
try:
result = await router.generate(
prompt=full_prompt,
model_id=data.model_id,
task_type="polish",
temperature=0.6,
max_tokens=1000,
)
return success_response(
data={
"success": True,
"original": data.content,
"polished": result.content,
"polish_type": data.polish_type,
"model": result.model,
"usage": result.usage,
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"润色失败: {str(e)}")