""" 积分系统 Service 层 =================== 核心能力: 1. 余额查询 2. 充值(直接到账 / 微信回调后到账) 3. 消费(后置扣费:执行业务 → 出结果 → 直接扣费) 4. 过期回收 5. 流水记录 设计原则: - 所有业务操作在一个事务内完成(balance + batch + transaction 三者原子性)。 - FIFO 批次消耗:按 expired_at 升序扣减。 - 后置扣费模式:先执行业务,出结果后按实际消耗扣费。 - 允许欠费(单次业务实际消耗超出预估上限),但欠费后不可继续使用。 注意:本 Service 不自行 commit,由调用方(API 层)通过 FastAPI Depends 注入的 Session 统一提交。所有操作在调用方事务内原子执行。 """ from __future__ import annotations import logging import math from datetime import UTC, datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING, Any import yaml from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) from app.core.exceptions import InsufficientPointsException from app.models.point_batch import PointBatch from app.models.point_transaction import PointTransaction from app.models.user_point import UserPoint if TYPE_CHECKING: from uuid import UUID # ── 配置加载 ────────────────────────────────────────── _CONFIG_PATH = Path(__file__).resolve().parents[2] / "config" / "points-config.yaml" def _load_points_config() -> dict[str, Any]: """加载积分计费配置。服务启动时读取一次,后续内存中使用。""" if not _CONFIG_PATH.exists(): raise FileNotFoundError(f"积分配置文件不存在: {_CONFIG_PATH}") with open(_CONFIG_PATH, encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} # 合并为统一的查询字典:source_type -> {"mode": "fixed|duration|free", ...} merged: dict[str, dict] = {} for key, points in cfg.get("fixed_costs", {}).items(): merged[key] = {"mode": "fixed", "points": points} for key, rule in cfg.get("duration_based_costs", {}).items(): merged[key] = {"mode": "duration", **rule} for key in cfg.get("free_services", []): merged[key] = {"mode": "free", "points": 0} # 保留充值档位原始配置 merged["_recharge_options"] = cfg.get("recharge_options", []) return merged POINTS_CONFIG: dict[str, Any] = _load_points_config() def get_recharge_options() -> list[dict]: """获取充值档位配置(由后端控制,支持积分赠送)""" options = POINTS_CONFIG.get("_recharge_options", []) if isinstance(options, list): return options return [] def get_chargeable_source_types() -> list[str]: """获取所有需要扣费的业务类型列表(排除免费业务)""" return [ key for key, cfg in POINTS_CONFIG.items() if not key.startswith("_") and cfg.get("mode") != "free" ] EXPIRATION_DAYS = 180 def _now() -> datetime: """返回带时区的当前时间""" return datetime.now(UTC) def _calculate_cost(source_type: str, param: dict | None = None) -> int: """根据消费类型和实际结果参数计算所需积分(后置扣费时使用)""" if source_type not in POINTS_CONFIG: raise ValueError(f"未知的消费类型: {source_type}") cfg = POINTS_CONFIG[source_type] mode = cfg["mode"] if mode == "free": return 0 if mode == "fixed": return cfg["points"] if mode == "duration": if param is None: raise ValueError(f"消费类型 {source_type} 需要提供参数才能计算积分") seconds = param.get("seconds", 0) min_points = cfg.get("min_points", 1) if "divisor" in cfg: return max(min_points, math.ceil(seconds / cfg["divisor"])) if "multiplier" in cfg: return max(min_points, math.ceil(seconds) * cfg["multiplier"]) raise ValueError(f"消费类型 {source_type} 缺少计算规则") def _estimate_max_cost(source_type: str, param: dict | None = None) -> int: """ 预估消费上限(执行业务前检查余额用)。 按秒计费业务取保守估计,确保实际消耗不会超过预估上限。 """ if source_type not in POINTS_CONFIG: raise ValueError(f"未知的消费类型: {source_type}") cfg = POINTS_CONFIG[source_type] mode = cfg["mode"] if mode == "free": return 0 if mode == "fixed": return cfg["points"] if mode == "duration": if param is None: raise ValueError(f"消费类型 {source_type} 需要提供参数才能预估积分") est = cfg.get("estimation", {}) min_points = cfg.get("min_points", 1) if "seconds_per_char" in est: # TTS 模式:字数 → 预估秒数 char_count = param.get("char_count", 0) if char_count <= 0: raise ValueError("TTS 预估需要提供 char_count 参数") estimated_seconds = char_count * est["seconds_per_char"] if "divisor" in cfg: return max(min_points, math.ceil(estimated_seconds / cfg["divisor"])) if est.get("use_input_seconds"): # 视频模式:直接使用输入秒数作为预估上限 seconds = param.get("input_seconds", 0) if seconds <= 0: raise ValueError("video 预估需要提供 input_seconds 参数") if "multiplier" in cfg: return max(min_points, math.ceil(seconds) * cfg["multiplier"]) raise ValueError(f"消费类型 {source_type} 缺少预估规则") # ── 余额查询 ────────────────────────────────────────── async def get_user_balance(db: AsyncSession, user_id: UUID | str) -> dict: """获取用户积分余额快照(实时计算,排除已过期批次)。""" result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id)) up = result.scalar_one_or_none() if not up: return { "balance": 0, "total_recharged": 0, "total_consumed": 0, "total_expired": 0, } # 实时计算可用余额(排除已过期批次),避免 expire_batches 延迟时数据不一致 from sqlalchemy import func as _func available_result = await db.execute( select(_func.coalesce(_func.sum(PointBatch.remaining), 0)).where( PointBatch.user_id == user_id, PointBatch.remaining > 0, PointBatch.expired_at > _now(), ) ) available_balance = available_result.scalar() or 0 return { "balance": available_balance, "total_recharged": up.total_recharged, "total_consumed": up.total_consumed, "total_expired": up.total_expired, } async def check_balance( db: AsyncSession, user_id: UUID | str, required_points: int = 0, ) -> dict: """ 检查用户余额是否足够。 :param required_points: 需要的积分数量 :return: {"sufficient": bool, "balance": int, "required": int} """ balance_info = await get_user_balance(db, user_id) balance = balance_info["balance"] return { "sufficient": balance >= required_points, "balance": balance, "required": required_points, } # ── 充值 ────────────────────────────────────────────── async def recharge( db: AsyncSession, *, user_id: UUID | str, points: int, source: str, description: str = "", order_id: int | None = None, batch_expired_at: datetime | None = None, ) -> PointTransaction: """ 直接给用户账户充值积分。 :param points: 正整数,充值积分数量 :param source: wxpay / invite / gift / compensation :param order_id: 关联的充值订单 ID(仅 wxpay 时填) :param batch_expired_at: 该批次过期时间,默认 180 天后 """ if points <= 0: raise ValueError("充值积分必须大于 0") now = _now() # 幂等保护:同一笔订单(order_id)只能充值一次 if order_id: existing_result = await db.execute( select(PointTransaction).where( PointTransaction.source_id == str(order_id), PointTransaction.type == "recharge", ) ) existing_tx = existing_result.scalar_one_or_none() if existing_tx: logger.warning(f"[Points] 订单 {order_id} 已充值过,跳过重复充值") return existing_tx # 1. 获取或创建用户积分账户 result = await db.execute(select(UserPoint).where(UserPoint.user_id == user_id)) up = result.scalar_one_or_none() if not up: up = UserPoint( user_id=user_id, balance=0, total_recharged=0, total_consumed=0, total_expired=0, ) db.add(up) await db.flush() # 2. 增加余额 balance_before = up.balance up.balance += points up.total_recharged += points # 3. 写入批次(欠费充值时,先用新积分偿还欠费,剩余部分才写入批次) debt = max(0, -balance_before) # 欠费金额 batch_remaining = max(0, points - debt) # 实际可用的批次积分 batch_id = None if batch_remaining > 0: expired_at = batch_expired_at or (now + timedelta(days=EXPIRATION_DAYS)) batch = PointBatch( user_id=user_id, amount=batch_remaining, remaining=batch_remaining, expired_at=expired_at, source=source, ) db.add(batch) await db.flush() batch_id = batch.id # 4. 写流水 tx = PointTransaction( user_id=user_id, type="recharge", amount=points, balance_before=balance_before, balance_after=up.balance, source_type=source, source_id=str(order_id) if order_id else None, batch_id=batch_id, category=_CATEGORY_MAP.get(source, "充值"), description=description or f"{source} 充值 {points} 积分", ) db.add(tx) return tx # ── 消费(后置扣费)─────────────────────────────────── # source_type → category 映射(用于流水分类展示) _CATEGORY_MAP: dict[str, str] = { "script": "脚本生成", "polish": "文案润色", "title": "标题生成", "tts": "配音合成", "voice_clone": "声音复刻", "video": "视频生成", "compose": "压制成片", "subtitle_burn": "字幕烧录", "cover_design": "封面设计", "cover_avatar": "封面形象", "wxpay": "充值", "compensation": "充值", "invite": "充值", "gift": "充值", } async def consume( db: AsyncSession, *, user_id: UUID | str, points: int, source_type: str, source_id: str, description: str = "", duration: float | None = None, category: str | None = None, allow_negative: bool = False, ) -> PointTransaction: """ 直接扣费(后置计费)。 业务执行成功后调用,按实际消耗直接扣除余额。 默认不允许欠费(余额不足时抛出 InsufficientPointsException)。 Scheduler 后置扣费等场景可设置 allow_negative=True,允许余额变负。 :param points: 实际消耗积分(正整数) :param source_type: 消费来源类型 :param source_id: 关联的任务 ID 或订单 ID :param allow_negative: 是否允许扣费后余额为负 :return: 消费流水记录 """ if points <= 0: raise ValueError("消费积分必须大于 0") # 加锁顺序约定:先 PointBatch 后 UserPoint,避免与 expire_batches 死锁 # expire_batches 的锁顺序也是 PointBatch → UserPoint # 1. FIFO 扣减批次 remaining(先加锁,与 expire_batches 顺序一致) result = await db.execute( select(PointBatch) .where( PointBatch.user_id == user_id, PointBatch.remaining > 0, PointBatch.expired_at > _now(), ) .order_by(PointBatch.expired_at.asc()) .with_for_update() ) batches: list[PointBatch] = list(result.scalars().all()) # 2. 获取用户积分账户(加锁) result = await db.execute( select(UserPoint).where(UserPoint.user_id == user_id).with_for_update() ) up = result.scalar_one_or_none() if not up: # 没有积分账户也允许消费(形成欠费) up = UserPoint( user_id=user_id, balance=0, total_recharged=0, total_consumed=0, total_expired=0, ) db.add(up) await db.flush() # 3. 余额检查:用实时可用余额(未过期批次 remaining 总和),避免 expire_batches 延迟导致超扣 available = sum(b.remaining for b in batches) if not allow_negative and available < points: raise InsufficientPointsException(f"积分不足,当前可用余额 {available},需要 {points} 积分") remaining_to_deduct = points for batch in batches: if remaining_to_deduct <= 0: break deduct = min(batch.remaining, remaining_to_deduct) batch.remaining -= deduct remaining_to_deduct -= deduct # 4. 更新用户账户(允许欠费:balance 可能变负) balance_before = up.balance up.balance -= points up.total_consumed += points # 4. 写消费流水 tx = PointTransaction( user_id=user_id, type="consume", amount=points, balance_before=balance_before, balance_after=up.balance, source_type=source_type, source_id=source_id, batch_id=batches[0].id if batches else None, duration=duration, category=category or _CATEGORY_MAP.get(source_type, source_type), description=description or f"消费 {source_type} {points} 积分", ) db.add(tx) return tx # ── 过期回收 ────────────────────────────────────────── async def expire_batches(db: AsyncSession) -> int: """ 回收过期积分批次。返回过期积分总数。 这是一个批量维护操作,建议由定时任务调用。 使用 FOR UPDATE 锁定批次和用户积分账户,防止并发回收冲突。 """ now = _now() # 1. 获取过期批次(加锁) result = await db.execute( select(PointBatch) .where( PointBatch.expired_at <= now, PointBatch.remaining > 0, ) .with_for_update() ) expired_batches: list[PointBatch] = list(result.scalars().all()) total_expired = 0 for batch in expired_batches: recoverable = batch.remaining if recoverable <= 0: continue # 获取用户积分账户(加锁) up_result = await db.execute( select(UserPoint).where(UserPoint.user_id == batch.user_id).with_for_update() ) up = up_result.scalar_one_or_none() if not up: continue balance_before = up.balance up.balance -= recoverable up.total_expired += recoverable batch.remaining -= recoverable total_expired += recoverable # 写过期流水 tx = PointTransaction( user_id=batch.user_id, type="expire", amount=recoverable, balance_before=balance_before, balance_after=up.balance, source_type=batch.source, source_id=None, batch_id=batch.id, category="过期回收", description=f"积分批次过期回收 {recoverable} 积分", ) db.add(tx) return total_expired