SARepo 0.1.3__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sarepo-0.1.3 → sarepo-0.1.4}/PKG-INFO +1 -1
- sarepo-0.1.4/SARepo/repo.py +48 -0
- sarepo-0.1.4/SARepo/sa_repo.py +387 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo.egg-info/PKG-INFO +1 -1
- {sarepo-0.1.3 → sarepo-0.1.4}/pyproject.toml +1 -1
- sarepo-0.1.4/tests/test_sync_basic.py +145 -0
- sarepo-0.1.3/SARepo/repo.py +0 -16
- sarepo-0.1.3/SARepo/sa_repo.py +0 -152
- sarepo-0.1.3/tests/test_sync_basic.py +0 -73
- {sarepo-0.1.3 → sarepo-0.1.4}/LICENSE +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/README.md +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo/__init__.py +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo/base.py +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo/models.py +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo/specs.py +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo/uow.py +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo.egg-info/SOURCES.txt +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo.egg-info/dependency_links.txt +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo.egg-info/requires.txt +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/SARepo.egg-info/top_level.txt +0 -0
- {sarepo-0.1.3 → sarepo-0.1.4}/setup.cfg +0 -0
@@ -0,0 +1,48 @@
|
|
1
|
+
|
2
|
+
from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
|
3
|
+
from .base import Page, PageRequest
|
4
|
+
|
5
|
+
T = TypeVar("T")
|
6
|
+
|
7
|
+
class CrudRepository(Protocol, Generic[T]):
|
8
|
+
model: Type[T]
|
9
|
+
def getAll(self, limit: Optional[int]) -> List[T]: ...
|
10
|
+
def get(self, id_: Any) -> T: ...
|
11
|
+
def try_get(self, id_: Any) -> Optional[T]: ...
|
12
|
+
def add(self, entity: T) -> T: ...
|
13
|
+
def update(self, entity: T) -> T: ...
|
14
|
+
def remove(self, entity: T) -> None: ...
|
15
|
+
def delete_by_id(self, id_: Any) -> bool: ...
|
16
|
+
def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
|
17
|
+
def get_all_by_column(
|
18
|
+
self,
|
19
|
+
column_name: str,
|
20
|
+
value: Any,
|
21
|
+
*,
|
22
|
+
limit: Optional[int] = None,
|
23
|
+
order_by=None,
|
24
|
+
include_deleted: bool = False,
|
25
|
+
**extra_filters
|
26
|
+
) -> list[T]: ...
|
27
|
+
def find_all_by_column(
|
28
|
+
self,
|
29
|
+
column_name: str,
|
30
|
+
value: Any,
|
31
|
+
*,
|
32
|
+
limit: Optional[int] = None,
|
33
|
+
order_by=None,
|
34
|
+
include_deleted: bool = False,
|
35
|
+
**extra_filters
|
36
|
+
) -> list[T]: ...
|
37
|
+
def get_or_create(
|
38
|
+
self,
|
39
|
+
defaults: Optional[dict] = None,
|
40
|
+
**unique_filters
|
41
|
+
) -> tuple[T, bool]: ...
|
42
|
+
def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]: ...
|
43
|
+
def aggregate_avg(self, column_name: str, **filters) -> Optional[float]: ...
|
44
|
+
def aggregate_min(self, column_name: str, **filters): ...
|
45
|
+
def aggregate_max(self, column_name: str, **filters): ...
|
46
|
+
def aggregate_sum(self, column_name: str, **filters): ...
|
47
|
+
def count(self, **filters) -> int: ...
|
48
|
+
def restore(self, id_: Any) -> bool: ...
|
@@ -0,0 +1,387 @@
|
|
1
|
+
|
2
|
+
from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
|
3
|
+
from sqlalchemy.orm import Session
|
4
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
5
|
+
from sqlalchemy import inspect, select, func, text
|
6
|
+
from .base import PageRequest, Page, NotFoundError
|
7
|
+
|
8
|
+
T = TypeVar("T")
|
9
|
+
Spec = Callable
|
10
|
+
|
11
|
+
class SARepository(Generic[T]):
|
12
|
+
"""Synchronous repository implementation for SQLAlchemy 2.x."""
|
13
|
+
def __init__(self, model: Type[T], session: Session):
|
14
|
+
self.model = model
|
15
|
+
self.session = session
|
16
|
+
|
17
|
+
def _resolve_column(self, column_name: str):
|
18
|
+
try:
|
19
|
+
return getattr(self.model, column_name)
|
20
|
+
except AttributeError as e:
|
21
|
+
raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
|
22
|
+
|
23
|
+
def _apply_filters(self, stmt, **filters):
|
24
|
+
if filters:
|
25
|
+
stmt = stmt.filter_by(**filters)
|
26
|
+
return stmt
|
27
|
+
|
28
|
+
def _select(self):
|
29
|
+
return select(self.model)
|
30
|
+
|
31
|
+
def getAll(self, limit: Optional[int] = None) -> List[T]:
|
32
|
+
stmt = select(self.model)
|
33
|
+
if limit is not None:
|
34
|
+
stmt = stmt.limit(limit)
|
35
|
+
result = self.session.execute(stmt)
|
36
|
+
return result.scalars().all()
|
37
|
+
|
38
|
+
def get(self, id_: Any) -> T:
|
39
|
+
obj = self.session.get(self.model, id_)
|
40
|
+
if not obj:
|
41
|
+
raise NotFoundError(f"{self.model.__name__}({id_}) not found")
|
42
|
+
return obj
|
43
|
+
|
44
|
+
def try_get(self, id_: Any) -> Optional[T]:
|
45
|
+
return self.session.get(self.model, id_)
|
46
|
+
|
47
|
+
def add(self, entity: T) -> T:
|
48
|
+
self.session.add(entity)
|
49
|
+
self.session.flush()
|
50
|
+
self.session.refresh(entity)
|
51
|
+
return entity
|
52
|
+
|
53
|
+
def update(self, entity: T) -> T:
|
54
|
+
self.session.flush()
|
55
|
+
self.session.refresh(entity)
|
56
|
+
return entity
|
57
|
+
|
58
|
+
def remove(self, entity: T) -> None:
|
59
|
+
insp = inspect(entity, raiseerr=False)
|
60
|
+
if not (insp and (insp.persistent or insp.pending)):
|
61
|
+
pk = getattr(entity, "id", None)
|
62
|
+
if pk is None:
|
63
|
+
raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
|
64
|
+
entity = self.session.get(self.model, pk)
|
65
|
+
if entity is None:
|
66
|
+
return
|
67
|
+
if hasattr(entity, "is_deleted"):
|
68
|
+
setattr(entity, "is_deleted", True)
|
69
|
+
else:
|
70
|
+
self.session.delete(entity)
|
71
|
+
|
72
|
+
def delete_by_id(self, id_: Any) -> bool:
|
73
|
+
obj = self.session.get(self.model, id_)
|
74
|
+
if not obj:
|
75
|
+
return False
|
76
|
+
self.remove(obj)
|
77
|
+
return True
|
78
|
+
|
79
|
+
def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
|
80
|
+
stmt = self._select()
|
81
|
+
if spec:
|
82
|
+
stmt = spec(stmt)
|
83
|
+
if order_by is not None:
|
84
|
+
stmt = stmt.order_by(order_by)
|
85
|
+
total = self.session.execute(
|
86
|
+
select(func.count()).select_from(stmt.subquery())
|
87
|
+
).scalar_one()
|
88
|
+
items = self.session.execute(
|
89
|
+
stmt.offset(page.page * page.size).limit(page.size)
|
90
|
+
).scalars().all()
|
91
|
+
return Page(items, total, page.page, page.size)
|
92
|
+
|
93
|
+
def get_all_by_column(
|
94
|
+
self,
|
95
|
+
column_name: str,
|
96
|
+
value: Any,
|
97
|
+
*,
|
98
|
+
limit: Optional[int] = None,
|
99
|
+
order_by=None,
|
100
|
+
include_deleted: bool = False,
|
101
|
+
**extra_filters
|
102
|
+
) -> list[T]:
|
103
|
+
col = self._resolve_column(column_name)
|
104
|
+
stmt = select(self.model).where(col == value)
|
105
|
+
if not include_deleted and hasattr(self.model, "is_deleted"):
|
106
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
107
|
+
stmt = self._apply_filters(stmt, **extra_filters)
|
108
|
+
if order_by is not None:
|
109
|
+
stmt = stmt.order_by(order_by)
|
110
|
+
if limit is not None:
|
111
|
+
stmt = stmt.limit(limit)
|
112
|
+
res = self.session.execute(stmt)
|
113
|
+
return res.scalars().all()
|
114
|
+
|
115
|
+
# Alias
|
116
|
+
def find_all_by_column(self, *args, **kwargs):
|
117
|
+
return self.get_all_by_column(*args, **kwargs)
|
118
|
+
|
119
|
+
def get_or_create(
|
120
|
+
self,
|
121
|
+
defaults: Optional[dict] = None,
|
122
|
+
**unique_filters
|
123
|
+
) -> tuple[T, bool]:
|
124
|
+
"""
|
125
|
+
Возвращает (obj, created). unique_filters определяют уникальность.
|
126
|
+
defaults дополняют поля при создании.
|
127
|
+
"""
|
128
|
+
stmt = select(self.model).filter_by(**unique_filters)
|
129
|
+
if hasattr(self.model, "is_deleted"):
|
130
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
131
|
+
obj = self.session.execute(stmt).scalar_one_or_none()
|
132
|
+
if obj:
|
133
|
+
return obj, False
|
134
|
+
payload = {**unique_filters, **(defaults or {})}
|
135
|
+
obj = self.model(**payload) # type: ignore[call-arg]
|
136
|
+
self.session.add(obj)
|
137
|
+
self.session.flush()
|
138
|
+
self.session.refresh(obj)
|
139
|
+
return obj, True
|
140
|
+
|
141
|
+
def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
|
142
|
+
"""
|
143
|
+
Безопасно выполняет сырой SQL (используй плейсхолдеры :name).
|
144
|
+
Возвращает список dict (строки).
|
145
|
+
"""
|
146
|
+
res = self.session.execute(text(sql), params or {})
|
147
|
+
# mapping() -> RowMapping (dict-like)
|
148
|
+
return [dict(row) for row in res.mappings().all()]
|
149
|
+
|
150
|
+
def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
|
151
|
+
col = self._resolve_column(column_name)
|
152
|
+
stmt = select(func.avg(col))
|
153
|
+
if hasattr(self.model, "is_deleted"):
|
154
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
155
|
+
stmt = self._apply_filters(stmt, **filters)
|
156
|
+
return self.session.execute(stmt).scalar()
|
157
|
+
|
158
|
+
def aggregate_min(self, column_name: str, **filters):
|
159
|
+
col = self._resolve_column(column_name)
|
160
|
+
stmt = select(func.min(col))
|
161
|
+
if hasattr(self.model, "is_deleted"):
|
162
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
163
|
+
stmt = self._apply_filters(stmt, **filters)
|
164
|
+
return self.session.execute(stmt).scalar()
|
165
|
+
|
166
|
+
def aggregate_max(self, column_name: str, **filters):
|
167
|
+
col = self._resolve_column(column_name)
|
168
|
+
stmt = select(func.max(col))
|
169
|
+
if hasattr(self.model, "is_deleted"):
|
170
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
171
|
+
stmt = self._apply_filters(stmt, **filters)
|
172
|
+
return self.session.execute(stmt).scalar()
|
173
|
+
|
174
|
+
def aggregate_sum(self, column_name: str, **filters):
|
175
|
+
col = self._resolve_column(column_name)
|
176
|
+
stmt = select(func.sum(col))
|
177
|
+
if hasattr(self.model, "is_deleted"):
|
178
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
179
|
+
stmt = self._apply_filters(stmt, **filters)
|
180
|
+
return self.session.execute(stmt).scalar()
|
181
|
+
|
182
|
+
def count(self, **filters) -> int:
|
183
|
+
stmt = select(func.count()).select_from(self.model)
|
184
|
+
if hasattr(self.model, "is_deleted") and not filters.pop("include_deleted", False):
|
185
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
186
|
+
if filters:
|
187
|
+
stmt = stmt.filter_by(**filters)
|
188
|
+
return int(self.session.execute(stmt).scalar_one())
|
189
|
+
|
190
|
+
def restore(self, id_: Any) -> bool:
|
191
|
+
"""
|
192
|
+
Для soft-delete: is_deleted=False. Возвращает True, если восстановили.
|
193
|
+
"""
|
194
|
+
if not hasattr(self.model, "is_deleted"):
|
195
|
+
raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
|
196
|
+
obj = self.session.get(self.model, id_)
|
197
|
+
if not obj:
|
198
|
+
return False
|
199
|
+
if getattr(obj, "is_deleted", False):
|
200
|
+
setattr(obj, "is_deleted", False)
|
201
|
+
self.session.flush()
|
202
|
+
return True
|
203
|
+
return False
|
204
|
+
|
205
|
+
class SAAsyncRepository(Generic[T]):
|
206
|
+
"""Async repository implementation for SQLAlchemy 2.x."""
|
207
|
+
def __init__(self, model: Type[T], session: AsyncSession):
|
208
|
+
self.model = model
|
209
|
+
self.session = session
|
210
|
+
|
211
|
+
def _resolve_column(self, column_name: str):
|
212
|
+
try:
|
213
|
+
return getattr(self.model, column_name)
|
214
|
+
except AttributeError as e:
|
215
|
+
raise ValueError(f"Model {self.model.__name__} has no column '{column_name}'") from e
|
216
|
+
|
217
|
+
def _apply_filters(self, stmt, **filters):
|
218
|
+
if filters:
|
219
|
+
stmt = stmt.filter_by(**filters)
|
220
|
+
return stmt
|
221
|
+
|
222
|
+
def _select(self):
|
223
|
+
return select(self.model)
|
224
|
+
|
225
|
+
async def getAll(self, limit: Optional[int] = None) -> List[T]:
|
226
|
+
stmt = select(self.model)
|
227
|
+
if limit is not None:
|
228
|
+
stmt = stmt.limit(limit)
|
229
|
+
result = await self.session.execute(stmt)
|
230
|
+
return result.scalars().all()
|
231
|
+
|
232
|
+
async def get(self, id_: Any) -> T:
|
233
|
+
obj = await self.session.get(self.model, id_)
|
234
|
+
if not obj:
|
235
|
+
raise NotFoundError(f"{self.model.__name__}({id_}) not found")
|
236
|
+
return obj
|
237
|
+
|
238
|
+
async def try_get(self, id_: Any) -> Optional[T]:
|
239
|
+
return await self.session.get(self.model, id_)
|
240
|
+
|
241
|
+
async def add(self, entity: T) -> T:
|
242
|
+
self.session.add(entity)
|
243
|
+
await self.session.flush()
|
244
|
+
await self.session.refresh(entity)
|
245
|
+
return entity
|
246
|
+
|
247
|
+
async def update(self, entity: T) -> T:
|
248
|
+
await self.session.flush()
|
249
|
+
await self.session.refresh(entity)
|
250
|
+
return entity
|
251
|
+
|
252
|
+
async def remove(self, entity: T) -> None:
|
253
|
+
insp = inspect(entity, raiseerr=False)
|
254
|
+
if not (insp and (insp.persistent or insp.pending)):
|
255
|
+
pk = getattr(entity, "id", None)
|
256
|
+
if pk is None:
|
257
|
+
raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
|
258
|
+
entity = await self.session.get(self.model, pk)
|
259
|
+
if entity is None:
|
260
|
+
return
|
261
|
+
if hasattr(entity, "is_deleted"):
|
262
|
+
setattr(entity, "is_deleted", True)
|
263
|
+
else:
|
264
|
+
await self.session.delete(entity)
|
265
|
+
|
266
|
+
async def delete_by_id(self, id_: Any) -> bool:
|
267
|
+
obj = await self.session.get(self.model, id_)
|
268
|
+
if not obj:
|
269
|
+
return False
|
270
|
+
await self.remove(obj)
|
271
|
+
return True
|
272
|
+
|
273
|
+
async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
|
274
|
+
stmt = self._select()
|
275
|
+
if spec:
|
276
|
+
stmt = spec(stmt)
|
277
|
+
if order_by is not None:
|
278
|
+
stmt = stmt.order_by(order_by)
|
279
|
+
total = (await self.session.execute(
|
280
|
+
select(func.count()).select_from(stmt.subquery())
|
281
|
+
)).scalar_one()
|
282
|
+
res = await self.session.execute(
|
283
|
+
stmt.offset(page.page * page.size).limit(page.size)
|
284
|
+
)
|
285
|
+
items = res.scalars().all()
|
286
|
+
return Page(items, total, page.page, page.size)
|
287
|
+
|
288
|
+
async def get_all_by_column(
|
289
|
+
self,
|
290
|
+
column_name: str,
|
291
|
+
value: Any,
|
292
|
+
*,
|
293
|
+
limit: Optional[int] = None,
|
294
|
+
order_by=None,
|
295
|
+
include_deleted: bool = False,
|
296
|
+
**extra_filters
|
297
|
+
) -> list[T]:
|
298
|
+
col = self._resolve_column(column_name)
|
299
|
+
stmt = select(self.model).where(col == value)
|
300
|
+
if not include_deleted and hasattr(self.model, "is_deleted"):
|
301
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
302
|
+
stmt = self._apply_filters(stmt, **extra_filters)
|
303
|
+
if order_by is not None:
|
304
|
+
stmt = stmt.order_by(order_by)
|
305
|
+
if limit is not None:
|
306
|
+
stmt = stmt.limit(limit)
|
307
|
+
res = await self.session.execute(stmt)
|
308
|
+
return res.scalars().all()
|
309
|
+
|
310
|
+
# Alias
|
311
|
+
async def find_all_by_column(self, *args, **kwargs):
|
312
|
+
return await self.get_all_by_column(*args, **kwargs)
|
313
|
+
|
314
|
+
async def get_or_create(
|
315
|
+
self,
|
316
|
+
defaults: Optional[dict] = None,
|
317
|
+
**unique_filters
|
318
|
+
) -> tuple[T, bool]:
|
319
|
+
stmt = select(self.model).filter_by(**unique_filters)
|
320
|
+
if hasattr(self.model, "is_deleted"):
|
321
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
322
|
+
obj = (await self.session.execute(stmt)).scalar_one_or_none()
|
323
|
+
if obj:
|
324
|
+
return obj, False
|
325
|
+
payload = {**unique_filters, **(defaults or {})}
|
326
|
+
obj = self.model(**payload) # type: ignore[call-arg]
|
327
|
+
self.session.add(obj)
|
328
|
+
await self.session.flush()
|
329
|
+
await self.session.refresh(obj)
|
330
|
+
return obj, True
|
331
|
+
|
332
|
+
async def raw_query(self, sql: str, params: Optional[dict] = None) -> list[dict]:
|
333
|
+
res = await self.session.execute(text(sql), params or {})
|
334
|
+
return [dict(row) for row in res.mappings().all()]
|
335
|
+
|
336
|
+
async def aggregate_avg(self, column_name: str, **filters) -> Optional[float]:
|
337
|
+
col = self._resolve_column(column_name)
|
338
|
+
stmt = select(func.avg(col))
|
339
|
+
if hasattr(self.model, "is_deleted"):
|
340
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
341
|
+
stmt = self._apply_filters(stmt, **filters)
|
342
|
+
return (await self.session.execute(stmt)).scalar()
|
343
|
+
|
344
|
+
async def aggregate_min(self, column_name: str, **filters):
|
345
|
+
col = self._resolve_column(column_name)
|
346
|
+
stmt = select(func.min(col))
|
347
|
+
if hasattr(self.model, "is_deleted"):
|
348
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
349
|
+
stmt = self._apply_filters(stmt, **filters)
|
350
|
+
return (await self.session.execute(stmt)).scalar()
|
351
|
+
|
352
|
+
async def aggregate_max(self, column_name: str, **filters):
|
353
|
+
col = self._resolve_column(column_name)
|
354
|
+
stmt = select(func.max(col))
|
355
|
+
if hasattr(self.model, "is_deleted"):
|
356
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
357
|
+
stmt = self._apply_filters(stmt, **filters)
|
358
|
+
return (await self.session.execute(stmt)).scalar()
|
359
|
+
|
360
|
+
async def aggregate_sum(self, column_name: str, **filters):
|
361
|
+
col = self._resolve_column(column_name)
|
362
|
+
stmt = select(func.sum(col))
|
363
|
+
if hasattr(self.model, "is_deleted"):
|
364
|
+
stmt = stmt.where(self.model.is_deleted == False)
|
365
|
+
stmt = self._apply_filters(stmt, **filters)
|
366
|
+
return (await self.session.execute(stmt)).scalar()
|
367
|
+
|
368
|
+
async def count(self, **filters) -> int:
|
369
|
+
include_deleted = bool(filters.pop("include_deleted", False))
|
370
|
+
stmt = select(func.count()).select_from(self.model)
|
371
|
+
if hasattr(self.model, "is_deleted") and not include_deleted:
|
372
|
+
stmt = stmt.where(self.model.is_deleted == False) # noqa: E712
|
373
|
+
if filters:
|
374
|
+
stmt = stmt.filter_by(**filters)
|
375
|
+
return int((await self.session.execute(stmt)).scalar_one())
|
376
|
+
|
377
|
+
async def restore(self, id_: Any) -> bool:
|
378
|
+
if not hasattr(self.model, "is_deleted"):
|
379
|
+
raise RuntimeError(f"{self.model.__name__} has no 'is_deleted' field")
|
380
|
+
obj = await self.session.get(self.model, id_)
|
381
|
+
if not obj:
|
382
|
+
return False
|
383
|
+
if getattr(obj, "is_deleted", False):
|
384
|
+
setattr(obj, "is_deleted", False)
|
385
|
+
await self.session.flush()
|
386
|
+
return True
|
387
|
+
return False
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import pytest
|
2
|
+
from typing import Optional
|
3
|
+
from sqlalchemy import create_engine, String, Integer, Boolean, Text
|
4
|
+
from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
|
5
|
+
|
6
|
+
from SARepo.sa_repo import SARepository
|
7
|
+
from SARepo.base import NotFoundError
|
8
|
+
|
9
|
+
|
10
|
+
|
11
|
+
class Base(DeclarativeBase):
|
12
|
+
pass
|
13
|
+
|
14
|
+
|
15
|
+
class Item(Base):
|
16
|
+
__tablename__ = "items"
|
17
|
+
|
18
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
19
|
+
title: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
20
|
+
city: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)
|
21
|
+
amount: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
22
|
+
age: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
23
|
+
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
|
24
|
+
note: Mapped[Optional[str]] = mapped_column(Text)
|
25
|
+
|
26
|
+
|
27
|
+
|
28
|
+
@pytest.fixture()
|
29
|
+
def session():
|
30
|
+
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
|
31
|
+
Base.metadata.create_all(engine)
|
32
|
+
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
33
|
+
with SessionLocal() as s:
|
34
|
+
yield s
|
35
|
+
|
36
|
+
|
37
|
+
@pytest.fixture()
|
38
|
+
def repo(session):
|
39
|
+
return SARepository(Item, session)
|
40
|
+
|
41
|
+
|
42
|
+
def seed(session, repo):
|
43
|
+
"""Наполняем базу различными записями для агрегатов/фильтров."""
|
44
|
+
data = [
|
45
|
+
Item(title="alpha", city="Almaty", amount=10, age=20, note="A"),
|
46
|
+
Item(title="beta", city="Astana", amount=20, age=30, note="B"),
|
47
|
+
Item(title="gamma", city="Almaty", amount=30, age=40, note="G"),
|
48
|
+
Item(title="delta", city="Shymkent", amount=40, age=50, note="D"),
|
49
|
+
Item(title="omega", city=None, amount=50, age=60, note="O"),
|
50
|
+
]
|
51
|
+
for obj in data:
|
52
|
+
repo.add(obj)
|
53
|
+
session.commit()
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
def test_find_all_by_column_basic(session, repo):
|
58
|
+
seed(session, repo)
|
59
|
+
|
60
|
+
items = repo.find_all_by_column("city", "Almaty")
|
61
|
+
titles = {x.title for x in items}
|
62
|
+
assert titles == {"alpha", "gamma"}
|
63
|
+
|
64
|
+
limited = repo.find_all_by_column("city", "Almaty", limit=1, order_by=Item.amount.desc())
|
65
|
+
assert len(limited) == 1 and limited[0].title == "gamma"
|
66
|
+
|
67
|
+
with pytest.raises(ValueError):
|
68
|
+
repo.find_all_by_column("no_such_column", "x")
|
69
|
+
|
70
|
+
|
71
|
+
def test_get_or_create(session, repo):
|
72
|
+
seed(session, repo)
|
73
|
+
|
74
|
+
obj, created = repo.get_or_create(title="alpha", defaults={"city": "Kokshetau"})
|
75
|
+
assert created is False
|
76
|
+
assert obj.city == "Almaty"
|
77
|
+
|
78
|
+
obj2, created2 = repo.get_or_create(title="sigma", defaults={"city": "Aktau", "amount": 77})
|
79
|
+
assert created2 is True
|
80
|
+
assert (obj2.title, obj2.city, obj2.amount) == ("sigma", "Aktau", 77)
|
81
|
+
|
82
|
+
|
83
|
+
def test_raw_query(session, repo):
|
84
|
+
seed(session, repo)
|
85
|
+
|
86
|
+
rows = repo.raw_query(
|
87
|
+
"SELECT id, title, city FROM items WHERE lower(title) LIKE :p",
|
88
|
+
{"p": "%a%"}
|
89
|
+
)
|
90
|
+
assert isinstance(rows, list) and all(isinstance(r, dict) for r in rows)
|
91
|
+
titles = {r["title"] for r in rows}
|
92
|
+
assert {"alpha", "gamma", "delta"}.issubset(titles)
|
93
|
+
|
94
|
+
|
95
|
+
def test_aggregates_avg_min_max_sum_ignore_soft_deleted_by_default(session, repo):
|
96
|
+
seed(session, repo)
|
97
|
+
|
98
|
+
omega = repo.find_all_by_column("title", "omega")[0]
|
99
|
+
omega.is_deleted = True
|
100
|
+
session.commit()
|
101
|
+
|
102
|
+
avg_amount = repo.aggregate_avg("amount")
|
103
|
+
min_amount = repo.aggregate_min("amount")
|
104
|
+
max_amount = repo.aggregate_max("amount")
|
105
|
+
sum_amount = repo.aggregate_sum("amount")
|
106
|
+
|
107
|
+
assert avg_amount == pytest.approx((10 + 20 + 30 + 40) / 4)
|
108
|
+
assert min_amount == 10
|
109
|
+
assert max_amount == 40
|
110
|
+
assert sum_amount == 10 + 20 + 30 + 40
|
111
|
+
|
112
|
+
avg_age_almaty = repo.aggregate_avg("age", city="Almaty")
|
113
|
+
assert avg_age_almaty == pytest.approx((20 + 40) / 2)
|
114
|
+
|
115
|
+
|
116
|
+
def test_count_with_and_without_deleted(session, repo):
|
117
|
+
seed(session, repo)
|
118
|
+
|
119
|
+
for t in ("alpha", "beta"):
|
120
|
+
x = repo.find_all_by_column("title", t)[0]
|
121
|
+
x.is_deleted = True
|
122
|
+
session.commit()
|
123
|
+
|
124
|
+
assert repo.count() == 3
|
125
|
+
|
126
|
+
assert repo.count(include_deleted=True) == 5
|
127
|
+
|
128
|
+
assert repo.count(city="Almaty") == 1
|
129
|
+
|
130
|
+
|
131
|
+
def test_restore(session, repo):
|
132
|
+
seed(session, repo)
|
133
|
+
|
134
|
+
beta = repo.find_all_by_column("title", "beta")[0]
|
135
|
+
beta.is_deleted = True
|
136
|
+
session.commit()
|
137
|
+
|
138
|
+
ok = repo.restore(beta.id)
|
139
|
+
assert ok is True
|
140
|
+
session.commit()
|
141
|
+
|
142
|
+
ok2 = repo.restore(beta.id)
|
143
|
+
assert ok2 is False
|
144
|
+
|
145
|
+
assert repo.count() == 5
|
sarepo-0.1.3/SARepo/repo.py
DELETED
@@ -1,16 +0,0 @@
|
|
1
|
-
|
2
|
-
from typing import Generic, List, TypeVar, Type, Optional, Any, Protocol
|
3
|
-
from .base import Page, PageRequest
|
4
|
-
|
5
|
-
T = TypeVar("T")
|
6
|
-
|
7
|
-
class CrudRepository(Protocol, Generic[T]):
|
8
|
-
model: Type[T]
|
9
|
-
def getAll(self, limit: Optional[int]) -> List[T]: ...
|
10
|
-
def get(self, id_: Any) -> T: ...
|
11
|
-
def try_get(self, id_: Any) -> Optional[T]: ...
|
12
|
-
def add(self, entity: T) -> T: ...
|
13
|
-
def update(self, entity: T) -> T: ...
|
14
|
-
def remove(self, entity: T) -> None: ...
|
15
|
-
def delete_by_id(self, id_: Any) -> bool: ...
|
16
|
-
def page(self, page: PageRequest, spec=None, order_by=None) -> Page[T]: ...
|
sarepo-0.1.3/SARepo/sa_repo.py
DELETED
@@ -1,152 +0,0 @@
|
|
1
|
-
|
2
|
-
from typing import List, Type, Generic, TypeVar, Optional, Sequence, Any, Callable
|
3
|
-
from sqlalchemy.orm import Session
|
4
|
-
from sqlalchemy.ext.asyncio import AsyncSession
|
5
|
-
from sqlalchemy import inspect, select, func
|
6
|
-
from .base import PageRequest, Page, NotFoundError
|
7
|
-
|
8
|
-
T = TypeVar("T")
|
9
|
-
Spec = Callable # aliased to match specs.Spec
|
10
|
-
|
11
|
-
class SARepository(Generic[T]):
|
12
|
-
"""Synchronous repository implementation for SQLAlchemy 2.x."""
|
13
|
-
def __init__(self, model: Type[T], session: Session):
|
14
|
-
self.model = model
|
15
|
-
self.session = session
|
16
|
-
|
17
|
-
def _select(self):
|
18
|
-
return select(self.model)
|
19
|
-
|
20
|
-
def getAll(self, limit: Optional[int] = None) -> List[T]:
|
21
|
-
stmt = select(self.model)
|
22
|
-
if limit is not None:
|
23
|
-
stmt = stmt.limit(limit)
|
24
|
-
result = self.session.execute(stmt)
|
25
|
-
return result.scalars().all()
|
26
|
-
|
27
|
-
def get(self, id_: Any) -> T:
|
28
|
-
obj = self.session.get(self.model, id_)
|
29
|
-
if not obj:
|
30
|
-
raise NotFoundError(f"{self.model.__name__}({id_}) not found")
|
31
|
-
return obj
|
32
|
-
|
33
|
-
def try_get(self, id_: Any) -> Optional[T]:
|
34
|
-
return self.session.get(self.model, id_)
|
35
|
-
|
36
|
-
def add(self, entity: T) -> T:
|
37
|
-
self.session.add(entity)
|
38
|
-
self.session.flush()
|
39
|
-
self.session.refresh(entity)
|
40
|
-
return entity
|
41
|
-
|
42
|
-
def update(self, entity: T) -> T:
|
43
|
-
self.session.flush()
|
44
|
-
self.session.refresh(entity)
|
45
|
-
return entity
|
46
|
-
|
47
|
-
def remove(self, entity: T) -> None:
|
48
|
-
insp = inspect(entity, raiseerr=False)
|
49
|
-
if not (insp and (insp.persistent or insp.pending)):
|
50
|
-
pk = getattr(entity, "id", None)
|
51
|
-
if pk is None:
|
52
|
-
raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
|
53
|
-
entity = self.session.get(self.model, pk)
|
54
|
-
if entity is None:
|
55
|
-
return
|
56
|
-
if hasattr(entity, "is_deleted"):
|
57
|
-
setattr(entity, "is_deleted", True)
|
58
|
-
else:
|
59
|
-
self.session.delete(entity)
|
60
|
-
|
61
|
-
def delete_by_id(self, id_: Any) -> bool:
|
62
|
-
obj = self.session.get(self.model, id_)
|
63
|
-
if not obj:
|
64
|
-
return False
|
65
|
-
self.remove(obj)
|
66
|
-
return True
|
67
|
-
|
68
|
-
def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
|
69
|
-
stmt = self._select()
|
70
|
-
if spec:
|
71
|
-
stmt = spec(stmt)
|
72
|
-
if order_by is not None:
|
73
|
-
stmt = stmt.order_by(order_by)
|
74
|
-
total = self.session.execute(
|
75
|
-
select(func.count()).select_from(stmt.subquery())
|
76
|
-
).scalar_one()
|
77
|
-
items = self.session.execute(
|
78
|
-
stmt.offset(page.page * page.size).limit(page.size)
|
79
|
-
).scalars().all()
|
80
|
-
return Page(items, total, page.page, page.size)
|
81
|
-
|
82
|
-
class SAAsyncRepository(Generic[T]):
|
83
|
-
"""Async repository implementation for SQLAlchemy 2.x."""
|
84
|
-
def __init__(self, model: Type[T], session: AsyncSession):
|
85
|
-
self.model = model
|
86
|
-
self.session = session
|
87
|
-
|
88
|
-
async def getAll(self, limit: Optional[int] = None) -> List[T]:
|
89
|
-
stmt = select(self.model)
|
90
|
-
if limit is not None:
|
91
|
-
stmt = stmt.limit(limit)
|
92
|
-
result = await self.session.execute(stmt)
|
93
|
-
return result.scalars().all()
|
94
|
-
|
95
|
-
def _select(self):
|
96
|
-
return select(self.model)
|
97
|
-
|
98
|
-
async def get(self, id_: Any) -> T:
|
99
|
-
obj = await self.session.get(self.model, id_)
|
100
|
-
if not obj:
|
101
|
-
raise NotFoundError(f"{self.model.__name__}({id_}) not found")
|
102
|
-
return obj
|
103
|
-
|
104
|
-
async def try_get(self, id_: Any) -> Optional[T]:
|
105
|
-
return await self.session.get(self.model, id_)
|
106
|
-
|
107
|
-
async def add(self, entity: T) -> T:
|
108
|
-
self.session.add(entity)
|
109
|
-
await self.session.flush()
|
110
|
-
await self.session.refresh(entity)
|
111
|
-
return entity
|
112
|
-
|
113
|
-
async def update(self, entity: T) -> T:
|
114
|
-
await self.session.flush()
|
115
|
-
await self.session.refresh(entity)
|
116
|
-
return entity
|
117
|
-
|
118
|
-
async def remove(self, entity: T) -> None:
|
119
|
-
insp = inspect(entity, raiseerr=False)
|
120
|
-
if not (insp and (insp.persistent or insp.pending)):
|
121
|
-
pk = getattr(entity, "id", None)
|
122
|
-
if pk is None:
|
123
|
-
raise ValueError("remove() needs a persistent entity or an entity with a primary key set")
|
124
|
-
entity = await self.session.get(self.model, pk)
|
125
|
-
if entity is None:
|
126
|
-
return
|
127
|
-
if hasattr(entity, "is_deleted"):
|
128
|
-
setattr(entity, "is_deleted", True)
|
129
|
-
else:
|
130
|
-
await self.session.delete(entity)
|
131
|
-
|
132
|
-
async def delete_by_id(self, id_: Any) -> bool:
|
133
|
-
obj = await self.session.get(self.model, id_)
|
134
|
-
if not obj:
|
135
|
-
return False
|
136
|
-
await self.remove(obj)
|
137
|
-
return True
|
138
|
-
|
139
|
-
async def page(self, page: PageRequest, spec: Optional[Spec] = None, order_by=None) -> Page[T]:
|
140
|
-
stmt = self._select()
|
141
|
-
if spec:
|
142
|
-
stmt = spec(stmt)
|
143
|
-
if order_by is not None:
|
144
|
-
stmt = stmt.order_by(order_by)
|
145
|
-
total = (await self.session.execute(
|
146
|
-
select(func.count()).select_from(stmt.subquery())
|
147
|
-
)).scalar_one()
|
148
|
-
res = await self.session.execute(
|
149
|
-
stmt.offset(page.page * page.size).limit(page.size)
|
150
|
-
)
|
151
|
-
items = res.scalars().all()
|
152
|
-
return Page(items, total, page.page, page.size)
|
@@ -1,73 +0,0 @@
|
|
1
|
-
import pytest
|
2
|
-
from sqlalchemy import create_engine, String
|
3
|
-
from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
|
4
|
-
from SARepo.sa_repo import SARepository
|
5
|
-
from SARepo.base import PageRequest, NotFoundError
|
6
|
-
|
7
|
-
class Base(DeclarativeBase): pass
|
8
|
-
|
9
|
-
class Item(Base):
|
10
|
-
__tablename__ = "items"
|
11
|
-
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
12
|
-
title: Mapped[str] = mapped_column(String(100), nullable=False)
|
13
|
-
|
14
|
-
def test_crud_and_pagination():
|
15
|
-
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
|
16
|
-
Base.metadata.create_all(engine)
|
17
|
-
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
18
|
-
|
19
|
-
with SessionLocal() as session:
|
20
|
-
repo = SARepository(Item, session)
|
21
|
-
for i in range(25):
|
22
|
-
repo.add(Item(title=f"t{i}"))
|
23
|
-
session.commit()
|
24
|
-
|
25
|
-
page = repo.page(PageRequest(1, 10))
|
26
|
-
assert page.total == 25
|
27
|
-
assert len(page.items) == 10
|
28
|
-
assert page.page == 1
|
29
|
-
assert page.pages == 3
|
30
|
-
# проверим, что конкретные объекты есть
|
31
|
-
got = {it.title for it in page.items}
|
32
|
-
assert got.issubset({f"t{i}" for i in range(25)})
|
33
|
-
|
34
|
-
def test_get_all():
|
35
|
-
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
|
36
|
-
Base.metadata.create_all(engine)
|
37
|
-
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
38
|
-
|
39
|
-
with SessionLocal() as session:
|
40
|
-
repo = SARepository(Item, session)
|
41
|
-
for i in range(10):
|
42
|
-
repo.add(Item(title=f"get-all-{i}"))
|
43
|
-
session.commit()
|
44
|
-
|
45
|
-
items = repo.getAll()
|
46
|
-
assert isinstance(items, list)
|
47
|
-
assert len(items) == 10
|
48
|
-
titles = {i.title for i in items}
|
49
|
-
assert titles == {f"get-all-{i}" for i in range(10)}
|
50
|
-
|
51
|
-
items2 = repo.getAll(5)
|
52
|
-
assert isinstance(items2, list)
|
53
|
-
assert len(items2) == 5
|
54
|
-
titles = {i.title for i in items2}
|
55
|
-
assert titles == {f"get-all-{i}" for i in range(5)}
|
56
|
-
|
57
|
-
def test_get_and_try_get():
|
58
|
-
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False)
|
59
|
-
Base.metadata.create_all(engine)
|
60
|
-
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
61
|
-
|
62
|
-
with SessionLocal() as session:
|
63
|
-
repo = SARepository(Item, session)
|
64
|
-
obj = repo.add(Item(title="one"))
|
65
|
-
session.commit()
|
66
|
-
|
67
|
-
same = repo.get(obj.id)
|
68
|
-
assert same.id == obj.id and same.title == "one"
|
69
|
-
|
70
|
-
assert repo.try_get(9999) is None
|
71
|
-
|
72
|
-
with pytest.raises(NotFoundError):
|
73
|
-
repo.get(9999)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|