chore(release): bump version to 1.9.1 and apply pending changes

This commit is contained in:
小鱼开发
2026-06-16 15:17:30 +08:00
parent 9a71584d6c
commit c6a40331d4
152 changed files with 9396 additions and 10267 deletions
+537
View File
@@ -0,0 +1,537 @@
diff --git a/python-api/app/api/v1/caption.py b/python-api/app/api/v1/caption.py
index afae831..99a771a 100644
--- a/python-api/app/api/v1/caption.py
+++ b/python-api/app/api/v1/caption.py
@@ -24,19 +24,6 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/caption", tags=["Caption"])
-
-
-
-
-
-
-
-
-
-
-
-
-
@router.post("/ata/align")
async def auto_align_caption(
request_body: AutoAlignSubmitRequest,
@@ -88,9 +75,3 @@ async def auto_align_caption(
except Exception as e:
logger.error(f"自动打轴异常: {e}")
raise HTTPException(status_code=500, detail="字幕打轴失败,请稍后重试")
-
-
-
-
-
-
diff --git a/python-api/app/api/v1/points.py b/python-api/app/api/v1/points.py
index b2c7208..23bafa9 100644
--- a/python-api/app/api/v1/points.py
+++ b/python-api/app/api/v1/points.py
@@ -6,7 +6,7 @@
"""
import logging
-from datetime import UTC, datetime
+from datetime import UTC, datetime, timedelta
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_db
+from app.core.exceptions import InsufficientPointsException
from app.crud.point_recharge_order import point_recharge_order
from app.crud.point_transaction import point_transaction
from app.models.user import User
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/points", tags=["Points"])
# ── 余额查询 ──────────────────────────────────────────
+
@router.get("/balance", response_model=ApiResponse[PointBalanceResponse])
async def get_balance(
db: AsyncSession = Depends(get_db),
@@ -47,6 +49,7 @@ async def get_balance(
# ── 流水查询 ──────────────────────────────────────────
+
@router.get("/transactions", response_model=ApiResponse[PointTransactionListResponse])
async def list_transactions(
pagination: PaginationParams = Depends(),
@@ -127,12 +130,13 @@ async def list_transactions(
# ── 充值 ──────────────────────────────────────────────
+
@router.post("/recharge", response_model=ApiResponse[RechargeResponse])
async def create_recharge_order(
request: RechargeRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
- http_request: Request = None,
+ http_request: Request = None, # type: ignore[assignment]
):
"""
创建积分充值订单(微信支付 Native 扫码)
@@ -303,9 +307,7 @@ async def handle_wxpay_notify(
return _wx_response()
# 查找订单
- order = await point_recharge_order.get_by_out_trade_no(
- db, out_trade_no=out_trade_no
- )
+ order = await point_recharge_order.get_by_out_trade_no(db, out_trade_no=out_trade_no)
if not order:
logger.error(f"[WechatPay] 回调订单不存在: {out_trade_no}")
return _wx_response()
@@ -400,9 +402,7 @@ async def query_recharge_status(
wxpay = get_wxpay_service()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
- wx_result = await wxpay.query_order(
- client, out_trade_no=order.out_trade_no
- )
+ wx_result = await wxpay.query_order(client, out_trade_no=order.out_trade_no)
order.query_result = str(wx_result)
trade_state = wx_result.get("trade_state", "")
@@ -465,6 +465,7 @@ async def query_recharge_status(
# ── 充值档位查询 ──────────────────────────────────────
+
@router.get("/recharge-options", response_model=ApiResponse[list[dict]])
async def get_recharge_options(
current_user: User = Depends(get_current_user),
@@ -480,6 +481,7 @@ async def get_recharge_options(
# ── 扣费业务类型查询 ──────────────────────────────────
+
@router.get("/chargeable-types", response_model=ApiResponse[list[str]])
async def get_chargeable_types(
current_user: User = Depends(get_current_user),
@@ -496,6 +498,7 @@ async def get_chargeable_types(
# ── 积分规则查询 ──────────────────────────────────────
+
@router.get("/rules", response_model=ApiResponse[list[dict]])
async def get_points_rules(
current_user: User = Depends(get_current_user),
@@ -530,9 +533,9 @@ async def get_points_rules(
# ── 积分预估查询 ──────────────────────────────────────
-
# ── 今日消费统计 ──────────────────────────────────────
+
@router.get("/today-consumed", response_model=ApiResponse[dict])
async def get_today_consumed(
db: AsyncSession = Depends(get_db),
@@ -545,6 +548,7 @@ async def get_today_consumed(
# ── 直接消费扣费(前端/Rust 层调用)───────────────────
+
@router.post("/consume", response_model=ApiResponse[dict])
async def consume_points(
request: ConsumeRequest,
@@ -569,12 +573,12 @@ async def consume_points(
source_type=request.source_type,
source_id=request.source_id,
description=f"【{request.description or request.source_type}】",
- allow_negative=False,
+ allow_negative=request.allow_negative,
)
- except ValueError as e:
+ except InsufficientPointsException:
# 余额不足(在同一事务内判断,避免竞态)
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
+ raise
+ except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
await db.commit()
@@ -587,5 +591,3 @@ async def consume_points(
},
message="消费成功",
)
-
-
diff --git a/python-api/app/api/v1/script.py b/python-api/app/api/v1/script.py
index 5e7b18f..26a11a3 100644
--- a/python-api/app/api/v1/script.py
+++ b/python-api/app/api/v1/script.py
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.ai.model_router import get_model_router
from app.ai.prompts import list_categories, list_prompt_files, load_prompt, render_template
from app.api.deps import get_current_user
+from app.core.exceptions import AITimeoutException, InsufficientPointsException
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -71,9 +72,8 @@ async def polish_content(
required_points = ps._calculate_cost("polish")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -99,11 +99,11 @@ async def polish_content(
data=polished,
message=f"{type_name}润色完成",
)
+ except InsufficientPointsException:
+ raise
except HTTPException:
raise
except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
logger.warning(f"[Polish] 润色失败: {e}")
raise HTTPException(status_code=500, detail="润色失败,请检查输入内容后重试")
except Exception as e:
@@ -111,9 +111,6 @@ async def polish_content(
raise HTTPException(status_code=500, detail=f"{type_name}润色失败,请稍后重试")
-
-
-
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
async def generate_title(
request: GenerateTitleRequest,
@@ -146,7 +143,11 @@ async def generate_title(
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
# 渲染用户提示词
- title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
+ title_type_desc = (
+ "大标题(主标题,提炼核心卖点,吸睛)"
+ if request.title_type == "main"
+ else "小标题(副标题,补充说明或制造悬念)"
+ )
user_prompt = render_template(
user_template,
title_type=request.title_type,
@@ -163,9 +164,8 @@ async def generate_title(
required_points = ps._calculate_cost("title")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -179,10 +179,10 @@ async def generate_title(
title = result.content.strip() if result.content else ""
# 去除可能的引号
- title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
+ title = title.strip('"').strip("'").strip("「」").strip("『』").strip("《》")
# 截断到最大长度
if len(title) > request.max_length:
- title = title[:request.max_length]
+ title = title[: request.max_length]
# 扣费
points = ps._calculate_cost("title")
@@ -200,15 +200,13 @@ async def generate_title(
data=GenerateTitleResponse(title=title),
message="标题生成成功",
)
+ except InsufficientPointsException:
+ raise
except HTTPException:
raise
except TimeoutError:
logger.warning("[generate_title] 标题生成超时")
- raise HTTPException(status_code=500, detail="标题生成超时,请重试")
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
+ raise AITimeoutException("标题生成超时,请稍后重试")
except Exception as e:
logger.error(f"[generate_title] 标题生成失败: {e}")
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
diff --git a/python-api/app/api/v1/tasks.py b/python-api/app/api/v1/tasks.py
index ae17fe3..55965ca 100644
--- a/python-api/app/api/v1/tasks.py
+++ b/python-api/app/api/v1/tasks.py
@@ -18,6 +18,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from app.api.deps import get_current_user
+from app.core.exceptions import InsufficientPointsException
from app.core.redis_client import get_redis_client
from app.db.session import AsyncSessionLocal
from app.models.user import User
@@ -38,7 +39,6 @@ class ScriptParams(BaseModel):
category: str = Field(..., min_length=1, description="大类代码")
filename: str = Field(..., min_length=1, description="提示词文件名")
-
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
@@ -96,7 +96,9 @@ class VideoParams(BaseModel):
volume: int = Field(default=0, ge=0, le=10, description="音量")
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
- total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
+ total_planned_duration: float = Field(
+ ..., gt=0, description="所有分镜规划时长之和(秒),用于预检"
+ )
batch_id: str | None = Field(default=None, description="批次ID(可选)")
@field_validator("video_url")
@@ -134,6 +136,7 @@ class TaskStatusResponse(BaseModel):
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
+ error_code: str | None = Field(None, description="错误码(失败时,如 content_violation")
created_at: str = Field("", description="任务创建时间(ISO格式)")
@@ -222,9 +225,8 @@ async def create_task(
f"[API] 积分不足: user={user_id}, type={task_type}, "
f"required={required_points}, balance={check['balance']}"
)
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
# ── 3. 写入 Redis ──────────────────────────────────
@@ -246,18 +248,41 @@ async def create_task(
params=validated_params,
)
await registry.add_running(task_id)
-
- logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
- return TaskCreateResponse(
- task_id=task_id,
- status="running",
- message=f"{task_type} 任务已创建",
- )
-
except Exception as e:
logger.error(f"[API] Failed to update registry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
+ # ── 4. 脚本生成:Redis 写入成功后再扣费 ─────────────
+ if task_type == "script" and required_points > 0:
+ try:
+ async with AsyncSessionLocal() as db:
+ await ps.consume(
+ db,
+ user_id=user_id,
+ points=required_points,
+ source_type="script",
+ source_id=task_id,
+ description="【脚本生成】",
+ )
+ await db.commit()
+ except InsufficientPointsException:
+ # 余额不足:将任务标记为失败,避免无费执行
+ await registry.update(task_id, status="failed", message="积分不足")
+ await registry.remove_running(task_id)
+ raise
+ except Exception as e:
+ logger.error(f"[API] 脚本任务扣费失败: {e}")
+ await registry.update(task_id, status="failed", message="扣费失败")
+ await registry.remove_running(task_id)
+ raise HTTPException(status_code=500, detail="扣费失败,请稍后重试")
+
+ logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
+ return TaskCreateResponse(
+ task_id=task_id,
+ status="running",
+ message=f"{task_type} 任务已创建",
+ )
+
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
@@ -294,6 +319,7 @@ async def list_tasks(
total=task.total,
result=None, # 列表查询不返回 result,避免数据过大
error=task.error,
+ error_code=task.error_code,
created_at=task.created_at,
)
)
@@ -337,6 +363,7 @@ async def get_task_status(
total=task.total,
result=task.result,
error=task.error,
+ error_code=task.error_code,
created_at=task.created_at,
)
diff --git a/python-api/app/api/v1/vidu.py b/python-api/app/api/v1/vidu.py
index c0fb04a..401d8da 100644
--- a/python-api/app/api/v1/vidu.py
+++ b/python-api/app/api/v1/vidu.py
@@ -44,10 +44,12 @@ async def vidu_callback(request: Request):
headers_dict = dict(request.headers)
# 使用 APP_BASE_URL 构建 callback_url,确保与提交任务时传给 Vidu 的一致
- #Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
+ # Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
app_base_url = get_settings().app_base_url
callback_url = f"{app_base_url}/api/v1/vidu/callback" if app_base_url else str(request.url)
- logger.info(f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}")
+ logger.info(
+ f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}"
+ )
try:
task_status = await gateway.handle_webhook(
@@ -64,15 +66,13 @@ async def vidu_callback(request: Request):
logger.error(f"[Vidu] 回调处理失败: {e}")
raise HTTPException(status_code=500, detail="回调处理失败,请稍后重试")
- logger.info(f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}")
+ logger.info(
+ f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}"
+ )
# 2. 通过 platform_task_id 反查 Async Engine 内部 task_id,更新 TaskRegistry
- platform_task_id = (
- task_status.result.get("task_id") if task_status.result else None
- )
- video_url = (
- task_status.result.get("video_url") if task_status.result else None
- )
+ platform_task_id = task_status.result.get("task_id") if task_status.result else None
+ video_url = task_status.result.get("video_url") if task_status.result else None
logger.info(f"[Vidu] 准备反查 internal_task_id: platform_task_id={platform_task_id}")
@@ -121,8 +121,6 @@ async def vidu_callback(request: Request):
f"platform={platform_task_id}"
)
else:
- logger.warning(
- f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}"
- )
+ logger.warning(f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}")
return success_response(message="回调已接收")
diff --git a/python-api/app/api/v1/voice.py b/python-api/app/api/v1/voice.py
index ba8cac9..dad2b7b 100644
--- a/python-api/app/api/v1/voice.py
+++ b/python-api/app/api/v1/voice.py
@@ -10,12 +10,13 @@ import logging
import re
import time
import uuid
+
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
-from app.core.exceptions import PlatformError
+from app.core.exceptions import InsufficientPointsException, PlatformError
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -49,7 +50,9 @@ class TTSSynthesizeRequest(BaseModel):
class VoiceCloneSubmitRequest(BaseModel):
"""声音复刻提交请求"""
- source_audio_url: str | None = Field(None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)")
+ source_audio_url: str | None = Field(
+ None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)"
+ )
source_video_url: str | None = Field(None, description="源视频 URL(可选)")
video_id: str | None = Field(None, description="历史作品ID(可选)")
voice_name: str | None = Field(None, description="自定义音色名称(≤20字符)")
@@ -111,7 +114,7 @@ async def synthesize_speech(
# 宽松预检:余额为负或零时阻止,避免浪费第三方资源
balance_info = await ps.get_user_balance(db, current_user.id)
if balance_info["balance"] <= 0:
- raise HTTPException(status_code=402, detail="余额不足,请先充值")
+ raise InsufficientPointsException("余额不足,请先充值")
try:
audio_url = await service.synthesize(
@@ -137,10 +140,8 @@ async def synthesize_speech(
allow_negative=True,
)
await db.commit()
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- logger.error(f"[Voice] TTS 扣费失败: {e}")
+ except InsufficientPointsException:
+ raise
except Exception as e:
logger.error(f"[Voice] TTS 扣费失败: {e}")
@@ -165,7 +166,6 @@ async def synthesize_speech(
raise HTTPException(status_code=500, detail="语音合成失败,请稍后重试")
-
def _normalize_voice_id(name: str | None) -> str:
"""
将用户输入的名称规范化为 Vidu 合法的 voice_id。
@@ -220,9 +220,8 @@ async def submit_clone_task(
required_points = ps._calculate_cost("voice_clone")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -244,10 +243,8 @@ async def submit_clone_task(
description="【声音复刻】",
)
await db.commit()
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- logger.error(f"[Voice] 克隆扣费失败: {e}")
+ except InsufficientPointsException:
+ raise
except Exception as e:
logger.error(f"[Voice] 克隆扣费失败: {e}")
@@ -292,5 +289,3 @@ async def query_clone_task(
),
message="克隆已完成",
)
-
-
+146
View File
@@ -0,0 +1,146 @@
diff --git a/python-api/app/core/exceptions.py b/python-api/app/core/exceptions.py
index d9970d5..837f8d3 100644
--- a/python-api/app/core/exceptions.py
+++ b/python-api/app/core/exceptions.py
@@ -24,9 +24,16 @@ class AppException(HTTPException):
status_code: int,
message: str = "操作失败",
detail: dict | None = None,
+ *,
+ error_code: str | None = None,
):
- super().__init__(status_code=status_code, detail=detail or {})
+ body = detail or {}
+ body["message"] = message
+ if error_code:
+ body["error_code"] = error_code
+ super().__init__(status_code=status_code, detail=body)
self.message = message
+ self.error_code = error_code
class NotFoundException(AppException):
@@ -44,7 +51,7 @@ class ValidationException(AppException):
def __init__(self, message: str = "参数验证失败"):
super().__init__(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
message=message,
)
@@ -79,6 +86,17 @@ class BusinessException(AppException):
)
+class InsufficientPointsException(AppException):
+ """积分不足"""
+
+ def __init__(self, message: str = "积分不足"):
+ super().__init__(
+ status_code=status.HTTP_402_PAYMENT_REQUIRED,
+ message=message,
+ error_code="insufficient_points",
+ )
+
+
class ModelUnavailableException(AppException):
"""AI 模型不可用"""
@@ -99,6 +117,50 @@ class TaskFailedException(AppException):
)
+class PromptNotFoundException(AppException):
+ """提示词文件不存在"""
+
+ def __init__(self, message: str = "未找到提示词"):
+ super().__init__(
+ status_code=status.HTTP_404_NOT_FOUND,
+ message=message,
+ error_code="prompt_not_found",
+ )
+
+
+class AIEmptyResponseException(AppException):
+ """AI 返回内容为空"""
+
+ def __init__(self, message: str = "AI 返回内容为空"):
+ super().__init__(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ message=message,
+ error_code="empty_result",
+ )
+
+
+class AIParseErrorException(AppException):
+ """AI 返回内容解析失败"""
+
+ def __init__(self, message: str = "AI 返回格式解析失败"):
+ super().__init__(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ message=message,
+ error_code="parse_error",
+ )
+
+
+class AITimeoutException(AppException):
+ """AI 调用超时"""
+
+ def __init__(self, message: str = "AI 请求超时,请稍后重试"):
+ super().__init__(
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
+ message=message,
+ error_code="timeout",
+ )
+
+
# ═══════════════════════════════════════════════════════════════
# 第三方平台异常(PlatformError 体系)
# ═══════════════════════════════════════════════════════════════
@@ -111,14 +173,15 @@ class PlatformErrorType:
确保前端和网关能够统一处理。
"""
- RATE_LIMIT = "rate_limit" # 429,可重试
- AUTH_FAILED = "auth_failed" # 401/403,不可重试
- TIMEOUT = "timeout" # 连接/读取超时,可重试
- SERVER_ERROR = "server_error" # 第三方 5xx,可重试
- BAD_REQUEST = "bad_request" # 参数错误,不可重试
+ RATE_LIMIT = "rate_limit" # 429,可重试
+ AUTH_FAILED = "auth_failed" # 401/403,不可重试
+ TIMEOUT = "timeout" # 连接/读取超时,可重试
+ SERVER_ERROR = "server_error" # 第三方 5xx,可重试
+ BAD_REQUEST = "bad_request" # 参数错误,不可重试
QUOTA_EXHAUSTED = "quota_exhausted" # 额度用完,不可重试(或延迟重试)
- NOT_FOUND = "not_found" # 资源不存在,不可重试
- UNKNOWN = "unknown" # 兜底
+ NOT_FOUND = "not_found" # 资源不存在,不可重试
+ CONTENT_VIOLATION = "content_violation" # 内容安全/审核不通过,不可重试
+ UNKNOWN = "unknown" # 兜底
class PlatformError(Exception):
@@ -145,12 +208,14 @@ class PlatformError(Exception):
retryable: bool = False,
error_type: str = PlatformErrorType.UNKNOWN,
status_code: int | None = None,
+ raw_code: str | None = None,
):
super().__init__(message)
self.platform = platform
self.retryable = retryable
self.error_type = error_type
self.status_code = status_code
+ self.raw_code = raw_code
def to_http_status(self) -> int:
"""根据 error_type 和 retryable 返回标准 HTTP 状态码"""
@@ -161,6 +226,7 @@ class PlatformError(Exception):
PlatformErrorType.AUTH_FAILED: 401,
PlatformErrorType.BAD_REQUEST: 400,
PlatformErrorType.NOT_FOUND: 404,
+ PlatformErrorType.CONTENT_VIOLATION: 400,
}
if self.error_type in mapping:
return mapping[self.error_type]
+345
View File
@@ -0,0 +1,345 @@
diff --git a/python-api/app/ai/providers/vidu_provider.py b/python-api/app/ai/providers/vidu_provider.py
index cab5902..fccfbbf 100644
--- a/python-api/app/ai/providers/vidu_provider.py
+++ b/python-api/app/ai/providers/vidu_provider.py
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
-def _map_vidu_error(status: int, message: str) -> PlatformError:
- """把 Vidu HTTP 错误映射为标准 PlatformError"""
+# Vidu 错误码分类
+_VIDU_AUDIT_ERROR_CODES = {
+ "TaskPromptPolicyViolation",
+ "AuditSubmitIllegal",
+ "CreationPolicyViolation",
+ "PhotoAuditNotPass",
+ "AuditFailed",
+ "ImageCheckBodyJointsFailed",
+ "ImageCheckFaceFailed",
+ "ImageObjectsUndetected",
+ "FaceDetectFailure",
+ "FaceDetectNotPass",
+ "NoFaceDetected",
+ "MultiFaceDetected",
+}
+
+_VIDU_RETRYABLE_ERROR_CODES = {
+ "InternalServiceFailure",
+ "ModelUnavailable",
+ "Unknown",
+}
+
+_VIDU_RATE_LIMIT_ERROR_CODES = {
+ "QuotaExceeded",
+ "TooManyRequests",
+ "SystemThrottling",
+ "OperationInProcess",
+}
+
+
+def _extract_vidu_error_code(message: str | None) -> str | None:
+ """从 Vidu 错误信息中提取错误码"""
+ if not message:
+ return None
+ # Vidu 错误码格式:"ErrorCode: 中文描述"
+ return message.split(":")[0].strip() or None
+
+
+def _map_vidu_error(
+ status: int,
+ message: str,
+ *,
+ err_code: str | None = None,
+) -> PlatformError:
+ """把 Vidu HTTP 错误映射为标准 PlatformError
+
+ 优先根据 Vidu 业务错误码(err_code)判断类型,HTTP status 仅作为兜底。
+ """
+ raw_code = err_code or _extract_vidu_error_code(message)
+
+ # 1. 内容安全/审核类:不可重试
+ if raw_code in _VIDU_AUDIT_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=False,
+ error_type=PlatformErrorType.CONTENT_VIOLATION,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 2. 平台内部/模型不可用:可重试
+ if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=True,
+ error_type=PlatformErrorType.SERVER_ERROR,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 3. 限流类:可重试
+ if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
+ return PlatformError(
+ message=message,
+ platform="vidu",
+ retryable=True,
+ error_type=PlatformErrorType.RATE_LIMIT,
+ status_code=status,
+ raw_code=raw_code,
+ )
+
+ # 4. HTTP status 兜底
mapping = {
429: (PlatformErrorType.RATE_LIMIT, True),
401: (PlatformErrorType.AUTH_FAILED, False),
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
retryable=retryable,
error_type=error_type,
status_code=status,
+ raw_code=raw_code,
)
@@ -66,7 +149,9 @@ class ViduProvider:
from app.core.platform_config import get_platform_config_loader
platform_config = get_platform_config_loader().get_platform("vidu")
- self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
+ self.base_url = (
+ platform_config.base_url if platform_config else "https://api.vidu.cn"
+ ).rstrip("/")
if not self.api_key:
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
@@ -135,9 +220,12 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu TTS] 网络错误: {e}")
@@ -182,9 +270,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Clone] 网络错误: {e}")
@@ -238,9 +331,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body)
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu LipSync] 网络错误: {e}")
@@ -264,9 +362,14 @@ class ViduProvider:
resp = await self.client.get(url)
data = resp.json()
if resp.status_code != 200:
- msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
- logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
- raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
+ err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
+ msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
+ logger.error(
+ f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
+ )
+ raise _map_vidu_error(
+ resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
+ )
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Query] 网络错误: {e}")
diff --git a/python-api/app/ai/providers/volcengine_caption_provider.py b/python-api/app/ai/providers/volcengine_caption_provider.py
index 0f2f271..09ddcc7 100644
--- a/python-api/app/ai/providers/volcengine_caption_provider.py
+++ b/python-api/app/ai/providers/volcengine_caption_provider.py
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
if code is not None and code in error_mapping:
error_type, retryable = error_mapping[code]
return PlatformError(
- message, platform="volcengine_caption",
- retryable=retryable, error_type=error_type,
+ message,
+ platform="volcengine_caption",
+ retryable=retryable,
+ error_type=error_type,
status_code=status,
)
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
}
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
- message, platform="volcengine_caption",
- retryable=retryable, error_type=error_type,
+ message,
+ platform="volcengine_caption",
+ retryable=retryable,
+ error_type=error_type,
status_code=status,
)
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
max_lines: int = 1,
) -> dict[str, Any]:
"""提交字幕生成任务,返回 {id: task_id}"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"language": language,
"caption_type": caption_type,
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
+ raise _map_caption_error(
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
+ ) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
+ f"字幕服务网络错误: {e}",
+ platform="volcengine_caption",
+ retryable=True,
+ error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
blocking: bool = False,
) -> dict[str, Any]:
"""查询字幕任务结果,返回原始 JSON"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
- raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
+ raise _map_caption_error(
+ e.response.status_code, f"HTTP错误: {e.response.status_code}"
+ ) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
- f"字幕服务网络错误: {e}", platform="volcengine_caption",
- retryable=True, error_type=PlatformErrorType.TIMEOUT,
+ f"字幕服务网络错误: {e}",
+ platform="volcengine_caption",
+ retryable=True,
+ error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
sta_punc_mode: int = 3,
) -> dict[str, Any]:
"""提交自动字幕打轴任务,返回 {id: task_id}"""
- params = {
+ params: dict[str, str | int] = {
"appid": self.appid,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
response.raise_for_status()
data = response.json()
if "id" not in data:
- raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
+ raise _map_caption_error(
+ 500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
+ )
return data
except PlatformError:
raise
diff --git a/python-api/app/ai/providers/volcengine_provider.py b/python-api/app/ai/providers/volcengine_provider.py
index 0e2a5d5..9f029a0 100644
--- a/python-api/app/ai/providers/volcengine_provider.py
+++ b/python-api/app/ai/providers/volcengine_provider.py
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
if status == 429 or "rate limit" in message.lower():
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
- error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
+ error_type=PlatformErrorType.RATE_LIMIT,
+ status_code=status,
)
elif status in (401, 403) or "authentication" in message.lower():
return PlatformError(
- message, platform="volcengine_ark", retryable=False,
- error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=False,
+ error_type=PlatformErrorType.AUTH_FAILED,
+ status_code=status,
)
elif status and status >= 500:
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
- error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
+ error_type=PlatformErrorType.SERVER_ERROR,
+ status_code=status,
)
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
return PlatformError(
- message, platform="volcengine_ark", retryable=True,
+ message,
+ platform="volcengine_ark",
+ retryable=True,
error_type=PlatformErrorType.TIMEOUT,
)
else:
return PlatformError(
- message, platform="volcengine_ark", retryable=False,
+ message,
+ platform="volcengine_ark",
+ retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
+331
View File
@@ -0,0 +1,331 @@
diff --git a/python-api/app/services/point_service.py b/python-api/app/services/point_service.py
index 0e5d79e..709f78d 100644
--- a/python-api/app/services/point_service.py
+++ b/python-api/app/services/point_service.py
@@ -25,7 +25,7 @@ import logging
import math
from datetime import UTC, datetime, timedelta
from pathlib import Path
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import yaml
from sqlalchemy import select
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
+from app.core.exceptions import InsufficientPointsException
from app.models.point_batch import PointBatch
from app.models.point_transaction import PointTransaction
from app.models.user_point import UserPoint
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
-def _load_points_config() -> dict:
+def _load_points_config() -> dict[str, Any]:
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
if not _CONFIG_PATH.exists():
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
- with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
+ with open(_CONFIG_PATH, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
merged: dict[str, dict] = {}
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
return merged
-POINTS_CONFIG: dict[str, dict] = _load_points_config()
+POINTS_CONFIG: dict[str, Any] = _load_points_config()
def get_recharge_options() -> list[dict]:
"""获取充值档位配置(由后端控制,支持积分赠送)"""
- return POINTS_CONFIG.get("_recharge_options", [])
+ options = POINTS_CONFIG.get("_recharge_options", [])
+ if isinstance(options, list):
+ return options
+ return []
def get_chargeable_source_types() -> list[str]:
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
return [
- key for key, cfg in POINTS_CONFIG.items()
+ key
+ for key, cfg in POINTS_CONFIG.items()
if not key.startswith("_") and cfg.get("mode") != "free"
]
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
# ── 余额查询 ──────────────────────────────────────────
+
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
- result = await db.execute(
- select(UserPoint).where(UserPoint.user_id == user_id)
- )
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
from sqlalchemy import func as _func
available_result = await db.execute(
- select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
- .where(
+ select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
PointBatch.user_id == user_id,
PointBatch.remaining > 0,
PointBatch.expired_at > _now(),
@@ -221,6 +224,7 @@ async def check_balance(
# ── 充值 ──────────────────────────────────────────────
+
async def recharge(
db: AsyncSession,
*,
@@ -247,8 +251,7 @@ async def recharge(
# 幂等保护:同一笔订单(order_id)只能充值一次
if order_id:
existing_result = await db.execute(
- select(PointTransaction)
- .where(
+ select(PointTransaction).where(
PointTransaction.source_id == str(order_id),
PointTransaction.type == "recharge",
)
@@ -259,9 +262,7 @@ async def recharge(
return existing_tx
# 1. 获取或创建用户积分账户
- result = await db.execute(
- select(UserPoint).where(UserPoint.user_id == user_id)
- )
+ result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -353,7 +354,7 @@ async def consume(
直接扣费(后置计费)。
业务执行成功后调用,按实际消耗直接扣除余额。
- 默认不允许欠费(余额不足时抛出 ValueError)。
+ 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
:param points: 实际消耗积分(正整数)
@@ -383,9 +384,7 @@ async def consume(
# 2. 获取用户积分账户(加锁)
result = await db.execute(
- select(UserPoint)
- .where(UserPoint.user_id == user_id)
- .with_for_update()
+ select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
)
up = result.scalar_one_or_none()
@@ -404,7 +403,7 @@ async def consume(
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
available = sum(b.remaining for b in batches)
if not allow_negative and available < points:
- raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
+ raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
remaining_to_deduct = points
for batch in batches:
@@ -440,6 +439,7 @@ async def consume(
# ── 过期回收 ──────────────────────────────────────────
+
async def expire_batches(db: AsyncSession) -> int:
"""
回收过期积分批次。返回过期积分总数。
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
# 获取用户积分账户(加锁)
up_result = await db.execute(
- select(UserPoint)
- .where(UserPoint.user_id == batch.user_id)
- .with_for_update()
+ select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
)
up = up_result.scalar_one_or_none()
if not up:
diff --git a/python-api/app/services/script_service.py b/python-api/app/services/script_service.py
index 49aa4b1..60f58d5 100644
--- a/python-api/app/services/script_service.py
+++ b/python-api/app/services/script_service.py
@@ -7,9 +7,16 @@ import asyncio
import logging
import time
from pathlib import Path
+from typing import Any
from app.ai.model_router import get_model_router
from app.ai.prompts import load_prompt_file, load_script_user_prompt
+from app.core.exceptions import (
+ AIEmptyResponseException,
+ AIParseErrorException,
+ AITimeoutException,
+ PromptNotFoundException,
+)
from app.schemas.script import ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
-
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
-
-
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
@@ -58,7 +62,7 @@ class ScriptService:
# 加载 Prompt
system_prompt = load_prompt_file(category, filename)
if not system_prompt:
- raise ValueError(f"未找到提示词: category={category}, filename={filename}")
+ raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
# 用户提示词
user_prompt = load_script_user_prompt(
@@ -75,24 +79,26 @@ class ScriptService:
)
if not result.content or not result.content.strip():
- raise ValueError("AI 返回内容为空,请检查模型配置或重试")
+ raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
- raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
+ raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
- raise ValueError("AI 返回的分镜数据为空或格式不正确")
+ raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
shots = [ScriptShot(**shot) for shot in shots_data]
return shots
+ except (AIEmptyResponseException, AIParseErrorException):
+ raise
except Exception as e:
- raise ValueError(f"分镜数据处理失败: {str(e)}")
+ raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
async def polish_content(
self,
@@ -144,21 +150,23 @@ class ScriptService:
)
return result.content.strip()
except TimeoutError:
- raise ValueError("润色请求超时,请重试")
+ raise AITimeoutException("润色请求超时,请重试")
+ except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
+ raise
except Exception as e:
- raise ValueError(f"润色失败: {str(e)}")
+ raise AIParseErrorException(f"润色失败: {str(e)}")
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
- models = []
+ models: list[dict[str, Any]] = []
available_count = 0
- recommended = None
+ recommended: dict[str, Any] | None = None
for _provider_id, health in health_results.items():
- model_info = {
+ model_info: dict[str, Any] = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
@@ -169,9 +177,12 @@ class ScriptService:
if health.is_available:
available_count += 1
- if recommended is None or health.response_time < recommended.get(
- "response_time", float("inf")
- ):
+ current_best = (
+ float("inf")
+ if recommended is None
+ else float(recommended.get("response_time") or float("inf"))
+ )
+ if health.response_time < current_best:
recommended = model_info
total = len(models)
@@ -188,7 +199,6 @@ class ScriptService:
"""测试指定模型连接"""
model_router = await get_model_router()
-
start_time = time.time()
try:
diff --git a/python-api/app/services/vidu_service.py b/python-api/app/services/vidu_service.py
index 054823f..61bbdcd 100644
--- a/python-api/app/services/vidu_service.py
+++ b/python-api/app/services/vidu_service.py
@@ -207,8 +207,9 @@ class ViduService:
error_type=PlatformErrorType.BAD_REQUEST,
)
- logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
- return result.data or {}
+ clone_data = result.data or {}
+ logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
+ return clone_data
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
"""Vidu 声音复刻是同步接口,无独立查询。
@@ -270,6 +271,8 @@ class ViduService:
result_data = status.result or {}
return {
"state": ViduAdapter.denormalize_state(status.state),
- "creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
+ "creations": (
+ [{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
+ ),
"message": status.error_message,
}
diff --git a/python-api/app/services/volcengine_caption_service.py b/python-api/app/services/volcengine_caption_service.py
index 83f59b4..9a565a5 100644
--- a/python-api/app/services/volcengine_caption_service.py
+++ b/python-api/app/services/volcengine_caption_service.py
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
error_type=PlatformErrorType.BAD_REQUEST,
)
- logger.warning(
- f"{task_name}超过最大轮询次数: task_id={task_id}, "
- f"retries={retries}"
- )
+ logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
raise PlatformError(
f"{task_name}超时,请稍后重试",
platform="volcengine_caption",
+15 -23
View File
@@ -9,7 +9,7 @@
**美家卡智影**是一款面向桌面端的 AI 视频创作应用,采用"Python 后端 API + Tauri 桌面前端"的混合架构。
- **产品标识**: `cn.meijiaka.ai-video` / `cn.meijiaka.ai-zy`
- **版本**: `1.8.2`
- **版本**: `1.9.1`
- **核心功能**: AI 脚本生成、AI 配音合成(TTS)、声音复刻、视频生成(Vidu)、视频字幕生成、压制成片(FFmpeg)、项目本地持久化
### 技术栈总览
@@ -25,7 +25,7 @@
| 前端构建 | Vite 7 |
| 状态管理 | Zustand 5 + Immer 11 |
| 路由 | `react-router-dom` |
| 数据请求 | 自研智能路由客户端 + SWR |
| 数据请求 | HTTP 客户端 (`src/api/client.ts`) + SWR |
| 测试(后端) | pytest + pytest-asyncio |
| 测试(前端) | Vitest 4 + jsdom + `@testing-library/react` |
| 部署 | Docker + Docker Compose + Nginx |
@@ -42,7 +42,7 @@
├── tauri-app/ # Tauri 桌面前端
├── docs/ # 项目文档(架构设计、API 对接指南等)
├── scripts/ # 辅助脚本
├── .gitlab-ci.yml # GitLab CI/CD 配置
├── .github/workflows/ # GitHub Actions CI/CD 配置
└── AGENTS.md # 本文档
```
@@ -51,7 +51,7 @@
```
python-api/
├── app/ # 主应用代码
│ ├── api/v1/ # API 路由(按领域拆分:auth, script, voice, vidu, caption, tasks, upload, materials, system
│ ├── api/v1/ # API 路由(按领域拆分:auth, script, voice, vidu, caption, tasks, upload, materials, system, bgm_music, cover_background, image, points, update, events
│ ├── core/ # 核心工具(配置加载、安全、异常、Redis 客户端、健康检查)
│ ├── db/ # 数据库配置与会话管理
│ ├── models/ # SQLAlchemy ORM 模型(BaseModel 提供 UUID 主键 + 时间戳)
@@ -62,7 +62,7 @@ python-api/
│ ├── crud/ # 数据库 CRUD 封装
│ ├── config.py # Pydantic Settings 配置管理
│ └── main.py # FastAPI 应用入口(含 lifespan 管理)
├── config/ # 运行时配置文件(platform-config.yaml, materials.json
├── config/ # 运行时配置文件(platform-config.yaml, points-config.yaml
├── alembic/ # 数据库迁移脚本
├── nginx/ # Nginx 反向代理配置(含 acme.sh SSL 证书脚本)
├── Dockerfile # 多阶段构建镜像(builder + production
@@ -81,12 +81,11 @@ python-api/
tauri-app/
├── src/ # React 前端源码
│ ├── api/ # API 客户端与模块
│ │ ├── client.ts # 智能路由客户端(HTTP / IPC 自动选择,camelCase ↔ snake_case 自动转换)
│ │ ├── generated/ # OpenAPI 生成的 TypeScript 类型
│ │ ├── client.ts # HTTP 客户端(camelCase ↔ snake_case 自动转换)
│ │ └── modules/ # 按领域拆分的 API 模块
│ ├── components/ # 可复用组件(PascalCase 文件夹)
│ ├── pages/ # 页面级组件(PascalCase 文件夹)
│ ├── store/ # Zustand 状态管理(含 __tests__
│ ├── store/ # Zustand 状态管理
│ ├── hooks/ # 自定义 React Hooks
│ ├── utils/ # 工具函数
│ ├── styles/ # CSS 变量与全局样式
@@ -97,11 +96,8 @@ tauri-app/
│ │ ├── lib.rs # Tauri Builder、Command 定义、公共类型
│ │ ├── ffmpeg_cmd.rs # FFmpeg 命令封装
│ │ ├── video_processing.rs # 压制成片业务逻辑
│ │ ├── api_proxy.rs # Python API 代理
│ │ ├── auth.rs # 认证命令
│ │ ├── avatar_cache.rs # 头像缓存
│ │ ├── storage/ # 本地存储引擎(项目、认证、配置、头像等)
│ │ ├── commands/ # Tauri IPC 命令(按领域拆分)
│ │ ├── commands/ # Tauri IPC 命令(按领域拆分asset, auth_state, cover_avatar, file, product, project, video_compose, voice 等
│ │ └── utils.rs # 通用工具
│ ├── Cargo.toml # Rust 依赖
│ ├── tauri.conf.json # Tauri 应用配置(窗口、CSP、打包、sidecar)
@@ -231,8 +227,7 @@ npm run format:check # Prettier --check
npm run stylelint # Stylelint
npm run stylelint:fix # Stylelint --fix
# OpenAPI 类型生成
npm run gen:api # 根据 openapi.json 生成 TypeScript 类型
# 注:当前项目未配置 OpenAPI 自动生成 TypeScript 类型的脚本
```
---
@@ -280,7 +275,7 @@ npm run gen:api # 根据 openapi.json 生成 TypeScript 类型
1. **判断是否需要本地能力**(FFmpeg、文件系统、系统调用)。
2. **不需要** → 直接在 `tauri-app/src/api/modules/` 使用 `client.get/post/put/delete` 调用 Python HTTP API。
3. **需要**将 endpoint 加入 `tauri-app/src/api/client.ts` 的 IPC 处理逻辑,并在 `tauri-app/src-tauri/src/commands/``lib.rs` 中实现对应的 `#[tauri::command]` 处理器。
3. **需要** `tauri-app/src/api/modules/` 中通过 `@tauri-apps/api/core``invoke` 调用 Rust 命令,并在 `tauri-app/src-tauri/src/commands/``lib.rs` 中实现对应的 `#[tauri::command]` 处理器。
### 语义层防护网(后端)
@@ -304,7 +299,6 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
- **组件测试**: `@testing-library/react` + `@testing-library/jest-dom`
- **文件位置**:
- 全局 setup: `src/__tests__/setup.ts`
- Store 测试: `src/store/__tests__/*.test.tsx`
- 组件/页面测试: 建议放在被测文件同目录或 `__tests__` 子目录中
- **Mock 策略**: `setup.ts` 中已全局 mock `localStorage``@tauri-apps/api/core``invoke` 方法、`window.__TAURI_INTERNALS__`。每个测试后自动调用 `vi.clearAllMocks()`
@@ -327,13 +321,11 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
### 部署流程
#### 测试环境(GitLab CI
#### 测试环境(GitHub Actions
`.gitlab-ci.yml` 定义了 `deploy-backend` 任务
1. 在部署服务器拉取代码(`master` 分支触发)。
2. 构建 api + scheduler 镜像(`docker-compose.test.yml`
3. 启动服务,健康检查 `curl http://localhost:8081/health`
4. 清理 7 天前的旧镜像。
`.github/workflows/` 存放 CI/CD 工作流
- `release.yml`Tauri 桌面端 Release 打包工作流(按 tag 触发,支持 macOS / Windows 平台及 sidecar 二进制下载)。
- 后端 api + scheduler 的测试/生产部署目前通过手动执行 `docker-compose.test.yml` / `docker-compose.prod.yml` 完成
#### 生产环境
@@ -384,7 +376,7 @@ Makefile 中 `lint-semantic` 目标会检查以下规则:
- **ORM**: SQLAlchemy 2.0(异步,asyncpg 驱动)
- **迁移工具**: Alembic
- **基础模型**: `app.models.base.BaseModel` 提供 UUID 主键、`created_at``updated_at`
- **当前模型**: `User`(用户/设备认证)
- **当前模型**: 包括 `User`(用户/设备认证)`UserDevice``UserPoint``PointRechargeOrder``PointTransaction``BrollCategory``BrollMaterial``BrollTag``CoverBackground``BgmMusic``Update`
- **迁移注意**: Alembic 使用同步连接(psycopg2),会自动将 `+asyncpg` 替换掉。
---
+1 -1
View File
@@ -1 +1 @@
1.8.3
1.9.1
+11 -17
View File
@@ -115,26 +115,20 @@ lint-semantic: ## 语义层禁词检查(防止供应商术语泄漏到业务
echo "❌ API 层发现 element_id(应使用 provider_element_id 或 human_id"; \
exit 1; \
fi
@# Scheduler 层禁止 task_id 作为内部变量/Redis key(读取 Provider 返回除外)
@errs=$$(grep -rn '\btask_id\b' app/scheduler --include='*.py' \
| grep -v 'job_id' \
| grep -v '__pycache__' \
| grep -v '\.get("task_id")' \
| grep -v 'result.get("task_id")' \
| grep -v 'task_type' \
| grep -v '"task_id"' \
| grep -v "'task_id'"); \
if [ -n "$$errs" ]; then \
echo "$$errs"; \
echo "❌ Scheduler 层发现 task_id(应使用 job_id"; \
exit 1; \
fi
@# Scheduler 层 Redis key 必须使用 job: 而非 task:
@errs=$$(grep -rn 'redis.*task:' app/scheduler --include='*.py' \
@# Scheduler 层统一使用 task 命名,禁止混用 job
@errs=$$(grep -rn '\bjob_id\b' app/scheduler --include='*.py' \
| grep -v '__pycache__'); \
if [ -n "$$errs" ]; then \
echo "$$errs"; \
echo "❌ Scheduler Redis key 使用 task:(应使用 job:"; \
echo "❌ Scheduler 层发现 job_id(应使用 task_id"; \
exit 1; \
fi
@# Scheduler 层 Redis key 必须使用 task: 而非 job:
@errs=$$(grep -rn 'redis.*job:' app/scheduler --include='*.py' \
| grep -v '__pycache__'); \
if [ -n "$$errs" ]; then \
echo "$$errs"; \
echo "❌ Scheduler Redis key 使用 job:(应使用 task:"; \
exit 1; \
fi
@echo "✅ 语义层检查通过"
+17 -9
View File
@@ -113,7 +113,7 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
result = await self.provider.clone_voice(
audio_url=payload["audio_url"],
voice_id=payload["voice_id"],
text=payload.get("text"),
text=payload.get("text") or "",
)
return AdapterResponse(
success=True,
@@ -219,6 +219,7 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
) -> bool:
"""验证 Vidu 回调 HMAC-SHA256 签名"""
import logging
logger = logging.getLogger(__name__)
# HTTP 头大小写不敏感:建立小写 key 的查找表
@@ -233,6 +234,9 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
if not all([signature, algorithm, access_key, signed_headers_str, date]):
logger.warning(f"[Vidu] 签名验证失败: 缺少必要头, headers={list(headers.keys())}")
return False
assert signature is not None
assert signed_headers_str is not None
assert date is not None
if algorithm != "hmac-sha256":
logger.warning(f"[Vidu] 签名验证失败: 不支持的算法 {algorithm}")
return False
@@ -256,17 +260,15 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
canonical_query_string = parsed.query or ""
signing_string = (
f"POST\n"
f"{http_uri}\n"
f"{canonical_query_string}\n"
f"vidu\n"
f"{date}\n"
f"POST\n" f"{http_uri}\n" f"{canonical_query_string}\n" f"vidu\n" f"{date}\n"
)
for name in header_names:
signing_string += f"{name}:{header_values[name]}\n"
expected = base64.b64encode(
hmac.new(secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256).digest()
hmac.new(
secret.encode("utf-8"), signing_string.encode("utf-8"), hashlib.sha256
).digest()
).decode("utf-8")
if not hmac.compare_digest(signature, expected):
@@ -307,6 +309,12 @@ class ViduAdapter(PlatformAdapter, SyncCapable, TaskCapable, CallbackCapable):
return TaskStatus(
state=self.normalize_state(state),
result={"video_url": video_url, "creations": creations, "task_id": task_id} if video_url else {"task_id": task_id},
error_message=(data.get("err_code") or data.get("message")) if state == "failed" else None,
result=(
{"video_url": video_url, "creations": creations, "task_id": task_id}
if video_url
else {"task_id": task_id}
),
error_message=(
(data.get("err_code") or data.get("message")) if state == "failed" else None
),
)
@@ -69,11 +69,11 @@ class VolcengineArkAdapter(PlatformAdapter, SyncCapable):
)
elif method == Method.EMBEDDING:
result = await self.provider.create_embeddings(
embedding_result: dict[str, Any] = await self.provider.create_embeddings(
texts=payload["texts"],
model=payload.get("model"),
)
return AdapterResponse(success=True, data=result)
return AdapterResponse(success=True, data=embedding_result)
else:
return AdapterResponse(
+12 -10
View File
@@ -24,7 +24,9 @@ logger = logging.getLogger(__name__)
class LLMGateway:
"""LLM 调用网关"""
def __init__(self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None):
def __init__(
self, adapters: dict[str, SyncCapable], fallback_chains: dict[str, list[str]] | None = None
):
self.adapters = adapters
self.fallback_chains = fallback_chains or {}
@@ -55,15 +57,18 @@ class LLMGateway:
for mid in models_to_try:
adapter = self._get_adapter(platform)
try:
result = await adapter.call(Method.CHAT, {
"prompt": prompt,
"model": mid,
**kwargs,
})
result = await adapter.call(
Method.CHAT,
{
"prompt": prompt,
"model": mid,
**kwargs,
},
)
if result.success:
if mid != model_id:
logger.warning(f"[LLMGateway] 模型降级成功: {model_id}{mid}")
return result.data
return result.data or {}
else:
last_error = PlatformError(
result.error_message or f"模型 {mid} 调用失败",
@@ -82,6 +87,3 @@ class LLMGateway:
platform=platform,
retryable=False,
)
+23 -22
View File
@@ -13,6 +13,7 @@ from app.ai.adapters.constants import Method
from app.ai.providers.base import GenerationResult, ModelHealth, ProviderError
from app.ai.providers.volcengine_provider import VolcengineProvider
from app.core.config_loader import AIModelConfigLoader, get_config_loader
from app.core.exceptions import AppException, PlatformError
from app.platform_gateway import PlatformGateway
logger = logging.getLogger(__name__)
@@ -26,9 +27,7 @@ class _PlatformInstance:
self.gateway = gateway
self.provider_id = config.get("id", "")
async def generate(
self, model_name: str, prompt: str, **kwargs
) -> GenerationResult:
async def generate(self, model_name: str, prompt: str, **kwargs) -> GenerationResult:
"""调用生成(通过 PlatformGateway"""
if self.gateway:
result = await self.gateway.call_sync(
@@ -41,9 +40,7 @@ class _PlatformInstance:
},
)
if not result.success:
raise ProviderError(
result.error_message or f"{self.provider_id} 调用失败"
)
raise ProviderError(result.error_message or f"{self.provider_id} 调用失败")
data = result.data or {}
return GenerationResult(
content=data.get("content", ""),
@@ -64,7 +61,11 @@ class _PlatformInstance:
id=model_name or self.provider_id,
name=self.provider_id,
is_available=adapter_result.success,
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
response_time=(
adapter_result.data.get("response_time_ms", 0)
if adapter_result.data
else 0
),
last_error=adapter_result.error_message,
)
except Exception as e:
@@ -89,7 +90,6 @@ class ModelRouter:
- 模型自动选择
"""
def __init__(self):
self.platforms: dict[str, _PlatformInstance] = {}
self._config_loader: AIModelConfigLoader | None = None
@@ -249,11 +249,7 @@ class ModelRouter:
if task_type:
model_id = self.select_model_for_task(task_type)
if model_id is None:
models = (
self._config_loader.get_enabled_models()
if self._config_loader
else []
)
models = self._config_loader.get_enabled_models() if self._config_loader else []
if models:
model_id = models[0].id
else:
@@ -270,9 +266,9 @@ class ModelRouter:
params = {**model.default_params, **kwargs}
try:
return await platform.generate(
prompt=prompt, model_name=model.model_name, **params
)
return await platform.generate(prompt=prompt, model_name=model.model_name, **params)
except (PlatformError, AppException):
raise
except Exception as e:
raise ProviderError(f"模型 {model_id} 生成失败: {e}") from e
@@ -290,7 +286,11 @@ class ModelRouter:
id=model.id,
name=model.display_name,
is_available=adapter_result.success,
response_time=adapter_result.data.get("response_time_ms", 0) if adapter_result.data else 0,
response_time=(
adapter_result.data.get("response_time_ms", 0)
if adapter_result.data
else 0
),
last_error=adapter_result.error_message,
)
else:
@@ -306,11 +306,12 @@ class ModelRouter:
# fallback: 直接通过 PlatformInstance
results = {}
if model_id:
model = self._config_loader.get_model(model_id) if self._config_loader else None
if model:
platform = self.platforms.get(model.platform_id)
if platform:
results[model_id] = await platform.health_check(model.model_name)
target_model = self._config_loader.get_model(model_id) if self._config_loader else None
if target_model is None:
raise ProviderError(f"模型不存在: {model_id}")
platform = self.platforms.get(target_model.platform_id)
if platform:
results[model_id] = await platform.health_check(target_model.model_name)
else:
if self._config_loader:
for model in self._config_loader.get_enabled_models():
+13 -10
View File
@@ -94,10 +94,12 @@ def list_categories() -> list[dict]:
for cat_meta in meta.get("categories", []):
cat_code = cat_meta["code"]
cat_name = cat_meta.get("name", cat_code)
categories.append({
"code": cat_code,
"name": cat_name,
})
categories.append(
{
"code": cat_code,
"name": cat_name,
}
)
return categories
@@ -135,11 +137,13 @@ def list_prompt_files(category: str) -> list[dict]:
else:
label = name
desc = ""
files.append({
"filename": f.name,
"label": label.strip(),
"desc": desc.strip(),
})
files.append(
{
"filename": f.name,
"label": label.strip(),
"desc": desc.strip(),
}
)
return files
@@ -174,7 +178,6 @@ def load_system_prompt(category: str, subcategory: str) -> str:
return load_prompt_file(category, chosen["filename"])
def load_script_user_prompt(
topic: str,
extra_params: str | None = None,
@@ -50,12 +50,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
【分镜固定结构规则】
开篇的分镜为:一段人物出镜
中间内容全部用空镜,空镜(内置完整素材库标题)与文案内容需匹配
@@ -29,7 +29,7 @@
合同签署
厨卫原始毛坯状态-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
装修合同核对-现场交底
卧室原始状态-翻新基础
厨卫原始状态-翻新基础
@@ -66,12 +66,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
验收合格签字确认-全屋验收
@@ -42,7 +42,7 @@
墙面纯色面漆涂刷-面漆涂刷
乳胶漆调配-面漆涂刷
卫生间陶粒回填
防水翻车漏水-施工翻车
防水翻车漏水-施工翻车
轻钢龙骨骨架搭建-吊顶造型
木龙骨基础框架固定-吊顶造型
全屋定制板材检查
@@ -53,11 +53,11 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
【分镜固定结构规则】
开篇的分镜为:(可选用讨好装修师傅、恶搞开篇或施工翻车镜,最好能贴近话术内容和主题)+ 一段人物出镜 + 一段空镜补充,不得有 2 段人物出镜
分点阐述全部用空镜,空镜(素材库标题)与文案内容需匹配
@@ -94,7 +94,7 @@ duration: “分镜时长”(如 “5s”,时长为 "配音文案" 的字数
{
"id": 3,
"type": "empty_shot",
"scene": "墙体掉落-施工翻车",
"scene": "墙体掉落-施工翻车",
"voiceover": "今天我告诉你几个监工关键时间点,你必须在场。",
"duration": "5.00s"
}
@@ -51,10 +51,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -214,12 +214,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -36,10 +36,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -199,12 +199,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -39,10 +39,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -202,12 +202,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -52,10 +52,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -215,12 +215,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -41,10 +41,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -204,12 +204,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -42,10 +42,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -205,12 +205,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -41,10 +41,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -204,12 +204,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -127,12 +127,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
踢脚线安装验收-软装进场
【分镜固定结构规则】
开篇的分镜为:一段网红开篇(可选用恶搞开篇或施工翻车镜,最好能贴近硬装收尾、软装进场、装修细节避坑主题,优先选工地恶搞、墙面空鼓、硬装完工全屋全景等相关)+ 一段人物出镜 + 一段空镜补充,不得有 2 段人物出镜
@@ -44,10 +44,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -207,12 +207,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -43,10 +43,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -206,12 +206,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -41,10 +41,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -204,12 +204,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -43,10 +43,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -206,12 +206,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -70,10 +70,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -233,12 +233,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -48,10 +48,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -211,12 +211,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -36,10 +36,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -199,12 +199,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -43,10 +43,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -206,12 +206,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -61,10 +61,10 @@ type为segment=人物出镜;type为empty_shot=从下方内置素材库选匹
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -224,12 +224,12 @@ type为segment=人物出镜;type为empty_shot=从下方内置素材库选匹
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
@@ -63,10 +63,10 @@
客厅原始墙面-毛坯基础
强弱电箱原始特写-毛坯基础
毛坯全屋广角全景-毛坯基础
阳台原始结构空镜-毛坯基础
阳台原始结构-毛坯基础
墙面点位弹线-现场交底
开关插座定位-现场交底
开工仪式简单镜头-现场交底
开工仪式-现场交底
施工方案现场讲解-现场交底
甲乙工长三方对接-现场交底
给排水点位标记-现场交底
@@ -226,12 +226,12 @@
暴力拆除-恶搞开篇
炫技-恶搞开篇
贴砖恶搞-恶搞开篇
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙体掉落-施工翻车
墙面开裂-施工翻车
墙面空鼓-施工翻车
水管错位-施工翻车
电线乱接-施工翻车
防水翻车漏水-施工翻车
墙面漆面细节查验-全屋验收
柜体开合顺畅度检查-全屋验收
踢脚线安装验收-软装进场
+2 -1
View File
@@ -14,8 +14,9 @@ from app.ai.providers.base import (
# 火山方舟官方 SDK Provider
# 需要: pip install 'volcengine-python-sdk[ark]'
try:
from app.ai.providers.volcengine_provider import VolcengineProvider
from app.ai.providers.volcengine_provider import VolcengineProvider as _VolcengineProvider
VolcengineProvider: type | None = _VolcengineProvider
VOLCENGINE_AVAILABLE = True
except ImportError:
VOLCENGINE_AVAILABLE = False
+118 -15
View File
@@ -24,8 +24,90 @@ from app.core.exceptions import PlatformError, PlatformErrorType
logger = logging.getLogger(__name__)
def _map_vidu_error(status: int, message: str) -> PlatformError:
"""把 Vidu HTTP 错误映射为标准 PlatformError"""
# Vidu 错误码分类
_VIDU_AUDIT_ERROR_CODES = {
"TaskPromptPolicyViolation",
"AuditSubmitIllegal",
"CreationPolicyViolation",
"PhotoAuditNotPass",
"AuditFailed",
"ImageCheckBodyJointsFailed",
"ImageCheckFaceFailed",
"ImageObjectsUndetected",
"FaceDetectFailure",
"FaceDetectNotPass",
"NoFaceDetected",
"MultiFaceDetected",
}
_VIDU_RETRYABLE_ERROR_CODES = {
"InternalServiceFailure",
"ModelUnavailable",
"Unknown",
}
_VIDU_RATE_LIMIT_ERROR_CODES = {
"QuotaExceeded",
"TooManyRequests",
"SystemThrottling",
"OperationInProcess",
}
def _extract_vidu_error_code(message: str | None) -> str | None:
"""从 Vidu 错误信息中提取错误码"""
if not message:
return None
# Vidu 错误码格式:"ErrorCode: 中文描述"
return message.split(":")[0].strip() or None
def _map_vidu_error(
status: int,
message: str,
*,
err_code: str | None = None,
) -> PlatformError:
"""把 Vidu HTTP 错误映射为标准 PlatformError
优先根据 Vidu 业务错误码(err_code)判断类型HTTP status 仅作为兜底
"""
raw_code = err_code or _extract_vidu_error_code(message)
# 1. 内容安全/审核类:不可重试
if raw_code in _VIDU_AUDIT_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
status_code=status,
raw_code=raw_code,
)
# 2. 平台内部/模型不可用:可重试
if raw_code in _VIDU_RETRYABLE_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=True,
error_type=PlatformErrorType.SERVER_ERROR,
status_code=status,
raw_code=raw_code,
)
# 3. 限流类:可重试
if raw_code in _VIDU_RATE_LIMIT_ERROR_CODES:
return PlatformError(
message=message,
platform="vidu",
retryable=True,
error_type=PlatformErrorType.RATE_LIMIT,
status_code=status,
raw_code=raw_code,
)
# 4. HTTP status 兜底
mapping = {
429: (PlatformErrorType.RATE_LIMIT, True),
401: (PlatformErrorType.AUTH_FAILED, False),
@@ -43,6 +125,7 @@ def _map_vidu_error(status: int, message: str) -> PlatformError:
retryable=retryable,
error_type=error_type,
status_code=status,
raw_code=raw_code,
)
@@ -66,7 +149,9 @@ class ViduProvider:
from app.core.platform_config import get_platform_config_loader
platform_config = get_platform_config_loader().get_platform("vidu")
self.base_url = (platform_config.base_url if platform_config else "https://api.vidu.cn").rstrip("/")
self.base_url = (
platform_config.base_url if platform_config else "https://api.vidu.cn"
).rstrip("/")
if not self.api_key:
raise ValueError("Vidu API Key 未配置,请在 .env 中设置 VIDU_API_KEY")
@@ -135,9 +220,12 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
logger.error(f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}")
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}")
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu TTS] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(resp.status_code, f"Vidu TTS error: {msg}", err_code=err_code)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu TTS] 网络错误: {e}")
@@ -182,9 +270,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body, timeout=httpx.Timeout(120.0, connect=5.0))
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
logger.error(f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}")
raise _map_vidu_error(resp.status_code, f"Vidu clone error: {msg}")
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu Clone] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu clone error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Clone] 网络错误: {e}")
@@ -238,9 +331,14 @@ class ViduProvider:
resp = await self.client.post(url, json=body)
data = resp.json()
if resp.status_code != 200 or data.get("state") == "failed":
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
logger.error(f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}")
raise _map_vidu_error(resp.status_code, f"Vidu lip-sync error: {msg}")
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu LipSync] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu lip-sync error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu LipSync] 网络错误: {e}")
@@ -264,9 +362,14 @@ class ViduProvider:
resp = await self.client.get(url)
data = resp.json()
if resp.status_code != 200:
msg = data.get("err_code") or data.get("message") or f"HTTP {resp.status_code}"
logger.error(f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}")
raise _map_vidu_error(resp.status_code, f"Vidu query task error: {msg}")
err_code = data.get("err_code") or _extract_vidu_error_code(data.get("message"))
msg = err_code or data.get("message") or f"HTTP {resp.status_code}"
logger.error(
f"[Vidu Query] 请求失败: url={url}, status={resp.status_code}, response={data}"
)
raise _map_vidu_error(
resp.status_code, f"Vidu query task error: {msg}", err_code=err_code
)
return data
except (httpx.NetworkError, httpx.TimeoutException) as e:
logger.error(f"[Vidu Query] 网络错误: {e}")
@@ -37,8 +37,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
if code is not None and code in error_mapping:
error_type, retryable = error_mapping[code]
return PlatformError(
message, platform="volcengine_caption",
retryable=retryable, error_type=error_type,
message,
platform="volcengine_caption",
retryable=retryable,
error_type=error_type,
status_code=status,
)
@@ -53,8 +55,10 @@ def _map_caption_error(status: int, message: str, code: int | None = None) -> Pl
}
error_type, retryable = http_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
message, platform="volcengine_caption",
retryable=retryable, error_type=error_type,
message,
platform="volcengine_caption",
retryable=retryable,
error_type=error_type,
status_code=status,
)
@@ -124,7 +128,7 @@ class VolcengineCaptionProvider:
max_lines: int = 1,
) -> dict[str, Any]:
"""提交字幕生成任务,返回 {id: task_id}"""
params = {
params: dict[str, str | int] = {
"appid": self.appid,
"language": language,
"caption_type": caption_type,
@@ -150,11 +154,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
raise _map_caption_error(
e.response.status_code, f"HTTP错误: {e.response.status_code}"
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"字幕服务网络错误: {e}", platform="volcengine_caption",
retryable=True, error_type=PlatformErrorType.TIMEOUT,
f"字幕服务网络错误: {e}",
platform="volcengine_caption",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"提交任务失败: {str(e)}") from e
@@ -165,7 +173,7 @@ class VolcengineCaptionProvider:
blocking: bool = False,
) -> dict[str, Any]:
"""查询字幕任务结果,返回原始 JSON"""
params = {
params: dict[str, str | int] = {
"appid": self.appid,
"id": task_id,
"blocking": 1 if blocking else 0,
@@ -182,11 +190,15 @@ class VolcengineCaptionProvider:
except PlatformError:
raise
except httpx.HTTPStatusError as e:
raise _map_caption_error(e.response.status_code, f"HTTP错误: {e.response.status_code}") from e
raise _map_caption_error(
e.response.status_code, f"HTTP错误: {e.response.status_code}"
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"字幕服务网络错误: {e}", platform="volcengine_caption",
retryable=True, error_type=PlatformErrorType.TIMEOUT,
f"字幕服务网络错误: {e}",
platform="volcengine_caption",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_caption_error(500, f"查询任务失败: {str(e)}") from e
@@ -201,7 +213,7 @@ class VolcengineCaptionProvider:
sta_punc_mode: int = 3,
) -> dict[str, Any]:
"""提交自动字幕打轴任务,返回 {id: task_id}"""
params = {
params: dict[str, str | int] = {
"appid": self.appid,
"caption_type": caption_type,
"sta_punc_mode": sta_punc_mode,
@@ -218,7 +230,9 @@ class VolcengineCaptionProvider:
response.raise_for_status()
data = response.json()
if "id" not in data:
raise _map_caption_error(500, f"提交打轴任务失败: {data.get('message', '未知错误')}")
raise _map_caption_error(
500, f"提交打轴任务失败: {data.get('message', '未知错误')}"
)
return data
except PlatformError:
raise
@@ -34,8 +34,10 @@ def _map_mediakit_error(status: int, message: str, code: int | None = None) -> P
}
error_type, retryable = error_mapping.get(status, (PlatformErrorType.UNKNOWN, False))
return PlatformError(
message, platform="volcengine_mediakit",
retryable=retryable, error_type=error_type,
message,
platform="volcengine_mediakit",
retryable=retryable,
error_type=error_type,
status_code=status,
)
@@ -167,8 +169,10 @@ class VolcengineMediakitProvider:
) from e
except (httpx.NetworkError, httpx.TimeoutException) as e:
raise PlatformError(
f"MediaKit 网络错误: {e}", platform="volcengine_mediakit",
retryable=True, error_type=PlatformErrorType.TIMEOUT,
f"MediaKit 网络错误: {e}",
platform="volcengine_mediakit",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
) from e
except Exception as e:
raise _map_mediakit_error(500, f"抠图失败: {str(e)}") from e
@@ -291,27 +291,40 @@ class VolcengineProvider(LLMProvider):
if status == 429 or "rate limit" in message.lower():
return PlatformError(
message, platform="volcengine_ark", retryable=True,
error_type=PlatformErrorType.RATE_LIMIT, status_code=status,
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.RATE_LIMIT,
status_code=status,
)
elif status in (401, 403) or "authentication" in message.lower():
return PlatformError(
message, platform="volcengine_ark", retryable=False,
error_type=PlatformErrorType.AUTH_FAILED, status_code=status,
message,
platform="volcengine_ark",
retryable=False,
error_type=PlatformErrorType.AUTH_FAILED,
status_code=status,
)
elif status and status >= 500:
return PlatformError(
message, platform="volcengine_ark", retryable=True,
error_type=PlatformErrorType.SERVER_ERROR, status_code=status,
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.SERVER_ERROR,
status_code=status,
)
elif "timeout" in message.lower() or isinstance(e, TimeoutError):
return PlatformError(
message, platform="volcengine_ark", retryable=True,
message,
platform="volcengine_ark",
retryable=True,
error_type=PlatformErrorType.TIMEOUT,
)
else:
return PlatformError(
message, platform="volcengine_ark", retryable=False,
message,
platform="volcengine_ark",
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
+3 -1
View File
@@ -5,6 +5,8 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy import select
@@ -20,7 +22,7 @@ security = HTTPBearer(auto_error=False)
# 数据库依赖
async def get_db() -> AsyncSession:
async def get_db() -> AsyncGenerator[AsyncSession]:
"""获取数据库 Session"""
async for session in db_session():
yield session
+2 -2
View File
@@ -80,7 +80,7 @@ async def send_code(
async def login(
request: MobileLoginRequest,
db: AsyncSession = Depends(get_db),
http_request: Request = None,
http_request: Request = None, # type: ignore[assignment]
):
"""
手机号验证码登录
@@ -133,7 +133,7 @@ async def login(
async def login_password(
request: PasswordLoginRequest,
db: AsyncSession = Depends(get_db),
http_request: Request = None,
http_request: Request = None, # type: ignore[assignment]
):
"""
手机号密码登录
-19
View File
@@ -24,19 +24,6 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/caption", tags=["Caption"])
@router.post("/ata/align")
async def auto_align_caption(
request_body: AutoAlignSubmitRequest,
@@ -88,9 +75,3 @@ async def auto_align_caption(
except Exception as e:
logger.error(f"自动打轴异常: {e}")
raise HTTPException(status_code=500, detail="字幕打轴失败,请稍后重试")
+12 -10
View File
@@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_db
from app.config import get_settings
from app.core.exceptions import InsufficientPointsException
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
from app.services import point_service as ps
@@ -33,6 +34,7 @@ settings = get_settings()
# ── Dependencies ──
async def get_mediakit_service(request: Request) -> VolcengineMediakitService:
"""FastAPI Depends:从 app.state 获取全局 VolcengineMediakitService 实例。"""
service = getattr(request.app.state, "volcengine_mediakit_service", None)
@@ -46,6 +48,7 @@ async def get_mediakit_service(request: Request) -> VolcengineMediakitService:
# ── Schemas ──
class ImageUploadResponse(BaseModel):
"""图片上传响应"""
@@ -64,11 +67,15 @@ class RemoveBackgroundRequest(BaseModel):
"""抠图请求"""
image_url: str = Field(..., description="原始图片 URL")
scene: str = Field(default="human", description="场景类型:general(通用)、human(人物,默认白色描边)或 product(商品)")
scene: str = Field(
default="human",
description="场景类型:general(通用)、human(人物,默认白色描边)或 product(商品)",
)
# ── Endpoints ──
@router.post("/upload/image", response_model=ApiResponse[ImageUploadResponse])
async def upload_image(
file: UploadFile = File(..., description="图片文件"),
@@ -178,9 +185,8 @@ async def remove_background(
required_points = ps._calculate_cost("cover_avatar")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
raise InsufficientPointsException(
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -193,9 +199,7 @@ async def remove_background(
)
if not result.image_url:
logger.error(
f"[RemoveBackground] 抠图返回空 URL: raw={result.raw}"
)
logger.error(f"[RemoveBackground] 抠图返回空 URL: raw={result.raw}")
raise HTTPException(status_code=500, detail="抠图失败:未返回结果图片 URL")
logger.info(f"[RemoveBackground] 抠图成功: {result.image_url[:80]}...")
@@ -256,7 +260,5 @@ async def remove_background(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(
f"[RemoveBackground] 抠图失败: image_url={req.image_url[:80]}..., error={e}"
)
logger.error(f"[RemoveBackground] 抠图失败: image_url={req.image_url[:80]}..., error={e}")
raise HTTPException(status_code=500, detail=f"抠图失败: {e}")
+2 -5
View File
@@ -59,9 +59,7 @@ async def batch_match_materials_endpoint(
根据分镜列表一次性匹配所有素材自动进行项目级去重
"""
raw_scenes = [
{"scene": s.scene, "duration": s.duration} for s in request.scenes
]
raw_scenes = [{"scene": s.scene, "duration": s.duration} for s in request.scenes]
results = await batch_match(
db,
@@ -70,8 +68,7 @@ async def batch_match_materials_endpoint(
)
matched: list[MaterialInfo | None] = [
MaterialInfo(url=r["url"], duration=r["duration"]) if r else None
for r in results
MaterialInfo(url=r["url"], duration=r["duration"]) if r else None for r in results
]
await db.commit()
+19 -15
View File
@@ -6,7 +6,7 @@
"""
import logging
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_db
from app.core.exceptions import InsufficientPointsException
from app.crud.point_recharge_order import point_recharge_order
from app.crud.point_transaction import point_transaction
from app.models.user import User
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/points", tags=["Points"])
# ── 余额查询 ──────────────────────────────────────────
@router.get("/balance", response_model=ApiResponse[PointBalanceResponse])
async def get_balance(
db: AsyncSession = Depends(get_db),
@@ -47,6 +49,7 @@ async def get_balance(
# ── 流水查询 ──────────────────────────────────────────
@router.get("/transactions", response_model=ApiResponse[PointTransactionListResponse])
async def list_transactions(
pagination: PaginationParams = Depends(),
@@ -127,12 +130,13 @@ async def list_transactions(
# ── 充值 ──────────────────────────────────────────────
@router.post("/recharge", response_model=ApiResponse[RechargeResponse])
async def create_recharge_order(
request: RechargeRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
http_request: Request = None,
http_request: Request = None, # type: ignore[assignment]
):
"""
创建积分充值订单微信支付 Native 扫码
@@ -213,6 +217,7 @@ async def create_recharge_order(
logger.error(f"[Points] 微信统一下单未返回 code_url: {wx_result}")
order.status = "failed"
order.error_msg = "微信未返回二维码链接"
await db.commit()
raise HTTPException(status_code=500, detail="微信支付下单失败")
order.prepay_id = wx_result.get("prepay_id")
@@ -303,9 +308,7 @@ async def handle_wxpay_notify(
return _wx_response()
# 查找订单
order = await point_recharge_order.get_by_out_trade_no(
db, out_trade_no=out_trade_no
)
order = await point_recharge_order.get_by_out_trade_no(db, out_trade_no=out_trade_no)
if not order:
logger.error(f"[WechatPay] 回调订单不存在: {out_trade_no}")
return _wx_response()
@@ -371,6 +374,7 @@ async def handle_wxpay_notify(
await db.rollback()
logger.exception(f"[WechatPay] 订单 {out_trade_no} 充值积分失败: {e}")
# 记录错误但不抛出,返回 SUCCESS 避免微信重试
order.status = "failed"
order.error_msg = f"充值积分失败: {e}"
await db.commit()
@@ -400,9 +404,7 @@ async def query_recharge_status(
wxpay = get_wxpay_service()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
wx_result = await wxpay.query_order(
client, out_trade_no=order.out_trade_no
)
wx_result = await wxpay.query_order(client, out_trade_no=order.out_trade_no)
order.query_result = str(wx_result)
trade_state = wx_result.get("trade_state", "")
@@ -465,6 +467,7 @@ async def query_recharge_status(
# ── 充值档位查询 ──────────────────────────────────────
@router.get("/recharge-options", response_model=ApiResponse[list[dict]])
async def get_recharge_options(
current_user: User = Depends(get_current_user),
@@ -480,6 +483,7 @@ async def get_recharge_options(
# ── 扣费业务类型查询 ──────────────────────────────────
@router.get("/chargeable-types", response_model=ApiResponse[list[str]])
async def get_chargeable_types(
current_user: User = Depends(get_current_user),
@@ -496,6 +500,7 @@ async def get_chargeable_types(
# ── 积分规则查询 ──────────────────────────────────────
@router.get("/rules", response_model=ApiResponse[list[dict]])
async def get_points_rules(
current_user: User = Depends(get_current_user),
@@ -530,9 +535,9 @@ async def get_points_rules(
# ── 积分预估查询 ──────────────────────────────────────
# ── 今日消费统计 ──────────────────────────────────────
@router.get("/today-consumed", response_model=ApiResponse[dict])
async def get_today_consumed(
db: AsyncSession = Depends(get_db),
@@ -545,6 +550,7 @@ async def get_today_consumed(
# ── 直接消费扣费(前端/Rust 层调用)───────────────────
@router.post("/consume", response_model=ApiResponse[dict])
async def consume_points(
request: ConsumeRequest,
@@ -569,12 +575,12 @@ async def consume_points(
source_type=request.source_type,
source_id=request.source_id,
description=f"{request.description or request.source_type}",
allow_negative=False,
allow_negative=request.allow_negative,
)
except ValueError as e:
except InsufficientPointsException:
# 余额不足(在同一事务内判断,避免竞态)
if "积分不足" in str(e):
raise HTTPException(status_code=402, detail=str(e))
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
await db.commit()
@@ -587,5 +593,3 @@ async def consume_points(
},
message="消费成功",
)
+71 -19
View File
@@ -17,6 +17,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.ai.model_router import get_model_router
from app.ai.prompts import list_categories, list_prompt_files, load_prompt, render_template
from app.api.deps import get_current_user
from app.core.exceptions import (
AITimeoutException,
InsufficientPointsException,
PlatformError,
PlatformErrorType,
)
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -33,6 +39,51 @@ router = APIRouter()
logger = logging.getLogger(__name__)
def _map_platform_error(e: PlatformError) -> HTTPException:
"""把第三方平台错误映射为用户友好的 HTTP 异常(带标准 error_code)。"""
if e.error_type == PlatformErrorType.CONTENT_VIOLATION:
return HTTPException(
status_code=400,
detail={
"message": "人物分镜台词未通过安全审核,请修改后重试",
"error_code": "content_violation",
},
)
if e.error_type == PlatformErrorType.RATE_LIMIT:
return HTTPException(
status_code=429,
detail={
"message": "当前请求过于频繁,请稍后再试",
"error_code": "rate_limit",
},
)
if e.error_type == PlatformErrorType.TIMEOUT:
return AITimeoutException("服务响应超时,请稍后重试")
if e.error_type == PlatformErrorType.AUTH_FAILED:
return HTTPException(
status_code=401,
detail={
"message": "第三方服务认证失败,请稍后重试或联系客服",
"error_code": "auth_failed",
},
)
if e.error_type == PlatformErrorType.SERVER_ERROR:
return HTTPException(
status_code=503,
detail={
"message": "第三方服务繁忙,请稍后重试",
"error_code": "server_error",
},
)
return HTTPException(
status_code=400,
detail={
"message": "请求失败,请检查后重试",
"error_code": e.error_type or "unknown",
},
)
@router.get("/categories", response_model=ApiResponse[list[CategoryItem]])
async def get_categories():
"""
@@ -71,9 +122,8 @@ async def polish_content(
required_points = ps._calculate_cost("polish")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
raise InsufficientPointsException(
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -99,11 +149,13 @@ async def polish_content(
data=polished,
message=f"{type_name}润色完成",
)
except InsufficientPointsException:
raise
except HTTPException:
raise
except PlatformError as e:
raise _map_platform_error(e)
except ValueError as e:
if "积分不足" in str(e):
raise HTTPException(status_code=402, detail=str(e))
logger.warning(f"[Polish] 润色失败: {e}")
raise HTTPException(status_code=500, detail="润色失败,请检查输入内容后重试")
except Exception as e:
@@ -111,9 +163,6 @@ async def polish_content(
raise HTTPException(status_code=500, detail=f"{type_name}润色失败,请稍后重试")
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
async def generate_title(
request: GenerateTitleRequest,
@@ -146,7 +195,11 @@ async def generate_title(
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
# 渲染用户提示词
title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
title_type_desc = (
"大标题(主标题,提炼核心卖点,吸睛)"
if request.title_type == "main"
else "小标题(副标题,补充说明或制造悬念)"
)
user_prompt = render_template(
user_template,
title_type=request.title_type,
@@ -163,9 +216,8 @@ async def generate_title(
required_points = ps._calculate_cost("title")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
raise InsufficientPointsException(
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -179,10 +231,10 @@ async def generate_title(
title = result.content.strip() if result.content else ""
# 去除可能的引号
title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
title = title.strip('"').strip("'").strip("「」").strip("『』").strip("《》")
# 截断到最大长度
if len(title) > request.max_length:
title = title[:request.max_length]
title = title[: request.max_length]
# 扣费
points = ps._calculate_cost("title")
@@ -200,15 +252,15 @@ async def generate_title(
data=GenerateTitleResponse(title=title),
message="标题生成成功",
)
except InsufficientPointsException:
raise
except HTTPException:
raise
except PlatformError as e:
raise _map_platform_error(e)
except TimeoutError:
logger.warning("[generate_title] 标题生成超时")
raise HTTPException(status_code=500, detail="标题生成超时,请重试")
except ValueError as e:
if "积分不足" in str(e):
raise HTTPException(status_code=402, detail=str(e))
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
raise AITimeoutException("标题生成超时,请稍后重试")
except Exception as e:
logger.error(f"[generate_title] 标题生成失败: {e}")
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
-3
View File
@@ -46,6 +46,3 @@ async def system_version():
},
message="获取版本成功",
)
+32 -14
View File
@@ -15,9 +15,10 @@ import uuid
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from app.api.deps import get_current_user
from app.core.exceptions import InsufficientPointsException
from app.core.redis_client import get_redis_client
from app.db.session import AsyncSessionLocal
from app.models.user import User
@@ -38,7 +39,6 @@ class ScriptParams(BaseModel):
category: str = Field(..., min_length=1, description="大类代码")
filename: str = Field(..., min_length=1, description="提示词文件名")
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
@@ -69,6 +69,12 @@ class SubtitleParams(BaseModel):
raise ValueError("video_path 不能为空")
return v.strip()
@model_validator(mode="after")
def validate_auto_align(self) -> "SubtitleParams":
if self.mode == "auto_align" and (not self.audio_text or not self.audio_text.strip()):
raise ValueError("auto_align 模式必须提供 audio_text")
return self
class TTSParams(BaseModel):
"""TTS 语音合成参数"""
@@ -96,7 +102,9 @@ class VideoParams(BaseModel):
volume: int = Field(default=0, ge=0, le=10, description="音量")
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
total_planned_duration: float = Field(
..., gt=0, description="所有分镜规划时长之和(秒),用于预检"
)
batch_id: str | None = Field(default=None, description="批次ID(可选)")
@field_validator("video_url")
@@ -106,6 +114,12 @@ class VideoParams(BaseModel):
raise ValueError("video_url 不能为空")
return v.strip()
@model_validator(mode="after")
def validate_audio_or_text(self) -> "VideoParams":
if not self.audio_url and not self.text:
raise ValueError("audio_url 和 text 必须至少填一个")
return self
class TaskCreateRequest(BaseModel):
"""创建任务请求"""
@@ -134,6 +148,7 @@ class TaskStatusResponse(BaseModel):
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
error_code: str | None = Field(None, description="错误码(失败时,如 content_violation")
created_at: str = Field("", description="任务创建时间(ISO格式)")
@@ -175,6 +190,9 @@ async def create_task(
validated_params = {
"category": script_validated.category,
"filename": script_validated.filename,
"user_id": user_id,
"required_points": required_points,
"project_id": project_id,
}
elif task_type == "subtitle":
@@ -222,9 +240,8 @@ async def create_task(
f"[API] 积分不足: user={user_id}, type={task_type}, "
f"required={required_points}, balance={check['balance']}"
)
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
raise InsufficientPointsException(
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
# ── 3. 写入 Redis ──────────────────────────────────
@@ -246,18 +263,17 @@ async def create_task(
params=validated_params,
)
await registry.add_running(task_id)
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
return TaskCreateResponse(
task_id=task_id,
status="running",
message=f"{task_type} 任务已创建",
)
except Exception as e:
logger.error(f"[API] Failed to update registry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
return TaskCreateResponse(
task_id=task_id,
status="running",
message=f"{task_type} 任务已创建",
)
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
@@ -294,6 +310,7 @@ async def list_tasks(
total=task.total,
result=None, # 列表查询不返回 result,避免数据过大
error=task.error,
error_code=task.error_code,
created_at=task.created_at,
)
)
@@ -337,6 +354,7 @@ async def get_task_status(
total=task.total,
result=task.result,
error=task.error,
error_code=task.error_code,
created_at=task.created_at,
)
+143 -14
View File
@@ -7,13 +7,15 @@
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.models.update import AppRelease, ReleasePackage
from app.schemas.update import (
PackageInfo,
ReleaseCreate,
ReleaseListItem,
ReleaseResponse,
@@ -38,9 +40,7 @@ async def check_update(
如果无需更新返回 204如果有更新返回 Tauri 标准格式的 JSON
"""
# 查询最新版本
result = await db.execute(
select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1)
)
result = await db.execute(select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1))
latest: AppRelease | None = result.scalar_one_or_none()
if not latest:
@@ -52,11 +52,13 @@ async def check_update(
# 查询对应平台的包(优先返回 updater 用的包:有 signature 的 .app.tar.gz / .exe
result = await db.execute(
select(ReleasePackage).where(
select(ReleasePackage)
.where(
ReleasePackage.release_id == latest.id,
ReleasePackage.platform == target,
ReleasePackage.architecture == arch,
).order_by(
)
.order_by(
# 有 signature 的排前面(updater 包),空 signature 的排后面(dmg 安装包)
ReleasePackage.signature.desc()
)
@@ -86,6 +88,133 @@ async def check_update(
)
def _parse_user_agent(user_agent: str | None) -> tuple[str, str] | None:
"""
User-Agent 解析 Tauri 平台标识和架构
Tauri updater 使用的 platform 值为darwin / windows / linux
architecture 值为x86_64 / aarch64 / i686
"""
if not user_agent:
return None
ua = user_agent.lower()
# Windows
if "windows" in ua:
platform = "windows"
if "arm64" in ua or "aarch64" in ua:
arch = "aarch64"
elif "win64" in ua or "x64" in ua:
arch = "x86_64"
else:
arch = "x86_64"
return platform, arch
# macOS
if "macintosh" in ua or "mac os x" in ua:
platform = "darwin"
if "arm64" in ua or "aarch64" in ua:
arch = "aarch64"
elif "intel" in ua:
arch = "x86_64"
else:
# 现代 Mac 默认按 Apple Silicon 处理;
# 若浏览器/Rosetta 环境未暴露 arm64,可让用户手动选择
arch = "aarch64"
return platform, arch
# Linux
if "linux" in ua:
platform = "linux"
if "aarch64" in ua or "arm64" in ua:
arch = "aarch64"
elif "x86_64" in ua or "x64" in ua:
arch = "x86_64"
else:
arch = "x86_64"
return platform, arch
return None
@router.get("/download")
async def download_latest(
request: Request,
target: str | None = Query(None, description="平台:darwin / windows / linux"),
arch: str | None = Query(None, description="架构:x86_64 / aarch64 / i686"),
db: AsyncSession = Depends(get_db),
):
"""
统一下载入口自动匹配最新版本和当前环境安装包
优先级
1. 查询参数 target + arch最可靠推荐前端显式传入
2. User-Agent 解析兜底
返回 302 重定向到对应安装包的存储地址
"""
# 1. 确定平台和架构
if target and arch:
platform = target.lower()
architecture = arch.lower()
else:
parsed = _parse_user_agent(request.headers.get("user-agent"))
if not parsed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无法识别您的操作系统,请通过官网或应用商店下载对应版本",
)
platform, architecture = parsed
# 2. 查询最新版本
result = await db.execute(select(AppRelease).order_by(AppRelease.release_date.desc()).limit(1))
latest: AppRelease | None = result.scalar_one_or_none()
if not latest:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="暂无可用下载",
)
# 3. 查询该平台所有包(不限制架构,因为 macOS 常用 universal 包会同时写入 x86_64/aarch64
result = await db.execute(
select(ReleasePackage).where(
ReleasePackage.release_id == latest.id,
ReleasePackage.platform == platform,
)
)
platform_pkgs = list(result.scalars().all())
if not platform_pkgs:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"版本 {latest.version} 暂无可用的 {platform} 安装包",
)
# 4. 优先选择用户安装包,而不是 updater 用的 .app.tar.gz
# 优先级:.dmg / .exe / .msi / .AppImage > .app.tar.gz > 其他
def _install_pkg_priority(pkg: ReleasePackage) -> int:
name = pkg.filename.lower()
if name.endswith(".dmg"):
return 1
if name.endswith(".exe"):
return 2
if name.endswith(".msi"):
return 3
if name.endswith(".appimage"):
return 4
if name.endswith(".app.tar.gz"):
return 10
return 5
# 先尝试精确架构 + 高优先级安装包
exact_arch_pkgs = [p for p in platform_pkgs if p.architecture == architecture]
candidate_pkgs = exact_arch_pkgs or platform_pkgs
pkg = min(candidate_pkgs, key=_install_pkg_priority)
return RedirectResponse(url=pkg.file_url)
@router.post("/releases", response_model=ReleaseResponse, status_code=status.HTTP_201_CREATED)
async def create_release(
release: ReleaseCreate,
@@ -139,14 +268,14 @@ async def create_release(
mandatory=new_release.mandatory,
created_at=new_release.created_at,
packages=[
{
"platform": p.platform,
"architecture": p.architecture,
"filename": p.filename,
"file_url": p.file_url,
"file_size": p.file_size,
"signature": p.signature,
}
PackageInfo(
platform=p.platform,
architecture=p.architecture,
filename=p.filename,
file_url=p.file_url,
file_size=p.file_size,
signature=p.signature,
)
for p in new_release.packages
],
)
+4 -7
View File
@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
settings = get_settings()
class UploadResponse(BaseModel):
"""上传响应"""
@@ -103,8 +102,8 @@ async def upload_video(
domain=domain,
)
url = result.get("url")
key = result.get("key")
url = result.get("url") or ""
key = result.get("key") or ""
if not url:
raise HTTPException(status_code=500, detail="上传到七牛云失败:未返回 URL")
@@ -126,8 +125,6 @@ async def upload_video(
raise HTTPException(status_code=500, detail=f"上传失败: {e}")
@router.post("/audio", response_model=ApiResponse[UploadResponse])
async def upload_audio(
file: UploadFile = File(..., description="音频文件"),
@@ -198,8 +195,8 @@ async def upload_audio(
domain=domain,
)
url = result.get("url")
key = result.get("key")
url = result.get("url") or ""
key = result.get("key") or ""
if not url:
raise HTTPException(status_code=500, detail="上传到七牛云失败:未返回 URL")
+31 -18
View File
@@ -18,6 +18,10 @@ from app.core.exceptions import PlatformError
from app.core.redis_client import get_redis_client
from app.platform_gateway import PlatformGateway
from app.schemas.common import success_response
from app.utils.content_fingerprint import (
extract_vidu_error_code,
is_vidu_audit_error,
)
logger = logging.getLogger(__name__)
@@ -44,10 +48,12 @@ async def vidu_callback(request: Request):
headers_dict = dict(request.headers)
# 使用 APP_BASE_URL 构建 callback_url,确保与提交任务时传给 Vidu 的一致
#Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
# (Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
app_base_url = get_settings().app_base_url
callback_url = f"{app_base_url}/api/v1/vidu/callback" if app_base_url else str(request.url)
logger.info(f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}")
logger.info(
f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}"
)
try:
task_status = await gateway.handle_webhook(
@@ -64,15 +70,13 @@ async def vidu_callback(request: Request):
logger.error(f"[Vidu] 回调处理失败: {e}")
raise HTTPException(status_code=500, detail="回调处理失败,请稍后重试")
logger.info(f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}")
logger.info(
f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}"
)
# 2. 通过 platform_task_id 反查 Async Engine 内部 task_id,更新 TaskRegistry
platform_task_id = (
task_status.result.get("task_id") if task_status.result else None
)
video_url = (
task_status.result.get("video_url") if task_status.result else None
)
platform_task_id = task_status.result.get("task_id") if task_status.result else None
video_url = task_status.result.get("video_url") if task_status.result else None
logger.info(f"[Vidu] 准备反查 internal_task_id: platform_task_id={platform_task_id}")
@@ -105,12 +109,23 @@ async def vidu_callback(request: Request):
result={"video_url": video_url, "state": "success"},
)
elif task_status.state == "failed":
await registry.update(
internal_task_id,
status="failed",
message="视频生成失败",
error=task_status.error_message or "视频生成失败",
)
error_message = task_status.error_message or "视频生成失败"
err_code = extract_vidu_error_code(error_message)
is_audit = err_code and is_vidu_audit_error(err_code)
update_kwargs = {
"status": "failed",
"message": (
"人物分镜台词未通过安全审核,请修改后重试"
if is_audit
else "视频生成失败"
),
"error": error_message,
}
if is_audit:
update_kwargs["error_code"] = "content_violation"
await registry.update(internal_task_id, **update_kwargs)
logger.info(
f"[Vidu] 回调已更新 TaskRegistry: task={internal_task_id}, "
f"state={task_status.state}, video_url={video_url}"
@@ -121,8 +136,6 @@ async def vidu_callback(request: Request):
f"platform={platform_task_id}"
)
else:
logger.warning(
f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}"
)
logger.warning(f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}")
return success_response(message="回调已接收")
+29 -30
View File
@@ -10,12 +10,13 @@ import logging
import re
import time
import uuid
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.core.exceptions import PlatformError
from app.core.exceptions import InsufficientPointsException, PlatformError
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -49,7 +50,9 @@ class TTSSynthesizeRequest(BaseModel):
class VoiceCloneSubmitRequest(BaseModel):
"""声音复刻提交请求"""
source_audio_url: str | None = Field(None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)")
source_audio_url: str | None = Field(
None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)"
)
source_video_url: str | None = Field(None, description="源视频 URL(可选)")
video_id: str | None = Field(None, description="历史作品ID(可选)")
voice_name: str | None = Field(None, description="自定义音色名称(≤20字符)")
@@ -111,7 +114,7 @@ async def synthesize_speech(
# 宽松预检:余额为负或零时阻止,避免浪费第三方资源
balance_info = await ps.get_user_balance(db, current_user.id)
if balance_info["balance"] <= 0:
raise HTTPException(status_code=402, detail="余额不足,请先充值")
raise InsufficientPointsException("余额不足,请先充值")
try:
audio_url = await service.synthesize(
@@ -122,7 +125,7 @@ async def synthesize_speech(
pitch=request.pitch,
)
# 探测音频时长并扣费
# 探测音频时长并扣费(计费成功才返回结果)
try:
seconds = await get_audio_duration(audio_url)
points = ps._calculate_cost("tts", {"seconds": seconds})
@@ -137,24 +140,26 @@ async def synthesize_speech(
allow_negative=True,
)
await db.commit()
except ValueError as e:
if "积分不足" in str(e):
raise HTTPException(status_code=402, detail=str(e))
logger.error(f"[Voice] TTS 扣费失败: {e}")
return success_response(
data={
"audio_url": audio_url,
"format": "mp3",
"text": request.text,
"voice_id": request.voice_id or DEFAULT_VOICE_ID,
"consumed_points": points,
"duration": seconds,
},
message="合成成功",
)
except InsufficientPointsException:
raise
except Exception as e:
logger.error(f"[Voice] TTS 扣费失败: {e}")
return success_response(
data={
"audio_url": audio_url,
"format": "mp3",
"text": request.text,
"voice_id": request.voice_id or DEFAULT_VOICE_ID,
"consumed_points": points,
"duration": seconds,
},
message="合成成功",
)
raise HTTPException(
status_code=500,
detail="语音合成计费失败,请稍后重试",
)
except HTTPException:
raise
@@ -165,7 +170,6 @@ async def synthesize_speech(
raise HTTPException(status_code=500, detail="语音合成失败,请稍后重试")
def _normalize_voice_id(name: str | None) -> str:
"""
将用户输入的名称规范化为 Vidu 合法的 voice_id
@@ -220,9 +224,8 @@ async def submit_clone_task(
required_points = ps._calculate_cost("voice_clone")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
raise HTTPException(
status_code=402,
detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
raise InsufficientPointsException(
f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -244,10 +247,8 @@ async def submit_clone_task(
description="【声音复刻】",
)
await db.commit()
except ValueError as e:
if "积分不足" in str(e):
raise HTTPException(status_code=402, detail=str(e))
logger.error(f"[Voice] 克隆扣费失败: {e}")
except InsufficientPointsException:
raise
except Exception as e:
logger.error(f"[Voice] 克隆扣费失败: {e}")
@@ -292,5 +293,3 @@ async def query_clone_task(
),
message="克隆已完成",
)
+11 -17
View File
@@ -24,7 +24,7 @@ class Settings(BaseSettings):
# 应用基础配置
APP_NAME: str = Field(default="美家卡智影 API", description="应用名称")
APP_VERSION: str = Field(default="1.8.2", description="应用版本")
APP_VERSION: str = Field(default="1.9.1", description="应用版本")
DEBUG: bool = Field(default=False, description="调试模式")
ENV: Literal["development", "staging", "production"] = Field(
default="development", description="运行环境"
@@ -45,7 +45,9 @@ class Settings(BaseSettings):
description="数据库连接字符串(PostgreSQL)",
)
DATABASE_POOL_SIZE: int = Field(default=10, description="数据库连接池常驻连接数")
DATABASE_MAX_OVERFLOW: int = Field(default=10, description="连接池临时溢出上限(建议 ≤ pool_size)")
DATABASE_MAX_OVERFLOW: int = Field(
default=10, description="连接池临时溢出上限(建议 ≤ pool_size)"
)
DATABASE_POOL_RECYCLE: int = Field(
default=1800, description="连接回收时间(秒),防止长连接被数据库静默断开"
)
@@ -73,7 +75,7 @@ class Settings(BaseSettings):
# 安全配置
SECRET_KEY: str = Field(
...,
default="",
description="JWT 签名密钥(生产环境必须修改)",
)
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(
@@ -107,7 +109,9 @@ class Settings(BaseSettings):
VOLCENGINE_CAPTION_TOKEN: str | None = Field(default=None, description="火山字幕 Token")
# 火山引擎 MediaKit 服务(背景移除等多媒体处理)
VOLCENGINE_MEDIAKIT_TOKEN: str | None = Field(default=None, description="火山引擎 MediaKit Token")
VOLCENGINE_MEDIAKIT_TOKEN: str | None = Field(
default=None, description="火山引擎 MediaKit Token"
)
# Vidu 密钥(base_url 已从 Settings 移除,改用 config/platform-config.yaml 配置)
VIDU_API_KEY: str | None = Field(default=None, description="Vidu API Key")
@@ -124,9 +128,7 @@ class Settings(BaseSettings):
WXPAY_MCHID: str | None = Field(default=None, description="微信支付商户号")
WXPAY_APPID: str | None = Field(default=None, description="微信支付 AppID")
WXPAY_API_KEY: str | None = Field(default=None, description="微信支付 APIv2 密钥")
WXPAY_NOTIFY_URL: str | None = Field(
default=None, description="微信支付回调地址(完整 URL"
)
WXPAY_NOTIFY_URL: str | None = Field(default=None, description="微信支付回调地址(完整 URL")
# B2M 短信平台配置
SMS_APP_ID: str | None = Field(default=None, description="B2M 短信平台 AppID")
@@ -134,16 +136,12 @@ class Settings(BaseSettings):
SMS_BASE_URL: str | None = Field(
default=None, description="B2M 短信平台接口地址(如 http://sms.b2m.cn:8080"
)
SMS_EXTENDED_CODE: str | None = Field(
default=None, description="B2M 短信平台扩展码(选填)"
)
SMS_EXTENDED_CODE: str | None = Field(default=None, description="B2M 短信平台扩展码(选填)")
SMS_CODE_WHITELIST: str = Field(
default="",
description="免验证码登录白名单(逗号分隔的手机号,如 13800138000,13900139000",
)
# 文件上传限制(字节)
UPLOAD_MAX_VIDEO_SIZE: int = Field(
default=500 * 1024 * 1024, description="视频最大上传大小(字节)"
@@ -187,11 +185,7 @@ class Settings(BaseSettings):
"""免验证码登录白名单(去重、去空格)"""
if not self.SMS_CODE_WHITELIST:
return set()
return {
mobile.strip()
for mobile in self.SMS_CODE_WHITELIST.split(",")
if mobile.strip()
}
return {mobile.strip() for mobile in self.SMS_CODE_WHITELIST.split(",") if mobile.strip()}
@lru_cache
+75 -9
View File
@@ -24,9 +24,16 @@ class AppException(HTTPException):
status_code: int,
message: str = "操作失败",
detail: dict | None = None,
*,
error_code: str | None = None,
):
super().__init__(status_code=status_code, detail=detail or {})
body = detail or {}
body["message"] = message
if error_code:
body["error_code"] = error_code
super().__init__(status_code=status_code, detail=body)
self.message = message
self.error_code = error_code
class NotFoundException(AppException):
@@ -44,7 +51,7 @@ class ValidationException(AppException):
def __init__(self, message: str = "参数验证失败"):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
message=message,
)
@@ -79,6 +86,17 @@ class BusinessException(AppException):
)
class InsufficientPointsException(AppException):
"""积分不足"""
def __init__(self, message: str = "积分不足"):
super().__init__(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
message=message,
error_code="insufficient_points",
)
class ModelUnavailableException(AppException):
"""AI 模型不可用"""
@@ -99,6 +117,50 @@ class TaskFailedException(AppException):
)
class PromptNotFoundException(AppException):
"""提示词文件不存在"""
def __init__(self, message: str = "未找到提示词"):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
message=message,
error_code="prompt_not_found",
)
class AIEmptyResponseException(AppException):
"""AI 返回内容为空"""
def __init__(self, message: str = "AI 返回内容为空"):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=message,
error_code="empty_result",
)
class AIParseErrorException(AppException):
"""AI 返回内容解析失败"""
def __init__(self, message: str = "AI 返回格式解析失败"):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=message,
error_code="parse_error",
)
class AITimeoutException(AppException):
"""AI 调用超时"""
def __init__(self, message: str = "AI 请求超时,请稍后重试"):
super().__init__(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
message=message,
error_code="timeout",
)
# ═══════════════════════════════════════════════════════════════
# 第三方平台异常(PlatformError 体系)
# ═══════════════════════════════════════════════════════════════
@@ -111,14 +173,15 @@ class PlatformErrorType:
确保前端和网关能够统一处理
"""
RATE_LIMIT = "rate_limit" # 429,可重试
AUTH_FAILED = "auth_failed" # 401/403,不可重试
TIMEOUT = "timeout" # 连接/读取超时,可重试
SERVER_ERROR = "server_error" # 第三方 5xx,可重试
BAD_REQUEST = "bad_request" # 参数错误,不可重试
RATE_LIMIT = "rate_limit" # 429,可重试
AUTH_FAILED = "auth_failed" # 401/403,不可重试
TIMEOUT = "timeout" # 连接/读取超时,可重试
SERVER_ERROR = "server_error" # 第三方 5xx,可重试
BAD_REQUEST = "bad_request" # 参数错误,不可重试
QUOTA_EXHAUSTED = "quota_exhausted" # 额度用完,不可重试(或延迟重试)
NOT_FOUND = "not_found" # 资源不存在,不可重试
UNKNOWN = "unknown" # 兜底
NOT_FOUND = "not_found" # 资源不存在,不可重试
CONTENT_VIOLATION = "content_violation" # 内容安全/审核不通过,不可重试
UNKNOWN = "unknown" # 兜底
class PlatformError(Exception):
@@ -145,12 +208,14 @@ class PlatformError(Exception):
retryable: bool = False,
error_type: str = PlatformErrorType.UNKNOWN,
status_code: int | None = None,
raw_code: str | None = None,
):
super().__init__(message)
self.platform = platform
self.retryable = retryable
self.error_type = error_type
self.status_code = status_code
self.raw_code = raw_code
def to_http_status(self) -> int:
"""根据 error_type 和 retryable 返回标准 HTTP 状态码"""
@@ -161,6 +226,7 @@ class PlatformError(Exception):
PlatformErrorType.AUTH_FAILED: 401,
PlatformErrorType.BAD_REQUEST: 400,
PlatformErrorType.NOT_FOUND: 404,
PlatformErrorType.CONTENT_VIOLATION: 400,
}
if self.error_type in mapping:
return mapping[self.error_type]
+4 -16
View File
@@ -143,9 +143,7 @@ class PlatformConfigLoader:
启动时加载全环境只读不支持热重载
"""
DEFAULT_CONFIG_PATH = (
Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
)
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent / "config" / "platform-config.yaml"
def __init__(self, config_path: str | None = None):
self.config_path = Path(config_path) if config_path else self.DEFAULT_CONFIG_PATH
@@ -157,9 +155,7 @@ class PlatformConfigLoader:
def _load(self) -> None:
"""加载并校验配置文件"""
if not self.config_path.exists():
raise FileNotFoundError(
f"平台配置文件不存在: {self.config_path}"
)
raise FileNotFoundError(f"平台配置文件不存在: {self.config_path}")
try:
with open(self.config_path, encoding="utf-8") as f:
@@ -215,18 +211,10 @@ class PlatformConfigLoader:
return [m for m in self._models.values() if m.is_enabled]
def get_models_by_capability(self, capability: str) -> list[ModelConfig]:
return [
m
for m in self._models.values()
if m.is_enabled and capability in m.capabilities
]
return [m for m in self._models.values() if m.is_enabled and capability in m.capabilities]
def get_models_by_platform(self, platform_id: str) -> list[ModelConfig]:
return [
m
for m in self._models.values()
if m.platform_id == platform_id and m.is_enabled
]
return [m for m in self._models.values() if m.platform_id == platform_id and m.is_enabled]
def get_default_model_for_task(self, task_type: str) -> str | None:
if self._raw is None:
+3 -1
View File
@@ -4,6 +4,8 @@ Redis 客户端
全局 Redis 连接 Scheduler RateLimiter 使用
"""
from typing import Any
from redis.asyncio import Redis
from app.config import get_settings
@@ -19,7 +21,7 @@ def get_redis_client() -> Redis:
settings = get_settings()
# 构建连接参数
client_kwargs = {
client_kwargs: dict[str, Any] = {
"host": settings.REDIS_HOST,
"port": settings.REDIS_PORT,
"db": settings.REDIS_DB,
+1 -3
View File
@@ -70,9 +70,7 @@ class BrollCategoryCRUD(CRUDBase[BrollCategory]):
)
return result.scalar_one_or_none()
async def get_by_level(
self, db: AsyncSession, *, level: int
) -> list[BrollCategory]:
async def get_by_level(self, db: AsyncSession, *, level: int) -> list[BrollCategory]:
"""根据层级获取所有启用的分类"""
result = await db.execute(
select(BrollCategory).where(
+1 -3
View File
@@ -49,9 +49,7 @@ class BrollMaterialCRUD(CRUDBase[BrollMaterial]):
)
return list(result.scalars().all())
async def increment_usage_count(
self, db: AsyncSession, *, material_id: int
) -> None:
async def increment_usage_count(self, db: AsyncSession, *, material_id: int) -> None:
"""
原子递增素材使用次数
+2 -6
View File
@@ -23,9 +23,7 @@ class PointRechargeOrderCRUD(CRUDBase[PointRechargeOrder]):
) -> PointRechargeOrder | None:
"""根据商户订单号查询"""
result = await db.execute(
select(PointRechargeOrder).where(
PointRechargeOrder.out_trade_no == out_trade_no
)
select(PointRechargeOrder).where(PointRechargeOrder.out_trade_no == out_trade_no)
)
return result.scalar_one_or_none()
@@ -34,9 +32,7 @@ class PointRechargeOrderCRUD(CRUDBase[PointRechargeOrder]):
) -> PointRechargeOrder | None:
"""根据微信支付订单号查询"""
result = await db.execute(
select(PointRechargeOrder).where(
PointRechargeOrder.wx_order_no == wx_order_no
)
select(PointRechargeOrder).where(PointRechargeOrder.wx_order_no == wx_order_no)
)
return result.scalar_one_or_none()
+8 -10
View File
@@ -6,6 +6,7 @@
"""
from datetime import datetime, time
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -24,7 +25,7 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
self,
db: AsyncSession,
*,
user_id: str,
user_id: UUID | str,
skip: int = 0,
limit: int = 50,
tx_type: str | None = None,
@@ -55,7 +56,7 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
self,
db: AsyncSession,
*,
user_id: str,
user_id: UUID | str,
tx_type: str | None = None,
category: str | None = None,
source_type: str | None = None,
@@ -105,18 +106,15 @@ class PointTransactionCRUD(CRUDBase[PointTransaction]):
self,
db: AsyncSession,
*,
user_id: str,
user_id: UUID | str,
) -> int:
"""统计用户今日消费积分总和"""
now = datetime.now()
start_of_day = datetime.combine(now.date(), time.min)
stmt = (
select(func.coalesce(func.sum(PointTransaction.amount), 0))
.where(
PointTransaction.user_id == user_id,
PointTransaction.type == "consume",
PointTransaction.created_at >= start_of_day,
)
stmt = select(func.coalesce(func.sum(PointTransaction.amount), 0)).where(
PointTransaction.user_id == user_id,
PointTransaction.type == "consume",
PointTransaction.created_at >= start_of_day,
)
result = await db.execute(stmt)
return result.scalar() or 0
+9 -11
View File
@@ -5,7 +5,11 @@
用户认证相关的数据访问
"""
from typing import Any, cast
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.engine import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
@@ -48,7 +52,7 @@ class UserCRUD(CRUDBase[User]):
return user
async def update_login_info(
self, db: AsyncSession, *, user_id: str, ip: str | None = None
self, db: AsyncSession, *, user_id: UUID | str, ip: str | None = None
) -> User | None:
"""
更新用户最后登录信息
@@ -68,7 +72,7 @@ class UserCRUD(CRUDBase[User]):
return user
async def update_password(
self, db: AsyncSession, *, user_id: str, password_hash: str
self, db: AsyncSession, *, user_id: UUID | str, password_hash: str
) -> User | None:
"""更新用户密码"""
user = await self.get(db, id=user_id)
@@ -80,9 +84,7 @@ class UserCRUD(CRUDBase[User]):
await db.refresh(user)
return user
async def update_extra(
self, db: AsyncSession, *, user_id: str, extra: dict
) -> bool:
async def update_extra(self, db: AsyncSession, *, user_id: UUID | str, extra: dict) -> bool:
"""
原子更新用户 extra 字段JSONB
@@ -90,12 +92,8 @@ class UserCRUD(CRUDBase[User]):
"""
from sqlalchemy import update
stmt = (
update(User)
.where(User.id == user_id)
.values(extra=extra)
)
result = await db.execute(stmt)
stmt = update(User).where(User.id == user_id).values(extra=extra)
result = cast(CursorResult[Any], await db.execute(stmt))
await db.commit()
return result.rowcount > 0
+8 -12
View File
@@ -6,6 +6,8 @@
核心操作是覆盖而非新增使用 INSERT ... ON CONFLICT DO UPDATE 保证原子性
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
@@ -20,18 +22,16 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
def __init__(self) -> None:
super().__init__(UserDevice)
async def get_by_user_id(self, db: AsyncSession, *, user_id: str) -> UserDevice | None:
async def get_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> UserDevice | None:
"""根据用户 ID 获取设备记录"""
result = await db.execute(
select(UserDevice).where(UserDevice.user_id == user_id)
)
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
return result.scalar_one_or_none()
async def create_or_update(
self,
db: AsyncSession,
*,
user_id: str,
user_id: UUID | str,
device_id: str,
device_name: str | None = None,
os_info: str | None = None,
@@ -80,12 +80,10 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
await db.commit()
# 返回最新的记录
result = await db.execute(
select(UserDevice).where(UserDevice.user_id == user_id)
)
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
return result.scalar_one()
async def delete_by_user_id(self, db: AsyncSession, *, user_id: str) -> bool:
async def delete_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> bool:
"""根据用户 ID 删除设备记录(登出时使用)"""
device = await self.get_by_user_id(db, user_id=user_id)
if device is None:
@@ -100,9 +98,7 @@ class UserDeviceCRUD(CRUDBase[UserDevice]):
) -> UserDevice | None:
"""根据 Refresh Token 哈希获取设备记录"""
result = await db.execute(
select(UserDevice).where(
UserDevice.refresh_token_hash == refresh_token_hash
)
select(UserDevice).where(UserDevice.refresh_token_hash == refresh_token_hash)
)
return result.scalar_one_or_none()
+9 -9
View File
@@ -149,12 +149,8 @@ async def lifespan(app: FastAPI):
logger.info("VolcengineMediakitAdapter initialized")
if app.state.volcengine_provider:
app.state.volcengine_ark_adapter = VolcengineArkAdapter(
app.state.volcengine_provider
)
app.state.platform_gateway.register(
"volcengine_ark", app.state.volcengine_ark_adapter
)
app.state.volcengine_ark_adapter = VolcengineArkAdapter(app.state.volcengine_provider)
app.state.platform_gateway.register("volcengine_ark", app.state.volcengine_ark_adapter)
logger.info("VolcengineArkAdapter registered to PlatformGateway")
logger.info("PlatformGateway initialized")
@@ -177,9 +173,7 @@ async def lifespan(app: FastAPI):
logger.info("Vidu Service initialized")
if app.state.volcengine_caption_provider:
app.state.volcengine_caption_service = VolcengineCaptionService(
app.state.platform_gateway
)
app.state.volcengine_caption_service = VolcengineCaptionService(app.state.platform_gateway)
logger.info("Volcengine Caption Service initialized")
else:
app.state.volcengine_caption_service = None
@@ -276,12 +270,14 @@ def create_app() -> FastAPI:
"code": exc.status_code or http_status,
"message": str(exc),
"data": None,
"error_code": exc.error_type,
}
if settings.DEBUG:
content["detail"] = {
"platform": exc.platform,
"error_type": exc.error_type,
"retryable": exc.retryable,
"raw_code": exc.raw_code,
}
return _cors_response(request, http_status, content)
@@ -294,6 +290,9 @@ def create_app() -> FastAPI:
exc.detail if isinstance(exc.detail, str) else "请求失败"
)
detail = exc.detail if isinstance(exc.detail, dict) else None
error_code = getattr(exc, "error_code", None)
if not error_code and isinstance(detail, dict):
error_code = detail.get("error_code")
return _cors_response(
request,
@@ -301,6 +300,7 @@ def create_app() -> FastAPI:
{
"code": exc.status_code,
"message": message,
"error_code": error_code,
"detail": detail,
},
)
+4 -12
View File
@@ -33,18 +33,10 @@ class BgmMusic(BaseModelBigInt):
category: Mapped[str] = mapped_column(
String(32), nullable=False, index=True, comment="场景分类"
)
file_path: Mapped[str] = mapped_column(
String(512), nullable=False, comment="相对文件路径"
)
url: Mapped[str] = mapped_column(
String(1024), nullable=False, comment="七牛云 URL"
)
duration: Mapped[float] = mapped_column(
Float, nullable=True, comment="时长(秒)"
)
file_path: Mapped[str] = mapped_column(String(512), nullable=False, comment="相对文件路径")
url: Mapped[str] = mapped_column(String(1024), nullable=False, comment="七牛云 URL")
duration: Mapped[float] = mapped_column(Float, nullable=True, comment="时长(秒)")
status: Mapped[str] = mapped_column(
String(16), default="active", nullable=False, comment="状态: active/inactive"
)
sort_order: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="排序权重"
)
sort_order: Mapped[int] = mapped_column(Integer, default=0, nullable=False, comment="排序权重")
+17 -2
View File
@@ -7,7 +7,16 @@ Tauri updater 插件所需的数据结构。
from datetime import UTC, datetime
from sqlalchemy import Boolean, BigInteger, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.session import Base
@@ -60,5 +69,11 @@ class ReleasePackage(Base):
release: Mapped["AppRelease"] = relationship("AppRelease", back_populates="packages")
__table_args__ = (
UniqueConstraint("release_id", "platform", "architecture", "filename", name="uix_app_pkg_platform_arch_filename"),
UniqueConstraint(
"release_id",
"platform",
"architecture",
"filename",
name="uix_app_pkg_platform_arch_filename",
),
)
+1 -4
View File
@@ -15,7 +15,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import BaseModel
class UserStatus(str, enum.Enum):
class UserStatus(enum.StrEnum):
"""用户状态"""
ACTIVE = "active" # 正常
@@ -117,6 +117,3 @@ class User(BaseModel):
def display_name(self) -> str:
"""对外展示的名称"""
return self.nickname or f"用户_{self.mobile[-4:]}"
+90 -6
View File
@@ -24,6 +24,10 @@ from app.ai.adapters.base import (
TaskStatus,
)
from app.core.exceptions import PlatformError, PlatformErrorType
from app.utils.content_fingerprint import (
compute_content_fingerprint,
is_vidu_audit_error,
)
logger = logging.getLogger(__name__)
@@ -31,6 +35,10 @@ logger = logging.getLogger(__name__)
_TASK_MAPPING_PREFIX = "platform_gateway:task_mapping"
_TASK_MAPPING_TTL = 7 * 24 * 60 * 60 # 7 天
# Redis key 前缀:内容审核失败缓存
_AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection"
_AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时
class PlatformGateway:
"""第三方平台统一调用网关"""
@@ -61,10 +69,13 @@ class PlatformGateway:
redis = self._get_redis()
# 正向映射:internal → platform
key = self._task_mapping_key(internal_task_id)
await redis.hset(key, mapping={
"platform": platform,
"platform_task_id": platform_task_id,
})
await redis.hset(
key,
mapping={
"platform": platform,
"platform_task_id": platform_task_id,
},
)
await redis.expire(key, _TASK_MAPPING_TTL)
# 反向映射:platform → internal(供回调查找)
reverse_key = f"{_TASK_MAPPING_PREFIX}:reverse:{platform}:{platform_task_id}"
@@ -82,6 +93,33 @@ class PlatformGateway:
"platform_task_id": data.get("platform_task_id", ""),
}
def _audit_rejection_key(self, fingerprint: str) -> str:
return f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}"
async def _get_audit_rejection(self, fingerprint: str) -> str | None:
"""查询该内容指纹是否近期审核失败。
Returns:
失败错误码 "TaskPromptPolicyViolation"未命中返回 None
"""
if not fingerprint:
return None
try:
redis = self._get_redis()
key = self._audit_rejection_key(fingerprint)
return await redis.get(key)
except Exception as e:
logger.warning(f"[PlatformGateway] 查询审核缓存失败: {e}")
return None
async def _set_audit_rejection(self, fingerprint: str, error_code: str) -> None:
"""缓存审核失败结果。"""
if not fingerprint or not error_code:
return
redis = self._get_redis()
key = self._audit_rejection_key(fingerprint)
await redis.setex(key, _AUDIT_REJECTION_TTL, error_code)
async def get_internal_task_id_by_platform_task_id(
self, platform: str, platform_task_id: str
) -> str | None:
@@ -156,15 +194,62 @@ class PlatformGateway:
若提供则直接使用该 ID 建立映射否则自动生成
callback 场景必须传入确保回调能反查到正确的 Registry 记录
"""
internal_task_id = internal_task_id or uuid.uuid4().hex
# 1. 同一 internal_task_id 已提交过,直接返回(幂等)
existing = await self._get_task_mapping(internal_task_id)
if existing:
return internal_task_id
# 2. Vidu 内容指纹防重:相同内容近期审核失败则直接拦截
fingerprint: str | None = None
if platform == "vidu":
fingerprint = compute_content_fingerprint(
task_type=task_type,
video_url=payload.get("video_url"),
audio_url=payload.get("audio_url"),
ref_photo_url=payload.get("ref_photo_url"),
text=payload.get("text"),
voice_id=payload.get("voice_id"),
)
rejected_code = await self._get_audit_rejection(fingerprint)
if rejected_code:
raise PlatformError(
"人物分镜台词未通过安全审核,请修改后重试",
platform=platform,
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
raw_code=rejected_code,
)
# 3. 调用平台 Adapter 提交任务
adapter = self._get_task_adapter(platform, task_type)
result = await adapter.submit(task_type, payload, callback_url)
try:
result = await adapter.submit(task_type, payload, callback_url)
except PlatformError as e:
# Vidu 审核类错误:缓存内容指纹,防止重复调用
err_code = e.raw_code
if platform == "vidu" and fingerprint and err_code and is_vidu_audit_error(err_code):
await self._set_audit_rejection(fingerprint, err_code)
raise PlatformError(
"人物分镜台词未通过安全审核,请修改后重试",
platform=platform,
retryable=False,
error_type=PlatformErrorType.CONTENT_VIOLATION,
raw_code=err_code,
) from e
raise
if not result.success:
raw_code = result.error_code
if platform == "vidu" and fingerprint and raw_code and is_vidu_audit_error(raw_code):
await self._set_audit_rejection(fingerprint, raw_code)
raise PlatformError(
result.error_message or "任务提交失败",
platform=platform,
retryable=result.retryable,
error_type=PlatformErrorType.UNKNOWN,
raw_code=raw_code,
)
platform_task_id = (result.data or {}).get("task_id", "")
@@ -175,7 +260,6 @@ class PlatformGateway:
retryable=False,
error_type=PlatformErrorType.UNKNOWN,
)
internal_task_id = internal_task_id or uuid.uuid4().hex
await self._store_task_mapping(internal_task_id, platform, platform_task_id)
logger.info(
f"Task submitted: internal={internal_task_id}, "
+52 -3
View File
@@ -7,6 +7,7 @@ Async Engine 核心调度器
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from app.core.redis_client import get_redis_client
@@ -17,6 +18,13 @@ from app.scheduler.slot_manager import SlotManager
logger = logging.getLogger(__name__)
# 各任务类型最大执行时间(秒),超过后自动标记为 failed
TASK_TIMEOUT_SECONDS = {
"script": 5 * 60,
"subtitle": 10 * 60,
"video": 30 * 60,
}
class AsyncEngine:
"""统一异步作业调度引擎"""
@@ -46,13 +54,50 @@ class AsyncEngine:
logger.debug("Tick: no running tasks")
return
# 2. 按 task_type 分组
# 2. 按 task_type 分组,并处理超时任务
tasks_by_type: dict[str, list[Any]] = {}
timeout_changes: list[StateChange] = []
now = datetime.now(UTC)
for task_id in running_ids:
record = await self.registry.get(task_id)
if not record:
await self.registry.remove_running(task_id)
continue
max_duration = TASK_TIMEOUT_SECONDS.get(record.task_type)
is_timeout = (
max_duration
and record.status == "running"
and record.created_at
and (now - datetime.fromisoformat(record.created_at)).total_seconds() > max_duration
)
if is_timeout:
logger.warning(
f"Task timeout: {task_id}, type={record.task_type}, "
f"created_at={record.created_at}"
)
timeout_changes.append(
StateChange(task_id=task_id, field_path="status", value="failed")
)
timeout_changes.append(
StateChange(
task_id=task_id,
field_path="message",
value="任务执行超时,请稍后重试",
)
)
timeout_changes.append(
StateChange(
task_id=task_id,
field_path="error",
value=f"任务执行超过 {max_duration}",
)
)
await self.registry.remove_running(task_id)
continue
tasks_by_type.setdefault(record.task_type, []).append(record)
# 3. 并行执行各 Handler 的 tick
@@ -63,10 +108,14 @@ class AsyncEngine:
]
)
# 4. 收集并应用状态变更
# 4. 收集并应用状态变更(包含超时任务)
all_changes: list[StateChange] = []
for changes in results:
if changes:
await self._apply_changes(changes)
all_changes.extend(changes)
all_changes.extend(timeout_changes)
if all_changes:
await self._apply_changes(all_changes)
# 5. 清理已结束的作业
await self._cleanup_finished()
@@ -10,6 +10,7 @@ import logging
from typing import Any
from app.ai.prompts.loader import _load_system_meta
from app.core.exceptions import InsufficientPointsException
from app.core.platform_config import get_platform_config_loader
from app.db.session import AsyncSessionLocal
from app.scheduler.handlers.base import AsyncHandler
@@ -38,6 +39,7 @@ def _get_category_name(category: str, filename: str) -> str:
return f"{cat_name} · {label}"
return cat_name
SLOT_KEY = "script:slots"
@@ -64,9 +66,7 @@ class ScriptHandler(AsyncHandler):
def _get_service(self) -> ScriptService:
if self.service is None:
raise RuntimeError(
"ScriptHandler 需要通过构造函数传入 ScriptService 实例"
)
raise RuntimeError("ScriptHandler 需要通过构造函数传入 ScriptService 实例")
return self.service
async def tick(
@@ -83,7 +83,9 @@ class ScriptHandler(AsyncHandler):
changes.extend(await self._process_task(task, registry, slots))
except Exception as e:
logger.exception(f"[Script {task.task_id}] failed")
changes.append(StateChange(task_id=task.task_id, field_path="status", value="failed"))
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(task_id=task.task_id, field_path="error", value=str(e)[:500])
)
@@ -129,30 +131,66 @@ class ScriptHandler(AsyncHandler):
"shot_count": len(shots),
}
changes.append(StateChange(task_id=task.task_id, field_path="status", value="completed"))
# 生成成功后再扣费
user_id = params.get("user_id")
required_points = params.get("required_points", 0)
if user_id and required_points > 0:
try:
async with AsyncSessionLocal() as db:
await ps.consume(
db,
user_id=user_id,
points=required_points,
source_type="script",
source_id=task.task_id,
description="【脚本生成】",
)
await db.commit()
except InsufficientPointsException:
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="message",
value="积分不足",
)
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="error_code",
value="insufficient_points",
)
)
return changes
except Exception as e:
logger.error(f"[ScriptTask {task.task_id}] 扣费失败: {e}")
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="message",
value="扣费失败",
)
)
return changes
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="completed")
)
changes.append(StateChange(task_id=task.task_id, field_path="progress", value=100))
changes.append(
StateChange(task_id=task.task_id, field_path="message", value="脚本生成完成")
)
changes.append(StateChange(task_id=task.task_id, field_path="completed", value=1))
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
changes.append(StateChange(task_id=task.task_id, field_path="result", value=result_data))
# 后置扣费(独立 session,失败不影响任务结果)
try:
async with AsyncSessionLocal() as db:
points = ps._calculate_cost("script")
await ps.consume(
db,
user_id=task.user_id,
points=points,
source_type="script",
source_id=task.task_id,
description="【脚本生成】",
)
await db.commit()
except Exception as e:
logger.error(f"[Script {task.task_id}] 扣费失败: {e}")
changes.append(
StateChange(task_id=task.task_id, field_path="result", value=result_data)
)
except Exception as exc:
logger.exception(f"[ScriptTask {task.task_id}] Failed")
@@ -160,6 +198,8 @@ class ScriptHandler(AsyncHandler):
changes.append(
StateChange(task_id=task.task_id, field_path="message", value=str(exc)[:200])
)
changes.append(StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500]))
changes.append(
StateChange(task_id=task.task_id, field_path="error", value=str(exc)[:500])
)
return changes
@@ -45,9 +45,7 @@ class SubtitleHandler(AsyncHandler):
def _get_service(self) -> VolcengineCaptionService:
if self.service is None:
raise RuntimeError(
"SubtitleHandler 需要通过构造函数传入 VolcengineCaptionService 实例"
)
raise RuntimeError("SubtitleHandler 需要通过构造函数传入 VolcengineCaptionService 实例")
return self.service
async def tick(
@@ -93,7 +91,9 @@ class SubtitleHandler(AsyncHandler):
"utterances": utterances,
}
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="completed")
StateChange(
task_id=task.task_id, field_path="status", value="completed"
)
)
changes.append(
StateChange(task_id=task.task_id, field_path="progress", value=100)
@@ -106,9 +106,13 @@ class SubtitleHandler(AsyncHandler):
changes.append(
StateChange(task_id=task.task_id, field_path="completed", value=1)
)
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
changes.append(
StateChange(task_id=task.task_id, field_path="result", value=result_payload)
StateChange(task_id=task.task_id, field_path="total", value=1)
)
changes.append(
StateChange(
task_id=task.task_id, field_path="result", value=result_payload
)
)
elif status.state == "failed":
changes.append(
@@ -173,9 +177,13 @@ class SubtitleHandler(AsyncHandler):
if not volc_task_id:
raise ValueError("未返回任务ID")
params["volc_task_id"] = volc_task_id
changes.append(StateChange(task_id=task.task_id, field_path="params", value=params))
changes.append(
StateChange(task_id=task.task_id, field_path="message", value="字幕任务已提交")
StateChange(task_id=task.task_id, field_path="params", value=params)
)
changes.append(
StateChange(
task_id=task.task_id, field_path="message", value="字幕任务已提交"
)
)
except RuntimeError:
logger.error(f"[Subtitle {task.task_id}] service not initialized")
@@ -9,17 +9,27 @@ Video 任务处理器
import logging
from typing import Any
from app.core.exceptions import PlatformError, PlatformErrorType
from app.core.platform_config import get_platform_config_loader
from app.core.redis_client import get_redis_client
from app.scheduler.handlers.base import AsyncHandler
from app.scheduler.models import StateChange
from app.scheduler.registry import TaskRegistry
from app.scheduler.slot_manager import SlotManager
from app.services.vidu_service import ViduService
from app.utils.content_fingerprint import (
compute_content_fingerprint,
extract_vidu_error_code,
is_vidu_audit_error,
)
logger = logging.getLogger(__name__)
SLOT_KEY = "vidu:video_slots"
_AUDIT_REJECTION_PREFIX = "platform_gateway:audit_rejection"
_AUDIT_REJECTION_TTL = 24 * 60 * 60 # 24 小时
def _get_video_max_slots() -> int:
"""从 platform-config.yaml 读取 rate_limit 配置作为 max_slots"""
@@ -45,11 +55,39 @@ class VideoHandler(AsyncHandler):
def _get_service(self) -> ViduService:
if self.service is None:
raise RuntimeError(
"VideoHandler 需要通过构造函数传入 ViduService 实例"
)
raise RuntimeError("VideoHandler 需要通过构造函数传入 ViduService 实例")
return self.service
async def _cache_audit_rejection_if_needed(
self, params: dict[str, Any], error_message: str | None
) -> None:
"""如果失败原因是 Vidu 审核类错误,缓存内容指纹防止重复提交。"""
err_code = extract_vidu_error_code(error_message)
if not err_code or not is_vidu_audit_error(err_code):
return
fingerprint = compute_content_fingerprint(
task_type="lip_sync",
video_url=params.get("video_url"),
audio_url=params.get("audio_url"),
ref_photo_url=params.get("ref_photo_url"),
text=params.get("text"),
voice_id=params.get("voice_id"),
)
if not fingerprint:
return
try:
redis = get_redis_client()
key = f"{_AUDIT_REJECTION_PREFIX}:{fingerprint}"
await redis.setex(key, _AUDIT_REJECTION_TTL, err_code)
logger.info(
f"[Video] 审核失败内容已缓存: fingerprint={fingerprint[:16]}..., "
f"err_code={err_code}"
)
except Exception as e:
logger.warning(f"[Video] 缓存审核失败结果出错: {e}")
async def tick(
self, tasks: list[Any], registry: TaskRegistry, slots: SlotManager
) -> list[StateChange]:
@@ -87,11 +125,7 @@ class VideoHandler(AsyncHandler):
value=1,
)
)
changes.append(
StateChange(
task_id=task.task_id, field_path="total", value=1
)
)
changes.append(StateChange(task_id=task.task_id, field_path="total", value=1))
elif task.status == "failed":
# callback 已标记失败,移除 running
await registry.remove_running(task.task_id)
@@ -150,6 +184,15 @@ class VideoHandler(AsyncHandler):
f"video_url={video_url[:60]}..."
)
elif vidu_state == "failed":
error_message = vidu_status.get("message") or "视频生成失败"
err_code = extract_vidu_error_code(error_message)
is_audit = err_code and is_vidu_audit_error(err_code)
message = (
"人物分镜台词未通过安全审核,请修改后重试"
if is_audit
else "视频生成失败"
)
changes.append(
StateChange(
task_id=task.task_id,
@@ -161,20 +204,29 @@ class VideoHandler(AsyncHandler):
StateChange(
task_id=task.task_id,
field_path="message",
value="视频生成失败",
value=message,
)
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="error",
value=vidu_status.get("message") or "视频生成失败",
value=error_message,
)
)
if is_audit:
changes.append(
StateChange(
task_id=task.task_id,
field_path="error_code",
value="content_violation",
)
)
await self._cache_audit_rejection_if_needed(params, error_message)
await registry.remove_running(task.task_id)
logger.warning(
f"[Video {task.task_id}] 主动查询 Vidu 任务失败: "
f"{vidu_status.get('message')}"
f"{error_message}"
)
else:
# 仍在处理中,继续等待
@@ -182,15 +234,11 @@ class VideoHandler(AsyncHandler):
f"[Video {task.task_id}] 主动查询 Vidu 状态: {vidu_state},继续等待"
)
except Exception as e:
logger.warning(
f"[Video {task.task_id}] 主动查询 Vidu 失败: {e}"
)
logger.warning(f"[Video {task.task_id}] 主动查询 Vidu 失败: {e}")
continue # ← 已提交,不再重复提交
# 提交阶段:占用 slot,提交成功后自动释放
async with slots.acquire_ctx(
SLOT_KEY, task.task_id, self.max_slots
) as acquired:
async with slots.acquire_ctx(SLOT_KEY, task.task_id, self.max_slots) as acquired:
if not acquired:
continue
@@ -227,9 +275,7 @@ class VideoHandler(AsyncHandler):
raise ValueError("未返回任务ID")
params["vidu_task_id"] = vidu_task_id
changes.append(
StateChange(
task_id=task.task_id, field_path="params", value=params
)
StateChange(task_id=task.task_id, field_path="params", value=params)
)
changes.append(
StateChange(
@@ -241,9 +287,7 @@ class VideoHandler(AsyncHandler):
except RuntimeError:
logger.error(f"[Video {task.task_id}] service not initialized")
changes.append(
StateChange(
task_id=task.task_id, field_path="status", value="failed"
)
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(
@@ -259,12 +303,50 @@ class VideoHandler(AsyncHandler):
value="视频处理服务未就绪",
)
)
except PlatformError as e:
error_message = str(e)
is_audit = e.error_type == PlatformErrorType.CONTENT_VIOLATION
if is_audit:
message = "人物分镜台词未通过安全审核,请修改后重试"
elif e.error_type == PlatformErrorType.AUTH_FAILED:
message = "视频服务认证失败,请联系客服"
elif e.error_type == PlatformErrorType.RATE_LIMIT:
message = "视频生成服务繁忙,请稍后重试"
else:
message = "视频生成任务提交失败,请稍后重试"
logger.error(f"[Video {task.task_id}] submit platform error: {e}")
changes.append(
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="message",
value=message,
)
)
changes.append(
StateChange(
task_id=task.task_id,
field_path="error",
value=error_message,
)
)
if is_audit:
changes.append(
StateChange(
task_id=task.task_id,
field_path="error_code",
value="content_violation",
)
)
# 审核类错误在 PlatformGateway 已写缓存,此处幂等补充
await self._cache_audit_rejection_if_needed(params, error_message)
except Exception as e:
logger.error(f"[Video {task.task_id}] submit error: {e}")
changes.append(
StateChange(
task_id=task.task_id, field_path="status", value="failed"
)
StateChange(task_id=task.task_id, field_path="status", value="failed")
)
changes.append(
StateChange(
+1
View File
@@ -24,6 +24,7 @@ class TaskRecord:
total: int = 0
result: dict[str, Any] = field(default_factory=dict)
error: str | None = None
error_code: str | None = None
params: dict[str, Any] | Any = field(default_factory=dict)
created_at: str = ""
+14 -10
View File
@@ -7,8 +7,9 @@
import json
import logging
from collections.abc import Awaitable
from datetime import UTC
from typing import Any
from typing import Any, cast
from redis.asyncio import Redis
@@ -54,8 +55,8 @@ class TaskRegistry:
if params:
data["params"] = json.dumps(params, ensure_ascii=False)
await self.redis.hset(_task_key(task_id), mapping=data)
await self.redis.expire(_task_key(task_id), ttl)
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=data))
await cast(Awaitable[int], self.redis.expire(_task_key(task_id), ttl))
logger.debug(f"Registry created: {task_id}, type={task_type}")
async def update(self, task_id: str, **fields: Any) -> None:
@@ -68,12 +69,12 @@ class TaskRegistry:
mapping[key] = ""
else:
mapping[key] = str(value)
await self.redis.hset(_task_key(task_id), mapping=mapping)
await cast(Awaitable[int], self.redis.hset(_task_key(task_id), mapping=mapping))
logger.debug(f"Registry updated: {task_id}, fields={list(fields.keys())}")
async def get(self, task_id: str) -> TaskRecord | None:
"""读取完整作业记录"""
data = await self.redis.hgetall(_task_key(task_id))
data = await cast(Awaitable[dict[Any, Any]], self.redis.hgetall(_task_key(task_id)))
if not data:
return None
@@ -88,6 +89,8 @@ class TaskRegistry:
return int(raw)
except ValueError:
return 0
if key in ("error_code", "error") and raw == "":
return None
return raw
parsed = {k: _parse(k, v) for k, v in data.items()}
@@ -107,21 +110,22 @@ class TaskRegistry:
total=parsed.get("total", 0),
result=parsed.get("result", {}),
error=parsed.get("error"),
error_code=parsed.get("error_code"),
params=params,
created_at=parsed.get("created_at", ""),
)
async def add_running(self, task_id: str) -> None:
"""将作业标记为 running(加入全局 running 集合)"""
await self.redis.sadd(KEY_RUNNING_SET, task_id)
await cast(Awaitable[int], self.redis.sadd(KEY_RUNNING_SET, task_id))
async def remove_running(self, task_id: str) -> None:
"""将作业从全局 running 集合移除"""
await self.redis.srem(KEY_RUNNING_SET, task_id)
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))
async def get_running_task_ids(self) -> list[str]:
"""获取所有 running 的作业 ID 列表"""
members = await self.redis.smembers(KEY_RUNNING_SET)
members = await cast(Awaitable[set[Any]], self.redis.smembers(KEY_RUNNING_SET))
return list(members)
async def list_running_by_user(self, user_id: str) -> list[TaskRecord]:
@@ -139,5 +143,5 @@ class TaskRegistry:
async def delete(self, task_id: str) -> None:
"""删除作业记录"""
await self.redis.delete(_task_key(task_id))
await self.redis.srem(KEY_RUNNING_SET, task_id)
await cast(Awaitable[int], self.redis.delete(_task_key(task_id)))
await cast(Awaitable[int], self.redis.srem(KEY_RUNNING_SET, task_id))
+7 -3
View File
@@ -6,7 +6,9 @@
"""
import logging
from collections.abc import Awaitable
from contextlib import asynccontextmanager
from typing import cast
from redis.asyncio import Redis
@@ -37,7 +39,9 @@ class SlotManager:
async def acquire(self, slot_key: str, slot_id: str, max_slots: int) -> bool:
"""申请一个槽位。返回 True 表示成功,False 表示槽位已满。"""
try:
result = await self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots))
result = await cast(
Awaitable[str], self.redis.eval(_ACQUIRE_LUA, 1, slot_key, slot_id, str(max_slots))
)
acquired = result == 1
if acquired:
logger.debug(f"Slot acquired: {slot_key}/{slot_id} (max={max_slots})")
@@ -51,7 +55,7 @@ class SlotManager:
async def release(self, slot_key: str, slot_id: str) -> None:
"""释放一个槽位。"""
try:
await self.redis.srem(slot_key, slot_id)
await cast(Awaitable[int], self.redis.srem(slot_key, slot_id))
logger.debug(f"Slot released: {slot_key}/{slot_id}")
except Exception as e:
logger.warning(f"Slot release error: {slot_key}/{slot_id}: {e}")
@@ -78,6 +82,6 @@ class SlotManager:
async def count(self, slot_key: str) -> int:
"""获取当前已占用的槽位数量。"""
try:
return await self.redis.scard(slot_key)
return await cast(Awaitable[int], self.redis.scard(slot_key))
except Exception:
return 0
+2 -3
View File
@@ -3,7 +3,7 @@
==============
"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class BgmMusicItem(BaseModel):
@@ -18,8 +18,7 @@ class BgmMusicItem(BaseModel):
duration: float | None = Field(default=None, description="时长(秒)")
sort_order: int = Field(default=0, description="排序权重")
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class BgmMusicListResponse(BaseModel):
+16 -12
View File
@@ -6,14 +6,12 @@
{ code: number; data: T; message: string }
"""
from typing import Any, Generic, TypeVar
from typing import Any
from pydantic import BaseModel, Field
T = TypeVar("T")
from pydantic import BaseModel, ConfigDict, Field
class ApiResponse(BaseModel, Generic[T]):
class ApiResponse[T](BaseModel):
"""
统一 API 响应格式
@@ -27,17 +25,18 @@ class ApiResponse(BaseModel, Generic[T]):
data: T | None = Field(default=None, description="响应数据")
message: str = Field(default="success", description="提示信息")
class Config:
json_schema_extra = {
model_config = ConfigDict(
json_schema_extra={
"example": {
"code": 200,
"data": {},
"message": "success",
}
}
)
class PaginatedData(BaseModel, Generic[T]):
class PaginatedData[T](BaseModel):
"""分页数据包装"""
items: list[T] = Field(description="数据列表")
@@ -61,18 +60,23 @@ class PaginationParams(BaseModel):
class ApiErrorResponse(BaseModel):
"""错误响应格式"""
code: int = Field(description="错误")
code: int = Field(description="HTTP 状态")
message: str = Field(description="错误信息")
error_code: str | None = Field(default=None, description="应用级错误码")
detail: dict[str, Any] | None = Field(default=None, description="详细错误信息")
def success_response(data: T | None = None, message: str = "success") -> ApiResponse[T]:
def success_response[T](data: T | None = None, message: str = "success") -> ApiResponse[T]:
"""构造成功响应"""
return ApiResponse(code=200, data=data, message=message)
def error_response(
code: int, message: str, detail: dict[str, Any] | None = None
code: int,
message: str,
detail: dict[str, Any] | None = None,
*,
error_code: str | None = None,
) -> ApiErrorResponse:
"""构造错误响应"""
return ApiErrorResponse(code=code, message=message, detail=detail)
return ApiErrorResponse(code=code, message=message, error_code=error_code, detail=detail)
-3
View File
@@ -25,6 +25,3 @@ class SegmentStatus(StrEnum):
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
+3 -9
View File
@@ -18,9 +18,7 @@ class MatchMaterialRequest(BaseModel):
scene: str = Field(description="分镜场景描述")
duration: float = Field(description="所需时长(秒)", gt=0)
project_id: str | None = Field(
default=None, description="项目ID,用于跨分镜去重"
)
project_id: str | None = Field(default=None, description="项目ID,用于跨分镜去重")
@field_validator("scene")
@classmethod
@@ -56,9 +54,7 @@ class BatchMatchSceneItem(BaseModel):
class BatchMatchMaterialRequest(BaseModel):
"""批量匹配素材请求"""
project_id: str | None = Field(
default=None, description="项目ID,用于跨分镜去重"
)
project_id: str | None = Field(default=None, description="项目ID,用于跨分镜去重")
scenes: list[BatchMatchSceneItem] = Field(
description="分镜场景列表",
min_length=1,
@@ -76,6 +72,4 @@ class BatchMatchMaterialResponse(BaseModel):
"""批量匹配素材响应"""
project_id: str | None = Field(description="项目ID")
results: list[MaterialInfo | None] = Field(
description="匹配结果列表,与 scenes 一一对应"
)
results: list[MaterialInfo | None] = Field(description="匹配结果列表,与 scenes 一一对应")
+16 -3
View File
@@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field
# ── 余额查询 ──────────────────────────────────────────
class PointBalanceResponse(BaseModel):
"""积分余额响应"""
@@ -20,6 +21,7 @@ class PointBalanceResponse(BaseModel):
# ── 流水记录 ──────────────────────────────────────────
class PointTransactionItem(BaseModel):
"""单条流水记录"""
@@ -33,7 +35,10 @@ class PointTransactionItem(BaseModel):
source_type: str | None = Field(None, description="消费来源类型")
source_id: str | None = Field(None, description="消费来源业务 ID")
duration: float | None = Field(None, description="时长(秒),按秒计费业务记录")
category: str | None = Field(None, description="业务分类:脚本生成 / 配音合成 / 视频生成 / 压制成片 / 字幕烧录 / 封面设计 / 充值")
category: str | None = Field(
None,
description="业务分类:脚本生成 / 配音合成 / 视频生成 / 压制成片 / 字幕烧录 / 封面设计 / 充值",
)
description: str | None = None
created_at: datetime
@@ -49,6 +54,7 @@ class PointTransactionListResponse(BaseModel):
# ── 充值 ──────────────────────────────────────────────
class RechargeRequest(BaseModel):
"""充值请求"""
@@ -70,6 +76,7 @@ class RechargeResponse(BaseModel):
# ── 积分预估 ──────────────────────────────────────────
class CostEstimateResponse(BaseModel):
"""积分预估响应"""
@@ -80,6 +87,7 @@ class CostEstimateResponse(BaseModel):
# ── 充值订单 ──────────────────────────────────────────
class RechargeOrderItem(BaseModel):
"""充值订单记录"""
@@ -104,11 +112,16 @@ class RechargeOrderListResponse(BaseModel):
# ── 消费扣费 ──────────────────────────────────────────
class ConsumeRequest(BaseModel):
"""消费扣费请求(供前端/Rust 层调用)"""
points: int = Field(gt=0, description="消耗积分数量")
source_type: str = Field(description="消费来源类型,如 compose / subtitle_burn / cover_design / video")
source_type: str = Field(
description="消费来源类型,如 compose / subtitle_burn / cover_design / video"
)
source_id: str = Field(description="消费来源业务 ID")
description: str | None = Field(default=None, description="消费描述")
allow_negative: bool = Field(default=False, description="是否允许扣费后余额为负(不确定消耗场景用)")
allow_negative: bool = Field(
default=False, description="是否允许扣费后余额为负(不确定消耗场景用)"
)
+3 -2
View File
@@ -3,7 +3,6 @@
===============
"""
from pydantic import BaseModel, ConfigDict, Field
from app.schemas.segment import Segment
@@ -93,7 +92,9 @@ class GenerateTitleRequest(BaseModel):
script_content: str = Field(..., description="脚本内容(utterances 文本拼接)", min_length=1)
title_type: str = Field(..., description="标题类型:main(大标题) / sub(小标题)")
max_length: int = Field(default=8, ge=1, le=100, description="最大字数限制")
usage: str = Field(default="video", description="使用场景:video(视频画面标题) / cover(封面标题)")
usage: str = Field(
default="video", description="使用场景:video(视频画面标题) / cover(封面标题)"
)
class GenerateTitleResponse(BaseModel):
+1 -3
View File
@@ -32,9 +32,7 @@ class Segment(BaseModel):
duration: int | None = Field(default=None, description="时长(秒)")
voice_id: str | None = Field(default=None, description="音色ID(空镜时使用)")
status: SegmentStatus = Field(default=SegmentStatus.PENDING)
provider_task_id: str | None = Field(
default=None, description="供应商任务ID"
)
provider_task_id: str | None = Field(default=None, description="供应商任务ID")
video_url: str | None = Field(default=None, description="生成后的视频URL")
local_path: str | None = Field(default=None, description="本地视频路径")
qiniu_url: str | None = Field(default=None, description="七牛云URL")
+2 -2
View File
@@ -6,7 +6,6 @@ Tauri updater 插件所需的请求/响应模型。
"""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
@@ -58,6 +57,7 @@ class ReleaseListItem(BaseModel):
# Tauri updater 插件所需的 JSON 格式
# ------------------------------------------------------------------
class TauriPlatformInfo(BaseModel):
"""Tauri updater 单平台信息"""
@@ -73,7 +73,7 @@ class TauriUpdateResponse(BaseModel):
version: str = Field(..., description="新版本号")
notes: str = Field(default="", description="更新说明")
pub_date: Optional[str] = Field(default=None, description="发布时间(RFC 3339")
pub_date: str | None = Field(default=None, description="发布时间(RFC 3339")
mandatory: bool = Field(default=False, description="是否强制更新(自定义扩展字段)")
platforms: dict[str, TauriPlatformInfo] = Field(
..., description="平台安装包映射,key 格式:OS-ARCH"
+2 -3
View File
@@ -6,7 +6,7 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class UserInfo(BaseModel):
@@ -30,8 +30,7 @@ class UserProfileResponse(BaseModel):
last_login_at: datetime | None = Field(None, description="最后登录时间")
created_at: datetime = Field(..., description="注册时间")
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class UpdateNicknameRequest(BaseModel):
+4 -1
View File
@@ -220,7 +220,10 @@ def validate_and_normalize_shots(raw_data: Any) -> list[dict[str, Any]]:
duration = item.get("duration")
duration_str = parse_duration(duration) # 返回如 "5s"
try:
normalized["duration"] = int(re.search(r"\d+", duration_str).group())
match = re.search(r"\d+", duration_str)
if match is None:
raise ValueError("时长字符串中未找到数字")
normalized["duration"] = int(match.group())
except (AttributeError, ValueError):
normalized["duration"] = 5
+12 -12
View File
@@ -14,6 +14,7 @@ import hashlib
import logging
import random
from typing import Any
from uuid import UUID
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
@@ -32,9 +33,9 @@ from app.models.user_device import UserDevice
logger = logging.getLogger(__name__)
# ── 短信业务常量(数值类配置不走 .env,内嵌代码)─────
SMS_CODE_LENGTH = 6 # 验证码位数
SMS_CODE_EXPIRE_MINUTES = 5 # 验证码有效期(分钟)
SMS_DAILY_LIMIT = 10 # 单手机号每日发送上限
SMS_CODE_LENGTH = 6 # 验证码位数
SMS_CODE_EXPIRE_MINUTES = 5 # 验证码有效期(分钟)
SMS_DAILY_LIMIT = 10 # 单手机号每日发送上限
# ========== SSE 连接池 ==========
# key: user_id, value: asyncio.Queue(用于向该用户的 SSE 连接发送消息)
@@ -73,6 +74,7 @@ async def _kick_old_device(user_id: str) -> None:
# ========== 验证码校验 ==========
async def verify_sms_code(mobile: str, code: str) -> bool:
"""
校验短信验证码
@@ -145,9 +147,7 @@ async def send_sms_code(mobile: str) -> str:
)
else:
# 配置不完整,记录警告但不打印验证码
logger.warning(
f"[SMS] B2M 短信配置不完整,验证码未发送: 手机号={mobile}"
)
logger.warning(f"[SMS] B2M 短信配置不完整,验证码未发送: 手机号={mobile}")
except SMSError as e:
logger.error(f"[SMS] 短信发送失败: {e}")
# 短信发送失败不影响验证码生成
@@ -159,6 +159,7 @@ async def send_sms_code(mobile: str) -> str:
# ========== Token 工具 ==========
def _hash_refresh_token(token: str) -> str:
"""Refresh Token SHA256 哈希(用于数据库存储)"""
return hashlib.sha256(token.encode()).hexdigest()
@@ -166,6 +167,7 @@ def _hash_refresh_token(token: str) -> str:
# ========== 登录服务 ==========
async def login_with_sms(
db: AsyncSession,
*,
@@ -260,9 +262,7 @@ async def refresh_access_token(
refresh_token_hash = _hash_refresh_token(refresh_token)
# 2. 查设备记录
device = await device_crud.get_by_refresh_token_hash(
db, refresh_token_hash=refresh_token_hash
)
device = await device_crud.get_by_refresh_token_hash(db, refresh_token_hash=refresh_token_hash)
if device is None:
raise ValueError("设备已失效,请重新登录")
@@ -397,7 +397,7 @@ async def reset_password_with_sms(
return True
async def logout(db: AsyncSession, *, user_id: str) -> bool:
async def logout(db: AsyncSession, *, user_id: UUID | str) -> bool:
"""
用户登出
@@ -406,14 +406,14 @@ async def logout(db: AsyncSession, *, user_id: str) -> bool:
2. 注销 SSE 连接
"""
await device_crud.delete_by_user_id(db, user_id=user_id)
unregister_sse_connection(user_id)
unregister_sse_connection(str(user_id))
return True
async def get_current_user_device(
db: AsyncSession,
*,
user_id: str,
user_id: UUID | str,
) -> UserDevice | None:
"""获取当前用户的设备记录"""
return await device_crud.get_by_user_id(db, user_id=user_id)
+32 -67
View File
@@ -9,6 +9,8 @@ import logging
import math
import random
import re
from collections.abc import Awaitable
from typing import Any, cast
from sqlalchemy.ext.asyncio import AsyncSession
@@ -16,6 +18,7 @@ from app.core.exceptions import ValidationException
from app.core.redis_client import get_redis_client
from app.crud import broll_category, broll_material
from app.models.broll_category import BrollCategory
from app.models.broll_material import BrollMaterial
logger = logging.getLogger(__name__)
@@ -28,13 +31,13 @@ def _normalize_scene(scene: str) -> str:
# 去除所有 Unicode 空白字符(空格、全角空格、换行、tab 等)
cleaned = re.sub(r"\s+", "", scene)
# 去除常见中文标点符号(逗号、句号、感叹号、问号、顿号、分号、冒号、引号、括号等)
cleaned = re.sub(r"[,。!?、;:""''()【】《》]+", "", cleaned)
cleaned = re.sub(r"[,。!?、;:" "''()【】《》]+", "", cleaned)
# 去除零宽字符(零宽空格、零宽非连接符、零宽连接符、零宽非断空格等)
cleaned = re.sub(r"[\u200b-\u200f\ufeff]+", "", cleaned)
return cleaned
def _weighted_choice(materials: list) -> object: # noqa: ANN001
def _weighted_choice(materials: list[BrollMaterial]) -> BrollMaterial:
"""
加权随机选择素材
@@ -50,9 +53,9 @@ def _weighted_choice(materials: list) -> object: # noqa: ANN001
total_weight = sum(weights)
if total_weight == 0:
return random.choice(materials)
return random.choice(materials) # nosec B311 素材抽样,非加密场景
r = random.uniform(0, total_weight)
r = random.uniform(0, total_weight) # nosec B311 素材抽样,非加密场景
cumulative = 0.0
for m, w in zip(materials, weights, strict=True):
cumulative += w
@@ -74,10 +77,9 @@ async def _try_fallback_to_parent(
- scene '-'取后半部分作为 parent_name '电路施工-电路施工' -> '电路施工'
- 若不含 '-'直接以整个 scene 作为 parent_name
匹配策略逐级降级
匹配策略
1. 精确匹配 level=2 分类 name
2. 模糊匹配LIKE %parent_name%兼容 "电路施工" "电路施工镜"
3. 去掉常见后缀阶段等再精确匹配
2. 模糊匹配LIKE %parent_name%
返回
随机选中的一个 level=3 子分类 None
@@ -88,37 +90,20 @@ async def _try_fallback_to_parent(
parent_name = normalized_scene
# 1. 精确匹配
parent = await broll_category.get_by_name_and_level(
db, name=parent_name, level=2
)
parent = await broll_category.get_by_name_and_level(db, name=parent_name, level=2)
# 2. 模糊匹配(兼容 "电路施工" → "电路施工镜"
# 2. 模糊匹配
if parent is None:
parent = await broll_category.get_by_name_like_and_level(
db, name=parent_name, level=2
)
# 3. 去掉常见后缀再试
if parent is None:
for suffix in ("", "阶段"):
if not parent_name.endswith(suffix):
candidate = parent_name + suffix
parent = await broll_category.get_by_name_and_level(
db, name=candidate, level=2
)
if parent:
break
parent = await broll_category.get_by_name_like_and_level(db, name=parent_name, level=2)
if parent is None:
return None
children = await broll_category.get_children_by_parent_id(
db, parent_id=parent.id, level=3
)
children = await broll_category.get_children_by_parent_id(db, parent_id=parent.id, level=3)
if not children:
return None
return random.choice(children)
return random.choice(children) # nosec B311 素材抽样,非加密场景
async def match_material(
@@ -161,46 +146,34 @@ async def match_material(
normalized = _normalize_scene(scene)
# 1. 查找三级分类(精确匹配 -> 全量内存匹配兜底 -> 顺序颠倒 -> 上级回退)
category = await broll_category.get_by_name_and_level(
db, name=normalized, level=3
)
category = await broll_category.get_by_name_and_level(db, name=normalized, level=3)
# 精确匹配失败时,全量查询后在内存标准化匹配(兼容数据库 name 含不可见字符)
if category is None:
all_categories = await broll_category.get_by_level(db, level=3)
for c in all_categories:
if _normalize_scene(c.name) == normalized:
category = c
logger.info(
f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'"
)
logger.info(f"素材分类全量内存匹配命中: '{normalized}' -> '{c.name}'")
break
# 若仍失败,尝试将 "A-B" 倒序为 "B-A" 再匹配
if category is None:
parts = normalized.rsplit("-", 1)
if len(parts) == 2:
reversed_name = f"{parts[1]}-{parts[0]}"
category = await broll_category.get_by_name_and_level(
db, name=reversed_name, level=3
)
category = await broll_category.get_by_name_and_level(db, name=reversed_name, level=3)
if category:
logger.info(
f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'"
)
logger.info(f"素材分类顺序颠倒兜底命中: '{normalized}' -> '{reversed_name}'")
# 若仍失败,回退到上级分类随机选取
if category is None:
category = await _try_fallback_to_parent(db, normalized)
if category:
logger.info(
f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'"
)
logger.info(f"素材回退到上级分类命中: '{normalized}' -> '{category.name}'")
if category is None:
logger.warning(f"素材匹配失败: 未找到分类 '{normalized}' (原始 scene: '{scene}')")
return None
# 2. 查询该分类下所有 active 素材(先不过滤时长,用于日志诊断)
all_materials = await broll_material.get_active_by_categories(
db, category_ids=[category.id]
)
all_materials = await broll_material.get_active_by_categories(db, category_ids=[category.id])
if not all_materials:
logger.warning(f"素材匹配失败: 分类 '{normalized}' 下无任何可用素材")
return None
@@ -224,7 +197,7 @@ async def match_material(
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
@@ -243,8 +216,8 @@ async def match_material(
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
await redis.sadd(key, chosen.url)
await redis.expire(key, _USED_MATERIALS_TTL)
await cast(Awaitable[int], redis.sadd(key, chosen.url))
await cast(Awaitable[int], redis.expire(key, _USED_MATERIALS_TTL))
except Exception as e:
logger.warning(f"Redis 去重记录失败: {e}")
@@ -282,10 +255,8 @@ async def batch_match(
unique_names = list(set(normalized_scenes))
# 2. 批量查询分类:优先精确查询,失败时全量内存匹配兜底
categories = await broll_category.get_by_names_and_level(
db, names=unique_names, level=3
)
category_map: dict[str, object] = {}
categories = await broll_category.get_by_names_and_level(db, names=unique_names, level=3)
category_map: dict[str, BrollCategory] = {}
for c in categories:
category_map[_normalize_scene(c.name)] = c
@@ -305,16 +276,14 @@ async def batch_match(
if len(parts) == 2:
reversed_map[name] = f"{parts[1]}-{parts[0]}"
scene_to_category: dict[str, object] = {}
scene_to_category: dict[str, BrollCategory] = {}
for name in unique_names:
if name in category_map:
scene_to_category[name] = category_map[name]
elif name in reversed_map and reversed_map[name] in category_map:
rev = reversed_map[name]
scene_to_category[name] = category_map[rev]
logger.info(
f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'"
)
logger.info(f"批量匹配顺序颠倒兜底命中: '{name}' -> '{rev}'")
# 3. 未匹配的 scene 回退到上级分类随机选取
unmatched = [name for name in unique_names if name not in scene_to_category]
@@ -322,15 +291,11 @@ async def batch_match(
fallback_cat = await _try_fallback_to_parent(db, name)
if fallback_cat:
scene_to_category[name] = fallback_cat
logger.info(
f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'"
)
logger.info(f"批量匹配回退到上级分类命中: '{name}' -> '{fallback_cat.name}'")
# 4. 收集所有需要的 category_id
needed_category_ids = [
scene_to_category[name].id
for name in unique_names
if name in scene_to_category
scene_to_category[name].id for name in unique_names if name in scene_to_category
]
# 4. 批量查询素材(1 次 DB
@@ -339,7 +304,7 @@ async def batch_match(
)
# 按 category_id 分组,方便内存过滤
materials_by_category: dict[int, list] = {}
materials_by_category: dict[int, list[BrollMaterial]] = {}
for m in all_materials:
materials_by_category.setdefault(m.category_id, []).append(m)
@@ -349,13 +314,13 @@ async def batch_match(
try:
redis = get_redis_client()
key = f"proj:{project_id}:used_materials"
used_urls = set(await redis.smembers(key))
used_urls = set(await cast(Awaitable[set[Any]], redis.smembers(key)))
except Exception as e:
logger.warning(f"Redis 去重查询失败,降级为不去重: {e}")
# 6. 内存中逐个匹配
results: list[dict | None] = []
chosen_materials: list = [] # 记录选中的素材,用于批量更新
chosen_materials: list[BrollMaterial] = [] # 记录选中的素材,用于批量更新
for idx, scene_name in enumerate(normalized_scenes):
required_duration = scenes[idx]["duration"]
+22 -24
View File
@@ -25,7 +25,7 @@ import logging
import math
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import yaml
from sqlalchemy import select
@@ -33,6 +33,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
from app.core.exceptions import InsufficientPointsException
from app.models.point_batch import PointBatch
from app.models.point_transaction import PointTransaction
from app.models.user_point import UserPoint
@@ -46,11 +47,11 @@ if TYPE_CHECKING:
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
def _load_points_config() -> dict:
def _load_points_config() -> dict[str, Any]:
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
if not _CONFIG_PATH.exists():
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
with open(_CONFIG_PATH, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
merged: dict[str, dict] = {}
@@ -65,18 +66,22 @@ def _load_points_config() -> dict:
return merged
POINTS_CONFIG: dict[str, dict] = _load_points_config()
POINTS_CONFIG: dict[str, Any] = _load_points_config()
def get_recharge_options() -> list[dict]:
"""获取充值档位配置(由后端控制,支持积分赠送)"""
return POINTS_CONFIG.get("_recharge_options", [])
options = POINTS_CONFIG.get("_recharge_options", [])
if isinstance(options, list):
return options
return []
def get_chargeable_source_types() -> list[str]:
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
return [
key for key, cfg in POINTS_CONFIG.items()
key
for key, cfg in POINTS_CONFIG.items()
if not key.startswith("_") and cfg.get("mode") != "free"
]
@@ -163,11 +168,10 @@ def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
# ── 余额查询 ──────────────────────────────────────────
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -182,8 +186,7 @@ async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
from sqlalchemy import func as _func
available_result = await db.execute(
select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
.where(
select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where(
PointBatch.user_id == user_id,
PointBatch.remaining > 0,
PointBatch.expired_at > _now(),
@@ -221,6 +224,7 @@ async def check_balance(
# ── 充值 ──────────────────────────────────────────────
async def recharge(
db: AsyncSession,
*,
@@ -247,8 +251,7 @@ async def recharge(
# 幂等保护:同一笔订单(order_id)只能充值一次
if order_id:
existing_result = await db.execute(
select(PointTransaction)
.where(
select(PointTransaction).where(
PointTransaction.source_id == str(order_id),
PointTransaction.type == "recharge",
)
@@ -259,9 +262,7 @@ async def recharge(
return existing_tx
# 1. 获取或创建用户积分账户
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id))
up = result.scalar_one_or_none()
if not up:
@@ -353,7 +354,7 @@ async def consume(
直接扣费后置计费
业务执行成功后调用按实际消耗直接扣除余额
默认不允许欠费余额不足时抛出 ValueError
默认不允许欠费余额不足时抛出 InsufficientPointsException
Scheduler 后置扣费等场景可设置 allow_negative=True允许余额变负
:param points: 实际消耗积分正整数
@@ -383,9 +384,7 @@ async def consume(
# 2. 获取用户积分账户(加锁)
result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == user_id)
.with_for_update()
select(UserPoint).where(UserPoint.user_id == user_id).with_for_update()
)
up = result.scalar_one_or_none()
@@ -404,7 +403,7 @@ async def consume(
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
available = sum(b.remaining for b in batches)
if not allow_negative and available < points:
raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分")
remaining_to_deduct = points
for batch in batches:
@@ -440,6 +439,7 @@ async def consume(
# ── 过期回收 ──────────────────────────────────────────
async def expire_batches(db: AsyncSession) -> int:
"""
回收过期积分批次返回过期积分总数
@@ -468,9 +468,7 @@ async def expire_batches(db: AsyncSession) -> int:
# 获取用户积分账户(加锁)
up_result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == batch.user_id)
.with_for_update()
select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update()
)
up = up_result.scalar_one_or_none()
if not up:
+33 -25
View File
@@ -104,7 +104,9 @@ class QiniuService:
# 项目前缀
PROJECT_PREFIX = "meijiaka-zy"
def generate_key(self, file_type: str, original_filename: str, user_id: str = None) -> str:
def generate_key(
self, file_type: str, original_filename: str, user_id: str | None = None
) -> str:
"""
生成规范的文件存储路径
@@ -143,7 +145,7 @@ class QiniuService:
return mime_type in allowed_types
def get_upload_token(
self, bucket: str, key: str, expires: int = 3600, policy: dict = None
self, bucket: str, key: str, expires: int = 3600, policy: dict | None = None
) -> str:
"""
生成上传凭证客户端直传使用
@@ -197,9 +199,9 @@ class QiniuService:
def upload_file(
self,
local_path: str,
key: str = None,
key: str | None = None,
file_type: str = "audio",
user_id: str = None,
user_id: str | None = None,
check_duplicate: bool = True,
) -> dict:
"""
@@ -222,15 +224,15 @@ class QiniuService:
"is_duplicate": 是否复用已有文件
}
"""
local_path = Path(local_path)
if not local_path.exists():
raise FileNotFoundError(f"文件不存在: {local_path}")
local_path_obj = Path(local_path)
if not local_path_obj.exists():
raise FileNotFoundError(f"文件不存在: {local_path_obj}")
# 根据文件类型获取对应的 bucket 和 domain
bucket, domain = self._get_bucket_and_domain(file_type)
# 计算文件 MD5 哈希
file_md5 = self._calculate_file_hash(local_path)
file_md5 = self._calculate_file_hash(local_path_obj)
# 检查是否已存在相同文件
if check_duplicate:
@@ -248,20 +250,20 @@ class QiniuService:
# 自动生成 Key
if key is None:
key = self.generate_key(file_type, local_path.name, user_id)
key = self.generate_key(file_type, local_path_obj.name, user_id)
# 生成上传 Token
token = self.get_upload_token(bucket, key)
# 使用分片上传
ret, info = put_file(up_token=token, key=key, file_path=str(local_path))
ret, info = put_file(up_token=token, key=key, file_path=str(local_path_obj))
if ret is None:
raise Exception(f"上传失败: {info}")
# 获取文件信息
mime_type, _ = mimetypes.guess_type(str(local_path))
fsize = local_path.stat().st_size
mime_type, _ = mimetypes.guess_type(str(local_path_obj))
fsize = local_path_obj.stat().st_size
return {
"key": ret["key"],
@@ -277,8 +279,8 @@ class QiniuService:
stream: BinaryIO,
key: str,
mime_type: str = "application/octet-stream",
bucket: str = None,
domain: str = None,
bucket: str | None = None,
domain: str | None = None,
) -> dict:
"""
上传文件流到七牛云
@@ -317,7 +319,9 @@ class QiniuService:
return {"key": ret["key"], "hash": ret["hash"], "url": self.get_file_url(domain, key)}
def upload_audio(self, local_path: str, user_id: str = None, key: str = None) -> dict:
def upload_audio(
self, local_path: str, user_id: str | None = None, key: str | None = None
) -> dict:
"""
上传音频文件专用接口
@@ -336,7 +340,9 @@ class QiniuService:
return self.upload_file(local_path=local_path, key=key, file_type="audio", user_id=user_id)
def upload_video(self, local_path: str, user_id: str = None, key: str = None) -> dict:
def upload_video(
self, local_path: str, user_id: str | None = None, key: str | None = None
) -> dict:
"""
上传视频文件专用接口
@@ -454,9 +460,9 @@ class QiniuService:
async def upload_file_async(
self,
local_path: str,
key: str = None,
key: str | None = None,
file_type: str = "audio",
user_id: str = None,
user_id: str | None = None,
check_duplicate: bool = True,
) -> dict:
"""异步版本 upload_file"""
@@ -469,19 +475,21 @@ class QiniuService:
stream: BinaryIO,
key: str,
mime_type: str = "application/octet-stream",
bucket: str = None,
domain: str = None,
bucket: str | None = None,
domain: str | None = None,
) -> dict:
"""异步版本 upload_stream"""
return await asyncio.to_thread(
self.upload_stream, stream, key, mime_type, bucket, domain
)
return await asyncio.to_thread(self.upload_stream, stream, key, mime_type, bucket, domain)
async def upload_audio_async(self, local_path: str, user_id: str = None, key: str = None) -> dict:
async def upload_audio_async(
self, local_path: str, user_id: str | None = None, key: str | None = None
) -> dict:
"""异步版本 upload_audio"""
return await asyncio.to_thread(self.upload_audio, local_path, user_id, key)
async def upload_video_async(self, local_path: str, user_id: str = None, key: str = None) -> dict:
async def upload_video_async(
self, local_path: str, user_id: str | None = None, key: str | None = None
) -> dict:
"""异步版本 upload_video"""
return await asyncio.to_thread(self.upload_video, local_path, user_id, key)
+27 -17
View File
@@ -7,9 +7,16 @@ import asyncio
import logging
import time
from pathlib import Path
from typing import Any
from app.ai.model_router import get_model_router
from app.ai.prompts import load_prompt_file, load_script_user_prompt
from app.core.exceptions import (
AIEmptyResponseException,
AIParseErrorException,
AITimeoutException,
PromptNotFoundException,
)
from app.schemas.script import ScriptShot
from app.services.ai_response_utils import (
safe_parse_ai_json_response,
@@ -22,12 +29,9 @@ logger = logging.getLogger(__name__)
class ScriptService:
"""脚本生成服务"""
def __init__(self):
self.prompts_dir = Path(__file__).parent.parent / "ai" / "prompts"
def _load_prompt(self, name: str) -> str:
"""加载 Prompt 模板"""
prompt_file = self.prompts_dir / f"{name}.txt"
@@ -58,7 +62,7 @@ class ScriptService:
# 加载 Prompt
system_prompt = load_prompt_file(category, filename)
if not system_prompt:
raise ValueError(f"未找到提示词: category={category}, filename={filename}")
raise PromptNotFoundException(f"未找到提示词: category={category}, filename={filename}")
# 用户提示词
user_prompt = load_script_user_prompt(
@@ -75,24 +79,26 @@ class ScriptService:
)
if not result.content or not result.content.strip():
raise ValueError("AI 返回内容为空,请检查模型配置或重试")
raise AIEmptyResponseException("AI 返回内容为空,请检查模型配置或重试")
success, parsed_data, error_msg = safe_parse_ai_json_response(result.content)
if not success:
raise ValueError(error_msg or "AI 返回格式错误,无法解析为 JSON")
raise AIParseErrorException(error_msg or "AI 返回格式错误,无法解析为 JSON")
try:
shots_data = validate_and_normalize_shots(parsed_data)
if not shots_data:
raise ValueError("AI 返回的分镜数据为空或格式不正确")
raise AIEmptyResponseException("AI 返回的分镜数据为空或格式不正确")
shots = [ScriptShot(**shot) for shot in shots_data]
return shots
except (AIEmptyResponseException, AIParseErrorException):
raise
except Exception as e:
raise ValueError(f"分镜数据处理失败: {str(e)}")
raise AIParseErrorException(f"分镜数据处理失败: {str(e)}")
async def polish_content(
self,
@@ -144,21 +150,23 @@ class ScriptService:
)
return result.content.strip()
except TimeoutError:
raise ValueError("润色请求超时,请重试")
raise AITimeoutException("润色请求超时,请重试")
except (AIEmptyResponseException, AIParseErrorException, AITimeoutException):
raise
except Exception as e:
raise ValueError(f"润色失败: {str(e)}")
raise AIParseErrorException(f"润色失败: {str(e)}")
async def check_model_health(self) -> dict:
"""检查模型健康状态"""
model_router = await get_model_router()
health_results = await model_router.health_check()
models = []
models: list[dict[str, Any]] = []
available_count = 0
recommended = None
recommended: dict[str, Any] | None = None
for _provider_id, health in health_results.items():
model_info = {
model_info: dict[str, Any] = {
"id": health.id,
"name": health.name,
"is_available": health.is_available,
@@ -169,9 +177,12 @@ class ScriptService:
if health.is_available:
available_count += 1
if recommended is None or health.response_time < recommended.get(
"response_time", float("inf")
):
current_best = (
float("inf")
if recommended is None
else float(recommended.get("response_time") or float("inf"))
)
if health.response_time < current_best:
recommended = model_info
total = len(models)
@@ -188,7 +199,6 @@ class ScriptService:
"""测试指定模型连接"""
model_router = await get_model_router()
start_time = time.time()
try:
+2 -3
View File
@@ -76,9 +76,7 @@ class B2MSMSService:
if self.secret_key:
key_bytes = self.secret_key.encode("utf-8")
if len(key_bytes) not in (16, 24, 32):
raise SMSError(
f"AES 密钥长度必须是 16/24/32 字节,当前 {len(key_bytes)} 字节"
)
raise SMSError(f"AES 密钥长度必须是 16/24/32 字节,当前 {len(key_bytes)} 字节")
self._secret_key_bytes = key_bytes
if not all([self.app_id, self.secret_key, self.base_url]):
@@ -245,6 +243,7 @@ class B2MSMSService:
# ── 便捷函数 ──────────────────────────────────────────
def get_sms_service() -> B2MSMSService:
"""获取短信服务实例"""
return B2MSMSService()
+6 -3
View File
@@ -207,8 +207,9 @@ class ViduService:
error_type=PlatformErrorType.BAD_REQUEST,
)
logger.info(f"[Vidu Clone] 复刻成功: voice_id={result.data.get('voice_id')}")
return result.data or {}
clone_data = result.data or {}
logger.info(f"[Vidu Clone] 复刻成功: voice_id={clone_data.get('voice_id')}")
return clone_data
async def query_clone_task(self, voice_id: str) -> dict[str, Any]:
"""Vidu 声音复刻是同步接口,无独立查询。
@@ -270,6 +271,8 @@ class ViduService:
result_data = status.result or {}
return {
"state": ViduAdapter.denormalize_state(status.state),
"creations": [{"url": result_data.get("video_url")}] if result_data.get("video_url") else [],
"creations": (
[{"url": result_data.get("video_url")}] if result_data.get("video_url") else []
),
"message": status.error_message,
}
@@ -155,10 +155,7 @@ class VolcengineCaptionService:
error_type=PlatformErrorType.BAD_REQUEST,
)
logger.warning(
f"{task_name}超过最大轮询次数: task_id={task_id}, "
f"retries={retries}"
)
logger.warning(f"{task_name}超过最大轮询次数: task_id={task_id}, " f"retries={retries}")
raise PlatformError(
f"{task_name}超时,请稍后重试",
platform="volcengine_caption",
@@ -78,9 +78,7 @@ class VolcengineMediakitService:
)
if not response.success:
raise RuntimeError(
response.error_message or "抠图失败"
)
raise RuntimeError(response.error_message or "抠图失败")
result_image_url = (response.data or {}).get("image_url", "")
return RemoveBackgroundResult(
+6 -2
View File
@@ -69,7 +69,9 @@ class WechatPayService:
self.notify_url = settings.WXPAY_NOTIFY_URL
if not all([self.mchid, self.appid, self.api_key]):
raise WechatPayError("微信支付配置不完整:WXPAY_MCHID / WXPAY_APPID / WXPAY_API_KEY 未配置")
raise WechatPayError(
"微信支付配置不完整:WXPAY_MCHID / WXPAY_APPID / WXPAY_API_KEY 未配置"
)
# ── 签名与验签 ──────────────────────────────────────
@@ -85,7 +87,9 @@ class WechatPayService:
5. MD5 加密结果转大写
"""
# 过滤空值和 sign
filtered = {k: str(v) for k, v in params.items() if v is not None and v != "" and k != "sign"}
filtered = {
k: str(v) for k, v in params.items() if v is not None and v != "" and k != "sign"
}
# ASCII 升序排序
sorted_items = sorted(filtered.items(), key=lambda x: x[0])
# 拼接字符串
+3 -1
View File
@@ -41,13 +41,15 @@ async def get_audio_duration(url: str, timeout: float = 10.0) -> float:
try:
if header[:3] == b"ID3" or header[:2] == b"\xff\xfb" or header[:2] == b"\xff\xf3":
audio = MP3(data)
audio: MP3 | WAVE = MP3(data)
elif header[:4] == b"RIFF":
audio = WAVE(data)
else:
# fallback:先尝试 MP3(大多数 TTS 返回 mp3
audio = MP3(data)
if audio.info is None:
raise ValueError("音频信息解析失败")
duration = audio.info.length
if duration is None or duration <= 0:
raise ValueError("音频时长解析失败")
+133
View File
@@ -0,0 +1,133 @@
"""
内容指纹工具
============
用于 AI 第三方平台 Vidu的审核结果缓存与防重复提交
核心逻辑
- 对提交的音频/视频/图片 URL + 任务参数生成 SHA256 指纹
- 如果相同内容近期因审核失败被缓存则直接返回错误不再调用第三方平台
- 仅规范化 URL去掉 token 等动态参数不下载大文件本身
"""
from __future__ import annotations
import hashlib
from urllib.parse import parse_qs, urlencode, urlparse
# Vidu 内容安全/审核类错误码
# 这些错误在内容不变的情况下重试也没用,应该被缓存。
VIDU_AUDIT_ERROR_CODES = frozenset(
{
"TaskPromptPolicyViolation", # Prompt 触发安审风控
"AuditSubmitIllegal", # 输入未通过安全审核
"CreationPolicyViolation", # 生成物触发风控
"PhotoAuditNotPass", # 图片审核不通过
"AuditFailed", # 审核失败
"ImageCheckBodyJointsFailed", # 输入图人体检测失败
"ImageCheckFaceFailed", # 输入图人脸检测失败
"ImageObjectsUndetected", # 人体或人脸有遮挡
"FaceDetectFailure", # 人脸检测失败
"FaceDetectNotPass", # 人脸检测不通过
"NoFaceDetected", # 未检测到人脸
"MultiFaceDetected", # 检测到多张人脸
}
)
# 常见动态 query 参数,生成指纹时应忽略
_DYNAMIC_QUERY_PARAMS = frozenset(
{
"token",
"e", # 七牛过期时间戳
"t", # 时间戳
"sign",
"x-oss-signature",
"x-oss-expires",
"x-oss-access-key-id",
"response-content-disposition",
"v", # 版本号/缓存戳
}
)
def normalize_url(url: str | None) -> str:
"""规范化 URL,去掉动态参数,确保同一资源不同 token 得到相同指纹。
Args:
url: 原始 URL可能包含七牛私有 token时间戳等动态参数
Returns:
规范化后的 URL 字符串None 或空字符串返回空串
"""
if not url:
return ""
parsed = urlparse(url)
query_params = parse_qs(parsed.query, keep_blank_values=True)
for key in list(query_params.keys()):
if key.lower() in _DYNAMIC_QUERY_PARAMS:
del query_params[key]
query = urlencode(sorted(query_params.items()), doseq=True)
path = parsed.path or "/"
if query:
return f"{parsed.scheme}://{parsed.netloc}{path}?{query}"
return f"{parsed.scheme}://{parsed.netloc}{path}"
def compute_content_fingerprint(
task_type: str,
*,
video_url: str | None = None,
audio_url: str | None = None,
ref_photo_url: str | None = None,
text: str | None = None,
voice_id: str | None = None,
) -> str:
"""计算内容指纹。
指纹字段选择原则只包含会影响 Vidu 审核结果的输入内容
不包含 callback_urlspeedvolumepayload 等业务/技术参数
Args:
task_type: 任务类型 "lip_sync", "tts", "clone_voice"
video_url: 视频 URL
audio_url: 音频 URL
ref_photo_url: 参考图片 URL
text: 文本/Prompt
voice_id: 音色 ID
Returns:
SHA256 十六进制指纹字符串
"""
parts = [
task_type.strip().lower(),
normalize_url(video_url),
normalize_url(audio_url),
normalize_url(ref_photo_url),
(text or "").strip(),
(voice_id or "").strip().lower(),
]
raw = "|".join(parts)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def is_vidu_audit_error(err_code: str | None) -> bool:
"""判断是否为 Vidu 审核类错误码。"""
if not err_code:
return False
return err_code.strip() in VIDU_AUDIT_ERROR_CODES
def extract_vidu_error_code(message: str | None) -> str | None:
"""从 Vidu 错误信息中提取错误码。
Vidu 错误信息格式通常为"ErrorCode: 中文描述"
"""
if not message:
return None
candidate = message.split(":")[0].strip()
return candidate or None
+3 -3
View File
@@ -16,9 +16,9 @@ def validate_file_magic(content: bytes, expected_content_type: str) -> bool:
# 拒绝常见危险文件头
dangerous_signatures = [
(b"MZ", "Windows 可执行文件"), # .exe, .dll
(b"#!", "Shell 脚本"), # bash, python, etc
(b"PK\x03\x04", "ZIP 压缩包"), # .zip, .jar, .docx
(b"MZ", "Windows 可执行文件"), # .exe, .dll
(b"#!", "Shell 脚本"), # bash, python, etc
(b"PK\x03\x04", "ZIP 压缩包"), # .zip, .jar, .docx
(b"<?xml", "XML 文件"),
(b"<html", "HTML 文件"),
(b"<!DO", "HTML 文档"),
+13 -15
View File
@@ -1,6 +1,6 @@
[project]
name = "meijiaka-ai-api"
version = "1.8.2"
version = "1.9.1"
description = "美家卡智影 - AI 视频创作后端 API"
authors = [{ name = "Meijiaka Team" }]
readme = "README.md"
@@ -15,9 +15,10 @@ classifiers = [
dependencies = [
# Web 框架 (FastAPI 0.116+ 修复 Starlette 安全漏洞)
"fastapi>=0.136.1",
"fastapi>=0.115.8",
"uvicorn[standard]~=0.32.0",
"python-multipart~=0.0.20",
"python-multipart>=0.0.27",
"pyasn1>=0.6.3", # 安全修复:间接依赖,强制升级
# 认证 & 安全
"passlib[bcrypt]~=1.7.4",
@@ -47,14 +48,14 @@ dependencies = [
"volcengine-python-sdk[ark]~=5.0.0",
# HTTP 客户端
"httpx~=0.28.0",
"aiohttp>=3.13.4", # 安全修复:修复 CVE-2025-XXXX 系列漏洞
"httpx>=0.28.0",
"aiohttp>=3.14.0", # 安全修复:修复 CVE-2025-XXXX 系列漏洞
# 对象存储
"qiniu~=7.13.0",
# 工具
"pyjwt~=2.10.0",
"pyjwt>=2.13.0",
"pyyaml~=6.0.2",
"orjson>=3.11.0", # 安全修复:修复 CVE-2025-XXXX
@@ -68,12 +69,12 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pytest~=8.3.0",
"pytest-asyncio~=0.24.0",
"pytest-cov~=6.0.0",
"ruff~=0.8.0",
"black~=24.10.0",
"mypy~=1.14.0",
"pytest>=9.0.3",
"pytest-asyncio>=0.24.0",
"pytest-cov>=6.0.0",
"ruff>=0.8.0",
"black>=26.3.1",
"mypy>=1.14.0",
"bandit[toml]~=1.8.0", # 安全扫描
"pip-audit~=2.7.0", # 漏洞检测
"pre-commit~=4.0.0", # Git 钩子
@@ -99,7 +100,6 @@ ignore = ["E501", "E402", "N802", "N803", "N806", "N815", "B008", "B904"]
[tool.mypy]
python_version = "3.13"
strict = false
warn_return_any = false
warn_unused_configs = true
ignore_missing_imports = true
@@ -112,7 +112,6 @@ disallow_incomplete_defs = false
# ========== 重构防护网:新代码严格模式 ==========
[[tool.mypy.overrides]]
module = ["app.schemas.*", "app.crud.*", "app.scheduler.handlers.*"]
strict = true
warn_return_any = true
check_untyped_defs = true
disallow_untyped_defs = true
@@ -121,7 +120,6 @@ disallow_incomplete_defs = true
# Redis 客户端 typing 问题(Awaitable[T] | T),暂不严格检查
[[tool.mypy.overrides]]
module = ["app.scheduler.registry", "app.scheduler.slot_manager"]
strict = false
check_untyped_defs = false
disallow_untyped_defs = false
+24 -20
View File
@@ -1,8 +1,8 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.lock
aiohappyeyeballs==2.6.1
aiohappyeyeballs==2.6.2
# via aiohttp
aiohttp==3.13.5
aiohttp==3.14.1
# via meijiaka-ai-api (pyproject.toml)
aiosignal==1.4.0
# via aiohttp
@@ -27,7 +27,7 @@ bcrypt==4.2.1
# via
# meijiaka-ai-api (pyproject.toml)
# passlib
certifi==2026.4.22
certifi==2026.5.20
# via
# httpcore
# httpx
@@ -37,19 +37,19 @@ cffi==2.0.0
# via cryptography
charset-normalizer==3.4.7
# via requests
click==8.3.3
click==8.4.1
# via uvicorn
cryptography==48.0.0
cryptography==48.0.1
# via volcengine-python-sdk
distro==1.9.0
# via openai
fastapi==0.136.1
fastapi==0.136.3
# via meijiaka-ai-api (pyproject.toml)
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
greenlet==3.5.0
greenlet==3.5.1
# via sqlalchemy
h11==0.16.0
# via
@@ -57,20 +57,20 @@ h11==0.16.0
# uvicorn
httpcore==1.0.9
# via httpx
httptools==0.7.1
httptools==0.8.0
# via uvicorn
httpx==0.28.1
# via
# meijiaka-ai-api (pyproject.toml)
# openai
# volcengine-python-sdk
idna==3.13
idna==3.18
# via
# anyio
# httpx
# requests
# yarl
jiter==0.14.0
jiter==0.15.0
# via openai
mako==1.3.12
# via alembic
@@ -88,12 +88,16 @@ orjson==3.11.9
# via meijiaka-ai-api (pyproject.toml)
passlib==1.7.4
# via meijiaka-ai-api (pyproject.toml)
propcache==0.4.1
pillow==12.2.0
# via meijiaka-ai-api (pyproject.toml)
propcache==0.5.2
# via
# aiohttp
# yarl
psycopg2-binary==2.9.12
# via meijiaka-ai-api (pyproject.toml)
pyasn1==0.6.3
# via meijiaka-ai-api (pyproject.toml)
pycparser==3.0
# via cffi
pydantic==2.9.2
@@ -107,7 +111,7 @@ pydantic-core==2.23.4
# via pydantic
pydantic-settings==2.6.1
# via meijiaka-ai-api (pyproject.toml)
pyjwt==2.10.1
pyjwt==2.13.0
# via meijiaka-ai-api (pyproject.toml)
python-dateutil==2.9.0.post0
# via volcengine-python-sdk
@@ -115,7 +119,7 @@ python-dotenv==1.2.2
# via
# pydantic-settings
# uvicorn
python-multipart==0.0.27
python-multipart==0.0.32
# via meijiaka-ai-api (pyproject.toml)
pyyaml==6.0.3
# via
@@ -125,7 +129,7 @@ qiniu==7.13.2
# via meijiaka-ai-api (pyproject.toml)
redis==5.2.1
# via meijiaka-ai-api (pyproject.toml)
requests==2.33.1
requests==2.34.2
# via qiniu
six==1.17.0
# via
@@ -133,15 +137,15 @@ six==1.17.0
# volcengine-python-sdk
sniffio==1.3.1
# via openai
sqlalchemy==2.0.49
sqlalchemy==2.0.50
# via
# meijiaka-ai-api (pyproject.toml)
# alembic
starlette==1.0.0
starlette==1.3.0
# via fastapi
tenacity==9.0.0
# via meijiaka-ai-api (pyproject.toml)
tqdm==4.67.3
tqdm==4.68.2
# via openai
typing-extensions==4.15.0
# via
@@ -162,11 +166,11 @@ uvicorn==0.32.1
# via meijiaka-ai-api (pyproject.toml)
uvloop==0.22.1
# via uvicorn
volcengine-python-sdk==5.0.26
volcengine-python-sdk==5.0.34
# via meijiaka-ai-api (pyproject.toml)
watchfiles==1.1.1
watchfiles==1.2.0
# via uvicorn
websockets==16.0
# via uvicorn
yarl==1.23.0
yarl==1.24.2
# via aiohttp
View File
+116
View File
@@ -0,0 +1,116 @@
"""
异常体系单元测试
================
验证 AppException / InsufficientPointsException 的结构化字段
确保前端可以通过 error_code 识别错误类型
"""
from fastapi import HTTPException
from app.core.exceptions import (
AIEmptyResponseException,
AIParseErrorException,
AITimeoutException,
AppException,
BusinessException,
InsufficientPointsException,
NotFoundException,
PromptNotFoundException,
ValidationException,
)
class TestAppException:
"""业务异常基类"""
def test_app_exception_has_error_code(self) -> None:
exc = AppException(
status_code=400,
message="参数错误",
error_code="validation_error",
)
assert exc.status_code == 400
assert exc.message == "参数错误"
assert exc.error_code == "validation_error"
assert exc.detail == {"message": "参数错误", "error_code": "validation_error"}
def test_app_exception_detail_can_be_dict(self) -> None:
exc = AppException(
status_code=422,
message="字段校验失败",
detail={"fields": {"name": "required"}},
error_code="validation_error",
)
assert exc.detail == {
"fields": {"name": "required"},
"message": "字段校验失败",
"error_code": "validation_error",
}
def test_subclasses_without_error_code(self) -> None:
exc = NotFoundException("资源不存在")
assert exc.status_code == 404
assert exc.message == "资源不存在"
assert exc.error_code is None
class TestInsufficientPointsException:
"""积分不足异常"""
def test_default_fields(self) -> None:
exc = InsufficientPointsException()
assert exc.status_code == 402
assert exc.message == "积分不足"
assert exc.error_code == "insufficient_points"
assert isinstance(exc, HTTPException)
def test_custom_message(self) -> None:
exc = InsufficientPointsException("余额不足,请先充值")
assert exc.message == "余额不足,请先充值"
assert exc.error_code == "insufficient_points"
def test_detail_structure(self) -> None:
exc = InsufficientPointsException("需要 10 积分")
assert exc.detail == {
"message": "需要 10 积分",
"error_code": "insufficient_points",
}
class TestOtherSubclasses:
"""其他常用子类"""
def test_validation_exception(self) -> None:
exc = ValidationException("字段缺失")
assert exc.status_code == 422
assert exc.message == "字段缺失"
def test_business_exception(self) -> None:
exc = BusinessException("业务状态不允许")
assert exc.status_code == 400
assert exc.message == "业务状态不允许"
class TestAIStructuredExceptions:
"""AI 相关结构化异常"""
def test_prompt_not_found_exception(self) -> None:
exc = PromptNotFoundException("未找到提示词")
assert exc.status_code == 404
assert exc.error_code == "prompt_not_found"
def test_ai_empty_response_exception(self) -> None:
exc = AIEmptyResponseException("AI 返回为空")
assert exc.status_code == 500
assert exc.error_code == "empty_result"
def test_ai_parse_error_exception(self) -> None:
exc = AIParseErrorException("解析失败")
assert exc.status_code == 500
assert exc.error_code == "parse_error"
def test_ai_timeout_exception(self) -> None:
exc = AITimeoutException("请求超时")
assert exc.status_code == 504
assert exc.error_code == "timeout"

Some files were not shown because too many files have changed in this diff Show More