SARepo 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -0,0 +1,50 @@
1
+
2
+ from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
+
4
+ from SARepo.sa_repo import Spec
5
+ from .base import Page, PageRequest
6
+
7
+ T = TypeVar("T")
8
+
9
+ class CrudRepository(Protocol, Generic[T]):
10
+ model: Type[T]
11
+ def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]: ...
12
+ def get(self, id_: Any, *, include_deleted: bool = False) -> T: ...
13
+ def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]: ...
14
+ def add(self, entity: T) -> T: ...
15
+ def update(self, entity: T) -> T: ...
16
+ def remove(self, entity: T) -> None: ...
17
+ def delete_by_id(self, id_: Any) -> bool: ...
18
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: ...
19
+ def get_all_by_column(
20
+ self,
21
+ column_name: str,
22
+ value: Any,
23
+ *,
24
+ limit: Optional[int] = None,
25
+ order_by=None,
26
+ include_deleted: bool = False,
27
+ **extra_filters
28
+ ) -> list[T]: ...
29
+ def find_all_by_column(
30
+ self,
31
+ column_name: str,
32
+ value: Any,
33
+ *,
34
+ limit: Optional[int] = None,
35
+ order_by=None,
36
+ include_deleted: bool = False,
37
+ **extra_filters
38
+ ) -> list[T]: ...
39
+ def get_or_create(
40
+ self,
41
+ defaults: Optional[dict] = None,
42
+ **unique_filters
43
+ ) -> tuple[T, bool]: ...
44
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
45
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
46
+ def aggregate_min(self, column_name: str, **filters): ...
47
+ def aggregate_max(self, column_name: str, **filters): ...
48
+ def aggregate_sum(self, column_name: str, **filters): ...
49
+ def count(self, **filters) -> int: ...
50
+ def restore(self, id_: Any) -> bool: ...
@@ -0,0 +1,423 @@
1
+
2
+ from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
+ from sqlalchemy.orm import Session
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy import inspect, select, func, text
6
+ from .base import PageRequest, Page, NotFoundError
7
+
8
+ T = TypeVar("T")
9
+ Spec = Callable
10
+
11
+ class SARepository(Generic[T]):
12
+ """Synchronous repository implementation for SQLAlchemy 2.x."""
13
+ def __init__(self, model: Type[T], session: Session):
14
+ self.model = model
15
+ self.session = session
16
+
17
+ def _resolve_column(self, column_name: str):
18
+ try:
19
+ return getattr(self.model, column_name)
20
+ except AttributeError as e:
21
+ raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
22
+
23
+ def _apply_filters(self, stmt, **filters):
24
+ if filters:
25
+ stmt = stmt.filter_by(**filters)
26
+ return stmt
27
+
28
+ def _select(self):
29
+ return select(self.model)
30
+
31
+ def _has_soft_delete(self) -> bool:
32
+ return hasattr(self.model, "is_deleted")
33
+
34
+ def _apply_alive_filter(self, stmt, include_deleted: bool) :
35
+ if self._has_soft_delete() and not include_deleted:
36
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
37
+ return stmt
38
+
39
+ def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
40
+ stmt = select(self.model)
41
+ stmt = self._apply_alive_filter(stmt, include_deleted)
42
+ if limit is not None:
43
+ stmt = stmt.limit(limit)
44
+ result = self.session.execute(stmt)
45
+ return result.scalars().all()
46
+
47
+ def get(self, id_: Any, *, include_deleted: bool = False) -> T:
48
+ obj = self.session.get(self.model, id_)
49
+ if not obj:
50
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
51
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
52
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found") # скрываем как «нет»
53
+ return obj
54
+
55
+ def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
56
+ obj = self.session.get(self.model, id_)
57
+ if not obj:
58
+ return None
59
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
60
+ return None
61
+ return obj
62
+
63
+ def add(self, entity: T) -> T:
64
+ self.session.add(entity)
65
+ self.session.flush()
66
+ self.session.refresh(entity)
67
+ return entity
68
+
69
+ def update(self, entity: T) -> T:
70
+ self.session.flush()
71
+ self.session.refresh(entity)
72
+ return entity
73
+
74
+ def remove(self, entity: T) -> None:
75
+ insp = inspect(entity, raiseerr=False)
76
+ if not (insp and (insp.persistent or insp.pending)):
77
+ pk = getattr(entity, "id", None)
78
+ if pk is None:
79
+ raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
80
+ entity = self.session.get(self.model, pk)
81
+ if entity is None:
82
+ return
83
+ if hasattr(entity, "is_deleted"):
84
+ setattr(entity, "is_deleted", True)
85
+ else:
86
+ self.session.delete(entity)
87
+
88
+ def delete_by_id(self, id_: Any) -> bool:
89
+ obj = self.session.get(self.model, id_)
90
+ if not obj:
91
+ return False
92
+ self.remove(obj)
93
+ return True
94
+
95
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]:
96
+ base = self._select()
97
+ if spec:
98
+ base = spec(base)
99
+ base = self._apply_alive_filter(base, include_deleted)
100
+ if order_by is not None:
101
+ base = base.order_by(order_by)
102
+
103
+ total = self.session.execute(
104
+ select(func.count()).select_from(base.subquery())
105
+ ).scalar_one()
106
+
107
+ items = self.session.execute(
108
+ base.offset(page.page * page.size).limit(page.size)
109
+ ).scalars().all()
110
+ return Page(items, total, page.page, page.size)
111
+
112
+ def get_all_by_column(
113
+ self,
114
+ column_name: str,
115
+ value: Any,
116
+ *,
117
+ limit: Optional[int] = None,
118
+ order_by=None,
119
+ include_deleted: bool = False,
120
+ **extra_filters
121
+ ) -> list[T]:
122
+ col = self._resolve_column(column_name)
123
+ stmt = select(self.model).where(col == value)
124
+ stmt = self._apply_alive_filter(stmt, include_deleted)
125
+ stmt = self._apply_filters(stmt, **extra_filters)
126
+ if order_by is not None:
127
+ stmt = stmt.order_by(order_by)
128
+ if limit is not None:
129
+ stmt = stmt.limit(limit)
130
+ res = self.session.execute(stmt)
131
+ return res.scalars().all()
132
+
133
+ # Alias
134
+ def find_all_by_column(self, *args, **kwargs):
135
+ return self.get_all_by_column(*args, **kwargs)
136
+
137
+ def get_or_create(
138
+ self,
139
+ defaults: Optional[dict] = None,
140
+ **unique_filters
141
+ ) -> tuple[T, bool]:
142
+ """
143
+ Возвращает (obj, created). unique_filters определяют уникальность.
144
+ defaults дополняют поля при создании.
145
+ """
146
+ stmt = select(self.model).filter_by(**unique_filters)
147
+ if hasattr(self.model, "is_deleted"):
148
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
149
+ obj = self.session.execute(stmt).scalar_one_or_none()
150
+ if obj:
151
+ return obj, False
152
+ payload = {**unique_filters, **(defaults or {})}
153
+ obj = self.model(**payload) # type: ignore[call-arg]
154
+ self.session.add(obj)
155
+ self.session.flush()
156
+ self.session.refresh(obj)
157
+ return obj, True
158
+
159
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
160
+ """
161
+ Безопасно выполняет сырой SQL (используй плейсхолдеры :name).
162
+ Возвращает список dict (строки).
163
+ """
164
+ res = self.session.execute(text(sql), params or {})
165
+ # mapping() -> RowMapping (dict-like)
166
+ return [dict(row) for row in res.mappings().all()]
167
+
168
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
169
+ col = self._resolve_column(column_name)
170
+ stmt = select(func.avg(col))
171
+ if hasattr(self.model, "is_deleted"):
172
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
173
+ stmt = self._apply_filters(stmt, **filters)
174
+ return self.session.execute(stmt).scalar()
175
+
176
+ def aggregate_min(self, column_name: str, **filters):
177
+ col = self._resolve_column(column_name)
178
+ stmt = select(func.min(col))
179
+ if hasattr(self.model, "is_deleted"):
180
+ stmt = stmt.where(self.model.is_deleted == False)
181
+ stmt = self._apply_filters(stmt, **filters)
182
+ return self.session.execute(stmt).scalar()
183
+
184
+ def aggregate_max(self, column_name: str, **filters):
185
+ col = self._resolve_column(column_name)
186
+ stmt = select(func.max(col))
187
+ if hasattr(self.model, "is_deleted"):
188
+ stmt = stmt.where(self.model.is_deleted == False)
189
+ stmt = self._apply_filters(stmt, **filters)
190
+ return self.session.execute(stmt).scalar()
191
+
192
+ def aggregate_sum(self, column_name: str, **filters):
193
+ col = self._resolve_column(column_name)
194
+ stmt = select(func.sum(col))
195
+ if hasattr(self.model, "is_deleted"):
196
+ stmt = stmt.where(self.model.is_deleted == False)
197
+ stmt = self._apply_filters(stmt, **filters)
198
+ return self.session.execute(stmt).scalar()
199
+
200
+ def count(self, **filters) -> int:
201
+ stmt = select(func.count()).select_from(self.model)
202
+ if hasattr(self.model, "is_deleted") and not filters.pop("include_deleted", False):
203
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
204
+ if filters:
205
+ stmt = stmt.filter_by(**filters)
206
+ return int(self.session.execute(stmt).scalar_one())
207
+
208
+ def restore(self, id_: Any) -> bool:
209
+ """
210
+ Для soft-delete: is_deleted=False. Возвращает True, если восстановили.
211
+ """
212
+ if not hasattr(self.model, "is_deleted"):
213
+ raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
214
+ obj = self.session.get(self.model, id_)
215
+ if not obj:
216
+ return False
217
+ if getattr(obj, "is_deleted", False):
218
+ setattr(obj, "is_deleted", False)
219
+ self.session.flush()
220
+ return True
221
+ return False
222
+
223
+ class SAAsyncRepository(Generic[T]):
224
+ """Async repository implementation for SQLAlchemy 2.x."""
225
+ def __init__(self, model: Type[T], session: AsyncSession):
226
+ self.model = model
227
+ self.session = session
228
+
229
+ def _resolve_column(self, column_name: str):
230
+ try:
231
+ return getattr(self.model, column_name)
232
+ except AttributeError as e:
233
+ raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
234
+
235
+ def _apply_filters(self, stmt, **filters):
236
+ if filters:
237
+ stmt = stmt.filter_by(**filters)
238
+ return stmt
239
+
240
+ def _select(self):
241
+ return select(self.model)
242
+
243
+ def _has_soft_delete(self) -> bool:
244
+ return hasattr(self.model, "is_deleted")
245
+
246
+ def _apply_alive_filter(self, stmt, include_deleted: bool) :
247
+ if self._has_soft_delete() and not include_deleted:
248
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
249
+ return stmt
250
+
251
+ async def getAll(self, limit: Optional[int] = None, *, include_deleted: bool = False) -> List[T]:
252
+ stmt = select(self.model)
253
+ stmt = self._apply_alive_filter(stmt, include_deleted)
254
+ if limit is not None:
255
+ stmt = stmt.limit(limit)
256
+ result = await self.session.execute(stmt)
257
+ return result.scalars().all()
258
+
259
+ async def get(self, id_: Any, *, include_deleted: bool = False) -> T:
260
+ obj = await self.session.get(self.model, id_)
261
+ if not obj:
262
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
263
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
264
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
265
+ return obj
266
+
267
+ async def try_get(self, id_: Any, *, include_deleted: bool = False) -> Optional[T]:
268
+ obj = await self.session.get(self.model, id_)
269
+ if not obj:
270
+ return None
271
+ if self._has_soft_delete() and not include_deleted and getattr(obj, "is_deleted", False):
272
+ return None
273
+ return obj
274
+
275
+ async def add(self, entity: T) -> T:
276
+ self.session.add(entity)
277
+ await self.session.flush()
278
+ await self.session.refresh(entity)
279
+ return entity
280
+
281
+ async def update(self, entity: T) -> T:
282
+ await self.session.flush()
283
+ await self.session.refresh(entity)
284
+ return entity
285
+
286
+ async def remove(self, entity: T) -> None:
287
+ insp = inspect(entity, raiseerr=False)
288
+ if not (insp and (insp.persistent or insp.pending)):
289
+ pk = getattr(entity, "id", None)
290
+ if pk is None:
291
+ raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
292
+ entity = await self.session.get(self.model, pk)
293
+ if entity is None:
294
+ return
295
+ if hasattr(entity, "is_deleted"):
296
+ setattr(entity, "is_deleted", True)
297
+ else:
298
+ await self.session.delete(entity)
299
+
300
+ async def delete_by_id(self, id_: Any) -> bool:
301
+ obj = await self.session.get(self.model, id_)
302
+ if not obj:
303
+ return False
304
+ await self.remove(obj)
305
+ return True
306
+
307
+ async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None, *, include_deleted: bool = False) -> Page[T]: # type: ignore
308
+ base = self._select()
309
+ if spec:
310
+ base = spec(base)
311
+ base = self._apply_alive_filter(base, include_deleted)
312
+ if order_by is not None:
313
+ base = base.order_by(order_by)
314
+
315
+ total = (await self.session.execute(
316
+ select(func.count()).select_from(base.subquery())
317
+ )).scalar_one()
318
+
319
+ res = await self.session.execute(
320
+ base.offset(page.page * page.size).limit(page.size)
321
+ )
322
+ items = res.scalars().all()
323
+ return Page(items, total, page.page, page.size)
324
+
325
+ async def get_all_by_column(
326
+ self,
327
+ column_name: str,
328
+ value: Any,
329
+ *,
330
+ limit: Optional[int] = None,
331
+ order_by=None,
332
+ include_deleted: bool = False,
333
+ **extra_filters
334
+ ) -> list[T]:
335
+ col = self._resolve_column(column_name)
336
+ stmt = select(self.model).where(col == value)
337
+ stmt = self._apply_alive_filter(stmt, include_deleted)
338
+ stmt = self._apply_filters(stmt, **extra_filters)
339
+ if order_by is not None:
340
+ stmt = stmt.order_by(order_by)
341
+ if limit is not None:
342
+ stmt = stmt.limit(limit)
343
+ res = await self.session.execute(stmt)
344
+ return res.scalars().all()
345
+
346
+ # Alias
347
+ async def find_all_by_column(self, *args, **kwargs):
348
+ return await self.get_all_by_column(*args, **kwargs)
349
+
350
+ async def get_or_create(
351
+ self,
352
+ defaults: Optional[dict] = None,
353
+ **unique_filters
354
+ ) -> tuple[T, bool]:
355
+ stmt = select(self.model).filter_by(**unique_filters)
356
+ if hasattr(self.model, "is_deleted"):
357
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
358
+ obj = (await self.session.execute(stmt)).scalar_one_or_none()
359
+ if obj:
360
+ return obj, False
361
+ payload = {**unique_filters, **(defaults or {})}
362
+ obj = self.model(**payload) # type: ignore[call-arg]
363
+ self.session.add(obj)
364
+ await self.session.flush()
365
+ await self.session.refresh(obj)
366
+ return obj, True
367
+
368
+ async def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
369
+ res = await self.session.execute(text(sql), params or {})
370
+ return [dict(row) for row in res.mappings().all()]
371
+
372
+ async def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
373
+ col = self._resolve_column(column_name)
374
+ stmt = select(func.avg(col))
375
+ if hasattr(self.model, "is_deleted"):
376
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
377
+ stmt = self._apply_filters(stmt, **filters)
378
+ return (await self.session.execute(stmt)).scalar()
379
+
380
+ async def aggregate_min(self, column_name: str, **filters):
381
+ col = self._resolve_column(column_name)
382
+ stmt = select(func.min(col))
383
+ if hasattr(self.model, "is_deleted"):
384
+ stmt = stmt.where(self.model.is_deleted == False)
385
+ stmt = self._apply_filters(stmt, **filters)
386
+ return (await self.session.execute(stmt)).scalar()
387
+
388
+ async def aggregate_max(self, column_name: str, **filters):
389
+ col = self._resolve_column(column_name)
390
+ stmt = select(func.max(col))
391
+ if hasattr(self.model, "is_deleted"):
392
+ stmt = stmt.where(self.model.is_deleted == False)
393
+ stmt = self._apply_filters(stmt, **filters)
394
+ return (await self.session.execute(stmt)).scalar()
395
+
396
+ async def aggregate_sum(self, column_name: str, **filters):
397
+ col = self._resolve_column(column_name)
398
+ stmt = select(func.sum(col))
399
+ if hasattr(self.model, "is_deleted"):
400
+ stmt = stmt.where(self.model.is_deleted == False)
401
+ stmt = self._apply_filters(stmt, **filters)
402
+ return (await self.session.execute(stmt)).scalar()
403
+
404
+ async def count(self, **filters) -> int:
405
+ include_deleted = bool(filters.pop("include_deleted", False))
406
+ stmt = select(func.count()).select_from(self.model)
407
+ if hasattr(self.model, "is_deleted") and not include_deleted:
408
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
409
+ if filters:
410
+ stmt = stmt.filter_by(**filters)
411
+ return int((await self.session.execute(stmt)).scalar_one())
412
+
413
+ async def restore(self, id_: Any) -> bool:
414
+ if not hasattr(self.model, "is_deleted"):
415
+ raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
416
+ obj = await self.session.get(self.model, id_)
417
+ if not obj:
418
+ return False
419
+ if getattr(obj, "is_deleted", False):
420
+ setattr(obj, "is_deleted", False)
421
+ await self.session.flush()
422
+ return True
423
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
 
2
2
  [project]
3
3
  name = "SARepo"
4
- version = "0.1.3"
4
+ version = "0.1.5"
5
5
  description = "Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x"
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
@@ -0,0 +1,145 @@
1
+ import pytest
2
+ from typing import Optional
3
+ from sqlalchemy import create_engine, String, Integer, Boolean, Text
4
+ from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
5
+
6
+ from SARepo.sa_repo import SARepository
7
+ from SARepo.base import NotFoundError
8
+
9
+
10
+
11
+ class Base(DeclarativeBase):
12
+ pass
13
+
14
+
15
+ class Item(Base):
16
+ __tablename__ = "items"
17
+
18
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
19
+ title: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
20
+ city: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
21
+ amount: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
22
+ age: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
23
+ is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
24
+ note: Mapped[Optional[str]] = mapped_column(Text)
25
+
26
+
27
+
28
+ @pytest.fixture()
29
+ def session():
30
+ engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
31
+ Base.metadata.create_all(engine)
32
+ SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
33
+ with SessionLocal() as s:
34
+ yield s
35
+
36
+
37
+ @pytest.fixture()
38
+ def repo(session):
39
+ return SARepository(Item, session)
40
+
41
+
42
+ def seed(session, repo):
43
+ """Наполняем базу различными записями для агрегатов/фильтров."""
44
+ data = [
45
+ Item(title="alpha", city="Almaty", amount=10, age=20, note="A"),
46
+ Item(title="beta", city="Astana", amount=20, age=30, note="B"),
47
+ Item(title="gamma", city="Almaty", amount=30, age=40, note="G"),
48
+ Item(title="delta", city="Shymkent", amount=40, age=50, note="D"),
49
+ Item(title="omega", city=None, amount=50, age=60, note="O"),
50
+ ]
51
+ for obj in data:
52
+ repo.add(obj)
53
+ session.commit()
54
+
55
+
56
+
57
+ def test_find_all_by_column_basic(session, repo):
58
+ seed(session, repo)
59
+
60
+ items = repo.find_all_by_column("city", "Almaty")
61
+ titles = {x.title for x in items}
62
+ assert titles == {"alpha", "gamma"}
63
+
64
+ limited = repo.find_all_by_column("city", "Almaty", limit=1, order_by=Item.amount.desc())
65
+ assert len(limited) == 1 and limited[0].title == "gamma"
66
+
67
+ with pytest.raises(ValueError):
68
+ repo.find_all_by_column("no_such_column", "x")
69
+
70
+
71
+ def test_get_or_create(session, repo):
72
+ seed(session, repo)
73
+
74
+ obj, created = repo.get_or_create(title="alpha", defaults={"city": "Kokshetau"})
75
+ assert created is False
76
+ assert obj.city == "Almaty"
77
+
78
+ obj2, created2 = repo.get_or_create(title="sigma", defaults={"city": "Aktau", "amount": 77})
79
+ assert created2 is True
80
+ assert (obj2.title, obj2.city, obj2.amount) == ("sigma", "Aktau", 77)
81
+
82
+
83
+ def test_raw_query(session, repo):
84
+ seed(session, repo)
85
+
86
+ rows = repo.raw_query(
87
+ "SELECT id, title, city FROM items WHERE lower(title) LIKE :p",
88
+ {"p": "%a%"}
89
+ )
90
+ assert isinstance(rows, list) and all(isinstance(r, dict) for r in rows)
91
+ titles = {r["title"] for r in rows}
92
+ assert {"alpha", "gamma", "delta"}.issubset(titles)
93
+
94
+
95
+ def test_aggregates_avg_min_max_sum_ignore_soft_deleted_by_default(session, repo):
96
+ seed(session, repo)
97
+
98
+ omega = repo.find_all_by_column("title", "omega")[0]
99
+ omega.is_deleted = True
100
+ session.commit()
101
+
102
+ avg_amount = repo.aggregate_avg("amount")
103
+ min_amount = repo.aggregate_min("amount")
104
+ max_amount = repo.aggregate_max("amount")
105
+ sum_amount = repo.aggregate_sum("amount")
106
+
107
+ assert avg_amount == pytest.approx((10 + 20 + 30 + 40) / 4)
108
+ assert min_amount == 10
109
+ assert max_amount == 40
110
+ assert sum_amount == 10 + 20 + 30 + 40
111
+
112
+ avg_age_almaty = repo.aggregate_avg("age", city="Almaty")
113
+ assert avg_age_almaty == pytest.approx((20 + 40) / 2)
114
+
115
+
116
+ def test_count_with_and_without_deleted(session, repo):
117
+ seed(session, repo)
118
+
119
+ for t in ("alpha", "beta"):
120
+ x = repo.find_all_by_column("title", t)[0]
121
+ x.is_deleted = True
122
+ session.commit()
123
+
124
+ assert repo.count() == 3
125
+
126
+ assert repo.count(include_deleted=True) == 5
127
+
128
+ assert repo.count(city="Almaty") == 1
129
+
130
+
131
+ def test_restore(session, repo):
132
+ seed(session, repo)
133
+
134
+ beta = repo.find_all_by_column("title", "beta")[0]
135
+ beta.is_deleted = True
136
+ session.commit()
137
+
138
+ ok = repo.restore(beta.id)
139
+ assert ok is True
140
+ session.commit()
141
+
142
+ ok2 = repo.restore(beta.id)
143
+ assert ok2 is False
144
+
145
+ assert repo.count() == 5
@@ -1,16 +0,0 @@
1
-
2
- from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
- from .base import Page, PageRequest
4
-
5
- T = TypeVar("T")
6
-
7
- class CrudRepository(Protocol, Generic[T]):
8
- model: Type[T]
9
- def getAll(self, limit: Optional[int]) -> List[T]: ...
10
- def get(self, id_: Any) -> T: ...
11
- def try_get(self, id_: Any) -> Optional[T]: ...
12
- def add(self, entity: T) -> T: ...
13
- def update(self, entity: T) -> T: ...
14
- def remove(self, entity: T) -> None: ...
15
- def delete_by_id(self, id_: Any) -> bool: ...
16
- def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
@@ -1,152 +0,0 @@
1
-
2
- from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
- from sqlalchemy.orm import Session
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
- from sqlalchemy import inspect, select, func
6
- from .base import PageRequest, Page, NotFoundError
7
-
8
- T = TypeVar("T")
9
- Spec = Callable # aliased to match specs.Spec
10
-
11
- class SARepository(Generic[T]):
12
- """Synchronous repository implementation for SQLAlchemy 2.x."""
13
- def __init__(self, model: Type[T], session: Session):
14
- self.model = model
15
- self.session = session
16
-
17
- def _select(self):
18
- return select(self.model)
19
-
20
- def getAll(self, limit: Optional[int] = None) -> List[T]:
21
- stmt = select(self.model)
22
- if limit is not None:
23
- stmt = stmt.limit(limit)
24
- result = self.session.execute(stmt)
25
- return result.scalars().all()
26
-
27
- def get(self, id_: Any) -> T:
28
- obj = self.session.get(self.model, id_)
29
- if not obj:
30
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
31
- return obj
32
-
33
- def try_get(self, id_: Any) -> Optional[T]:
34
- return self.session.get(self.model, id_)
35
-
36
- def add(self, entity: T) -> T:
37
- self.session.add(entity)
38
- self.session.flush()
39
- self.session.refresh(entity)
40
- return entity
41
-
42
- def update(self, entity: T) -> T:
43
- self.session.flush()
44
- self.session.refresh(entity)
45
- return entity
46
-
47
- def remove(self, entity: T) -> None:
48
- insp = inspect(entity, raiseerr=False)
49
- if not (insp and (insp.persistent or insp.pending)):
50
- pk = getattr(entity, "id", None)
51
- if pk is None:
52
- raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
53
- entity = self.session.get(self.model, pk)
54
- if entity is None:
55
- return
56
- if hasattr(entity, "is_deleted"):
57
- setattr(entity, "is_deleted", True)
58
- else:
59
- self.session.delete(entity)
60
-
61
- def delete_by_id(self, id_: Any) -> bool:
62
- obj = self.session.get(self.model, id_)
63
- if not obj:
64
- return False
65
- self.remove(obj)
66
- return True
67
-
68
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
69
- stmt = self._select()
70
- if spec:
71
- stmt = spec(stmt)
72
- if order_by is not None:
73
- stmt = stmt.order_by(order_by)
74
- total = self.session.execute(
75
- select(func.count()).select_from(stmt.subquery())
76
- ).scalar_one()
77
- items = self.session.execute(
78
- stmt.offset(page.page * page.size).limit(page.size)
79
- ).scalars().all()
80
- return Page(items, total, page.page, page.size)
81
-
82
- class SAAsyncRepository(Generic[T]):
83
- """Async repository implementation for SQLAlchemy 2.x."""
84
- def __init__(self, model: Type[T], session: AsyncSession):
85
- self.model = model
86
- self.session = session
87
-
88
- async def getAll(self, limit: Optional[int] = None) -> List[T]:
89
- stmt = select(self.model)
90
- if limit is not None:
91
- stmt = stmt.limit(limit)
92
- result = await self.session.execute(stmt)
93
- return result.scalars().all()
94
-
95
- def _select(self):
96
- return select(self.model)
97
-
98
- async def get(self, id_: Any) -> T:
99
- obj = await self.session.get(self.model, id_)
100
- if not obj:
101
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
102
- return obj
103
-
104
- async def try_get(self, id_: Any) -> Optional[T]:
105
- return await self.session.get(self.model, id_)
106
-
107
- async def add(self, entity: T) -> T:
108
- self.session.add(entity)
109
- await self.session.flush()
110
- await self.session.refresh(entity)
111
- return entity
112
-
113
- async def update(self, entity: T) -> T:
114
- await self.session.flush()
115
- await self.session.refresh(entity)
116
- return entity
117
-
118
- async def remove(self, entity: T) -> None:
119
- insp = inspect(entity, raiseerr=False)
120
- if not (insp and (insp.persistent or insp.pending)):
121
- pk = getattr(entity, "id", None)
122
- if pk is None:
123
- raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
124
- entity = await self.session.get(self.model, pk)
125
- if entity is None:
126
- return
127
- if hasattr(entity, "is_deleted"):
128
- setattr(entity, "is_deleted", True)
129
- else:
130
- await self.session.delete(entity)
131
-
132
- async def delete_by_id(self, id_: Any) -> bool:
133
- obj = await self.session.get(self.model, id_)
134
- if not obj:
135
- return False
136
- await self.remove(obj)
137
- return True
138
-
139
- async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
140
- stmt = self._select()
141
- if spec:
142
- stmt = spec(stmt)
143
- if order_by is not None:
144
- stmt = stmt.order_by(order_by)
145
- total = (await self.session.execute(
146
- select(func.count()).select_from(stmt.subquery())
147
- )).scalar_one()
148
- res = await self.session.execute(
149
- stmt.offset(page.page * page.size).limit(page.size)
150
- )
151
- items = res.scalars().all()
152
- return Page(items, total, page.page, page.size)
@@ -1,73 +0,0 @@
1
- import pytest
2
- from sqlalchemy import create_engine, String
3
- from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
4
- from SARepo.sa_repo import SARepository
5
- from SARepo.base import PageRequest, NotFoundError
6
-
7
- class Base(DeclarativeBase): pass
8
-
9
- class Item(Base):
10
- __tablename__ = "items"
11
- id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
12
- title: Mapped[str] = mapped_column(String(100), nullable=False)
13
-
14
- def test_crud_and_pagination():
15
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
16
- Base.metadata.create_all(engine)
17
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
18
-
19
- with SessionLocal() as session:
20
- repo = SARepository(Item, session)
21
- for i in range(25):
22
- repo.add(Item(title=f"t{i}"))
23
- session.commit()
24
-
25
- page = repo.page(PageRequest(1, 10))
26
- assert page.total == 25
27
- assert len(page.items) == 10
28
- assert page.page == 1
29
- assert page.pages == 3
30
- # проверим, что конкретные объекты есть
31
- got = {it.title for it in page.items}
32
- assert got.issubset({f"t{i}" for i in range(25)})
33
-
34
- def test_get_all():
35
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
36
- Base.metadata.create_all(engine)
37
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
38
-
39
- with SessionLocal() as session:
40
- repo = SARepository(Item, session)
41
- for i in range(10):
42
- repo.add(Item(title=f"get-all-{i}"))
43
- session.commit()
44
-
45
- items = repo.getAll()
46
- assert isinstance(items, list)
47
- assert len(items) == 10
48
- titles = {i.title for i in items}
49
- assert titles == {f"get-all-{i}" for i in range(10)}
50
-
51
- items2 = repo.getAll(5)
52
- assert isinstance(items2, list)
53
- assert len(items2) == 5
54
- titles = {i.title for i in items2}
55
- assert titles == {f"get-all-{i}" for i in range(5)}
56
-
57
- def test_get_and_try_get():
58
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
59
- Base.metadata.create_all(engine)
60
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
61
-
62
- with SessionLocal() as session:
63
- repo = SARepository(Item, session)
64
- obj = repo.add(Item(title="one"))
65
- session.commit()
66
-
67
- same = repo.get(obj.id)
68
- assert same.id == obj.id and same.title == "one"
69
-
70
- assert repo.try_get(9999) is None
71
-
72
- with pytest.raises(NotFoundError):
73
- repo.get(9999)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes