""" 积分流水 CRUD ============= 只增不改,用于审计和对账。 """ from datetime import datetime, time from uuid import UUID from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.crud.base import CRUDBase from app.models.point_transaction import PointTransaction class PointTransactionCRUD(CRUDBase[PointTransaction]): """积分流水数据访问对象""" def __init__(self) -> None: super().__init__(PointTransaction) async def get_by_user_id( self, db: AsyncSession, *, user_id: UUID | str, skip: int = 0, limit: int = 50, tx_type: str | None = None, category: str | None = None, source_type: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> list[PointTransaction]: """根据用户 ID 获取流水记录(支持筛选和分页,按时间倒序)""" stmt = select(PointTransaction).where(PointTransaction.user_id == user_id) if tx_type: stmt = stmt.where(PointTransaction.type == tx_type) if category: stmt = stmt.where(PointTransaction.category == category) if source_type: stmt = stmt.where(PointTransaction.source_type == source_type) if start_time: stmt = stmt.where(PointTransaction.created_at >= start_time) if end_time: stmt = stmt.where(PointTransaction.created_at <= end_time) stmt = stmt.order_by(PointTransaction.created_at.desc()).offset(skip).limit(limit) result = await db.execute(stmt) return list(result.scalars().all()) async def count_by_user_id( self, db: AsyncSession, *, user_id: UUID | str, tx_type: str | None = None, category: str | None = None, source_type: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> int: """根据筛选条件统计流水记录总数""" from sqlalchemy import func stmt = select(func.count(PointTransaction.id)).where(PointTransaction.user_id == user_id) if tx_type: stmt = stmt.where(PointTransaction.type == tx_type) if category: stmt = stmt.where(PointTransaction.category == category) if source_type: stmt = stmt.where(PointTransaction.source_type == source_type) if start_time: stmt = stmt.where(PointTransaction.created_at >= start_time) if end_time: stmt = stmt.where(PointTransaction.created_at <= end_time) result = await db.execute(stmt) return result.scalar() or 0 async def get_by_source( self, db: AsyncSession, *, user_id: str, source_type: str, source_id: str, ) -> list[PointTransaction]: """根据消费来源查询流水(用于查询某次 AI 调用的扣费记录)""" result = await db.execute( select(PointTransaction) .where( PointTransaction.user_id == user_id, PointTransaction.source_type == source_type, PointTransaction.source_id == source_id, ) .order_by(PointTransaction.created_at.desc()) ) return list(result.scalars().all()) async def sum_consumed_today( self, db: AsyncSession, *, user_id: UUID | str, ) -> int: """统计用户今日消费积分总和""" now = datetime.now() start_of_day = datetime.combine(now.date(), time.min) stmt = select(func.coalesce(func.sum(PointTransaction.amount), 0)).where( PointTransaction.user_id == user_id, PointTransaction.type == "consume", PointTransaction.created_at >= start_of_day, ) result = await db.execute(stmt) return result.scalar() or 0 # 导出实例 point_transaction = PointTransactionCRUD()