rb-commons 0.7.15__py3-none-any.whl → 0.7.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rb_commons/orm/managers.py +654 -578
- {rb_commons-0.7.15.dist-info → rb_commons-0.7.17.dist-info}/METADATA +1 -1
- {rb_commons-0.7.15.dist-info → rb_commons-0.7.17.dist-info}/RECORD +5 -5
- {rb_commons-0.7.15.dist-info → rb_commons-0.7.17.dist-info}/WHEEL +0 -0
- {rb_commons-0.7.15.dist-info → rb_commons-0.7.17.dist-info}/top_level.txt +0 -0
rb_commons/orm/managers.py
CHANGED
@@ -2,11 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import re
|
4
4
|
import uuid
|
5
|
-
from typing import TypeVar, Type, Generic, Optional, List, Dict,
|
6
|
-
from sqlalchemy import select, delete, update, and_, func, desc, inspect, or_, asc, true
|
5
|
+
from typing import TypeVar, Type, Generic, Optional, List, Dict, Union, Sequence, Any, Iterable, Callable
|
6
|
+
from sqlalchemy import select, delete, update, and_, func, desc, inspect, or_, asc, true, cast
|
7
7
|
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
8
8
|
from sqlalchemy.ext.asyncio import AsyncSession
|
9
9
|
from sqlalchemy.orm import declarative_base, selectinload, RelationshipProperty, Load
|
10
|
+
from sqlalchemy.sql.sqltypes import String, Text, Unicode, UnicodeText, Integer, BigInteger, SmallInteger, Float, Numeric
|
10
11
|
from rb_commons.http.exceptions import NotFoundException
|
11
12
|
from rb_commons.orm.exceptions import DatabaseException, InternalException
|
12
13
|
from functools import lru_cache, wraps
|
@@ -14,184 +15,220 @@ from rb_commons.orm.querysets import Q, QJSON
|
|
14
15
|
|
15
16
|
ModelType = TypeVar('ModelType', bound=declarative_base())
|
16
17
|
|
18
|
+
|
17
19
|
def with_transaction_error_handling(func):
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
20
|
+
async def wrapper(self, *args, **kwargs):
|
21
|
+
try:
|
22
|
+
return await func(self, *args, **kwargs)
|
23
|
+
except IntegrityError as e:
|
24
|
+
await self.session.rollback()
|
25
|
+
raise InternalException(f"Constraint violation: {str(e)}") from e
|
26
|
+
except SQLAlchemyError as e:
|
27
|
+
await self.session.rollback()
|
28
|
+
raise DatabaseException(f"Database error: {str(e)}") from e
|
29
|
+
except Exception as e:
|
30
|
+
await self.session.rollback()
|
31
|
+
raise InternalException(f"Unexpected error: {str(e)}") from e
|
32
|
+
|
33
|
+
return wrapper
|
34
|
+
|
31
35
|
|
32
36
|
F = TypeVar("F", bound=Callable[..., Any])
|
33
37
|
|
38
|
+
|
34
39
|
def query_mutator(func: F) -> F:
|
35
|
-
|
40
|
+
"""
|
36
41
|
Make a query‑builder method clone‑on‑write without touching its body.
|
37
42
|
"""
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
43
|
+
|
44
|
+
@wraps(func)
|
45
|
+
def wrapper(self: "BaseManager[Any]", *args, **kwargs):
|
46
|
+
clone = self._clone()
|
47
|
+
result = func(clone, *args, **kwargs)
|
48
|
+
return result if result is not None else clone
|
49
|
+
|
50
|
+
return wrapper
|
44
51
|
|
45
52
|
|
46
53
|
AGG_MAP: dict[str, Callable[[Any], Any]] = {
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
+
"sum": func.sum,
|
55
|
+
"avg": func.avg,
|
56
|
+
"mean": func.avg,
|
57
|
+
"min": func.min,
|
58
|
+
"max": func.max,
|
59
|
+
"count": func.count,
|
60
|
+
"first": lambda c: c,
|
54
61
|
}
|
55
62
|
|
63
|
+
|
56
64
|
class BaseManager(Generic[ModelType]):
|
57
|
-
|
65
|
+
model: Type[ModelType]
|
58
66
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
67
|
+
def __init__(self, session: AsyncSession) -> None:
|
68
|
+
self.session: AsyncSession = session
|
69
|
+
self.filters: List[Any] = []
|
70
|
+
self._filtered: bool = False
|
71
|
+
self._limit: Optional[int] = None
|
72
|
+
self._order_by: List[Any] = []
|
73
|
+
self._joins: set[str] = set()
|
66
74
|
|
67
|
-
|
68
|
-
|
75
|
+
mapper = inspect(self.model)
|
76
|
+
self._column_keys = [c.key for c in mapper.mapper.column_attrs]
|
69
77
|
|
70
|
-
|
71
|
-
|
78
|
+
def _clone(self) -> "BaseManager[ModelType]":
|
79
|
+
"""
|
72
80
|
Shallow‑copy all mutable query state into a new manager instance.
|
73
81
|
"""
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
82
|
+
clone = self.__class__(self.session)
|
83
|
+
clone.filters = list(self.filters)
|
84
|
+
clone._order_by = list(self._order_by)
|
85
|
+
clone._limit = self._limit
|
86
|
+
clone._joins = set(self._joins)
|
87
|
+
clone._filtered = self._filtered
|
88
|
+
return clone
|
89
|
+
|
90
|
+
async def _smart_commit(self, instance: Optional[ModelType] = None) -> Optional[ModelType]:
|
91
|
+
if not self.session.in_transaction():
|
92
|
+
await self.session.commit()
|
93
|
+
if instance is not None:
|
94
|
+
await self.session.refresh(instance)
|
95
|
+
return instance
|
96
|
+
return None
|
97
|
+
|
98
|
+
def _build_comparison(self, col, operator: str, value: Any):
|
99
|
+
if operator == "eq":
|
100
|
+
return col == value
|
101
|
+
if operator == "ne":
|
102
|
+
return col != value
|
103
|
+
if operator == "gt":
|
104
|
+
return col > value
|
105
|
+
if operator == "lt":
|
106
|
+
return col < value
|
107
|
+
if operator == "gte":
|
108
|
+
return col >= value
|
109
|
+
if operator == "lte":
|
110
|
+
return col <= value
|
111
|
+
if operator == "in":
|
112
|
+
return col.in_(value)
|
113
|
+
if operator in {"contains", "startswith", "endswith"}:
|
114
|
+
return self._textop_with_autocast(col, operator, value)
|
115
|
+
if operator == "null":
|
116
|
+
return col.is_(None) if value else col.isnot(None)
|
117
|
+
raise ValueError(f"Unsupported operator: {operator}")
|
118
|
+
|
119
|
+
@lru_cache(maxsize=None)
|
120
|
+
def _parse_lookup_meta(self, lookup: str):
|
121
|
+
"""
|
114
122
|
One-time parse of "foo__bar__lt" into:
|
115
123
|
- parts = ["foo","bar"]
|
116
124
|
- operator="lt"
|
117
125
|
- relationship_attr, column_attr pointers
|
118
126
|
"""
|
119
127
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
128
|
+
parts = lookup.split("__")
|
129
|
+
operator = "eq"
|
130
|
+
|
131
|
+
if parts[-1] in {"eq", "ne", "gt", "lt", "gte", "lte", "in", "contains", "startswith", "endswith", "null"}:
|
132
|
+
operator = parts.pop()
|
133
|
+
|
134
|
+
current = self.model
|
135
|
+
rel = None
|
136
|
+
col = None
|
137
|
+
for p in parts:
|
138
|
+
a = getattr(current, p)
|
139
|
+
if hasattr(a, "property") and isinstance(a.property, RelationshipProperty):
|
140
|
+
rel = a
|
141
|
+
current = a.property.mapper.class_
|
142
|
+
else:
|
143
|
+
col = a
|
144
|
+
return parts, operator, rel, col
|
145
|
+
|
146
|
+
def _parse_lookup(self, lookup: str, value: Any):
|
147
|
+
parts, operator, rel_attr, col_attr = self._parse_lookup_meta(lookup)
|
148
|
+
|
149
|
+
if rel_attr is not None and col_attr is None:
|
150
|
+
uselist = rel_attr.property.uselist
|
151
|
+
primaryjoin = rel_attr.property.primaryjoin
|
152
|
+
|
153
|
+
if uselist:
|
154
|
+
target_cls = rel_attr.property.mapper.class_
|
155
|
+
cnt = (
|
156
|
+
select(func.count("*"))
|
157
|
+
.select_from(target_cls)
|
158
|
+
.where(primaryjoin)
|
159
|
+
.correlate(self.model)
|
160
|
+
.scalar_subquery()
|
161
|
+
)
|
162
|
+
return self._build_comparison(cnt, operator, value)
|
163
|
+
else:
|
164
|
+
exists_expr = (
|
165
|
+
select(1)
|
166
|
+
.where(primaryjoin)
|
167
|
+
.correlate(self.model)
|
168
|
+
.exists()
|
169
|
+
)
|
170
|
+
if operator in {"eq", "lte"} and str(value) in {"0", "False", "false"}:
|
171
|
+
return ~exists_expr
|
172
|
+
if operator in {"gt", "gte", "eq"} and str(value) in {"1", "True", "true"}:
|
173
|
+
return exists_expr
|
174
|
+
return self._build_comparison(exists_expr, operator, bool(value))
|
175
|
+
|
176
|
+
expr = self._build_comparison(col_attr, operator, value)
|
177
|
+
|
178
|
+
if rel_attr:
|
179
|
+
if rel_attr.property.uselist:
|
180
|
+
return rel_attr.any(expr)
|
181
|
+
else:
|
182
|
+
return rel_attr.has(expr)
|
183
|
+
|
184
|
+
return expr
|
185
|
+
|
186
|
+
def _q_to_expr(self, q: Union[Q, QJSON]):
|
187
|
+
if isinstance(q, QJSON):
|
188
|
+
return self._parse_qjson(q)
|
189
|
+
|
190
|
+
clauses: List[Any] = [self._parse_lookup(k, v) for k, v in q.lookups.items()]
|
191
|
+
for child in q.children:
|
192
|
+
clauses.append(self._q_to_expr(child))
|
193
|
+
|
194
|
+
if not clauses:
|
195
|
+
combined = true()
|
196
|
+
elif q._operator == "OR":
|
197
|
+
combined = or_(*clauses)
|
198
|
+
else:
|
199
|
+
combined = and_(*clauses)
|
200
|
+
|
201
|
+
return ~combined if q.negated else combined
|
202
|
+
|
203
|
+
def _parse_qjson(self, qjson: QJSON):
|
204
|
+
col = getattr(self.model, qjson.field, None)
|
205
|
+
if col is None:
|
206
|
+
raise ValueError(f"Invalid JSON field: {qjson.field}")
|
207
|
+
|
208
|
+
json_expr = col[qjson.key].astext
|
209
|
+
|
210
|
+
if qjson.operator == "eq":
|
211
|
+
return json_expr == str(qjson.value)
|
212
|
+
if qjson.operator == "ne":
|
213
|
+
return json_expr != str(qjson.value)
|
214
|
+
if qjson.operator == "contains":
|
215
|
+
return json_expr.ilike(f"%{qjson.value}%")
|
216
|
+
if qjson.operator == "startswith":
|
217
|
+
return json_expr.ilike(f"{qjson.value}%")
|
218
|
+
if qjson.operator == "endswith":
|
219
|
+
return json_expr.ilike(f"%{qjson.value}")
|
220
|
+
if qjson.operator == "in":
|
221
|
+
if not isinstance(qjson.value, (list, tuple, set)):
|
222
|
+
raise ValueError(f"{qjson.field}[{qjson.key}]__in requires an iterable")
|
223
|
+
return json_expr.in_(qjson.value)
|
224
|
+
raise ValueError(f"Unsupported QJSON operator: {qjson.operator}")
|
225
|
+
|
226
|
+
def _build_relation_loaders(
|
227
|
+
self,
|
228
|
+
model: Any,
|
229
|
+
relations: Sequence[str] | None = None
|
230
|
+
) -> List[Load]:
|
231
|
+
"""
|
195
232
|
Given e.g. ["media", "properties.property", "properties__property"],
|
196
233
|
returns [
|
197
234
|
selectinload(Product.media),
|
@@ -200,111 +237,149 @@ class BaseManager(Generic[ModelType]):
|
|
200
237
|
|
201
238
|
If `relations` is None or empty, recurse *all* relationships once (cycle-safe).
|
202
239
|
"""
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
240
|
+
loaders: List[Load] = []
|
241
|
+
|
242
|
+
if relations:
|
243
|
+
for path in relations:
|
244
|
+
parts = re.split(r"\.|\_\_", path)
|
245
|
+
current_model = model
|
246
|
+
loader: Load | None = None
|
247
|
+
|
248
|
+
for part in parts:
|
249
|
+
attr = getattr(current_model, part, None)
|
250
|
+
if attr is None or not hasattr(attr, "property"):
|
251
|
+
raise ValueError(f"Invalid relationship path: {path!r}")
|
252
|
+
loader = selectinload(attr) if loader is None else loader.selectinload(attr)
|
253
|
+
current_model = attr.property.mapper.class_
|
254
|
+
|
255
|
+
loaders.append(loader)
|
256
|
+
|
257
|
+
return loaders
|
258
|
+
|
259
|
+
visited = set()
|
260
|
+
|
261
|
+
def recurse(curr_model: Any, curr_loader: Load | None = None):
|
262
|
+
mapper = inspect(curr_model)
|
263
|
+
if mapper in visited:
|
264
|
+
return
|
265
|
+
visited.add(mapper)
|
266
|
+
|
267
|
+
for rel in mapper.relationships:
|
268
|
+
attr = getattr(curr_model, rel.key)
|
269
|
+
loader = (
|
270
|
+
selectinload(attr)
|
271
|
+
if curr_loader is None
|
272
|
+
else curr_loader.selectinload(attr)
|
273
|
+
)
|
274
|
+
loaders.append(loader)
|
275
|
+
recurse(rel.mapper.class_, loader)
|
276
|
+
|
277
|
+
recurse(model)
|
278
|
+
return loaders
|
279
|
+
|
280
|
+
async def _execute_query(self, stmt):
|
281
|
+
result = await self.session.execute(stmt)
|
282
|
+
rows = result.scalars().all()
|
283
|
+
return list({obj.id: obj for obj in rows}.values())
|
284
|
+
|
285
|
+
@query_mutator
|
286
|
+
def order_by(self, *columns: Any):
|
287
|
+
"""Collect ORDER BY clauses.
|
251
288
|
"""
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
else:
|
282
|
-
self.filters.append(expr)
|
283
|
-
|
284
|
-
return self
|
285
|
-
|
286
|
-
@query_mutator
|
287
|
-
def or_filter(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
|
288
|
-
"""Add one OR group (shortcut for `filter(Q() | Q())`)."""
|
289
|
-
|
290
|
-
or_clauses: List[Any] = []
|
291
|
-
for expr in expressions:
|
292
|
-
if isinstance(expr, Q) or isinstance(expr, QJSON):
|
293
|
-
or_clauses.append(self._q_to_expr(expr))
|
294
|
-
else:
|
295
|
-
or_clauses.append(expr)
|
296
|
-
|
297
|
-
for k, v in lookups.items():
|
298
|
-
or_clauses.append(self._parse_lookup(k, v))
|
299
|
-
|
300
|
-
if or_clauses:
|
301
|
-
self._filtered = True
|
302
|
-
self.filters.append(or_(*or_clauses))
|
303
|
-
return self
|
304
|
-
|
305
|
-
@query_mutator
|
306
|
-
def exclude(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
|
289
|
+
for col in columns:
|
290
|
+
if isinstance(col, str):
|
291
|
+
descending = col.startswith("-")
|
292
|
+
field_name = col.lstrip("+-")
|
293
|
+
sa_col = getattr(self.model, field_name, None)
|
294
|
+
if sa_col is None:
|
295
|
+
raise ValueError(f"Invalid order_by field '{field_name}' for {self.model.__name__}")
|
296
|
+
self._order_by.append(sa_col.desc() if descending else sa_col.asc())
|
297
|
+
else:
|
298
|
+
self._order_by.append(col)
|
299
|
+
|
300
|
+
return self
|
301
|
+
|
302
|
+
def _is_textual_type(self, col) -> bool:
|
303
|
+
try:
|
304
|
+
return hasattr(col, "type") and isinstance(col.type, (String, Text, Unicode, UnicodeText))
|
305
|
+
except Exception:
|
306
|
+
return False
|
307
|
+
|
308
|
+
def _is_numeric_type(self, col) -> bool:
|
309
|
+
try:
|
310
|
+
return hasattr(col, "type") and isinstance(col.type, (Integer, BigInteger, SmallInteger, Float, Numeric))
|
311
|
+
except Exception:
|
312
|
+
return False
|
313
|
+
|
314
|
+
def _ilike_with_autocast(self, col, pattern: str, raw_value: Any):
|
315
|
+
"""
|
316
|
+
Use ILIKE on text columns. Otherwise cast to text.
|
317
|
+
If value looks integer and the column is numeric, do (col == int(value)) OR ILIKE(cast(col)).
|
307
318
|
"""
|
319
|
+
if self._is_textual_type(col):
|
320
|
+
return col.ilike(pattern)
|
321
|
+
|
322
|
+
text_col = cast(col, String())
|
323
|
+
if isinstance(raw_value, str) and raw_value.isdigit() and self._is_numeric_type(col):
|
324
|
+
return or_(col == int(raw_value), text_col.ilike(pattern))
|
325
|
+
return text_col.ilike(pattern)
|
326
|
+
|
327
|
+
def _textop_with_autocast(self, col, operator: str, raw_value: Any):
|
328
|
+
"""
|
329
|
+
Supports 'contains' | 'startswith' | 'endswith' with auto-cast.
|
330
|
+
"""
|
331
|
+
val = "" if raw_value is None else str(raw_value)
|
332
|
+
if operator == "contains":
|
333
|
+
return self._ilike_with_autocast(col, f"%{val}%", raw_value)
|
334
|
+
if operator == "startswith":
|
335
|
+
return self._ilike_with_autocast(col, f"{val}%", raw_value)
|
336
|
+
if operator == "endswith":
|
337
|
+
return self._ilike_with_autocast(col, f"%{val}", raw_value)
|
338
|
+
raise ValueError(f"Unsupported text operator: {operator}")
|
339
|
+
|
340
|
+
@query_mutator
|
341
|
+
def filter(self, *expressions: Any, **lookups: Any) -> "BaseManager":
|
342
|
+
self._filtered = True
|
343
|
+
|
344
|
+
for k, v in lookups.items():
|
345
|
+
root = k.split("__", 1)[0]
|
346
|
+
if hasattr(self.model, root):
|
347
|
+
attr = getattr(self.model, root)
|
348
|
+
if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
|
349
|
+
self._joins.add(root)
|
350
|
+
|
351
|
+
self.filters.append(self._parse_lookup(k, v))
|
352
|
+
|
353
|
+
for expr in expressions:
|
354
|
+
if isinstance(expr, Q) or isinstance(expr, QJSON):
|
355
|
+
self.filters.append(self._q_to_expr(expr))
|
356
|
+
else:
|
357
|
+
self.filters.append(expr)
|
358
|
+
|
359
|
+
return self
|
360
|
+
|
361
|
+
@query_mutator
|
362
|
+
def or_filter(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
|
363
|
+
"""Add one OR group (shortcut for `filter(Q() | Q())`)."""
|
364
|
+
|
365
|
+
or_clauses: List[Any] = []
|
366
|
+
for expr in expressions:
|
367
|
+
if isinstance(expr, Q) or isinstance(expr, QJSON):
|
368
|
+
or_clauses.append(self._q_to_expr(expr))
|
369
|
+
else:
|
370
|
+
or_clauses.append(expr)
|
371
|
+
|
372
|
+
for k, v in lookups.items():
|
373
|
+
or_clauses.append(self._parse_lookup(k, v))
|
374
|
+
|
375
|
+
if or_clauses:
|
376
|
+
self._filtered = True
|
377
|
+
self.filters.append(or_(*or_clauses))
|
378
|
+
return self
|
379
|
+
|
380
|
+
@query_mutator
|
381
|
+
def exclude(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
|
382
|
+
"""
|
308
383
|
Exclude records that match the given conditions.
|
309
384
|
This is the opposite of filter() - it adds NOT conditions.
|
310
385
|
|
@@ -325,226 +400,227 @@ class BaseManager(Generic[ModelType]):
|
|
325
400
|
# Exclude using QJSON
|
326
401
|
manager.exclude(QJSON("metadata", "type", "eq", "archived"))
|
327
402
|
"""
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
403
|
+
self._filtered = True
|
404
|
+
|
405
|
+
for k, v in lookups.items():
|
406
|
+
root = k.split("__", 1)[0]
|
407
|
+
if hasattr(self.model, root):
|
408
|
+
attr = getattr(self.model, root)
|
409
|
+
if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
|
410
|
+
self._joins.add(root)
|
411
|
+
|
412
|
+
lookup_expr = self._parse_lookup(k, v)
|
413
|
+
self.filters.append(~lookup_expr)
|
414
|
+
|
415
|
+
for expr in expressions:
|
416
|
+
if isinstance(expr, Q) or isinstance(expr, QJSON):
|
417
|
+
q_expr = self._q_to_expr(expr)
|
418
|
+
self.filters.append(~q_expr)
|
419
|
+
else:
|
420
|
+
self.filters.append(~expr)
|
421
|
+
|
422
|
+
return self
|
423
|
+
|
424
|
+
@query_mutator
|
425
|
+
def limit(self, value: int) -> "BaseManager[ModelType]":
|
426
|
+
self._limit = value
|
427
|
+
return self
|
428
|
+
|
429
|
+
async def all(self, relations: Optional[List[str]] = None):
|
430
|
+
stmt = select(self.model)
|
431
|
+
|
432
|
+
if relations:
|
433
|
+
opts = self._build_relation_loaders(self.model, relations)
|
434
|
+
stmt = stmt.options(*opts)
|
435
|
+
|
436
|
+
if self.filters:
|
437
|
+
stmt = stmt.filter(and_(*self.filters))
|
438
|
+
if self._order_by:
|
439
|
+
stmt = stmt.order_by(*self._order_by)
|
440
|
+
if self._limit:
|
441
|
+
stmt = stmt.limit(self._limit)
|
442
|
+
|
443
|
+
return await self._execute_query(stmt)
|
444
|
+
|
445
|
+
async def first(self, relations: Optional[Sequence[str]] = None):
|
446
|
+
self._ensure_filtered()
|
447
|
+
stmt = select(self.model).filter(and_(*self.filters))
|
448
|
+
|
449
|
+
if self._order_by:
|
450
|
+
stmt = stmt.order_by(*self._order_by)
|
451
|
+
|
452
|
+
if relations:
|
453
|
+
opts = self._build_relation_loaders(self.model, relations)
|
454
|
+
stmt = stmt.options(*opts)
|
455
|
+
|
456
|
+
result = await self.session.execute(stmt)
|
457
|
+
return result.scalars().first()
|
458
|
+
|
459
|
+
async def last(self, relations: Optional[Sequence[str]] = None):
|
460
|
+
self._ensure_filtered()
|
461
|
+
stmt = select(self.model).filter(and_(*self.filters))
|
462
|
+
order = self._order_by or [self.model.id.desc()]
|
463
|
+
stmt = stmt.order_by(*order[::-1])
|
464
|
+
|
465
|
+
if relations:
|
466
|
+
opts = self._build_relation_loaders(self.model, relations)
|
467
|
+
stmt = stmt.options(*opts)
|
468
|
+
|
469
|
+
result = await self.session.execute(stmt)
|
470
|
+
return result.scalars().first()
|
471
|
+
|
472
|
+
async def count(self) -> int | None:
|
473
|
+
self._ensure_filtered()
|
474
|
+
|
475
|
+
stmt = select(func.count(self.model.id)).select_from(self.model)
|
476
|
+
if self.filters:
|
477
|
+
stmt = stmt.where(and_(*self.filters))
|
478
|
+
|
479
|
+
result = await self.session.execute(stmt)
|
480
|
+
return int(result.scalar_one())
|
481
|
+
|
482
|
+
async def paginate(self, limit: int = 10, offset: int = 0, relations: Optional[Sequence[str]] = None):
|
483
|
+
self._ensure_filtered()
|
484
|
+
stmt = select(self.model).filter(and_(*self.filters))
|
485
|
+
|
486
|
+
if relations:
|
487
|
+
opts = self._build_relation_loaders(self.model, relations)
|
488
|
+
stmt = stmt.options(*opts)
|
489
|
+
|
490
|
+
if self._order_by:
|
491
|
+
stmt = stmt.order_by(*self._order_by)
|
492
|
+
stmt = stmt.limit(limit).offset(offset)
|
493
|
+
return await self._execute_query(stmt)
|
494
|
+
|
495
|
+
@with_transaction_error_handling
|
496
|
+
async def create(self, **kwargs):
|
497
|
+
obj = self.model(**kwargs)
|
498
|
+
self.session.add(obj)
|
499
|
+
await self.session.flush()
|
500
|
+
return await self._smart_commit(obj)
|
501
|
+
|
502
|
+
@with_transaction_error_handling
|
503
|
+
async def save(self, instance: ModelType):
|
504
|
+
self.session.add(instance)
|
505
|
+
await self.session.flush()
|
506
|
+
return await self._smart_commit(instance)
|
507
|
+
|
508
|
+
@with_transaction_error_handling
|
509
|
+
async def lazy_save(self, instance: ModelType, relations: list[str] | None = None) -> ModelType:
|
510
|
+
self.session.add(instance)
|
511
|
+
await self.session.commit()
|
512
|
+
|
513
|
+
if relations is None:
|
514
|
+
from sqlalchemy.inspection import inspect
|
515
|
+
mapper = inspect(self.model)
|
516
|
+
relations = [r.key for r in mapper.relationships]
|
517
|
+
|
518
|
+
if not relations:
|
519
|
+
return instance
|
520
|
+
|
521
|
+
stmt = select(self.model).filter_by(id=instance.id)
|
522
|
+
stmt = stmt.options(*self._build_relation_loaders(self.model, relations))
|
523
|
+
result = await self.session.execute(stmt)
|
524
|
+
loaded = result.scalar_one_or_none()
|
525
|
+
if loaded is None:
|
526
|
+
raise NotFoundException("Could not reload after save", 404, "0001")
|
527
|
+
return loaded
|
528
|
+
|
529
|
+
@with_transaction_error_handling
|
530
|
+
async def update(self, instance: ModelType, **fields):
|
531
|
+
if not fields:
|
532
|
+
raise InternalException("No fields provided for update")
|
533
|
+
for k, v in fields.items():
|
534
|
+
setattr(instance, k, v)
|
535
|
+
self.session.add(instance)
|
536
|
+
await self._smart_commit()
|
537
|
+
return instance
|
538
|
+
|
539
|
+
@with_transaction_error_handling
|
540
|
+
async def update_by_filters(self, filters: Dict[str, Any], **fields):
|
541
|
+
if not fields:
|
542
|
+
raise InternalException("No fields provided for update")
|
543
|
+
stmt = update(self.model).filter_by(**filters).values(**fields)
|
544
|
+
await self.session.execute(stmt)
|
545
|
+
await self.session.commit()
|
546
|
+
return await self.get(**filters)
|
547
|
+
|
548
|
+
@with_transaction_error_handling
|
549
|
+
async def delete(self, instance: Optional[ModelType] = None):
|
550
|
+
if instance is not None:
|
551
|
+
await self.session.delete(instance)
|
552
|
+
await self.session.commit()
|
553
|
+
return True
|
554
|
+
self._ensure_filtered()
|
555
|
+
stmt = delete(self.model).where(and_(*self.filters))
|
556
|
+
await self.session.execute(stmt)
|
557
|
+
await self.session.commit()
|
558
|
+
return True
|
559
|
+
|
560
|
+
@with_transaction_error_handling
|
561
|
+
async def bulk_save(self, instances: Iterable[ModelType]):
|
562
|
+
if not instances:
|
563
|
+
return
|
564
|
+
self.session.add_all(list(instances))
|
565
|
+
await self.session.flush()
|
566
|
+
if not self.session.in_transaction():
|
567
|
+
await self.session.commit()
|
568
|
+
|
569
|
+
@with_transaction_error_handling
|
570
|
+
async def bulk_delete(self):
|
571
|
+
self._ensure_filtered()
|
572
|
+
stmt = delete(self.model).where(and_(*self.filters))
|
573
|
+
result = await self.session.execute(stmt)
|
574
|
+
await self._smart_commit()
|
575
|
+
return result.rowcount
|
576
|
+
|
577
|
+
async def get(self, pk: Union[str, int, uuid.UUID], relations: Optional[Sequence[str]] = None) -> Any:
|
578
|
+
stmt = select(self.model).filter_by(id=pk)
|
579
|
+
if relations:
|
580
|
+
opts = self._build_relation_loaders(self.model, relations)
|
581
|
+
stmt = stmt.options(*opts)
|
582
|
+
|
583
|
+
result = await self.session.execute(stmt)
|
584
|
+
instance = result.scalar_one_or_none()
|
585
|
+
if instance is None:
|
586
|
+
raise NotFoundException("Object does not exist", 404, "0001")
|
587
|
+
return instance
|
588
|
+
|
589
|
+
async def is_exists(self):
|
590
|
+
self._ensure_filtered()
|
591
|
+
|
592
|
+
stmt = (
|
593
|
+
select(self.model)
|
594
|
+
.filter(and_(*self.filters))
|
595
|
+
.limit(1)
|
596
|
+
)
|
597
|
+
result = await self.session.execute(stmt)
|
598
|
+
return result.scalars().first() is not None
|
599
|
+
|
600
|
+
@query_mutator
|
601
|
+
def has_relation(self, relation_name: str):
|
602
|
+
relationship = getattr(self.model, relation_name)
|
603
|
+
subquery = (
|
604
|
+
select(1)
|
605
|
+
.select_from(relationship.property.mapper.class_)
|
606
|
+
.where(relationship.property.primaryjoin)
|
607
|
+
.exists()
|
608
|
+
)
|
609
|
+
self.filters.append(subquery)
|
610
|
+
self._filtered = True
|
611
|
+
return self
|
612
|
+
|
613
|
+
def _infer_default_agg(self, column) -> str:
|
614
|
+
try:
|
615
|
+
from sqlalchemy import Integer, BigInteger, SmallInteger, Float, Numeric
|
616
|
+
if hasattr(column, "type") and isinstance(column.type, (Integer, BigInteger, SmallInteger, Float, Numeric)):
|
617
|
+
return "sum"
|
618
|
+
except Exception:
|
619
|
+
pass
|
620
|
+
return "max"
|
621
|
+
|
622
|
+
def _order_expr_for_path(self, token: str):
|
623
|
+
"""
|
548
624
|
token grammar:
|
549
625
|
[-]<path>[:<agg>][!first|!last]
|
550
626
|
<path> := "field" | "relation__field" (one hop)
|
@@ -555,103 +631,103 @@ class BaseManager(Generic[ModelType]):
|
|
555
631
|
"stocks__sold:sum"
|
556
632
|
"""
|
557
633
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
634
|
+
# strip leading '-' (handled by caller), and parse nulls placement
|
635
|
+
core = token.lstrip("-")
|
636
|
+
nulls_placement = None
|
637
|
+
if core.endswith("!first"):
|
638
|
+
core, nulls_placement = core[:-6], "first"
|
639
|
+
elif core.endswith("!last"):
|
640
|
+
core, nulls_placement = core[:-5], "last"
|
641
|
+
|
642
|
+
# split aggregate suffix if present
|
643
|
+
if ":" in core:
|
644
|
+
path, agg_name = core.split(":", 1)
|
645
|
+
agg_name = agg_name.lower().strip()
|
646
|
+
else:
|
647
|
+
path, agg_name = core, None
|
648
|
+
|
649
|
+
# base column on the model (no relation hop)
|
650
|
+
if "__" not in path and "." not in path:
|
651
|
+
col = getattr(self.model, path, None)
|
652
|
+
if col is None:
|
653
|
+
raise ValueError(f"Invalid order_by field '{path}' for {self.model.__name__}")
|
654
|
+
expr = col
|
655
|
+
if nulls_placement == "first":
|
656
|
+
expr = expr.nullsfirst()
|
657
|
+
elif nulls_placement == "last":
|
658
|
+
expr = expr.nullslast()
|
659
|
+
return expr
|
660
|
+
|
661
|
+
# relation hop (exactly one)
|
662
|
+
parts = re.split(r"\.|\_\_", path)
|
663
|
+
if len(parts) != 2:
|
664
|
+
raise ValueError(f"Only one relation hop supported in order_by: {path!r}")
|
665
|
+
|
666
|
+
rel_name, col_name = parts
|
667
|
+
rel_attr = getattr(self.model, rel_name, None)
|
668
|
+
if rel_attr is None or not hasattr(rel_attr, "property"):
|
669
|
+
raise ValueError(f"Invalid relationship '{rel_name}' on {self.model.__name__}")
|
670
|
+
|
671
|
+
target_mapper = rel_attr.property.mapper
|
672
|
+
target_cls = target_mapper.class_
|
673
|
+
target_col = getattr(target_cls, col_name, None)
|
674
|
+
if target_col is None:
|
675
|
+
raise ValueError(f"Invalid column '{col_name}' on related model {target_cls.__name__}")
|
676
|
+
|
677
|
+
primaryjoin = rel_attr.property.primaryjoin
|
678
|
+
uselist = rel_attr.property.uselist
|
679
|
+
|
680
|
+
# One-to-many (or many-to-many via association): require aggregate (or infer)
|
681
|
+
if uselist:
|
682
|
+
agg_name = agg_name or self._infer_default_agg(target_col)
|
683
|
+
agg_fn = AGG_MAP.get(agg_name)
|
684
|
+
if agg_fn is None:
|
685
|
+
raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
|
686
|
+
|
687
|
+
# SELECT agg(related.col) WHERE primaryjoin (correlated)
|
688
|
+
subq = (
|
689
|
+
select(agg_fn(target_col))
|
690
|
+
.where(primaryjoin)
|
691
|
+
.correlate(self.model) # tie to outer row
|
692
|
+
.scalar_subquery()
|
693
|
+
)
|
694
|
+
expr = subq
|
695
|
+
|
696
|
+
else:
|
697
|
+
if agg_name and agg_name != "first":
|
698
|
+
agg_fn = AGG_MAP.get(agg_name)
|
699
|
+
if agg_fn is None:
|
700
|
+
raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
|
701
|
+
select_expr = agg_fn(target_col)
|
702
|
+
else:
|
703
|
+
select_expr = target_col
|
704
|
+
|
705
|
+
sub = select(select_expr).where(primaryjoin).correlate(self.model)
|
706
|
+
if agg_name == "first":
|
707
|
+
sub = sub.limit(1)
|
708
|
+
expr = sub.scalar_subquery()
|
709
|
+
|
710
|
+
if nulls_placement == "first":
|
711
|
+
expr = expr.nullsfirst()
|
712
|
+
elif nulls_placement == "last":
|
713
|
+
expr = expr.nullslast()
|
714
|
+
|
715
|
+
return expr
|
716
|
+
|
717
|
+
@query_mutator
|
718
|
+
def sort_by(self, tokens):
|
719
|
+
self._order_by = []
|
720
|
+
for tok in tokens or []:
|
721
|
+
direction = desc if tok.startswith("-") else asc
|
722
|
+
name = tok.lstrip("-")
|
723
|
+
self._order_by.append(direction(self._order_expr_for_path(name)))
|
724
|
+
return self
|
725
|
+
|
726
|
+
def model_to_dict(self, instance: ModelType, exclude: set[str] = None) -> dict:
|
727
|
+
exclude = exclude or set()
|
728
|
+
return {k: getattr(instance, k) for k in self._column_keys if k not in exclude}
|
729
|
+
|
730
|
+
def _ensure_filtered(self):
|
731
|
+
if not self._filtered:
|
732
|
+
raise RuntimeError("You must call `filter()` before this operation.")
|
657
733
|
|