103 lines
2.7 KiB
Python
103 lines
2.7 KiB
Python
"""
|
||
用户 CRUD 操作
|
||
==============
|
||
|
||
用户认证相关的数据访问。
|
||
"""
|
||
|
||
from typing import Any, cast
|
||
from uuid import UUID
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.engine import CursorResult
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.crud.base import CRUDBase
|
||
from app.models.user import User
|
||
|
||
|
||
class UserCRUD(CRUDBase[User]):
|
||
"""用户数据访问对象"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__(User)
|
||
|
||
async def get_by_mobile(self, db: AsyncSession, *, mobile: str) -> User | None:
|
||
"""根据手机号获取用户"""
|
||
result = await db.execute(select(User).where(User.mobile == mobile))
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_or_create_by_mobile(
|
||
self, db: AsyncSession, *, mobile: str, nickname: str | None = None, source: str = "unknown"
|
||
) -> User:
|
||
"""
|
||
根据手机号获取或创建用户
|
||
|
||
Returns:
|
||
已存在或新创建的用户
|
||
"""
|
||
user = await self.get_by_mobile(db, mobile=mobile)
|
||
|
||
if user is None:
|
||
# 创建新用户
|
||
user = await self.create(
|
||
db,
|
||
obj_in={
|
||
"mobile": mobile,
|
||
"nickname": nickname or f"用户_{mobile[-4:]}",
|
||
"source": source,
|
||
},
|
||
)
|
||
|
||
return user
|
||
|
||
async def update_login_info(
|
||
self, db: AsyncSession, *, user_id: UUID | str, ip: str | None = None
|
||
) -> User | None:
|
||
"""
|
||
更新用户最后登录信息
|
||
"""
|
||
from datetime import UTC, datetime
|
||
|
||
user = await self.get(db, id=user_id)
|
||
if user is None:
|
||
return None
|
||
|
||
user.last_login_at = datetime.now(UTC)
|
||
if ip:
|
||
user.last_login_ip = ip
|
||
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
return user
|
||
|
||
async def update_password(
|
||
self, db: AsyncSession, *, user_id: UUID | str, password_hash: str
|
||
) -> User | None:
|
||
"""更新用户密码"""
|
||
user = await self.get(db, id=user_id)
|
||
if user is None:
|
||
return None
|
||
|
||
user.password_hash = password_hash
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
return user
|
||
|
||
async def update_extra(self, db: AsyncSession, *, user_id: UUID | str, extra: dict) -> bool:
|
||
"""
|
||
原子更新用户 extra 字段(JSONB)
|
||
|
||
使用 SQLAlchemy 的 update 语句避免读-改-写的竞态条件。
|
||
"""
|
||
from sqlalchemy import update
|
||
|
||
stmt = update(User).where(User.id == user_id).values(extra=extra)
|
||
result = cast(CursorResult[Any], await db.execute(stmt))
|
||
await db.commit()
|
||
return result.rowcount > 0
|
||
|
||
|
||
# 导出实例
|
||
user = UserCRUD()
|