SARepo 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -0,0 +1,48 @@
1
+
2
+ from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
+ from .base import Page, PageRequest
4
+
5
+ T = TypeVar("T")
6
+
7
+ class CrudRepository(Protocol, Generic[T]):
8
+ model: Type[T]
9
+ def getAll(self, limit: Optional[int]) -> List[T]: ...
10
+ def get(self, id_: Any) -> T: ...
11
+ def try_get(self, id_: Any) -> Optional[T]: ...
12
+ def add(self, entity: T) -> T: ...
13
+ def update(self, entity: T) -> T: ...
14
+ def remove(self, entity: T) -> None: ...
15
+ def delete_by_id(self, id_: Any) -> bool: ...
16
+ def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
17
+ def get_all_by_column(
18
+ self,
19
+ column_name: str,
20
+ value: Any,
21
+ *,
22
+ limit: Optional[int] = None,
23
+ order_by=None,
24
+ include_deleted: bool = False,
25
+ **extra_filters
26
+ ) -> list[T]: ...
27
+ def find_all_by_column(
28
+ self,
29
+ column_name: str,
30
+ value: Any,
31
+ *,
32
+ limit: Optional[int] = None,
33
+ order_by=None,
34
+ include_deleted: bool = False,
35
+ **extra_filters
36
+ ) -> list[T]: ...
37
+ def get_or_create(
38
+ self,
39
+ defaults: Optional[dict] = None,
40
+ **unique_filters
41
+ ) -> tuple[T, bool]: ...
42
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
43
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
44
+ def aggregate_min(self, column_name: str, **filters): ...
45
+ def aggregate_max(self, column_name: str, **filters): ...
46
+ def aggregate_sum(self, column_name: str, **filters): ...
47
+ def count(self, **filters) -> int: ...
48
+ def restore(self, id_: Any) -> bool: ...
@@ -0,0 +1,387 @@
1
+
2
+ from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
+ from sqlalchemy.orm import Session
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy import inspect, select, func, text
6
+ from .base import PageRequest, Page, NotFoundError
7
+
8
+ T = TypeVar("T")
9
+ Spec = Callable
10
+
11
+ class SARepository(Generic[T]):
12
+ """Synchronous repository implementation for SQLAlchemy 2.x."""
13
+ def __init__(self, model: Type[T], session: Session):
14
+ self.model = model
15
+ self.session = session
16
+
17
+ def _resolve_column(self, column_name: str):
18
+ try:
19
+ return getattr(self.model, column_name)
20
+ except AttributeError as e:
21
+ raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
22
+
23
+ def _apply_filters(self, stmt, **filters):
24
+ if filters:
25
+ stmt = stmt.filter_by(**filters)
26
+ return stmt
27
+
28
+ def _select(self):
29
+ return select(self.model)
30
+
31
+ def getAll(self, limit: Optional[int] = None) -> List[T]:
32
+ stmt = select(self.model)
33
+ if limit is not None:
34
+ stmt = stmt.limit(limit)
35
+ result = self.session.execute(stmt)
36
+ return result.scalars().all()
37
+
38
+ def get(self, id_: Any) -> T:
39
+ obj = self.session.get(self.model, id_)
40
+ if not obj:
41
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
42
+ return obj
43
+
44
+ def try_get(self, id_: Any) -> Optional[T]:
45
+ return self.session.get(self.model, id_)
46
+
47
+ def add(self, entity: T) -> T:
48
+ self.session.add(entity)
49
+ self.session.flush()
50
+ self.session.refresh(entity)
51
+ return entity
52
+
53
+ def update(self, entity: T) -> T:
54
+ self.session.flush()
55
+ self.session.refresh(entity)
56
+ return entity
57
+
58
+ def remove(self, entity: T) -> None:
59
+ insp = inspect(entity, raiseerr=False)
60
+ if not (insp and (insp.persistent or insp.pending)):
61
+ pk = getattr(entity, "id", None)
62
+ if pk is None:
63
+ raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
64
+ entity = self.session.get(self.model, pk)
65
+ if entity is None:
66
+ return
67
+ if hasattr(entity, "is_deleted"):
68
+ setattr(entity, "is_deleted", True)
69
+ else:
70
+ self.session.delete(entity)
71
+
72
+ def delete_by_id(self, id_: Any) -> bool:
73
+ obj = self.session.get(self.model, id_)
74
+ if not obj:
75
+ return False
76
+ self.remove(obj)
77
+ return True
78
+
79
+ def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
80
+ stmt = self._select()
81
+ if spec:
82
+ stmt = spec(stmt)
83
+ if order_by is not None:
84
+ stmt = stmt.order_by(order_by)
85
+ total = self.session.execute(
86
+ select(func.count()).select_from(stmt.subquery())
87
+ ).scalar_one()
88
+ items = self.session.execute(
89
+ stmt.offset(page.page * page.size).limit(page.size)
90
+ ).scalars().all()
91
+ return Page(items, total, page.page, page.size)
92
+
93
+ def get_all_by_column(
94
+ self,
95
+ column_name: str,
96
+ value: Any,
97
+ *,
98
+ limit: Optional[int] = None,
99
+ order_by=None,
100
+ include_deleted: bool = False,
101
+ **extra_filters
102
+ ) -> list[T]:
103
+ col = self._resolve_column(column_name)
104
+ stmt = select(self.model).where(col == value)
105
+ if not include_deleted and hasattr(self.model, "is_deleted"):
106
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
107
+ stmt = self._apply_filters(stmt, **extra_filters)
108
+ if order_by is not None:
109
+ stmt = stmt.order_by(order_by)
110
+ if limit is not None:
111
+ stmt = stmt.limit(limit)
112
+ res = self.session.execute(stmt)
113
+ return res.scalars().all()
114
+
115
+ # Alias
116
+ def find_all_by_column(self, *args, **kwargs):
117
+ return self.get_all_by_column(*args, **kwargs)
118
+
119
+ def get_or_create(
120
+ self,
121
+ defaults: Optional[dict] = None,
122
+ **unique_filters
123
+ ) -> tuple[T, bool]:
124
+ """
125
+ Возвращает (obj, created). unique_filters определяют уникальность.
126
+ defaults дополняют поля при создании.
127
+ """
128
+ stmt = select(self.model).filter_by(**unique_filters)
129
+ if hasattr(self.model, "is_deleted"):
130
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
131
+ obj = self.session.execute(stmt).scalar_one_or_none()
132
+ if obj:
133
+ return obj, False
134
+ payload = {**unique_filters, **(defaults or {})}
135
+ obj = self.model(**payload) # type: ignore[call-arg]
136
+ self.session.add(obj)
137
+ self.session.flush()
138
+ self.session.refresh(obj)
139
+ return obj, True
140
+
141
+ def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
142
+ """
143
+ Безопасно выполняет сырой SQL (используй плейсхолдеры :name).
144
+ Возвращает список dict (строки).
145
+ """
146
+ res = self.session.execute(text(sql), params or {})
147
+ # mapping() -> RowMapping (dict-like)
148
+ return [dict(row) for row in res.mappings().all()]
149
+
150
+ def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
151
+ col = self._resolve_column(column_name)
152
+ stmt = select(func.avg(col))
153
+ if hasattr(self.model, "is_deleted"):
154
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
155
+ stmt = self._apply_filters(stmt, **filters)
156
+ return self.session.execute(stmt).scalar()
157
+
158
+ def aggregate_min(self, column_name: str, **filters):
159
+ col = self._resolve_column(column_name)
160
+ stmt = select(func.min(col))
161
+ if hasattr(self.model, "is_deleted"):
162
+ stmt = stmt.where(self.model.is_deleted == False)
163
+ stmt = self._apply_filters(stmt, **filters)
164
+ return self.session.execute(stmt).scalar()
165
+
166
+ def aggregate_max(self, column_name: str, **filters):
167
+ col = self._resolve_column(column_name)
168
+ stmt = select(func.max(col))
169
+ if hasattr(self.model, "is_deleted"):
170
+ stmt = stmt.where(self.model.is_deleted == False)
171
+ stmt = self._apply_filters(stmt, **filters)
172
+ return self.session.execute(stmt).scalar()
173
+
174
+ def aggregate_sum(self, column_name: str, **filters):
175
+ col = self._resolve_column(column_name)
176
+ stmt = select(func.sum(col))
177
+ if hasattr(self.model, "is_deleted"):
178
+ stmt = stmt.where(self.model.is_deleted == False)
179
+ stmt = self._apply_filters(stmt, **filters)
180
+ return self.session.execute(stmt).scalar()
181
+
182
+ def count(self, **filters) -> int:
183
+ stmt = select(func.count()).select_from(self.model)
184
+ if hasattr(self.model, "is_deleted") and not filters.pop("include_deleted", False):
185
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
186
+ if filters:
187
+ stmt = stmt.filter_by(**filters)
188
+ return int(self.session.execute(stmt).scalar_one())
189
+
190
+ def restore(self, id_: Any) -> bool:
191
+ """
192
+ Для soft-delete: is_deleted=False. Возвращает True, если восстановили.
193
+ """
194
+ if not hasattr(self.model, "is_deleted"):
195
+ raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
196
+ obj = self.session.get(self.model, id_)
197
+ if not obj:
198
+ return False
199
+ if getattr(obj, "is_deleted", False):
200
+ setattr(obj, "is_deleted", False)
201
+ self.session.flush()
202
+ return True
203
+ return False
204
+
205
+ class SAAsyncRepository(Generic[T]):
206
+ """Async repository implementation for SQLAlchemy 2.x."""
207
+ def __init__(self, model: Type[T], session: AsyncSession):
208
+ self.model = model
209
+ self.session = session
210
+
211
+ def _resolve_column(self, column_name: str):
212
+ try:
213
+ return getattr(self.model, column_name)
214
+ except AttributeError as e:
215
+ raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
216
+
217
+ def _apply_filters(self, stmt, **filters):
218
+ if filters:
219
+ stmt = stmt.filter_by(**filters)
220
+ return stmt
221
+
222
+ def _select(self):
223
+ return select(self.model)
224
+
225
+ async def getAll(self, limit: Optional[int] = None) -> List[T]:
226
+ stmt = select(self.model)
227
+ if limit is not None:
228
+ stmt = stmt.limit(limit)
229
+ result = await self.session.execute(stmt)
230
+ return result.scalars().all()
231
+
232
+ async def get(self, id_: Any) -> T:
233
+ obj = await self.session.get(self.model, id_)
234
+ if not obj:
235
+ raise NotFoundError(f"{self.model.__name__}({id_}) not found")
236
+ return obj
237
+
238
+ async def try_get(self, id_: Any) -> Optional[T]:
239
+ return await self.session.get(self.model, id_)
240
+
241
+ async def add(self, entity: T) -> T:
242
+ self.session.add(entity)
243
+ await self.session.flush()
244
+ await self.session.refresh(entity)
245
+ return entity
246
+
247
+ async def update(self, entity: T) -> T:
248
+ await self.session.flush()
249
+ await self.session.refresh(entity)
250
+ return entity
251
+
252
+ async def remove(self, entity: T) -> None:
253
+ insp = inspect(entity, raiseerr=False)
254
+ if not (insp and (insp.persistent or insp.pending)):
255
+ pk = getattr(entity, "id", None)
256
+ if pk is None:
257
+ raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
258
+ entity = await self.session.get(self.model, pk)
259
+ if entity is None:
260
+ return
261
+ if hasattr(entity, "is_deleted"):
262
+ setattr(entity, "is_deleted", True)
263
+ else:
264
+ await self.session.delete(entity)
265
+
266
+ async def delete_by_id(self, id_: Any) -> bool:
267
+ obj = await self.session.get(self.model, id_)
268
+ if not obj:
269
+ return False
270
+ await self.remove(obj)
271
+ return True
272
+
273
+ async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
274
+ stmt = self._select()
275
+ if spec:
276
+ stmt = spec(stmt)
277
+ if order_by is not None:
278
+ stmt = stmt.order_by(order_by)
279
+ total = (await self.session.execute(
280
+ select(func.count()).select_from(stmt.subquery())
281
+ )).scalar_one()
282
+ res = await self.session.execute(
283
+ stmt.offset(page.page * page.size).limit(page.size)
284
+ )
285
+ items = res.scalars().all()
286
+ return Page(items, total, page.page, page.size)
287
+
288
+ async def get_all_by_column(
289
+ self,
290
+ column_name: str,
291
+ value: Any,
292
+ *,
293
+ limit: Optional[int] = None,
294
+ order_by=None,
295
+ include_deleted: bool = False,
296
+ **extra_filters
297
+ ) -> list[T]:
298
+ col = self._resolve_column(column_name)
299
+ stmt = select(self.model).where(col == value)
300
+ if not include_deleted and hasattr(self.model, "is_deleted"):
301
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
302
+ stmt = self._apply_filters(stmt, **extra_filters)
303
+ if order_by is not None:
304
+ stmt = stmt.order_by(order_by)
305
+ if limit is not None:
306
+ stmt = stmt.limit(limit)
307
+ res = await self.session.execute(stmt)
308
+ return res.scalars().all()
309
+
310
+ # Alias
311
+ async def find_all_by_column(self, *args, **kwargs):
312
+ return await self.get_all_by_column(*args, **kwargs)
313
+
314
+ async def get_or_create(
315
+ self,
316
+ defaults: Optional[dict] = None,
317
+ **unique_filters
318
+ ) -> tuple[T, bool]:
319
+ stmt = select(self.model).filter_by(**unique_filters)
320
+ if hasattr(self.model, "is_deleted"):
321
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
322
+ obj = (await self.session.execute(stmt)).scalar_one_or_none()
323
+ if obj:
324
+ return obj, False
325
+ payload = {**unique_filters, **(defaults or {})}
326
+ obj = self.model(**payload) # type: ignore[call-arg]
327
+ self.session.add(obj)
328
+ await self.session.flush()
329
+ await self.session.refresh(obj)
330
+ return obj, True
331
+
332
+ async def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
333
+ res = await self.session.execute(text(sql), params or {})
334
+ return [dict(row) for row in res.mappings().all()]
335
+
336
+ async def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
337
+ col = self._resolve_column(column_name)
338
+ stmt = select(func.avg(col))
339
+ if hasattr(self.model, "is_deleted"):
340
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
341
+ stmt = self._apply_filters(stmt, **filters)
342
+ return (await self.session.execute(stmt)).scalar()
343
+
344
+ async def aggregate_min(self, column_name: str, **filters):
345
+ col = self._resolve_column(column_name)
346
+ stmt = select(func.min(col))
347
+ if hasattr(self.model, "is_deleted"):
348
+ stmt = stmt.where(self.model.is_deleted == False)
349
+ stmt = self._apply_filters(stmt, **filters)
350
+ return (await self.session.execute(stmt)).scalar()
351
+
352
+ async def aggregate_max(self, column_name: str, **filters):
353
+ col = self._resolve_column(column_name)
354
+ stmt = select(func.max(col))
355
+ if hasattr(self.model, "is_deleted"):
356
+ stmt = stmt.where(self.model.is_deleted == False)
357
+ stmt = self._apply_filters(stmt, **filters)
358
+ return (await self.session.execute(stmt)).scalar()
359
+
360
+ async def aggregate_sum(self, column_name: str, **filters):
361
+ col = self._resolve_column(column_name)
362
+ stmt = select(func.sum(col))
363
+ if hasattr(self.model, "is_deleted"):
364
+ stmt = stmt.where(self.model.is_deleted == False)
365
+ stmt = self._apply_filters(stmt, **filters)
366
+ return (await self.session.execute(stmt)).scalar()
367
+
368
+ async def count(self, **filters) -> int:
369
+ include_deleted = bool(filters.pop("include_deleted", False))
370
+ stmt = select(func.count()).select_from(self.model)
371
+ if hasattr(self.model, "is_deleted") and not include_deleted:
372
+ stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
373
+ if filters:
374
+ stmt = stmt.filter_by(**filters)
375
+ return int((await self.session.execute(stmt)).scalar_one())
376
+
377
+ async def restore(self, id_: Any) -> bool:
378
+ if not hasattr(self.model, "is_deleted"):
379
+ raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
380
+ obj = await self.session.get(self.model, id_)
381
+ if not obj:
382
+ return False
383
+ if getattr(obj, "is_deleted", False):
384
+ setattr(obj, "is_deleted", False)
385
+ await self.session.flush()
386
+ return True
387
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SARepo
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x
5
5
  Author: nurbergenovv
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
 
2
2
  [project]
3
3
  name = "SARepo"
4
- version = "0.1.3"
4
+ version = "0.1.4"
5
5
  description = "Minimal, explicit Repository & Unit-of-Work layer over SQLAlchemy 2.x"
6
6
  readme = "README.md"
7
7
  requires-python = ">=3.11"
@@ -0,0 +1,145 @@
1
+ import pytest
2
+ from typing import Optional
3
+ from sqlalchemy import create_engine, String, Integer, Boolean, Text
4
+ from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
5
+
6
+ from SARepo.sa_repo import SARepository
7
+ from SARepo.base import NotFoundError
8
+
9
+
10
+
11
+ class Base(DeclarativeBase):
12
+ pass
13
+
14
+
15
+ class Item(Base):
16
+ __tablename__ = "items"
17
+
18
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
19
+ title: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
20
+ city: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
21
+ amount: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
22
+ age: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
23
+ is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
24
+ note: Mapped[Optional[str]] = mapped_column(Text)
25
+
26
+
27
+
28
+ @pytest.fixture()
29
+ def session():
30
+ engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
31
+ Base.metadata.create_all(engine)
32
+ SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
33
+ with SessionLocal() as s:
34
+ yield s
35
+
36
+
37
+ @pytest.fixture()
38
+ def repo(session):
39
+ return SARepository(Item, session)
40
+
41
+
42
+ def seed(session, repo):
43
+ """Наполняем базу различными записями для агрегатов/фильтров."""
44
+ data = [
45
+ Item(title="alpha", city="Almaty", amount=10, age=20, note="A"),
46
+ Item(title="beta", city="Astana", amount=20, age=30, note="B"),
47
+ Item(title="gamma", city="Almaty", amount=30, age=40, note="G"),
48
+ Item(title="delta", city="Shymkent", amount=40, age=50, note="D"),
49
+ Item(title="omega", city=None, amount=50, age=60, note="O"),
50
+ ]
51
+ for obj in data:
52
+ repo.add(obj)
53
+ session.commit()
54
+
55
+
56
+
57
+ def test_find_all_by_column_basic(session, repo):
58
+ seed(session, repo)
59
+
60
+ items = repo.find_all_by_column("city", "Almaty")
61
+ titles = {x.title for x in items}
62
+ assert titles == {"alpha", "gamma"}
63
+
64
+ limited = repo.find_all_by_column("city", "Almaty", limit=1, order_by=Item.amount.desc())
65
+ assert len(limited) == 1 and limited[0].title == "gamma"
66
+
67
+ with pytest.raises(ValueError):
68
+ repo.find_all_by_column("no_such_column", "x")
69
+
70
+
71
+ def test_get_or_create(session, repo):
72
+ seed(session, repo)
73
+
74
+ obj, created = repo.get_or_create(title="alpha", defaults={"city": "Kokshetau"})
75
+ assert created is False
76
+ assert obj.city == "Almaty"
77
+
78
+ obj2, created2 = repo.get_or_create(title="sigma", defaults={"city": "Aktau", "amount": 77})
79
+ assert created2 is True
80
+ assert (obj2.title, obj2.city, obj2.amount) == ("sigma", "Aktau", 77)
81
+
82
+
83
+ def test_raw_query(session, repo):
84
+ seed(session, repo)
85
+
86
+ rows = repo.raw_query(
87
+ "SELECT id, title, city FROM items WHERE lower(title) LIKE :p",
88
+ {"p": "%a%"}
89
+ )
90
+ assert isinstance(rows, list) and all(isinstance(r, dict) for r in rows)
91
+ titles = {r["title"] for r in rows}
92
+ assert {"alpha", "gamma", "delta"}.issubset(titles)
93
+
94
+
95
+ def test_aggregates_avg_min_max_sum_ignore_soft_deleted_by_default(session, repo):
96
+ seed(session, repo)
97
+
98
+ omega = repo.find_all_by_column("title", "omega")[0]
99
+ omega.is_deleted = True
100
+ session.commit()
101
+
102
+ avg_amount = repo.aggregate_avg("amount")
103
+ min_amount = repo.aggregate_min("amount")
104
+ max_amount = repo.aggregate_max("amount")
105
+ sum_amount = repo.aggregate_sum("amount")
106
+
107
+ assert avg_amount == pytest.approx((10 + 20 + 30 + 40) / 4)
108
+ assert min_amount == 10
109
+ assert max_amount == 40
110
+ assert sum_amount == 10 + 20 + 30 + 40
111
+
112
+ avg_age_almaty = repo.aggregate_avg("age", city="Almaty")
113
+ assert avg_age_almaty == pytest.approx((20 + 40) / 2)
114
+
115
+
116
+ def test_count_with_and_without_deleted(session, repo):
117
+ seed(session, repo)
118
+
119
+ for t in ("alpha", "beta"):
120
+ x = repo.find_all_by_column("title", t)[0]
121
+ x.is_deleted = True
122
+ session.commit()
123
+
124
+ assert repo.count() == 3
125
+
126
+ assert repo.count(include_deleted=True) == 5
127
+
128
+ assert repo.count(city="Almaty") == 1
129
+
130
+
131
+ def test_restore(session, repo):
132
+ seed(session, repo)
133
+
134
+ beta = repo.find_all_by_column("title", "beta")[0]
135
+ beta.is_deleted = True
136
+ session.commit()
137
+
138
+ ok = repo.restore(beta.id)
139
+ assert ok is True
140
+ session.commit()
141
+
142
+ ok2 = repo.restore(beta.id)
143
+ assert ok2 is False
144
+
145
+ assert repo.count() == 5
@@ -1,16 +0,0 @@
1
-
2
- from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
3
- from .base import Page, PageRequest
4
-
5
- T = TypeVar("T")
6
-
7
- class CrudRepository(Protocol, Generic[T]):
8
- model: Type[T]
9
- def getAll(self, limit: Optional[int]) -> List[T]: ...
10
- def get(self, id_: Any) -> T: ...
11
- def try_get(self, id_: Any) -> Optional[T]: ...
12
- def add(self, entity: T) -> T: ...
13
- def update(self, entity: T) -> T: ...
14
- def remove(self, entity: T) -> None: ...
15
- def delete_by_id(self, id_: Any) -> bool: ...
16
- def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
@@ -1,152 +0,0 @@
1
-
2
- from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
3
- from sqlalchemy.orm import Session
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
- from sqlalchemy import inspect, select, func
6
- from .base import PageRequest, Page, NotFoundError
7
-
8
- T = TypeVar("T")
9
- Spec = Callable # aliased to match specs.Spec
10
-
11
- class SARepository(Generic[T]):
12
- """Synchronous repository implementation for SQLAlchemy 2.x."""
13
- def __init__(self, model: Type[T], session: Session):
14
- self.model = model
15
- self.session = session
16
-
17
- def _select(self):
18
- return select(self.model)
19
-
20
- def getAll(self, limit: Optional[int] = None) -> List[T]:
21
- stmt = select(self.model)
22
- if limit is not None:
23
- stmt = stmt.limit(limit)
24
- result = self.session.execute(stmt)
25
- return result.scalars().all()
26
-
27
- def get(self, id_: Any) -> T:
28
- obj = self.session.get(self.model, id_)
29
- if not obj:
30
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
31
- return obj
32
-
33
- def try_get(self, id_: Any) -> Optional[T]:
34
- return self.session.get(self.model, id_)
35
-
36
- def add(self, entity: T) -> T:
37
- self.session.add(entity)
38
- self.session.flush()
39
- self.session.refresh(entity)
40
- return entity
41
-
42
- def update(self, entity: T) -> T:
43
- self.session.flush()
44
- self.session.refresh(entity)
45
- return entity
46
-
47
- def remove(self, entity: T) -> None:
48
- insp = inspect(entity, raiseerr=False)
49
- if not (insp and (insp.persistent or insp.pending)):
50
- pk = getattr(entity, "id", None)
51
- if pk is None:
52
- raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
53
- entity = self.session.get(self.model, pk)
54
- if entity is None:
55
- return
56
- if hasattr(entity, "is_deleted"):
57
- setattr(entity, "is_deleted", True)
58
- else:
59
- self.session.delete(entity)
60
-
61
- def delete_by_id(self, id_: Any) -> bool:
62
- obj = self.session.get(self.model, id_)
63
- if not obj:
64
- return False
65
- self.remove(obj)
66
- return True
67
-
68
- def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
69
- stmt = self._select()
70
- if spec:
71
- stmt = spec(stmt)
72
- if order_by is not None:
73
- stmt = stmt.order_by(order_by)
74
- total = self.session.execute(
75
- select(func.count()).select_from(stmt.subquery())
76
- ).scalar_one()
77
- items = self.session.execute(
78
- stmt.offset(page.page * page.size).limit(page.size)
79
- ).scalars().all()
80
- return Page(items, total, page.page, page.size)
81
-
82
- class SAAsyncRepository(Generic[T]):
83
- """Async repository implementation for SQLAlchemy 2.x."""
84
- def __init__(self, model: Type[T], session: AsyncSession):
85
- self.model = model
86
- self.session = session
87
-
88
- async def getAll(self, limit: Optional[int] = None) -> List[T]:
89
- stmt = select(self.model)
90
- if limit is not None:
91
- stmt = stmt.limit(limit)
92
- result = await self.session.execute(stmt)
93
- return result.scalars().all()
94
-
95
- def _select(self):
96
- return select(self.model)
97
-
98
- async def get(self, id_: Any) -> T:
99
- obj = await self.session.get(self.model, id_)
100
- if not obj:
101
- raise NotFoundError(f"{self.model.__name__}({id_}) not found")
102
- return obj
103
-
104
- async def try_get(self, id_: Any) -> Optional[T]:
105
- return await self.session.get(self.model, id_)
106
-
107
- async def add(self, entity: T) -> T:
108
- self.session.add(entity)
109
- await self.session.flush()
110
- await self.session.refresh(entity)
111
- return entity
112
-
113
- async def update(self, entity: T) -> T:
114
- await self.session.flush()
115
- await self.session.refresh(entity)
116
- return entity
117
-
118
- async def remove(self, entity: T) -> None:
119
- insp = inspect(entity, raiseerr=False)
120
- if not (insp and (insp.persistent or insp.pending)):
121
- pk = getattr(entity, "id", None)
122
- if pk is None:
123
- raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
124
- entity = await self.session.get(self.model, pk)
125
- if entity is None:
126
- return
127
- if hasattr(entity, "is_deleted"):
128
- setattr(entity, "is_deleted", True)
129
- else:
130
- await self.session.delete(entity)
131
-
132
- async def delete_by_id(self, id_: Any) -> bool:
133
- obj = await self.session.get(self.model, id_)
134
- if not obj:
135
- return False
136
- await self.remove(obj)
137
- return True
138
-
139
- async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
140
- stmt = self._select()
141
- if spec:
142
- stmt = spec(stmt)
143
- if order_by is not None:
144
- stmt = stmt.order_by(order_by)
145
- total = (await self.session.execute(
146
- select(func.count()).select_from(stmt.subquery())
147
- )).scalar_one()
148
- res = await self.session.execute(
149
- stmt.offset(page.page * page.size).limit(page.size)
150
- )
151
- items = res.scalars().all()
152
- return Page(items, total, page.page, page.size)
@@ -1,73 +0,0 @@
1
- import pytest
2
- from sqlalchemy import create_engine, String
3
- from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
4
- from SARepo.sa_repo import SARepository
5
- from SARepo.base import PageRequest, NotFoundError
6
-
7
- class Base(DeclarativeBase): pass
8
-
9
- class Item(Base):
10
- __tablename__ = "items"
11
- id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
12
- title: Mapped[str] = mapped_column(String(100), nullable=False)
13
-
14
- def test_crud_and_pagination():
15
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
16
- Base.metadata.create_all(engine)
17
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
18
-
19
- with SessionLocal() as session:
20
- repo = SARepository(Item, session)
21
- for i in range(25):
22
- repo.add(Item(title=f"t{i}"))
23
- session.commit()
24
-
25
- page = repo.page(PageRequest(1, 10))
26
- assert page.total == 25
27
- assert len(page.items) == 10
28
- assert page.page == 1
29
- assert page.pages == 3
30
- # проверим, что конкретные объекты есть
31
- got = {it.title for it in page.items}
32
- assert got.issubset({f"t{i}" for i in range(25)})
33
-
34
- def test_get_all():
35
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
36
- Base.metadata.create_all(engine)
37
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
38
-
39
- with SessionLocal() as session:
40
- repo = SARepository(Item, session)
41
- for i in range(10):
42
- repo.add(Item(title=f"get-all-{i}"))
43
- session.commit()
44
-
45
- items = repo.getAll()
46
- assert isinstance(items, list)
47
- assert len(items) == 10
48
- titles = {i.title for i in items}
49
- assert titles == {f"get-all-{i}" for i in range(10)}
50
-
51
- items2 = repo.getAll(5)
52
- assert isinstance(items2, list)
53
- assert len(items2) == 5
54
- titles = {i.title for i in items2}
55
- assert titles == {f"get-all-{i}" for i in range(5)}
56
-
57
- def test_get_and_try_get():
58
- engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
59
- Base.metadata.create_all(engine)
60
- SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
61
-
62
- with SessionLocal() as session:
63
- repo = SARepository(Item, session)
64
- obj = repo.add(Item(title="one"))
65
- session.commit()
66
-
67
- same = repo.get(obj.id)
68
- assert same.id == obj.id and same.title == "one"
69
-
70
- assert repo.try_get(9999) is None
71
-
72
- with pytest.raises(NotFoundError):
73
- repo.get(9999)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes