SARepo 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -0,0 +1,50 @@
1
+
2
+ from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
+
4
+ from SARepo.sa_repo import Spec
5
+ from .base import Page, PageRequest
6
+
7
+ T = TypeVar("T")
8
+
9
+ class CrudRepository(Protocol, Generic[T]):
10
+ model: Type[T]
11
+ def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]: ...
12
+ def get(self, id_: Any, *, include_deleted: bool = False) -> T: ...
13
+ def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]: ...
14
+ def add(self, entity: T) -> T: ...
15
+ def update(self, entity: T) -> T: ...
16
+ def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool: ...
17
+ def delete_by_id(self, id_: Any) -> bool: ...
18
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: ...
19
+ def get_all_by_column(
20
+ self,
21
+ column_name: str,
22
+ value: Any,
23
+ *,
24
+ limit: Optional[int] = None,
25
+ order_by=None,
26
+ include_deleted: bool = False,
27
+ **extra_filters
28
+ ) -> list[T]: ...
29
+ def find_all_by_column(
30
+ self,
31
+ column_name: str,
32
+ value: Any,
33
+ *,
34
+ limit: Optional[int] = None,
35
+ order_by=None,
36
+ include_deleted: bool = False,
37
+ **extra_filters
38
+ ) -> list[T]: ...
39
+ def get_or_create(
40
+ self,
41
+ defaults: Optional[dict] = None,
42
+ **unique_filters
43
+ ) -> tuple[T, bool]: ...
44
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
45
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
46
+ def aggregate_min(self, column_name: str, **filters): ...
47
+ def aggregate_max(self, column_name: str, **filters): ...
48
+ def aggregate_sum(self, column_name: str, **filters): ...
49
+ def count(self, **filters) -> int: ...
50
+ def restore(self, id_: Any) -> bool: ...
@@ -28,22 +28,38 @@ class SARepository(Generic[T]):
28
28
  def _select(self):
29
29
  return select(self.model)
30
30
 
31
- def getAll(self, limit: Optional[int] = None) -> List[T]:
31
+ def _has_soft_delete(self) -> bool:
32
+ return hasattr(self.model, "is_deleted")
33
+
34
+ def _apply_alive_filter(self, stmt, include_deleted: bool) :
35
+ if self._has_soft_delete() and not include_deleted:
36
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
37
+ return stmt
38
+
39
+ def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
32
40
  stmt = select(self.model)
41
+ stmt = self._apply_alive_filter(stmt, include_deleted)
33
42
  if limit is not None:
34
43
  stmt = stmt.limit(limit)
35
44
  result = self.session.execute(stmt)
36
- return result.scalars().all()
45
+ return result.scalars().all()
37
46
 
38
- def get(self, id_: Any) -> T:
47
+ def get(self, id_: Any, *, include_deleted: bool = False) -> T:
39
48
  obj = self.session.get(self.model, id_)
40
49
  if not obj:
41
50
  raise NotFoundError(f"{self.model.__name__}({id_}) not found")
51
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
52
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found") # скрываем как «нет»
42
53
  return obj
43
54
 
44
- def try_get(self, id_: Any) -> Optional[T]:
45
- return self.session.get(self.model, id_)
46
-
55
+ def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
56
+ obj = self.session.get(self.model, id_)
57
+ if not obj:
58
+ return None
59
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
60
+ return None
61
+ return obj
62
+
47
63
  def add(self, entity: T) -> T:
48
64
  self.session.add(entity)
49
65
  self.session.flush()
@@ -55,7 +71,13 @@ class SARepository(Generic[T]):
55
71
  self.session.refresh(entity)
56
72
  return entity
57
73
 
58
- def remove(self, entity: T) -> None:
74
+ def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
75
+ if entity is None and id is None:
76
+ raise ValueError("remove() requires either entity or id")
77
+
78
+ if id is not None:
79
+ return self._delete_by_id(id)
80
+
59
81
  insp = inspect(entity, raiseerr=False)
60
82
  if not (insp and (insp.persistent or insp.pending)):
61
83
  pk = getattr(entity, "id", None)
@@ -63,30 +85,35 @@ class SARepository(Generic[T]):
63
85
  raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
64
86
  entity = self.session.get(self.model, pk)
65
87
  if entity is None:
66
- return
88
+ return False
67
89
  if hasattr(entity, "is_deleted"):
68
90
  setattr(entity, "is_deleted", True)
69
91
  else:
70
92
  self.session.delete(entity)
93
+
94
+ return True
71
95
 
