52 lines
1.2 KiB
Python
52 lines
1.2 KiB
Python
"""
|
|
用户 CRUD 操作
|
|
==============
|
|
|
|
用户认证相关的数据访问。
|
|
"""
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.user import User
|
|
|
|
|
|
class UserCRUD(CRUDBase[User]):
|
|
"""用户数据访问对象"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(User)
|
|
|
|
async def get_by_mobile(self, db: AsyncSession, *, mobile: str) -> User | None:
|
|
"""根据手机号获取用户"""
|
|
result = await db.execute(select(User).where(User.mobile == mobile))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_or_create_by_mobile(
|
|
self, db: AsyncSession, *, mobile: str, nickname: str | None = None
|
|
) -> User:
|
|
"""
|
|
根据手机号获取或创建用户
|
|
|
|
Returns:
|
|
已存在或新创建的用户
|
|
"""
|
|
user = await self.get_by_mobile(db, mobile=mobile)
|
|
|
|
if user is None:
|
|
# 创建新用户
|
|
user = await self.create(
|
|
db,
|
|
obj_in={
|
|
"mobile": mobile,
|
|
"nickname": nickname or f"用户_{mobile[-4:]}",
|
|
},
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
# 导出实例
|
|
user = UserCRUD()
|