SARepo 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -0,0 +1,74 @@
1
+ from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
2
+
3
+ from SARepo.sa_repo import Spec
4
+ from .base import Page, PageRequest
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class CrudRepository(Protocol, Generic[T]):
10
+ model: Type[T]
11
+
12
+ def getAll(self,
13
+ limit: Optional[int] = None,
14
+ *,
15
+ include_deleted: bool = False,
16
+ order_by=None,
17
+ **filters) -> List[T]: ...
18
+
19
+ def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T: ...
20
+
21
+ def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]: ...
22
+
23
+ def add(self, entity: T) -> T: ...
24
+
25
+ def update(self, entity: T) -> T: ...
26
+
27
+ def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool: ...
28
+
29
+ def delete_by_id(self, id_: Any) -> bool: ...
30
+
31
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> \
32
+ Page[T]: ...
33
+
34
+ def get_all_by_column(
35
+ self,
36
+ column_name: str,
37
+ value: Any,
38
+ *,
39
+ limit: Optional[int] = None,
40
+ order_by=None,
41
+ include_deleted: bool = False,
42
+ **extra_filters
43
+ ) -> list[T]: ...
44
+
45
+ def find_all_by_column(
46
+ self,
47
+ column_name: str,
48
+ value: Any,
49
+ *,
50
+ limit: Optional[int] = None,
51
+ order_by=None,
52
+ include_deleted: bool = False,
53
+ **extra_filters
54
+ ) -> list[T]: ...
55
+
56
+ def get_or_create(
57
+ self,
58
+ defaults: Optional[dict] = None,
59
+ **unique_filters
60
+ ) -> tuple[T, bool]: ...
61
+
62
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
63
+
64
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
65
+
66
+ def aggregate_min(self, column_name: str, **filters): ...
67
+
68
+ def aggregate_max(self, column_name: str, **filters): ...
69
+
70
+ def aggregate_sum(self, column_name: str, **filters): ...
71
+
72
+ def count(self, **filters) -> int: ...
73
+
74
+ def restore(self, id_: Any) -> bool: ...
@@ -1,4 +1,3 @@
1
-
2
1
  from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
2
  from sqlalchemy.orm import Session
4
3
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,8 +7,10 @@ from .base import PageRequest, Page, NotFoundError
8
7
  T = TypeVar("T")
9
8
  Spec = Callable
10
9
 
10
+
11
11
  class SARepository(Generic[T]):
12
12
  """Synchronous repository implementation for SQLAlchemy 2.x."""
13
+
13
14
  def __init__(self, model: Type[T], session: Session):
14
15
  self.model = model
15
16
  self.session = session
@@ -31,35 +32,52 @@ class SARepository(Generic[T]):
31
32
  def _has_soft_delete(self) -> bool:
32
33
  return hasattr(self.model, "is_deleted")
33
34
 
34
- def _apply_alive_filter(self, stmt, include_deleted: bool) :
35
+ def _apply_alive_filter(self, stmt, include_deleted: bool):
35
36
  if self._has_soft_delete() and not include_deleted:
36
37
  stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
37
38
  return stmt
38
-
39
- def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
39
+
40
+ def getAll(
41
+ self,
42
+ limit: Optional[int] = None,
43
+ *,
44
+ include_deleted: bool = False,
45
+ order_by=None,
46
+ **filters
47
+ ) -> List[T]:
40
48
  stmt = select(self.model)
49
+ stmt = self._apply_filters(stmt, **filters)
41
50
  stmt = self._apply_alive_filter(stmt, include_deleted)
51
+ if order_by is not None:
52
+ stmt = stmt.order_by(order_by)
42
53
  if limit is not None:
43
54
  stmt = stmt.limit(limit)
44
55
  result = self.session.execute(stmt)
45
- return result.scalars().all()
46
-
47
- def get(self, id_: Any, *, include_deleted: bool = False) -> T:
48
- obj = self.session.get(self.model, id_)
49
- if not obj:
50
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
56
+ return result.scalars().all()
57
+
58
+ def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T:
59
+ if id_ is not None:
60
+ obj = self.session.get(self.model, id_)
61
+ if not obj:
62
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
63
+ else:
64
+ stmt = select(self.model)
65
+ stmt = self._apply_filters(stmt, **filters)
66
+ stmt = self._apply_alive_filter(stmt, include_deleted)
67
+ res = self.session.execute(stmt)
68
+ obj = res.scalars().first()
69
+ if not obj:
70
+ raise NotFoundError(f"{self.model.__name__} not found for {filters}")
51
71
  if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
52
- raise NotFoundError(f"{self.model.__name__}({id_}) not found") # скрываем как «нет»
72
+ raise NotFoundError(f"{self.model.__name__}({getattr(obj, 'id', '?')}) deleted")
53
73
  return obj
54
74
 
55
- def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
56
- obj = self.session.get(self.model, id_)
57
- if not obj:
58
- return None
59
- if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
75
+ def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]:
76
+ try:
77
+ return self.get(id_=id_, include_deleted=include_deleted, **filters)
78
+ except NotFoundError:
60
79
  return None
61
- return obj
62
-
80
+
63
81
  def add(self, entity: T) -> T:
64
82
  self.session.add(entity)
65
83
  self.session.flush()
@@ -74,10 +92,10 @@ class SARepository(Generic[T]):
74
92
  def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
75
93
  if entity is None and id is None:
76
94
  raise ValueError("remove() requires either entity or id")
77
-
95
+
78
96
  if id is not None:
79
97
  return self._delete_by_id(id)
80
-
98
+
81
99
  insp = inspect(entity, raiseerr=False)
82
100
  if not (insp and (insp.persistent or insp.pending)):
83
101
  pk = getattr(entity, "id", None)
@@ -90,7 +108,7 @@ class SARepository(Generic[T]):
90
108
  setattr(entity, "is_deleted", True)
91
109
  else:
92
110
  self.session.delete(entity)
93
-
111
+
94
112
  return True
95
113
 
96
114
  def _delete_by_id(self, id_: Any) -> bool:
@@ -100,7 +118,8 @@ class SARepository(Generic[T]):
100
118
  self.remove(obj)
101
119
  return True
102
120
 
103
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]:
121
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> \
122
+ Page[T]:
104
123
  base = self._select()
105
124
  if spec:
106
125
  base = spec(base)
@@ -116,16 +135,16 @@ class SARepository(Generic[T]):
116
135
  base.offset(page.page * page.size).limit(page.size)
117
136
  ).scalars().all()
118
137
  return Page(items, total, page.page, page.size)
119
-
138
+
120
139
  def get_all_by_column(
121
- self,
122
- column_name: str,
123
- value: Any,
124
- *,
125
- limit: Optional[int] = None,
126
- order_by=None,
127
- include_deleted: bool = False,
128
- **extra_filters
140
+ self,
141
+ column_name: str,
142
+ value: Any,
143
+ *,
144
+ limit: Optional[int] = None,
145
+ order_by=None,
146
+ include_deleted: bool = False,
147
+ **extra_filters
129
148
  ) -> list[T]:
130
149
  col = self._resolve_column(column_name)
131
150
  stmt = select(self.model).where(col == value)
@@ -137,15 +156,15 @@ class SARepository(Generic[T]):
137
156
  stmt = stmt.limit(limit)
138
157
  res = self.session.execute(stmt)
139
158
  return res.scalars().all()
140
-
159
+
141
160
  # Alias
142
161
  def find_all_by_column(self, *args, **kwargs):
143
162
  return self.get_all_by_column(*args, **kwargs)
144
-
163
+
145
164
  def get_or_create(
146
- self,
147
- defaults: Optional[dict] = None,
148
- **unique_filters
165
+ self,
166
+ defaults: Optional[dict] = None,
167
+ **unique_filters
149
168
  ) -> tuple[T, bool]:
