125 lines
4.0 KiB
Python
125 lines
4.0 KiB
Python
"""
|
|
积分流水 CRUD
|
|
=============
|
|
|
|
只增不改,用于审计和对账。
|
|
"""
|
|
|
|
from datetime import datetime, time
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.point_transaction import PointTransaction
|
|
|
|
|
|
class PointTransactionCRUD(CRUDBase[PointTransaction]):
|
|
"""积分流水数据访问对象"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(PointTransaction)
|
|
|
|
async def get_by_user_id(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
tx_type: str | None = None,
|
|
category: str | None = None,
|
|
source_type: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
) -> list[PointTransaction]:
|
|
"""根据用户 ID 获取流水记录(支持筛选和分页,按时间倒序)"""
|
|
stmt = select(PointTransaction).where(PointTransaction.user_id == user_id)
|
|
|
|
if tx_type:
|
|
stmt = stmt.where(PointTransaction.type == tx_type)
|
|
if category:
|
|
stmt = stmt.where(PointTransaction.category == category)
|
|
if source_type:
|
|
stmt = stmt.where(PointTransaction.source_type == source_type)
|
|
if start_time:
|
|
stmt = stmt.where(PointTransaction.created_at >= start_time)
|
|
if end_time:
|
|
stmt = stmt.where(PointTransaction.created_at <= end_time)
|
|
|
|
stmt = stmt.order_by(PointTransaction.created_at.desc()).offset(skip).limit(limit)
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_by_user_id(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
tx_type: str | None = None,
|
|
category: str | None = None,
|
|
source_type: str | None = None,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
) -> int:
|
|
"""根据筛选条件统计流水记录总数"""
|
|
from sqlalchemy import func
|
|
|
|
stmt = select(func.count(PointTransaction.id)).where(PointTransaction.user_id == user_id)
|
|
|
|
if tx_type:
|
|
stmt = stmt.where(PointTransaction.type == tx_type)
|
|
if category:
|
|
stmt = stmt.where(PointTransaction.category == category)
|
|
if source_type:
|
|
stmt = stmt.where(PointTransaction.source_type == source_type)
|
|
if start_time:
|
|
stmt = stmt.where(PointTransaction.created_at >= start_time)
|
|
if end_time:
|
|
stmt = stmt.where(PointTransaction.created_at <= end_time)
|
|
|
|
result = await db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def get_by_source(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: str,
|
|
source_type: str,
|
|
source_id: str,
|
|
) -> list[PointTransaction]:
|
|
"""根据消费来源查询流水(用于查询某次 AI 调用的扣费记录)"""
|
|
result = await db.execute(
|
|
select(PointTransaction)
|
|
.where(
|
|
PointTransaction.user_id == user_id,
|
|
PointTransaction.source_type == source_type,
|
|
PointTransaction.source_id == source_id,
|
|
)
|
|
.order_by(PointTransaction.created_at.desc())
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def sum_consumed_today(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
) -> int:
|
|
"""统计用户今日消费积分总和"""
|
|
now = datetime.now()
|
|
start_of_day = datetime.combine(now.date(), time.min)
|
|
stmt = select(func.coalesce(func.sum(PointTransaction.amount), 0)).where(
|
|
PointTransaction.user_id == user_id,
|
|
PointTransaction.type == "consume",
|
|
PointTransaction.created_at >= start_of_day,
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
|
|
# 导出实例
|
|
point_transaction = PointTransactionCRUD()
|