更新 Python/base_crud.py

This commit is contained in:
2025-03-31 15:56:37 +00:00
parent 424e7a0c62
commit 00fef59ae2

View File

@@ -1,227 +1,227 @@
from typing import Any, Generic, List, Optional, Type, TypeVar from typing import Any, Generic, List, Optional, Type, TypeVar
from sqlmodel import Session, SQLModel, select from sqlmodel import Session, SQLModel, select
T = TypeVar("T", bound=SQLModel) T = TypeVar("T", bound=SQLModel)
class BaseCrud(Generic[T]): class BaseCrud(Generic[T]):
""" """
通用CRUD操作基类类似于Spring Data JPA的CrudRepository 通用CRUD操作基类类似于Spring Data JPA的CrudRepository
泛型T必须是SQLModel的子类 泛型T必须是SQLModel的子类
""" """
def __init__(self, model_class: Type[T], session: Session): def __init__(self, model_class: Type[T], session: Session):
""" """
初始化CRUD操作类 初始化CRUD操作类
Args: Args:
model_class: SQLModel模型类 model_class: SQLModel模型类
session: SQLModel会话对象 session: SQLModel会话对象
""" """
self.model_class = model_class self.model_class = model_class
self.session = session self.session = session
def save(self, entity: T) -> T: def save(self, entity: T) -> T:
""" """
保存实体如果存在则更新否则创建新的 保存实体如果存在则更新否则创建新的
Args: Args:
entity: 要保存的实体 entity: 要保存的实体
Returns: Returns:
保存后的实体 保存后的实体
""" """
self.session.add(entity) self.session.add(entity)
try: try:
self.session.commit() self.session.commit()
self.session.refresh(entity) self.session.refresh(entity)
return entity return entity
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
raise e raise e
def save_all(self, entities: List[T]) -> List[T]: def save_all(self, entities: List[T]) -> List[T]:
""" """
批量保存实体 批量保存实体
Args: Args:
entities: 要保存的实体列表 entities: 要保存的实体列表
Returns: Returns:
保存后的实体列表 保存后的实体列表
""" """
for entity in entities: for entity in entities:
self.session.add(entity) self.session.add(entity)
try: try:
self.session.commit() self.session.commit()
for entity in entities: for entity in entities:
self.session.refresh(entity) self.session.refresh(entity)
return entities return entities
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
raise e raise e
def find_by_id(self, id: Any) -> Optional[T]: def find_by_id(self, id: Any) -> Optional[T]:
""" """
通过ID查找实体 通过ID查找实体
Args: Args:
id: 实体ID id: 实体ID
Returns: Returns:
找到的实体如果不存在则返回None 找到的实体如果不存在则返回None
""" """
return self.session.get(self.model_class, id) return self.session.get(self.model_class, id)
def exists_by_id(self, id: Any) -> bool: def exists_by_id(self, id: Any) -> bool:
""" """
检查指定ID的实体是否存在 检查指定ID的实体是否存在
Args: Args:
id: 实体ID id: 实体ID
Returns: Returns:
如果存在则返回True否则返回False 如果存在则返回True否则返回False
""" """
return self.find_by_id(id) is not None return self.find_by_id(id) is not None
def find_all(self) -> List[T]: def find_all(self) -> List[T]:
""" """
查找所有实体 查找所有实体
Returns: Returns:
实体列表 实体列表
""" """
statement = select(self.model_class) statement = select(self.model_class)
results = self.session.exec(statement).all() results = self.session.exec(statement).all()
return results return results
def find_all_by_ids(self, ids: List[Any]) -> List[T]: def find_all_by_ids(self, ids: List[Any]) -> List[T]:
""" """
通过ID列表查找多个实体 通过ID列表查找多个实体
Args: Args:
ids: ID列表 ids: ID列表
Returns: Returns:
找到的实体列表 找到的实体列表
""" """
if not ids: if not ids:
return [] return []
statement = select(self.model_class).where(self.model_class.id.in_(ids)) statement = select(self.model_class).where(self.model_class.id.in_(ids))
return self.session.exec(statement).all() return self.session.exec(statement).all()
def count(self) -> int: def count(self) -> int:
""" """
计算实体总数 计算实体总数
Returns: Returns:
实体总数 实体总数
""" """
statement = select(self.model_class) statement = select(self.model_class)
return len(self.session.exec(statement).all()) return len(self.session.exec(statement).all())
def delete_by_id(self, id: Any) -> bool: def delete_by_id(self, id: Any) -> bool:
""" """
通过ID删除实体 通过ID删除实体
Args: Args:
id: 实体ID id: 实体ID
Returns: Returns:
是否成功删除 是否成功删除
""" """
entity = self.find_by_id(id) entity = self.find_by_id(id)
if entity is None: if entity is None:
return False return False
return self.delete(entity) return self.delete(entity)
def delete(self, entity: T) -> bool: def delete(self, entity: T) -> bool:
""" """
删除指定实体 删除指定实体
Args: Args:
entity: 要删除的实体 entity: 要删除的实体
Returns: Returns:
是否成功删除 是否成功删除
""" """
try: try:
self.session.delete(entity) self.session.delete(entity)
self.session.commit() self.session.commit()
return True return True
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
raise e raise e
def delete_all(self) -> bool: def delete_all(self) -> bool:
""" """
删除所有实体 删除所有实体
Returns: Returns:
是否成功删除 是否成功删除
""" """
try: try:
statement = select(self.model_class) statement = select(self.model_class)
entities = self.session.exec(statement).all() entities = self.session.exec(statement).all()
for entity in entities: for entity in entities:
self.session.delete(entity) self.session.delete(entity)
self.session.commit() self.session.commit()
return True return True
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
raise e raise e
def delete_all_by_ids(self, ids: List[Any]) -> bool: def delete_all_by_ids(self, ids: List[Any]) -> bool:
""" """
通过ID列表批量删除实体 通过ID列表批量删除实体
Args: Args:
ids: ID列表 ids: ID列表
Returns: Returns:
是否成功删除 是否成功删除
""" """
if not ids: if not ids:
return True return True
try: try:
entities = self.find_all_by_ids(ids) entities = self.find_all_by_ids(ids)
for entity in entities: for entity in entities:
self.session.delete(entity) self.session.delete(entity)
self.session.commit() self.session.commit()
return True return True
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
raise e raise e
def find_by(self, **kwargs) -> List[T]: def find_by(self, **kwargs) -> List[T]:
""" """
根据条件查询实体 根据条件查询实体
Args: Args:
**kwargs: 查询条件字段名= **kwargs: 查询条件字段名=
Returns: Returns:
符合条件的实体列表 符合条件的实体列表
""" """
statement = select(self.model_class) statement = select(self.model_class)
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self.model_class, key): if hasattr(self.model_class, key):
statement = statement.where(getattr(self.model_class, key) == value) statement = statement.where(getattr(self.model_class, key) == value)
return self.session.exec(statement).all() return self.session.exec(statement).all()
def find_one_by(self, **kwargs) -> Optional[T]: def find_one_by(self, **kwargs) -> Optional[T]:
""" """
根据条件查询单个实体 根据条件查询单个实体
Args: Args:
**kwargs: 查询条件字段名= **kwargs: 查询条件字段名=
Returns: Returns:
符合条件的第一个实体如果不存在则返回None 符合条件的第一个实体如果不存在则返回None
""" """
results = self.find_by(**kwargs) results = self.find_by(**kwargs)
return results[0] if results else None return results[0] if results else None