SARepo 0.1.5__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -0,0 +1,74 @@
1
+ from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
2
+
3
+ from SARepo.sa_repo import Spec
4
+ from .base import Page, PageRequest
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class CrudRepository(Protocol, Generic[T]):
10
+ model: Type[T]
11
+
12
+ def getAll(self,
13
+ limit: Optional[int] = None,
14
+ *,
15
+ include_deleted: bool = False,
16
+ order_by=None,
17
+ **filters) -> List[T]: ...
18
+
19
+ def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T: ...
20
+
21
+ def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]: ...
22
+
23
+ def add(self, entity: T) -> T: ...
24
+
25
+ def update(self, entity: T) -> T: ...
26
+
27
+ def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool: ...
28
+
29
+ def delete_by_id(self, id_: Any) -> bool: ...
30
+
31
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> \
32
+ Page[T]: ...
33
+
34
+ def get_all_by_column(
35
+ self,
36
+ column_name: str,
37
+ value: Any,
38
+ *,
39
+ limit: Optional[int] = None,
40
+ order_by=None,
41
+ include_deleted: bool = False,
42
+ **extra_filters
43
+ ) -> list[T]: ...
44
+
45
+ def find_all_by_column(
46
+ self,
47
+ column_name: str,
48
+ value: Any,
49
+ *,
50
+ limit: Optional[int] = None,
51
+ order_by=None,
52
+ include_deleted: bool = False,
53
+ **extra_filters
54
+ ) -> list[T]: ...
55
+
56
+ def get_or_create(
57
+ self,
58
+ defaults: Optional[dict] = None,
59
+ **unique_filters
60
+ ) -> tuple[T, bool]: ...
61
+
62
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
63
+
64
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
65
+
66
+ def aggregate_min(self, column_name: str, **filters): ...
67
+
68
+ def aggregate_max(self, column_name: str, **filters): ...
69
+
70
+ def aggregate_sum(self, column_name: str, **filters): ...
71
+
72
+ def count(self, **filters) -> int: ...
73
+
74
+ def restore(self, id_: Any) -> bool: ...
@@ -1,4 +1,3 @@
1
-
2
1
  from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
2
  from sqlalchemy.orm import Session
4
3
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,8 +7,10 @@ from .base import PageRequest, Page, NotFoundError
8
7
  T = TypeVar("T")
9
8
  Spec = Callable
10
9
 
10
+
11
11
  class SARepository(Generic[T]):
12
12
  """Synchronous repository implementation for SQLAlchemy 2.x."""
13
+
13
14
  def __init__(self, model: Type[T], session: Session):
14
15
  self.model = model
15
16
  self.session = session
@@ -31,35 +32,52 @@ class SARepository(Generic[T]):
31
32
  def _has_soft_delete(self) -> bool:
32
33
  return hasattr(self.model, "is_deleted")
33
34
 
34
- def _apply_alive_filter(self, stmt, include_deleted: bool) :
35
+ def _apply_alive_filter(self, stmt, include_deleted: bool):
35
36
  if self._has_soft_delete() and not include_deleted:
36
37
  stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
37
38
  return stmt
38
-
39
- def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
39
+
40
+ def getAll(
41
+ self,
42
+ limit: Optional[int] = None,
43
+ *,
44
+ include_deleted: bool = False,
45
+ order_by=None,
46
+ **filters
47
+ ) -> List[T]:
40
48
  stmt = select(self.model)
49
+ stmt = self._apply_filters(stmt, **filters)
41
50
  stmt = self._apply_alive_filter(stmt, include_deleted)
51
+ if order_by is not None:
52
+ stmt = stmt.order_by(order_by)
42
53
  if limit is not None:
43
54
  stmt = stmt.limit(limit)
44
55
  result = self.session.execute(stmt)
45
- return result.scalars().all()
46
-
47
- def get(self, id_: Any, *, include_deleted: bool = False) -> T:
48
- obj = self.session.get(self.model, id_)
49
- if not obj:
50
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
56
+ return result.scalars().all()
57
+
58
+ def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T:
59
+ if id_ is not None:
60
+ obj = self.session.get(self.model, id_)
61
+ if not obj:
62
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
63
+ else:
64
+ stmt = select(self.model)
65
+ stmt = self._apply_filters(stmt, **filters)
66
+ stmt = self._apply_alive_filter(stmt, include_deleted)
67
+ res = self.session.execute(stmt)
68
+ obj = res.scalars().first()
69
+ if not obj:
70
+ raise NotFoundError(f"{self.model.__name__} not found for {filters}")
51
71
  if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
52
- raise NotFoundError(f"{self.model.__name__}({id_}) not found") # скрываем как «нет»
72
+ raise NotFoundError(f"{self.model.__name__}({getattr(obj, 'id', '?')}) deleted")
53
73
  return obj
54
74
 
55
- def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
56
- obj = self.session.get(self.model, id_)
57
- if not obj:
58
- return None
59
- if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
75
+ def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]:
76
+ try:
77
+ return self.get(id_=id_, include_deleted=include_deleted, **filters)
78
+ except NotFoundError:
60
79
  return None
61
- return obj
62
-
80
+
63
81
  def add(self, entity: T) -> T:
64
82
  self.session.add(entity)
65
83
  self.session.flush()
@@ -71,7 +89,13 @@ class SARepository(Generic[T]):
71
89
  self.session.refresh(entity)
72
90
  return entity
73
91
 
74
- def remove(self, entity: T) -> None:
92
+ def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
93
+ if entity is None and id is None:
94
+ raise ValueError("remove() requires either entity or id")
95
+
96
+ if id is not None:
97
+ return self._delete_by_id(id)
98
+
75
99
  insp = inspect(entity, raiseerr=False)
76
100
  if not (insp and (insp.persistent or insp.pending)):
77
101
  pk = getattr(entity, "id", None)
@@ -79,20 +103,23 @@ class SARepository(Generic[T]):
79
103
  raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
80
104
  entity = self.session.get(self.model, pk)
81
105
  if entity is None:
82
- return
106
+ return False
83
107
  if hasattr(entity, "is_deleted"):
84
108
  setattr(entity, "is_deleted", True)
85
109
  else:
86
110
  self.session.delete(entity)
87
111
 
88
- def delete_by_id(self, id_: Any) -> bool:
112
+ return True
113
+
114
+ def _delete_by_id(self, id_: Any) -> bool:
89
115
  obj = self.session.get(self.model, id_)
90
116
  if not obj:
91
117
  return False
92
118
  self.remove(obj)
93
119
  return True
94
120
 
95
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]:
121
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> \
122
+ Page[T]:
96
123
  base = self._select()
97
124
  if spec:
98
125
  base = spec(base)
@@ -108,16 +135,16 @@ class SARepository(Generic[T]):
108
135
  base.offset(page.page * page.size).limit(page.size)
109
136
  ).scalars().all()
110
137
  return Page(items, total, page.page, page.size)
111
-
138
+
112
139
  def get_all_by_column(
113
- self,
114
- column_name: str,
115
- value: Any,
116
- *,
117
- limit: Optional[int] = None,
118
- order_by=None,
119
- include_deleted: bool = False,
120
- **extra_filters
140
+ self,
141
+ column_name: str,
142
+ value: Any,
143
+ *,
144
+ limit: Optional[int] = None,
145
+ order_by=None,
146
+ include_deleted: bool = False,
147
+ **extra_filters
121
148
  ) -> list[T]:
122
149
  col = self._resolve_column(column_name)
123
150
  stmt = select(self.model).where(col == value)
@@ -129,15 +156,15 @@ class SARepository(Generic[T]):
129
156
  stmt = stmt.limit(limit)
130
157
  res = self.session.execute(stmt)
131
158
  return res.scalars().all()
132
-
159
+
133
160
  # Alias
134
161
  def find_all_by_column(self, *args, **kwargs):
135
162
  return self.get_all_by_column(*args, **kwargs)
136
-
163
+
137
164
  def get_or_create(
138
- self,
139
- defaults: Optional[dict] = None,
140
- **unique_filters
165
+ self,
166
+ defaults: Optional[dict] = None,
167
+ **unique_filters
141
168
  ) -> tuple[T, bool]:
142
169
  """
143
170
  Возвращает (obj, created). unique_filters определяют уникальность.
@@ -220,12 +247,14 @@ class SARepository(Generic[T]):
220
247
  return True
221
248
  return False
222
249
 
250
+
223
251
  class SAAsyncRepository(Generic[T]):
224
252
  """Async repository implementation for SQLAlchemy 2.x."""
253
+
225
254
  def __init__(self, model: Type[T], session: AsyncSession):
226
255
  self.model = model
227
256
  self.session = session
228
-
257
+
229
258
  def _resolve_column(self, column_name: str):
230
259
  try:
231
260
  return getattr(self.model, column_name)
@@ -236,41 +265,71 @@ class SAAsyncRepository(Generic[T]):
236
265
  if filters:
237
266
  stmt = stmt.filter_by(**filters)
238
267
  return stmt
239
-
268
+
240
269
  def _select(self):
241
270
  return select(self.model)
242
-
271
+
243
272
  def _has_soft_delete(self) -> bool:
244
273
  return hasattr(self.model, "is_deleted")
245
274
 
246
- def _apply_alive_filter(self, stmt, include_deleted: bool) :
275
+ def _apply_alive_filter(self, stmt, include_deleted: bool):
247
276
  if self._has_soft_delete() and not include_deleted:
248
277
  stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
249
278
  return stmt
250
279
 
251
- async def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
280
+ async def getAll(
281
+ self,
282
+ limit: Optional[int] = None,
283
+ *,
284
+ include_deleted: bool = False,
285
+ order_by=None,
286
+ **filters
287
+ ) -> List[T]:
288
+ """
289
+ Получить все записи с возможностью фильтрации (username='foo', age__gt=20, и т.д.)
290
+ """
252
291
  stmt = select(self.model)
292
+ stmt = self._apply_filters(stmt, **filters)
253
293
  stmt = self._apply_alive_filter(stmt, include_deleted)
294
+ if order_by is not None:
295
+ stmt = stmt.order_by(order_by)
254
296
  if limit is not None:
255
297
  stmt = stmt.limit(limit)
256
298
  result = await self.session.execute(stmt)
257
299
  return result.scalars().all()
258
300
 
259
- async def get(self, id_: Any, *, include_deleted: bool = False) -> T:
260
- obj = await self.session.get(self.model, id_)
261
- if not obj:
262
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
301
+ async def get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> T:
302
+ """
303
+ Получить один объект по id или по произвольным фильтрам.
304
+ Пример:
305
+ await repo.get(id_=1)
306
+ await repo.get(username='ibrahim')
307
+ await repo.get(username__ilike='%rah%')
308
+ """
309
+ if id_ is not None:
310
+ obj = await self.session.get(self.model, id_)
311
+ if not obj:
312
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
313
+ else:
314
+ stmt = select(self.model)
315
+ stmt = self._apply_filters(stmt, **filters)
316
+ stmt = self._apply_alive_filter(stmt, include_deleted)
317
+ res = await self.session.execute(stmt)
318
+ obj = res.scalars().first()
319
+ if not obj:
320
+ raise NotFoundError(f"{self.model.__name__} not found for {filters}")
263
321
  if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
264
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
322
+ raise NotFoundError(f"{self.model.__name__}({getattr(obj, 'id', '?')}) deleted")
265
323
  return obj
266
324
 
267
- async def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
268
- obj = await self.session.get(self.model, id_)
269
- if not obj:
270
- return None
271
- if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
325
+ async def try_get(self, id_: Any = None, *, include_deleted: bool = False, **filters) -> Optional[T]:
326
+ """
327
+ Как get(), но не выбрасывает исключение при отсутствии объекта.
328
+ """
329
+ try:
330
+ return await self.get(id_=id_, include_deleted=include_deleted, **filters)
331
+ except NotFoundError:
272
332
  return None
273
- return obj
274
333
 
275
334
  async def add(self, entity: T) -> T:
276
335
  self.session.add(entity)
@@ -283,7 +342,13 @@ class SAAsyncRepository(Generic[T]):
283
342
  await self.session.refresh(entity)
284
343
  return entity
285
344
 
286
- async def remove(self, entity: T) -> None:
345
+ async def remove(self, entity: Optional[T] = None, id: Optional[Any] = None) -> bool:
346
+ if entity is None and id is None:
347
+ raise ValueError("remove() requires either entity or id")
348
+
349
+ if id is not None:
350
+ return await self._delete_by_id(id)
351
+
287
352
  insp = inspect(entity, raiseerr=False)
288
353
  if not (insp and (insp.persistent or insp.pending)):
289
354
  pk = getattr(entity, "id", None)
@@ -291,20 +356,23 @@ class SAAsyncRepository(Generic[T]):
291
356
  raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
292
357
  entity = await self.session.get(self.model, pk)
293
358
  if entity is None:
294
- return
359
+ return False
295
360
  if hasattr(entity, "is_deleted"):
296
361
  setattr(entity, "is_deleted", True)
297
362
  else:
298
363
  await self.session.delete(entity)
299
364
 
300
- async def delete_by_id(self, id_: Any) -> bool:
365
+ return True
366
+
367
+ async def _delete_by_id(self, id_: Any) -> bool:
301
368
  obj = await self.session.get(self.model, id_)
302
369
  if not obj:
303
370
  return False
304
371
  await self.remove(obj)
305
372
  return True
306
373
 
307
- async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: # type: ignore
374
+ async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *,
375
+ include_deleted: bool = False) -> Page[T]: # type: ignore
308
376
  base = self._select()
309
377
  if spec:
310
378
  base = spec(base)
@@ -323,14 +391,14 @@ class SAAsyncRepository(Generic[T]):
323
391
  return Page(items, total, page.page, page.size)
324
392
 
325
393
  async def get_all_by_column(
326
- self,
327
- column_name: str,
328
- value: Any,
329
- *,
330
- limit: Optional[int] = None,
331
- order_by=None,
332
- include_deleted: bool = False,
333
- **extra_filters
394
+ self,
395
+ column_name: str,
396
+ value: Any,
397
+ *,
398
+ limit: Optional[int] = None,
399
+ order_by=None,
400
+ include_deleted: bool = False,
401
+ **extra_filters
334
402
  ) -> list[T]:
335
403
  col = self._resolve_column(column_name)
336
404
  stmt = select(self.model).where(col == value)
@@ -346,11 +414,11 @@ class SAAsyncRepository(Generic[T]):
346
414
  # Alias
347
415
  async def find_all_by_column(self, *args, **kwargs):
348
416
  return await self.get_all_by_column(*args, **kwargs)
349
-
417
+
350
418
  async def get_or_create(
351
- self,
352
- defaults: Optional[dict] = None,
353
- **unique_filters
419
+ self,
420
+ defaults: Optional[dict] = None,
421
+ **unique_filters
354
422
  ) -> tuple[T, bool]:
355
423
  stmt = select(self.model).filter_by(**unique_filters)
356
424
  if hasattr(self.model, "is_deleted"):
@@ -420,4 +488,4 @@ class SAAsyncRepository(Generic[T]):
420
488
  setattr(obj, "is_deleted", False)
421
489
  await self.session.flush()
422
490
  return True
423
- return False
491
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
 
2
2
  [project]
3
3
  name = "SARepo"
4
- version = "0.1.5"
4
+ version = "0.1.7"
5
5
  description = "Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x"
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
@@ -54,20 +54,6 @@ def seed(session, repo):
54
54
 
55
55
 
56
56
 
57
- def test_find_all_by_column_basic(session, repo):
58
- seed(session, repo)
59
-
60
- items = repo.find_all_by_column("city", "Almaty")
61
- titles = {x.title for x in items}
62
- assert titles == {"alpha", "gamma"}
63
-
64
- limited = repo.find_all_by_column("city", "Almaty", limit=1, order_by=Item.amount.desc())
65
- assert len(limited) == 1 and limited[0].title == "gamma"
66
-
67
- with pytest.raises(ValueError):
68
- repo.find_all_by_column("no_such_column", "x")
69
-
70
-
71
57
  def test_get_or_create(session, repo):
72
58
  seed(session, repo)
73
59
 
@@ -132,8 +118,11 @@ def test_restore(session, repo):
132
118
  seed(session, repo)
133
119
 
134
120
  beta = repo.find_all_by_column("title", "beta")[0]
135
- beta.is_deleted = True
121
+ beta.is_deleted = False
136
122
  session.commit()
123
+
124
+ deleted = repo.remove(id=beta.id)
125
+ print("DELETED ", deleted)
137
126
 
138
127
  ok = repo.restore(beta.id)
139
128
  assert ok is True
@@ -1,50 +0,0 @@
1
-
2
- from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
-
4
- from SARepo.sa_repo import Spec
5
- from .base import Page, PageRequest
6
-
7
- T = TypeVar("T")
8
-
9
- class CrudRepository(Protocol, Generic[T]):
10
- model: Type[T]
11
- def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]: ...
12
- def get(self, id_: Any, *, include_deleted: bool = False) -> T: ...
13
- def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]: ...
14
- def add(self, entity: T) -> T: ...
15
- def update(self, entity: T) -> T: ...
16
- def remove(self, entity: T) -> None: ...
17
- def delete_by_id(self, id_: Any) -> bool: ...
18
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: ...
19
- def get_all_by_column(
20
- self,
21
- column_name: str,
22
- value: Any,
23
- *,
24
- limit: Optional[int] = None,
25
- order_by=None,
26
- include_deleted: bool = False,
27
- **extra_filters
28
- ) -> list[T]: ...
29
- def find_all_by_column(
30
- self,
31
- column_name: str,
32
- value: Any,
33
- *,
34
- limit: Optional[int] = None,
35
- order_by=None,
36
- include_deleted: bool = False,
37
- **extra_filters
38
- ) -> list[T]: ...
39
- def get_or_create(
40
- self,
41
- defaults: Optional[dict] = None,
42
- **unique_filters
43
- ) -> tuple[T, bool]: ...
44
- def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
45
- def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
46
- def aggregate_min(self, column_name: str, **filters): ...
47
- def aggregate_max(self, column_name: str, **filters): ...
48
- def aggregate_sum(self, column_name: str, **filters): ...
49
- def count(self, **filters) -> int: ...
50
- def restore(self, id_: Any) -> bool: ...
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes