Files
meijiaka-zy/python-api/app/services/point_service.py
T
小鱼开发 51521fc0dd feat(payment): 微信支付 APIv2 + 积分充值 + SMS 短信 + 双 Token 认证
- 微信支付从 APIv3 降级为 APIv2(MD5/XML)
- 积分系统:充值下单、微信回调、消费冻结/结算/退款
- SMS B2M 短信验证码服务
- 双 Token 认证(Access 30min + Refresh 30days)
- SSE 单设备踢人
- 用户设备管理、积分账户模型
- Alembic 迁移脚本
2026-05-07 18:43:02 +08:00

503 lines
15 KiB
Python

"""
积分系统 Service 层
===================
核心能力:
1. 余额查询(含冻结)
2. 充值(直接到账 / 微信回调后到账)
3. 消费(预扣 → 结算/退还)
4. 过期回收
5. 流水记录
设计原则:
- 所有业务操作在一个事务内完成(balance + batch + transaction 三者原子性)。
- FIFO 批次消耗:按 expired_at 升序扣减。
- 消费不直接扣除可用余额,而是先冻结,后根据结果结算或退回。
注意:本 Service 不自行 commit,由调用方(API 层)通过 FastAPI Depends 注入的
Session 统一提交。所有操作在调用方事务内原子执行。
"""
from __future__ import annotations
import math
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.point_transaction import point_transaction
from app.models.point_batch import PointBatch
from app.models.point_transaction import PointTransaction
from app.models.user_point import UserPoint
if TYPE_CHECKING:
from uuid import UUID
# ── 常量 ──────────────────────────────────────────────
EXPIRATION_DAYS = 180
POINTS_COST: dict[str, int | None] = {
"script": 5,
"polish": 1,
"title": 1,
"voice_clone": 200,
"tts": None, # ceil(seconds / 5)
"video": None, # seconds * 5
}
def _now() -> datetime:
"""返回带时区的当前时间"""
return datetime.now(UTC)
def _calculate_cost(source_type: str, param: dict | None = None) -> int:
"""根据消费类型计算所需积分"""
if source_type not in POINTS_COST:
raise ValueError(f"未知的消费类型: {source_type}")
base = POINTS_COST[source_type]
if base is not None:
return base
if param is None:
raise ValueError(f"消费类型 {source_type} 需要提供参数才能计算积分")
if source_type == "tts":
seconds = param.get("seconds", 0)
return max(1, math.ceil(seconds / 5))
if source_type == "video":
seconds = param.get("seconds", 0)
return max(1, seconds * 5)
raise ValueError(f"消费类型 {source_type} 缺少计算规则")
# ── 余额查询 ──────────────────────────────────────────
async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict:
"""获取用户积分余额快照"""
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
up = result.scalar_one_or_none()
if not up:
return {
"balance": 0,
"frozen": 0,
"available": 0,
"total_recharged": 0,
"total_consumed": 0,
"total_expired": 0,
"total_refunded": 0,
}
return {
"balance": up.balance,
"frozen": up.frozen,
"available": up.balance - up.frozen,
"total_recharged": up.total_recharged,
"total_consumed": up.total_consumed,
"total_expired": up.total_expired,
"total_refunded": up.total_refunded,
}
# ── 充值 ──────────────────────────────────────────────
async def recharge(
db: AsyncSession,
*,
user_id: UUID | str,
points: int,
source: str,
description: str = "",
order_id: int | None = None,
batch_expired_at: datetime | None = None,
) -> PointTransaction:
"""
直接给用户账户充值积分。
:param points: 正整数,充值积分数量
:param source: wxpay / invite / gift / compensation
:param order_id: 关联的充值订单 ID(仅 wxpay 时填)
:param batch_expired_at: 该批次过期时间,默认 180 天后
"""
if points <= 0:
raise ValueError("充值积分必须大于 0")
now = _now()
# 1. 获取或创建用户积分账户
result = await db.execute(
select(UserPoint).where(UserPoint.user_id == user_id)
)
up = result.scalar_one_or_none()
if not up:
up = UserPoint(
user_id=user_id,
balance=0,
frozen=0,
total_recharged=0,
total_consumed=0,
total_expired=0,
total_refunded=0,
)
db.add(up)
await db.flush()
# 2. 增加余额
balance_before = up.balance
up.balance += points
up.total_recharged += points
# 3. 写入批次
expired_at = batch_expired_at or (now + timedelta(days=EXPIRATION_DAYS))
batch = PointBatch(
user_id=user_id,
amount=points,
remaining=points,
frozen=0,
expired_at=expired_at,
source=source,
)
db.add(batch)
await db.flush()
# 4. 写流水
tx = PointTransaction(
user_id=user_id,
type="recharge",
amount=points,
balance_before=balance_before,
balance_after=up.balance,
source_type=source,
source_id=str(order_id) if order_id else None,
batch_id=batch.id,
description=description or f"{source} 充值 {points} 积分",
)
db.add(tx)
return tx
# ── 消费预扣 ──────────────────────────────────────────
async def freeze_for_consumption(
db: AsyncSession,
*,
user_id: UUID | str,
source_type: str,
source_id: str,
param: dict | None = None,
description: str = "",
) -> tuple[PointTransaction, list[PointBatch]]:
"""
消费前预扣积分。
:return: (流水记录, 被扣减的批次列表)
:raises ValueError: 积分不足时抛出
"""
points = _calculate_cost(source_type, param)
# 1. 获取用户积分(加锁防并发超扣)
result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == user_id)
.with_for_update()
)
up = result.scalar_one_or_none()
if not up or (up.balance - up.frozen) < points:
raise ValueError("积分不足")
# 2. 获取可用批次(按过期时间升序,加锁)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.user_id == user_id,
PointBatch.remaining - PointBatch.frozen > 0,
PointBatch.expired_at > _now(),
)
.order_by(PointBatch.expired_at.asc())
.with_for_update()
)
batches: list[PointBatch] = list(result.scalars().all())
remaining_to_deduct = points
affected_batches: list[PointBatch] = []
for batch in batches:
if remaining_to_deduct <= 0:
break
available = batch.remaining - batch.frozen
if available <= 0:
continue
deduct = min(available, remaining_to_deduct)
batch.frozen += deduct
remaining_to_deduct -= deduct
affected_batches.append(batch)
if remaining_to_deduct > 0:
raise ValueError("积分不足(批次计算异常)")
# 3. 更新账户冻结额
balance_before = up.balance
up.balance -= points
up.frozen += points
# 4. 写流水
tx = PointTransaction(
user_id=user_id,
type="consume",
amount=-points, # 负数表示支出
balance_before=balance_before,
balance_after=up.balance,
source_type=source_type,
source_id=source_id,
batch_id=affected_batches[0].id if affected_batches else None,
description=description or f"消费 {source_type} 预扣 {points} 积分",
)
db.add(tx)
return tx, affected_batches
# ── 消费结算(成功)───────────────────────────────────
async def settle_consumption(
db: AsyncSession,
*,
user_id: UUID | str,
source_type: str,
source_id: str,
actual_points: int | None = None,
description: str = "",
) -> PointTransaction | None:
"""
消费成功后结算。
:param actual_points: 实际消耗积分(如果比预扣少,差额退回余额)
:return: 差额退款流水(如果有),否则 None
"""
# 1. 查找预扣流水
txs = await point_transaction.get_by_source(
db, user_id=user_id, source_type=source_type, source_id=source_id
)
freeze_tx = next((t for t in txs if t.type == "consume"), None)
if not freeze_tx:
raise ValueError("未找到预扣记录")
frozen_points = abs(freeze_tx.amount)
# 如果没有实际消耗值,视为全额消耗
if actual_points is None:
actual_points = frozen_points
if actual_points > frozen_points:
raise ValueError("实际消耗不能大于预扣金额")
# 2. 获取用户积分账户(加锁)
result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == user_id)
.with_for_update()
)
up = result.scalar_one_or_none()
if not up:
raise ValueError("用户积分账户不存在")
# 3. 解冻用户账户
up.frozen -= frozen_points
if up.frozen < 0:
up.frozen = 0
# 4. 获取冻结中的批次(加锁)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.user_id == user_id,
PointBatch.frozen > 0,
)
.order_by(PointBatch.expired_at.asc())
.with_for_update()
)
batches: list[PointBatch] = list(result.scalars().all())
# 5. 按 FIFO 顺序从批次中扣减实际消耗
remaining_to_consume = actual_points
for batch in batches:
if remaining_to_consume <= 0:
break
batch_frozen = batch.frozen
batch.frozen = 0
consume_from_batch = min(batch_frozen, remaining_to_consume)
batch.remaining -= consume_from_batch
remaining_to_consume -= consume_from_batch
# 6. 记录实际消耗
up.total_consumed += actual_points
# 7. 如果有差额,退回余额
refund_tx = None
if actual_points < frozen_points:
refund_points = frozen_points - actual_points
up.balance += refund_points
up.total_refunded += refund_points
# 差额退回流水
refund_tx = PointTransaction(
user_id=user_id,
type="refund",
amount=refund_points,
balance_before=up.balance - refund_points,
balance_after=up.balance,
source_type=source_type,
source_id=source_id,
batch_id=batches[0].id if batches else None,
description=description or f"{source_type} 消费差额退回 {refund_points} 积分",
)
db.add(refund_tx)
return refund_tx
# ── 消费失败退还 ──────────────────────────────────────
async def refund_consumption(
db: AsyncSession,
*,
user_id: UUID | str,
source_type: str,
source_id: str,
description: str = "",
) -> PointTransaction:
"""
消费失败时全额退还预扣积分。
"""
# 1. 查找预扣流水
txs = await point_transaction.get_by_source(
db, user_id=user_id, source_type=source_type, source_id=source_id
)
freeze_tx = next((t for t in txs if t.type == "consume"), None)
if not freeze_tx:
raise ValueError("未找到预扣记录")
frozen_points = abs(freeze_tx.amount)
# 2. 获取用户积分账户(加锁)
result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == user_id)
.with_for_update()
)
up = result.scalar_one_or_none()
if not up:
raise ValueError("用户积分账户不存在")
# 3. 解冻并退回余额
up.frozen -= frozen_points
if up.frozen < 0:
up.frozen = 0
up.balance += frozen_points
up.total_refunded += frozen_points
# 4. 解冻批次(remaining 不变,因为预扣时没有扣减 remaining)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.user_id == user_id,
PointBatch.frozen > 0,
)
.order_by(PointBatch.expired_at.asc())
.with_for_update()
)
batches: list[PointBatch] = list(result.scalars().all())
for batch in batches:
batch.frozen = 0
# 5. 写退款流水
tx = PointTransaction(
user_id=user_id,
type="refund",
amount=frozen_points,
balance_before=up.balance - frozen_points,
balance_after=up.balance,
source_type=source_type,
source_id=source_id,
batch_id=batches[0].id if batches else None,
description=description or f"{source_type} 消费失败退回 {frozen_points} 积分",
)
db.add(tx)
return tx
# ── 过期回收 ──────────────────────────────────────────
async def expire_batches(db: AsyncSession) -> int:
"""
回收过期积分批次。返回过期积分总数。
这是一个批量维护操作,建议由定时任务调用。
使用 FOR UPDATE 锁定批次和用户积分账户,防止并发回收冲突。
"""
now = _now()
# 1. 获取过期批次(加锁)
result = await db.execute(
select(PointBatch)
.where(
PointBatch.expired_at <= now,
PointBatch.remaining > 0,
)
.with_for_update()
)
expired_batches: list[PointBatch] = list(result.scalars().all())
total_expired = 0
for batch in expired_batches:
# 实际可回收 = remaining - frozen(冻结部分等结算时再处理)
recoverable = batch.remaining - batch.frozen
if recoverable <= 0:
continue
# 获取用户积分账户(加锁)
up_result = await db.execute(
select(UserPoint)
.where(UserPoint.user_id == batch.user_id)
.with_for_update()
)
up = up_result.scalar_one_or_none()
if not up:
continue
balance_before = up.balance
up.balance -= recoverable
up.total_expired += recoverable
batch.remaining -= recoverable
total_expired += recoverable
# 写过期流水
tx = PointTransaction(
user_id=batch.user_id,
type="expire",
amount=-recoverable,
balance_before=balance_before,
balance_after=up.balance,
source_type=batch.source,
source_id=None,
batch_id=batch.id,
description=f"积分批次过期回收 {recoverable} 积分",
)
db.add(tx)
return total_expired