72
- def delete_by_id(self, id_: Any) -> bool:
96
+ def _delete_by_id(self, id_: Any) -> bool:
73
97
  obj = self.session.get(self.model, id_)
74
98
  if not obj:
75
99
  return False
76
100
  self.remove(obj)
77
101
  return True
78
102
 
79
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
80
- stmt = self._select()
103
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]:
104
+ base = self._select()
81
105
  if spec:
82
- stmt = spec(stmt)
106
+ base = spec(base)
107
+ base = self._apply_alive_filter(base, include_deleted)
83
108
  if order_by is not None:
84
- stmt = stmt.order_by(order_by)
109
+ base = base.order_by(order_by)
110
+
85
111
  total = self.session.execute(
86
- select(func.count()).select_from(stmt.subquery())
112
+ select(func.count()).select_from(base.subquery())
87
113
  ).scalar_one()
114
+
88
115
  items = self.session.execute(
89
- stmt.offset(page.page * page.size).limit(page.size)
116
+ base.offset(page.page * page.size).limit(page.size)
90
117
  ).scalars().all()
91
118
  return Page(items, total, page.page, page.size)
92
119
 
@@ -102,8 +129,7 @@ class SARepository(Generic[T]):
102
129
  ) -> list[T]:
103
130
  col = self._resolve_column(column_name)
104
131
  stmt = select(self.model).where(col == value)
105
- if not include_deleted and hasattr(self.model, "is_deleted"):
106
- stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
132
+ stmt = self._apply_alive_filter(stmt, include_deleted)
107
133
  stmt = self._apply_filters(stmt, **extra_filters)
108
134
  if order_by is not None:
109
135
  stmt = stmt.order_by(order_by)
@@ -111,7 +137,7 @@ class SARepository(Generic[T]):
111
137
  stmt = stmt.limit(limit)
112
138
  res = self.session.execute(stmt)
113
139
  return res.scalars().all()
114
-
140
+
115
141
  # Alias
116
142
  def find_all_by_column(self, *args, **kwargs):
117
143
  return self.get_all_by_column(*args, **kwargs)
@@ -222,21 +248,37 @@ class SAAsyncRepository(Generic[T]):
222
248
  def _select(self):
223
249
  return select(self.model)
224
250
 
225
- async def getAll(self, limit: Optional[int] = None) -> List[T]:
251
+ def _has_soft_delete(self) -> bool:
252
+ return hasattr(self.model, "is_deleted")
253
+
254
+ def _apply_alive_filter(self, stmt, include_deleted: bool) :
255
+ if self._has_soft_delete() and not include_deleted:
256
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
257
+ return stmt
258
+
259
+ async def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
226
260
  stmt = select(self.model)
261
+ stmt = self._apply_alive_filter(stmt, include_deleted)
227
262
  if limit is not None:
228
263
  stmt = stmt.limit(limit)
229
264
  result = await self.session.execute(stmt)
230
265
  return result.scalars().all()
231
266
 
232
- async def get(self, id_: Any) -> T:
267
+ async def get(self, id_: Any, *, include_deleted: bool = False) -> T:
233
268
  obj = await self.session.get(self.model, id_)
234
269
  if not obj:
235
270
  raise NotFoundError(f"{self.model.__name__}({id_}) not found")
271
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
272
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
236
273
  return obj
237
274
 
238
- async def try_get(self, id_: Any) -> Optional[T]:
239
- return await self.session.get(self.model, id_)
275
+ async def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
276
+ obj = await self.session.get(self.model, id_)
277
+ if not obj:
278
+ return None
279
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
280
+ return None
281
+ return obj
240
282
 
241
283
  async def add(self, entity: T) -> T:
242
284
  self.session.add(entity)
@@ -249,7 +291,13 @@ class SAAsyncRepository(Generic[T]):
249
291
  await self.session.refresh(entity)
250
292
  return entity
251
293
 
252
- async def remove(self, entity: T) -> None:
294
+ async def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
295
+ if entity is None and id is None:
296
+ raise ValueError("remove() requires either entity or id")
297
+
298
+ if id is not None:
299
+ return await self._delete_by_id(id)
300
+
253
301
  insp = inspect(entity, raiseerr=False)
254
302
  if not (insp and (insp.persistent or insp.pending)):
255
303
  pk = getattr(entity, "id", None)
@@ -257,34 +305,40 @@ class SAAsyncRepository(Generic[T]):
257
305
  raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
258
306
  entity = await self.session.get(self.model, pk)
259
307
  if entity is None:
