""" 用户设备 CRUD 操作 ================== 单设备登录约束:一个用户同一时间只能在一个设备上登录。 核心操作是「覆盖」而非「新增」,使用 INSERT ... ON CONFLICT DO UPDATE 保证原子性。 """ from uuid import UUID from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from app.crud.base import CRUDBase from app.models.user_device import UserDevice class UserDeviceCRUD(CRUDBase[UserDevice]): """用户设备数据访问对象""" def __init__(self) -> None: super().__init__(UserDevice) async def get_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> UserDevice | None: """根据用户 ID 获取设备记录""" result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id)) return result.scalar_one_or_none() async def create_or_update( self, db: AsyncSession, *, user_id: UUID | str, device_id: str, device_name: str | None = None, os_info: str | None = None, app_version: str | None = None, refresh_token_hash: str | None = None, ) -> UserDevice: """ 创建或覆盖设备记录(单设备登录的核心操作) 使用 PostgreSQL 的 ON CONFLICT DO UPDATE,一条 SQL 完成: - 不存在 → INSERT - 已存在 → UPDATE(覆盖旧设备信息) """ from datetime import UTC, datetime now = datetime.now(UTC) stmt = ( insert(UserDevice) .values( user_id=user_id, device_id=device_id, device_name=device_name, os_info=os_info, app_version=app_version, refresh_token_hash=refresh_token_hash, last_active_at=now, created_at=now, updated_at=now, ) .on_conflict_do_update( index_elements=["user_id"], set_={ "device_id": device_id, "device_name": device_name, "os_info": os_info, "app_version": app_version, "refresh_token_hash": refresh_token_hash, "last_active_at": now, "updated_at": now, }, ) ) await db.execute(stmt) await db.commit() # 返回最新的记录 result = await db.execute(select(UserDevice).where(UserDevice.user_id == user_id)) return result.scalar_one() async def delete_by_user_id(self, db: AsyncSession, *, user_id: UUID | str) -> bool: """根据用户 ID 删除设备记录(登出时使用)""" device = await self.get_by_user_id(db, user_id=user_id) if device is None: return False await db.delete(device) await db.commit() return True async def get_by_refresh_token_hash( self, db: AsyncSession, *, refresh_token_hash: str ) -> UserDevice | None: """根据 Refresh Token 哈希获取设备记录""" result = await db.execute( select(UserDevice).where(UserDevice.refresh_token_hash == refresh_token_hash) ) return result.scalar_one_or_none() # 导出实例 user_device = UserDeviceCRUD()