Files

538 lines
22 KiB
Diff
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
diff --git a/python-api/app/api/v1/caption.py b/python-api/app/api/v1/caption.py
index afae831..99a771a 100644
--- a/python-api/app/api/v1/caption.py
+++ b/python-api/app/api/v1/caption.py
@@ -24,19 +24,6 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/caption", tags=["Caption"])
-
-
-
-
-
-
-
-
-
-
-
-
-
@router.post("/ata/align")
async def auto_align_caption(
request_body: AutoAlignSubmitRequest,
@@ -88,9 +75,3 @@ async def auto_align_caption(
except Exception as e:
logger.error(f"自动打轴异常: {e}")
raise HTTPException(status_code=500, detail="字幕打轴失败,请稍后重试")
-
-
-
-
-
-
diff --git a/python-api/app/api/v1/points.py b/python-api/app/api/v1/points.py
index b2c7208..23bafa9 100644
--- a/python-api/app/api/v1/points.py
+++ b/python-api/app/api/v1/points.py
@@ -6,7 +6,7 @@
"""
import logging
-from datetime import UTC, datetime
+from datetime import UTC, datetime, timedelta
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user, get_db
+from app.core.exceptions import InsufficientPointsException
from app.crud.point_recharge_order import point_recharge_order
from app.crud.point_transaction import point_transaction
from app.models.user import User
@@ -35,6 +36,7 @@ router = APIRouter(prefix="/points", tags=["Points"])
# ── 余额查询 ──────────────────────────────────────────
+
@router.get("/balance", response_model=ApiResponse[PointBalanceResponse])
async def get_balance(
db: AsyncSession = Depends(get_db),
@@ -47,6 +49,7 @@ async def get_balance(
# ── 流水查询 ──────────────────────────────────────────
+
@router.get("/transactions", response_model=ApiResponse[PointTransactionListResponse])
async def list_transactions(
pagination: PaginationParams = Depends(),
@@ -127,12 +130,13 @@ async def list_transactions(
# ── 充值 ──────────────────────────────────────────────
+
@router.post("/recharge", response_model=ApiResponse[RechargeResponse])
async def create_recharge_order(
request: RechargeRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
- http_request: Request = None,
+ http_request: Request = None, # type: ignore[assignment]
):
"""
创建积分充值订单(微信支付 Native 扫码)
@@ -303,9 +307,7 @@ async def handle_wxpay_notify(
return _wx_response()
# 查找订单
- order = await point_recharge_order.get_by_out_trade_no(
- db, out_trade_no=out_trade_no
- )
+ order = await point_recharge_order.get_by_out_trade_no(db, out_trade_no=out_trade_no)
if not order:
logger.error(f"[WechatPay] 回调订单不存在: {out_trade_no}")
return _wx_response()
@@ -400,9 +402,7 @@ async def query_recharge_status(
wxpay = get_wxpay_service()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0, connect=10.0)) as client:
- wx_result = await wxpay.query_order(
- client, out_trade_no=order.out_trade_no
- )
+ wx_result = await wxpay.query_order(client, out_trade_no=order.out_trade_no)
order.query_result = str(wx_result)
trade_state = wx_result.get("trade_state", "")
@@ -465,6 +465,7 @@ async def query_recharge_status(
# ── 充值档位查询 ──────────────────────────────────────
+
@router.get("/recharge-options", response_model=ApiResponse[list[dict]])
async def get_recharge_options(
current_user: User = Depends(get_current_user),
@@ -480,6 +481,7 @@ async def get_recharge_options(
# ── 扣费业务类型查询 ──────────────────────────────────
+
@router.get("/chargeable-types", response_model=ApiResponse[list[str]])
async def get_chargeable_types(
current_user: User = Depends(get_current_user),
@@ -496,6 +498,7 @@ async def get_chargeable_types(
# ── 积分规则查询 ──────────────────────────────────────
+
@router.get("/rules", response_model=ApiResponse[list[dict]])
async def get_points_rules(
current_user: User = Depends(get_current_user),
@@ -530,9 +533,9 @@ async def get_points_rules(
# ── 积分预估查询 ──────────────────────────────────────
-
# ── 今日消费统计 ──────────────────────────────────────
+
@router.get("/today-consumed", response_model=ApiResponse[dict])
async def get_today_consumed(
db: AsyncSession = Depends(get_db),
@@ -545,6 +548,7 @@ async def get_today_consumed(
# ── 直接消费扣费(前端/Rust 层调用)───────────────────
+
@router.post("/consume", response_model=ApiResponse[dict])
async def consume_points(
request: ConsumeRequest,
@@ -569,12 +573,12 @@ async def consume_points(
source_type=request.source_type,
source_id=request.source_id,
description=f"【{request.description or request.source_type}】",
- allow_negative=False,
+ allow_negative=request.allow_negative,
)
- except ValueError as e:
+ except InsufficientPointsException:
# 余额不足(在同一事务内判断,避免竞态)
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
+ raise
+ except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
await db.commit()
@@ -587,5 +591,3 @@ async def consume_points(
},
message="消费成功",
)
-
-
diff --git a/python-api/app/api/v1/script.py b/python-api/app/api/v1/script.py
index 5e7b18f..26a11a3 100644
--- a/python-api/app/api/v1/script.py
+++ b/python-api/app/api/v1/script.py
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.ai.model_router import get_model_router
from app.ai.prompts import list_categories, list_prompt_files, load_prompt, render_template
from app.api.deps import get_current_user
+from app.core.exceptions import AITimeoutException, InsufficientPointsException
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -71,9 +72,8 @@ async def polish_content(
required_points = ps._calculate_cost("polish")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -99,11 +99,11 @@ async def polish_content(
data=polished,
message=f"{type_name}润色完成",
)
+ except InsufficientPointsException:
+ raise
except HTTPException:
raise
except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
logger.warning(f"[Polish] 润色失败: {e}")
raise HTTPException(status_code=500, detail="润色失败,请检查输入内容后重试")
except Exception as e:
@@ -111,9 +111,6 @@ async def polish_content(
raise HTTPException(status_code=500, detail=f"{type_name}润色失败,请稍后重试")
-
-
-
@router.post("/generate-title", response_model=ApiResponse[GenerateTitleResponse])
async def generate_title(
request: GenerateTitleRequest,
@@ -146,7 +143,11 @@ async def generate_title(
usage_note = "- 视频画面上的标题需要精炼,聚焦核心关键词\n- 副标题与主标题形成呼应,补充说明但不喧宾夺主"
# 渲染用户提示词
- title_type_desc = "大标题(主标题,提炼核心卖点,吸睛)" if request.title_type == "main" else "小标题(副标题,补充说明或制造悬念)"
+ title_type_desc = (
+ "大标题(主标题,提炼核心卖点,吸睛)"
+ if request.title_type == "main"
+ else "小标题(副标题,补充说明或制造悬念)"
+ )
user_prompt = render_template(
user_template,
title_type=request.title_type,
@@ -163,9 +164,8 @@ async def generate_title(
required_points = ps._calculate_cost("title")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -179,10 +179,10 @@ async def generate_title(
title = result.content.strip() if result.content else ""
# 去除可能的引号
- title = title.strip('"').strip("'").strip('「」').strip('『』').strip('《》')
+ title = title.strip('"').strip("'").strip("「」").strip("『』").strip("《》")
# 截断到最大长度
if len(title) > request.max_length:
- title = title[:request.max_length]
+ title = title[: request.max_length]
# 扣费
points = ps._calculate_cost("title")
@@ -200,15 +200,13 @@ async def generate_title(
data=GenerateTitleResponse(title=title),
message="标题生成成功",
)
+ except InsufficientPointsException:
+ raise
except HTTPException:
raise
except TimeoutError:
logger.warning("[generate_title] 标题生成超时")
- raise HTTPException(status_code=500, detail="标题生成超时,请重试")
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
+ raise AITimeoutException("标题生成超时,请稍后重试")
except Exception as e:
logger.error(f"[generate_title] 标题生成失败: {e}")
raise HTTPException(status_code=500, detail=f"标题生成失败: {str(e)}")
diff --git a/python-api/app/api/v1/tasks.py b/python-api/app/api/v1/tasks.py
index ae17fe3..55965ca 100644
--- a/python-api/app/api/v1/tasks.py
+++ b/python-api/app/api/v1/tasks.py
@@ -18,6 +18,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, field_validator
from app.api.deps import get_current_user
+from app.core.exceptions import InsufficientPointsException
from app.core.redis_client import get_redis_client
from app.db.session import AsyncSessionLocal
from app.models.user import User
@@ -38,7 +39,6 @@ class ScriptParams(BaseModel):
category: str = Field(..., min_length=1, description="大类代码")
filename: str = Field(..., min_length=1, description="提示词文件名")
-
@field_validator("category")
@classmethod
def validate_category(cls, v: str) -> str:
@@ -96,7 +96,9 @@ class VideoParams(BaseModel):
volume: int = Field(default=0, ge=0, le=10, description="音量")
ref_photo_url: str | None = Field(default=None, description="人脸参考图 URL")
planned_duration: float = Field(..., gt=0, description="该分镜脚本规划时长(秒),用于余额预检")
- total_planned_duration: float = Field(..., gt=0, description="所有分镜规划时长之和(秒),用于预检")
+ total_planned_duration: float = Field(
+ ..., gt=0, description="所有分镜规划时长之和(秒),用于预检"
+ )
batch_id: str | None = Field(default=None, description="批次ID(可选)")
@field_validator("video_url")
@@ -134,6 +136,7 @@ class TaskStatusResponse(BaseModel):
total: int = Field(0, description="总子任务数")
result: dict | None = Field(None, description="任务结果(完成时)")
error: str | None = Field(None, description="错误信息(失败时)")
+ error_code: str | None = Field(None, description="错误码(失败时,如 content_violation")
created_at: str = Field("", description="任务创建时间(ISO格式)")
@@ -222,9 +225,8 @@ async def create_task(
f"[API] 积分不足: user={user_id}, type={task_type}, "
f"required={required_points}, balance={check['balance']}"
)
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
# ── 3. 写入 Redis ──────────────────────────────────
@@ -246,18 +248,41 @@ async def create_task(
params=validated_params,
)
await registry.add_running(task_id)
-
- logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
- return TaskCreateResponse(
- task_id=task_id,
- status="running",
- message=f"{task_type} 任务已创建",
- )
-
except Exception as e:
logger.error(f"[API] Failed to update registry: {e}")
raise HTTPException(status_code=500, detail="创建任务失败:Redis写入错误")
+ # ── 4. 脚本生成:Redis 写入成功后再扣费 ─────────────
+ if task_type == "script" and required_points > 0:
+ try:
+ async with AsyncSessionLocal() as db:
+ await ps.consume(
+ db,
+ user_id=user_id,
+ points=required_points,
+ source_type="script",
+ source_id=task_id,
+ description="【脚本生成】",
+ )
+ await db.commit()
+ except InsufficientPointsException:
+ # 余额不足:将任务标记为失败,避免无费执行
+ await registry.update(task_id, status="failed", message="积分不足")
+ await registry.remove_running(task_id)
+ raise
+ except Exception as e:
+ logger.error(f"[API] 脚本任务扣费失败: {e}")
+ await registry.update(task_id, status="failed", message="扣费失败")
+ await registry.remove_running(task_id)
+ raise HTTPException(status_code=500, detail="扣费失败,请稍后重试")
+
+ logger.info(f"[API] Task created: {task_id}, type={task_type}, user={user_id}")
+ return TaskCreateResponse(
+ task_id=task_id,
+ status="running",
+ message=f"{task_type} 任务已创建",
+ )
+
@router.get("", response_model=list[TaskStatusResponse])
async def list_tasks(
@@ -294,6 +319,7 @@ async def list_tasks(
total=task.total,
result=None, # 列表查询不返回 result,避免数据过大
error=task.error,
+ error_code=task.error_code,
created_at=task.created_at,
)
)
@@ -337,6 +363,7 @@ async def get_task_status(
total=task.total,
result=task.result,
error=task.error,
+ error_code=task.error_code,
created_at=task.created_at,
)
diff --git a/python-api/app/api/v1/vidu.py b/python-api/app/api/v1/vidu.py
index c0fb04a..401d8da 100644
--- a/python-api/app/api/v1/vidu.py
+++ b/python-api/app/api/v1/vidu.py
@@ -44,10 +44,12 @@ async def vidu_callback(request: Request):
headers_dict = dict(request.headers)
# 使用 APP_BASE_URL 构建 callback_url,确保与提交任务时传给 Vidu 的一致
- #Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
+ # Nginx 反向代理可能导致 request.url 的 scheme 为 http,与 Vidu 签名时的 https 不一致)
app_base_url = get_settings().app_base_url
callback_url = f"{app_base_url}/api/v1/vidu/callback" if app_base_url else str(request.url)
- logger.info(f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}")
+ logger.info(
+ f"[Vidu] 收到回调: request_url={request.url}, callback_url={callback_url}, body={body_bytes.decode('utf-8', errors='replace')[:500]}"
+ )
try:
task_status = await gateway.handle_webhook(
@@ -64,15 +66,13 @@ async def vidu_callback(request: Request):
logger.error(f"[Vidu] 回调处理失败: {e}")
raise HTTPException(status_code=500, detail="回调处理失败,请稍后重试")
- logger.info(f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}")
+ logger.info(
+ f"[Vidu] 回调解析完成: state={task_status.state}, result={task_status.result}, error={task_status.error_message}"
+ )
# 2. 通过 platform_task_id 反查 Async Engine 内部 task_id,更新 TaskRegistry
- platform_task_id = (
- task_status.result.get("task_id") if task_status.result else None
- )
- video_url = (
- task_status.result.get("video_url") if task_status.result else None
- )
+ platform_task_id = task_status.result.get("task_id") if task_status.result else None
+ video_url = task_status.result.get("video_url") if task_status.result else None
logger.info(f"[Vidu] 准备反查 internal_task_id: platform_task_id={platform_task_id}")
@@ -121,8 +121,6 @@ async def vidu_callback(request: Request):
f"platform={platform_task_id}"
)
else:
- logger.warning(
- f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}"
- )
+ logger.warning(f"[Vidu] 回调无法反查内部 task_id: platform={platform_task_id}")
return success_response(message="回调已接收")
diff --git a/python-api/app/api/v1/voice.py b/python-api/app/api/v1/voice.py
index ba8cac9..dad2b7b 100644
--- a/python-api/app/api/v1/voice.py
+++ b/python-api/app/api/v1/voice.py
@@ -10,12 +10,13 @@ import logging
import re
import time
import uuid
+
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
-from app.core.exceptions import PlatformError
+from app.core.exceptions import InsufficientPointsException, PlatformError
from app.db.session import get_db
from app.models.user import User
from app.schemas.common import ApiResponse, success_response
@@ -49,7 +50,9 @@ class TTSSynthesizeRequest(BaseModel):
class VoiceCloneSubmitRequest(BaseModel):
"""声音复刻提交请求"""
- source_audio_url: str | None = Field(None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)")
+ source_audio_url: str | None = Field(
+ None, description="源音频 URL5-30秒,mp3/wav,需公开可访问)"
+ )
source_video_url: str | None = Field(None, description="源视频 URL(可选)")
video_id: str | None = Field(None, description="历史作品ID(可选)")
voice_name: str | None = Field(None, description="自定义音色名称(≤20字符)")
@@ -111,7 +114,7 @@ async def synthesize_speech(
# 宽松预检:余额为负或零时阻止,避免浪费第三方资源
balance_info = await ps.get_user_balance(db, current_user.id)
if balance_info["balance"] <= 0:
- raise HTTPException(status_code=402, detail="余额不足,请先充值")
+ raise InsufficientPointsException("余额不足,请先充值")
try:
audio_url = await service.synthesize(
@@ -137,10 +140,8 @@ async def synthesize_speech(
allow_negative=True,
)
await db.commit()
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- logger.error(f"[Voice] TTS 扣费失败: {e}")
+ except InsufficientPointsException:
+ raise
except Exception as e:
logger.error(f"[Voice] TTS 扣费失败: {e}")
@@ -165,7 +166,6 @@ async def synthesize_speech(
raise HTTPException(status_code=500, detail="语音合成失败,请稍后重试")
-
def _normalize_voice_id(name: str | None) -> str:
"""
将用户输入的名称规范化为 Vidu 合法的 voice_id。
@@ -220,9 +220,8 @@ async def submit_clone_task(
required_points = ps._calculate_cost("voice_clone")
check = await ps.check_balance(db, current_user.id, required_points)
if not check["sufficient"]:
- raise HTTPException(
- status_code=402,
- detail=f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}",
+ raise InsufficientPointsException(
+ f"积分不足,需要 {required_points} 积分,当前余额 {check['balance']}"
)
try:
@@ -244,10 +243,8 @@ async def submit_clone_task(
description="【声音复刻】",
)
await db.commit()
- except ValueError as e:
- if "积分不足" in str(e):
- raise HTTPException(status_code=402, detail=str(e))
- logger.error(f"[Voice] 克隆扣费失败: {e}")
+ except InsufficientPointsException:
+ raise
except Exception as e:
logger.error(f"[Voice] 克隆扣费失败: {e}")
@@ -292,5 +289,3 @@ async def query_clone_task(
),
message="克隆已完成",
)
-
-