7550559aa0
- 删除8个未使用IPC命令,保留validate_media_path - file.rs返回类型优化为ApiResponse<()> - point_service.consume()注释与签名一致 - VideoGeneration改为拼接成功后扣费 - 添加漏扣费风险注释 - 删除过时测试文件 - 修复camelToSnake连续大写字母问题 - vidu.py import移至模块顶层 Refs: P1-1~P1-6 技术债务清理
483 lines
15 KiB
Python
483 lines
15 KiB
Python
"""
|
|
积分系统 Service 层
|
|
===================
|
|
|
|
核心能力:
|
|
1. 余额查询
|
|
2. 充值(直接到账 / 微信回调后到账)
|
|
3. 消费(后置扣费:执行业务 → 出结果 → 直接扣费)
|
|
4. 过期回收
|
|
5. 流水记录
|
|
|
|
设计原则:
|
|
- 所有业务操作在一个事务内完成(balance + batch + transaction 三者原子性)。
|
|
- FIFO 批次消耗:按 expired_at 升序扣减。
|
|
- 后置扣费模式:先执行业务,出结果后按实际消耗扣费。
|
|
- 允许欠费(单次业务实际消耗超出预估上限),但欠费后不可继续使用。
|
|
|
|
注意:本 Service 不自行 commit,由调用方(API 层)通过 FastAPI Depends 注入的
|
|
Session 统一提交。所有操作在调用方事务内原子执行。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from datetime import UTC, datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import yaml
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.point_batch import PointBatch
|
|
from app.models.point_transaction import PointTransaction
|
|
from app.models.user_point import UserPoint
|
|
|
|
if TYPE_CHECKING:
|
|
from uuid import UUID
|
|
|
|
|
|
# ── 配置加载 ──────────────────────────────────────────
|
|
|
|
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
|
|
|
|
|
|
def _load_points_config() -> dict:
|
|
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
|
|
if not _CONFIG_PATH.exists():
|
|
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
|
|
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
|
cfg = yaml.safe_load(f) or {}
|
|
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
|
|
merged: dict[str, dict] = {}
|
|
for key, points in cfg.get("fixed_costs", {}).items():
|
|
merged[key] = {"mode": "fixed", "points": points}
|
|
for key, rule in cfg.get("duration_based_costs", {}).items():
|
|
merged[key] = {"mode": "duration", **rule}
|
|
for key in cfg.get("free_services", []):
|
|
merged[key] = {"mode": "free", "points": 0}
|
|
# 保留充值档位原始配置
|
|
merged["_recharge_options"] = cfg.get("recharge_options", [])
|
|
return merged
|
|
|
|
|
|
POINTS_CONFIG: dict[str, dict] = _load_points_config()
|
|
|
|
|
|
def get_recharge_options() -> list[dict]:
|
|
"""获取充值档位配置(由后端控制,支持积分赠送)"""
|
|
return POINTS_CONFIG.get("_recharge_options", [])
|
|
|
|
|
|
def get_chargeable_source_types() -> list[str]:
|
|
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
|
|
return [
|
|
key for key, cfg in POINTS_CONFIG.items()
|
|
if not key.startswith("_") and cfg.get("mode") != "free"
|
|
]
|
|
|
|
|
|
EXPIRATION_DAYS = 180
|
|
|
|
|
|
def _now() -> datetime:
|
|
"""返回带时区的当前时间"""
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def _calculate_cost(source_type: str, param: dict | None = None) -> int:
|
|
"""根据消费类型和实际结果参数计算所需积分(后置扣费时使用)"""
|
|
if source_type not in POINTS_CONFIG:
|
|
raise ValueError(f"未知的消费类型: {source_type}")
|
|
|
|
cfg = POINTS_CONFIG[source_type]
|
|
mode = cfg["mode"]
|
|
|
|
if mode == "free":
|
|
return 0
|
|
|
|
if mode == "fixed":
|
|
return cfg["points"]
|
|
|
|
if mode == "duration":
|
|
if param is None:
|
|
raise ValueError(f"消费类型 {source_type} 需要提供参数才能计算积分")
|
|
seconds = param.get("seconds", 0)
|
|
min_points = cfg.get("min_points", 1)
|
|
if "divisor" in cfg:
|
|
return max(min_points, math.ceil(seconds / cfg["divisor"]))
|
|
if "multiplier" in cfg:
|
|
return max(min_points, math.ceil(seconds) * cfg["multiplier"])
|
|
|
|
raise ValueError(f"消费类型 {source_type} 缺少计算规则")
|
|
|
|
|
|
def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
|
|
"""
|
|
预估消费上限(执行业务前检查余额用)。
|
|
|
|
按秒计费业务取保守估计,确保实际消耗不会超过预估上限。
|
|
"""
|
|
if source_type not in POINTS_CONFIG:
|
|
raise ValueError(f"未知的消费类型: {source_type}")
|
|
|
|
cfg = POINTS_CONFIG[source_type]
|
|
mode = cfg["mode"]
|
|
|
|
if mode == "free":
|
|
return 0
|
|
|
|
if mode == "fixed":
|
|
return cfg["points"]
|
|
|
|
if mode == "duration":
|
|
if param is None:
|
|
raise ValueError(f"消费类型 {source_type} 需要提供参数才能预估积分")
|
|
|
|
est = cfg.get("estimation", {})
|
|
min_points = cfg.get("min_points", 1)
|
|
|
|
if "seconds_per_char" in est:
|
|
# TTS 模式:字数 → 预估秒数
|
|
char_count = param.get("char_count", 0)
|
|
if char_count <= 0:
|
|
raise ValueError("TTS 预估需要提供 char_count 参数")
|
|
estimated_seconds = char_count * est["seconds_per_char"]
|
|
if "divisor" in cfg:
|
|
return max(min_points, math.ceil(estimated_seconds / cfg["divisor"]))
|
|
|
|
if est.get("use_input_seconds"):
|
|
# 视频模式:直接使用输入秒数作为预估上限
|
|
seconds = param.get("input_seconds", 0)
|
|
if seconds <= 0:
|
|
raise ValueError("video 预估需要提供 input_seconds 参数")
|
|
if "multiplier" in cfg:
|
|
return max(min_points, math.ceil(seconds) * cfg["multiplier"])
|
|
|
|
raise ValueError(f"消费类型 {source_type} 缺少预估规则")
|
|
|
|
|
|
# ── 余额查询 ──────────────────────────────────────────
|
|
|
|
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
|
|
"""获取用户积分余额快照"""
|
|
result = await db.execute(
|
|
select(UserPoint).where(UserPoint.user_id == user_id)
|
|
)
|
|
up = result.scalar_one_or_none()
|
|
|
|
if not up:
|
|
return {
|
|
"balance": 0,
|
|
"total_recharged": 0,
|
|
"total_consumed": 0,
|
|
"total_expired": 0,
|
|
}
|
|
|
|
return {
|
|
"balance": up.balance,
|
|
"total_recharged": up.total_recharged,
|
|
"total_consumed": up.total_consumed,
|
|
"total_expired": up.total_expired,
|
|
}
|
|
|
|
|
|
async def check_balance(
|
|
db: AsyncSession,
|
|
user_id: UUID | str,
|
|
required_points: int = 0,
|
|
) -> dict:
|
|
"""
|
|
检查用户余额是否足够。
|
|
|
|
:param required_points: 需要的积分数量
|
|
:return: {"sufficient": bool, "balance": int, "required": int}
|
|
"""
|
|
balance_info = await get_user_balance(db, user_id)
|
|
balance = balance_info["balance"]
|
|
return {
|
|
"sufficient": balance >= required_points,
|
|
"balance": balance,
|
|
"required": required_points,
|
|
}
|
|
|
|
|
|
# ── 充值 ──────────────────────────────────────────────
|
|
|
|
async def recharge(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
points: int,
|
|
source: str,
|
|
description: str = "",
|
|
order_id: int | None = None,
|
|
batch_expired_at: datetime | None = None,
|
|
) -> PointTransaction:
|
|
"""
|
|
直接给用户账户充值积分。
|
|
|
|
:param points: 正整数,充值积分数量
|
|
:param source: wxpay / invite / gift / compensation
|
|
:param order_id: 关联的充值订单 ID(仅 wxpay 时填)
|
|
:param batch_expired_at: 该批次过期时间,默认 180 天后
|
|
"""
|
|
if points <= 0:
|
|
raise ValueError("充值积分必须大于 0")
|
|
|
|
now = _now()
|
|
|
|
# 幂等保护:同一笔订单(order_id)只能充值一次
|
|
if order_id:
|
|
existing_result = await db.execute(
|
|
select(PointTransaction)
|
|
.where(
|
|
PointTransaction.source_id == str(order_id),
|
|
PointTransaction.type == "recharge",
|
|
)
|
|
)
|
|
existing_tx = existing_result.scalar_one_or_none()
|
|
if existing_tx:
|
|
logger.warning(f"[Points] 订单 {order_id} 已充值过,跳过重复充值")
|
|
return existing_tx
|
|
|
|
# 1. 获取或创建用户积分账户
|
|
result = await db.execute(
|
|
select(UserPoint).where(UserPoint.user_id == user_id)
|
|
)
|
|
up = result.scalar_one_or_none()
|
|
|
|
if not up:
|
|
up = UserPoint(
|
|
user_id=user_id,
|
|
balance=0,
|
|
total_recharged=0,
|
|
total_consumed=0,
|
|
total_expired=0,
|
|
)
|
|
db.add(up)
|
|
await db.flush()
|
|
|
|
# 2. 增加余额
|
|
balance_before = up.balance
|
|
up.balance += points
|
|
up.total_recharged += points
|
|
|
|
# 3. 写入批次(欠费充值时,先用新积分偿还欠费,剩余部分才写入批次)
|
|
debt = max(0, -balance_before) # 欠费金额
|
|
batch_remaining = max(0, points - debt) # 实际可用的批次积分
|
|
|
|
batch_id = None
|
|
if batch_remaining > 0:
|
|
expired_at = batch_expired_at or (now + timedelta(days=EXPIRATION_DAYS))
|
|
batch = PointBatch(
|
|
user_id=user_id,
|
|
amount=batch_remaining,
|
|
remaining=batch_remaining,
|
|
expired_at=expired_at,
|
|
source=source,
|
|
)
|
|
db.add(batch)
|
|
await db.flush()
|
|
batch_id = batch.id
|
|
|
|
# 4. 写流水
|
|
tx = PointTransaction(
|
|
user_id=user_id,
|
|
type="recharge",
|
|
amount=points,
|
|
balance_before=balance_before,
|
|
balance_after=up.balance,
|
|
source_type=source,
|
|
source_id=str(order_id) if order_id else None,
|
|
batch_id=batch_id,
|
|
category=_CATEGORY_MAP.get(source, "充值"),
|
|
description=description or f"{source} 充值 {points} 积分",
|
|
)
|
|
db.add(tx)
|
|
|
|
return tx
|
|
|
|
|
|
# ── 消费(后置扣费)───────────────────────────────────
|
|
|
|
# source_type → category 映射(用于流水分类展示)
|
|
_CATEGORY_MAP: dict[str, str] = {
|
|
"script": "脚本生成",
|
|
"polish": "文案润色",
|
|
"title": "标题生成",
|
|
"tts": "配音合成",
|
|
"voice_clone": "声音复刻",
|
|
"video": "视频生成",
|
|
"compose": "压制成片",
|
|
"subtitle_burn": "字幕烧录",
|
|
"cover_design": "封面设计",
|
|
"wxpay": "充值",
|
|
"compensation": "充值",
|
|
"invite": "充值",
|
|
"gift": "充值",
|
|
}
|
|
|
|
|
|
async def consume(
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
points: int,
|
|
source_type: str,
|
|
source_id: str,
|
|
description: str = "",
|
|
duration: float | None = None,
|
|
category: str | None = None,
|
|
allow_negative: bool = False,
|
|
) -> PointTransaction:
|
|
"""
|
|
直接扣费(后置计费)。
|
|
|
|
业务执行成功后调用,按实际消耗直接扣除余额。
|
|
默认不允许欠费(余额不足时抛出 ValueError)。
|
|
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
|
|
|
|
:param points: 实际消耗积分(正整数)
|
|
:param source_type: 消费来源类型
|
|
:param source_id: 关联的任务 ID 或订单 ID
|
|
:param allow_negative: 是否允许扣费后余额为负
|
|
:return: 消费流水记录
|
|
"""
|
|
if points <= 0:
|
|
raise ValueError("消费积分必须大于 0")
|
|
|
|
# 加锁顺序约定:先 PointBatch 后 UserPoint,避免与 expire_batches 死锁
|
|
# expire_batches 的锁顺序也是 PointBatch → UserPoint
|
|
|
|
# 1. FIFO 扣减批次 remaining(先加锁,与 expire_batches 顺序一致)
|
|
result = await db.execute(
|
|
select(PointBatch)
|
|
.where(
|
|
PointBatch.user_id == user_id,
|
|
PointBatch.remaining > 0,
|
|
PointBatch.expired_at > _now(),
|
|
)
|
|
.order_by(PointBatch.expired_at.asc())
|
|
.with_for_update()
|
|
)
|
|
batches: list[PointBatch] = list(result.scalars().all())
|
|
|
|
# 2. 获取用户积分账户(加锁)
|
|
result = await db.execute(
|
|
select(UserPoint)
|
|
.where(UserPoint.user_id == user_id)
|
|
.with_for_update()
|
|
)
|
|
up = result.scalar_one_or_none()
|
|
|
|
if not up:
|
|
# 没有积分账户也允许消费(形成欠费)
|
|
up = UserPoint(
|
|
user_id=user_id,
|
|
balance=0,
|
|
total_recharged=0,
|
|
total_consumed=0,
|
|
total_expired=0,
|
|
)
|
|
db.add(up)
|
|
await db.flush()
|
|
|
|
# 3. 余额检查(在同一事务内,避免竞态)
|
|
if not allow_negative and up.balance < points:
|
|
raise ValueError(f"积分不足,当前余额 {up.balance},需要 {points} 积分")
|
|
|
|
remaining_to_deduct = points
|
|
for batch in batches:
|
|
if remaining_to_deduct <= 0:
|
|
break
|
|
deduct = min(batch.remaining, remaining_to_deduct)
|
|
batch.remaining -= deduct
|
|
remaining_to_deduct -= deduct
|
|
|
|
# 4. 更新用户账户(允许欠费:balance 可能变负)
|
|
balance_before = up.balance
|
|
up.balance -= points
|
|
up.total_consumed += points
|
|
|
|
# 4. 写消费流水
|
|
tx = PointTransaction(
|
|
user_id=user_id,
|
|
type="consume",
|
|
amount=points,
|
|
balance_before=balance_before,
|
|
balance_after=up.balance,
|
|
source_type=source_type,
|
|
source_id=source_id,
|
|
batch_id=batches[0].id if batches else None,
|
|
duration=duration,
|
|
category=category or _CATEGORY_MAP.get(source_type, source_type),
|
|
description=description or f"消费 {source_type} {points} 积分",
|
|
)
|
|
db.add(tx)
|
|
|
|
return tx
|
|
|
|
|
|
# ── 过期回收 ──────────────────────────────────────────
|
|
|
|
async def expire_batches(db: AsyncSession) -> int:
|
|
"""
|
|
回收过期积分批次。返回过期积分总数。
|
|
|
|
这是一个批量维护操作,建议由定时任务调用。
|
|
使用 FOR UPDATE 锁定批次和用户积分账户,防止并发回收冲突。
|
|
"""
|
|
now = _now()
|
|
|
|
# 1. 获取过期批次(加锁)
|
|
result = await db.execute(
|
|
select(PointBatch)
|
|
.where(
|
|
PointBatch.expired_at <= now,
|
|
PointBatch.remaining > 0,
|
|
)
|
|
.with_for_update()
|
|
)
|
|
expired_batches: list[PointBatch] = list(result.scalars().all())
|
|
|
|
total_expired = 0
|
|
for batch in expired_batches:
|
|
recoverable = batch.remaining
|
|
if recoverable <= 0:
|
|
continue
|
|
|
|
# 获取用户积分账户(加锁)
|
|
up_result = await db.execute(
|
|
select(UserPoint)
|
|
.where(UserPoint.user_id == batch.user_id)
|
|
.with_for_update()
|
|
)
|
|
up = up_result.scalar_one_or_none()
|
|
if not up:
|
|
continue
|
|
|
|
balance_before = up.balance
|
|
up.balance -= recoverable
|
|
up.total_expired += recoverable
|
|
batch.remaining -= recoverable
|
|
total_expired += recoverable
|
|
|
|
# 写过期流水
|
|
tx = PointTransaction(
|
|
user_id=batch.user_id,
|
|
type="expire",
|
|
amount=recoverable,
|
|
balance_before=balance_before,
|
|
balance_after=up.balance,
|
|
source_type=batch.source,
|
|
source_id=None,
|
|
batch_id=batch.id,
|
|
category="过期回收",
|
|
description=f"积分批次过期回收 {recoverable} 积分",
|
|
)
|
|
db.add(tx)
|
|
|
|
return total_expired
|