Files
meijiaka-zy/python-api/app/crud/point_transaction.py
T

125 lines
4.0 KiB
Python

"""
积分流水 CRUD
=============
只增不改,用于审计和对账。
"""
from datetime import datetime, time
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.point_transaction import PointTransaction
class PointTransactionCRUD(CRUDBase[PointTransaction]):
"""积分流水数据访问对象"""
def __init__(self) -> None:
super().__init__(PointTransaction)
async def get_by_user_id(
self,
db: AsyncSession,
*,
user_id: UUID | str,
skip: int = 0,
limit: int = 50,
tx_type: str | None = None,
category: str | None = None,
source_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[PointTransaction]:
"""根据用户 ID 获取流水记录(支持筛选和分页,按时间倒序)"""
stmt = select(PointTransaction).where(PointTransaction.user_id == user_id)
if tx_type:
stmt = stmt.where(PointTransaction.type == tx_type)
if category:
stmt = stmt.where(PointTransaction.category == category)
if source_type:
stmt = stmt.where(PointTransaction.source_type == source_type)
if start_time:
stmt = stmt.where(PointTransaction.created_at >= start_time)
if end_time:
stmt = stmt.where(PointTransaction.created_at <= end_time)
stmt = stmt.order_by(PointTransaction.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(stmt)
return list(result.scalars().all())
async def count_by_user_id(
self,
db: AsyncSession,
*,
user_id: UUID | str,
tx_type: str | None = None,
category: str | None = None,
source_type: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> int:
"""根据筛选条件统计流水记录总数"""
from sqlalchemy import func
stmt = select(func.count(PointTransaction.id)).where(PointTransaction.user_id == user_id)
if tx_type:
stmt = stmt.where(PointTransaction.type == tx_type)
if category:
stmt = stmt.where(PointTransaction.category == category)
if source_type:
stmt = stmt.where(PointTransaction.source_type == source_type)
if start_time:
stmt = stmt.where(PointTransaction.created_at >= start_time)
if end_time:
stmt = stmt.where(PointTransaction.created_at <= end_time)
result = await db.execute(stmt)
return result.scalar() or 0
async def get_by_source(
self,
db: AsyncSession,
*,
user_id: str,
source_type: str,
source_id: str,
) -> list[PointTransaction]:
"""根据消费来源查询流水(用于查询某次 AI 调用的扣费记录)"""
result = await db.execute(
select(PointTransaction)
.where(
PointTransaction.user_id == user_id,
PointTransaction.source_type == source_type,
PointTransaction.source_id == source_id,
)
.order_by(PointTransaction.created_at.desc())
)
return list(result.scalars().all())
async def sum_consumed_today(
self,
db: AsyncSession,
*,
user_id: UUID | str,
) -> int:
"""统计用户今日消费积分总和"""
now = datetime.now()
start_of_day = datetime.combine(now.date(), time.min)
stmt = select(func.coalesce(func.sum(PointTransaction.amount), 0)).where(
PointTransaction.user_id == user_id,
PointTransaction.type == "consume",
PointTransaction.created_at >= start_of_day,
)
result = await db.execute(stmt)
return result.scalar() or 0
# 导出实例
point_transaction = PointTransactionCRUD()