Files
meijiaka-zy/python-api/app/services/point_service.py
T
小鱼开发 53371aabcd feat(image): 封面形象抠图增加积分消耗(每次 10 积分)
- config/points-config.yaml: 添加 cover_avatar: 10 固定积分
- point_service.py: _CATEGORY_MAP 添加 cover_avatar → 封面形象
- image.py: remove_background 接口前置余额检查 + 后置扣费
- CoverAvatarLibrary.tsx: 上传弹窗显示积分提示,余额不足友好提示
2026-05-23 10:59:47 +08:00

501 lines
16 KiB
Python

"""
积分系统 Service 层
===================
核心能力:
1. 余额查询
2. 充值(直接到账 / 微信回调后到账)
3. 消费(后置扣费:执行业务 → 出结果 → 直接扣费)
4. 过期回收
5. 流水记录
设计原则:
- 所有业务操作在一个事务内完成(balance + batch + transaction 三者原子性)。
- FIFO 批次消耗:按 expired_at 升序扣减。
- 后置扣费模式:先执行业务,出结果后按实际消耗扣费。
- 允许欠费(单次业务实际消耗超出预估上限),但欠费后不可继续使用。
注意:本 Service 不自行 commit,由调用方(API 层)通过 FastAPI Depends 注入的
Session 统一提交。所有操作在调用方事务内原子执行。
"""
from __future__ import annotations
import logging
import math
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING
import yaml
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
from app.models.point_batch import PointBatch
from app.models.point_transaction import PointTransaction
from app.models.user_point import UserPoint
if TYPE_CHECKING:
from uuid import UUID
# ── 配置加载 ──────────────────────────────────────────
_CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml"
def _load_points_config() -> dict:
"""加载积分计费配置。服务启动时读取一次,后续内存中使用。"""
if not _CONFIG_PATH.exists():
raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}")
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
# 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...}
merged: dict[str, dict] = {}
for key, points in cfg.get("fixed_costs", {}).items():
merged[key] = {"mode": "fixed", "points": points}
for key, rule in cfg.get("duration_based_costs", {}).items():
merged[key] = {"mode": "duration", **rule}
for key in cfg.get("free_services", []):
merged[key] = {"mode": "free", "points": 0}
# 保留充值档位原始配置
merged["_recharge_options"] = cfg.get("recharge_options", [])
return merged
POINTS_CONFIG: dict[str, dict] = _load_points_config()
def get_recharge_options() -> list[dict]:
"""获取充值档位配置(由后端控制,支持积分赠送)"""
return POINTS_CONFIG.get("_recharge_options", [])
def get_chargeable_source_types() -> list[str]:
"""获取所有需要扣费的业务类型列表(排除免费业务)"""
return [
key for key, cfg in POINTS_CONFIG.items()
if not key.startswith("_") and cfg.get("mode") != "free"
]
EXPIRATION_DAYS = 180
def _now() -> datetime:
"""返回带时区的当前时间"""
return datetime.now(UTC)
def _calculate_cost(source_type: str, param: dict | None = None) -> int:
"""根据消费类型和实际结果参数计算所需积分(后置扣费时使用)"""
if source_type not in POINTS_CONFIG:
raise ValueError(f"未知的消费类型: {source_type}")
cfg = POINTS_CONFIG[source_type]
mode = cfg["mode"]
if mode == "free":
return 0
if mode == "fixed":
return cfg["points"]
if mode == "duration":
if param is None:
raise ValueError(f"消费类型 {source_type} 需要提供参数才能计算积分")
seconds = param.get("seconds", 0)
min_points = cfg.get("min_points", 1)
if "divisor" in cfg:
return max(min_points, math.ceil(seconds / cfg["divisor"]))
if "multiplier" in cfg:
return max(min_points, math.ceil(seconds) * cfg["multiplier"])
raise ValueError(f"消费类型 {source_type} 缺少计算规则")
def _estimate_max_cost(source_type: str, param: dict | None = None) -> int:
"""
预估消费上限(执行业务前检查余额用)。
按秒计费业务取保守估计,确保实际消耗不会超过预估上限。
"""
if source_type not in POINTS_CONFIG:
raise ValueError(f"未知的消费类型: {source_type}")
cfg = POINTS_CONFIG[source_type]
mode = cfg["mode"]
if mode == "free":
return 0
if mode == "fixed":
return cfg["points"]
if mode == "duration":
if param is None:
raise ValueError(f"消费类型 {source_type} 需要提供参数才能预估积分")
est = cfg.get("estimation", {})
min_points = cfg.get("min_points", 1)
if "seconds_per_char" in est:
# TTS 模式:字数 → 预估秒数
char_count = param.get("char_count", 0)
if char_count <= 0:
raise ValueError("TTS 预估需要提供 char_count 参数")
estimated_seconds = char_count * est["seconds_per_char"]
if "divisor" in cfg:
return max(min_points, math.ceil(estimated_seconds / cfg["divisor"]))
if est.get("use_input_seconds"):
# 视频模式:直接使用输入秒数作为预估上限
seconds = param.get("input_seconds", 0)
if seconds <= 0:
raise ValueError("video 预估需要提供 input_seconds 参数")
if "multiplier" in cfg:
return max(min_points, math.ceil(seconds) * cfg["multiplier"])
raise ValueError(f"消费类型 {source_type} 缺少预估规则")
# ── 余额查询 ──────────────────────────────────────────
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
"""获取用户积分余额快照(实时计算,排除已过期批次)。"""
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
up = result.scalar_one_or_none()
if not up:
return {
"balance": 0,
"total_recharged": 0,
"total_consumed": 0,
"total_expired": 0,
}
# 实时计算可用余额(排除已过期批次),避免 expire_batches 延迟时数据不一致
from sqlalchemy import func as _func
available_result = await db.execute(
select(_func.coalesce(_func.sum(PointBatch.remaining), 0))
.where(
PointBatch.user_id == user_id,
PointBatch.remaining > 0,
PointBatch.expired_at > _now(),
)
)
available_balance = available_result.scalar() or 0
return {
"balance": available_balance,
"total_recharged": up.total_recharged,
"total_consumed": up.total_consumed,
"total_expired": up.total_expired,
}
async def check_balance(
db: AsyncSession,
user_id: UUID | str,
required_points: int = 0,
) -> dict:
"""
检查用户余额是否足够。
:param required_points: 需要的积分数量
:return: {"sufficient": bool, "balance": int, "required": int}
"""
balance_info = await get_user_balance(db, user_id)
balance = balance_info["balance"]
return {
"sufficient": balance >= required_points,
"balance": balance,
"required": required_points,
}
# ── 充值 ──────────────────────────────────────────────
async def recharge(
db: AsyncSession,
*,
user_id: UUID | str,
points: int,
source: str,
description: str = "",
order_id: int | None = None,
batch_expired_at: datetime | None = None,
) -> PointTransaction:
"""
直接给用户账户充值积分。
:param points: 正整数,充值积分数量
:param source: wxpay / invite / gift / compensation
:param order_id: 关联的充值订单 ID(仅 wxpay 时填)
:param batch_expired_at: 该批次过期时间,默认 180 天后
"""
if points <= 0:
raise ValueError("充值积分必须大于 0")
now = _now()
# 幂等保护:同一笔订单(order_id)只能充值一次
if order_id:
existing_result = await db.execute(
select(PointTransaction)
.where(
PointTransaction.source_id == str(order_id),
PointTransaction.type == "recharge",
)
)
existing_tx = existing_result.scalar_one_or_none()
if existing_tx:
logger.warning(f"[Points] 订单 {order_id} 已充值过,跳过重复充值")
return existing_tx
# 1. 获取或创建用户积分账户
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
up = result.scalar_one_or_none()
if not up:
up = UserPoint(
user_id=user_id,
balance=0,
total_recharged=0,
total_consumed=0,
total_expired=0,
)
db.add(up)
await db.flush()
# 2. 增加余额
balance_before = up.balance
up.balance += points
up.total_recharged += points
# 3. 写入批次(欠费充值时,先用新积分偿还欠费,剩余部分才写入批次)
debt = max(0, -balance_before) # 欠费金额
batch_remaining = max(0, points - debt) # 实际可用的批次积分
batch_id = None
if batch_remaining > 0:
expired_at = batch_expired_at or (now + timedelta(days=EXPIRATION_DAYS))
batch = PointBatch(
user_id=user_id,
amount=batch_remaining,
remaining=batch_remaining,
expired_at=expired_at,
source=source,
)
db.add(batch)
await db.flush()
batch_id = batch.id
# 4. 写流水
tx = PointTransaction(
user_id=user_id,
type="recharge",
amount=points,
balance_before=balance_before,
balance_after=up.balance,
source_type=source,
source_id=str(order_id) if order_id else None,
batch_id=batch_id,
category=_CATEGORY_MAP.get(source, "充值"),
description=description or f"{source} 充值 {points} 积分",
)
db.add(tx)
return tx
# ── 消费(后置扣费)───────────────────────────────────
# source_type → category 映射(用于流水分类展示)
_CATEGORY_MAP: dict[str, str] = {
"script": "脚本生成",
"polish": "文案润色",
"title": "标题生成",
"tts": "配音合成",
"voice_clone": "声音复刻",
"video": "视频生成",
"compose": "压制成片",
"subtitle_burn": "字幕烧录",
"cover_design": "封面设计",
"cover_avatar": "封面形象",
"wxpay": "充值",
"compensation": "充值",
"invite": "充值",
"gift": "充值",
}
async def consume(
db: AsyncSession,
*,
user_id: UUID | str,
points: int,
source_type: str,
source_id: str,
description: str = "",
duration: float | None = None,
category: str | None = None,
allow_negative: bool = False,
) -> PointTransaction:
"""
直接扣费(后置计费)。
业务执行成功后调用,按实际消耗直接扣除余额。
默认不允许欠费(余额不足时抛出 ValueError)。
Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。
:param points: 实际消耗积分(正整数)
:param source_type: 消费来源类型
:param source_id: 关联的任务 ID 或订单 ID
:param allow_negative: 是否允许扣费后余额为负
:return: 消费流水记录
"""
if points <= 0:
raise ValueError("消费积分必须大于 0")
# 加锁顺序约定:先 PointBatch 后 UserPoint,避免与 expire_batches 死锁
# expire_batches 的锁顺序也是 PointBatch → UserPoint
# 1. FIFO 扣减批次 remaining(先加锁,与 expire_batches 顺序一致)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.user_id == user_id,
PointBatch.remaining > 0,
PointBatch.expired_at > _now(),
)
.order_by(PointBatch.expired_at.asc())
.with_for_update()
)
batches: list[PointBatch] = list(result.scalars().all())
# 2. 获取用户积分账户(加锁)
result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == user_id)
.with_for_update()
)
up = result.scalar_one_or_none()
if not up:
# 没有积分账户也允许消费(形成欠费)
up = UserPoint(
user_id=user_id,
balance=0,
total_recharged=0,
total_consumed=0,
total_expired=0,
)
db.add(up)
await db.flush()
# 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣
available = sum(b.remaining for b in batches)
if not allow_negative and available < points:
raise ValueError(f"积分不足,当前可用余额 {available},需要 {points} 积分")
remaining_to_deduct = points
for batch in batches:
if remaining_to_deduct <= 0:
break
deduct = min(batch.remaining, remaining_to_deduct)
batch.remaining -= deduct
remaining_to_deduct -= deduct
# 4. 更新用户账户(允许欠费:balance 可能变负)
balance_before = up.balance
up.balance -= points
up.total_consumed += points
# 4. 写消费流水
tx = PointTransaction(
user_id=user_id,
type="consume",
amount=points,
balance_before=balance_before,
balance_after=up.balance,
source_type=source_type,
source_id=source_id,
batch_id=batches[0].id if batches else None,
duration=duration,
category=category or _CATEGORY_MAP.get(source_type, source_type),
description=description or f"消费 {source_type} {points} 积分",
)
db.add(tx)
return tx
# ── 过期回收 ──────────────────────────────────────────
async def expire_batches(db: AsyncSession) -> int:
"""
回收过期积分批次。返回过期积分总数。
这是一个批量维护操作,建议由定时任务调用。
使用 FOR UPDATE 锁定批次和用户积分账户,防止并发回收冲突。
"""
now = _now()
# 1. 获取过期批次(加锁)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.expired_at <= now,
PointBatch.remaining > 0,
)
.with_for_update()
)
expired_batches: list[PointBatch] = list(result.scalars().all())
total_expired = 0
for batch in expired_batches:
recoverable = batch.remaining
if recoverable <= 0:
continue
# 获取用户积分账户(加锁)
up_result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == batch.user_id)
.with_for_update()
)
up = up_result.scalar_one_or_none()
if not up:
continue
balance_before = up.balance
up.balance -= recoverable
up.total_expired += recoverable
batch.remaining -= recoverable
total_expired += recoverable
# 写过期流水
tx = PointTransaction(
user_id=batch.user_id,
type="expire",
amount=recoverable,
balance_before=balance_before,
balance_after=up.balance,
source_type=batch.source,
source_id=None,
batch_id=batch.id,
category="过期回收",
description=f"积分批次过期回收 {recoverable} 积分",
)
db.add(tx)
return total_expired