Files
file-templates/Python/base_crud.py
2025-03-31 15:56:37 +00:00

228 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Any, Generic, List, Optional, Type, TypeVar
from sqlmodel import Session, SQLModel, select
T = TypeVar("T", bound=SQLModel)
class BaseCrud(Generic[T]):
"""
通用CRUD操作基类类似于Spring Data JPA的CrudRepository
泛型T必须是SQLModel的子类
"""
def __init__(self, model_class: Type[T], session: Session):
"""
初始化CRUD操作类
Args:
model_class: SQLModel模型类
session: SQLModel会话对象
"""
self.model_class = model_class
self.session = session
def save(self, entity: T) -> T:
"""
保存实体,如果存在则更新,否则创建新的
Args:
entity: 要保存的实体
Returns:
保存后的实体
"""
self.session.add(entity)
try:
self.session.commit()
self.session.refresh(entity)
return entity
except Exception as e:
self.session.rollback()
raise e
def save_all(self, entities: List[T]) -> List[T]:
"""
批量保存实体
Args:
entities: 要保存的实体列表
Returns:
保存后的实体列表
"""
for entity in entities:
self.session.add(entity)
try:
self.session.commit()
for entity in entities:
self.session.refresh(entity)
return entities
except Exception as e:
self.session.rollback()
raise e
def find_by_id(self, id: Any) -> Optional[T]:
"""
通过ID查找实体
Args:
id: 实体ID
Returns:
找到的实体如果不存在则返回None
"""
return self.session.get(self.model_class, id)
def exists_by_id(self, id: Any) -> bool:
"""
检查指定ID的实体是否存在
Args:
id: 实体ID
Returns:
如果存在则返回True否则返回False
"""
return self.find_by_id(id) is not None
def find_all(self) -> List[T]:
"""
查找所有实体
Returns:
实体列表
"""
statement = select(self.model_class)
results = self.session.exec(statement).all()
return results
def find_all_by_ids(self, ids: List[Any]) -> List[T]:
"""
通过ID列表查找多个实体
Args:
ids: ID列表
Returns:
找到的实体列表
"""
if not ids:
return []
statement = select(self.model_class).where(self.model_class.id.in_(ids))
return self.session.exec(statement).all()
def count(self) -> int:
"""
计算实体总数
Returns:
实体总数
"""
statement = select(self.model_class)
return len(self.session.exec(statement).all())
def delete_by_id(self, id: Any) -> bool:
"""
通过ID删除实体
Args:
id: 实体ID
Returns:
是否成功删除
"""
entity = self.find_by_id(id)
if entity is None:
return False
return self.delete(entity)
def delete(self, entity: T) -> bool:
"""
删除指定实体
Args:
entity: 要删除的实体
Returns:
是否成功删除
"""
try:
self.session.delete(entity)
self.session.commit()
return True
except Exception as e:
self.session.rollback()
raise e
def delete_all(self) -> bool:
"""
删除所有实体
Returns:
是否成功删除
"""
try:
statement = select(self.model_class)
entities = self.session.exec(statement).all()
for entity in entities:
self.session.delete(entity)
self.session.commit()
return True
except Exception as e:
self.session.rollback()
raise e
def delete_all_by_ids(self, ids: List[Any]) -> bool:
"""
通过ID列表批量删除实体
Args:
ids: ID列表
Returns:
是否成功删除
"""
if not ids:
return True
try:
entities = self.find_all_by_ids(ids)
for entity in entities:
self.session.delete(entity)
self.session.commit()
return True
except Exception as e:
self.session.rollback()
raise e
def find_by(self, **kwargs) -> List[T]:
"""
根据条件查询实体
Args:
**kwargs: 查询条件,字段名=值
Returns:
符合条件的实体列表
"""
statement = select(self.model_class)
for key, value in kwargs.items():
if hasattr(self.model_class, key):
statement = statement.where(getattr(self.model_class, key) == value)
return self.session.exec(statement).all()
def find_one_by(self, **kwargs) -> Optional[T]:
"""
根据条件查询单个实体
Args:
**kwargs: 查询条件,字段名=值
Returns:
符合条件的第一个实体如果不存在则返回None
"""
results = self.find_by(**kwargs)
return results[0] if results else None