108 lines
3.4 KiB
Python
108 lines
3.4 KiB
Python
"""
|
|
用户设备 CRUD 操作
|
|
==================
|
|
|
|
单设备登录约束:一个用户同一时间只能在一个设备上登录。
|
|
核心操作是「覆盖」而非「新增」,使用 INSERT ... ON CONFLICT DO UPDATE 保证原子性。
|
|
"""
|
|
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.user_device import UserDevice
|
|
|
|
|
|
class UserDeviceCRUD(CRUDBase[UserDevice]):
|
|
"""用户设备数据访问对象"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(UserDevice)
|
|
|
|
async def get_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> UserDevice | None:
|
|
"""根据用户 ID 获取设备记录"""
|
|
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create_or_update(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
user_id: UUID | str,
|
|
device_id: str,
|
|
device_name: str | None = None,
|
|
os_info: str | None = None,
|
|
app_version: str | None = None,
|
|
refresh_token_hash: str | None = None,
|
|
) -> UserDevice:
|
|
"""
|
|
创建或覆盖设备记录(单设备登录的核心操作)
|
|
|
|
使用 PostgreSQL 的 ON CONFLICT DO UPDATE,一条 SQL 完成:
|
|
- 不存在 → INSERT
|
|
- 已存在 → UPDATE(覆盖旧设备信息)
|
|
"""
|
|
from datetime import UTC, datetime
|
|
|
|
now = datetime.now(UTC)
|
|
|
|
stmt = (
|
|
insert(UserDevice)
|
|
.values(
|
|
user_id=user_id,
|
|
device_id=device_id,
|
|
device_name=device_name,
|
|
os_info=os_info,
|
|
app_version=app_version,
|
|
refresh_token_hash=refresh_token_hash,
|
|
last_active_at=now,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
.on_conflict_do_update(
|
|
index_elements=["user_id"],
|
|
set_={
|
|
"device_id": device_id,
|
|
"device_name": device_name,
|
|
"os_info": os_info,
|
|
"app_version": app_version,
|
|
"refresh_token_hash": refresh_token_hash,
|
|
"last_active_at": now,
|
|
"updated_at": now,
|
|
},
|
|
)
|
|
)
|
|
|
|
await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
# 返回最新的记录
|
|
result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id))
|
|
return result.scalar_one()
|
|
|
|
async def delete_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> bool:
|
|
"""根据用户 ID 删除设备记录(登出时使用)"""
|
|
device = await self.get_by_user_id(db, user_id=user_id)
|
|
if device is None:
|
|
return False
|
|
|
|
await db.delete(device)
|
|
await db.commit()
|
|
return True
|
|
|
|
async def get_by_refresh_token_hash(
|
|
self, db: AsyncSession, *, refresh_token_hash: str
|
|
) -> UserDevice | None:
|
|
"""根据 Refresh Token 哈希获取设备记录"""
|
|
result = await db.execute(
|
|
select(UserDevice).where(UserDevice.refresh_token_hash == refresh_token_hash)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
# 导出实例
|
|
user_device = UserDeviceCRUD()
|