260
- return
308
+ return False
261
309
  if hasattr(entity, "is_deleted"):
262
310
  setattr(entity, "is_deleted", True)
263
311
  else:
264
312
  await self.session.delete(entity)
313
+
314
+ return True
265
315
 
266
- async def delete_by_id(self, id_: Any) -> bool:
316
+
317
+ async def _delete_by_id(self, id_: Any) -> bool:
267
318
  obj = await self.session.get(self.model, id_)
268
319
  if not obj:
269
320
  return False
270
321
  await self.remove(obj)
271
322
  return True
272
323
 
273
- async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
274
- stmt = self._select()
324
+ async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: # type: ignore
325
+ base = self._select()
275
326
  if spec:
276
- stmt = spec(stmt)
327
+ base = spec(base)
328
+ base = self._apply_alive_filter(base, include_deleted)
277
329
  if order_by is not None:
278
- stmt = stmt.order_by(order_by)
330
+ base = base.order_by(order_by)
331
+
279
332
  total = (await self.session.execute(
280
- select(func.count()).select_from(stmt.subquery())
333
+ select(func.count()).select_from(base.subquery())
281
334
  )).scalar_one()
335
+
282
336
  res = await self.session.execute(
283
- stmt.offset(page.page * page.size).limit(page.size)
337
+ base.offset(page.page * page.size).limit(page.size)
284
338
  )
285
339
  items = res.scalars().all()
286
340
  return Page(items, total, page.page, page.size)
287
-
341
+
288
342
  async def get_all_by_column(
289
343
  self,
290
344
  column_name: str,
@@ -297,8 +351,7 @@ class SAAsyncRepository(Generic[T]):
297
351
  ) -> list[T]:
298
352
  col = self._resolve_column(column_name)
299
353
  stmt = select(self.model).where(col == value)
300
- if not include_deleted and hasattr(self.model, "is_deleted"):
301
- stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
354
+ stmt = self._apply_alive_filter(stmt, include_deleted)
302
355
  stmt = self._apply_filters(stmt, **extra_filters)
303
356
  if order_by is not None:
304
357
  stmt = stmt.order_by(order_by)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
 
2
2
  [project]
3
3
  name = "SARepo"
4
- version = "0.1.4"
4
+ version = "0.1.6"
5
5
  description = "Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x"
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
@@ -132,8 +132,11 @@ def test_restore(session, repo):
132
132
  seed(session, repo)
133
133
 
134
134
  beta = repo.find_all_by_column("title", "beta")[0]
135
- beta.is_deleted = True
135
+ beta.is_deleted = False
136
136
  session.commit()
137
+
138
+ deleted = repo.remove(id=beta.id)
139
+ print("DELETED ", deleted)
137
140
 
138
141
  ok = repo.restore(beta.id)
139
142
  assert ok is True
@@ -1,48 +0,0 @@
1
-
2
- from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
- from .base import Page, PageRequest
4
-
5
- T = TypeVar("T")
6
-
7
- class CrudRepository(Protocol, Generic[T]):
8
- model: Type[T]
9
- def getAll(self, limit: Optional[int]) -> List[T]: ...
10
- def get(self, id_: Any) -> T: ...
11
- def try_get(self, id_: Any) -> Optional[T]: ...
12
- def add(self, entity: T) -> T: ...
13
- def update(self, entity: T) -> T: ...
14
- def remove(self, entity: T) -> None: ...
15
- def delete_by_id(self, id_: Any) -> bool: ...
16
- def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
17
- def get_all_by_column(
18
- self,
19
- column_name: str,
20
- value: Any,
21
- *,
22
- limit: Optional[int] = None,
23
- order_by=None,
24
- include_deleted: bool = False,
25
- **extra_filters
26
- ) -> list[T]: ...
27
- def find_all_by_column(
28
- self,
29
- column_name: str,
30
- value: Any,
31
- *,
32
- limit: Optional[int] = None,
33
- order_by=None,
34
- include_deleted: bool = False,
35
- **extra_filters
36
- ) -> list[T]: ...
37
- def get_or_create(
38
- self,
39
- defaults: Optional[dict] = None,
40
- **unique_filters
41
- ) -> tuple[T, bool]: ...
42
- def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
43
- def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
44
- def aggregate_min(self, column_name: str, **filters): ...
45
- def aggregate_max(self, column_name: str, **filters): ...
46
- def aggregate_sum(self, column_name: str, **filters): ...
47
- def count(self, **filters) -> int: ...
48
- def restore(self, id_: Any) -> bool: ...
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes