8e5174c58c
- main.py: 自定义 exception_handler 手动添加 CORS 头,避免 500 响应被浏览器拦截 - crud/base.py: CRUDBase.get 的 id 参数改为 Any,兼容 int/BigInt 主键 - api/v1/points.py: query_recharge_status 去掉 str() 转换,直接传 int order_id
128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
"""
|
|
CRUD 基础类
|
|
==========
|
|
|
|
提供通用的数据访问方法,所有业务 CRUD 必须继承此类。
|
|
"""
|
|
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.base import BaseModel as AppBaseModel
|
|
|
|
ModelType = TypeVar("ModelType", bound=AppBaseModel)
|
|
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel, default=Any)
|
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel, default=Any)
|
|
|
|
|
|
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|
"""
|
|
通用 CRUD 基类
|
|
|
|
所有业务 CRUD 必须继承此类,确保接口统一。
|
|
|
|
使用示例:
|
|
class UserCRUD(CRUDBase[User, UserCreate, UserUpdate]):
|
|
def __init__(self):
|
|
super().__init__(User)
|
|
|
|
# 添加业务特定方法...
|
|
|
|
user = UserCRUD()
|
|
"""
|
|
|
|
def __init__(self, model: type[ModelType]):
|
|
"""
|
|
Args:
|
|
model: SQLAlchemy 模型类
|
|
"""
|
|
self.model = model
|
|
|
|
async def get(self, db: AsyncSession, id: Any) -> ModelType | None:
|
|
"""根据 ID 获取单个对象"""
|
|
result = await db.execute(select(self.model).where(self.model.id == id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_multi(
|
|
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
|
) -> list[ModelType]:
|
|
"""获取多个对象(分页)"""
|
|
result = await db.execute(select(self.model).offset(skip).limit(limit))
|
|
return list(result.scalars().all())
|
|
|
|
async def create(
|
|
self, db: AsyncSession, *, obj_in: CreateSchemaType | dict[str, Any], commit: bool = True
|
|
) -> ModelType:
|
|
"""创建对象
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
obj_in: 对象数据(Pydantic 模型或字典)
|
|
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
|
"""
|
|
if isinstance(obj_in, BaseModel):
|
|
obj_in = obj_in.model_dump(exclude_unset=True)
|
|
db_obj = self.model(**obj_in)
|
|
db.add(db_obj)
|
|
if commit:
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
else:
|
|
# 不提交时刷新以获取默认值(如自增ID),但需在事务中
|
|
await db.flush()
|
|
await db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
async def update(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
db_obj: ModelType,
|
|
obj_in: UpdateSchemaType | dict[str, Any],
|
|
commit: bool = True,
|
|
) -> ModelType:
|
|
"""更新对象
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
db_obj: 数据库对象
|
|
obj_in: 更新数据(Pydantic 模型或字典)
|
|
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
|
"""
|
|
if isinstance(obj_in, BaseModel):
|
|
update_data = obj_in.model_dump(exclude_unset=True)
|
|
else:
|
|
update_data = obj_in
|
|
for field, value in update_data.items():
|
|
if hasattr(db_obj, field) and value is not None:
|
|
setattr(db_obj, field, value)
|
|
if commit:
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
else:
|
|
await db.flush()
|
|
return db_obj
|
|
|
|
async def delete(self, db: AsyncSession, *, id: str, commit: bool = True) -> ModelType | None:
|
|
"""删除对象
|
|
|
|
Args:
|
|
db: 数据库会话
|
|
id: 对象ID
|
|
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
|
|
"""
|
|
obj = await self.get(db, id)
|
|
if obj:
|
|
await db.delete(obj)
|
|
if commit:
|
|
await db.commit()
|
|
return obj
|
|
|
|
async def count(self, db: AsyncSession) -> int:
|
|
"""统计总数"""
|
|
result = await db.execute(select(func.count(self.model.id)))
|
|
return result.scalar() or 0
|