Files
meijiaka-zy/python-api/app/crud/base.py
T
小鱼开发 8e5174c58c fix: 修复轮询接口 CORS 头丢失 + CRUD 类型不匹配
- main.py: 自定义 exception_handler 手动添加 CORS 头,避免 500 响应被浏览器拦截
- crud/base.py: CRUDBase.get 的 id 参数改为 Any,兼容 int/BigInt 主键
- api/v1/points.py: query_recharge_status 去掉 str() 转换,直接传 int order_id
2026-05-08 21:56:56 +08:00

128 lines
4.0 KiB
Python

"""
CRUD 基础类
==========
提供通用的数据访问方法,所有业务 CRUD 必须继承此类。
"""
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.base import BaseModel as AppBaseModel
ModelType = TypeVar("ModelType", bound=AppBaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel, default=Any)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel, default=Any)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
"""
通用 CRUD 基类
所有业务 CRUD 必须继承此类,确保接口统一。
使用示例:
class UserCRUD(CRUDBase[User, UserCreate, UserUpdate]):
def __init__(self):
super().__init__(User)
# 添加业务特定方法...
user = UserCRUD()
"""
def __init__(self, model: type[ModelType]):
"""
Args:
model: SQLAlchemy 模型类
"""
self.model = model
async def get(self, db: AsyncSession, id: Any) -> ModelType | None:
"""根据 ID 获取单个对象"""
result = await db.execute(select(self.model).where(self.model.id == id))
return result.scalar_one_or_none()
async def get_multi(
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
) -> list[ModelType]:
"""获取多个对象(分页)"""
result = await db.execute(select(self.model).offset(skip).limit(limit))
return list(result.scalars().all())
async def create(
self, db: AsyncSession, *, obj_in: CreateSchemaType | dict[str, Any], commit: bool = True
) -> ModelType:
"""创建对象
Args:
db: 数据库会话
obj_in: 对象数据(Pydantic 模型或字典)
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
"""
if isinstance(obj_in, BaseModel):
obj_in = obj_in.model_dump(exclude_unset=True)
db_obj = self.model(**obj_in)
db.add(db_obj)
if commit:
await db.commit()
await db.refresh(db_obj)
else:
# 不提交时刷新以获取默认值(如自增ID),但需在事务中
await db.flush()
await db.refresh(db_obj)
return db_obj
async def update(
self,
db: AsyncSession,
*,
db_obj: ModelType,
obj_in: UpdateSchemaType | dict[str, Any],
commit: bool = True,
) -> ModelType:
"""更新对象
Args:
db: 数据库会话
db_obj: 数据库对象
obj_in: 更新数据(Pydantic 模型或字典)
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
"""
if isinstance(obj_in, BaseModel):
update_data = obj_in.model_dump(exclude_unset=True)
else:
update_data = obj_in
for field, value in update_data.items():
if hasattr(db_obj, field) and value is not None:
setattr(db_obj, field, value)
if commit:
await db.commit()
await db.refresh(db_obj)
else:
await db.flush()
return db_obj
async def delete(self, db: AsyncSession, *, id: str, commit: bool = True) -> ModelType | None:
"""删除对象
Args:
db: 数据库会话
id: 对象ID
commit: 是否自动提交(默认True)。如需在事务中批量操作,设为False由调用方控制提交
"""
obj = await self.get(db, id)
if obj:
await db.delete(obj)
if commit:
await db.commit()
return obj
async def count(self, db: AsyncSession) -> int:
"""统计总数"""
result = await db.execute(select(func.count(self.model.id)))
return result.scalar() or 0