chore(release): bump version to 1.9.1 and apply pending changes
This commit is contained in:
@@ -0,0 +1,537 @@
|
||||
diff --git a/python-api/app/api/v1/caption.py b/python-api/app/api/v1/caption.py
|
||||
index afae831..99a771a 100644
|
||||
--- a/python-api/app/api/v1/caption.py
|
||||
+++ b/python-api/app/api/v1/caption.py
|
||||
@@ -24,19 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/caption", tags=["Caption"])
|
||||
|
||||
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
@router.post("/ata/align")
|
||||
async def auto_align_caption(
|
||||
request_body: AutoAlignSubmitRequest,
|
||||
@@ -88,9 +75,3 @@ async def auto_align_caption(
|
||||
except Exception as e:
|
||||
logger.error(f"自动打轴异常: {e}")
|
||||
raise HTTPException(status_code=500, detail="字幕打轴失败,请稍后重试")
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
-
|
||||
diff --git a/python-api/app/api/v1/points.py b/python-api/app/api/v1/points.py
|
||||
index b2c7208..23bafa9 100644
|
||||
--- a/python-api/app/api/v1/points.py
|
||||
+++ b/python-api/app/api/v1/points.py
|
||||
@@ -6,7 +6,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
-from datetime import UTC, datetime
|
||||
+from datetime import UTC, datetime, timedelta
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
+from app.core.exceptions import InsufficientPointsException
|
||||
from app.crud.point_recharge_order import point_recharge_order
|
||||
from app.crud.point_transaction import point_transaction
|
||||
from app.models.user import User
|
||||
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/points", tags=["Points"])
|
||||
|
||||
# ── 余额查询 ──────────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/balance", response_model=ApiResponse[PointBalanceResponse])
|
||||
async def get_balance(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -47,6 +49,7 @@ async def get_balance(
|
||||
|
||||
# ── 流水查询 ──────────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/transactions", response_model=ApiResponse[PointTransactionListResponse])
|
||||
async def list_transactions(
|
||||
pagination: PaginationParams = Depends(),
|
||||
@@ -127,12 +130,13 @@ async def list_transactions(
|
||||
|
||||
# ── 充值 ──────────────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.post("/recharge", response_model=ApiResponse[RechargeResponse])
|
||||
async def create_recharge_order(
|
||||
request: RechargeRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
- http_request: Request = None,
|
||||
+ http_request: Request = None, # type: ignore[assignment]
|
||||
):
|
||||
"""
|
||||
创建积分充值订单(微信支付 Native 扫码)
|
||||
@@ -303,9 +307,7 @@ async def handle_wxpay_notify(
|
||||
return _wx_response()
|
||||
|
||||
# 查找订单
|
||||
- order = await point_recharge_order.get_by_out_trade_no(
|
||||
- db, out_trade_no=out_trade_no
|
||||
- )
|
||||
+ order = await point_recharge_order.get_by_out_trade_no(db, out_trade_no=out_trade_no)
|
||||
if not order:
|
||||
logger.error(f"[WechatPay] 回调订单不存在: {out_trade_no}")
|
||||
return _wx_response()
|
||||
@@ -400,9 +402,7 @@ async def query_recharge_status(
|
||||
|
||||
wxpay = get_wxpay_service()
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
|
||||
- wx_result = await wxpay.query_order(
|
||||
- client, out_trade_no=order.out_trade_no
|
||||
- )
|
||||
+ wx_result = await wxpay.query_order(client, out_trade_no=order.out_trade_no)
|
||||
|
||||
order.query_result = str(wx_result)
|
||||
trade_state = wx_result.get("trade_state", "")
|
||||
@@ -465,6 +465,7 @@ async def query_recharge_status(
|
||||
|
||||
# ── 充值档位查询 ──────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/recharge-options", response_model=ApiResponse[list[dict]])
|
||||
async def get_recharge_options(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -480,6 +481,7 @@ async def get_recharge_options(
|
||||
|
||||
# ── 扣费业务类型查询 ──────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/chargeable-types", response_model=ApiResponse[list[str]])
|
||||
async def get_chargeable_types(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -496,6 +498,7 @@ async def get_chargeable_types(
|
||||
|
||||
# ── 积分规则查询 ──────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/rules", response_model=ApiResponse[list[dict]])
|
||||
async def get_points_rules(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -530,9 +533,9 @@ async def get_points_rules(
|
||||
# ── 积分预估查询 ──────────────────────────────────────
|
||||
|
||||
|
||||
-
|
||||
# ── 今日消费统计 ──────────────────────────────────────
|
||||
|
||||
+
|
||||
@router.get("/today-consumed", response_model=ApiResponse[dict])
|
||||
async def get_today_consumed(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -545,6 +548,7 @@ async def get_today_consumed(
|
||||
|
||||
# ── 直接消费扣费(前端/Rust 层调用)───────────────────
|
||||
|
||||
+
|
||||
@router.post("/consume", response_model=ApiResponse[dict])
|
||||
async def consume_points(
|
||||
request: ConsumeRequest,
|
||||
@@ -569,12 +573,12 @@ async def consume_points(
|
||||
source_type=request.source_type,
|
||||
source_id=request.source_id,
|
||||
description=f"【{request.description or request.source_type}】",
|
||||
- allow_negative=False,
|
||||
+ allow_negative=request.allow_negative,
|
||||
)
|
||||
- except ValueError as e:
|
||||
+ except InsufficientPointsException:
|
||||
# 余额不足(在同一事务内判断,避免竞态)
|
||||
- if "积分不足" in str(e):
|
||||
- raise HTTPException(status_code=402, detail=str(e))
|
||||
+ raise
|
||||
+ except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
await db.commit()
|
||||
|
||||
@@ -587,5 +591,3 @@ async def consume_points(
|
||||
},
|
||||
message="消费成功",
|
||||
)
|
||||
-
|
||||
-
|
||||
diff --git a/python-api/app/api/v1/script.py b/python-api/app/api/v1/script.py
|
||||
index 5e7b18f..26a11a3 100644
|
||||
--- a/python-api/app/api/v1/script.py
|
||||
+++ b/python-api/app/api/v1/script.py
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.ai.model_router import get_model_router
|
||||
from app.ai.prompts import list_categories, list_prompt_files, load_prompt, render_template
|
||||
from app.api.deps import get_current_user
|
||||
+from app.core.exceptions import AITimeoutException, InsufficientPointsException
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
@@ -71,9 +72,8 @@ async def polish_content(
|
||||
required_points = ps._calculate_cost("polish")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
- raise HTTPException(
|
||||
- status_code=402,
|
||||
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
+ raise InsufficientPointsException(
|
||||
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -99,11 +99,11 @@ async def polish_content(
|
||||
data=polished,
|
||||
message=f"{type_name}润色完成",
|
||||
)
|
||||
+ except InsufficientPointsException:
|
||||
+ raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
- if "积分不足" in str(e):
|
||||
- raise HTTPException(status_code=402, detail=str(e))
|
||||
logger.warning(f"[Polish] 润色失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="润色失败,请检查输入内容后重试")
|
||||
except Exception as e:
|
||||
@@ -111,9 +111,6 @@ async def polish_content(
|
||||
raise HTTPException(status_code=500, detail=f"{type_name}润色失败,请稍后重试")
|
||||
|
||||
|
||||
-
|
||||
-
|
||||
-
|
||||
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
|
||||
async def generate_title(
|
||||
request: GenerateTitleRequest,
|
||||
@@ -146,7 +143,11 @@ async def generate_title(
|
||||
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
|
||||
|
||||
# 渲染用户提示词
|
||||
- title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
|
||||
+ title_type_desc = (
|
||||
+ "大标题(主标题,提炼核心卖点,吸睛)"
|
||||
+ if request.title_type == "main"
|
||||
+ else "小标题(副标题,补充说明或制造悬念)"
|
||||
+ )
|
||||
user_prompt = render_template(
|
||||
user_template,
|
||||
title_type=request.title_type,
|
||||
@@ -163,9 +164,8 @@ async def generate_title(
|
||||
required_points = ps._calculate_cost("title")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
- raise HTTPException(
|
||||
- status_code=402,
|
||||
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
+ raise InsufficientPointsException(
|
||||
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -179,10 +179,10 @@ async def generate_title(
|
||||
|
||||
title = result.content.strip() if result.content else ""
|
||||
# 去除可能的引号
|
||||
- title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
|
||||
+ title = title.strip('"').strip("'").strip("「」").strip("『』").strip("《》")
|
||||
# 截断到最大长度
|
||||
if len(title) > request.max_length:
|
||||
- title = title[:request.max_length]
|
||||
+ title = title[: request.max_length]
|
||||
|
||||
# 扣费
|
||||
points = ps._calculate_cost("title")
|
||||
@@ -200,15 +200,13 @@ async def generate_title(
|
||||
data=GenerateTitleResponse(title=title),
|
||||
message="标题生成成功",
|
||||
)
|
||||
+ except InsufficientPointsException:
|
||||
+ raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except TimeoutError:
|
||||
logger.warning("[generate_title] 标题生成超时")
|
||||
- raise HTTPException(status_code=500, detail="标题生成超时,请重试")
|
||||
- except ValueError as e:
|
||||
- if "积分不足" in str(e):
|
||||
- raise HTTPException(status_code=402, detail=str(e))
|
||||
- raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
|
||||
+ raise AITimeoutException("标题生成超时,请稍后重试")
|
||||
except Exception as e:
|
||||
logger.error(f"[generate_title] 标题生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
|
||||
diff --git a/python-api/app/api/v1/tasks.py b/python-api/app/api/v1/tasks.py
|
||||
index ae17fe3..55965ca 100644
|
||||
--- a/python-api/app/api/v1/tasks.py
|
||||
+++ b/python-api/app/api/v1/tasks.py
|
||||
@@ -18,6 +18,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
+from app.core.exceptions import InsufficientPointsException
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.models.user import User
|
||||
@@ -38,7 +39,6 @@ class ScriptParams(BaseModel):
|
||||
category: str = Field(..., min_length=1, description="大类代码")
|
||||
filename: str = Field(..., min_length=1, description="提示词文件名")
|
||||
|
||||
-
|
||||
@field_validator("category")
|
||||
@classmethod
|
||||
def validate_category(cls, v: str) -> str:
|
||||
@@ -96,7 +96,9 @@ class VideoParams(BaseModel):
|
||||
volume: int = Field(default=0, ge=0, le=10, description="音量")
|
||||
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
|
||||
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
|
||||
- total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
|
||||
+ total_planned_duration: float = Field(
|
||||
+ ..., gt=0, description="所有分镜规划时长之和(秒),用于预检"
|
||||
+ )
|
||||
batch_id: str | None = Field(default=None, description="批次ID(可选)")
|
||||
|
||||
@field_validator("video_url")
|
||||
@@ -134,6 +136,7 @@ class TaskStatusResponse(BaseModel):
|
||||
total: int = Field(0, description="总子任务数")
|
||||
result: dict | None = Field(None, description="任务结果(完成时)")
|
||||
error: str | None = Field(None, description="错误信息(失败时)")
|
||||
+ error_code: str | None = Field(None, description="错误码(失败时,如 content_violation)")
|
||||
created_at: str = Field("", description="任务创建时间(ISO格式)")
|
||||
|
||||
|
||||
@@ -222,9 +225,8 @@ async def create_task(
|
||||
f"[API] 积分不足: user={user_id}, type={task_type}, "
|
||||
f"required={required_points}, balance={check['balance']}"
|
||||
)
|
||||
- raise HTTPException(
|
||||
- status_code=402,
|
||||
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
+ raise InsufficientPointsException(
|
||||
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
# ── 3. 写入 Redis ──────────────────────────────────
|
||||
@@ -246,18 +248,41 @@ async def create_task(
|
||||
params=validated_params,
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
-
|
||||
- logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
||||
- return TaskCreateResponse(
|
||||
- task_id=task_id,
|
||||
- status="running",
|
||||
- message=f"{task_type} 任务已创建",
|
||||
- )
|
||||
-
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Failed to update registry: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
|
||||
|
||||
+ # ── 4. 脚本生成:Redis 写入成功后再扣费 ─────────────
|
||||
+ if task_type == "script" and required_points > 0:
|
||||
+ try:
|
||||
+ async with AsyncSessionLocal() as db:
|
||||
+ await ps.consume(
|
||||
+ db,
|
||||
+ user_id=user_id,
|
||||
+ points=required_points,
|
||||
+ source_type="script",
|
||||
+ source_id=task_id,
|
||||
+ description="【脚本生成】",
|
||||
+ )
|
||||
+ await db.commit()
|
||||
+ except InsufficientPointsException:
|
||||
+ # 余额不足:将任务标记为失败,避免无费执行
|
||||
+ await registry.update(task_id, status="failed", message="积分不足")
|
||||
+ await registry.remove_running(task_id)
|
||||
+ raise
|
||||
+ except Exception as e:
|
||||
+ logger.error(f"[API] 脚本任务扣费失败: {e}")
|
||||
+ await registry.update(task_id, status="failed", message="扣费失败")
|
||||
+ await registry.remove_running(task_id)
|
||||
+ raise HTTPException(status_code=500, detail="扣费失败,请稍后重试")
|
||||
+
|
||||
+ logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
||||
+ return TaskCreateResponse(
|
||||
+ task_id=task_id,
|
||||
+ status="running",
|
||||
+ message=f"{task_type} 任务已创建",
|
||||
+ )
|
||||
+
|
||||
|
||||
@router.get("", response_model=list[TaskStatusResponse])
|
||||
async def list_tasks(
|
||||
@@ -294,6 +319,7 @@ async def list_tasks(
|
||||
total=task.total,
|
||||
result=None, # 列表查询不返回 result,避免数据过大
|
||||
error=task.error,
|
||||
+ error_code=task.error_code,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
)
|
||||
@@ -337,6 +363,7 @@ async def get_task_status(
|
||||
total=task.total,
|
||||
result=task.result,
|
||||
error=task.error,
|
||||
+ error_code=task.error_code,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
diff --git a/python-api/app/api/v1/vidu.py b/python-api/app/api/v1/vidu.py
|
||||
index c0fb04a..401d8da 100644
|
||||
--- a/python-api/app/api/v1/vidu.py
|
||||
+++ b/python-api/app/api/v1/vidu.py
|
||||
@@ -44,10 +44,12 @@ async def vidu_callback(request: Request):
|
||||
headers_dict = dict(request.headers)
|
||||
|
||||
# 使用 APP_BASE_URL 构建 callback_url,确保与提交任务时传给 Vidu 的一致
|
||||
- #(Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
|
||||
+ # (Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
|
||||
app_base_url = get_settings().app_base_url
|
||||
callback_url = f"{app_base_url}/api/v1/vidu/callback" if app_base_url else str(request.url)
|
||||
- logger.info(f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}")
|
||||
+ logger.info(
|
||||
+ f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}"
|
||||
+ )
|
||||
|
||||
try:
|
||||
task_status = await gateway.handle_webhook(
|
||||
@@ -64,15 +66,13 @@ async def vidu_callback(request: Request):
|
||||
logger.error(f"[Vidu] 回调处理失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="回调处理失败,请稍后重试")
|
||||
|
||||
- logger.info(f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}")
|
||||
+ logger.info(
|
||||
+ f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}"
|
||||
+ )
|
||||
|
||||
# 2. 通过 platform_task_id 反查 Async Engine 内部 task_id,更新 TaskRegistry
|
||||
- platform_task_id = (
|
||||
- task_status.result.get("task_id") if task_status.result else None
|
||||
- )
|
||||
- video_url = (
|
||||
- task_status.result.get("video_url") if task_status.result else None
|
||||
- )
|
||||
+ platform_task_id = task_status.result.get("task_id") if task_status.result else None
|
||||
+ video_url = task_status.result.get("video_url") if task_status.result else None
|
||||
|
||||
logger.info(f"[Vidu] 准备反查 internal_task_id: platform_task_id={platform_task_id}")
|
||||
|
||||
@@ -121,8 +121,6 @@ async def vidu_callback(request: Request):
|
||||
f"platform={platform_task_id}"
|
||||
)
|
||||
else:
|
||||
- logger.warning(
|
||||
- f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}"
|
||||
- )
|
||||
+ logger.warning(f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}")
|
||||
|
||||
return success_response(message="回调已接收")
|
||||
diff --git a/python-api/app/api/v1/voice.py b/python-api/app/api/v1/voice.py
|
||||
index ba8cac9..dad2b7b 100644
|
||||
--- a/python-api/app/api/v1/voice.py
|
||||
+++ b/python-api/app/api/v1/voice.py
|
||||
@@ -10,12 +10,13 @@ import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
+
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
-from app.core.exceptions import PlatformError
|
||||
+from app.core.exceptions import InsufficientPointsException, PlatformError
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
@@ -49,7 +50,9 @@ class TTSSynthesizeRequest(BaseModel):
|
||||
class VoiceCloneSubmitRequest(BaseModel):
|
||||
"""声音复刻提交请求"""
|
||||
|
||||
- source_audio_url: str | None = Field(None, description="源音频 URL(5-30秒,mp3/wav,需公开可访问)")
|
||||
+ source_audio_url: str | None = Field(
|
||||
+ None, description="源音频 URL(5-30秒,mp3/wav,需公开可访问)"
|
||||
+ )
|
||||
source_video_url: str | None = Field(None, description="源视频 URL(可选)")
|
||||
video_id: str | None = Field(None, description="历史作品ID(可选)")
|
||||
voice_name: str | None = Field(None, description="自定义音色名称(≤20字符)")
|
||||
@@ -111,7 +114,7 @@ async def synthesize_speech(
|
||||
# 宽松预检:余额为负或零时阻止,避免浪费第三方资源
|
||||
balance_info = await ps.get_user_balance(db, current_user.id)
|
||||
if balance_info["balance"] <= 0:
|
||||
- raise HTTPException(status_code=402, detail="余额不足,请先充值")
|
||||
+ raise InsufficientPointsException("余额不足,请先充值")
|
||||
|
||||
try:
|
||||
audio_url = await service.synthesize(
|
||||
@@ -137,10 +140,8 @@ async def synthesize_speech(
|
||||
allow_negative=True,
|
||||
)
|
||||
await db.commit()
|
||||
- except ValueError as e:
|
||||
- if "积分不足" in str(e):
|
||||
- raise HTTPException(status_code=402, detail=str(e))
|
||||
- logger.error(f"[Voice] TTS 扣费失败: {e}")
|
||||
+ except InsufficientPointsException:
|
||||
+ raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] TTS 扣费失败: {e}")
|
||||
|
||||
@@ -165,7 +166,6 @@ async def synthesize_speech(
|
||||
raise HTTPException(status_code=500, detail="语音合成失败,请稍后重试")
|
||||
|
||||
|
||||
-
|
||||
def _normalize_voice_id(name: str | None) -> str:
|
||||
"""
|
||||
将用户输入的名称规范化为 Vidu 合法的 voice_id。
|
||||
@@ -220,9 +220,8 @@ async def submit_clone_task(
|
||||
required_points = ps._calculate_cost("voice_clone")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
- raise HTTPException(
|
||||
- status_code=402,
|
||||
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
+ raise InsufficientPointsException(
|
||||
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -244,10 +243,8 @@ async def submit_clone_task(
|
||||
description="【声音复刻】",
|
||||
)
|
||||
await db.commit()
|
||||
- except ValueError as e:
|
||||
- if "积分不足" in str(e):
|
||||
- raise HTTPException(status_code=402, detail=str(e))
|
||||
- logger.error(f"[Voice] 克隆扣费失败: {e}")
|
||||
+ except InsufficientPointsException:
|
||||
+ raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] 克隆扣费失败: {e}")
|
||||
|
||||
@@ -292,5 +289,3 @@ async def query_clone_task(
|
||||
),
|
||||
message="克隆已完成",
|
||||
)
|
||||
-
|
||||
-
|
||||
@@ -0,0 +1,146 @@
|
||||
diff --git a/python-api/app/core/exceptions.py b/python-api/app/core/exceptions.py
|
||||
index d9970d5..837f8d3 100644
|
||||
--- a/python-api/app/core/exceptions.py
|
||||
+++ b/python-api/app/core/exceptions.py
|
||||
@@ -24,9 +24,16 @@ class AppException(HTTPException):
|
||||
status_code: int,
|
||||
message: str = "操作失败",
|
||||
detail: dict | None = None,
|
||||
+ *,
|
||||
+ error_code: str | None = None,
|
||||
):
|
||||
- super().__init__(status_code=status_code, detail=detail or {})
|
||||
+ body = detail or {}
|
||||
+ body["message"] = message
|
||||
+ if error_code:
|
||||
+ body["error_code"] = error_code
|
||||
+ super().__init__(status_code=status_code, detail=body)
|
||||
self.message = message
|
||||
+ self.error_code = error_code
|
||||
|
||||
|
||||
class NotFoundException(AppException):
|
||||
@@ -44,7 +51,7 @@ class ValidationException(AppException):
|
||||
|
||||
def __init__(self, message: str = "参数验证失败"):
|
||||
super().__init__(
|
||||
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
+ status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
message=message,
|
||||
)
|
||||
|
||||
@@ -79,6 +86,17 @@ class BusinessException(AppException):
|
||||
)
|
||||
|
||||
|
||||
+class InsufficientPointsException(AppException):
|
||||
+ """积分不足"""
|
||||
+
|
||||
+ def __init__(self, message: str = "积分不足"):
|
||||
+ super().__init__(
|
||||
+ status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
+ message=message,
|
||||
+ error_code="insufficient_points",
|
||||
+ )
|
||||
+
|
||||
+
|
||||
class ModelUnavailableException(AppException):
|
||||
"""AI 模型不可用"""
|
||||
|
||||
@@ -99,6 +117,50 @@ class TaskFailedException(AppException):
|
||||
)
|
||||
|
||||
|
||||
+class PromptNotFoundException(AppException):
|
||||
+ """提示词文件不存在"""
|
||||
+
|
||||
+ def __init__(self, message: str = "未找到提示词"):
|
||||
+ super().__init__(
|
||||
+ status_code=status.HTTP_404_NOT_FOUND,
|
||||
+ message=message,
|
||||
+ error_code="prompt_not_found",
|
||||
+ )
|
||||
+
|
||||
+
|
||||
+class AIEmptyResponseException(AppException):
|
||||
+ """AI 返回内容为空"""
|
||||
+
|
||||
+ def __init__(self, message: str = "AI 返回内容为空"):
|
||||
+ super().__init__(
|
||||
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
+ message=message,
|
||||
+ error_code="empty_result",
|
||||
+ )
|
||||
+
|
||||
+
|
||||
+class AIParseErrorException(AppException):
|
||||
+ """AI 返回内容解析失败"""
|
||||
+
|
||||
+ def __init__(self, message: str = "AI 返回格式解析失败"):
|
||||
+ super().__init__(
|
||||
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
+ message=message,
|
||||
+ error_code="parse_error",
|
||||
+ )
|
||||
+
|
||||
+
|
||||
+class AITimeoutException(AppException):
|
||||
+ """AI 调用超时"""
|
||||
+
|
||||
+ def __init__(self, message: str = "AI 请求超时,请稍后重试"):
|
||||
+ super().__init__(
|
||||
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
+ message=message,
|
||||
+ error_code="timeout",
|
||||
+ )
|
||||
+
|
||||
+
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# 第三方平台异常(PlatformError 体系)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
@@ -111,14 +173,15 @@ class PlatformErrorType:
|
||||
确保前端和网关能够统一处理。
|
||||
"""
|
||||
|
||||
- RATE_LIMIT = "rate_limit" # 429,可重试
|
||||
- AUTH_FAILED = "auth_failed" # 401/403,不可重试
|
||||
- TIMEOUT = "timeout" # 连接/读取超时,可重试
|
||||
- SERVER_ERROR = "server_error" # 第三方 5xx,可重试
|
||||
- BAD_REQUEST = "bad_request" # 参数错误,不可重试
|
||||
+ RATE_LIMIT = "rate_limit" # 429,可重试
|
||||
+ AUTH_FAILED = "auth_failed" # 401/403,不可重试
|
||||
+ TIMEOUT = "timeout" # 连接/读取超时,可重试
|
||||
+ SERVER_ERROR = "server_error" # 第三方 5xx,可重试
|
||||
+ BAD_REQUEST = "bad_request" # 参数错误,不可重试
|
||||
QUOTA_EXHAUSTED = "quota_exhausted" # 额度用完,不可重试(或延迟重试)
|
||||
- NOT_FOUND = "not_found" # 资源不存在,不可重试
|
||||
- UNKNOWN = "unknown" # 兜底
|
||||
+ NOT_FOUND = "not_found" # 资源不存在,不可重试
|
||||
+ CONTENT_VIOLATION = "content_violation" # 内容安全/审核不通过,不可重试
|
||||
+ UNKNOWN = "unknown" # 兜底
|
||||
|
||||
|
||||
class PlatformError(Exception):
|
||||
@@ -145,12 +208,14 @@ class PlatformError(Exception):
|
||||
retryable: bool = False,
|
||||
error_type: str = PlatformErrorType.UNKNOWN,
|
||||
status_code: int | None = None,
|
||||
+ raw_code: str | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.platform = platform
|
||||
self.retryable = retryable
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
+ self.raw_code = raw_code
|
||||
|
||||
def to_http_status(self) -> int:
|
||||
"""根据 error_type 和 retryable 返回标准 HTTP 状态码"""
|
||||
@@ -161,6 +226,7 @@ class PlatformError(Exception):
|
||||
PlatformErrorType.AUTH_FAILED: 401,
|
||||
PlatformErrorType.BAD_REQUEST: 400,
|
||||
PlatformErrorType.NOT_FOUND: 404,
|
||||
+ PlatformErrorType.CONTENT_VIOLATION: 400,
|
||||
}
|
||||
if self.error_type in mapping:
|
||||
return mapping[self.error_type]
|
||||
@@ -0,0 +1,345 @@
|
||||
diff --git a/python-api/app/ai/providers/vidu_provider.py b/python-api/app/ai/providers/vidu_provider.py
|
||||
index cab5902..fccfbbf 100644
|
||||
--- a/python-api/app/ai/providers/vidu_provider.py
|
||||
+++ b/python-api/app/ai/providers/vidu_provider.py
|
||||
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
-def _map_vidu_error(status: int, message: str) -> PlatformError:
|
||||
- """把 Vidu HTTP 错误映射为标准 PlatformError"""
|
||||
+# Vidu 错误码分类
|
||||
+_VIDU_AUDIT_ERROR_CODES = {
|
||||
+ "TaskPromptPolicyViolation",
|
||||
+ "AuditSubmitIllegal",
|
||||
+ "CreationPolicyViolation",
|
||||
+ "PhotoAuditNotPass",
|
||||
+ "AuditFailed",
|
||||
+ "ImageCheckBodyJointsFailed",
|
||||
+ "ImageCheckFaceFailed",
|
||||
+ "ImageObjectsUndetected",
|
||||
+ "FaceDetectFailure",
|
||||
+ "FaceDetectNotPass",
|
||||
+ "NoFaceDetected",
|
||||
+ "MultiFaceDetected",
|
||||
+}
|
||||
+
|
||||
+_VIDU_RETRYABLE_ERROR_CODES = {
|
||||
+ "InternalServiceFailure",
|
||||
+ "ModelUnavailable",
|
||||
+ "Unknown",
|
||||
+}
|
||||
+
|
||||
+_VIDU_RATE_LIMIT_ERROR_CODES = {
|
||||
+ "QuotaExceeded",
|
||||
+ "TooManyRequests",
|
||||
+ "SystemThrottling",
|
||||
+ "OperationInProcess",
|
||||
+}
|
||||
+
|
||||
+
|
||||
+def _extract_vidu_error_code(message: str | None) -> str | None:
|
||||
+ """从 Vidu 错误信息中提取错误码"""
|
||||
+ if not message:
|
||||
+ return None
|
||||
+ # Vidu 错误码格式:"ErrorCode: 中文描述"
|
||||
+ return message.split(":")[0].strip() or None
|
||||
+
|
||||
+
|
||||
+def _map_vidu_error(
|
||||
+ status: int,
|
||||
+ message: str,
|
||||
+ *,
|
||||
+ err_code: str | None = None,
|
||||
+) -> PlatformError:
|
||||
+ """把 Vidu HTTP 错误映射为标准 PlatformError
|
||||
+
|
||||
+ 优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
|
||||
+ """
|
||||
+ raw_code = err_code or _extract_vidu_error_code(message)
|
||||
+
|
||||
+ # 1. 内容安全/审核类:不可重试
|
||||
+ if raw_code in _VIDU_AUDIT_ERROR_CODES:
|
||||
+ return PlatformError(
|
||||
+ message=message,
|
||||
+ platform="vidu",
|
||||
+ retryable=False,
|
||||
+ error_type=PlatformErrorType.CONTENT_VIOLATION,
|
||||
+ status_code=status,
|
||||
+ raw_code=raw_code,
|
||||
+ )
|
||||
+
|
||||
+ # 2. 平台内部/模型不可用:可重试
|
||||
+ if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
|
||||
+ return PlatformError(
|
||||
+ message=message,
|
||||
+ platform="vidu",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.SERVER_ERROR,
|
||||
+ status_code=status,
|
||||
+ raw_code=raw_code,
|
||||
+ )
|
||||
+
|
||||
+ # 3. 限流类:可重试
|
||||
+ if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
|
||||
+ return PlatformError(
|
||||
+ message=message,
|
||||
+ platform="vidu",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.RATE_LIMIT,
|
||||
+ status_code=status,
|
||||
+ raw_code=raw_code,
|
||||
+ )
|
||||
+
|
||||
+ # 4. HTTP status 兜底
|
||||
mapping = {
|
||||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||||
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
|
||||
retryable=retryable,
|
||||
error_type=error_type,
|
||||
status_code=status,
|
||||
+ raw_code=raw_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +149,9 @@ class ViduProvider:
|
||||
from app.core.platform_config import get_platform_config_loader
|
||||
|
||||
platform_config = get_platform_config_loader().get_platform("vidu")
|
||||
- self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
|
||||
+ self.base_url = (
|
||||
+ platform_config.base_url if platform_config else "https://api.vidu.cn"
|
||||
+ ).rstrip("/")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
|
||||
@@ -135,9 +220,12 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
- logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
- raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
|
||||
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
+ logger.error(
|
||||
+ f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
+ )
|
||||
+ raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu TTS] 网络错误: {e}")
|
||||
@@ -182,9 +270,14 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
- logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
- raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
|
||||
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
+ logger.error(
|
||||
+ f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
+ )
|
||||
+ raise _map_vidu_error(
|
||||
+ resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
|
||||
+ )
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu Clone] 网络错误: {e}")
|
||||
@@ -238,9 +331,14 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body)
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
- logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
- raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
|
||||
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
+ logger.error(
|
||||
+ f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
+ )
|
||||
+ raise _map_vidu_error(
|
||||
+ resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
|
||||
+ )
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu LipSync] 网络错误: {e}")
|
||||
@@ -264,9 +362,14 @@ class ViduProvider:
|
||||
resp = await self.client.get(url)
|
||||
data = resp.json()
|
||||
if resp.status_code != 200:
|
||||
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
- logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
- raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
|
||||
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
+ logger.error(
|
||||
+ f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
+ )
|
||||
+ raise _map_vidu_error(
|
||||
+ resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
|
||||
+ )
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu Query] 网络错误: {e}")
|
||||
diff --git a/python-api/app/ai/providers/volcengine_caption_provider.py b/python-api/app/ai/providers/volcengine_caption_provider.py
|
||||
index 0f2f271..09ddcc7 100644
|
||||
--- a/python-api/app/ai/providers/volcengine_caption_provider.py
|
||||
+++ b/python-api/app/ai/providers/volcengine_caption_provider.py
|
||||
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
||||
if code is not None and code in error_mapping:
|
||||
error_type, retryable = error_mapping[code]
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_caption",
|
||||
- retryable=retryable, error_type=error_type,
|
||||
+ message,
|
||||
+ platform="volcengine_caption",
|
||||
+ retryable=retryable,
|
||||
+ error_type=error_type,
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
||||
}
|
||||
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_caption",
|
||||
- retryable=retryable, error_type=error_type,
|
||||
+ message,
|
||||
+ platform="volcengine_caption",
|
||||
+ retryable=retryable,
|
||||
+ error_type=error_type,
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
|
||||
max_lines: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""提交字幕生成任务,返回 {id: task_id}"""
|
||||
- params = {
|
||||
+ params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"language": language,
|
||||
"caption_type": caption_type,
|
||||
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
|
||||
except PlatformError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
||||
+ raise _map_caption_error(
|
||||
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||||
+ ) from e
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
raise PlatformError(
|
||||
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
||||
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||||
+ f"字幕服务网络错误: {e}",
|
||||
+ platform="volcengine_caption",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.TIMEOUT,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
|
||||
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
|
||||
blocking: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""查询字幕任务结果,返回原始 JSON"""
|
||||
- params = {
|
||||
+ params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"id": task_id,
|
||||
"blocking": 1 if blocking else 0,
|
||||
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
|
||||
except PlatformError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
||||
+ raise _map_caption_error(
|
||||
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||||
+ ) from e
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
raise PlatformError(
|
||||
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
||||
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||||
+ f"字幕服务网络错误: {e}",
|
||||
+ platform="volcengine_caption",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.TIMEOUT,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
|
||||
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
|
||||
sta_punc_mode: int = 3,
|
||||
) -> dict[str, Any]:
|
||||
"""提交自动字幕打轴任务,返回 {id: task_id}"""
|
||||
- params = {
|
||||
+ params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"caption_type": caption_type,
|
||||
"sta_punc_mode": sta_punc_mode,
|
||||
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if "id" not in data:
|
||||
- raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
|
||||
+ raise _map_caption_error(
|
||||
+ 500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
|
||||
+ )
|
||||
return data
|
||||
except PlatformError:
|
||||
raise
|
||||
diff --git a/python-api/app/ai/providers/volcengine_provider.py b/python-api/app/ai/providers/volcengine_provider.py
|
||||
index 0e2a5d5..9f029a0 100644
|
||||
--- a/python-api/app/ai/providers/volcengine_provider.py
|
||||
+++ b/python-api/app/ai/providers/volcengine_provider.py
|
||||
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
|
||||
|
||||
if status == 429 or "rate limit" in message.lower():
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_ark", retryable=True,
|
||||
- error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
|
||||
+ message,
|
||||
+ platform="volcengine_ark",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.RATE_LIMIT,
|
||||
+ status_code=status,
|
||||
)
|
||||
elif status in (401, 403) or "authentication" in message.lower():
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_ark", retryable=False,
|
||||
- error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
|
||||
+ message,
|
||||
+ platform="volcengine_ark",
|
||||
+ retryable=False,
|
||||
+ error_type=PlatformErrorType.AUTH_FAILED,
|
||||
+ status_code=status,
|
||||
)
|
||||
elif status and status >= 500:
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_ark", retryable=True,
|
||||
- error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
|
||||
+ message,
|
||||
+ platform="volcengine_ark",
|
||||
+ retryable=True,
|
||||
+ error_type=PlatformErrorType.SERVER_ERROR,
|
||||
+ status_code=status,
|
||||
)
|
||||
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_ark", retryable=True,
|
||||
+ message,
|
||||
+ platform="volcengine_ark",
|
||||
+ retryable=True,
|
||||
error_type=PlatformErrorType.TIMEOUT,
|
||||
)
|
||||
else:
|
||||
return PlatformError(
|
||||
- message, platform="volcengine_ark", retryable=False,
|
||||
+ message,
|
||||
+ platform="volcengine_ark",
|
||||
+ retryable=False,
|
||||
error_type=PlatformErrorType.UNKNOWN,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
diff --git a/python-api/app/services/point_service.py b/python-api/app/services/point_service.py
|
||||
index 0e5d79e..709f78d 100644
|
||||
--- a/python-api/app/services/point_service.py
|
||||
+++ b/python-api/app/services/point_service.py
|
||||
@@ -25,7 +25,7 @@ import logging
|
||||
import math
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
-from typing import TYPE_CHECKING
|
||||
+from typing import TYPE_CHECKING, Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
+from app.core.exceptions import InsufficientPointsException
|
||||
from app.models.point_batch import PointBatch
|
||||
from app.models.point_transaction import PointTransaction
|
||||
from app.models.user_point import UserPoint
|
||||
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
|
||||
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
|
||||
|
||||
|
||||
-def _load_points_config() -> dict:
|
||||
+def _load_points_config() -> dict[str, Any]:
|
||||
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
|
||||
if not _CONFIG_PATH.exists():
|
||||
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
|
||||
- with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
+ with open(_CONFIG_PATH, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
|
||||
merged: dict[str, dict] = {}
|
||||
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
|
||||
return merged
|
||||
|
||||
|
||||
-POINTS_CONFIG: dict[str, dict] = _load_points_config()
|
||||
+POINTS_CONFIG: dict[str, Any] = _load_points_config()
|
||||
|
||||
|
||||
def get_recharge_options() -> list[dict]:
|
||||
"""获取充值档位配置(由后端控制,支持积分赠送)"""
|
||||
- return POINTS_CONFIG.get("_recharge_options", [])
|
||||
+ options = POINTS_CONFIG.get("_recharge_options", [])
|
||||
+ if isinstance(options, list):
|
||||
+ return options
|
||||
+ return []
|
||||
|
||||
|
||||
def get_chargeable_source_types() -> list[str]:
|
||||
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
|
||||
return [
|
||||
- key for key, cfg in POINTS_CONFIG.items()
|
||||
+ key
|
||||
+ for key, cfg in POINTS_CONFIG.items()
|
||||
if not key.startswith("_") and cfg.get("mode") != "free"
|
||||
]
|
||||
|
||||
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
|
||||
|
||||
# ── 余额查询 ──────────────────────────────────────────
|
||||
|
||||
+
|
||||
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
||||
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
|
||||
- result = await db.execute(
|
||||
- select(UserPoint).where(UserPoint.user_id == user_id)
|
||||
- )
|
||||
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
if not up:
|
||||
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
||||
from sqlalchemy import func as _func
|
||||
|
||||
available_result = await db.execute(
|
||||
- select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
|
||||
- .where(
|
||||
+ select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
|
||||
PointBatch.user_id == user_id,
|
||||
PointBatch.remaining > 0,
|
||||
PointBatch.expired_at > _now(),
|
||||
@@ -221,6 +224,7 @@ async def check_balance(
|
||||
|
||||
# ── 充值 ──────────────────────────────────────────────
|
||||
|
||||
+
|
||||
async def recharge(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -247,8 +251,7 @@ async def recharge(
|
||||
# 幂等保护:同一笔订单(order_id)只能充值一次
|
||||
if order_id:
|
||||
existing_result = await db.execute(
|
||||
- select(PointTransaction)
|
||||
- .where(
|
||||
+ select(PointTransaction).where(
|
||||
PointTransaction.source_id == str(order_id),
|
||||
PointTransaction.type == "recharge",
|
||||
)
|
||||
@@ -259,9 +262,7 @@ async def recharge(
|
||||
return existing_tx
|
||||
|
||||
# 1. 获取或创建用户积分账户
|
||||
- result = await db.execute(
|
||||
- select(UserPoint).where(UserPoint.user_id == user_id)
|
||||
- )
|
||||
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
if not up:
|
||||
@@ -353,7 +354,7 @@ async def consume(
|
||||
直接扣费(后置计费)。
|
||||
|
||||
业务执行成功后调用,按实际消耗直接扣除余额。
|
||||
- 默认不允许欠费(余额不足时抛出 ValueError)。
|
||||
+ 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。
|
||||
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
|
||||
|
||||
:param points: 实际消耗积分(正整数)
|
||||
@@ -383,9 +384,7 @@ async def consume(
|
||||
|
||||
# 2. 获取用户积分账户(加锁)
|
||||
result = await db.execute(
|
||||
- select(UserPoint)
|
||||
- .where(UserPoint.user_id == user_id)
|
||||
- .with_for_update()
|
||||
+ select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
|
||||
)
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
@@ -404,7 +403,7 @@ async def consume(
|
||||
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
|
||||
available = sum(b.remaining for b in batches)
|
||||
if not allow_negative and available < points:
|
||||
- raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
||||
+ raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
||||
|
||||
remaining_to_deduct = points
|
||||
for batch in batches:
|
||||
@@ -440,6 +439,7 @@ async def consume(
|
||||
|
||||
# ── 过期回收 ──────────────────────────────────────────
|
||||
|
||||
+
|
||||
async def expire_batches(db: AsyncSession) -> int:
|
||||
"""
|
||||
回收过期积分批次。返回过期积分总数。
|
||||
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
|
||||
|
||||
# 获取用户积分账户(加锁)
|
||||
up_result = await db.execute(
|
||||
- select(UserPoint)
|
||||
- .where(UserPoint.user_id == batch.user_id)
|
||||
- .with_for_update()
|
||||
+ select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
|
||||
)
|
||||
up = up_result.scalar_one_or_none()
|
||||
if not up:
|
||||
diff --git a/python-api/app/services/script_service.py b/python-api/app/services/script_service.py
|
||||
index 49aa4b1..60f58d5 100644
|
||||
--- a/python-api/app/services/script_service.py
|
||||
+++ b/python-api/app/services/script_service.py
|
||||
@@ -7,9 +7,16 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
+from typing import Any
|
||||
|
||||
from app.ai.model_router import get_model_router
|
||||
from app.ai.prompts import load_prompt_file, load_script_user_prompt
|
||||
+from app.core.exceptions import (
|
||||
+ AIEmptyResponseException,
|
||||
+ AIParseErrorException,
|
||||
+ AITimeoutException,
|
||||
+ PromptNotFoundException,
|
||||
+)
|
||||
from app.schemas.script import ScriptShot
|
||||
from app.services.ai_response_utils import (
|
||||
safe_parse_ai_json_response,
|
||||
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
|
||||
class ScriptService:
|
||||
"""脚本生成服务"""
|
||||
|
||||
-
|
||||
def __init__(self):
|
||||
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
||||
|
||||
-
|
||||
-
|
||||
def _load_prompt(self, name: str) -> str:
|
||||
"""加载 Prompt 模板"""
|
||||
prompt_file = self.prompts_dir / f"{name}.txt"
|
||||
@@ -58,7 +62,7 @@ class ScriptService:
|
||||
# 加载 Prompt
|
||||
system_prompt = load_prompt_file(category, filename)
|
||||
if not system_prompt:
|
||||
- raise ValueError(f"未找到提示词: category={category}, filename={filename}")
|
||||
+ raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
|
||||
|
||||
# 用户提示词
|
||||
user_prompt = load_script_user_prompt(
|
||||
@@ -75,24 +79,26 @@ class ScriptService:
|
||||
)
|
||||
|
||||
if not result.content or not result.content.strip():
|
||||
- raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
||||
+ raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
|
||||
|
||||
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
||||
|
||||
if not success:
|
||||
- raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||||
+ raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||||
|
||||
try:
|
||||
shots_data = validate_and_normalize_shots(parsed_data)
|
||||
|
||||
if not shots_data:
|
||||
- raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
||||
+ raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
|
||||
|
||||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||||
return shots
|
||||
|
||||
+ except (AIEmptyResponseException, AIParseErrorException):
|
||||
+ raise
|
||||
except Exception as e:
|
||||
- raise ValueError(f"分镜数据处理失败: {str(e)}")
|
||||
+ raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
|
||||
|
||||
async def polish_content(
|
||||
self,
|
||||
@@ -144,21 +150,23 @@ class ScriptService:
|
||||
)
|
||||
return result.content.strip()
|
||||
except TimeoutError:
|
||||
- raise ValueError("润色请求超时,请重试")
|
||||
+ raise AITimeoutException("润色请求超时,请重试")
|
||||
+ except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
|
||||
+ raise
|
||||
except Exception as e:
|
||||
- raise ValueError(f"润色失败: {str(e)}")
|
||||
+ raise AIParseErrorException(f"润色失败: {str(e)}")
|
||||
|
||||
async def check_model_health(self) -> dict:
|
||||
"""检查模型健康状态"""
|
||||
model_router = await get_model_router()
|
||||
health_results = await model_router.health_check()
|
||||
|
||||
- models = []
|
||||
+ models: list[dict[str, Any]] = []
|
||||
available_count = 0
|
||||
- recommended = None
|
||||
+ recommended: dict[str, Any] | None = None
|
||||
|
||||
for _provider_id, health in health_results.items():
|
||||
- model_info = {
|
||||
+ model_info: dict[str, Any] = {
|
||||
"id": health.id,
|
||||
"name": health.name,
|
||||
"is_available": health.is_available,
|
||||
@@ -169,9 +177,12 @@ class ScriptService:
|
||||
|
||||
if health.is_available:
|
||||
available_count += 1
|
||||
- if recommended is None or health.response_time < recommended.get(
|
||||
- "response_time", float("inf")
|
||||
- ):
|
||||
+ current_best = (
|
||||
+ float("inf")
|
||||
+ if recommended is None
|
||||
+ else float(recommended.get("response_time") or float("inf"))
|
||||
+ )
|
||||
+ if health.response_time < current_best:
|
||||
recommended = model_info
|
||||
|
||||
total = len(models)
|
||||
@@ -188,7 +199,6 @@ class ScriptService:
|
||||
"""测试指定模型连接"""
|
||||
model_router = await get_model_router()
|
||||
|
||||
-
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
diff --git a/python-api/app/services/vidu_service.py b/python-api/app/services/vidu_service.py
|
||||
index 054823f..61bbdcd 100644
|
||||
--- a/python-api/app/services/vidu_service.py
|
||||
+++ b/python-api/app/services/vidu_service.py
|
||||
@@ -207,8 +207,9 @@ class ViduService:
|
||||
error_type=PlatformErrorType.BAD_REQUEST,
|
||||
)
|
||||
|
||||
- logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
|
||||
- return result.data or {}
|
||||
+ clone_data = result.data or {}
|
||||
+ logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
|
||||
+ return clone_data
|
||||
|
||||
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
|
||||
"""Vidu 声音复刻是同步接口,无独立查询。
|
||||
@@ -270,6 +271,8 @@ class ViduService:
|
||||
result_data = status.result or {}
|
||||
return {
|
||||
"state": ViduAdapter.denormalize_state(status.state),
|
||||
- "creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
|
||||
+ "creations": (
|
||||
+ [{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
|
||||
+ ),
|
||||
"message": status.error_message,
|
||||
}
|
||||
diff --git a/python-api/app/services/volcengine_caption_service.py b/python-api/app/services/volcengine_caption_service.py
|
||||
index 83f59b4..9a565a5 100644
|
||||
--- a/python-api/app/services/volcengine_caption_service.py
|
||||
+++ b/python-api/app/services/volcengine_caption_service.py
|
||||
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
|
||||
error_type=PlatformErrorType.BAD_REQUEST,
|
||||
)
|
||||
|
||||
- logger.warning(
|
||||
- f"{task_name}超过最大轮询次数: task_id={task_id}, "
|
||||
- f"retries={retries}"
|
||||
- )
|
||||
+ logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
|
||||
raise PlatformError(
|
||||
f"{task_name}超时,请稍后重试",
|
||||
platform="volcengine_caption",
|
||||
@@ -9,7 +9,7 @@
|
||||
**美家卡智影**是一款面向桌面端的 AI 视频创作应用,采用"Python 后端 API + Tauri 桌面前端"的混合架构。
|
||||
|
||||
- **产品标识**: `cn.meijiaka.ai-video` / `cn.meijiaka.ai-zy`
|
||||
- **版本**: `1.8.2`
|
||||
- **版本**: `1.9.1`
|
||||
- **核心功能**: AI 脚本生成、AI 配音合成(TTS)、声音复刻、视频生成(Vidu)、视频字幕生成、压制成片(FFmpeg)、项目本地持久化
|
||||
|
||||
### 技术栈总览
|
||||
@@ -25,7 +25,7 @@
|
||||
| 前端构建 | Vite 7 |
|
||||
| 状态管理 | Zustand 5 + Immer 11 |
|
||||
| 路由 | `react-router-dom` |
|
||||
| 数据请求 | 自研智能路由客户端 + SWR |
|
||||
| 数据请求 | HTTP 客户端 (`src/api/client.ts`) + SWR |
|
||||
| 测试(后端) | pytest + pytest-asyncio |
|
||||
| 测试(前端) | Vitest 4 + jsdom + `@testing-library/react` |
|
||||
| 部署 | Docker + Docker Compose + Nginx |
|
||||
@@ -42,7 +42,7 @@
|
||||
├── tauri-app/ # Tauri 桌面前端
|
||||
├── docs/ # 项目文档(架构设计、API 对接指南等)
|
||||
├── scripts/ # 辅助脚本
|
||||
├── .gitlab-ci.yml # GitLab CI/CD 配置
|
||||
├── .github/workflows/ # GitHub Actions CI/CD 配置
|
||||
└── AGENTS.md # 本文档
|
||||
```
|
||||
|
||||
@@ -51,7 +51,7 @@
|
||||
```
|
||||
python-api/
|
||||
├── app/ # 主应用代码
|
||||
│ ├── api/v1/ # API 路由(按领域拆分:auth, script, voice, vidu, caption, tasks, upload, materials, system)
|
||||
│ ├── api/v1/ # API 路由(按领域拆分:auth, script, voice, vidu, caption, tasks, upload, materials, system, bgm_music, cover_background, image, points, update, events)
|
||||
│ ├── core/ # 核心工具(配置加载、安全、异常、Redis 客户端、健康检查)
|
||||
│ ├── db/ # 数据库配置与会话管理
|
||||
│ ├── models/ # SQLAlchemy ORM 模型(BaseModel 提供 UUID 主键 + 时间戳)
|
||||
@@ -62,7 +62,7 @@ python-api/
|
||||
│ ├── crud/ # 数据库 CRUD 封装
|
||||
│ ├── config.py # Pydantic Settings 配置管理
|
||||
│ └── main.py # FastAPI 应用入口(含 lifespan 管理)
|
||||
├── config/ # 运行时配置文件(platform-config.yaml, materials.json)
|
||||
├── config/ # 运行时配置文件(platform-config.yaml, points-config.yaml)
|
||||
├── alembic/ # 数据库迁移脚本
|
||||
├── nginx/ # Nginx 反向代理配置(含 acme.sh SSL 证书脚本)
|
||||
├── Dockerfile # 多阶段构建镜像(builder + production)
|
||||
@@ -81,12 +81,11 @@ python-api/
|
||||
tauri-app/
|
||||
├── src/ # React 前端源码
|
||||
│ ├── api/ # API 客户端与模块
|
||||
│ │ ├── client.ts # 智能路由客户端(HTTP / IPC 自动选择,camelCase ↔ snake_case 自动转换)
|
||||
│ │ ├── generated/ # OpenAPI 生成的 TypeScript 类型
|
||||
│ │ ├── client.ts # HTTP 客户端(camelCase ↔ snake_case 自动转换)
|
||||
│ │ └── modules/ # 按领域拆分的 API 模块
|
||||
│ ├── components/ # 可复用组件(PascalCase 文件夹)
|
||||
│ ├── pages/ # 页面级组件(PascalCase 文件夹)
|
||||
│ ├── store/ # Zustand 状态管理(含 __tests__)
|
||||
│ ├── store/ # Zustand 状态管理
|
||||
│ ├── hooks/ # 自定义 React Hooks
|
||||
│ ├── utils/ # 工具函数
|
||||
│ ├── styles/ # CSS 变量与全局样式
|
||||
@@ -97,11 +96,8 @@ tauri-app/
|
||||
│ │ ├── lib.rs # Tauri Builder、Command 定义、公共类型
|
||||
│ │ ├── ffmpeg_cmd.rs # FFmpeg 命令封装
|
||||
│ │ ├── video_processing.rs # 压制成片业务逻辑
|
||||
│ │ ├── api_proxy.rs # Python API 代理
|
||||
│ │ ├── auth.rs # 认证命令
|
||||
│ │ ├── avatar_cache.rs # 头像缓存
|
||||
│ │ ├── storage/ # 本地存储引擎(项目、认证、配置、头像等)
|
||||
│ │ ├── commands/ # Tauri IPC 命令(按领域拆分)
|
||||
│ │ ├── commands/ # Tauri IPC 命令(按领域拆分:asset, auth_state, cover_avatar, file, product, project, video_compose, voice 等)
|
||||
│ │ └── utils.rs # 通用工具
|
||||
│ ├── Cargo.toml # Rust 依赖
|
||||
│ ├── tauri.conf.json # Tauri 应用配置(窗口、CSP、打包、sidecar)
|
||||
@@ -231,8 +227,7 @@ npm run format:check # Prettier --check
|
||||
npm run stylelint # Stylelint
|
||||
npm run stylelint:fix # Stylelint --fix
|
||||
|
||||
# OpenAPI 类型生成
|
||||
npm run gen:api # 根据 openapi.json 生成 TypeScript 类型
|
||||
# 注:当前项目未配置 OpenAPI 自动生成 TypeScript 类型的脚本
|
||||
```
|
||||
|
||||
---
|
||||
@@ -280,7 +275,7 @@ npm run gen:api # 根据 openapi.json 生成 TypeScript 类型
|
||||
|
||||
1. **判断是否需要本地能力**(FFmpeg、文件系统、系统调用)。
|
||||
2. **不需要** → 直接在 `tauri-app/src/api/modules/` 使用 `client.get/post/put/delete` 调用 Python HTTP API。
|
||||
3. **需要** → 将 endpoint 加入 `tauri-app/src/api/client.ts` 的 IPC 处理逻辑,并在 `tauri-app/src-tauri/src/commands/` 或 `lib.rs` 中实现对应的 `#[tauri::command]` 处理器。
|
||||
3. **需要** → 在 `tauri-app/src/api/modules/` 中通过 `@tauri-apps/api/core` 的 `invoke` 调用 Rust 命令,并在 `tauri-app/src-tauri/src/commands/` 或 `lib.rs` 中实现对应的 `#[tauri::command]` 处理器。
|
||||
|
||||
### 语义层防护网(后端)
|
||||
|
||||
@@ -304,7 +299,6 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
|
||||
- **组件测试**: `@testing-library/react` + `@testing-library/jest-dom`
|
||||
- **文件位置**:
|
||||
- 全局 setup: `src/__tests__/setup.ts`
|
||||
- Store 测试: `src/store/__tests__/*.test.tsx`
|
||||
- 组件/页面测试: 建议放在被测文件同目录或 `__tests__` 子目录中
|
||||
- **Mock 策略**: `setup.ts` 中已全局 mock `localStorage`、`@tauri-apps/api/core` 的 `invoke` 方法、`window.__TAURI_INTERNALS__`。每个测试后自动调用 `vi.clearAllMocks()`。
|
||||
|
||||
@@ -327,13 +321,11 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
|
||||
|
||||
### 部署流程
|
||||
|
||||
#### 测试环境(GitLab CI)
|
||||
#### 测试环境(GitHub Actions)
|
||||
|
||||
`.gitlab-ci.yml` 定义了 `deploy-backend` 任务:
|
||||
1. 在部署服务器拉取代码(`master` 分支触发)。
|
||||
2. 构建 api + scheduler 镜像(`docker-compose.test.yml`)。
|
||||
3. 启动服务,健康检查 `curl http://localhost:8081/health`。
|
||||
4. 清理 7 天前的旧镜像。
|
||||
`.github/workflows/` 存放 CI/CD 工作流:
|
||||
- `release.yml`:Tauri 桌面端 Release 打包工作流(按 tag 触发,支持 macOS / Windows 平台及 sidecar 二进制下载)。
|
||||
- 后端 api + scheduler 的测试/生产部署目前通过手动执行 `docker-compose.test.yml` / `docker-compose.prod.yml` 完成。
|
||||
|
||||
#### 生产环境
|
||||
|
||||
@@ -384,7 +376,7 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
|
||||
- **ORM**: SQLAlchemy 2.0(异步,asyncpg 驱动)
|
||||
- **迁移工具**: Alembic
|
||||
- **基础模型**: `app.models.base.BaseModel` 提供 UUID 主键、`created_at`、`updated_at`
|
||||
- **当前模型**: `User`(用户/设备认证)
|
||||
- **当前模型**: 包括 `User`(用户/设备认证)、`UserDevice`、`UserPoint`、`PointRechargeOrder`、`PointTransaction`、`BrollCategory`、`BrollMaterial`、`BrollTag`、`CoverBackground`、`BgmMusic`、`Update` 等
|
||||
- **迁移注意**: Alembic 使用同步连接(psycopg2),会自动将 `+asyncpg` 替换掉。
|
||||
|
||||
---
|
||||
|
||||
+11
-17
@@ -115,26 +115,20 @@ lint-semantic: ## 语义层禁词检查(防止供应商术语泄漏到业务
|
||||
echo "❌ API 层发现 element_id(应使用 provider_element_id 或 human_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# Scheduler 层禁止 task_id 作为内部变量/Redis key(读取 Provider 返回除外)
|
||||
@errs=$$(grep -rn '\btask_id\b' app/scheduler --include='*.py' \
|
||||
| grep -v 'job_id' \
|
||||
| grep -v '__pycache__' \
|
||||
| grep -v '\.get("task_id")' \
|
||||
| grep -v 'result.get("task_id")' \
|
||||
| grep -v 'task_type' \
|
||||
| grep -v '"task_id"' \
|
||||
| grep -v "'task_id'"); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ Scheduler 层发现 task_id(应使用 job_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# Scheduler 层 Redis key 必须使用 job: 而非 task:
|
||||
@errs=$$(grep -rn 'redis.*task:' app/scheduler --include='*.py' \
|
||||
@# Scheduler 层统一使用 task 命名,禁止混用 job
|
||||
@errs=$$(grep -rn '\bjob_id\b' app/scheduler --include='*.py' \
|
||||
| grep -v '__pycache__'); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ Scheduler Redis key 使用 task:(应使用 job:)"; \
|
||||
echo "❌ Scheduler 层发现 job_id(应使用 task_id)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@# Scheduler 层 Redis key 必须使用 task: 而非 job:
|
||||
@errs=$$(grep -rn 'redis.*job:' app/scheduler --include='*.py' \
|
||||
| grep -v '__pycache__'); \
|
||||
if [ -n "$$errs" ]; then \
|
||||
echo "$$errs"; \
|
||||
echo "❌ Scheduler Redis key 使用 job:(应使用 task:)"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "✅ 语义层检查通过"
|
||||
|
||||
@@ -113,7 +113,7 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||||
result = await self.provider.clone_voice(
|
||||
audio_url=payload["audio_url"],
|
||||
voice_id=payload["voice_id"],
|
||||
text=payload.get("text"),
|
||||
text=payload.get("text") or "",
|
||||
)
|
||||
return AdapterResponse(
|
||||
success=True,
|
||||
@@ -219,6 +219,7 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||||
) -> bool:
|
||||
"""验证 Vidu 回调 HMAC-SHA256 签名"""
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HTTP 头大小写不敏感:建立小写 key 的查找表
|
||||
@@ -233,6 +234,9 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||||
if not all([signature, algorithm, access_key, signed_headers_str, date]):
|
||||
logger.warning(f"[Vidu] 签名验证失败: 缺少必要头, headers={list(headers.keys())}")
|
||||
return False
|
||||
assert signature is not None
|
||||
assert signed_headers_str is not None
|
||||
assert date is not None
|
||||
if algorithm != "hmac-sha256":
|
||||
logger.warning(f"[Vidu] 签名验证失败: 不支持的算法 {algorithm}")
|
||||
return False
|
||||
@@ -256,17 +260,15 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||||
canonical_query_string = parsed.query or ""
|
||||
|
||||
signing_string = (
|
||||
f"POST\n"
|
||||
f"{http_uri}\n"
|
||||
f"{canonical_query_string}\n"
|
||||
f"vidu\n"
|
||||
f"{date}\n"
|
||||
f"POST\n" f"{http_uri}\n" f"{canonical_query_string}\n" f"vidu\n" f"{date}\n"
|
||||
)
|
||||
for name in header_names:
|
||||
signing_string += f"{name}:{header_values[name]}\n"
|
||||
|
||||
expected = base64.b64encode(
|
||||
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
|
||||
hmac.new(
|
||||
secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
if not hmac.compare_digest(signature, expected):
|
||||
@@ -307,6 +309,12 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
|
||||
|
||||
return TaskStatus(
|
||||
state=self.normalize_state(state),
|
||||
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
|
||||
error_message=(data.get("err_code") or data.get("message")) if state == "failed" else None,
|
||||
result=(
|
||||
{"video_url": video_url, "creations": creations, "task_id": task_id}
|
||||
if video_url
|
||||
else {"task_id": task_id}
|
||||
),
|
||||
error_message=(
|
||||
(data.get("err_code") or data.get("message")) if state == "failed" else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -69,11 +69,11 @@ class VolcengineArkAdapter(PlatformAdapter, SyncCapable):
|
||||
)
|
||||
|
||||
elif method == Method.EMBEDDING:
|
||||
result = await self.provider.create_embeddings(
|
||||
embedding_result: dict[str, Any] = await self.provider.create_embeddings(
|
||||
texts=payload["texts"],
|
||||
model=payload.get("model"),
|
||||
)
|
||||
return AdapterResponse(success=True, data=result)
|
||||
return AdapterResponse(success=True, data=embedding_result)
|
||||
|
||||
else:
|
||||
return AdapterResponse(
|
||||
|
||||
@@ -24,7 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
class LLMGateway:
|
||||
"""LLM 调用网关"""
|
||||
|
||||
def __init__(self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None):
|
||||
def __init__(
|
||||
self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None
|
||||
):
|
||||
self.adapters = adapters
|
||||
self.fallback_chains = fallback_chains or {}
|
||||
|
||||
@@ -55,15 +57,18 @@ class LLMGateway:
|
||||
for mid in models_to_try:
|
||||
adapter = self._get_adapter(platform)
|
||||
try:
|
||||
result = await adapter.call(Method.CHAT, {
|
||||
result = await adapter.call(
|
||||
Method.CHAT,
|
||||
{
|
||||
"prompt": prompt,
|
||||
"model": mid,
|
||||
**kwargs,
|
||||
})
|
||||
},
|
||||
)
|
||||
if result.success:
|
||||
if mid != model_id:
|
||||
logger.warning(f"[LLMGateway] 模型降级成功: {model_id} → {mid}")
|
||||
return result.data
|
||||
return result.data or {}
|
||||
else:
|
||||
last_error = PlatformError(
|
||||
result.error_message or f"模型 {mid} 调用失败",
|
||||
@@ -82,6 +87,3 @@ class LLMGateway:
|
||||
platform=platform,
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.ai.adapters.constants import Method
|
||||
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||||
from app.core.config_loader import AIModelConfigLoader, get_config_loader
|
||||
from app.core.exceptions import AppException, PlatformError
|
||||
from app.platform_gateway import PlatformGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,9 +27,7 @@ class _PlatformInstance:
|
||||
self.gateway = gateway
|
||||
self.provider_id = config.get("id", "")
|
||||
|
||||
async def generate(
|
||||
self, model_name: str, prompt: str, **kwargs
|
||||
) -> GenerationResult:
|
||||
async def generate(self, model_name: str, prompt: str, **kwargs) -> GenerationResult:
|
||||
"""调用生成(通过 PlatformGateway)"""
|
||||
if self.gateway:
|
||||
result = await self.gateway.call_sync(
|
||||
@@ -41,9 +40,7 @@ class _PlatformInstance:
|
||||
},
|
||||
)
|
||||
if not result.success:
|
||||
raise ProviderError(
|
||||
result.error_message or f"{self.provider_id} 调用失败"
|
||||
)
|
||||
raise ProviderError(result.error_message or f"{self.provider_id} 调用失败")
|
||||
data = result.data or {}
|
||||
return GenerationResult(
|
||||
content=data.get("content", ""),
|
||||
@@ -64,7 +61,11 @@ class _PlatformInstance:
|
||||
id=model_name or self.provider_id,
|
||||
name=self.provider_id,
|
||||
is_available=adapter_result.success,
|
||||
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
|
||||
response_time=(
|
||||
adapter_result.data.get("response_time_ms", 0)
|
||||
if adapter_result.data
|
||||
else 0
|
||||
),
|
||||
last_error=adapter_result.error_message,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -89,7 +90,6 @@ class ModelRouter:
|
||||
- 模型自动选择
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.platforms: dict[str, _PlatformInstance] = {}
|
||||
self._config_loader: AIModelConfigLoader | None = None
|
||||
@@ -249,11 +249,7 @@ class ModelRouter:
|
||||
if task_type:
|
||||
model_id = self.select_model_for_task(task_type)
|
||||
if model_id is None:
|
||||
models = (
|
||||
self._config_loader.get_enabled_models()
|
||||
if self._config_loader
|
||||
else []
|
||||
)
|
||||
models = self._config_loader.get_enabled_models() if self._config_loader else []
|
||||
if models:
|
||||
model_id = models[0].id
|
||||
else:
|
||||
@@ -270,9 +266,9 @@ class ModelRouter:
|
||||
params = {**model.default_params, **kwargs}
|
||||
|
||||
try:
|
||||
return await platform.generate(
|
||||
prompt=prompt, model_name=model.model_name, **params
|
||||
)
|
||||
return await platform.generate(prompt=prompt, model_name=model.model_name, **params)
|
||||
except (PlatformError, AppException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ProviderError(f"模型 {model_id} 生成失败: {e}") from e
|
||||
|
||||
@@ -290,7 +286,11 @@ class ModelRouter:
|
||||
id=model.id,
|
||||
name=model.display_name,
|
||||
is_available=adapter_result.success,
|
||||
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
|
||||
response_time=(
|
||||
adapter_result.data.get("response_time_ms", 0)
|
||||
if adapter_result.data
|
||||
else 0
|
||||
),
|
||||
last_error=adapter_result.error_message,
|
||||
)
|
||||
else:
|
||||
@@ -306,11 +306,12 @@ class ModelRouter:
|
||||
# fallback: 直接通过 PlatformInstance
|
||||
results = {}
|
||||
if model_id:
|
||||
model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
if model:
|
||||
platform = self.platforms.get(model.platform_id)
|
||||
target_model = self._config_loader.get_model(model_id) if self._config_loader else None
|
||||
if target_model is None:
|
||||
raise ProviderError(f"模型不存在: {model_id}")
|
||||
platform = self.platforms.get(target_model.platform_id)
|
||||
if platform:
|
||||
results[model_id] = await platform.health_check(model.model_name)
|
||||
results[model_id] = await platform.health_check(target_model.model_name)
|
||||
else:
|
||||
if self._config_loader:
|
||||
for model in self._config_loader.get_enabled_models():
|
||||
|
||||
@@ -94,10 +94,12 @@ def list_categories() -> list[dict]:
|
||||
for cat_meta in meta.get("categories", []):
|
||||
cat_code = cat_meta["code"]
|
||||
cat_name = cat_meta.get("name", cat_code)
|
||||
categories.append({
|
||||
categories.append(
|
||||
{
|
||||
"code": cat_code,
|
||||
"name": cat_name,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return categories
|
||||
|
||||
@@ -135,11 +137,13 @@ def list_prompt_files(category: str) -> list[dict]:
|
||||
else:
|
||||
label = name
|
||||
desc = ""
|
||||
files.append({
|
||||
files.append(
|
||||
{
|
||||
"filename": f.name,
|
||||
"label": label.strip(),
|
||||
"desc": desc.strip(),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
@@ -174,7 +178,6 @@ def load_system_prompt(category: str, subcategory: str) -> str:
|
||||
return load_prompt_file(category, chosen["filename"])
|
||||
|
||||
|
||||
|
||||
def load_script_user_prompt(
|
||||
topic: str,
|
||||
extra_params: str | None = None,
|
||||
|
||||
@@ -50,12 +50,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
【分镜固定结构规则】
|
||||
开篇的分镜为:一段人物出镜
|
||||
中间内容全部用空镜,空镜(内置完整素材库标题)与文案内容需匹配
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
合同签署
|
||||
厨卫原始毛坯状态-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
装修合同核对-现场交底
|
||||
卧室原始状态-翻新基础
|
||||
厨卫原始状态-翻新基础
|
||||
@@ -66,12 +66,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
验收合格签字确认-全屋验收
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
墙面纯色面漆涂刷-面漆涂刷
|
||||
乳胶漆调配-面漆涂刷
|
||||
卫生间陶粒回填
|
||||
防水翻车漏水-施工翻车镜
|
||||
防水翻车漏水-施工翻车
|
||||
轻钢龙骨骨架搭建-吊顶造型
|
||||
木龙骨基础框架固定-吊顶造型
|
||||
全屋定制板材检查
|
||||
@@ -53,11 +53,11 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
【分镜固定结构规则】
|
||||
开篇的分镜为:(可选用讨好装修师傅、恶搞开篇或施工翻车镜,最好能贴近话术内容和主题)+ 一段人物出镜 + 一段空镜补充,不得有 2 段人物出镜
|
||||
分点阐述全部用空镜,空镜(素材库标题)与文案内容需匹配
|
||||
@@ -94,7 +94,7 @@ duration: “分镜时长”(如 “5s”,时长为 "配音文案" 的字数
|
||||
{
|
||||
"id": 3,
|
||||
"type": "empty_shot",
|
||||
"scene": "墙体掉落-施工翻车镜",
|
||||
"scene": "墙体掉落-施工翻车",
|
||||
"voiceover": "今天我告诉你几个监工关键时间点,你必须在场。",
|
||||
"duration": "5.00s"
|
||||
}
|
||||
|
||||
@@ -51,10 +51,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -214,12 +214,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -36,10 +36,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -199,12 +199,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -39,10 +39,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -202,12 +202,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -52,10 +52,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -215,12 +215,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -41,10 +41,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -204,12 +204,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -42,10 +42,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -205,12 +205,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -41,10 +41,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -204,12 +204,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -127,12 +127,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
踢脚线安装验收-软装进场
|
||||
【分镜固定结构规则】
|
||||
开篇的分镜为:一段网红开篇(可选用恶搞开篇或施工翻车镜,最好能贴近硬装收尾、软装进场、装修细节避坑主题,优先选工地恶搞、墙面空鼓、硬装完工全屋全景等相关)+ 一段人物出镜 + 一段空镜补充,不得有 2 段人物出镜
|
||||
|
||||
@@ -44,10 +44,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -207,12 +207,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -43,10 +43,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -206,12 +206,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -41,10 +41,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -204,12 +204,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -43,10 +43,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -206,12 +206,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -70,10 +70,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -233,12 +233,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -48,10 +48,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -211,12 +211,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -36,10 +36,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -199,12 +199,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -43,10 +43,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -206,12 +206,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -61,10 +61,10 @@ type为segment=人物出镜;type为empty_shot=从下方内置素材库选匹
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -224,12 +224,12 @@ type为segment=人物出镜;type为empty_shot=从下方内置素材库选匹
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -63,10 +63,10 @@
|
||||
客厅原始墙面-毛坯基础
|
||||
强弱电箱原始特写-毛坯基础
|
||||
毛坯全屋广角全景-毛坯基础
|
||||
阳台原始结构空镜-毛坯基础
|
||||
阳台原始结构-毛坯基础
|
||||
墙面点位弹线-现场交底
|
||||
开关插座定位-现场交底
|
||||
开工仪式简单镜头-现场交底
|
||||
开工仪式-现场交底
|
||||
施工方案现场讲解-现场交底
|
||||
甲乙工长三方对接-现场交底
|
||||
给排水点位标记-现场交底
|
||||
@@ -226,12 +226,12 @@
|
||||
暴力拆除-恶搞开篇
|
||||
炫技-恶搞开篇
|
||||
贴砖恶搞-恶搞开篇
|
||||
墙体掉落-施工翻车镜
|
||||
墙面开裂-施工翻车镜
|
||||
墙面空鼓-施工翻车镜
|
||||
水管错位-施工翻车镜
|
||||
电线乱接-施工翻车镜
|
||||
防水翻车漏水-施工翻车镜
|
||||
墙体掉落-施工翻车
|
||||
墙面开裂-施工翻车
|
||||
墙面空鼓-施工翻车
|
||||
水管错位-施工翻车
|
||||
电线乱接-施工翻车
|
||||
防水翻车漏水-施工翻车
|
||||
墙面漆面细节查验-全屋验收
|
||||
柜体开合顺畅度检查-全屋验收
|
||||
踢脚线安装验收-软装进场
|
||||
|
||||
@@ -14,8 +14,9 @@ from app.ai.providers.base import (
|
||||
# 火山方舟官方 SDK Provider
|
||||
# 需要: pip install 'volcengine-python-sdk[ark]'
|
||||
try:
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider
|
||||
from app.ai.providers.volcengine_provider import VolcengineProvider as _VolcengineProvider
|
||||
|
||||
VolcengineProvider: type | None = _VolcengineProvider
|
||||
VOLCENGINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
VOLCENGINE_AVAILABLE = False
|
||||
|
||||
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _map_vidu_error(status: int, message: str) -> PlatformError:
|
||||
"""把 Vidu HTTP 错误映射为标准 PlatformError"""
|
||||
# Vidu 错误码分类
|
||||
_VIDU_AUDIT_ERROR_CODES = {
|
||||
"TaskPromptPolicyViolation",
|
||||
"AuditSubmitIllegal",
|
||||
"CreationPolicyViolation",
|
||||
"PhotoAuditNotPass",
|
||||
"AuditFailed",
|
||||
"ImageCheckBodyJointsFailed",
|
||||
"ImageCheckFaceFailed",
|
||||
"ImageObjectsUndetected",
|
||||
"FaceDetectFailure",
|
||||
"FaceDetectNotPass",
|
||||
"NoFaceDetected",
|
||||
"MultiFaceDetected",
|
||||
}
|
||||
|
||||
_VIDU_RETRYABLE_ERROR_CODES = {
|
||||
"InternalServiceFailure",
|
||||
"ModelUnavailable",
|
||||
"Unknown",
|
||||
}
|
||||
|
||||
_VIDU_RATE_LIMIT_ERROR_CODES = {
|
||||
"QuotaExceeded",
|
||||
"TooManyRequests",
|
||||
"SystemThrottling",
|
||||
"OperationInProcess",
|
||||
}
|
||||
|
||||
|
||||
def _extract_vidu_error_code(message: str | None) -> str | None:
|
||||
"""从 Vidu 错误信息中提取错误码"""
|
||||
if not message:
|
||||
return None
|
||||
# Vidu 错误码格式:"ErrorCode: 中文描述"
|
||||
return message.split(":")[0].strip() or None
|
||||
|
||||
|
||||
def _map_vidu_error(
|
||||
status: int,
|
||||
message: str,
|
||||
*,
|
||||
err_code: str | None = None,
|
||||
) -> PlatformError:
|
||||
"""把 Vidu HTTP 错误映射为标准 PlatformError
|
||||
|
||||
优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
|
||||
"""
|
||||
raw_code = err_code or _extract_vidu_error_code(message)
|
||||
|
||||
# 1. 内容安全/审核类:不可重试
|
||||
if raw_code in _VIDU_AUDIT_ERROR_CODES:
|
||||
return PlatformError(
|
||||
message=message,
|
||||
platform="vidu",
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.CONTENT_VIOLATION,
|
||||
status_code=status,
|
||||
raw_code=raw_code,
|
||||
)
|
||||
|
||||
# 2. 平台内部/模型不可用:可重试
|
||||
if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
|
||||
return PlatformError(
|
||||
message=message,
|
||||
platform="vidu",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.SERVER_ERROR,
|
||||
status_code=status,
|
||||
raw_code=raw_code,
|
||||
)
|
||||
|
||||
# 3. 限流类:可重试
|
||||
if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
|
||||
return PlatformError(
|
||||
message=message,
|
||||
platform="vidu",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.RATE_LIMIT,
|
||||
status_code=status,
|
||||
raw_code=raw_code,
|
||||
)
|
||||
|
||||
# 4. HTTP status 兜底
|
||||
mapping = {
|
||||
429: (PlatformErrorType.RATE_LIMIT, True),
|
||||
401: (PlatformErrorType.AUTH_FAILED, False),
|
||||
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
|
||||
retryable=retryable,
|
||||
error_type=error_type,
|
||||
status_code=status,
|
||||
raw_code=raw_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +149,9 @@ class ViduProvider:
|
||||
from app.core.platform_config import get_platform_config_loader
|
||||
|
||||
platform_config = get_platform_config_loader().get_platform("vidu")
|
||||
self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
|
||||
self.base_url = (
|
||||
platform_config.base_url if platform_config else "https://api.vidu.cn"
|
||||
).rstrip("/")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
|
||||
@@ -135,9 +220,12 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
|
||||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(
|
||||
f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
)
|
||||
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu TTS] 网络错误: {e}")
|
||||
@@ -182,9 +270,14 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
|
||||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(
|
||||
f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
)
|
||||
raise _map_vidu_error(
|
||||
resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
|
||||
)
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu Clone] 网络错误: {e}")
|
||||
@@ -238,9 +331,14 @@ class ViduProvider:
|
||||
resp = await self.client.post(url, json=body)
|
||||
data = resp.json()
|
||||
if resp.status_code != 200 or data.get("state") == "failed":
|
||||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
|
||||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(
|
||||
f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
)
|
||||
raise _map_vidu_error(
|
||||
resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
|
||||
)
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu LipSync] 网络错误: {e}")
|
||||
@@ -264,9 +362,14 @@ class ViduProvider:
|
||||
resp = await self.client.get(url)
|
||||
data = resp.json()
|
||||
if resp.status_code != 200:
|
||||
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
|
||||
raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
|
||||
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
|
||||
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
|
||||
logger.error(
|
||||
f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
|
||||
)
|
||||
raise _map_vidu_error(
|
||||
resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
|
||||
)
|
||||
return data
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
logger.error(f"[Vidu Query] 网络错误: {e}")
|
||||
|
||||
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
||||
if code is not None and code in error_mapping:
|
||||
error_type, retryable = error_mapping[code]
|
||||
return PlatformError(
|
||||
message, platform="volcengine_caption",
|
||||
retryable=retryable, error_type=error_type,
|
||||
message,
|
||||
platform="volcengine_caption",
|
||||
retryable=retryable,
|
||||
error_type=error_type,
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
|
||||
}
|
||||
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||||
return PlatformError(
|
||||
message, platform="volcengine_caption",
|
||||
retryable=retryable, error_type=error_type,
|
||||
message,
|
||||
platform="volcengine_caption",
|
||||
retryable=retryable,
|
||||
error_type=error_type,
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
|
||||
max_lines: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""提交字幕生成任务,返回 {id: task_id}"""
|
||||
params = {
|
||||
params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"language": language,
|
||||
"caption_type": caption_type,
|
||||
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
|
||||
except PlatformError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
||||
raise _map_caption_error(
|
||||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||||
) from e
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
raise PlatformError(
|
||||
f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
||||
retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||||
f"字幕服务网络错误: {e}",
|
||||
platform="volcengine_caption",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.TIMEOUT,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
|
||||
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
|
||||
blocking: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""查询字幕任务结果,返回原始 JSON"""
|
||||
params = {
|
||||
params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"id": task_id,
|
||||
"blocking": 1 if blocking else 0,
|
||||
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
|
||||
except PlatformError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
|
||||
raise _map_caption_error(
|
||||
e.response.status_code, f"HTTP错误: {e.response.status_code}"
|
||||
) from e
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
raise PlatformError(
|
||||
f"字幕服务网络错误: {e}", platform="volcengine_caption",
|
||||
retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||||
f"字幕服务网络错误: {e}",
|
||||
platform="volcengine_caption",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.TIMEOUT,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
|
||||
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
|
||||
sta_punc_mode: int = 3,
|
||||
) -> dict[str, Any]:
|
||||
"""提交自动字幕打轴任务,返回 {id: task_id}"""
|
||||
params = {
|
||||
params: dict[str, str | int] = {
|
||||
"appid": self.appid,
|
||||
"caption_type": caption_type,
|
||||
"sta_punc_mode": sta_punc_mode,
|
||||
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if "id" not in data:
|
||||
raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
|
||||
raise _map_caption_error(
|
||||
500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
|
||||
)
|
||||
return data
|
||||
except PlatformError:
|
||||
raise
|
||||
|
||||
@@ -34,8 +34,10 @@ def _map_mediakit_error(status: int, message: str, code: int | None = None) -> P
|
||||
}
|
||||
error_type, retryable = error_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
|
||||
return PlatformError(
|
||||
message, platform="volcengine_mediakit",
|
||||
retryable=retryable, error_type=error_type,
|
||||
message,
|
||||
platform="volcengine_mediakit",
|
||||
retryable=retryable,
|
||||
error_type=error_type,
|
||||
status_code=status,
|
||||
)
|
||||
|
||||
@@ -167,8 +169,10 @@ class VolcengineMediakitProvider:
|
||||
) from e
|
||||
except (httpx.NetworkError, httpx.TimeoutException) as e:
|
||||
raise PlatformError(
|
||||
f"MediaKit 网络错误: {e}", platform="volcengine_mediakit",
|
||||
retryable=True, error_type=PlatformErrorType.TIMEOUT,
|
||||
f"MediaKit 网络错误: {e}",
|
||||
platform="volcengine_mediakit",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.TIMEOUT,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise _map_mediakit_error(500, f"抠图失败: {str(e)}") from e
|
||||
|
||||
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
|
||||
|
||||
if status == 429 or "rate limit" in message.lower():
|
||||
return PlatformError(
|
||||
message, platform="volcengine_ark", retryable=True,
|
||||
error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
|
||||
message,
|
||||
platform="volcengine_ark",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.RATE_LIMIT,
|
||||
status_code=status,
|
||||
)
|
||||
elif status in (401, 403) or "authentication" in message.lower():
|
||||
return PlatformError(
|
||||
message, platform="volcengine_ark", retryable=False,
|
||||
error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
|
||||
message,
|
||||
platform="volcengine_ark",
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.AUTH_FAILED,
|
||||
status_code=status,
|
||||
)
|
||||
elif status and status >= 500:
|
||||
return PlatformError(
|
||||
message, platform="volcengine_ark", retryable=True,
|
||||
error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
|
||||
message,
|
||||
platform="volcengine_ark",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.SERVER_ERROR,
|
||||
status_code=status,
|
||||
)
|
||||
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
|
||||
return PlatformError(
|
||||
message, platform="volcengine_ark", retryable=True,
|
||||
message,
|
||||
platform="volcengine_ark",
|
||||
retryable=True,
|
||||
error_type=PlatformErrorType.TIMEOUT,
|
||||
)
|
||||
else:
|
||||
return PlatformError(
|
||||
message, platform="volcengine_ark", retryable=False,
|
||||
message,
|
||||
platform="volcengine_ark",
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import select
|
||||
@@ -20,7 +22,7 @@ security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
async def get_db() -> AsyncSession:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession]:
|
||||
"""获取数据库 Session"""
|
||||
async for session in db_session():
|
||||
yield session
|
||||
|
||||
@@ -80,7 +80,7 @@ async def send_code(
|
||||
async def login(
|
||||
request: MobileLoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
http_request: Request = None,
|
||||
http_request: Request = None, # type: ignore[assignment]
|
||||
):
|
||||
"""
|
||||
手机号验证码登录
|
||||
@@ -133,7 +133,7 @@ async def login(
|
||||
async def login_password(
|
||||
request: PasswordLoginRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
http_request: Request = None,
|
||||
http_request: Request = None, # type: ignore[assignment]
|
||||
):
|
||||
"""
|
||||
手机号密码登录
|
||||
|
||||
@@ -24,19 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/caption", tags=["Caption"])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/ata/align")
|
||||
async def auto_align_caption(
|
||||
request_body: AutoAlignSubmitRequest,
|
||||
@@ -88,9 +75,3 @@ async def auto_align_caption(
|
||||
except Exception as e:
|
||||
logger.error(f"自动打轴异常: {e}")
|
||||
raise HTTPException(status_code=500, detail="字幕打轴失败,请稍后重试")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.config import get_settings
|
||||
from app.core.exceptions import InsufficientPointsException
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
from app.services import point_service as ps
|
||||
@@ -33,6 +34,7 @@ settings = get_settings()
|
||||
|
||||
# ── Dependencies ──
|
||||
|
||||
|
||||
async def get_mediakit_service(request: Request) -> VolcengineMediakitService:
|
||||
"""FastAPI Depends:从 app.state 获取全局 VolcengineMediakitService 实例。"""
|
||||
service = getattr(request.app.state, "volcengine_mediakit_service", None)
|
||||
@@ -46,6 +48,7 @@ async def get_mediakit_service(request: Request) -> VolcengineMediakitService:
|
||||
|
||||
# ── Schemas ──
|
||||
|
||||
|
||||
class ImageUploadResponse(BaseModel):
|
||||
"""图片上传响应"""
|
||||
|
||||
@@ -64,11 +67,15 @@ class RemoveBackgroundRequest(BaseModel):
|
||||
"""抠图请求"""
|
||||
|
||||
image_url: str = Field(..., description="原始图片 URL")
|
||||
scene: str = Field(default="human", description="场景类型:general(通用)、human(人物,默认白色描边)或 product(商品)")
|
||||
scene: str = Field(
|
||||
default="human",
|
||||
description="场景类型:general(通用)、human(人物,默认白色描边)或 product(商品)",
|
||||
)
|
||||
|
||||
|
||||
# ── Endpoints ──
|
||||
|
||||
|
||||
@router.post("/upload/image", response_model=ApiResponse[ImageUploadResponse])
|
||||
async def upload_image(
|
||||
file: UploadFile = File(..., description="图片文件"),
|
||||
@@ -178,9 +185,8 @@ async def remove_background(
|
||||
required_points = ps._calculate_cost("cover_avatar")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
raise InsufficientPointsException(
|
||||
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -193,9 +199,7 @@ async def remove_background(
|
||||
)
|
||||
|
||||
if not result.image_url:
|
||||
logger.error(
|
||||
f"[RemoveBackground] 抠图返回空 URL: raw={result.raw}"
|
||||
)
|
||||
logger.error(f"[RemoveBackground] 抠图返回空 URL: raw={result.raw}")
|
||||
raise HTTPException(status_code=500, detail="抠图失败:未返回结果图片 URL")
|
||||
|
||||
logger.info(f"[RemoveBackground] 抠图成功: {result.image_url[:80]}...")
|
||||
@@ -256,7 +260,5 @@ async def remove_background(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[RemoveBackground] 抠图失败: image_url={req.image_url[:80]}..., error={e}"
|
||||
)
|
||||
logger.error(f"[RemoveBackground] 抠图失败: image_url={req.image_url[:80]}..., error={e}")
|
||||
raise HTTPException(status_code=500, detail=f"抠图失败: {e}")
|
||||
|
||||
@@ -59,9 +59,7 @@ async def batch_match_materials_endpoint(
|
||||
|
||||
根据分镜列表一次性匹配所有素材,自动进行项目级去重。
|
||||
"""
|
||||
raw_scenes = [
|
||||
{"scene": s.scene, "duration": s.duration} for s in request.scenes
|
||||
]
|
||||
raw_scenes = [{"scene": s.scene, "duration": s.duration} for s in request.scenes]
|
||||
|
||||
results = await batch_match(
|
||||
db,
|
||||
@@ -70,8 +68,7 @@ async def batch_match_materials_endpoint(
|
||||
)
|
||||
|
||||
matched: list[MaterialInfo | None] = [
|
||||
MaterialInfo(url=r["url"], duration=r["duration"]) if r else None
|
||||
for r in results
|
||||
MaterialInfo(url=r["url"], duration=r["duration"]) if r else None for r in results
|
||||
]
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.core.exceptions import InsufficientPointsException
|
||||
from app.crud.point_recharge_order import point_recharge_order
|
||||
from app.crud.point_transaction import point_transaction
|
||||
from app.models.user import User
|
||||
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/points", tags=["Points"])
|
||||
|
||||
# ── 余额查询 ──────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/balance", response_model=ApiResponse[PointBalanceResponse])
|
||||
async def get_balance(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -47,6 +49,7 @@ async def get_balance(
|
||||
|
||||
# ── 流水查询 ──────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/transactions", response_model=ApiResponse[PointTransactionListResponse])
|
||||
async def list_transactions(
|
||||
pagination: PaginationParams = Depends(),
|
||||
@@ -127,12 +130,13 @@ async def list_transactions(
|
||||
|
||||
# ── 充值 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/recharge", response_model=ApiResponse[RechargeResponse])
|
||||
async def create_recharge_order(
|
||||
request: RechargeRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
http_request: Request = None,
|
||||
http_request: Request = None, # type: ignore[assignment]
|
||||
):
|
||||
"""
|
||||
创建积分充值订单(微信支付 Native 扫码)
|
||||
@@ -213,6 +217,7 @@ async def create_recharge_order(
|
||||
logger.error(f"[Points] 微信统一下单未返回 code_url: {wx_result}")
|
||||
order.status = "failed"
|
||||
order.error_msg = "微信未返回二维码链接"
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=500, detail="微信支付下单失败")
|
||||
|
||||
order.prepay_id = wx_result.get("prepay_id")
|
||||
@@ -303,9 +308,7 @@ async def handle_wxpay_notify(
|
||||
return _wx_response()
|
||||
|
||||
# 查找订单
|
||||
order = await point_recharge_order.get_by_out_trade_no(
|
||||
db, out_trade_no=out_trade_no
|
||||
)
|
||||
order = await point_recharge_order.get_by_out_trade_no(db, out_trade_no=out_trade_no)
|
||||
if not order:
|
||||
logger.error(f"[WechatPay] 回调订单不存在: {out_trade_no}")
|
||||
return _wx_response()
|
||||
@@ -371,6 +374,7 @@ async def handle_wxpay_notify(
|
||||
await db.rollback()
|
||||
logger.exception(f"[WechatPay] 订单 {out_trade_no} 充值积分失败: {e}")
|
||||
# 记录错误但不抛出,返回 SUCCESS 避免微信重试
|
||||
order.status = "failed"
|
||||
order.error_msg = f"充值积分失败: {e}"
|
||||
await db.commit()
|
||||
|
||||
@@ -400,9 +404,7 @@ async def query_recharge_status(
|
||||
|
||||
wxpay = get_wxpay_service()
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
|
||||
wx_result = await wxpay.query_order(
|
||||
client, out_trade_no=order.out_trade_no
|
||||
)
|
||||
wx_result = await wxpay.query_order(client, out_trade_no=order.out_trade_no)
|
||||
|
||||
order.query_result = str(wx_result)
|
||||
trade_state = wx_result.get("trade_state", "")
|
||||
@@ -465,6 +467,7 @@ async def query_recharge_status(
|
||||
|
||||
# ── 充值档位查询 ──────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/recharge-options", response_model=ApiResponse[list[dict]])
|
||||
async def get_recharge_options(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -480,6 +483,7 @@ async def get_recharge_options(
|
||||
|
||||
# ── 扣费业务类型查询 ──────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/chargeable-types", response_model=ApiResponse[list[str]])
|
||||
async def get_chargeable_types(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -496,6 +500,7 @@ async def get_chargeable_types(
|
||||
|
||||
# ── 积分规则查询 ──────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/rules", response_model=ApiResponse[list[dict]])
|
||||
async def get_points_rules(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -530,9 +535,9 @@ async def get_points_rules(
|
||||
# ── 积分预估查询 ──────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
# ── 今日消费统计 ──────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/today-consumed", response_model=ApiResponse[dict])
|
||||
async def get_today_consumed(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -545,6 +550,7 @@ async def get_today_consumed(
|
||||
|
||||
# ── 直接消费扣费(前端/Rust 层调用)───────────────────
|
||||
|
||||
|
||||
@router.post("/consume", response_model=ApiResponse[dict])
|
||||
async def consume_points(
|
||||
request: ConsumeRequest,
|
||||
@@ -569,12 +575,12 @@ async def consume_points(
|
||||
source_type=request.source_type,
|
||||
source_id=request.source_id,
|
||||
description=f"【{request.description or request.source_type}】",
|
||||
allow_negative=False,
|
||||
allow_negative=request.allow_negative,
|
||||
)
|
||||
except ValueError as e:
|
||||
except InsufficientPointsException:
|
||||
# 余额不足(在同一事务内判断,避免竞态)
|
||||
if "积分不足" in str(e):
|
||||
raise HTTPException(status_code=402, detail=str(e))
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
await db.commit()
|
||||
|
||||
@@ -587,5 +593,3 @@ async def consume_points(
|
||||
},
|
||||
message="消费成功",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.ai.model_router import get_model_router
|
||||
from app.ai.prompts import list_categories, list_prompt_files, load_prompt, render_template
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.exceptions import (
|
||||
AITimeoutException,
|
||||
InsufficientPointsException,
|
||||
PlatformError,
|
||||
PlatformErrorType,
|
||||
)
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
@@ -33,6 +39,51 @@ router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _map_platform_error(e: PlatformError) -> HTTPException:
|
||||
"""把第三方平台错误映射为用户友好的 HTTP 异常(带标准 error_code)。"""
|
||||
if e.error_type == PlatformErrorType.CONTENT_VIOLATION:
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": "人物分镜台词未通过安全审核,请修改后重试",
|
||||
"error_code": "content_violation",
|
||||
},
|
||||
)
|
||||
if e.error_type == PlatformErrorType.RATE_LIMIT:
|
||||
return HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"message": "当前请求过于频繁,请稍后再试",
|
||||
"error_code": "rate_limit",
|
||||
},
|
||||
)
|
||||
if e.error_type == PlatformErrorType.TIMEOUT:
|
||||
return AITimeoutException("服务响应超时,请稍后重试")
|
||||
if e.error_type == PlatformErrorType.AUTH_FAILED:
|
||||
return HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"message": "第三方服务认证失败,请稍后重试或联系客服",
|
||||
"error_code": "auth_failed",
|
||||
},
|
||||
)
|
||||
if e.error_type == PlatformErrorType.SERVER_ERROR:
|
||||
return HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"message": "第三方服务繁忙,请稍后重试",
|
||||
"error_code": "server_error",
|
||||
},
|
||||
)
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": "请求失败,请检查后重试",
|
||||
"error_code": e.error_type or "unknown",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/categories", response_model=ApiResponse[list[CategoryItem]])
|
||||
async def get_categories():
|
||||
"""
|
||||
@@ -71,9 +122,8 @@ async def polish_content(
|
||||
required_points = ps._calculate_cost("polish")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
raise InsufficientPointsException(
|
||||
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -99,11 +149,13 @@ async def polish_content(
|
||||
data=polished,
|
||||
message=f"{type_name}润色完成",
|
||||
)
|
||||
except InsufficientPointsException:
|
||||
raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except PlatformError as e:
|
||||
raise _map_platform_error(e)
|
||||
except ValueError as e:
|
||||
if "积分不足" in str(e):
|
||||
raise HTTPException(status_code=402, detail=str(e))
|
||||
logger.warning(f"[Polish] 润色失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="润色失败,请检查输入内容后重试")
|
||||
except Exception as e:
|
||||
@@ -111,9 +163,6 @@ async def polish_content(
|
||||
raise HTTPException(status_code=500, detail=f"{type_name}润色失败,请稍后重试")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
|
||||
async def generate_title(
|
||||
request: GenerateTitleRequest,
|
||||
@@ -146,7 +195,11 @@ async def generate_title(
|
||||
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
|
||||
|
||||
# 渲染用户提示词
|
||||
title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
|
||||
title_type_desc = (
|
||||
"大标题(主标题,提炼核心卖点,吸睛)"
|
||||
if request.title_type == "main"
|
||||
else "小标题(副标题,补充说明或制造悬念)"
|
||||
)
|
||||
user_prompt = render_template(
|
||||
user_template,
|
||||
title_type=request.title_type,
|
||||
@@ -163,9 +216,8 @@ async def generate_title(
|
||||
required_points = ps._calculate_cost("title")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
raise InsufficientPointsException(
|
||||
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -179,10 +231,10 @@ async def generate_title(
|
||||
|
||||
title = result.content.strip() if result.content else ""
|
||||
# 去除可能的引号
|
||||
title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
|
||||
title = title.strip('"').strip("'").strip("「」").strip("『』").strip("《》")
|
||||
# 截断到最大长度
|
||||
if len(title) > request.max_length:
|
||||
title = title[:request.max_length]
|
||||
title = title[: request.max_length]
|
||||
|
||||
# 扣费
|
||||
points = ps._calculate_cost("title")
|
||||
@@ -200,15 +252,15 @@ async def generate_title(
|
||||
data=GenerateTitleResponse(title=title),
|
||||
message="标题生成成功",
|
||||
)
|
||||
except InsufficientPointsException:
|
||||
raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except PlatformError as e:
|
||||
raise _map_platform_error(e)
|
||||
except TimeoutError:
|
||||
logger.warning("[generate_title] 标题生成超时")
|
||||
raise HTTPException(status_code=500, detail="标题生成超时,请重试")
|
||||
except ValueError as e:
|
||||
if "积分不足" in str(e):
|
||||
raise HTTPException(status_code=402, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
|
||||
raise AITimeoutException("标题生成超时,请稍后重试")
|
||||
except Exception as e:
|
||||
logger.error(f"[generate_title] 标题生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
|
||||
|
||||
@@ -46,6 +46,3 @@ async def system_version():
|
||||
},
|
||||
message="获取版本成功",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@ import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.exceptions import InsufficientPointsException
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.models.user import User
|
||||
@@ -38,7 +39,6 @@ class ScriptParams(BaseModel):
|
||||
category: str = Field(..., min_length=1, description="大类代码")
|
||||
filename: str = Field(..., min_length=1, description="提示词文件名")
|
||||
|
||||
|
||||
@field_validator("category")
|
||||
@classmethod
|
||||
def validate_category(cls, v: str) -> str:
|
||||
@@ -69,6 +69,12 @@ class SubtitleParams(BaseModel):
|
||||
raise ValueError("video_path 不能为空")
|
||||
return v.strip()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_auto_align(self) -> "SubtitleParams":
|
||||
if self.mode == "auto_align" and (not self.audio_text or not self.audio_text.strip()):
|
||||
raise ValueError("auto_align 模式必须提供 audio_text")
|
||||
return self
|
||||
|
||||
|
||||
class TTSParams(BaseModel):
|
||||
"""TTS 语音合成参数"""
|
||||
@@ -96,7 +102,9 @@ class VideoParams(BaseModel):
|
||||
volume: int = Field(default=0, ge=0, le=10, description="音量")
|
||||
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
|
||||
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
|
||||
total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
|
||||
total_planned_duration: float = Field(
|
||||
..., gt=0, description="所有分镜规划时长之和(秒),用于预检"
|
||||
)
|
||||
batch_id: str | None = Field(default=None, description="批次ID(可选)")
|
||||
|
||||
@field_validator("video_url")
|
||||
@@ -106,6 +114,12 @@ class VideoParams(BaseModel):
|
||||
raise ValueError("video_url 不能为空")
|
||||
return v.strip()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_audio_or_text(self) -> "VideoParams":
|
||||
if not self.audio_url and not self.text:
|
||||
raise ValueError("audio_url 和 text 必须至少填一个")
|
||||
return self
|
||||
|
||||
|
||||
class TaskCreateRequest(BaseModel):
|
||||
"""创建任务请求"""
|
||||
@@ -134,6 +148,7 @@ class TaskStatusResponse(BaseModel):
|
||||
total: int = Field(0, description="总子任务数")
|
||||
result: dict | None = Field(None, description="任务结果(完成时)")
|
||||
error: str | None = Field(None, description="错误信息(失败时)")
|
||||
error_code: str | None = Field(None, description="错误码(失败时,如 content_violation)")
|
||||
created_at: str = Field("", description="任务创建时间(ISO格式)")
|
||||
|
||||
|
||||
@@ -175,6 +190,9 @@ async def create_task(
|
||||
validated_params = {
|
||||
"category": script_validated.category,
|
||||
"filename": script_validated.filename,
|
||||
"user_id": user_id,
|
||||
"required_points": required_points,
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
elif task_type == "subtitle":
|
||||
@@ -222,9 +240,8 @@ async def create_task(
|
||||
f"[API] 积分不足: user={user_id}, type={task_type}, "
|
||||
f"required={required_points}, balance={check['balance']}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
raise InsufficientPointsException(
|
||||
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
# ── 3. 写入 Redis ──────────────────────────────────
|
||||
@@ -246,6 +263,9 @@ async def create_task(
|
||||
params=validated_params,
|
||||
)
|
||||
await registry.add_running(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Failed to update registry: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
|
||||
|
||||
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
|
||||
return TaskCreateResponse(
|
||||
@@ -254,10 +274,6 @@ async def create_task(
|
||||
message=f"{task_type} 任务已创建",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[API] Failed to update registry: {e}")
|
||||
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
|
||||
|
||||
|
||||
@router.get("", response_model=list[TaskStatusResponse])
|
||||
async def list_tasks(
|
||||
@@ -294,6 +310,7 @@ async def list_tasks(
|
||||
total=task.total,
|
||||
result=None, # 列表查询不返回 result,避免数据过大
|
||||
error=task.error,
|
||||
error_code=task.error_code,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
)
|
||||
@@ -337,6 +354,7 @@ async def get_task_status(
|
||||
total=task.total,
|
||||
result=task.result,
|
||||
error=task.error,
|
||||
error_code=task.error_code,
|
||||
created_at=task.created_at,
|
||||
)
|
||||
|
||||
|
||||
+143
-14
@@ -7,13 +7,15 @@
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.models.update import AppRelease, ReleasePackage
|
||||
from app.schemas.update import (
|
||||
PackageInfo,
|
||||
ReleaseCreate,
|
||||
ReleaseListItem,
|
||||
ReleaseResponse,
|
||||
@@ -38,9 +40,7 @@ async def check_update(
|
||||
如果无需更新,返回 204;如果有更新,返回 Tauri 标准格式的 JSON。
|
||||
"""
|
||||
# 查询最新版本
|
||||
result = await db.execute(
|
||||
select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1)
|
||||
)
|
||||
result = await db.execute(select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1))
|
||||
latest: AppRelease | None = result.scalar_one_or_none()
|
||||
|
||||
if not latest:
|
||||
@@ -52,11 +52,13 @@ async def check_update(
|
||||
|
||||
# 查询对应平台的包(优先返回 updater 用的包:有 signature 的 .app.tar.gz / .exe)
|
||||
result = await db.execute(
|
||||
select(ReleasePackage).where(
|
||||
select(ReleasePackage)
|
||||
.where(
|
||||
ReleasePackage.release_id == latest.id,
|
||||
ReleasePackage.platform == target,
|
||||
ReleasePackage.architecture == arch,
|
||||
).order_by(
|
||||
)
|
||||
.order_by(
|
||||
# 有 signature 的排前面(updater 包),空 signature 的排后面(dmg 安装包)
|
||||
ReleasePackage.signature.desc()
|
||||
)
|
||||
@@ -86,6 +88,133 @@ async def check_update(
|
||||
)
|
||||
|
||||
|
||||
def _parse_user_agent(user_agent: str | None) -> tuple[str, str] | None:
|
||||
"""
|
||||
从 User-Agent 解析 Tauri 平台标识和架构。
|
||||
|
||||
Tauri updater 使用的 platform 值为:darwin / windows / linux
|
||||
architecture 值为:x86_64 / aarch64 / i686
|
||||
"""
|
||||
if not user_agent:
|
||||
return None
|
||||
|
||||
ua = user_agent.lower()
|
||||
|
||||
# Windows
|
||||
if "windows" in ua:
|
||||
platform = "windows"
|
||||
if "arm64" in ua or "aarch64" in ua:
|
||||
arch = "aarch64"
|
||||
elif "win64" in ua or "x64" in ua:
|
||||
arch = "x86_64"
|
||||
else:
|
||||
arch = "x86_64"
|
||||
return platform, arch
|
||||
|
||||
# macOS
|
||||
if "macintosh" in ua or "mac os x" in ua:
|
||||
platform = "darwin"
|
||||
if "arm64" in ua or "aarch64" in ua:
|
||||
arch = "aarch64"
|
||||
elif "intel" in ua:
|
||||
arch = "x86_64"
|
||||
else:
|
||||
# 现代 Mac 默认按 Apple Silicon 处理;
|
||||
# 若浏览器/Rosetta 环境未暴露 arm64,可让用户手动选择
|
||||
arch = "aarch64"
|
||||
return platform, arch
|
||||
|
||||
# Linux
|
||||
if "linux" in ua:
|
||||
platform = "linux"
|
||||
if "aarch64" in ua or "arm64" in ua:
|
||||
arch = "aarch64"
|
||||
elif "x86_64" in ua or "x64" in ua:
|
||||
arch = "x86_64"
|
||||
else:
|
||||
arch = "x86_64"
|
||||
return platform, arch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download_latest(
|
||||
request: Request,
|
||||
target: str | None = Query(None, description="平台:darwin / windows / linux"),
|
||||
arch: str | None = Query(None, description="架构:x86_64 / aarch64 / i686"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统一下载入口:自动匹配最新版本和当前环境安装包。
|
||||
|
||||
优先级:
|
||||
1. 查询参数 target + arch(最可靠,推荐前端显式传入)
|
||||
2. User-Agent 解析(兜底)
|
||||
|
||||
返回 302 重定向到对应安装包的存储地址。
|
||||
"""
|
||||
# 1. 确定平台和架构
|
||||
if target and arch:
|
||||
platform = target.lower()
|
||||
architecture = arch.lower()
|
||||
else:
|
||||
parsed = _parse_user_agent(request.headers.get("user-agent"))
|
||||
if not parsed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="无法识别您的操作系统,请通过官网或应用商店下载对应版本",
|
||||
)
|
||||
platform, architecture = parsed
|
||||
|
||||
# 2. 查询最新版本
|
||||
result = await db.execute(select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1))
|
||||
latest: AppRelease | None = result.scalar_one_or_none()
|
||||
if not latest:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="暂无可用下载",
|
||||
)
|
||||
|
||||
# 3. 查询该平台所有包(不限制架构,因为 macOS 常用 universal 包会同时写入 x86_64/aarch64)
|
||||
result = await db.execute(
|
||||
select(ReleasePackage).where(
|
||||
ReleasePackage.release_id == latest.id,
|
||||
ReleasePackage.platform == platform,
|
||||
)
|
||||
)
|
||||
platform_pkgs = list(result.scalars().all())
|
||||
|
||||
if not platform_pkgs:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"版本 {latest.version} 暂无可用的 {platform} 安装包",
|
||||
)
|
||||
|
||||
# 4. 优先选择用户安装包,而不是 updater 用的 .app.tar.gz
|
||||
# 优先级:.dmg / .exe / .msi / .AppImage > .app.tar.gz > 其他
|
||||
def _install_pkg_priority(pkg: ReleasePackage) -> int:
|
||||
name = pkg.filename.lower()
|
||||
if name.endswith(".dmg"):
|
||||
return 1
|
||||
if name.endswith(".exe"):
|
||||
return 2
|
||||
if name.endswith(".msi"):
|
||||
return 3
|
||||
if name.endswith(".appimage"):
|
||||
return 4
|
||||
if name.endswith(".app.tar.gz"):
|
||||
return 10
|
||||
return 5
|
||||
|
||||
# 先尝试精确架构 + 高优先级安装包
|
||||
exact_arch_pkgs = [p for p in platform_pkgs if p.architecture == architecture]
|
||||
candidate_pkgs = exact_arch_pkgs or platform_pkgs
|
||||
pkg = min(candidate_pkgs, key=_install_pkg_priority)
|
||||
|
||||
return RedirectResponse(url=pkg.file_url)
|
||||
|
||||
|
||||
@router.post("/releases", response_model=ReleaseResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_release(
|
||||
release: ReleaseCreate,
|
||||
@@ -139,14 +268,14 @@ async def create_release(
|
||||
mandatory=new_release.mandatory,
|
||||
created_at=new_release.created_at,
|
||||
packages=[
|
||||
{
|
||||
"platform": p.platform,
|
||||
"architecture": p.architecture,
|
||||
"filename": p.filename,
|
||||
"file_url": p.file_url,
|
||||
"file_size": p.file_size,
|
||||
"signature": p.signature,
|
||||
}
|
||||
PackageInfo(
|
||||
platform=p.platform,
|
||||
architecture=p.architecture,
|
||||
filename=p.filename,
|
||||
file_url=p.file_url,
|
||||
file_size=p.file_size,
|
||||
signature=p.signature,
|
||||
)
|
||||
for p in new_release.packages
|
||||
],
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
"""上传响应"""
|
||||
|
||||
@@ -103,8 +102,8 @@ async def upload_video(
|
||||
domain=domain,
|
||||
)
|
||||
|
||||
url = result.get("url")
|
||||
key = result.get("key")
|
||||
url = result.get("url") or ""
|
||||
key = result.get("key") or ""
|
||||
|
||||
if not url:
|
||||
raise HTTPException(status_code=500, detail="上传到七牛云失败:未返回 URL")
|
||||
@@ -126,8 +125,6 @@ async def upload_video(
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/audio", response_model=ApiResponse[UploadResponse])
|
||||
async def upload_audio(
|
||||
file: UploadFile = File(..., description="音频文件"),
|
||||
@@ -198,8 +195,8 @@ async def upload_audio(
|
||||
domain=domain,
|
||||
)
|
||||
|
||||
url = result.get("url")
|
||||
key = result.get("key")
|
||||
url = result.get("url") or ""
|
||||
key = result.get("key") or ""
|
||||
|
||||
if not url:
|
||||
raise HTTPException(status_code=500, detail="上传到七牛云失败:未返回 URL")
|
||||
|
||||
@@ -18,6 +18,10 @@ from app.core.exceptions import PlatformError
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.platform_gateway import PlatformGateway
|
||||
from app.schemas.common import success_response
|
||||
from app.utils.content_fingerprint import (
|
||||
extract_vidu_error_code,
|
||||
is_vidu_audit_error,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,10 +48,12 @@ async def vidu_callback(request: Request):
|
||||
headers_dict = dict(request.headers)
|
||||
|
||||
# 使用 APP_BASE_URL 构建 callback_url,确保与提交任务时传给 Vidu 的一致
|
||||
#(Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
|
||||
# (Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
|
||||
app_base_url = get_settings().app_base_url
|
||||
callback_url = f"{app_base_url}/api/v1/vidu/callback" if app_base_url else str(request.url)
|
||||
logger.info(f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}")
|
||||
logger.info(
|
||||
f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}"
|
||||
)
|
||||
|
||||
try:
|
||||
task_status = await gateway.handle_webhook(
|
||||
@@ -64,15 +70,13 @@ async def vidu_callback(request: Request):
|
||||
logger.error(f"[Vidu] 回调处理失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="回调处理失败,请稍后重试")
|
||||
|
||||
logger.info(f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}")
|
||||
logger.info(
|
||||
f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}"
|
||||
)
|
||||
|
||||
# 2. 通过 platform_task_id 反查 Async Engine 内部 task_id,更新 TaskRegistry
|
||||
platform_task_id = (
|
||||
task_status.result.get("task_id") if task_status.result else None
|
||||
)
|
||||
video_url = (
|
||||
task_status.result.get("video_url") if task_status.result else None
|
||||
)
|
||||
platform_task_id = task_status.result.get("task_id") if task_status.result else None
|
||||
video_url = task_status.result.get("video_url") if task_status.result else None
|
||||
|
||||
logger.info(f"[Vidu] 准备反查 internal_task_id: platform_task_id={platform_task_id}")
|
||||
|
||||
@@ -105,12 +109,23 @@ async def vidu_callback(request: Request):
|
||||
result={"video_url": video_url, "state": "success"},
|
||||
)
|
||||
elif task_status.state == "failed":
|
||||
await registry.update(
|
||||
internal_task_id,
|
||||
status="failed",
|
||||
message="视频生成失败",
|
||||
error=task_status.error_message or "视频生成失败",
|
||||
)
|
||||
error_message = task_status.error_message or "视频生成失败"
|
||||
err_code = extract_vidu_error_code(error_message)
|
||||
is_audit = err_code and is_vidu_audit_error(err_code)
|
||||
|
||||
update_kwargs = {
|
||||
"status": "failed",
|
||||
"message": (
|
||||
"人物分镜台词未通过安全审核,请修改后重试"
|
||||
if is_audit
|
||||
else "视频生成失败"
|
||||
),
|
||||
"error": error_message,
|
||||
}
|
||||
if is_audit:
|
||||
update_kwargs["error_code"] = "content_violation"
|
||||
|
||||
await registry.update(internal_task_id, **update_kwargs)
|
||||
logger.info(
|
||||
f"[Vidu] 回调已更新 TaskRegistry: task={internal_task_id}, "
|
||||
f"state={task_status.state}, video_url={video_url}"
|
||||
@@ -121,8 +136,6 @@ async def vidu_callback(request: Request):
|
||||
f"platform={platform_task_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}"
|
||||
)
|
||||
logger.warning(f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}")
|
||||
|
||||
return success_response(message="回调已接收")
|
||||
|
||||
@@ -10,12 +10,13 @@ import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.exceptions import PlatformError
|
||||
from app.core.exceptions import InsufficientPointsException, PlatformError
|
||||
from app.db.session import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.common import ApiResponse, success_response
|
||||
@@ -49,7 +50,9 @@ class TTSSynthesizeRequest(BaseModel):
|
||||
class VoiceCloneSubmitRequest(BaseModel):
|
||||
"""声音复刻提交请求"""
|
||||
|
||||
source_audio_url: str | None = Field(None, description="源音频 URL(5-30秒,mp3/wav,需公开可访问)")
|
||||
source_audio_url: str | None = Field(
|
||||
None, description="源音频 URL(5-30秒,mp3/wav,需公开可访问)"
|
||||
)
|
||||
source_video_url: str | None = Field(None, description="源视频 URL(可选)")
|
||||
video_id: str | None = Field(None, description="历史作品ID(可选)")
|
||||
voice_name: str | None = Field(None, description="自定义音色名称(≤20字符)")
|
||||
@@ -111,7 +114,7 @@ async def synthesize_speech(
|
||||
# 宽松预检:余额为负或零时阻止,避免浪费第三方资源
|
||||
balance_info = await ps.get_user_balance(db, current_user.id)
|
||||
if balance_info["balance"] <= 0:
|
||||
raise HTTPException(status_code=402, detail="余额不足,请先充值")
|
||||
raise InsufficientPointsException("余额不足,请先充值")
|
||||
|
||||
try:
|
||||
audio_url = await service.synthesize(
|
||||
@@ -122,7 +125,7 @@ async def synthesize_speech(
|
||||
pitch=request.pitch,
|
||||
)
|
||||
|
||||
# 探测音频时长并扣费
|
||||
# 探测音频时长并扣费(计费成功才返回结果)
|
||||
try:
|
||||
seconds = await get_audio_duration(audio_url)
|
||||
points = ps._calculate_cost("tts", {"seconds": seconds})
|
||||
@@ -137,12 +140,6 @@ async def synthesize_speech(
|
||||
allow_negative=True,
|
||||
)
|
||||
await db.commit()
|
||||
except ValueError as e:
|
||||
if "积分不足" in str(e):
|
||||
raise HTTPException(status_code=402, detail=str(e))
|
||||
logger.error(f"[Voice] TTS 扣费失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] TTS 扣费失败: {e}")
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
@@ -155,6 +152,14 @@ async def synthesize_speech(
|
||||
},
|
||||
message="合成成功",
|
||||
)
|
||||
except InsufficientPointsException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] TTS 扣费失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="语音合成计费失败,请稍后重试",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -165,7 +170,6 @@ async def synthesize_speech(
|
||||
raise HTTPException(status_code=500, detail="语音合成失败,请稍后重试")
|
||||
|
||||
|
||||
|
||||
def _normalize_voice_id(name: str | None) -> str:
|
||||
"""
|
||||
将用户输入的名称规范化为 Vidu 合法的 voice_id。
|
||||
@@ -220,9 +224,8 @@ async def submit_clone_task(
|
||||
required_points = ps._calculate_cost("voice_clone")
|
||||
check = await ps.check_balance(db, current_user.id, required_points)
|
||||
if not check["sufficient"]:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
|
||||
raise InsufficientPointsException(
|
||||
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -244,10 +247,8 @@ async def submit_clone_task(
|
||||
description="【声音复刻】",
|
||||
)
|
||||
await db.commit()
|
||||
except ValueError as e:
|
||||
if "积分不足" in str(e):
|
||||
raise HTTPException(status_code=402, detail=str(e))
|
||||
logger.error(f"[Voice] 克隆扣费失败: {e}")
|
||||
except InsufficientPointsException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice] 克隆扣费失败: {e}")
|
||||
|
||||
@@ -292,5 +293,3 @@ async def query_clone_task(
|
||||
),
|
||||
message="克隆已完成",
|
||||
)
|
||||
|
||||
|
||||
|
||||
+11
-17
@@ -24,7 +24,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 应用基础配置
|
||||
APP_NAME: str = Field(default="美家卡智影 API", description="应用名称")
|
||||
APP_VERSION: str = Field(default="1.8.2", description="应用版本")
|
||||
APP_VERSION: str = Field(default="1.9.1", description="应用版本")
|
||||
DEBUG: bool = Field(default=False, description="调试模式")
|
||||
ENV: Literal["development", "staging", "production"] = Field(
|
||||
default="development", description="运行环境"
|
||||
@@ -45,7 +45,9 @@ class Settings(BaseSettings):
|
||||
description="数据库连接字符串(PostgreSQL)",
|
||||
)
|
||||
DATABASE_POOL_SIZE: int = Field(default=10, description="数据库连接池常驻连接数")
|
||||
DATABASE_MAX_OVERFLOW: int = Field(default=10, description="连接池临时溢出上限(建议 ≤ pool_size)")
|
||||
DATABASE_MAX_OVERFLOW: int = Field(
|
||||
default=10, description="连接池临时溢出上限(建议 ≤ pool_size)"
|
||||
)
|
||||
DATABASE_POOL_RECYCLE: int = Field(
|
||||
default=1800, description="连接回收时间(秒),防止长连接被数据库静默断开"
|
||||
)
|
||||
@@ -73,7 +75,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 安全配置
|
||||
SECRET_KEY: str = Field(
|
||||
...,
|
||||
default="",
|
||||
description="JWT 签名密钥(生产环境必须修改)",
|
||||
)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(
|
||||
@@ -107,7 +109,9 @@ class Settings(BaseSettings):
|
||||
VOLCENGINE_CAPTION_TOKEN: str | None = Field(default=None, description="火山字幕 Token")
|
||||
|
||||
# 火山引擎 MediaKit 服务(背景移除等多媒体处理)
|
||||
VOLCENGINE_MEDIAKIT_TOKEN: str | None = Field(default=None, description="火山引擎 MediaKit Token")
|
||||
VOLCENGINE_MEDIAKIT_TOKEN: str | None = Field(
|
||||
default=None, description="火山引擎 MediaKit Token"
|
||||
)
|
||||
|
||||
# Vidu 密钥(base_url 已从 Settings 移除,改用 config/platform-config.yaml 配置)
|
||||
VIDU_API_KEY: str | None = Field(default=None, description="Vidu API Key")
|
||||
@@ -124,9 +128,7 @@ class Settings(BaseSettings):
|
||||
WXPAY_MCHID: str | None = Field(default=None, description="微信支付商户号")
|
||||
WXPAY_APPID: str | None = Field(default=None, description="微信支付 AppID")
|
||||
WXPAY_API_KEY: str | None = Field(default=None, description="微信支付 APIv2 密钥")
|
||||
WXPAY_NOTIFY_URL: str | None = Field(
|
||||
default=None, description="微信支付回调地址(完整 URL)"
|
||||
)
|
||||
WXPAY_NOTIFY_URL: str | None = Field(default=None, description="微信支付回调地址(完整 URL)")
|
||||
|
||||
# B2M 短信平台配置
|
||||
SMS_APP_ID: str | None = Field(default=None, description="B2M 短信平台 AppID")
|
||||
@@ -134,16 +136,12 @@ class Settings(BaseSettings):
|
||||
SMS_BASE_URL: str | None = Field(
|
||||
default=None, description="B2M 短信平台接口地址(如 http://sms.b2m.cn:8080)"
|
||||
)
|
||||
SMS_EXTENDED_CODE: str | None = Field(
|
||||
default=None, description="B2M 短信平台扩展码(选填)"
|
||||
)
|
||||
SMS_EXTENDED_CODE: str | None = Field(default=None, description="B2M 短信平台扩展码(选填)")
|
||||
SMS_CODE_WHITELIST: str = Field(
|
||||
default="",
|
||||
description="免验证码登录白名单(逗号分隔的手机号,如 13800138000,13900139000)",
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 文件上传限制(字节)
|
||||
UPLOAD_MAX_VIDEO_SIZE: int = Field(
|
||||
default=500 * 1024 * 1024, description="视频最大上传大小(字节)"
|
||||
@@ -187,11 +185,7 @@ class Settings(BaseSettings):
|
||||
"""免验证码登录白名单(去重、去空格)"""
|
||||
if not self.SMS_CODE_WHITELIST:
|
||||
return set()
|
||||
return {
|
||||
mobile.strip()
|
||||
for mobile in self.SMS_CODE_WHITELIST.split(",")
|
||||
if mobile.strip()
|
||||
}
|
||||
return {mobile.strip() for mobile in self.SMS_CODE_WHITELIST.split(",") if mobile.strip()}
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
||||
@@ -24,9 +24,16 @@ class AppException(HTTPException):
|
||||
status_code: int,
|
||||
message: str = "操作失败",
|
||||
detail: dict | None = None,
|
||||
*,
|
||||
error_code: str | None = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail=detail or {})
|
||||
body = detail or {}
|
||||
body["message"] = message
|
||||
if error_code:
|
||||
body["error_code"] = error_code
|
||||
super().__init__(status_code=status_code, detail=body)
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class NotFoundException(AppException):
|
||||
@@ -44,7 +51,7 @@ class ValidationException(AppException):
|
||||
|
||||
def __init__(self, message: str = "参数验证失败"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
message=message,
|
||||
)
|
||||
|
||||
@@ -79,6 +86,17 @@ class BusinessException(AppException):
|
||||
)
|
||||
|
||||
|
||||
class InsufficientPointsException(AppException):
|
||||
"""积分不足"""
|
||||
|
||||
def __init__(self, message: str = "积分不足"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
message=message,
|
||||
error_code="insufficient_points",
|
||||
)
|
||||
|
||||
|
||||
class ModelUnavailableException(AppException):
|
||||
"""AI 模型不可用"""
|
||||
|
||||
@@ -99,6 +117,50 @@ class TaskFailedException(AppException):
|
||||
)
|
||||
|
||||
|
||||
class PromptNotFoundException(AppException):
|
||||
"""提示词文件不存在"""
|
||||
|
||||
def __init__(self, message: str = "未找到提示词"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
message=message,
|
||||
error_code="prompt_not_found",
|
||||
)
|
||||
|
||||
|
||||
class AIEmptyResponseException(AppException):
|
||||
"""AI 返回内容为空"""
|
||||
|
||||
def __init__(self, message: str = "AI 返回内容为空"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=message,
|
||||
error_code="empty_result",
|
||||
)
|
||||
|
||||
|
||||
class AIParseErrorException(AppException):
|
||||
"""AI 返回内容解析失败"""
|
||||
|
||||
def __init__(self, message: str = "AI 返回格式解析失败"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=message,
|
||||
error_code="parse_error",
|
||||
)
|
||||
|
||||
|
||||
class AITimeoutException(AppException):
|
||||
"""AI 调用超时"""
|
||||
|
||||
def __init__(self, message: str = "AI 请求超时,请稍后重试"):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
message=message,
|
||||
error_code="timeout",
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# 第三方平台异常(PlatformError 体系)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
@@ -118,6 +180,7 @@ class PlatformErrorType:
|
||||
BAD_REQUEST = "bad_request" # 参数错误,不可重试
|
||||
QUOTA_EXHAUSTED = "quota_exhausted" # 额度用完,不可重试(或延迟重试)
|
||||
NOT_FOUND = "not_found" # 资源不存在,不可重试
|
||||
CONTENT_VIOLATION = "content_violation" # 内容安全/审核不通过,不可重试
|
||||
UNKNOWN = "unknown" # 兜底
|
||||
|
||||
|
||||
@@ -145,12 +208,14 @@ class PlatformError(Exception):
|
||||
retryable: bool = False,
|
||||
error_type: str = PlatformErrorType.UNKNOWN,
|
||||
status_code: int | None = None,
|
||||
raw_code: str | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.platform = platform
|
||||
self.retryable = retryable
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
self.raw_code = raw_code
|
||||
|
||||
def to_http_status(self) -> int:
|
||||
"""根据 error_type 和 retryable 返回标准 HTTP 状态码"""
|
||||
@@ -161,6 +226,7 @@ class PlatformError(Exception):
|
||||
PlatformErrorType.AUTH_FAILED: 401,
|
||||
PlatformErrorType.BAD_REQUEST: 400,
|
||||
PlatformErrorType.NOT_FOUND: 404,
|
||||
PlatformErrorType.CONTENT_VIOLATION: 400,
|
||||
}
|
||||
if self.error_type in mapping:
|
||||
return mapping[self.error_type]
|
||||
|
||||
@@ -143,9 +143,7 @@ class PlatformConfigLoader:
|
||||
启动时加载,全环境只读(不支持热重载)。
|
||||
"""
|
||||
|
||||
DEFAULT_CONFIG_PATH = (
|
||||
Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
|
||||
)
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
|
||||
|
||||
def __init__(self, config_path: str | None = None):
|
||||
self.config_path = Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
|
||||
@@ -157,9 +155,7 @@ class PlatformConfigLoader:
|
||||
def _load(self) -> None:
|
||||
"""加载并校验配置文件"""
|
||||
if not self.config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"平台配置文件不存在: {self.config_path}"
|
||||
)
|
||||
raise FileNotFoundError(f"平台配置文件不存在: {self.config_path}")
|
||||
|
||||
try:
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
@@ -215,18 +211,10 @@ class PlatformConfigLoader:
|
||||
return [m for m in self._models.values() if m.is_enabled]
|
||||
|
||||
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
|
||||
return [
|
||||
m
|
||||
for m in self._models.values()
|
||||
if m.is_enabled and capability in m.capabilities
|
||||
]
|
||||
return [m for m in self._models.values() if m.is_enabled and capability in m.capabilities]
|
||||
|
||||
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
|
||||
return [
|
||||
m
|
||||
for m in self._models.values()
|
||||
if m.platform_id == platform_id and m.is_enabled
|
||||
]
|
||||
return [m for m in self._models.values() if m.platform_id == platform_id and m.is_enabled]
|
||||
|
||||
def get_default_model_for_task(self, task_type: str) -> str | None:
|
||||
if self._raw is None:
|
||||
|
||||
@@ -4,6 +4,8 @@ Redis 客户端
|
||||
全局 Redis 连接,供 Scheduler 和 RateLimiter 使用
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import get_settings
|
||||
@@ -19,7 +21,7 @@ def get_redis_client() -> Redis:
|
||||
settings = get_settings()
|
||||
|
||||
# 构建连接参数
|
||||
client_kwargs = {
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"host": settings.REDIS_HOST,
|
||||
"port": settings.REDIS_PORT,
|
||||
"db": settings.REDIS_DB,
|
||||
|
||||
@@ -70,9 +70,7 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_level(
|
||||
self, db: AsyncSession, *, level: int
|
||||
) -> list[BrollCategory]:
|
||||
async def get_by_level(self, db: AsyncSession, *, level: int) -> list[BrollCategory]:
|
||||
"""根据层级获取所有启用的分类"""
|
||||
result = await db.execute(
|
||||
select(BrollCategory).where(
|
||||
|
||||
@@ -49,9 +49,7 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def increment_usage_count(
|
||||
self, db: AsyncSession, *, material_id: int
|
||||
) -> None:
|
||||
async def increment_usage_count(self, db: AsyncSession, *, material_id: int) -> None:
|
||||
"""
|
||||
原子递增素材使用次数
|
||||
|
||||
|
||||
@@ -23,9 +23,7 @@ class PointRechargeOrderCRUD(CRUDBase[PointRechargeOrder]):
|
||||
) -> PointRechargeOrder | None:
|
||||
"""根据商户订单号查询"""
|
||||
result = await db.execute(
|
||||
select(PointRechargeOrder).where(
|
||||
PointRechargeOrder.out_trade_no == out_trade_no
|
||||
)
|
||||
select(PointRechargeOrder).where(PointRechargeOrder.out_trade_no == out_trade_no)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -34,9 +32,7 @@ class PointRechargeOrderCRUD(CRUDBase[PointRechargeOrder]):
|
||||
) -> PointRechargeOrder | None:
|
||||
"""根据微信支付订单号查询"""
|
||||
result = await db.execute(
|
||||
select(PointRechargeOrder).where(
|
||||
PointRechargeOrder.wx_order_no == wx_order_no
|
||||
)
|
||||
select(PointRechargeOrder).where(PointRechargeOrder.wx_order_no == wx_order_no)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime, time
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -24,7 +25,7 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
user_id: UUID | str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
tx_type: str | None = None,
|
||||
@@ -55,7 +56,7 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
user_id: UUID | str,
|
||||
tx_type: str | None = None,
|
||||
category: str | None = None,
|
||||
source_type: str | None = None,
|
||||
@@ -105,19 +106,16 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
user_id: UUID | str,
|
||||
) -> int:
|
||||
"""统计用户今日消费积分总和"""
|
||||
now = datetime.now()
|
||||
start_of_day = datetime.combine(now.date(), time.min)
|
||||
stmt = (
|
||||
select(func.coalesce(func.sum(PointTransaction.amount), 0))
|
||||
.where(
|
||||
stmt = select(func.coalesce(func.sum(PointTransaction.amount), 0)).where(
|
||||
PointTransaction.user_id == user_id,
|
||||
PointTransaction.type == "consume",
|
||||
PointTransaction.created_at >= start_of_day,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
@@ -5,7 +5,11 @@
|
||||
用户认证相关的数据访问。
|
||||
"""
|
||||
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
@@ -48,7 +52,7 @@ class UserCRUD(CRUDBase[User]):
|
||||
return user
|
||||
|
||||
async def update_login_info(
|
||||
self, db: AsyncSession, *, user_id: str, ip: str | None = None
|
||||
self, db: AsyncSession, *, user_id: UUID | str, ip: str | None = None
|
||||
) -> User | None:
|
||||
"""
|
||||
更新用户最后登录信息
|
||||
@@ -68,7 +72,7 @@ class UserCRUD(CRUDBase[User]):
|
||||
return user
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user_id: str, password_hash: str
|
||||
self, db: AsyncSession, *, user_id: UUID | str, password_hash: str
|
||||
) -> User | None:
|
||||
"""更新用户密码"""
|
||||
user = await self.get(db, id=user_id)
|
||||
@@ -80,9 +84,7 @@ class UserCRUD(CRUDBase[User]):
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def update_extra(
|
||||
self, db: AsyncSession, *, user_id: str, extra: dict
|
||||
) -> bool:
|
||||
async def update_extra(self, db: AsyncSession, *, user_id: UUID | str, extra: dict) -> bool:
|
||||
"""
|
||||
原子更新用户 extra 字段(JSONB)
|
||||
|
||||
@@ -90,12 +92,8 @@ class UserCRUD(CRUDBase[User]):
|
||||
"""
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(extra=extra)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
stmt = update(User).where(User.id == user_id).values(extra=extra)
|
||||
result = cast(CursorResult[Any], await db.execute(stmt))
|
||||
await db.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
核心操作是「覆盖」而非「新增」,使用 INSERT ... ON CONFLICT DO UPDATE 保证原子性。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -20,18 +22,16 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(UserDevice)
|
||||
|
||||
async def get_by_user_id(self, db: AsyncSession, *, user_id: str) -> UserDevice | None:
|
||||
async def get_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> UserDevice | None:
|
||||
"""根据用户 ID 获取设备记录"""
|
||||
result = await db.execute(
|
||||
select(UserDevice).where(UserDevice.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_or_update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
user_id: UUID | str,
|
||||
device_id: str,
|
||||
device_name: str | None = None,
|
||||
os_info: str | None = None,
|
||||
@@ -80,12 +80,10 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
|
||||
await db.commit()
|
||||
|
||||
# 返回最新的记录
|
||||
result = await db.execute(
|
||||
select(UserDevice).where(UserDevice.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
|
||||
return result.scalar_one()
|
||||
|
||||
async def delete_by_user_id(self, db: AsyncSession, *, user_id: str) -> bool:
|
||||
async def delete_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> bool:
|
||||
"""根据用户 ID 删除设备记录(登出时使用)"""
|
||||
device = await self.get_by_user_id(db, user_id=user_id)
|
||||
if device is None:
|
||||
@@ -100,9 +98,7 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
|
||||
) -> UserDevice | None:
|
||||
"""根据 Refresh Token 哈希获取设备记录"""
|
||||
result = await db.execute(
|
||||
select(UserDevice).where(
|
||||
UserDevice.refresh_token_hash == refresh_token_hash
|
||||
)
|
||||
select(UserDevice).where(UserDevice.refresh_token_hash == refresh_token_hash)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -149,12 +149,8 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("VolcengineMediakitAdapter initialized")
|
||||
|
||||
if app.state.volcengine_provider:
|
||||
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
|
||||
app.state.volcengine_provider
|
||||
)
|
||||
app.state.platform_gateway.register(
|
||||
"volcengine_ark", app.state.volcengine_ark_adapter
|
||||
)
|
||||
app.state.volcengine_ark_adapter = VolcengineArkAdapter(app.state.volcengine_provider)
|
||||
app.state.platform_gateway.register("volcengine_ark", app.state.volcengine_ark_adapter)
|
||||
logger.info("VolcengineArkAdapter registered to PlatformGateway")
|
||||
|
||||
logger.info("PlatformGateway initialized")
|
||||
@@ -177,9 +173,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("Vidu Service initialized")
|
||||
|
||||
if app.state.volcengine_caption_provider:
|
||||
app.state.volcengine_caption_service = VolcengineCaptionService(
|
||||
app.state.platform_gateway
|
||||
)
|
||||
app.state.volcengine_caption_service = VolcengineCaptionService(app.state.platform_gateway)
|
||||
logger.info("Volcengine Caption Service initialized")
|
||||
else:
|
||||
app.state.volcengine_caption_service = None
|
||||
@@ -276,12 +270,14 @@ def create_app() -> FastAPI:
|
||||
"code": exc.status_code or http_status,
|
||||
"message": str(exc),
|
||||
"data": None,
|
||||
"error_code": exc.error_type,
|
||||
}
|
||||
if settings.DEBUG:
|
||||
content["detail"] = {
|
||||
"platform": exc.platform,
|
||||
"error_type": exc.error_type,
|
||||
"retryable": exc.retryable,
|
||||
"raw_code": exc.raw_code,
|
||||
}
|
||||
return _cors_response(request, http_status, content)
|
||||
|
||||
@@ -294,6 +290,9 @@ def create_app() -> FastAPI:
|
||||
exc.detail if isinstance(exc.detail, str) else "请求失败"
|
||||
)
|
||||
detail = exc.detail if isinstance(exc.detail, dict) else None
|
||||
error_code = getattr(exc, "error_code", None)
|
||||
if not error_code and isinstance(detail, dict):
|
||||
error_code = detail.get("error_code")
|
||||
|
||||
return _cors_response(
|
||||
request,
|
||||
@@ -301,6 +300,7 @@ def create_app() -> FastAPI:
|
||||
{
|
||||
"code": exc.status_code,
|
||||
"message": message,
|
||||
"error_code": error_code,
|
||||
"detail": detail,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -33,18 +33,10 @@ class BgmMusic(BaseModelBigInt):
|
||||
category: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, index=True, comment="场景分类"
|
||||
)
|
||||
file_path: Mapped[str] = mapped_column(
|
||||
String(512), nullable=False, comment="相对文件路径"
|
||||
)
|
||||
url: Mapped[str] = mapped_column(
|
||||
String(1024), nullable=False, comment="七牛云 URL"
|
||||
)
|
||||
duration: Mapped[float] = mapped_column(
|
||||
Float, nullable=True, comment="时长(秒)"
|
||||
)
|
||||
file_path: Mapped[str] = mapped_column(String(512), nullable=False, comment="相对文件路径")
|
||||
url: Mapped[str] = mapped_column(String(1024), nullable=False, comment="七牛云 URL")
|
||||
duration: Mapped[float] = mapped_column(Float, nullable=True, comment="时长(秒)")
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), default="active", nullable=False, comment="状态: active/inactive"
|
||||
)
|
||||
sort_order: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False, comment="排序权重"
|
||||
)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0, nullable=False, comment="排序权重")
|
||||
|
||||
@@ -7,7 +7,16 @@ Tauri updater 插件所需的数据结构。
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, BigInteger, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.session import Base
|
||||
@@ -60,5 +69,11 @@ class ReleasePackage(Base):
|
||||
release: Mapped["AppRelease"] = relationship("AppRelease", back_populates="packages")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("release_id", "platform", "architecture", "filename", name="uix_app_pkg_platform_arch_filename"),
|
||||
UniqueConstraint(
|
||||
"release_id",
|
||||
"platform",
|
||||
"architecture",
|
||||
"filename",
|
||||
name="uix_app_pkg_platform_arch_filename",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class UserStatus(str, enum.Enum):
|
||||
class UserStatus(enum.StrEnum):
|
||||
"""用户状态"""
|
||||
|
||||
ACTIVE = "active" # 正常
|
||||
@@ -117,6 +117,3 @@ class User(BaseModel):
|
||||
def display_name(self) -> str:
|
||||
"""对外展示的名称"""
|
||||
return self.nickname or f"用户_{self.mobile[-4:]}"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ from app.ai.adapters.base import (
|
||||
TaskStatus,
|
||||
)
|
||||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||||
from app.utils.content_fingerprint import (
|
||||
compute_content_fingerprint,
|
||||
is_vidu_audit_error,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,6 +35,10 @@ logger = logging.getLogger(__name__)
|
||||
_TASK_MAPPING_PREFIX = "platform_gateway:task_mapping"
|
||||
_TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天
|
||||
|
||||
# Redis key 前缀:内容审核失败缓存
|
||||
_AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection"
|
||||
_AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时
|
||||
|
||||
|
||||
class PlatformGateway:
|
||||
"""第三方平台统一调用网关"""
|
||||
@@ -61,10 +69,13 @@ class PlatformGateway:
|
||||
redis = self._get_redis()
|
||||
# 正向映射:internal → platform
|
||||
key = self._task_mapping_key(internal_task_id)
|
||||
await redis.hset(key, mapping={
|
||||
await redis.hset(
|
||||
key,
|
||||
mapping={
|
||||
"platform": platform,
|
||||
"platform_task_id": platform_task_id,
|
||||
})
|
||||
},
|
||||
)
|
||||
await redis.expire(key, _TASK_MAPPING_TTL)
|
||||
# 反向映射:platform → internal(供回调查找)
|
||||
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
|
||||
@@ -82,6 +93,33 @@ class PlatformGateway:
|
||||
"platform_task_id": data.get("platform_task_id", ""),
|
||||
}
|
||||
|
||||
def _audit_rejection_key(self, fingerprint: str) -> str:
|
||||
return f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}"
|
||||
|
||||
async def _get_audit_rejection(self, fingerprint: str) -> str | None:
|
||||
"""查询该内容指纹是否近期审核失败。
|
||||
|
||||
Returns:
|
||||
失败错误码(如 "TaskPromptPolicyViolation"),未命中返回 None
|
||||
"""
|
||||
if not fingerprint:
|
||||
return None
|
||||
try:
|
||||
redis = self._get_redis()
|
||||
key = self._audit_rejection_key(fingerprint)
|
||||
return await redis.get(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"[PlatformGateway] 查询审核缓存失败: {e}")
|
||||
return None
|
||||
|
||||
async def _set_audit_rejection(self, fingerprint: str, error_code: str) -> None:
|
||||
"""缓存审核失败结果。"""
|
||||
if not fingerprint or not error_code:
|
||||
return
|
||||
redis = self._get_redis()
|
||||
key = self._audit_rejection_key(fingerprint)
|
||||
await redis.setex(key, _AUDIT_REJECTION_TTL, error_code)
|
||||
|
||||
async def get_internal_task_id_by_platform_task_id(
|
||||
self, platform: str, platform_task_id: str
|
||||
) -> str | None:
|
||||
@@ -156,15 +194,62 @@ class PlatformGateway:
|
||||
若提供,则直接使用该 ID 建立映射;否则自动生成。
|
||||
callback 场景必须传入,确保回调能反查到正确的 Registry 记录。
|
||||
"""
|
||||
internal_task_id = internal_task_id or uuid.uuid4().hex
|
||||
|
||||
# 1. 同一 internal_task_id 已提交过,直接返回(幂等)
|
||||
existing = await self._get_task_mapping(internal_task_id)
|
||||
if existing:
|
||||
return internal_task_id
|
||||
|
||||
# 2. Vidu 内容指纹防重:相同内容近期审核失败则直接拦截
|
||||
fingerprint: str | None = None
|
||||
if platform == "vidu":
|
||||
fingerprint = compute_content_fingerprint(
|
||||
task_type=task_type,
|
||||
video_url=payload.get("video_url"),
|
||||
audio_url=payload.get("audio_url"),
|
||||
ref_photo_url=payload.get("ref_photo_url"),
|
||||
text=payload.get("text"),
|
||||
voice_id=payload.get("voice_id"),
|
||||
)
|
||||
rejected_code = await self._get_audit_rejection(fingerprint)
|
||||
if rejected_code:
|
||||
raise PlatformError(
|
||||
"人物分镜台词未通过安全审核,请修改后重试",
|
||||
platform=platform,
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.CONTENT_VIOLATION,
|
||||
raw_code=rejected_code,
|
||||
)
|
||||
|
||||
# 3. 调用平台 Adapter 提交任务
|
||||
adapter = self._get_task_adapter(platform, task_type)
|
||||
try:
|
||||
result = await adapter.submit(task_type, payload, callback_url)
|
||||
except PlatformError as e:
|
||||
# Vidu 审核类错误:缓存内容指纹,防止重复调用
|
||||
err_code = e.raw_code
|
||||
if platform == "vidu" and fingerprint and err_code and is_vidu_audit_error(err_code):
|
||||
await self._set_audit_rejection(fingerprint, err_code)
|
||||
raise PlatformError(
|
||||
"人物分镜台词未通过安全审核,请修改后重试",
|
||||
platform=platform,
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.CONTENT_VIOLATION,
|
||||
raw_code=err_code,
|
||||
) from e
|
||||
raise
|
||||
|
||||
if not result.success:
|
||||
raw_code = result.error_code
|
||||
if platform == "vidu" and fingerprint and raw_code and is_vidu_audit_error(raw_code):
|
||||
await self._set_audit_rejection(fingerprint, raw_code)
|
||||
raise PlatformError(
|
||||
result.error_message or "任务提交失败",
|
||||
platform=platform,
|
||||
retryable=result.retryable,
|
||||
error_type=PlatformErrorType.UNKNOWN,
|
||||
raw_code=raw_code,
|
||||
)
|
||||
|
||||
platform_task_id = (result.data or {}).get("task_id", "")
|
||||
@@ -175,7 +260,6 @@ class PlatformGateway:
|
||||
retryable=False,
|
||||
error_type=PlatformErrorType.UNKNOWN,
|
||||
)
|
||||
internal_task_id = internal_task_id or uuid.uuid4().hex
|
||||
await self._store_task_mapping(internal_task_id, platform, platform_task_id)
|
||||
logger.info(
|
||||
f"Task submitted: internal={internal_task_id}, "
|
||||
|
||||
@@ -7,6 +7,7 @@ Async Engine 核心调度器
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.core.redis_client import get_redis_client
|
||||
@@ -17,6 +18,13 @@ from app.scheduler.slot_manager import SlotManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 各任务类型最大执行时间(秒),超过后自动标记为 failed
|
||||
TASK_TIMEOUT_SECONDS = {
|
||||
"script": 5 * 60,
|
||||
"subtitle": 10 * 60,
|
||||
"video": 30 * 60,
|
||||
}
|
||||
|
||||
|
||||
class AsyncEngine:
|
||||
"""统一异步作业调度引擎"""
|
||||
@@ -46,13 +54,50 @@ class AsyncEngine:
|
||||
logger.debug("Tick: no running tasks")
|
||||
return
|
||||
|
||||
# 2. 按 task_type 分组
|
||||
# 2. 按 task_type 分组,并处理超时任务
|
||||
tasks_by_type: dict[str, list[Any]] = {}
|
||||
timeout_changes: list[StateChange] = []
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for task_id in running_ids:
|
||||
record = await self.registry.get(task_id)
|
||||
if not record:
|
||||
await self.registry.remove_running(task_id)
|
||||
continue
|
||||
|
||||
max_duration = TASK_TIMEOUT_SECONDS.get(record.task_type)
|
||||
is_timeout = (
|
||||
max_duration
|
||||
and record.status == "running"
|
||||
and record.created_at
|
||||
and (now - datetime.fromisoformat(record.created_at)).total_seconds() > max_duration
|
||||
)
|
||||
|
||||
if is_timeout:
|
||||
logger.warning(
|
||||
f"Task timeout: {task_id}, type={record.task_type}, "
|
||||
f"created_at={record.created_at}"
|
||||
)
|
||||
timeout_changes.append(
|
||||
StateChange(task_id=task_id, field_path="status", value="failed")
|
||||
)
|
||||
timeout_changes.append(
|
||||
StateChange(
|
||||
task_id=task_id,
|
||||
field_path="message",
|
||||
value="任务执行超时,请稍后重试",
|
||||
)
|
||||
)
|
||||
timeout_changes.append(
|
||||
StateChange(
|
||||
task_id=task_id,
|
||||
field_path="error",
|
||||
value=f"任务执行超过 {max_duration} 秒",
|
||||
)
|
||||
)
|
||||
await self.registry.remove_running(task_id)
|
||||
continue
|
||||
|
||||
tasks_by_type.setdefault(record.task_type, []).append(record)
|
||||
|
||||
# 3. 并行执行各 Handler 的 tick
|
||||
@@ -63,10 +108,14 @@ class AsyncEngine:
|
||||
]
|
||||
)
|
||||
|
||||
# 4. 收集并应用状态变更
|
||||
# 4. 收集并应用状态变更(包含超时任务)
|
||||
all_changes: list[StateChange] = []
|
||||
for changes in results:
|
||||
if changes:
|
||||
await self._apply_changes(changes)
|
||||
all_changes.extend(changes)
|
||||
all_changes.extend(timeout_changes)
|
||||
if all_changes:
|
||||
await self._apply_changes(all_changes)
|
||||
|
||||
# 5. 清理已结束的作业
|
||||
await self._cleanup_finished()
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.ai.prompts.loader import _load_system_meta
|
||||
from app.core.exceptions import InsufficientPointsException
|
||||
from app.core.platform_config import get_platform_config_loader
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
@@ -38,6 +39,7 @@ def _get_category_name(category: str, filename: str) -> str:
|
||||
return f"{cat_name} · {label}"
|
||||
return cat_name
|
||||
|
||||
|
||||
SLOT_KEY = "script:slots"
|
||||
|
||||
|
||||
@@ -64,9 +66,7 @@ class ScriptHandler(AsyncHandler):
|
||||
|
||||
def _get_service(self) -> ScriptService:
|
||||
if self.service is None:
|
||||
raise RuntimeError(
|
||||
"ScriptHandler 需要通过构造函数传入 ScriptService 实例"
|
||||
)
|
||||
raise RuntimeError("ScriptHandler 需要通过构造函数传入 ScriptService 实例")
|
||||
return self.service
|
||||
|
||||
async def tick(
|
||||
@@ -83,7 +83,9 @@ class ScriptHandler(AsyncHandler):
|
||||
changes.extend(await self._process_task(task, registry, slots))
|
||||
except Exception as e:
|
||||
logger.exception(f"[Script {task.task_id}] failed")
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="status", value="failed"))
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="error", value=str(e)[:500])
|
||||
)
|
||||
@@ -129,30 +131,66 @@ class ScriptHandler(AsyncHandler):
|
||||
"shot_count": len(shots),
|
||||
}
|
||||
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="status", value="completed"))
|
||||
# 生成成功后再扣费
|
||||
user_id = params.get("user_id")
|
||||
required_points = params.get("required_points", 0)
|
||||
if user_id and required_points > 0:
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
await ps.consume(
|
||||
db,
|
||||
user_id=user_id,
|
||||
points=required_points,
|
||||
source_type="script",
|
||||
source_id=task.task_id,
|
||||
description="【脚本生成】",
|
||||
)
|
||||
await db.commit()
|
||||
except InsufficientPointsException:
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="message",
|
||||
value="积分不足",
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="error_code",
|
||||
value="insufficient_points",
|
||||
)
|
||||
)
|
||||
return changes
|
||||
except Exception as e:
|
||||
logger.error(f"[ScriptTask {task.task_id}] 扣费失败: {e}")
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="message",
|
||||
value="扣费失败",
|
||||
)
|
||||
)
|
||||
return changes
|
||||
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="completed")
|
||||
)
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="progress", value=100))
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="message", value="脚本生成完成")
|
||||
)
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="completed", value=1))
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="result", value=result_data))
|
||||
|
||||
# 后置扣费(独立 session,失败不影响任务结果)
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
points = ps._calculate_cost("script")
|
||||
await ps.consume(
|
||||
db,
|
||||
user_id=task.user_id,
|
||||
points=points,
|
||||
source_type="script",
|
||||
source_id=task.task_id,
|
||||
description="【脚本生成】",
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="result", value=result_data)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[Script {task.task_id}] 扣费失败: {e}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"[ScriptTask {task.task_id}] Failed")
|
||||
@@ -160,6 +198,8 @@ class ScriptHandler(AsyncHandler):
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="message", value=str(exc)[:200])
|
||||
)
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500]))
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500])
|
||||
)
|
||||
|
||||
return changes
|
||||
|
||||
@@ -45,9 +45,7 @@ class SubtitleHandler(AsyncHandler):
|
||||
|
||||
def _get_service(self) -> VolcengineCaptionService:
|
||||
if self.service is None:
|
||||
raise RuntimeError(
|
||||
"SubtitleHandler 需要通过构造函数传入 VolcengineCaptionService 实例"
|
||||
)
|
||||
raise RuntimeError("SubtitleHandler 需要通过构造函数传入 VolcengineCaptionService 实例")
|
||||
return self.service
|
||||
|
||||
async def tick(
|
||||
@@ -93,7 +91,9 @@ class SubtitleHandler(AsyncHandler):
|
||||
"utterances": utterances,
|
||||
}
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="completed")
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="status", value="completed"
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="progress", value=100)
|
||||
@@ -106,9 +106,13 @@ class SubtitleHandler(AsyncHandler):
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="completed", value=1)
|
||||
)
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="result", value=result_payload)
|
||||
StateChange(task_id=task.task_id, field_path="total", value=1)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="result", value=result_payload
|
||||
)
|
||||
)
|
||||
elif status.state == "failed":
|
||||
changes.append(
|
||||
@@ -173,9 +177,13 @@ class SubtitleHandler(AsyncHandler):
|
||||
if not volc_task_id:
|
||||
raise ValueError("未返回任务ID")
|
||||
params["volc_task_id"] = volc_task_id
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="params", value=params))
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="message", value="字幕任务已提交")
|
||||
StateChange(task_id=task.task_id, field_path="params", value=params)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="message", value="字幕任务已提交"
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
logger.error(f"[Subtitle {task.task_id}] service not initialized")
|
||||
|
||||
@@ -9,17 +9,27 @@ Video 任务处理器
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.exceptions import PlatformError, PlatformErrorType
|
||||
from app.core.platform_config import get_platform_config_loader
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.scheduler.handlers.base import AsyncHandler
|
||||
from app.scheduler.models import StateChange
|
||||
from app.scheduler.registry import TaskRegistry
|
||||
from app.scheduler.slot_manager import SlotManager
|
||||
from app.services.vidu_service import ViduService
|
||||
from app.utils.content_fingerprint import (
|
||||
compute_content_fingerprint,
|
||||
extract_vidu_error_code,
|
||||
is_vidu_audit_error,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLOT_KEY = "vidu:video_slots"
|
||||
|
||||
_AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection"
|
||||
_AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时
|
||||
|
||||
|
||||
def _get_video_max_slots() -> int:
|
||||
"""从 platform-config.yaml 读取 rate_limit 配置作为 max_slots"""
|
||||
@@ -45,11 +55,39 @@ class VideoHandler(AsyncHandler):
|
||||
|
||||
def _get_service(self) -> ViduService:
|
||||
if self.service is None:
|
||||
raise RuntimeError(
|
||||
"VideoHandler 需要通过构造函数传入 ViduService 实例"
|
||||
)
|
||||
raise RuntimeError("VideoHandler 需要通过构造函数传入 ViduService 实例")
|
||||
return self.service
|
||||
|
||||
async def _cache_audit_rejection_if_needed(
|
||||
self, params: dict[str, Any], error_message: str | None
|
||||
) -> None:
|
||||
"""如果失败原因是 Vidu 审核类错误,缓存内容指纹防止重复提交。"""
|
||||
err_code = extract_vidu_error_code(error_message)
|
||||
if not err_code or not is_vidu_audit_error(err_code):
|
||||
return
|
||||
|
||||
fingerprint = compute_content_fingerprint(
|
||||
task_type="lip_sync",
|
||||
video_url=params.get("video_url"),
|
||||
audio_url=params.get("audio_url"),
|
||||
ref_photo_url=params.get("ref_photo_url"),
|
||||
text=params.get("text"),
|
||||
voice_id=params.get("voice_id"),
|
||||
)
|
||||
if not fingerprint:
|
||||
return
|
||||
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}"
|
||||
await redis.setex(key, _AUDIT_REJECTION_TTL, err_code)
|
||||
logger.info(
|
||||
f"[Video] 审核失败内容已缓存: fingerprint={fingerprint[:16]}..., "
|
||||
f"err_code={err_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Video] 缓存审核失败结果出错: {e}")
|
||||
|
||||
async def tick(
|
||||
self, tasks: list[Any], registry: TaskRegistry, slots: SlotManager
|
||||
) -> list[StateChange]:
|
||||
@@ -87,11 +125,7 @@ class VideoHandler(AsyncHandler):
|
||||
value=1,
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="total", value=1
|
||||
)
|
||||
)
|
||||
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
|
||||
elif task.status == "failed":
|
||||
# callback 已标记失败,移除 running
|
||||
await registry.remove_running(task.task_id)
|
||||
@@ -150,6 +184,15 @@ class VideoHandler(AsyncHandler):
|
||||
f"video_url={video_url[:60]}..."
|
||||
)
|
||||
elif vidu_state == "failed":
|
||||
error_message = vidu_status.get("message") or "视频生成失败"
|
||||
err_code = extract_vidu_error_code(error_message)
|
||||
is_audit = err_code and is_vidu_audit_error(err_code)
|
||||
message = (
|
||||
"人物分镜台词未通过安全审核,请修改后重试"
|
||||
if is_audit
|
||||
else "视频生成失败"
|
||||
)
|
||||
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
@@ -161,20 +204,29 @@ class VideoHandler(AsyncHandler):
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="message",
|
||||
value="视频生成失败",
|
||||
value=message,
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="error",
|
||||
value=vidu_status.get("message") or "视频生成失败",
|
||||
value=error_message,
|
||||
)
|
||||
)
|
||||
if is_audit:
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="error_code",
|
||||
value="content_violation",
|
||||
)
|
||||
)
|
||||
await self._cache_audit_rejection_if_needed(params, error_message)
|
||||
await registry.remove_running(task.task_id)
|
||||
logger.warning(
|
||||
f"[Video {task.task_id}] 主动查询 Vidu 任务失败: "
|
||||
f"{vidu_status.get('message')}"
|
||||
f"{error_message}"
|
||||
)
|
||||
else:
|
||||
# 仍在处理中,继续等待
|
||||
@@ -182,15 +234,11 @@ class VideoHandler(AsyncHandler):
|
||||
f"[Video {task.task_id}] 主动查询 Vidu 状态: {vidu_state},继续等待"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Video {task.task_id}] 主动查询 Vidu 失败: {e}"
|
||||
)
|
||||
logger.warning(f"[Video {task.task_id}] 主动查询 Vidu 失败: {e}")
|
||||
continue # ← 已提交,不再重复提交
|
||||
|
||||
# 提交阶段:占用 slot,提交成功后自动释放
|
||||
async with slots.acquire_ctx(
|
||||
SLOT_KEY, task.task_id, self.max_slots
|
||||
) as acquired:
|
||||
async with slots.acquire_ctx(SLOT_KEY, task.task_id, self.max_slots) as acquired:
|
||||
if not acquired:
|
||||
continue
|
||||
|
||||
@@ -227,9 +275,7 @@ class VideoHandler(AsyncHandler):
|
||||
raise ValueError("未返回任务ID")
|
||||
params["vidu_task_id"] = vidu_task_id
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="params", value=params
|
||||
)
|
||||
StateChange(task_id=task.task_id, field_path="params", value=params)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
@@ -241,9 +287,7 @@ class VideoHandler(AsyncHandler):
|
||||
except RuntimeError:
|
||||
logger.error(f"[Video {task.task_id}] service not initialized")
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="status", value="failed"
|
||||
)
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
@@ -259,12 +303,50 @@ class VideoHandler(AsyncHandler):
|
||||
value="视频处理服务未就绪",
|
||||
)
|
||||
)
|
||||
except PlatformError as e:
|
||||
error_message = str(e)
|
||||
is_audit = e.error_type == PlatformErrorType.CONTENT_VIOLATION
|
||||
if is_audit:
|
||||
message = "人物分镜台词未通过安全审核,请修改后重试"
|
||||
elif e.error_type == PlatformErrorType.AUTH_FAILED:
|
||||
message = "视频服务认证失败,请联系客服"
|
||||
elif e.error_type == PlatformErrorType.RATE_LIMIT:
|
||||
message = "视频生成服务繁忙,请稍后重试"
|
||||
else:
|
||||
message = "视频生成任务提交失败,请稍后重试"
|
||||
|
||||
logger.error(f"[Video {task.task_id}] submit platform error: {e}")
|
||||
changes.append(
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="message",
|
||||
value=message,
|
||||
)
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="error",
|
||||
value=error_message,
|
||||
)
|
||||
)
|
||||
if is_audit:
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id,
|
||||
field_path="error_code",
|
||||
value="content_violation",
|
||||
)
|
||||
)
|
||||
# 审核类错误在 PlatformGateway 已写缓存,此处幂等补充
|
||||
await self._cache_audit_rejection_if_needed(params, error_message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Video {task.task_id}] submit error: {e}")
|
||||
changes.append(
|
||||
StateChange(
|
||||
task_id=task.task_id, field_path="status", value="failed"
|
||||
)
|
||||
StateChange(task_id=task.task_id, field_path="status", value="failed")
|
||||
)
|
||||
changes.append(
|
||||
StateChange(
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskRecord:
|
||||
total: int = 0
|
||||
result: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
params: dict[str, Any] | Any = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
@@ -7,8 +7,9 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
@@ -54,8 +55,8 @@ class TaskRegistry:
|
||||
if params:
|
||||
data["params"] = json.dumps(params, ensure_ascii=False)
|
||||
|
||||
await self.redis.hset(_task_key(task_id), mapping=data)
|
||||
await self.redis.expire(_task_key(task_id), ttl)
|
||||
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=data))
|
||||
await cast(Awaitable[int], self.redis.expire(_task_key(task_id), ttl))
|
||||
logger.debug(f"Registry created: {task_id}, type={task_type}")
|
||||
|
||||
async def update(self, task_id: str, **fields: Any) -> None:
|
||||
@@ -68,12 +69,12 @@ class TaskRegistry:
|
||||
mapping[key] = ""
|
||||
else:
|
||||
mapping[key] = str(value)
|
||||
await self.redis.hset(_task_key(task_id), mapping=mapping)
|
||||
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=mapping))
|
||||
logger.debug(f"Registry updated: {task_id}, fields={list(fields.keys())}")
|
||||
|
||||
async def get(self, task_id: str) -> TaskRecord | None:
|
||||
"""读取完整作业记录"""
|
||||
data = await self.redis.hgetall(_task_key(task_id))
|
||||
data = await cast(Awaitable[dict[Any, Any]], self.redis.hgetall(_task_key(task_id)))
|
||||
if not data:
|
||||
return None
|
||||
|
||||
@@ -88,6 +89,8 @@ class TaskRegistry:
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
return 0
|
||||
if key in ("error_code", "error") and raw == "":
|
||||
return None
|
||||
return raw
|
||||
|
||||
parsed = {k: _parse(k, v) for k, v in data.items()}
|
||||
@@ -107,21 +110,22 @@ class TaskRegistry:
|
||||
total=parsed.get("total", 0),
|
||||
result=parsed.get("result", {}),
|
||||
error=parsed.get("error"),
|
||||
error_code=parsed.get("error_code"),
|
||||
params=params,
|
||||
created_at=parsed.get("created_at", ""),
|
||||
)
|
||||
|
||||
async def add_running(self, task_id: str) -> None:
|
||||
"""将作业标记为 running(加入全局 running 集合)"""
|
||||
await self.redis.sadd(KEY_RUNNING_SET, task_id)
|
||||
await cast(Awaitable[int], self.redis.sadd(KEY_RUNNING_SET, task_id))
|
||||
|
||||
async def remove_running(self, task_id: str) -> None:
|
||||
"""将作业从全局 running 集合移除"""
|
||||
await self.redis.srem(KEY_RUNNING_SET, task_id)
|
||||
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))
|
||||
|
||||
async def get_running_task_ids(self) -> list[str]:
|
||||
"""获取所有 running 的作业 ID 列表"""
|
||||
members = await self.redis.smembers(KEY_RUNNING_SET)
|
||||
members = await cast(Awaitable[set[Any]], self.redis.smembers(KEY_RUNNING_SET))
|
||||
return list(members)
|
||||
|
||||
async def list_running_by_user(self, user_id: str) -> list[TaskRecord]:
|
||||
@@ -139,5 +143,5 @@ class TaskRegistry:
|
||||
|
||||
async def delete(self, task_id: str) -> None:
|
||||
"""删除作业记录"""
|
||||
await self.redis.delete(_task_key(task_id))
|
||||
await self.redis.srem(KEY_RUNNING_SET, task_id)
|
||||
await cast(Awaitable[int], self.redis.delete(_task_key(task_id)))
|
||||
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
@@ -37,7 +39,9 @@ class SlotManager:
|
||||
async def acquire(self, slot_key: str, slot_id: str, max_slots: int) -> bool:
|
||||
"""申请一个槽位。返回 True 表示成功,False 表示槽位已满。"""
|
||||
try:
|
||||
result = await self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots))
|
||||
result = await cast(
|
||||
Awaitable[str], self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots))
|
||||
)
|
||||
acquired = result == 1
|
||||
if acquired:
|
||||
logger.debug(f"Slot acquired: {slot_key}/{slot_id} (max={max_slots})")
|
||||
@@ -51,7 +55,7 @@ class SlotManager:
|
||||
async def release(self, slot_key: str, slot_id: str) -> None:
|
||||
"""释放一个槽位。"""
|
||||
try:
|
||||
await self.redis.srem(slot_key, slot_id)
|
||||
await cast(Awaitable[int], self.redis.srem(slot_key, slot_id))
|
||||
logger.debug(f"Slot released: {slot_key}/{slot_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Slot release error: {slot_key}/{slot_id}: {e}")
|
||||
@@ -78,6 +82,6 @@ class SlotManager:
|
||||
async def count(self, slot_key: str) -> int:
|
||||
"""获取当前已占用的槽位数量。"""
|
||||
try:
|
||||
return await self.redis.scard(slot_key)
|
||||
return await cast(Awaitable[int], self.redis.scard(slot_key))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
==============
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BgmMusicItem(BaseModel):
|
||||
@@ -18,8 +18,7 @@ class BgmMusicItem(BaseModel):
|
||||
duration: float | None = Field(default=None, description="时长(秒)")
|
||||
sort_order: int = Field(default=0, description="排序权重")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class BgmMusicListResponse(BaseModel):
|
||||
|
||||
@@ -6,14 +6,12 @@
|
||||
{ code: number; data: T; message: string }
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
class ApiResponse[T](BaseModel):
|
||||
"""
|
||||
统一 API 响应格式
|
||||
|
||||
@@ -27,17 +25,18 @@ class ApiResponse(BaseModel, Generic[T]):
|
||||
data: T | None = Field(default=None, description="响应数据")
|
||||
message: str = Field(default="success", description="提示信息")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"code": 200,
|
||||
"data": {},
|
||||
"message": "success",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PaginatedData(BaseModel, Generic[T]):
|
||||
class PaginatedData[T](BaseModel):
|
||||
"""分页数据包装"""
|
||||
|
||||
items: list[T] = Field(description="数据列表")
|
||||
@@ -61,18 +60,23 @@ class PaginationParams(BaseModel):
|
||||
class ApiErrorResponse(BaseModel):
|
||||
"""错误响应格式"""
|
||||
|
||||
code: int = Field(description="错误码")
|
||||
code: int = Field(description="HTTP 状态码")
|
||||
message: str = Field(description="错误信息")
|
||||
error_code: str | None = Field(default=None, description="应用级错误码")
|
||||
detail: dict[str, Any] | None = Field(default=None, description="详细错误信息")
|
||||
|
||||
|
||||
def success_response(data: T | None = None, message: str = "success") -> ApiResponse[T]:
|
||||
def success_response[T](data: T | None = None, message: str = "success") -> ApiResponse[T]:
|
||||
"""构造成功响应"""
|
||||
return ApiResponse(code=200, data=data, message=message)
|
||||
|
||||
|
||||
def error_response(
|
||||
code: int, message: str, detail: dict[str, Any] | None = None
|
||||
code: int,
|
||||
message: str,
|
||||
detail: dict[str, Any] | None = None,
|
||||
*,
|
||||
error_code: str | None = None,
|
||||
) -> ApiErrorResponse:
|
||||
"""构造错误响应"""
|
||||
return ApiErrorResponse(code=code, message=message, detail=detail)
|
||||
return ApiErrorResponse(code=code, message=message, error_code=error_code, detail=detail)
|
||||
|
||||
@@ -25,6 +25,3 @@ class SegmentStatus(StrEnum):
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -18,9 +18,7 @@ class MatchMaterialRequest(BaseModel):
|
||||
|
||||
scene: str = Field(description="分镜场景描述")
|
||||
duration: float = Field(description="所需时长(秒)", gt=0)
|
||||
project_id: str | None = Field(
|
||||
default=None, description="项目ID,用于跨分镜去重"
|
||||
)
|
||||
project_id: str | None = Field(default=None, description="项目ID,用于跨分镜去重")
|
||||
|
||||
@field_validator("scene")
|
||||
@classmethod
|
||||
@@ -56,9 +54,7 @@ class BatchMatchSceneItem(BaseModel):
|
||||
class BatchMatchMaterialRequest(BaseModel):
|
||||
"""批量匹配素材请求"""
|
||||
|
||||
project_id: str | None = Field(
|
||||
default=None, description="项目ID,用于跨分镜去重"
|
||||
)
|
||||
project_id: str | None = Field(default=None, description="项目ID,用于跨分镜去重")
|
||||
scenes: list[BatchMatchSceneItem] = Field(
|
||||
description="分镜场景列表",
|
||||
min_length=1,
|
||||
@@ -76,6 +72,4 @@ class BatchMatchMaterialResponse(BaseModel):
|
||||
"""批量匹配素材响应"""
|
||||
|
||||
project_id: str | None = Field(description="项目ID")
|
||||
results: list[MaterialInfo | None] = Field(
|
||||
description="匹配结果列表,与 scenes 一一对应"
|
||||
)
|
||||
results: list[MaterialInfo | None] = Field(description="匹配结果列表,与 scenes 一一对应")
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ── 余额查询 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class PointBalanceResponse(BaseModel):
|
||||
"""积分余额响应"""
|
||||
|
||||
@@ -20,6 +21,7 @@ class PointBalanceResponse(BaseModel):
|
||||
|
||||
# ── 流水记录 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class PointTransactionItem(BaseModel):
|
||||
"""单条流水记录"""
|
||||
|
||||
@@ -33,7 +35,10 @@ class PointTransactionItem(BaseModel):
|
||||
source_type: str | None = Field(None, description="消费来源类型")
|
||||
source_id: str | None = Field(None, description="消费来源业务 ID")
|
||||
duration: float | None = Field(None, description="时长(秒),按秒计费业务记录")
|
||||
category: str | None = Field(None, description="业务分类:脚本生成 / 配音合成 / 视频生成 / 压制成片 / 字幕烧录 / 封面设计 / 充值")
|
||||
category: str | None = Field(
|
||||
None,
|
||||
description="业务分类:脚本生成 / 配音合成 / 视频生成 / 压制成片 / 字幕烧录 / 封面设计 / 充值",
|
||||
)
|
||||
description: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
@@ -49,6 +54,7 @@ class PointTransactionListResponse(BaseModel):
|
||||
|
||||
# ── 充值 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class RechargeRequest(BaseModel):
|
||||
"""充值请求"""
|
||||
|
||||
@@ -70,6 +76,7 @@ class RechargeResponse(BaseModel):
|
||||
|
||||
# ── 积分预估 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class CostEstimateResponse(BaseModel):
|
||||
"""积分预估响应"""
|
||||
|
||||
@@ -80,6 +87,7 @@ class CostEstimateResponse(BaseModel):
|
||||
|
||||
# ── 充值订单 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class RechargeOrderItem(BaseModel):
|
||||
"""充值订单记录"""
|
||||
|
||||
@@ -104,11 +112,16 @@ class RechargeOrderListResponse(BaseModel):
|
||||
|
||||
# ── 消费扣费 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class ConsumeRequest(BaseModel):
|
||||
"""消费扣费请求(供前端/Rust 层调用)"""
|
||||
|
||||
points: int = Field(gt=0, description="消耗积分数量")
|
||||
source_type: str = Field(description="消费来源类型,如 compose / subtitle_burn / cover_design / video")
|
||||
source_type: str = Field(
|
||||
description="消费来源类型,如 compose / subtitle_burn / cover_design / video"
|
||||
)
|
||||
source_id: str = Field(description="消费来源业务 ID")
|
||||
description: str | None = Field(default=None, description="消费描述")
|
||||
allow_negative: bool = Field(default=False, description="是否允许扣费后余额为负(不确定消耗场景用)")
|
||||
allow_negative: bool = Field(
|
||||
default=False, description="是否允许扣费后余额为负(不确定消耗场景用)"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
===============
|
||||
"""
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.schemas.segment import Segment
|
||||
@@ -93,7 +92,9 @@ class GenerateTitleRequest(BaseModel):
|
||||
script_content: str = Field(..., description="脚本内容(utterances 文本拼接)", min_length=1)
|
||||
title_type: str = Field(..., description="标题类型:main(大标题) / sub(小标题)")
|
||||
max_length: int = Field(default=8, ge=1, le=100, description="最大字数限制")
|
||||
usage: str = Field(default="video", description="使用场景:video(视频画面标题) / cover(封面标题)")
|
||||
usage: str = Field(
|
||||
default="video", description="使用场景:video(视频画面标题) / cover(封面标题)"
|
||||
)
|
||||
|
||||
|
||||
class GenerateTitleResponse(BaseModel):
|
||||
|
||||
@@ -32,9 +32,7 @@ class Segment(BaseModel):
|
||||
duration: int | None = Field(default=None, description="时长(秒)")
|
||||
voice_id: str | None = Field(default=None, description="音色ID(空镜时使用)")
|
||||
status: SegmentStatus = Field(default=SegmentStatus.PENDING)
|
||||
provider_task_id: str | None = Field(
|
||||
default=None, description="供应商任务ID"
|
||||
)
|
||||
provider_task_id: str | None = Field(default=None, description="供应商任务ID")
|
||||
video_url: str | None = Field(default=None, description="生成后的视频URL")
|
||||
local_path: str | None = Field(default=None, description="本地视频路径")
|
||||
qiniu_url: str | None = Field(default=None, description="七牛云URL")
|
||||
|
||||
@@ -6,7 +6,6 @@ Tauri updater 插件所需的请求/响应模型。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -58,6 +57,7 @@ class ReleaseListItem(BaseModel):
|
||||
# Tauri updater 插件所需的 JSON 格式
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TauriPlatformInfo(BaseModel):
|
||||
"""Tauri updater 单平台信息"""
|
||||
|
||||
@@ -73,7 +73,7 @@ class TauriUpdateResponse(BaseModel):
|
||||
|
||||
version: str = Field(..., description="新版本号")
|
||||
notes: str = Field(default="", description="更新说明")
|
||||
pub_date: Optional[str] = Field(default=None, description="发布时间(RFC 3339)")
|
||||
pub_date: str | None = Field(default=None, description="发布时间(RFC 3339)")
|
||||
mandatory: bool = Field(default=False, description="是否强制更新(自定义扩展字段)")
|
||||
platforms: dict[str, TauriPlatformInfo] = Field(
|
||||
..., description="平台安装包映射,key 格式:OS-ARCH"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -30,8 +30,7 @@ class UserProfileResponse(BaseModel):
|
||||
last_login_at: datetime | None = Field(None, description="最后登录时间")
|
||||
created_at: datetime = Field(..., description="注册时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UpdateNicknameRequest(BaseModel):
|
||||
|
||||
@@ -220,7 +220,10 @@ def validate_and_normalize_shots(raw_data: Any) -> list[dict[str, Any]]:
|
||||
duration = item.get("duration")
|
||||
duration_str = parse_duration(duration) # 返回如 "5s"
|
||||
try:
|
||||
normalized["duration"] = int(re.search(r"\d+", duration_str).group())
|
||||
match = re.search(r"\d+", duration_str)
|
||||
if match is None:
|
||||
raise ValueError("时长字符串中未找到数字")
|
||||
normalized["duration"] = int(match.group())
|
||||
except (AttributeError, ValueError):
|
||||
normalized["duration"] = 5
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import hashlib
|
||||
import logging
|
||||
import random
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -73,6 +74,7 @@ async def _kick_old_device(user_id: str) -> None:
|
||||
|
||||
# ========== 验证码校验 ==========
|
||||
|
||||
|
||||
async def verify_sms_code(mobile: str, code: str) -> bool:
|
||||
"""
|
||||
校验短信验证码。
|
||||
@@ -145,9 +147,7 @@ async def send_sms_code(mobile: str) -> str:
|
||||
)
|
||||
else:
|
||||
# 配置不完整,记录警告但不打印验证码
|
||||
logger.warning(
|
||||
f"[SMS] B2M 短信配置不完整,验证码未发送: 手机号={mobile}"
|
||||
)
|
||||
logger.warning(f"[SMS] B2M 短信配置不完整,验证码未发送: 手机号={mobile}")
|
||||
except SMSError as e:
|
||||
logger.error(f"[SMS] 短信发送失败: {e}")
|
||||
# 短信发送失败不影响验证码生成
|
||||
@@ -159,6 +159,7 @@ async def send_sms_code(mobile: str) -> str:
|
||||
|
||||
# ========== Token 工具 ==========
|
||||
|
||||
|
||||
def _hash_refresh_token(token: str) -> str:
|
||||
"""Refresh Token SHA256 哈希(用于数据库存储)"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
@@ -166,6 +167,7 @@ def _hash_refresh_token(token: str) -> str:
|
||||
|
||||
# ========== 登录服务 ==========
|
||||
|
||||
|
||||
async def login_with_sms(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -260,9 +262,7 @@ async def refresh_access_token(
|
||||
refresh_token_hash = _hash_refresh_token(refresh_token)
|
||||
|
||||
# 2. 查设备记录
|
||||
device = await device_crud.get_by_refresh_token_hash(
|
||||
db, refresh_token_hash=refresh_token_hash
|
||||
)
|
||||
device = await device_crud.get_by_refresh_token_hash(db, refresh_token_hash=refresh_token_hash)
|
||||
if device is None:
|
||||
raise ValueError("设备已失效,请重新登录")
|
||||
|
||||
@@ -397,7 +397,7 @@ async def reset_password_with_sms(
|
||||
return True
|
||||
|
||||
|
||||
async def logout(db: AsyncSession, *, user_id: str) -> bool:
|
||||
async def logout(db: AsyncSession, *, user_id: UUID | str) -> bool:
|
||||
"""
|
||||
用户登出。
|
||||
|
||||
@@ -406,14 +406,14 @@ async def logout(db: AsyncSession, *, user_id: str) -> bool:
|
||||
2. 注销 SSE 连接
|
||||
"""
|
||||
await device_crud.delete_by_user_id(db, user_id=user_id)
|
||||
unregister_sse_connection(user_id)
|
||||
unregister_sse_connection(str(user_id))
|
||||
return True
|
||||
|
||||
|
||||
async def get_current_user_device(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
user_id: UUID | str,
|
||||
) -> UserDevice | None:
|
||||
"""获取当前用户的设备记录"""
|
||||
return await device_crud.get_by_user_id(db, user_id=user_id)
|
||||
|
||||
@@ -9,6 +9,8 @@ import logging
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -16,6 +18,7 @@ from app.core.exceptions import ValidationException
|
||||
from app.core.redis_client import get_redis_client
|
||||
from app.crud import broll_category, broll_material
|
||||
from app.models.broll_category import BrollCategory
|
||||
from app.models.broll_material import BrollMaterial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,13 +31,13 @@ def _normalize_scene(scene: str) -> str:
|
||||
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
|
||||
cleaned = re.sub(r"\s+", "", scene)
|
||||
# 去除常见中文标点符号(逗号、句号、感叹号、问号、顿号、分号、冒号、引号、括号等)
|
||||
cleaned = re.sub(r"[,。!?、;:""''()【】《》]+", "", cleaned)
|
||||
cleaned = re.sub(r"[,。!?、;:" "''()【】《》]+", "", cleaned)
|
||||
# 去除零宽字符(零宽空格、零宽非连接符、零宽连接符、零宽非断空格等)
|
||||
cleaned = re.sub(r"[\u200b-\u200f\ufeff]+", "", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
||||
def _weighted_choice(materials: list[BrollMaterial]) -> BrollMaterial:
|
||||
"""
|
||||
加权随机选择素材
|
||||
|
||||
@@ -50,9 +53,9 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
|
||||
total_weight = sum(weights)
|
||||
|
||||
if total_weight == 0:
|
||||
return random.choice(materials)
|
||||
return random.choice(materials) # nosec B311 素材抽样,非加密场景
|
||||
|
||||
r = random.uniform(0, total_weight)
|
||||
r = random.uniform(0, total_weight) # nosec B311 素材抽样,非加密场景
|
||||
cumulative = 0.0
|
||||
for m, w in zip(materials, weights, strict=True):
|
||||
cumulative += w
|
||||
@@ -74,10 +77,9 @@ async def _try_fallback_to_parent(
|
||||
- 若 scene 含 '-',取后半部分作为 parent_name(如 '电路施工-电路施工' -> '电路施工')
|
||||
- 若不含 '-',直接以整个 scene 作为 parent_name
|
||||
|
||||
匹配策略(逐级降级):
|
||||
匹配策略:
|
||||
1. 精确匹配 level=2 分类 name
|
||||
2. 模糊匹配(LIKE %parent_name%),兼容 "电路施工" → "电路施工镜"
|
||||
3. 去掉常见后缀(镜、阶段等)再精确匹配
|
||||
2. 模糊匹配(LIKE %parent_name%)
|
||||
|
||||
返回:
|
||||
随机选中的一个 level=3 子分类,或 None
|
||||
@@ -88,37 +90,20 @@ async def _try_fallback_to_parent(
|
||||
parent_name = normalized_scene
|
||||
|
||||
# 1. 精确匹配
|
||||
parent = await broll_category.get_by_name_and_level(
|
||||
db, name=parent_name, level=2
|
||||
)
|
||||
parent = await broll_category.get_by_name_and_level(db, name=parent_name, level=2)
|
||||
|
||||
# 2. 模糊匹配(兼容 "电路施工" → "电路施工镜")
|
||||
# 2. 模糊匹配
|
||||
if parent is None:
|
||||
parent = await broll_category.get_by_name_like_and_level(
|
||||
db, name=parent_name, level=2
|
||||
)
|
||||
|
||||
# 3. 去掉常见后缀再试
|
||||
if parent is None:
|
||||
for suffix in ("镜", "阶段"):
|
||||
if not parent_name.endswith(suffix):
|
||||
candidate = parent_name + suffix
|
||||
parent = await broll_category.get_by_name_and_level(
|
||||
db, name=candidate, level=2
|
||||
)
|
||||
if parent:
|
||||
break
|
||||
parent = await broll_category.get_by_name_like_and_level(db, name=parent_name, level=2)
|
||||
|
||||
if parent is None:
|
||||
return None
|
||||
|
||||
children = await broll_category.get_children_by_parent_id(
|
||||
db, parent_id=parent.id, level=3
|
||||
)
|
||||
children = await broll_category.get_children_by_parent_id(db, parent_id=parent.id, level=3)
|
||||
if not children:
|
||||
return None
|
||||
|
||||
return random.choice(children)
|
||||
return random.choice(children) # nosec B311 素材抽样,非加密场景
|
||||
|
||||
|
||||
async def match_material(
|
||||
@@ -161,46 +146,34 @@ async def match_material(
|
||||
normalized = _normalize_scene(scene)
|
||||
|
||||
# 1. 查找三级分类(精确匹配 -> 全量内存匹配兜底 -> 顺序颠倒 -> 上级回退)
|
||||
category = await broll_category.get_by_name_and_level(
|
||||
db, name=normalized, level=3
|
||||
)
|
||||
category = await broll_category.get_by_name_and_level(db, name=normalized, level=3)
|
||||
# 精确匹配失败时,全量查询后在内存标准化匹配(兼容数据库 name 含不可见字符)
|
||||
if category is None:
|
||||
all_categories = await broll_category.get_by_level(db, level=3)
|
||||
for c in all_categories:
|
||||
if _normalize_scene(c.name) == normalized:
|
||||
category = c
|
||||
logger.info(
|
||||
f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'"
|
||||
)
|
||||
logger.info(f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'")
|
||||
break
|
||||
# 若仍失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
|
||||
if category is None:
|
||||
parts = normalized.rsplit("-", 1)
|
||||
if len(parts) == 2:
|
||||
reversed_name = f"{parts[1]}-{parts[0]}"
|
||||
category = await broll_category.get_by_name_and_level(
|
||||
db, name=reversed_name, level=3
|
||||
)
|
||||
category = await broll_category.get_by_name_and_level(db, name=reversed_name, level=3)
|
||||
if category:
|
||||
logger.info(
|
||||
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
|
||||
)
|
||||
logger.info(f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'")
|
||||
# 若仍失败,回退到上级分类随机选取
|
||||
if category is None:
|
||||
category = await _try_fallback_to_parent(db, normalized)
|
||||
if category:
|
||||
logger.info(
|
||||
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
|
||||
)
|
||||
logger.info(f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'")
|
||||
if category is None:
|
||||
logger.warning(f"素材匹配失败: 未找到分类 '{normalized}' (原始 scene: '{scene}')")
|
||||
return None
|
||||
|
||||
# 2. 查询该分类下所有 active 素材(先不过滤时长,用于日志诊断)
|
||||
all_materials = await broll_material.get_active_by_categories(
|
||||
db, category_ids=[category.id]
|
||||
)
|
||||
all_materials = await broll_material.get_active_by_categories(db, category_ids=[category.id])
|
||||
if not all_materials:
|
||||
logger.warning(f"素材匹配失败: 分类 '{normalized}' 下无任何可用素材")
|
||||
return None
|
||||
@@ -224,7 +197,7 @@ async def match_material(
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"proj:{project_id}:used_materials"
|
||||
used_urls = set(await redis.smembers(key))
|
||||
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||||
|
||||
@@ -243,8 +216,8 @@ async def match_material(
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"proj:{project_id}:used_materials"
|
||||
await redis.sadd(key, chosen.url)
|
||||
await redis.expire(key, _USED_MATERIALS_TTL)
|
||||
await cast(Awaitable[int], redis.sadd(key, chosen.url))
|
||||
await cast(Awaitable[int], redis.expire(key, _USED_MATERIALS_TTL))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 去重记录失败: {e}")
|
||||
|
||||
@@ -282,10 +255,8 @@ async def batch_match(
|
||||
unique_names = list(set(normalized_scenes))
|
||||
|
||||
# 2. 批量查询分类:优先精确查询,失败时全量内存匹配兜底
|
||||
categories = await broll_category.get_by_names_and_level(
|
||||
db, names=unique_names, level=3
|
||||
)
|
||||
category_map: dict[str, object] = {}
|
||||
categories = await broll_category.get_by_names_and_level(db, names=unique_names, level=3)
|
||||
category_map: dict[str, BrollCategory] = {}
|
||||
for c in categories:
|
||||
category_map[_normalize_scene(c.name)] = c
|
||||
|
||||
@@ -305,16 +276,14 @@ async def batch_match(
|
||||
if len(parts) == 2:
|
||||
reversed_map[name] = f"{parts[1]}-{parts[0]}"
|
||||
|
||||
scene_to_category: dict[str, object] = {}
|
||||
scene_to_category: dict[str, BrollCategory] = {}
|
||||
for name in unique_names:
|
||||
if name in category_map:
|
||||
scene_to_category[name] = category_map[name]
|
||||
elif name in reversed_map and reversed_map[name] in category_map:
|
||||
rev = reversed_map[name]
|
||||
scene_to_category[name] = category_map[rev]
|
||||
logger.info(
|
||||
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
|
||||
)
|
||||
logger.info(f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'")
|
||||
|
||||
# 3. 未匹配的 scene 回退到上级分类随机选取
|
||||
unmatched = [name for name in unique_names if name not in scene_to_category]
|
||||
@@ -322,15 +291,11 @@ async def batch_match(
|
||||
fallback_cat = await _try_fallback_to_parent(db, name)
|
||||
if fallback_cat:
|
||||
scene_to_category[name] = fallback_cat
|
||||
logger.info(
|
||||
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
|
||||
)
|
||||
logger.info(f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'")
|
||||
|
||||
# 4. 收集所有需要的 category_id
|
||||
needed_category_ids = [
|
||||
scene_to_category[name].id
|
||||
for name in unique_names
|
||||
if name in scene_to_category
|
||||
scene_to_category[name].id for name in unique_names if name in scene_to_category
|
||||
]
|
||||
|
||||
# 4. 批量查询素材(1 次 DB)
|
||||
@@ -339,7 +304,7 @@ async def batch_match(
|
||||
)
|
||||
|
||||
# 按 category_id 分组,方便内存过滤
|
||||
materials_by_category: dict[int, list] = {}
|
||||
materials_by_category: dict[int, list[BrollMaterial]] = {}
|
||||
for m in all_materials:
|
||||
materials_by_category.setdefault(m.category_id, []).append(m)
|
||||
|
||||
@@ -349,13 +314,13 @@ async def batch_match(
|
||||
try:
|
||||
redis = get_redis_client()
|
||||
key = f"proj:{project_id}:used_materials"
|
||||
used_urls = set(await redis.smembers(key))
|
||||
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
|
||||
|
||||
# 6. 内存中逐个匹配
|
||||
results: list[dict | None] = []
|
||||
chosen_materials: list = [] # 记录选中的素材,用于批量更新
|
||||
chosen_materials: list[BrollMaterial] = [] # 记录选中的素材,用于批量更新
|
||||
|
||||
for idx, scene_name in enumerate(normalized_scenes):
|
||||
required_duration = scenes[idx]["duration"]
|
||||
|
||||
@@ -25,7 +25,7 @@ import logging
|
||||
import math
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.exceptions import InsufficientPointsException
|
||||
from app.models.point_batch import PointBatch
|
||||
from app.models.point_transaction import PointTransaction
|
||||
from app.models.user_point import UserPoint
|
||||
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
|
||||
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
|
||||
|
||||
|
||||
def _load_points_config() -> dict:
|
||||
def _load_points_config() -> dict[str, Any]:
|
||||
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
|
||||
if not _CONFIG_PATH.exists():
|
||||
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
|
||||
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
with open(_CONFIG_PATH, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
|
||||
merged: dict[str, dict] = {}
|
||||
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
|
||||
return merged
|
||||
|
||||
|
||||
POINTS_CONFIG: dict[str, dict] = _load_points_config()
|
||||
POINTS_CONFIG: dict[str, Any] = _load_points_config()
|
||||
|
||||
|
||||
def get_recharge_options() -> list[dict]:
|
||||
"""获取充值档位配置(由后端控制,支持积分赠送)"""
|
||||
return POINTS_CONFIG.get("_recharge_options", [])
|
||||
options = POINTS_CONFIG.get("_recharge_options", [])
|
||||
if isinstance(options, list):
|
||||
return options
|
||||
return []
|
||||
|
||||
|
||||
def get_chargeable_source_types() -> list[str]:
|
||||
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
|
||||
return [
|
||||
key for key, cfg in POINTS_CONFIG.items()
|
||||
key
|
||||
for key, cfg in POINTS_CONFIG.items()
|
||||
if not key.startswith("_") and cfg.get("mode") != "free"
|
||||
]
|
||||
|
||||
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
|
||||
|
||||
# ── 余额查询 ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
||||
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
|
||||
result = await db.execute(
|
||||
select(UserPoint).where(UserPoint.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
if not up:
|
||||
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
||||
from sqlalchemy import func as _func
|
||||
|
||||
available_result = await db.execute(
|
||||
select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
|
||||
.where(
|
||||
select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
|
||||
PointBatch.user_id == user_id,
|
||||
PointBatch.remaining > 0,
|
||||
PointBatch.expired_at > _now(),
|
||||
@@ -221,6 +224,7 @@ async def check_balance(
|
||||
|
||||
# ── 充值 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def recharge(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
@@ -247,8 +251,7 @@ async def recharge(
|
||||
# 幂等保护:同一笔订单(order_id)只能充值一次
|
||||
if order_id:
|
||||
existing_result = await db.execute(
|
||||
select(PointTransaction)
|
||||
.where(
|
||||
select(PointTransaction).where(
|
||||
PointTransaction.source_id == str(order_id),
|
||||
PointTransaction.type == "recharge",
|
||||
)
|
||||
@@ -259,9 +262,7 @@ async def recharge(
|
||||
return existing_tx
|
||||
|
||||
# 1. 获取或创建用户积分账户
|
||||
result = await db.execute(
|
||||
select(UserPoint).where(UserPoint.user_id == user_id)
|
||||
)
|
||||
result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
if not up:
|
||||
@@ -353,7 +354,7 @@ async def consume(
|
||||
直接扣费(后置计费)。
|
||||
|
||||
业务执行成功后调用,按实际消耗直接扣除余额。
|
||||
默认不允许欠费(余额不足时抛出 ValueError)。
|
||||
默认不允许欠费(余额不足时抛出 InsufficientPointsException)。
|
||||
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
|
||||
|
||||
:param points: 实际消耗积分(正整数)
|
||||
@@ -383,9 +384,7 @@ async def consume(
|
||||
|
||||
# 2. 获取用户积分账户(加锁)
|
||||
result = await db.execute(
|
||||
select(UserPoint)
|
||||
.where(UserPoint.user_id == user_id)
|
||||
.with_for_update()
|
||||
select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
|
||||
)
|
||||
up = result.scalar_one_or_none()
|
||||
|
||||
@@ -404,7 +403,7 @@ async def consume(
|
||||
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
|
||||
available = sum(b.remaining for b in batches)
|
||||
if not allow_negative and available < points:
|
||||
raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
||||
raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
|
||||
|
||||
remaining_to_deduct = points
|
||||
for batch in batches:
|
||||
@@ -440,6 +439,7 @@ async def consume(
|
||||
|
||||
# ── 过期回收 ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def expire_batches(db: AsyncSession) -> int:
|
||||
"""
|
||||
回收过期积分批次。返回过期积分总数。
|
||||
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
|
||||
|
||||
# 获取用户积分账户(加锁)
|
||||
up_result = await db.execute(
|
||||
select(UserPoint)
|
||||
.where(UserPoint.user_id == batch.user_id)
|
||||
.with_for_update()
|
||||
select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
|
||||
)
|
||||
up = up_result.scalar_one_or_none()
|
||||
if not up:
|
||||
|
||||
@@ -104,7 +104,9 @@ class QiniuService:
|
||||
# 项目前缀
|
||||
PROJECT_PREFIX = "meijiaka-zy"
|
||||
|
||||
def generate_key(self, file_type: str, original_filename: str, user_id: str = None) -> str:
|
||||
def generate_key(
|
||||
self, file_type: str, original_filename: str, user_id: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
生成规范的文件存储路径
|
||||
|
||||
@@ -143,7 +145,7 @@ class QiniuService:
|
||||
return mime_type in allowed_types
|
||||
|
||||
def get_upload_token(
|
||||
self, bucket: str, key: str, expires: int = 3600, policy: dict = None
|
||||
self, bucket: str, key: str, expires: int = 3600, policy: dict | None = None
|
||||
) -> str:
|
||||
"""
|
||||
生成上传凭证(客户端直传使用)
|
||||
@@ -197,9 +199,9 @@ class QiniuService:
|
||||
def upload_file(
|
||||
self,
|
||||
local_path: str,
|
||||
key: str = None,
|
||||
key: str | None = None,
|
||||
file_type: str = "audio",
|
||||
user_id: str = None,
|
||||
user_id: str | None = None,
|
||||
check_duplicate: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -222,15 +224,15 @@ class QiniuService:
|
||||
"is_duplicate": 是否复用已有文件
|
||||
}
|
||||
"""
|
||||
local_path = Path(local_path)
|
||||
if not local_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {local_path}")
|
||||
local_path_obj = Path(local_path)
|
||||
if not local_path_obj.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {local_path_obj}")
|
||||
|
||||
# 根据文件类型获取对应的 bucket 和 domain
|
||||
bucket, domain = self._get_bucket_and_domain(file_type)
|
||||
|
||||
# 计算文件 MD5 哈希
|
||||
file_md5 = self._calculate_file_hash(local_path)
|
||||
file_md5 = self._calculate_file_hash(local_path_obj)
|
||||
|
||||
# 检查是否已存在相同文件
|
||||
if check_duplicate:
|
||||
@@ -248,20 +250,20 @@ class QiniuService:
|
||||
|
||||
# 自动生成 Key
|
||||
if key is None:
|
||||
key = self.generate_key(file_type, local_path.name, user_id)
|
||||
key = self.generate_key(file_type, local_path_obj.name, user_id)
|
||||
|
||||
# 生成上传 Token
|
||||
token = self.get_upload_token(bucket, key)
|
||||
|
||||
# 使用分片上传
|
||||
ret, info = put_file(up_token=token, key=key, file_path=str(local_path))
|
||||
ret, info = put_file(up_token=token, key=key, file_path=str(local_path_obj))
|
||||
|
||||
if ret is None:
|
||||
raise Exception(f"上传失败: {info}")
|
||||
|
||||
# 获取文件信息
|
||||
mime_type, _ = mimetypes.guess_type(str(local_path))
|
||||
fsize = local_path.stat().st_size
|
||||
mime_type, _ = mimetypes.guess_type(str(local_path_obj))
|
||||
fsize = local_path_obj.stat().st_size
|
||||
|
||||
return {
|
||||
"key": ret["key"],
|
||||
@@ -277,8 +279,8 @@ class QiniuService:
|
||||
stream: BinaryIO,
|
||||
key: str,
|
||||
mime_type: str = "application/octet-stream",
|
||||
bucket: str = None,
|
||||
domain: str = None,
|
||||
bucket: str | None = None,
|
||||
domain: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
上传文件流到七牛云
|
||||
@@ -317,7 +319,9 @@ class QiniuService:
|
||||
|
||||
return {"key": ret["key"], "hash": ret["hash"], "url": self.get_file_url(domain, key)}
|
||||
|
||||
def upload_audio(self, local_path: str, user_id: str = None, key: str = None) -> dict:
|
||||
def upload_audio(
|
||||
self, local_path: str, user_id: str | None = None, key: str | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
上传音频文件(专用接口)
|
||||
|
||||
@@ -336,7 +340,9 @@ class QiniuService:
|
||||
|
||||
return self.upload_file(local_path=local_path, key=key, file_type="audio", user_id=user_id)
|
||||
|
||||
def upload_video(self, local_path: str, user_id: str = None, key: str = None) -> dict:
|
||||
def upload_video(
|
||||
self, local_path: str, user_id: str | None = None, key: str | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
上传视频文件(专用接口)
|
||||
|
||||
@@ -454,9 +460,9 @@ class QiniuService:
|
||||
async def upload_file_async(
|
||||
self,
|
||||
local_path: str,
|
||||
key: str = None,
|
||||
key: str | None = None,
|
||||
file_type: str = "audio",
|
||||
user_id: str = None,
|
||||
user_id: str | None = None,
|
||||
check_duplicate: bool = True,
|
||||
) -> dict:
|
||||
"""异步版本 upload_file"""
|
||||
@@ -469,19 +475,21 @@ class QiniuService:
|
||||
stream: BinaryIO,
|
||||
key: str,
|
||||
mime_type: str = "application/octet-stream",
|
||||
bucket: str = None,
|
||||
domain: str = None,
|
||||
bucket: str | None = None,
|
||||
domain: str | None = None,
|
||||
) -> dict:
|
||||
"""异步版本 upload_stream"""
|
||||
return await asyncio.to_thread(
|
||||
self.upload_stream, stream, key, mime_type, bucket, domain
|
||||
)
|
||||
return await asyncio.to_thread(self.upload_stream, stream, key, mime_type, bucket, domain)
|
||||
|
||||
async def upload_audio_async(self, local_path: str, user_id: str = None, key: str = None) -> dict:
|
||||
async def upload_audio_async(
|
||||
self, local_path: str, user_id: str | None = None, key: str | None = None
|
||||
) -> dict:
|
||||
"""异步版本 upload_audio"""
|
||||
return await asyncio.to_thread(self.upload_audio, local_path, user_id, key)
|
||||
|
||||
async def upload_video_async(self, local_path: str, user_id: str = None, key: str = None) -> dict:
|
||||
async def upload_video_async(
|
||||
self, local_path: str, user_id: str | None = None, key: str | None = None
|
||||
) -> dict:
|
||||
"""异步版本 upload_video"""
|
||||
return await asyncio.to_thread(self.upload_video, local_path, user_id, key)
|
||||
|
||||
|
||||
@@ -7,9 +7,16 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.ai.model_router import get_model_router
|
||||
from app.ai.prompts import load_prompt_file, load_script_user_prompt
|
||||
from app.core.exceptions import (
|
||||
AIEmptyResponseException,
|
||||
AIParseErrorException,
|
||||
AITimeoutException,
|
||||
PromptNotFoundException,
|
||||
)
|
||||
from app.schemas.script import ScriptShot
|
||||
from app.services.ai_response_utils import (
|
||||
safe_parse_ai_json_response,
|
||||
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
|
||||
class ScriptService:
|
||||
"""脚本生成服务"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
|
||||
|
||||
|
||||
|
||||
def _load_prompt(self, name: str) -> str:
|
||||
"""加载 Prompt 模板"""
|
||||
prompt_file = self.prompts_dir / f"{name}.txt"
|
||||
@@ -58,7 +62,7 @@ class ScriptService:
|
||||
# 加载 Prompt
|
||||
system_prompt = load_prompt_file(category, filename)
|
||||
if not system_prompt:
|
||||
raise ValueError(f"未找到提示词: category={category}, filename={filename}")
|
||||
raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
|
||||
|
||||
# 用户提示词
|
||||
user_prompt = load_script_user_prompt(
|
||||
@@ -75,24 +79,26 @@ class ScriptService:
|
||||
)
|
||||
|
||||
if not result.content or not result.content.strip():
|
||||
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
|
||||
raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
|
||||
|
||||
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
|
||||
|
||||
if not success:
|
||||
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||||
raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
|
||||
|
||||
try:
|
||||
shots_data = validate_and_normalize_shots(parsed_data)
|
||||
|
||||
if not shots_data:
|
||||
raise ValueError("AI 返回的分镜数据为空或格式不正确")
|
||||
raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
|
||||
|
||||
shots = [ScriptShot(**shot) for shot in shots_data]
|
||||
return shots
|
||||
|
||||
except (AIEmptyResponseException, AIParseErrorException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"分镜数据处理失败: {str(e)}")
|
||||
raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
|
||||
|
||||
async def polish_content(
|
||||
self,
|
||||
@@ -144,21 +150,23 @@ class ScriptService:
|
||||
)
|
||||
return result.content.strip()
|
||||
except TimeoutError:
|
||||
raise ValueError("润色请求超时,请重试")
|
||||
raise AITimeoutException("润色请求超时,请重试")
|
||||
except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"润色失败: {str(e)}")
|
||||
raise AIParseErrorException(f"润色失败: {str(e)}")
|
||||
|
||||
async def check_model_health(self) -> dict:
|
||||
"""检查模型健康状态"""
|
||||
model_router = await get_model_router()
|
||||
health_results = await model_router.health_check()
|
||||
|
||||
models = []
|
||||
models: list[dict[str, Any]] = []
|
||||
available_count = 0
|
||||
recommended = None
|
||||
recommended: dict[str, Any] | None = None
|
||||
|
||||
for _provider_id, health in health_results.items():
|
||||
model_info = {
|
||||
model_info: dict[str, Any] = {
|
||||
"id": health.id,
|
||||
"name": health.name,
|
||||
"is_available": health.is_available,
|
||||
@@ -169,9 +177,12 @@ class ScriptService:
|
||||
|
||||
if health.is_available:
|
||||
available_count += 1
|
||||
if recommended is None or health.response_time < recommended.get(
|
||||
"response_time", float("inf")
|
||||
):
|
||||
current_best = (
|
||||
float("inf")
|
||||
if recommended is None
|
||||
else float(recommended.get("response_time") or float("inf"))
|
||||
)
|
||||
if health.response_time < current_best:
|
||||
recommended = model_info
|
||||
|
||||
total = len(models)
|
||||
@@ -188,7 +199,6 @@ class ScriptService:
|
||||
"""测试指定模型连接"""
|
||||
model_router = await get_model_router()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
|
||||
@@ -76,9 +76,7 @@ class B2MSMSService:
|
||||
if self.secret_key:
|
||||
key_bytes = self.secret_key.encode("utf-8")
|
||||
if len(key_bytes) not in (16, 24, 32):
|
||||
raise SMSError(
|
||||
f"AES 密钥长度必须是 16/24/32 字节,当前 {len(key_bytes)} 字节"
|
||||
)
|
||||
raise SMSError(f"AES 密钥长度必须是 16/24/32 字节,当前 {len(key_bytes)} 字节")
|
||||
self._secret_key_bytes = key_bytes
|
||||
|
||||
if not all([self.app_id, self.secret_key, self.base_url]):
|
||||
@@ -245,6 +243,7 @@ class B2MSMSService:
|
||||
|
||||
# ── 便捷函数 ──────────────────────────────────────────
|
||||
|
||||
|
||||
def get_sms_service() -> B2MSMSService:
|
||||
"""获取短信服务实例"""
|
||||
return B2MSMSService()
|
||||
|
||||
@@ -207,8 +207,9 @@ class ViduService:
|
||||
error_type=PlatformErrorType.BAD_REQUEST,
|
||||
)
|
||||
|
||||
logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
|
||||
return result.data or {}
|
||||
clone_data = result.data or {}
|
||||
logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
|
||||
return clone_data
|
||||
|
||||
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
|
||||
"""Vidu 声音复刻是同步接口,无独立查询。
|
||||
@@ -270,6 +271,8 @@ class ViduService:
|
||||
result_data = status.result or {}
|
||||
return {
|
||||
"state": ViduAdapter.denormalize_state(status.state),
|
||||
"creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
|
||||
"creations": (
|
||||
[{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
|
||||
),
|
||||
"message": status.error_message,
|
||||
}
|
||||
|
||||
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
|
||||
error_type=PlatformErrorType.BAD_REQUEST,
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"{task_name}超过最大轮询次数: task_id={task_id}, "
|
||||
f"retries={retries}"
|
||||
)
|
||||
logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
|
||||
raise PlatformError(
|
||||
f"{task_name}超时,请稍后重试",
|
||||
platform="volcengine_caption",
|
||||
|
||||
@@ -78,9 +78,7 @@ class VolcengineMediakitService:
|
||||
)
|
||||
|
||||
if not response.success:
|
||||
raise RuntimeError(
|
||||
response.error_message or "抠图失败"
|
||||
)
|
||||
raise RuntimeError(response.error_message or "抠图失败")
|
||||
|
||||
result_image_url = (response.data or {}).get("image_url", "")
|
||||
return RemoveBackgroundResult(
|
||||
|
||||
@@ -69,7 +69,9 @@ class WechatPayService:
|
||||
self.notify_url = settings.WXPAY_NOTIFY_URL
|
||||
|
||||
if not all([self.mchid, self.appid, self.api_key]):
|
||||
raise WechatPayError("微信支付配置不完整:WXPAY_MCHID / WXPAY_APPID / WXPAY_API_KEY 未配置")
|
||||
raise WechatPayError(
|
||||
"微信支付配置不完整:WXPAY_MCHID / WXPAY_APPID / WXPAY_API_KEY 未配置"
|
||||
)
|
||||
|
||||
# ── 签名与验签 ──────────────────────────────────────
|
||||
|
||||
@@ -85,7 +87,9 @@ class WechatPayService:
|
||||
5. MD5 加密,结果转大写
|
||||
"""
|
||||
# 过滤空值和 sign
|
||||
filtered = {k: str(v) for k, v in params.items() if v is not None and v != "" and k != "sign"}
|
||||
filtered = {
|
||||
k: str(v) for k, v in params.items() if v is not None and v != "" and k != "sign"
|
||||
}
|
||||
# ASCII 升序排序
|
||||
sorted_items = sorted(filtered.items(), key=lambda x: x[0])
|
||||
# 拼接字符串
|
||||
|
||||
@@ -41,13 +41,15 @@ async def get_audio_duration(url: str, timeout: float = 10.0) -> float:
|
||||
|
||||
try:
|
||||
if header[:3] == b"ID3" or header[:2] == b"\xff\xfb" or header[:2] == b"\xff\xf3":
|
||||
audio = MP3(data)
|
||||
audio: MP3 | WAVE = MP3(data)
|
||||
elif header[:4] == b"RIFF":
|
||||
audio = WAVE(data)
|
||||
else:
|
||||
# fallback:先尝试 MP3(大多数 TTS 返回 mp3)
|
||||
audio = MP3(data)
|
||||
|
||||
if audio.info is None:
|
||||
raise ValueError("音频信息解析失败")
|
||||
duration = audio.info.length
|
||||
if duration is None or duration <= 0:
|
||||
raise ValueError("音频时长解析失败")
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
内容指纹工具
|
||||
============
|
||||
|
||||
用于 AI 第三方平台(如 Vidu)的审核结果缓存与防重复提交。
|
||||
|
||||
核心逻辑:
|
||||
- 对提交的音频/视频/图片 URL + 任务参数生成 SHA256 指纹。
|
||||
- 如果相同内容近期因审核失败被缓存,则直接返回错误,不再调用第三方平台。
|
||||
- 仅规范化 URL(去掉 token 等动态参数),不下载大文件本身。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
# Vidu 内容安全/审核类错误码
|
||||
# 这些错误在内容不变的情况下重试也没用,应该被缓存。
|
||||
VIDU_AUDIT_ERROR_CODES = frozenset(
|
||||
{
|
||||
"TaskPromptPolicyViolation", # Prompt 触发安审风控
|
||||
"AuditSubmitIllegal", # 输入未通过安全审核
|
||||
"CreationPolicyViolation", # 生成物触发风控
|
||||
"PhotoAuditNotPass", # 图片审核不通过
|
||||
"AuditFailed", # 审核失败
|
||||
"ImageCheckBodyJointsFailed", # 输入图人体检测失败
|
||||
"ImageCheckFaceFailed", # 输入图人脸检测失败
|
||||
"ImageObjectsUndetected", # 人体或人脸有遮挡
|
||||
"FaceDetectFailure", # 人脸检测失败
|
||||
"FaceDetectNotPass", # 人脸检测不通过
|
||||
"NoFaceDetected", # 未检测到人脸
|
||||
"MultiFaceDetected", # 检测到多张人脸
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 常见动态 query 参数,生成指纹时应忽略
|
||||
_DYNAMIC_QUERY_PARAMS = frozenset(
|
||||
{
|
||||
"token",
|
||||
"e", # 七牛过期时间戳
|
||||
"t", # 时间戳
|
||||
"sign",
|
||||
"x-oss-signature",
|
||||
"x-oss-expires",
|
||||
"x-oss-access-key-id",
|
||||
"response-content-disposition",
|
||||
"v", # 版本号/缓存戳
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def normalize_url(url: str | None) -> str:
|
||||
"""规范化 URL,去掉动态参数,确保同一资源不同 token 得到相同指纹。
|
||||
|
||||
Args:
|
||||
url: 原始 URL,可能包含七牛私有 token、时间戳等动态参数
|
||||
|
||||
Returns:
|
||||
规范化后的 URL 字符串,None 或空字符串返回空串
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
parsed = urlparse(url)
|
||||
query_params = parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
for key in list(query_params.keys()):
|
||||
if key.lower() in _DYNAMIC_QUERY_PARAMS:
|
||||
del query_params[key]
|
||||
|
||||
query = urlencode(sorted(query_params.items()), doseq=True)
|
||||
path = parsed.path or "/"
|
||||
|
||||
if query:
|
||||
return f"{parsed.scheme}://{parsed.netloc}{path}?{query}"
|
||||
return f"{parsed.scheme}://{parsed.netloc}{path}"
|
||||
|
||||
|
||||
def compute_content_fingerprint(
|
||||
task_type: str,
|
||||
*,
|
||||
video_url: str | None = None,
|
||||
audio_url: str | None = None,
|
||||
ref_photo_url: str | None = None,
|
||||
text: str | None = None,
|
||||
voice_id: str | None = None,
|
||||
) -> str:
|
||||
"""计算内容指纹。
|
||||
|
||||
指纹字段选择原则:只包含会影响 Vidu 审核结果的输入内容。
|
||||
不包含 callback_url、speed、volume、payload 等业务/技术参数。
|
||||
|
||||
Args:
|
||||
task_type: 任务类型,如 "lip_sync", "tts", "clone_voice"
|
||||
video_url: 视频 URL
|
||||
audio_url: 音频 URL
|
||||
ref_photo_url: 参考图片 URL
|
||||
text: 文本/Prompt
|
||||
voice_id: 音色 ID
|
||||
|
||||
Returns:
|
||||
SHA256 十六进制指纹字符串
|
||||
"""
|
||||
parts = [
|
||||
task_type.strip().lower(),
|
||||
normalize_url(video_url),
|
||||
normalize_url(audio_url),
|
||||
normalize_url(ref_photo_url),
|
||||
(text or "").strip(),
|
||||
(voice_id or "").strip().lower(),
|
||||
]
|
||||
raw = "|".join(parts)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def is_vidu_audit_error(err_code: str | None) -> bool:
|
||||
"""判断是否为 Vidu 审核类错误码。"""
|
||||
if not err_code:
|
||||
return False
|
||||
return err_code.strip() in VIDU_AUDIT_ERROR_CODES
|
||||
|
||||
|
||||
def extract_vidu_error_code(message: str | None) -> str | None:
|
||||
"""从 Vidu 错误信息中提取错误码。
|
||||
|
||||
Vidu 错误信息格式通常为:"ErrorCode: 中文描述"
|
||||
"""
|
||||
if not message:
|
||||
return None
|
||||
candidate = message.split(":")[0].strip()
|
||||
return candidate or None
|
||||
+13
-15
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "meijiaka-ai-api"
|
||||
version = "1.8.2"
|
||||
version = "1.9.1"
|
||||
description = "美家卡智影 - AI 视频创作后端 API"
|
||||
authors = [{ name = "Meijiaka Team" }]
|
||||
readme = "README.md"
|
||||
@@ -15,9 +15,10 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
# Web 框架 (FastAPI 0.116+ 修复 Starlette 安全漏洞)
|
||||
"fastapi>=0.136.1",
|
||||
"fastapi>=0.115.8",
|
||||
"uvicorn[standard]~=0.32.0",
|
||||
"python-multipart~=0.0.20",
|
||||
"python-multipart>=0.0.27",
|
||||
"pyasn1>=0.6.3", # 安全修复:间接依赖,强制升级
|
||||
|
||||
# 认证 & 安全
|
||||
"passlib[bcrypt]~=1.7.4",
|
||||
@@ -47,14 +48,14 @@ dependencies = [
|
||||
"volcengine-python-sdk[ark]~=5.0.0",
|
||||
|
||||
# HTTP 客户端
|
||||
"httpx~=0.28.0",
|
||||
"aiohttp>=3.13.4", # 安全修复:修复 CVE-2025-XXXX 系列漏洞
|
||||
"httpx>=0.28.0",
|
||||
"aiohttp>=3.14.0", # 安全修复:修复 CVE-2025-XXXX 系列漏洞
|
||||
|
||||
# 对象存储
|
||||
"qiniu~=7.13.0",
|
||||
|
||||
# 工具
|
||||
"pyjwt~=2.10.0",
|
||||
"pyjwt>=2.13.0",
|
||||
|
||||
"pyyaml~=6.0.2",
|
||||
"orjson>=3.11.0", # 安全修复:修复 CVE-2025-XXXX
|
||||
@@ -68,12 +69,12 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest~=8.3.0",
|
||||
"pytest-asyncio~=0.24.0",
|
||||
"pytest-cov~=6.0.0",
|
||||
"ruff~=0.8.0",
|
||||
"black~=24.10.0",
|
||||
"mypy~=1.14.0",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"ruff>=0.8.0",
|
||||
"black>=26.3.1",
|
||||
"mypy>=1.14.0",
|
||||
"bandit[toml]~=1.8.0", # 安全扫描
|
||||
"pip-audit~=2.7.0", # 漏洞检测
|
||||
"pre-commit~=4.0.0", # Git 钩子
|
||||
@@ -99,7 +100,6 @@ ignore = ["E501", "E402", "N802", "N803", "N806", "N815", "B008", "B904"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.13"
|
||||
strict = false
|
||||
warn_return_any = false
|
||||
warn_unused_configs = true
|
||||
ignore_missing_imports = true
|
||||
@@ -112,7 +112,6 @@ disallow_incomplete_defs = false
|
||||
# ========== 重构防护网:新代码严格模式 ==========
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["app.schemas.*", "app.crud.*", "app.scheduler.handlers.*"]
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_defs = true
|
||||
@@ -121,7 +120,6 @@ disallow_incomplete_defs = true
|
||||
# Redis 客户端 typing 问题(Awaitable[T] | T),暂不严格检查
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["app.scheduler.registry", "app.scheduler.slot_manager"]
|
||||
strict = false
|
||||
check_untyped_defs = false
|
||||
disallow_untyped_defs = false
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml -o requirements.lock
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohappyeyeballs==2.6.2
|
||||
# via aiohttp
|
||||
aiohttp==3.13.5
|
||||
aiohttp==3.14.1
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
@@ -27,7 +27,7 @@ bcrypt==4.2.1
|
||||
# via
|
||||
# meijiaka-ai-api (pyproject.toml)
|
||||
# passlib
|
||||
certifi==2026.4.22
|
||||
certifi==2026.5.20
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
@@ -37,19 +37,19 @@ cffi==2.0.0
|
||||
# via cryptography
|
||||
charset-normalizer==3.4.7
|
||||
# via requests
|
||||
click==8.3.3
|
||||
click==8.4.1
|
||||
# via uvicorn
|
||||
cryptography==48.0.0
|
||||
cryptography==48.0.1
|
||||
# via volcengine-python-sdk
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
fastapi==0.136.1
|
||||
fastapi==0.136.3
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
greenlet==3.5.0
|
||||
greenlet==3.5.1
|
||||
# via sqlalchemy
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -57,20 +57,20 @@ h11==0.16.0
|
||||
# uvicorn
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httptools==0.7.1
|
||||
httptools==0.8.0
|
||||
# via uvicorn
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# meijiaka-ai-api (pyproject.toml)
|
||||
# openai
|
||||
# volcengine-python-sdk
|
||||
idna==3.13
|
||||
idna==3.18
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
jiter==0.14.0
|
||||
jiter==0.15.0
|
||||
# via openai
|
||||
mako==1.3.12
|
||||
# via alembic
|
||||
@@ -88,12 +88,16 @@ orjson==3.11.9
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
passlib==1.7.4
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
propcache==0.4.1
|
||||
pillow==12.2.0
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
propcache==0.5.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psycopg2-binary==2.9.12
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
pyasn1==0.6.3
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
pycparser==3.0
|
||||
# via cffi
|
||||
pydantic==2.9.2
|
||||
@@ -107,7 +111,7 @@ pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pydantic-settings==2.6.1
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
pyjwt==2.10.1
|
||||
pyjwt==2.13.0
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
python-dateutil==2.9.0.post0
|
||||
# via volcengine-python-sdk
|
||||
@@ -115,7 +119,7 @@ python-dotenv==1.2.2
|
||||
# via
|
||||
# pydantic-settings
|
||||
# uvicorn
|
||||
python-multipart==0.0.27
|
||||
python-multipart==0.0.32
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
@@ -125,7 +129,7 @@ qiniu==7.13.2
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
redis==5.2.1
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
requests==2.33.1
|
||||
requests==2.34.2
|
||||
# via qiniu
|
||||
six==1.17.0
|
||||
# via
|
||||
@@ -133,15 +137,15 @@ six==1.17.0
|
||||
# volcengine-python-sdk
|
||||
sniffio==1.3.1
|
||||
# via openai
|
||||
sqlalchemy==2.0.49
|
||||
sqlalchemy==2.0.50
|
||||
# via
|
||||
# meijiaka-ai-api (pyproject.toml)
|
||||
# alembic
|
||||
starlette==1.0.0
|
||||
starlette==1.3.0
|
||||
# via fastapi
|
||||
tenacity==9.0.0
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
tqdm==4.67.3
|
||||
tqdm==4.68.2
|
||||
# via openai
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
@@ -162,11 +166,11 @@ uvicorn==0.32.1
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
uvloop==0.22.1
|
||||
# via uvicorn
|
||||
volcengine-python-sdk==5.0.26
|
||||
volcengine-python-sdk==5.0.34
|
||||
# via meijiaka-ai-api (pyproject.toml)
|
||||
watchfiles==1.1.1
|
||||
watchfiles==1.2.0
|
||||
# via uvicorn
|
||||
websockets==16.0
|
||||
# via uvicorn
|
||||
yarl==1.23.0
|
||||
yarl==1.24.2
|
||||
# via aiohttp
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
异常体系单元测试
|
||||
================
|
||||
|
||||
验证 AppException / InsufficientPointsException 的结构化字段,
|
||||
确保前端可以通过 error_code 识别错误类型。
|
||||
"""
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.exceptions import (
|
||||
AIEmptyResponseException,
|
||||
AIParseErrorException,
|
||||
AITimeoutException,
|
||||
AppException,
|
||||
BusinessException,
|
||||
InsufficientPointsException,
|
||||
NotFoundException,
|
||||
PromptNotFoundException,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
|
||||
class TestAppException:
|
||||
"""业务异常基类"""
|
||||
|
||||
def test_app_exception_has_error_code(self) -> None:
|
||||
exc = AppException(
|
||||
status_code=400,
|
||||
message="参数错误",
|
||||
error_code="validation_error",
|
||||
)
|
||||
assert exc.status_code == 400
|
||||
assert exc.message == "参数错误"
|
||||
assert exc.error_code == "validation_error"
|
||||
assert exc.detail == {"message": "参数错误", "error_code": "validation_error"}
|
||||
|
||||
def test_app_exception_detail_can_be_dict(self) -> None:
|
||||
exc = AppException(
|
||||
status_code=422,
|
||||
message="字段校验失败",
|
||||
detail={"fields": {"name": "required"}},
|
||||
error_code="validation_error",
|
||||
)
|
||||
assert exc.detail == {
|
||||
"fields": {"name": "required"},
|
||||
"message": "字段校验失败",
|
||||
"error_code": "validation_error",
|
||||
}
|
||||
|
||||
def test_subclasses_without_error_code(self) -> None:
|
||||
exc = NotFoundException("资源不存在")
|
||||
assert exc.status_code == 404
|
||||
assert exc.message == "资源不存在"
|
||||
assert exc.error_code is None
|
||||
|
||||
|
||||
class TestInsufficientPointsException:
|
||||
"""积分不足异常"""
|
||||
|
||||
def test_default_fields(self) -> None:
|
||||
exc = InsufficientPointsException()
|
||||
assert exc.status_code == 402
|
||||
assert exc.message == "积分不足"
|
||||
assert exc.error_code == "insufficient_points"
|
||||
assert isinstance(exc, HTTPException)
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = InsufficientPointsException("余额不足,请先充值")
|
||||
assert exc.message == "余额不足,请先充值"
|
||||
assert exc.error_code == "insufficient_points"
|
||||
|
||||
def test_detail_structure(self) -> None:
|
||||
exc = InsufficientPointsException("需要 10 积分")
|
||||
assert exc.detail == {
|
||||
"message": "需要 10 积分",
|
||||
"error_code": "insufficient_points",
|
||||
}
|
||||
|
||||
|
||||
class TestOtherSubclasses:
|
||||
"""其他常用子类"""
|
||||
|
||||
def test_validation_exception(self) -> None:
|
||||
exc = ValidationException("字段缺失")
|
||||
assert exc.status_code == 422
|
||||
assert exc.message == "字段缺失"
|
||||
|
||||
def test_business_exception(self) -> None:
|
||||
exc = BusinessException("业务状态不允许")
|
||||
assert exc.status_code == 400
|
||||
assert exc.message == "业务状态不允许"
|
||||
|
||||
|
||||
class TestAIStructuredExceptions:
|
||||
"""AI 相关结构化异常"""
|
||||
|
||||
def test_prompt_not_found_exception(self) -> None:
|
||||
exc = PromptNotFoundException("未找到提示词")
|
||||
assert exc.status_code == 404
|
||||
assert exc.error_code == "prompt_not_found"
|
||||
|
||||
def test_ai_empty_response_exception(self) -> None:
|
||||
exc = AIEmptyResponseException("AI 返回为空")
|
||||
assert exc.status_code == 500
|
||||
assert exc.error_code == "empty_result"
|
||||
|
||||
def test_ai_parse_error_exception(self) -> None:
|
||||
exc = AIParseErrorException("解析失败")
|
||||
assert exc.status_code == 500
|
||||
assert exc.error_code == "parse_error"
|
||||
|
||||
def test_ai_timeout_exception(self) -> None:
|
||||
exc = AITimeoutException("请求超时")
|
||||
assert exc.status_code == 504
|
||||
assert exc.error_code == "timeout"
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
积分 Service 单元测试
|
||||
=====================
|
||||
|
||||
不依赖数据库,仅测试配置加载、计费计算、充值档位等纯函数逻辑。
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services import point_service as ps
|
||||
|
||||
|
||||
class TestPointsConfig:
|
||||
"""积分配置加载"""
|
||||
|
||||
def test_points_config_loaded(self) -> None:
|
||||
assert "script" in ps.POINTS_CONFIG
|
||||
assert "tts" in ps.POINTS_CONFIG
|
||||
assert "video" in ps.POINTS_CONFIG
|
||||
|
||||
def test_get_chargeable_source_types(self) -> None:
|
||||
types = ps.get_chargeable_source_types()
|
||||
assert "script" in types
|
||||
assert "tts" in types
|
||||
# 免费业务不应出现在扣费列表
|
||||
assert "caption" not in types
|
||||
|
||||
def test_get_recharge_options_returns_list(self) -> None:
|
||||
options = ps.get_recharge_options()
|
||||
assert isinstance(options, list)
|
||||
if options:
|
||||
assert "points" in options[0]
|
||||
|
||||
|
||||
class TestCalculateCost:
|
||||
"""后置计费计算"""
|
||||
|
||||
def test_fixed_cost(self) -> None:
|
||||
assert ps._calculate_cost("script") == 5
|
||||
assert ps._calculate_cost("polish") == 1
|
||||
assert ps._calculate_cost("voice_clone") == 200
|
||||
|
||||
def test_free_cost(self) -> None:
|
||||
free_types = [
|
||||
key
|
||||
for key, cfg in ps.POINTS_CONFIG.items()
|
||||
if not key.startswith("_") and cfg.get("mode") == "free"
|
||||
]
|
||||
for source_type in free_types:
|
||||
assert ps._calculate_cost(source_type) == 0
|
||||
|
||||
def test_unknown_source_type(self) -> None:
|
||||
with pytest.raises(ValueError, match="未知的消费类型"):
|
||||
ps._calculate_cost("not_exists")
|
||||
|
||||
def test_tts_duration_cost(self) -> None:
|
||||
# tts: divisor=5,按 5 秒为单位计费
|
||||
assert ps._calculate_cost("tts", {"seconds": 4}) == 1
|
||||
assert ps._calculate_cost("tts", {"seconds": 5}) == 1
|
||||
assert ps._calculate_cost("tts", {"seconds": 6}) == 2
|
||||
assert ps._calculate_cost("tts", {"seconds": 5.1}) == 2
|
||||
|
||||
def test_video_duration_cost(self) -> None:
|
||||
# video: multiplier=1,按秒向上取整
|
||||
assert ps._calculate_cost("video", {"seconds": 10}) == 10
|
||||
assert ps._calculate_cost("video", {"seconds": 10.2}) == 11
|
||||
|
||||
|
||||
class TestEstimateMaxCost:
|
||||
"""预估上限计算"""
|
||||
|
||||
def test_tts_estimate_by_char_count(self) -> None:
|
||||
cfg = ps.POINTS_CONFIG["tts"]
|
||||
seconds_per_char = cfg["estimation"]["seconds_per_char"]
|
||||
char_count = 100
|
||||
estimated_seconds = char_count * seconds_per_char
|
||||
expected = max(
|
||||
cfg["min_points"],
|
||||
math.ceil(estimated_seconds / cfg["divisor"]),
|
||||
)
|
||||
result = ps._estimate_max_cost("tts", {"char_count": char_count})
|
||||
assert result == expected
|
||||
|
||||
def test_video_estimate_by_input_seconds(self) -> None:
|
||||
with pytest.raises(ValueError, match="input_seconds"):
|
||||
ps._estimate_max_cost("video", {})
|
||||
|
||||
assert ps._estimate_max_cost("video", {"input_seconds": 10}) == 10
|
||||
assert ps._estimate_max_cost("video", {"input_seconds": 10.2}) == 11
|
||||
|
||||
def test_unknown_source_type(self) -> None:
|
||||
with pytest.raises(ValueError, match="未知的消费类型"):
|
||||
ps._estimate_max_cost("not_exists")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user