Files

108 lines
3.4 KiB
Python

"""
用户设备 CRUD 操作
==================
单设备登录约束:一个用户同一时间只能在一个设备上登录。
核心操作是「覆盖」而非「新增」,使用 INSERT ... ON CONFLICT DO UPDATE 保证原子性。
"""
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.user_device import UserDevice
class UserDeviceCRUD(CRUDBase[UserDevice]):
"""用户设备数据访问对象"""
def __init__(self) -> None:
super().__init__(UserDevice)
async def get_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> UserDevice | None:
"""根据用户 ID 获取设备记录"""
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
return result.scalar_one_or_none()
async def create_or_update(
self,
db: AsyncSession,
*,
user_id: UUID | str,
device_id: str,
device_name: str | None = None,
os_info: str | None = None,
app_version: str | None = None,
refresh_token_hash: str | None = None,
) -> UserDevice:
"""
创建或覆盖设备记录(单设备登录的核心操作)
使用 PostgreSQL 的 ON CONFLICT DO UPDATE,一条 SQL 完成:
- 不存在 → INSERT
- 已存在 → UPDATE(覆盖旧设备信息)
"""
from datetime import UTC, datetime
now = datetime.now(UTC)
stmt = (
insert(UserDevice)
.values(
user_id=user_id,
device_id=device_id,
device_name=device_name,
os_info=os_info,
app_version=app_version,
refresh_token_hash=refresh_token_hash,
last_active_at=now,
created_at=now,
updated_at=now,
)
.on_conflict_do_update(
index_elements=["user_id"],
set_={
"device_id": device_id,
"device_name": device_name,
"os_info": os_info,
"app_version": app_version,
"refresh_token_hash": refresh_token_hash,
"last_active_at": now,
"updated_at": now,
},
)
)
await db.execute(stmt)
await db.commit()
# 返回最新的记录
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
return result.scalar_one()
async def delete_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> bool:
"""根据用户 ID 删除设备记录(登出时使用)"""
device = await self.get_by_user_id(db, user_id=user_id)
if device is None:
return False
await db.delete(device)
await db.commit()
return True
async def get_by_refresh_token_hash(
self, db: AsyncSession, *, refresh_token_hash: str
) -> UserDevice | None:
"""根据 Refresh Token 哈希获取设备记录"""
result = await db.execute(
select(UserDevice).where(UserDevice.refresh_token_hash == refresh_token_hash)
)
return result.scalar_one_or_none()
# 导出实例
user_device = UserDeviceCRUD()