from typing import Any, Generic, List, Optional, Type, TypeVar from sqlmodel import Session, SQLModel, select T = TypeVar("T", bound=SQLModel) class BaseCrud(Generic[T]): """ 通用CRUD操作基类,类似于Spring Data JPA的CrudRepository 泛型T必须是SQLModel的子类 """ def __init__(self, model_class: Type[T], session: Session): """ 初始化CRUD操作类 Args: model_class: SQLModel模型类 session: SQLModel会话对象 """ self.model_class = model_class self.session = session def save(self, entity: T) -> T: """ 保存实体,如果存在则更新,否则创建新的 Args: entity: 要保存的实体 Returns: 保存后的实体 """ self.session.add(entity) try: self.session.commit() self.session.refresh(entity) return entity except Exception as e: self.session.rollback() raise e def save_all(self, entities: List[T]) -> List[T]: """ 批量保存实体 Args: entities: 要保存的实体列表 Returns: 保存后的实体列表 """ for entity in entities: self.session.add(entity) try: self.session.commit() for entity in entities: self.session.refresh(entity) return entities except Exception as e: self.session.rollback() raise e def find_by_id(self, id: Any) -> Optional[T]: """ 通过ID查找实体 Args: id: 实体ID Returns: 找到的实体,如果不存在则返回None """ return self.session.get(self.model_class, id) def exists_by_id(self, id: Any) -> bool: """ 检查指定ID的实体是否存在 Args: id: 实体ID Returns: 如果存在则返回True,否则返回False """ return self.find_by_id(id) is not None def find_all(self) -> List[T]: """ 查找所有实体 Returns: 实体列表 """ statement = select(self.model_class) results = self.session.exec(statement).all() return results def find_all_by_ids(self, ids: List[Any]) -> List[T]: """ 通过ID列表查找多个实体 Args: ids: ID列表 Returns: 找到的实体列表 """ if not ids: return [] statement = select(self.model_class).where(self.model_class.id.in_(ids)) return self.session.exec(statement).all() def count(self) -> int: """ 计算实体总数 Returns: 实体总数 """ statement = select(self.model_class) return len(self.session.exec(statement).all()) def delete_by_id(self, id: Any) -> bool: """ 通过ID删除实体 Args: id: 实体ID Returns: 是否成功删除 """ entity = self.find_by_id(id) if entity is None: return False return self.delete(entity) def delete(self, entity: T) -> bool: """ 删除指定实体 Args: entity: 要删除的实体 Returns: 是否成功删除 """ try: self.session.delete(entity) self.session.commit() return True except Exception as e: self.session.rollback() raise e def delete_all(self) -> bool: """ 删除所有实体 Returns: 是否成功删除 """ try: statement = select(self.model_class) entities = self.session.exec(statement).all() for entity in entities: self.session.delete(entity) self.session.commit() return True except Exception as e: self.session.rollback() raise e def delete_all_by_ids(self, ids: List[Any]) -> bool: """ 通过ID列表批量删除实体 Args: ids: ID列表 Returns: 是否成功删除 """ if not ids: return True try: entities = self.find_all_by_ids(ids) for entity in entities: self.session.delete(entity) self.session.commit() return True except Exception as e: self.session.rollback() raise e def find_by(self, **kwargs) -> List[T]: """ 根据条件查询实体 Args: **kwargs: 查询条件,字段名=值 Returns: 符合条件的实体列表 """ statement = select(self.model_class) for key, value in kwargs.items(): if hasattr(self.model_class, key): statement = statement.where(getattr(self.model_class, key) == value) return self.session.exec(statement).all() def find_one_by(self, **kwargs) -> Optional[T]: """ 根据条件查询单个实体 Args: **kwargs: 查询条件,字段名=值 Returns: 符合条件的第一个实体,如果不存在则返回None """ results = self.find_by(**kwargs) return results[0] if results else None