553 lines
16 KiB
Python
553 lines
16 KiB
Python
"""
|
|
AI 模型管理与生成 API
|
|
=====================
|
|
|
|
提供模型列表查询、文本生成、脚本生成、润色等功能。
|
|
|
|
模型配置存储在 config/ai_models.yaml,支持热重载。
|
|
"""
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.ai.model_router import get_model_router
|
|
from app.core.config_loader import get_config_loader
|
|
from app.schemas.common import ApiResponse, success_response
|
|
from app.services.ai_response_utils import (
|
|
safe_parse_ai_json_response,
|
|
validate_and_normalize_shots,
|
|
validate_shots_structure,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============ 请求/响应 Schema ============
|
|
|
|
|
|
class PlatformResponse(BaseModel):
|
|
"""平台响应"""
|
|
|
|
id: str
|
|
name: str
|
|
provider: str
|
|
|
|
|
|
class ModelResponse(BaseModel):
|
|
"""模型响应"""
|
|
|
|
id: str
|
|
platform_id: str
|
|
model_name: str
|
|
display_name: str
|
|
capabilities: list[str]
|
|
default_params: dict
|
|
is_enabled: bool
|
|
full_model_id: str
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
"""生成请求"""
|
|
|
|
prompt: str = Field(..., description="提示词")
|
|
model_id: str | None = Field(None, description="指定模型 ID")
|
|
task_type: str | None = Field(
|
|
None, description="任务类型,用于自动选模型: script/polish"
|
|
)
|
|
temperature: float | None = Field(None, description="随机性 (0-2)")
|
|
max_tokens: int | None = Field(None, description="最大生成长度")
|
|
|
|
|
|
class GenerateResponse(BaseModel):
|
|
"""生成响应"""
|
|
|
|
content: str
|
|
model: str
|
|
usage: dict | None
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""健康检查响应"""
|
|
|
|
status: str
|
|
total_models: int
|
|
available_models: int
|
|
models: list[dict]
|
|
|
|
|
|
# ============ API 路由 ============
|
|
|
|
|
|
@router.get("/platforms", response_model=ApiResponse[list[PlatformResponse]])
|
|
async def list_platforms():
|
|
"""获取所有平台列表"""
|
|
config_loader = get_config_loader()
|
|
platforms = config_loader.get_all_platforms()
|
|
|
|
return success_response(
|
|
data=[
|
|
PlatformResponse(
|
|
id=p.id,
|
|
name=p.name,
|
|
provider=p.provider,
|
|
)
|
|
for p in platforms
|
|
]
|
|
)
|
|
|
|
|
|
@router.get("/models", response_model=ApiResponse[list[ModelResponse]])
|
|
async def list_models(capability: str | None = None):
|
|
"""获取模型列表
|
|
|
|
Args:
|
|
capability: 按能力过滤,如 script、polish、chat
|
|
"""
|
|
router = await get_model_router()
|
|
models = router.list_models(capability=capability)
|
|
|
|
return success_response(
|
|
data=[
|
|
ModelResponse(
|
|
id=m["id"],
|
|
platform_id=m["platform_id"],
|
|
model_name=m["model_name"],
|
|
display_name=m["display_name"],
|
|
capabilities=m["capabilities"],
|
|
default_params=m["default_params"],
|
|
is_enabled=True, # 列表中的都是启用的
|
|
full_model_id=f"{m['platform_id']}/{m['id']}",
|
|
)
|
|
for m in models
|
|
]
|
|
)
|
|
|
|
|
|
@router.post("/generate", response_model=ApiResponse[GenerateResponse])
|
|
async def generate_text(data: GenerateRequest):
|
|
"""文本生成(自动路由到对应平台)"""
|
|
router = await get_model_router()
|
|
|
|
try:
|
|
result = await router.generate(
|
|
prompt=data.prompt,
|
|
model_id=data.model_id,
|
|
task_type=data.task_type,
|
|
temperature=data.temperature,
|
|
max_tokens=data.max_tokens,
|
|
)
|
|
|
|
return success_response(
|
|
data=GenerateResponse(
|
|
content=result.content,
|
|
model=result.model,
|
|
usage=result.usage,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/health", response_model=ApiResponse[HealthResponse])
|
|
async def health_check(model_id: str | None = None):
|
|
"""检查模型健康状态"""
|
|
router = await get_model_router()
|
|
|
|
health_results = await router.health_check(model_id)
|
|
|
|
models_status = []
|
|
available_count = 0
|
|
|
|
for mid, health in health_results.items():
|
|
models_status.append(
|
|
{
|
|
"id": mid,
|
|
"name": health.name,
|
|
"is_available": health.is_available,
|
|
"response_time": health.response_time,
|
|
"last_error": health.last_error,
|
|
}
|
|
)
|
|
if health.is_available:
|
|
available_count += 1
|
|
|
|
return success_response(
|
|
data={
|
|
"status": "healthy" if available_count > 0 else "unhealthy",
|
|
"total_models": len(models_status),
|
|
"available_models": available_count,
|
|
"models": models_status,
|
|
}
|
|
)
|
|
|
|
|
|
@router.get("/platforms/{platform_id}/test", response_model=ApiResponse[dict])
|
|
async def test_platform_connection(platform_id: str):
|
|
"""测试平台连接"""
|
|
from app.ai.model_router import PlatformInstance
|
|
|
|
config_loader = get_config_loader()
|
|
platform = config_loader.get_platform(platform_id)
|
|
|
|
if not platform:
|
|
raise HTTPException(status_code=404, detail="平台不存在")
|
|
|
|
try:
|
|
# PlatformInstance 自动从 Settings 读取 API Key
|
|
instance = PlatformInstance(
|
|
{
|
|
"id": platform.id,
|
|
"name": platform.name,
|
|
"provider": platform.provider,
|
|
"base_url": platform.base_url,
|
|
}
|
|
)
|
|
|
|
# 尝试调用
|
|
result = await instance.provider.health_check()
|
|
|
|
return success_response(
|
|
data={
|
|
"platform_id": platform_id,
|
|
"success": result.is_available,
|
|
"response_time": result.response_time,
|
|
"message": "连接成功" if result.is_available else result.last_error,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
return success_response(
|
|
data={
|
|
"platform_id": platform_id,
|
|
"success": False,
|
|
"error": str(e),
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/reload", response_model=ApiResponse[dict])
|
|
async def reload_config():
|
|
"""重新加载配置文件"""
|
|
router = await get_model_router()
|
|
reloaded = router.reload_config()
|
|
|
|
if reloaded:
|
|
return success_response(data={"reloaded": True}, message="配置已重新加载")
|
|
else:
|
|
return success_response(data={"reloaded": False}, message="配置文件无变化")
|
|
|
|
|
|
# =============================================================================
|
|
# Prompt 模板 API
|
|
# =============================================================================
|
|
|
|
from app.ai.prompts import (
|
|
SCRIPT_TYPES,
|
|
VIDEO_STYLES,
|
|
PolishPromptBuilder,
|
|
ScriptPromptBuilder,
|
|
)
|
|
|
|
|
|
class PromptTemplatesResponse(BaseModel):
|
|
"""Prompt 模板配置响应"""
|
|
|
|
script_types: list[dict]
|
|
video_styles: list[dict]
|
|
tones: list[str]
|
|
|
|
|
|
class ScriptGenerateRequest(BaseModel):
|
|
"""脚本生成请求"""
|
|
|
|
topic: str = Field(..., description="脚本主题", example="水电改造的3个致命错误")
|
|
duration: int = Field(30, ge=15, le=120, description="视频时长(秒)")
|
|
script_type: str = Field("干货型", description="脚本类型")
|
|
video_style: str = Field("口播", description="视频风格")
|
|
tone: str | None = Field(None, description="语气风格")
|
|
requirements: str | None = Field(None, description="额外要求")
|
|
model_id: str | None = Field(None, description="指定模型ID,默认使用系统默认模型")
|
|
|
|
|
|
class ScriptGenerateResponse(BaseModel):
|
|
"""脚本生成响应 - 针对前端展示优化"""
|
|
|
|
success: bool
|
|
script: list[
|
|
dict | None
|
|
] # 镜头列表,包含 shot_number, type, scene/prompt, voiceover, duration, word_count
|
|
total_duration: int | None # 预计总时长(秒)
|
|
target_duration: int # 目标时长(秒)
|
|
total_word_count: int | None # 总字数(供前端展示)
|
|
segment_count: int | None # 分镜数量(供前端展示)
|
|
empty_shot_count: int | None # 空镜数量(供前端展示)
|
|
script_type: str
|
|
model: str
|
|
usage: dict | None
|
|
error: str | None
|
|
raw_content: str | None
|
|
|
|
|
|
class PolishRequest(BaseModel):
|
|
"""润色请求"""
|
|
|
|
content: str = Field(..., description="需要润色的内容")
|
|
polish_type: str = Field("voiceover", description="润色类型:scene/voiceover")
|
|
model_id: str | None = Field(None, description="指定模型ID")
|
|
|
|
|
|
class PolishResponse(BaseModel):
|
|
"""润色响应"""
|
|
|
|
success: bool
|
|
original: str
|
|
polished: str | None
|
|
polish_type: str
|
|
model: str
|
|
usage: dict | None
|
|
|
|
|
|
@router.get("/prompts/templates", response_model=ApiResponse[PromptTemplatesResponse])
|
|
async def get_prompt_templates():
|
|
"""
|
|
获取所有可用的 Prompt 模板配置
|
|
|
|
包括脚本类型、视频风格、语气风格等选项。
|
|
"""
|
|
return success_response(
|
|
data={
|
|
"script_types": [
|
|
{
|
|
"id": key,
|
|
"name": value["name"],
|
|
"description": value["description"],
|
|
"key_points": value["key_points"],
|
|
}
|
|
for key, value in SCRIPT_TYPES.items()
|
|
if key != "default"
|
|
],
|
|
"video_styles": [
|
|
{
|
|
"id": key,
|
|
"name": value["name"],
|
|
"description": value["description"],
|
|
}
|
|
for key, value in VIDEO_STYLES.items()
|
|
],
|
|
"tones": ["专业", "亲和", "幽默", "严肃", "激情"],
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/prompts/build", response_model=ApiResponse[dict])
|
|
async def build_system_prompt(
|
|
duration: int = 30,
|
|
script_type: str = "干货型",
|
|
video_style: str = "口播",
|
|
tone: str | None = None,
|
|
):
|
|
"""
|
|
构建系统 Prompt(用于调试和预览)
|
|
|
|
返回构建好的系统 Prompt,可用于前端预览或调试。
|
|
"""
|
|
builder = ScriptPromptBuilder()
|
|
prompt = builder.build(
|
|
duration=duration,
|
|
script_type=script_type,
|
|
video_style=video_style,
|
|
industry="家装",
|
|
tone=tone,
|
|
)
|
|
|
|
return success_response(
|
|
data={
|
|
"system_prompt": prompt,
|
|
"length": len(prompt),
|
|
"parameters": {
|
|
"duration": duration,
|
|
"script_type": script_type,
|
|
"video_style": video_style,
|
|
"tone": tone,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/scripts/generate", response_model=ApiResponse[ScriptGenerateResponse])
|
|
async def generate_script(data: ScriptGenerateRequest):
|
|
"""
|
|
生成家装行业短视频脚本
|
|
|
|
使用专业的 Prompt 模板生成包含分镜+空镜的混合脚本。
|
|
针对前端展示优化,返回分镜数、空镜数、总字数等统计信息。
|
|
"""
|
|
router = await get_model_router()
|
|
|
|
# 构建系统 Prompt
|
|
builder = ScriptPromptBuilder()
|
|
system_prompt = builder.build(
|
|
duration=data.duration,
|
|
script_type=data.script_type,
|
|
video_style=data.video_style,
|
|
industry="家装",
|
|
tone=data.requirements,
|
|
custom_requirements=data.requirements,
|
|
)
|
|
|
|
# 构建用户输入
|
|
user_prompt = f"""主题是"{data.topic}"
|
|
|
|
要求:
|
|
1. 严格按照时长要求控制
|
|
2. 每个镜头的配音字数必须匹配时长(4-5字/秒)
|
|
3. 空镜必须有画外音,不能为空
|
|
4. 只返回JSON数组,不要有其他文字"""
|
|
|
|
full_prompt = f"{system_prompt}\n\n【用户输入】\n{user_prompt}\n\n请生成脚本,只返回JSON数组:"
|
|
|
|
# 调用模型
|
|
try:
|
|
result = await router.generate(
|
|
prompt=full_prompt,
|
|
model_id=data.model_id,
|
|
task_type="script",
|
|
temperature=0.7,
|
|
max_tokens=2500,
|
|
)
|
|
|
|
# 安全地解析 JSON 响应
|
|
success_parsed, parsed_data, error_msg = safe_parse_ai_json_response(
|
|
result.content
|
|
)
|
|
|
|
if not success_parsed:
|
|
logger.error(f"AI 响应解析失败: {error_msg}")
|
|
return success_response(
|
|
data={
|
|
"success": False,
|
|
"script": None,
|
|
"total_duration": None,
|
|
"target_duration": data.duration,
|
|
"total_word_count": None,
|
|
"segment_count": None,
|
|
"empty_shot_count": None,
|
|
"script_type": data.script_type,
|
|
"model": result.model,
|
|
"usage": result.usage,
|
|
"error": error_msg or "JSON解析失败",
|
|
"raw_content": result.content,
|
|
}
|
|
)
|
|
|
|
# 验证并标准化分镜数据
|
|
try:
|
|
script = validate_and_normalize_shots(parsed_data)
|
|
except Exception as e:
|
|
logger.error(f"分镜数据标准化失败: {e}")
|
|
return success_response(
|
|
data={
|
|
"success": False,
|
|
"script": None,
|
|
"total_duration": None,
|
|
"target_duration": data.duration,
|
|
"total_word_count": None,
|
|
"segment_count": None,
|
|
"empty_shot_count": None,
|
|
"script_type": data.script_type,
|
|
"model": result.model,
|
|
"usage": result.usage,
|
|
"error": f"分镜数据格式错误: {e}",
|
|
"raw_content": result.content,
|
|
}
|
|
)
|
|
|
|
# 验证分镜结构
|
|
is_valid, validation_errors = validate_shots_structure(script)
|
|
if not is_valid:
|
|
logger.warning(f"分镜结构验证失败: {validation_errors}")
|
|
# 继续处理,但记录警告
|
|
|
|
# 计算统计信息(供前端展示)
|
|
total_duration = sum(
|
|
int(shot.get("duration", "5s").rstrip("s秒"))
|
|
for shot in script
|
|
if isinstance(shot, dict)
|
|
)
|
|
total_word_count = sum(
|
|
len(shot.get("voiceover", "")) for shot in script if isinstance(shot, dict)
|
|
)
|
|
segment_count = sum(
|
|
1
|
|
for shot in script
|
|
if isinstance(shot, dict) and shot.get("type") == "segment"
|
|
)
|
|
empty_shot_count = sum(
|
|
1
|
|
for shot in script
|
|
if isinstance(shot, dict) and shot.get("type") == "empty_shot"
|
|
)
|
|
|
|
return success_response(
|
|
data={
|
|
"success": True,
|
|
"script": script,
|
|
"total_duration": total_duration,
|
|
"target_duration": data.duration,
|
|
"total_word_count": total_word_count,
|
|
"segment_count": segment_count,
|
|
"empty_shot_count": empty_shot_count,
|
|
"script_type": data.script_type,
|
|
"model": result.model,
|
|
"usage": result.usage,
|
|
"error": None,
|
|
"raw_content": None,
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"生成失败: {str(e)}")
|
|
|
|
|
|
@router.post("/scripts/polish", response_model=ApiResponse[PolishResponse])
|
|
async def polish_script_content(data: PolishRequest):
|
|
"""
|
|
润色脚本内容
|
|
|
|
对场景描述或口播文案进行专业润色。
|
|
"""
|
|
router = await get_model_router()
|
|
|
|
# 构建润色 Prompt
|
|
builder = PolishPromptBuilder()
|
|
system_prompt = builder.build(data.polish_type)
|
|
|
|
full_prompt = f"{system_prompt}\n\n【待润色内容】\n{data.content}\n\n请润色:"
|
|
|
|
# 调用模型
|
|
try:
|
|
result = await router.generate(
|
|
prompt=full_prompt,
|
|
model_id=data.model_id,
|
|
task_type="polish",
|
|
temperature=0.6,
|
|
max_tokens=1000,
|
|
)
|
|
|
|
return success_response(
|
|
data={
|
|
"success": True,
|
|
"original": data.content,
|
|
"polished": result.content,
|
|
"polish_type": data.polish_type,
|
|
"model": result.model,
|
|
"usage": result.usage,
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"润色失败: {str(e)}")
|