150
169
  """
151
170
  Возвращает (obj, created). unique_filters определяют уникальность.
@@ -228,12 +247,14 @@ class SARepository(Generic[T]):
228
247
  return True
229
248
  return False
230
249
 
250
+
231
251
  class SAAsyncRepository(Generic[T]):
232
252
  """Async repository implementation for SQLAlchemy 2.x."""
253
+
233
254
  def __init__(self, model: Type[T], session: AsyncSession):
234
255
  self.model = model
235
256
  self.session = session
236
-
257
+
237
258
  def _resolve_column(self, column_name: str):
238
259
  try:
239
260
  return getattr(self.model, column_name)
@@ -244,41 +265,71 @@ class SAAsyncRepository(Generic[T]):
244
265
  if filters:
245
266
  stmt = stmt.filter_by(**filters)
246
267
  return stmt
247
-
268
+
248
269
  def _select(self):
249
270
  return select(self.model)
250
-
271
+
251
272
  def _has_soft_delete(self) -> bool:
252
273
  return hasattr(self.model, "is_deleted")
253
274
 
254
- def _apply_alive_filter(self, stmt, include_deleted: bool) :
275
+ def _apply_alive_filter(self, stmt, include_deleted: bool):
255
276
  if self._has_soft_delete() and not include_deleted:
256
277
  stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
257
278
  return stmt
258
279
 
259
- async def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
280
+ async def getAll(
281
+ self,
282
+ limit: Optional[int] = None,
283
+ *,
284
+ include_deleted: bool = False,
285
+ order_by=None,
286
+ **filters
287
+ ) -> List[T]:
288
+ """
289
+ Получить все записи с возможностью фильтрации (username='foo', age__gt=20, и т.д.)
290
+ """
260
291
  stmt = select(self.model)
292
+ stmt = self._apply_filters(stmt, **filters)
261
293
  stmt = self._apply_alive_filter(stmt, include_deleted)
294
+ if order_by is not None:
295
+ stmt = stmt.order_by(order_by)
262
296
  if limit is not None:
263
297
  stmt = stmt.limit(limit)
264
298
  result = await self.session.execute(stmt)
265
299
  return result.scalars().all()
266
300
 
267
- async def get(self, id_: Any, *, include_deleted: bool = False) -> T:
268
- obj = await self.session.get(self.model, id_)
269
- if not obj:
270
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
301
+ async def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T:
302
+ """
303
+ Получить один объект по id или по произвольным фильтрам.
304
+ Пример:
305
+ await repo.get(id_=1)
306
+ await repo.get(username='ibrahim')
307
+ await repo.get(username__ilike='%rah%')
308
+ """
309
+ if id_ is not None:
310
+ obj = await self.session.get(self.model, id_)
311
+ if not obj:
312
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
313
+ else:
314
+ stmt = select(self.model)
315
+ stmt = self._apply_filters(stmt, **filters)
316
+ stmt = self._apply_alive_filter(stmt, include_deleted)
317
+ res = await self.session.execute(stmt)
318
+ obj = res.scalars().first()
319
+ if not obj:
320
+ raise NotFoundError(f"{self.model.__name__} not found for {filters}")
271
321
  if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
272
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
322
+ raise NotFoundError(f"{self.model.__name__}({getattr(obj, 'id', '?')}) deleted")
273
323
  return obj
274
324
 
275
- async def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
276
- obj = await self.session.get(self.model, id_)
277
- if not obj:
278
- return None
279
- if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
325
+ async def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]:
326
+ """
327
+ Как get(), но не выбрасывает исключение при отсутствии объекта.
328
+ """
329
+ try:
330
+ return await self.get(id_=id_, include_deleted=include_deleted, **filters)
331
+ except NotFoundError:
280
332
  return None
281
- return obj
282
333
 
283
334
  async def add(self, entity: T) -> T:
284
335
  self.session.add(entity)
@@ -294,10 +345,10 @@ class SAAsyncRepository(Generic[T]):
294
345
  async def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
295
346
  if entity is None and id is None:
296
347
  raise ValueError("remove() requires either entity or id")
297
-
348
+
298
349
  if id is not None:
299
350
  return await self._delete_by_id(id)
300
-
351
+
301
352
  insp = inspect(entity, raiseerr=False)
302
353
  if not (insp and (insp.persistent or insp.pending)):
303
354
  pk = getattr(entity, "id", None)
@@ -310,10 +361,9 @@ class SAAsyncRepository(Generic[T]):
310
361
  setattr(entity, "is_deleted", True)
311
362
  else:
312
363
  await self.session.delete(entity)
313
-
364
+
314
365
  return True
315
366
 
316
-
317
367
  async def _delete_by_id(self, id_: Any) -> bool:
318
368
  obj = await self.session.get(self.model, id_)
319
369
  if not obj:
@@ -321,7 +371,8 @@ class SAAsyncRepository(Generic[T]):
321
371
  await self.remove(obj)
322
372
  return True
323
373
 
324
- async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: # type: ignore
374
+ async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *,
375
+ include_deleted: bool = False) -> Page[T]: # type: ignore
325
376
  base = self._select()
326
377
  if spec:
327
378
  base = spec(base)
@@ -340,14 +391,14 @@ class SAAsyncRepository(Generic[T]):
340
391
  return Page(items, total, page.page, page.size)
341
392
 
342
393
  async def get_all_by_column(
343
- self,
344
- column_name: str,
345
- value: Any,
346
- *,
347
- limit: Optional[int] = None,
348
- order_by=None,
349
- include_deleted: bool = False,
350
- **extra_filters
394
+ self,
395
+ column_name: str,
396
+ value: Any,
397
+ *,
398
+ limit: Optional[int] = None,
399
+ order_by=None,
400
+ include_deleted: bool = False,
401
+ **extra_filters
351
402
  ) -> list[T]:
352
403
  col = self._resolve_column(column_name)
353
404
  stmt = select(self.model).where(col == value)
@@ -363,11 +414,11 @@ class SAAsyncRepository(Generic[T]):
363
414
  # Alias
364
415
  async def find_all_by_column(self, *args, **kwargs):
365
416
  return await self.get_all_by_column(*args, **kwargs)
366
-
417
+
367
418
  async def get_or_create(
368
- self,
369
- defaults: Optional[dict] = None,
370
- **unique_filters
419
+ self,
420
+ defaults: Optional[dict] = None,
421
+ **unique_filters
371
422
  ) -> tuple[T, bool]:
372
423
  stmt = select(self.model).filter_by(**unique_filters)
373
424
  if hasattr(self.model, "is_deleted"):
@@ -437,4 +488,4 @@ class SAAsyncRepository(Generic[T]):
437
488
  setattr(obj, "is_deleted", False)
438
489
  await self.session.flush()
439
490
  return True
440
- return False
491
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
 
2
2
  [project]
3
3
  name = "SARepo"
4
- version = "0.1.6"
4
+ version = "0.1.7"
5
5
  description = "Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x"
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
@@ -54,20 +54,6 @@ def seed(session, repo):
54
54
 
55
55
 
56
56
 
57
- def test_find_all_by_column_basic(session, repo):
58
- seed(session, repo)
59
-
60
- items = repo.find_all_by_column("city", "Almaty")
61
- titles = {x.title for x in items}
62
- assert titles == {"alpha", "gamma"}
63
-
64
- limited = repo.find_all_by_column("city", "Almaty", limit=1, order_by=Item.amount.desc())
65
- assert len(limited) == 1 and limited[0].title == "gamma"
66
-
67
- with pytest.raises(ValueError):
68
- repo.find_all_by_column("no_such_column", "x")
69
-
70
-
71
57
  def test_get_or_create(session, repo):
72
58
  seed(session, repo)
73
59
 
@@ -1,50 +0,0 @@
1
-
2
- from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
-
4
- from SARepo.sa_repo import Spec
5
- from .base import Page, PageRequest
6
-
7
- T = TypeVar("T")
8
-
9
- class CrudRepository(Protocol, Generic[T]):
10
- model: Type[T]
11
- def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]: ...
12
- def get(self, id_: Any, *, include_deleted: bool = False) -> T: ...
13
- def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]: ...
14
- def add(self, entity: T) -> T: ...
15
- def update(self, entity: T) -> T: ...
16
- def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool: ...
17
- def delete_by_id(self, id_: Any) -> bool: ...
18
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: ...
19
- def get_all_by_column(
20
- self,
21
- column_name: str,
22
- value: Any,
23
- *,
24
- limit: Optional[int] = None,
25
- order_by=None,
26
- include_deleted: bool = False,
27
- **extra_filters
28
- ) -> list[T]: ...
29
- def find_all_by_column(
30
- self,
31
- column_name: str,
32
- value: Any,
33
- *,
34
- limit: Optional[int] = None,
35
- order_by=None,
36
- include_deleted: bool = False,
37
- **extra_filters
38
- ) -> list[T]: ...
39
- def get_or_create(
40
- self,
41
- defaults: Optional[dict] = None,
42
- **unique_filters
43
- ) -> tuple[T, bool]: ...
44
- def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
45
- def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
46
- def aggregate_min(self, column_name: str, **filters): ...
47
- def aggregate_max(self, column_name: str, **filters): ...
48
- def aggregate_sum(self, column_name: str, **filters): ...
49
- def count(self, **filters) -> int: ...
50
- def restore(self, id_: Any) -> bool: ...
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes