rb-commons 0.7.15__py3-none-any.whl → 0.7.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,11 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import re
4
4
  import uuid
5
- from typing import TypeVar, Type, Generic, Optional, List, Dict, Literal, Union, Sequence, Any, Iterable, Callable
6
- from sqlalchemy import select, delete, update, and_, func, desc, inspect, or_, asc, true
5
+ from typing import TypeVar, Type, Generic, Optional, List, Dict, Union, Sequence, Any, Iterable, Callable
6
+ from sqlalchemy import select, delete, update, and_, func, desc, inspect, or_, asc, true, cast
7
7
  from sqlalchemy.exc import IntegrityError, SQLAlchemyError
8
8
  from sqlalchemy.ext.asyncio import AsyncSession
9
9
  from sqlalchemy.orm import declarative_base, selectinload, RelationshipProperty, Load
10
+ from sqlalchemy.sql.sqltypes import String, Text, Unicode, UnicodeText, Integer, BigInteger, SmallInteger, Float, Numeric
10
11
  from rb_commons.http.exceptions import NotFoundException
11
12
  from rb_commons.orm.exceptions import DatabaseException, InternalException
12
13
  from functools import lru_cache, wraps
@@ -14,184 +15,220 @@ from rb_commons.orm.querysets import Q, QJSON
14
15
 
15
16
  ModelType = TypeVar('ModelType', bound=declarative_base())
16
17
 
18
+
17
19
  def with_transaction_error_handling(func):
18
- async def wrapper(self, *args, **kwargs):
19
- try:
20
- return await func(self, *args, **kwargs)
21
- except IntegrityError as e:
22
- await self.session.rollback()
23
- raise InternalException(f"Constraint violation: {str(e)}") from e
24
- except SQLAlchemyError as e:
25
- await self.session.rollback()
26
- raise DatabaseException(f"Database error: {str(e)}") from e
27
- except Exception as e:
28
- await self.session.rollback()
29
- raise InternalException(f"Unexpected error: {str(e)}") from e
30
- return wrapper
20
+ async def wrapper(self, *args, **kwargs):
21
+ try:
22
+ return await func(self, *args, **kwargs)
23
+ except IntegrityError as e:
24
+ await self.session.rollback()
25
+ raise InternalException(f"Constraint violation: {str(e)}") from e
26
+ except SQLAlchemyError as e:
27
+ await self.session.rollback()
28
+ raise DatabaseException(f"Database error: {str(e)}") from e
29
+ except Exception as e:
30
+ await self.session.rollback()
31
+ raise InternalException(f"Unexpected error: {str(e)}") from e
32
+
33
+ return wrapper
34
+
31
35
 
32
36
  F = TypeVar("F", bound=Callable[..., Any])
33
37
 
38
+
34
39
  def query_mutator(func: F) -> F:
35
- """
40
+ """
36
41
  Make a query‑builder method clone‑on‑write without touching its body.
37
42
  """
38
- @wraps(func)
39
- def wrapper(self: "BaseManager[Any]", *args, **kwargs):
40
- clone = self._clone()
41
- result = func(clone, *args, **kwargs)
42
- return result if result is not None else clone
43
- return wrapper
43
+
44
+ @wraps(func)
45
+ def wrapper(self: "BaseManager[Any]", *args, **kwargs):
46
+ clone = self._clone()
47
+ result = func(clone, *args, **kwargs)
48
+ return result if result is not None else clone
49
+
50
+ return wrapper
44
51
 
45
52
 
46
53
  AGG_MAP: dict[str, Callable[[Any], Any]] = {
47
- "sum": func.sum,
48
- "avg": func.avg,
49
- "mean": func.avg,
50
- "min": func.min,
51
- "max": func.max,
52
- "count": func.count,
53
- "first": lambda c: c,
54
+ "sum": func.sum,
55
+ "avg": func.avg,
56
+ "mean": func.avg,
57
+ "min": func.min,
58
+ "max": func.max,
59
+ "count": func.count,
60
+ "first": lambda c: c,
54
61
  }
55
62
 
63
+
56
64
  class BaseManager(Generic[ModelType]):
57
- model: Type[ModelType]
65
+ model: Type[ModelType]
58
66
 
59
- def __init__(self, session: AsyncSession) -> None:
60
- self.session: AsyncSession = session
61
- self.filters: List[Any] = []
62
- self._filtered: bool = False
63
- self._limit: Optional[int] = None
64
- self._order_by: List[Any] = []
65
- self._joins: set[str] = set()
67
+ def __init__(self, session: AsyncSession) -> None:
68
+ self.session: AsyncSession = session
69
+ self.filters: List[Any] = []
70
+ self._filtered: bool = False
71
+ self._limit: Optional[int] = None
72
+ self._order_by: List[Any] = []
73
+ self._joins: set[str] = set()
66
74
 
67
- mapper = inspect(self.model)
68
- self._column_keys = [c.key for c in mapper.mapper.column_attrs]
75
+ mapper = inspect(self.model)
76
+ self._column_keys = [c.key for c in mapper.mapper.column_attrs]
69
77
 
70
- def _clone(self) -> "BaseManager[ModelType]":
71
- """
78
+ def _clone(self) -> "BaseManager[ModelType]":
79
+ """
72
80
  Shallow‑copy all mutable query state into a new manager instance.
73
81
  """
74
- clone = self.__class__(self.session)
75
- clone.filters = list(self.filters)
76
- clone._order_by = list(self._order_by)
77
- clone._limit = self._limit
78
- clone._joins = set(self._joins)
79
- clone._filtered = self._filtered
80
- return clone
81
-
82
- async def _smart_commit(self, instance: Optional[ModelType] = None) -> Optional[ModelType]:
83
- if not self.session.in_transaction():
84
- await self.session.commit()
85
- if instance is not None:
86
- await self.session.refresh(instance)
87
- return instance
88
- return None
89
-
90
- def _build_comparison(self, col, operator: str, value: Any):
91
- if operator == "eq":
92
- return col == value
93
- if operator == "ne":
94
- return col != value
95
- if operator == "gt":
96
- return col > value
97
- if operator == "lt":
98
- return col < value
99
- if operator == "gte":
100
- return col >= value
101
- if operator == "lte":
102
- return col <= value
103
- if operator == "in":
104
- return col.in_(value)
105
- if operator == "contains":
106
- return col.ilike(f"%{value}%")
107
- if operator == "null":
108
- return col.is_(None) if value else col.isnot(None)
109
- raise ValueError(f"Unsupported operator: {operator}")
110
-
111
- @lru_cache(maxsize=None)
112
- def _parse_lookup_meta(self, lookup: str):
113
- """
82
+ clone = self.__class__(self.session)
83
+ clone.filters = list(self.filters)
84
+ clone._order_by = list(self._order_by)
85
+ clone._limit = self._limit
86
+ clone._joins = set(self._joins)
87
+ clone._filtered = self._filtered
88
+ return clone
89
+
90
+ async def _smart_commit(self, instance: Optional[ModelType] = None) -> Optional[ModelType]:
91
+ if not self.session.in_transaction():
92
+ await self.session.commit()
93
+ if instance is not None:
94
+ await self.session.refresh(instance)
95
+ return instance
96
+ return None
97
+
98
+ def _build_comparison(self, col, operator: str, value: Any):
99
+ if operator == "eq":
100
+ return col == value
101
+ if operator == "ne":
102
+ return col != value
103
+ if operator == "gt":
104
+ return col > value
105
+ if operator == "lt":
106
+ return col < value
107
+ if operator == "gte":
108
+ return col >= value
109
+ if operator == "lte":
110
+ return col <= value
111
+ if operator == "in":
112
+ return col.in_(value)
113
+ if operator in {"contains", "startswith", "endswith"}:
114
+ return self._textop_with_autocast(col, operator, value)
115
+ if operator == "null":
116
+ return col.is_(None) if value else col.isnot(None)
117
+ raise ValueError(f"Unsupported operator: {operator}")
118
+
119
+ @lru_cache(maxsize=None)
120
+ def _parse_lookup_meta(self, lookup: str):
121
+ """
114
122
  One-time parse of "foo__bar__lt" into:
115
123
  - parts = ["foo","bar"]
116
124
  - operator="lt"
117
125
  - relationship_attr, column_attr pointers
118
126
  """
119
127
 
120
- parts = lookup.split("__")
121
- operator = "eq"
122
- if parts[-1] in {"eq", "ne", "gt", "lt", "gte", "lte", "in", "contains", "null"}:
123
- operator = parts.pop()
124
-
125
- current = self.model
126
- rel = None
127
- col = None
128
- for p in parts:
129
- a = getattr(current, p)
130
- if hasattr(a, "property") and isinstance(a.property, RelationshipProperty):
131
- rel = a
132
- current = a.property.mapper.class_
133
- else:
134
- col = a
135
- return parts, operator, rel, col
136
-
137
- def _parse_lookup(self, lookup: str, value: Any):
138
- parts, operator, rel_attr, col_attr = self._parse_lookup_meta(lookup)
139
- expr = self._build_comparison(col_attr, operator, value)
140
-
141
- if rel_attr:
142
- if rel_attr.property.uselist:
143
- return rel_attr.any(expr)
144
- else:
145
- return rel_attr.has(expr)
146
-
147
- return expr
148
-
149
- def _q_to_expr(self, q: Union[Q, QJSON]):
150
- if isinstance(q, QJSON):
151
- return self._parse_qjson(q)
152
-
153
- clauses: List[Any] = [self._parse_lookup(k, v) for k, v in q.lookups.items()]
154
- for child in q.children:
155
- clauses.append(self._q_to_expr(child))
156
-
157
- if not clauses:
158
- combined = true()
159
- elif q._operator == "OR":
160
- combined = or_(*clauses)
161
- else:
162
- combined = and_(*clauses)
163
-
164
- return ~combined if q.negated else combined
165
-
166
- def _parse_qjson(self, qjson: QJSON):
167
- col = getattr(self.model, qjson.field, None)
168
- if col is None:
169
- raise ValueError(f"Invalid JSON field: {qjson.field}")
170
-
171
- json_expr = col[qjson.key].astext
172
-
173
- if qjson.operator == "eq":
174
- return json_expr == str(qjson.value)
175
- if qjson.operator == "ne":
176
- return json_expr != str(qjson.value)
177
- if qjson.operator == "contains":
178
- return json_expr.ilike(f"%{qjson.value}%")
179
- if qjson.operator == "startswith":
180
- return json_expr.ilike(f"{qjson.value}%")
181
- if qjson.operator == "endswith":
182
- return json_expr.ilike(f"%{qjson.value}")
183
- if qjson.operator == "in":
184
- if not isinstance(qjson.value, (list, tuple, set)):
185
- raise ValueError(f"{qjson.field}[{qjson.key}]__in requires an iterable")
186
- return json_expr.in_(qjson.value)
187
- raise ValueError(f"Unsupported QJSON operator: {qjson.operator}")
188
-
189
- def _build_relation_loaders(
190
- self,
191
- model: Any,
192
- relations: Sequence[str] | None = None
193
- ) -> List[Load]:
194
- """
128
+ parts = lookup.split("__")
129
+ operator = "eq"
130
+
131
+ if parts[-1] in {"eq", "ne", "gt", "lt", "gte", "lte", "in", "contains", "startswith", "endswith", "null"}:
132
+ operator = parts.pop()
133
+
134
+ current = self.model
135
+ rel = None
136
+ col = None
137
+ for p in parts:
138
+ a = getattr(current, p)
139
+ if hasattr(a, "property") and isinstance(a.property, RelationshipProperty):
140
+ rel = a
141
+ current = a.property.mapper.class_
142
+ else:
143
+ col = a
144
+ return parts, operator, rel, col
145
+
146
+ def _parse_lookup(self, lookup: str, value: Any):
147
+ parts, operator, rel_attr, col_attr = self._parse_lookup_meta(lookup)
148
+
149
+ if rel_attr is not None and col_attr is None:
150
+ uselist = rel_attr.property.uselist
151
+ primaryjoin = rel_attr.property.primaryjoin
152
+
153
+ if uselist:
154
+ target_cls = rel_attr.property.mapper.class_
155
+ cnt = (
156
+ select(func.count("*"))
157
+ .select_from(target_cls)
158
+ .where(primaryjoin)
159
+ .correlate(self.model)
160
+ .scalar_subquery()
161
+ )
162
+ return self._build_comparison(cnt, operator, value)
163
+ else:
164
+ exists_expr = (
165
+ select(1)
166
+ .where(primaryjoin)
167
+ .correlate(self.model)
168
+ .exists()
169
+ )
170
+ if operator in {"eq", "lte"} and str(value) in {"0", "False", "false"}:
171
+ return ~exists_expr
172
+ if operator in {"gt", "gte", "eq"} and str(value) in {"1", "True", "true"}:
173
+ return exists_expr
174
+ return self._build_comparison(exists_expr, operator, bool(value))
175
+
176
+ expr = self._build_comparison(col_attr, operator, value)
177
+
178
+ if rel_attr:
179
+ if rel_attr.property.uselist:
180
+ return rel_attr.any(expr)
181
+ else:
182
+ return rel_attr.has(expr)
183
+
184
+ return expr
185
+
186
+ def _q_to_expr(self, q: Union[Q, QJSON]):
187
+ if isinstance(q, QJSON):
188
+ return self._parse_qjson(q)
189
+
190
+ clauses: List[Any] = [self._parse_lookup(k, v) for k, v in q.lookups.items()]
191
+ for child in q.children:
192
+ clauses.append(self._q_to_expr(child))
193
+
194
+ if not clauses:
195
+ combined = true()
196
+ elif q._operator == "OR":
197
+ combined = or_(*clauses)
198
+ else:
199
+ combined = and_(*clauses)
200
+
201
+ return ~combined if q.negated else combined
202
+
203
+ def _parse_qjson(self, qjson: QJSON):
204
+ col = getattr(self.model, qjson.field, None)
205
+ if col is None:
206
+ raise ValueError(f"Invalid JSON field: {qjson.field}")
207
+
208
+ json_expr = col[qjson.key].astext
209
+
210
+ if qjson.operator == "eq":
211
+ return json_expr == str(qjson.value)
212
+ if qjson.operator == "ne":
213
+ return json_expr != str(qjson.value)
214
+ if qjson.operator == "contains":
215
+ return json_expr.ilike(f"%{qjson.value}%")
216
+ if qjson.operator == "startswith":
217
+ return json_expr.ilike(f"{qjson.value}%")
218
+ if qjson.operator == "endswith":
219
+ return json_expr.ilike(f"%{qjson.value}")
220
+ if qjson.operator == "in":
221
+ if not isinstance(qjson.value, (list, tuple, set)):
222
+ raise ValueError(f"{qjson.field}[{qjson.key}]__in requires an iterable")
223
+ return json_expr.in_(qjson.value)
224
+ raise ValueError(f"Unsupported QJSON operator: {qjson.operator}")
225
+
226
+ def _build_relation_loaders(
227
+ self,
228
+ model: Any,
229
+ relations: Sequence[str] | None = None
230
+ ) -> List[Load]:
231
+ """
195
232
  Given e.g. ["media", "properties.property", "properties__property"],
196
233
  returns [
197
234
  selectinload(Product.media),
@@ -200,111 +237,149 @@ class BaseManager(Generic[ModelType]):
200
237
 
201
238
  If `relations` is None or empty, recurse *all* relationships once (cycle-safe).
202
239
  """
203
- loaders: List[Load] = []
204
-
205
- if relations:
206
- for path in relations:
207
- parts = re.split(r"\.|\_\_", path)
208
- current_model = model
209
- loader: Load | None = None
210
-
211
- for part in parts:
212
- attr = getattr(current_model, part, None)
213
- if attr is None or not hasattr(attr, "property"):
214
- raise ValueError(f"Invalid relationship path: {path!r}")
215
- loader = selectinload(attr) if loader is None else loader.selectinload(attr)
216
- current_model = attr.property.mapper.class_
217
-
218
- loaders.append(loader)
219
-
220
- return loaders
221
-
222
- visited = set()
223
-
224
- def recurse(curr_model: Any, curr_loader: Load | None = None):
225
- mapper = inspect(curr_model)
226
- if mapper in visited:
227
- return
228
- visited.add(mapper)
229
-
230
- for rel in mapper.relationships:
231
- attr = getattr(curr_model, rel.key)
232
- loader = (
233
- selectinload(attr)
234
- if curr_loader is None
235
- else curr_loader.selectinload(attr)
236
- )
237
- loaders.append(loader)
238
- recurse(rel.mapper.class_, loader)
239
-
240
- recurse(model)
241
- return loaders
242
-
243
- async def _execute_query(self, stmt):
244
- result = await self.session.execute(stmt)
245
- rows = result.scalars().all()
246
- return list({obj.id: obj for obj in rows}.values())
247
-
248
- @query_mutator
249
- def order_by(self, *columns: Any):
250
- """Collect ORDER BY clauses.
240
+ loaders: List[Load] = []
241
+
242
+ if relations:
243
+ for path in relations:
244
+ parts = re.split(r"\.|\_\_", path)
245
+ current_model = model
246
+ loader: Load | None = None
247
+
248
+ for part in parts:
249
+ attr = getattr(current_model, part, None)
250
+ if attr is None or not hasattr(attr, "property"):
251
+ raise ValueError(f"Invalid relationship path: {path!r}")
252
+ loader = selectinload(attr) if loader is None else loader.selectinload(attr)
253
+ current_model = attr.property.mapper.class_
254
+
255
+ loaders.append(loader)
256
+
257
+ return loaders
258
+
259
+ visited = set()
260
+
261
+ def recurse(curr_model: Any, curr_loader: Load | None = None):
262
+ mapper = inspect(curr_model)
263
+ if mapper in visited:
264
+ return
265
+ visited.add(mapper)
266
+
267
+ for rel in mapper.relationships:
268
+ attr = getattr(curr_model, rel.key)
269
+ loader = (
270
+ selectinload(attr)
271
+ if curr_loader is None
272
+ else curr_loader.selectinload(attr)
273
+ )
274
+ loaders.append(loader)
275
+ recurse(rel.mapper.class_, loader)
276
+
277
+ recurse(model)
278
+ return loaders
279
+
280
+ async def _execute_query(self, stmt):
281
+ result = await self.session.execute(stmt)
282
+ rows = result.scalars().all()
283
+ return list({obj.id: obj for obj in rows}.values())
284
+
285
+ @query_mutator
286
+ def order_by(self, *columns: Any):
287
+ """Collect ORDER BY clauses.
251
288
  """
252
- for col in columns:
253
- if isinstance(col, str):
254
- descending = col.startswith("-")
255
- field_name = col.lstrip("+-")
256
- sa_col = getattr(self.model, field_name, None)
257
- if sa_col is None:
258
- raise ValueError(f"Invalid order_by field '{field_name}' for {self.model.__name__}")
259
- self._order_by.append(sa_col.desc() if descending else sa_col.asc())
260
- else:
261
- self._order_by.append(col)
262
-
263
- return self
264
-
265
- @query_mutator
266
- def filter(self, *expressions: Any, **lookups: Any) -> "BaseManager":
267
- self._filtered = True
268
-
269
- for k, v in lookups.items():
270
- root = k.split("__", 1)[0]
271
- if hasattr(self.model, root):
272
- attr = getattr(self.model, root)
273
- if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
274
- self._joins.add(root)
275
-
276
- self.filters.append(self._parse_lookup(k, v))
277
-
278
- for expr in expressions:
279
- if isinstance(expr, Q) or isinstance(expr, QJSON):
280
- self.filters.append(self._q_to_expr(expr))
281
- else:
282
- self.filters.append(expr)
283
-
284
- return self
285
-
286
- @query_mutator
287
- def or_filter(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
288
- """Add one OR group (shortcut for `filter(Q() | Q())`)."""
289
-
290
- or_clauses: List[Any] = []
291
- for expr in expressions:
292
- if isinstance(expr, Q) or isinstance(expr, QJSON):
293
- or_clauses.append(self._q_to_expr(expr))
294
- else:
295
- or_clauses.append(expr)
296
-
297
- for k, v in lookups.items():
298
- or_clauses.append(self._parse_lookup(k, v))
299
-
300
- if or_clauses:
301
- self._filtered = True
302
- self.filters.append(or_(*or_clauses))
303
- return self
304
-
305
- @query_mutator
306
- def exclude(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
289
+ for col in columns:
290
+ if isinstance(col, str):
291
+ descending = col.startswith("-")
292
+ field_name = col.lstrip("+-")
293
+ sa_col = getattr(self.model, field_name, None)
294
+ if sa_col is None:
295
+ raise ValueError(f"Invalid order_by field '{field_name}' for {self.model.__name__}")
296
+ self._order_by.append(sa_col.desc() if descending else sa_col.asc())
297
+ else:
298
+ self._order_by.append(col)
299
+
300
+ return self
301
+
302
+ def _is_textual_type(self, col) -> bool:
303
+ try:
304
+ return hasattr(col, "type") and isinstance(col.type, (String, Text, Unicode, UnicodeText))
305
+ except Exception:
306
+ return False
307
+
308
+ def _is_numeric_type(self, col) -> bool:
309
+ try:
310
+ return hasattr(col, "type") and isinstance(col.type, (Integer, BigInteger, SmallInteger, Float, Numeric))
311
+ except Exception:
312
+ return False
313
+
314
+ def _ilike_with_autocast(self, col, pattern: str, raw_value: Any):
315
+ """
316
+ Use ILIKE on text columns. Otherwise cast to text.
317
+ If value looks integer and the column is numeric, do (col == int(value)) OR ILIKE(cast(col)).
307
318
  """
319
+ if self._is_textual_type(col):
320
+ return col.ilike(pattern)
321
+
322
+ text_col = cast(col, String())
323
+ if isinstance(raw_value, str) and raw_value.isdigit() and self._is_numeric_type(col):
324
+ return or_(col == int(raw_value), text_col.ilike(pattern))
325
+ return text_col.ilike(pattern)
326
+
327
+ def _textop_with_autocast(self, col, operator: str, raw_value: Any):
328
+ """
329
+ Supports 'contains' | 'startswith' | 'endswith' with auto-cast.
330
+ """
331
+ val = "" if raw_value is None else str(raw_value)
332
+ if operator == "contains":
333
+ return self._ilike_with_autocast(col, f"%{val}%", raw_value)
334
+ if operator == "startswith":
335
+ return self._ilike_with_autocast(col, f"{val}%", raw_value)
336
+ if operator == "endswith":
337
+ return self._ilike_with_autocast(col, f"%{val}", raw_value)
338
+ raise ValueError(f"Unsupported text operator: {operator}")
339
+
340
+ @query_mutator
341
+ def filter(self, *expressions: Any, **lookups: Any) -> "BaseManager":
342
+ self._filtered = True
343
+
344
+ for k, v in lookups.items():
345
+ root = k.split("__", 1)[0]
346
+ if hasattr(self.model, root):
347
+ attr = getattr(self.model, root)
348
+ if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
349
+ self._joins.add(root)
350
+
351
+ self.filters.append(self._parse_lookup(k, v))
352
+
353
+ for expr in expressions:
354
+ if isinstance(expr, Q) or isinstance(expr, QJSON):
355
+ self.filters.append(self._q_to_expr(expr))
356
+ else:
357
+ self.filters.append(expr)
358
+
359
+ return self
360
+
361
+ @query_mutator
362
+ def or_filter(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
363
+ """Add one OR group (shortcut for `filter(Q() | Q())`)."""
364
+
365
+ or_clauses: List[Any] = []
366
+ for expr in expressions:
367
+ if isinstance(expr, Q) or isinstance(expr, QJSON):
368
+ or_clauses.append(self._q_to_expr(expr))
369
+ else:
370
+ or_clauses.append(expr)
371
+
372
+ for k, v in lookups.items():
373
+ or_clauses.append(self._parse_lookup(k, v))
374
+
375
+ if or_clauses:
376
+ self._filtered = True
377
+ self.filters.append(or_(*or_clauses))
378
+ return self
379
+
380
+ @query_mutator
381
+ def exclude(self, *expressions: Any, **lookups: Any) -> "BaseManager[ModelType]":
382
+ """
308
383
  Exclude records that match the given conditions.
309
384
  This is the opposite of filter() - it adds NOT conditions.
310
385
 
@@ -325,226 +400,227 @@ class BaseManager(Generic[ModelType]):
325
400
  # Exclude using QJSON
326
401
  manager.exclude(QJSON("metadata", "type", "eq", "archived"))
327
402
  """
328
- self._filtered = True
329
-
330
- for k, v in lookups.items():
331
- root = k.split("__", 1)[0]
332
- if hasattr(self.model, root):
333
- attr = getattr(self.model, root)
334
- if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
335
- self._joins.add(root)
336
-
337
- lookup_expr = self._parse_lookup(k, v)
338
- self.filters.append(~lookup_expr)
339
-
340
- for expr in expressions:
341
- if isinstance(expr, Q) or isinstance(expr, QJSON):
342
- q_expr = self._q_to_expr(expr)
343
- self.filters.append(~q_expr)
344
- else:
345
- self.filters.append(~expr)
346
-
347
- return self
348
-
349
- @query_mutator
350
- def limit(self, value: int) -> "BaseManager[ModelType]":
351
- self._limit = value
352
- return self
353
-
354
- async def all(self, relations: Optional[List[str]] = None):
355
- stmt = select(self.model)
356
-
357
- if relations:
358
- opts = self._build_relation_loaders(self.model, relations)
359
- stmt = stmt.options(*opts)
360
-
361
- if self.filters:
362
- stmt = stmt.filter(and_(*self.filters))
363
- if self._order_by:
364
- stmt = stmt.order_by(*self._order_by)
365
- if self._limit:
366
- stmt = stmt.limit(self._limit)
367
-
368
- return await self._execute_query(stmt)
369
-
370
- async def first(self, relations: Optional[Sequence[str]] = None):
371
- self._ensure_filtered()
372
- stmt = select(self.model).filter(and_(*self.filters))
373
-
374
- if self._order_by:
375
- stmt = stmt.order_by(*self._order_by)
376
-
377
- if relations:
378
- opts = self._build_relation_loaders(self.model, relations)
379
- stmt = stmt.options(*opts)
380
-
381
- result = await self.session.execute(stmt)
382
- return result.scalars().first()
383
-
384
- async def last(self, relations: Optional[Sequence[str]] = None):
385
- self._ensure_filtered()
386
- stmt = select(self.model).filter(and_(*self.filters))
387
- order = self._order_by or [self.model.id.desc()]
388
- stmt = stmt.order_by(*order[::-1])
389
-
390
- if relations:
391
- opts = self._build_relation_loaders(self.model, relations)
392
- stmt = stmt.options(*opts)
393
-
394
- result = await self.session.execute(stmt)
395
- return result.scalars().first()
396
-
397
- async def count(self) -> int | None:
398
- self._ensure_filtered()
399
-
400
- stmt = select(func.count(self.model.id)).select_from(self.model)
401
- if self.filters:
402
- stmt = stmt.where(and_(*self.filters))
403
-
404
- result = await self.session.execute(stmt)
405
- return int(result.scalar_one())
406
-
407
- async def paginate(self, limit: int = 10, offset: int = 0, relations: Optional[Sequence[str]] = None):
408
- self._ensure_filtered()
409
- stmt = select(self.model).filter(and_(*self.filters))
410
-
411
- if relations:
412
- opts = self._build_relation_loaders(self.model, relations)
413
- stmt = stmt.options(*opts)
414
-
415
- if self._order_by:
416
- stmt = stmt.order_by(*self._order_by)
417
- stmt = stmt.limit(limit).offset(offset)
418
- return await self._execute_query(stmt)
419
-
420
- @with_transaction_error_handling
421
- async def create(self, **kwargs):
422
- obj = self.model(**kwargs)
423
- self.session.add(obj)
424
- await self.session.flush()
425
- return await self._smart_commit(obj)
426
-
427
- @with_transaction_error_handling
428
- async def save(self, instance: ModelType):
429
- self.session.add(instance)
430
- await self.session.flush()
431
- return await self._smart_commit(instance)
432
-
433
- @with_transaction_error_handling
434
- async def lazy_save(self, instance: ModelType, relations: list[str] | None = None) -> ModelType:
435
- self.session.add(instance)
436
- await self.session.commit()
437
-
438
- if relations is None:
439
- from sqlalchemy.inspection import inspect
440
- mapper = inspect(self.model)
441
- relations = [r.key for r in mapper.relationships]
442
-
443
- if not relations:
444
- return instance
445
-
446
- stmt = select(self.model).filter_by(id=instance.id)
447
- stmt = stmt.options(*self._build_relation_loaders(self.model, relations))
448
- result = await self.session.execute(stmt)
449
- loaded = result.scalar_one_or_none()
450
- if loaded is None:
451
- raise NotFoundException("Could not reload after save", 404, "0001")
452
- return loaded
453
-
454
- @with_transaction_error_handling
455
- async def update(self, instance: ModelType, **fields):
456
- if not fields:
457
- raise InternalException("No fields provided for update")
458
- for k, v in fields.items():
459
- setattr(instance, k, v)
460
- self.session.add(instance)
461
- await self._smart_commit()
462
- return instance
463
-
464
- @with_transaction_error_handling
465
- async def update_by_filters(self, filters: Dict[str, Any], **fields):
466
- if not fields:
467
- raise InternalException("No fields provided for update")
468
- stmt = update(self.model).filter_by(**filters).values(**fields)
469
- await self.session.execute(stmt)
470
- await self.session.commit()
471
- return await self.get(**filters)
472
-
473
- @with_transaction_error_handling
474
- async def delete(self, instance: Optional[ModelType] = None):
475
- if instance is not None:
476
- await self.session.delete(instance)
477
- await self.session.commit()
478
- return True
479
- self._ensure_filtered()
480
- stmt = delete(self.model).where(and_(*self.filters))
481
- await self.session.execute(stmt)
482
- await self.session.commit()
483
- return True
484
-
485
- @with_transaction_error_handling
486
- async def bulk_save(self, instances: Iterable[ModelType]):
487
- if not instances:
488
- return
489
- self.session.add_all(list(instances))
490
- await self.session.flush()
491
- if not self.session.in_transaction():
492
- await self.session.commit()
493
-
494
- @with_transaction_error_handling
495
- async def bulk_delete(self):
496
- self._ensure_filtered()
497
- stmt = delete(self.model).where(and_(*self.filters))
498
- result = await self.session.execute(stmt)
499
- await self._smart_commit()
500
- return result.rowcount
501
-
502
- async def get(self, pk: Union[str, int, uuid.UUID], relations: Optional[Sequence[str]] = None) -> Any:
503
- stmt = select(self.model).filter_by(id=pk)
504
- if relations:
505
- opts = self._build_relation_loaders(self.model, relations)
506
- stmt = stmt.options(*opts)
507
-
508
- result = await self.session.execute(stmt)
509
- instance = result.scalar_one_or_none()
510
- if instance is None:
511
- raise NotFoundException("Object does not exist", 404, "0001")
512
- return instance
513
-
514
- async def is_exists(self):
515
- self._ensure_filtered()
516
-
517
- stmt = (
518
- select(self.model)
519
- .filter(and_(*self.filters))
520
- .limit(1)
521
- )
522
- result = await self.session.execute(stmt)
523
- return result.scalars().first() is not None
524
-
525
- @query_mutator
526
- def has_relation(self, relation_name: str):
527
- relationship = getattr(self.model, relation_name)
528
- subquery = (
529
- select(1)
530
- .select_from(relationship.property.mapper.class_)
531
- .where(relationship.property.primaryjoin)
532
- .exists()
533
- )
534
- self.filters.append(subquery)
535
- self._filtered = True
536
- return self
537
-
538
- def _infer_default_agg(self, column) -> str:
539
- try:
540
- from sqlalchemy import Integer, BigInteger, SmallInteger, Float, Numeric
541
- if hasattr(column, "type") and isinstance(column.type, (Integer, BigInteger, SmallInteger, Float, Numeric)):
542
- return "sum"
543
- except Exception:
544
- pass
545
- return "max"
546
- def _order_expr_for_path(self, token: str):
547
- """
403
+ self._filtered = True
404
+
405
+ for k, v in lookups.items():
406
+ root = k.split("__", 1)[0]
407
+ if hasattr(self.model, root):
408
+ attr = getattr(self.model, root)
409
+ if hasattr(attr, "property") and isinstance(attr.property, RelationshipProperty):
410
+ self._joins.add(root)
411
+
412
+ lookup_expr = self._parse_lookup(k, v)
413
+ self.filters.append(~lookup_expr)
414
+
415
+ for expr in expressions:
416
+ if isinstance(expr, Q) or isinstance(expr, QJSON):
417
+ q_expr = self._q_to_expr(expr)
418
+ self.filters.append(~q_expr)
419
+ else:
420
+ self.filters.append(~expr)
421
+
422
+ return self
423
+
424
+ @query_mutator
425
+ def limit(self, value: int) -> "BaseManager[ModelType]":
426
+ self._limit = value
427
+ return self
428
+
429
+ async def all(self, relations: Optional[List[str]] = None):
430
+ stmt = select(self.model)
431
+
432
+ if relations:
433
+ opts = self._build_relation_loaders(self.model, relations)
434
+ stmt = stmt.options(*opts)
435
+
436
+ if self.filters:
437
+ stmt = stmt.filter(and_(*self.filters))
438
+ if self._order_by:
439
+ stmt = stmt.order_by(*self._order_by)
440
+ if self._limit:
441
+ stmt = stmt.limit(self._limit)
442
+
443
+ return await self._execute_query(stmt)
444
+
445
+ async def first(self, relations: Optional[Sequence[str]] = None):
446
+ self._ensure_filtered()
447
+ stmt = select(self.model).filter(and_(*self.filters))
448
+
449
+ if self._order_by:
450
+ stmt = stmt.order_by(*self._order_by)
451
+
452
+ if relations:
453
+ opts = self._build_relation_loaders(self.model, relations)
454
+ stmt = stmt.options(*opts)
455
+
456
+ result = await self.session.execute(stmt)
457
+ return result.scalars().first()
458
+
459
+ async def last(self, relations: Optional[Sequence[str]] = None):
460
+ self._ensure_filtered()
461
+ stmt = select(self.model).filter(and_(*self.filters))
462
+ order = self._order_by or [self.model.id.desc()]
463
+ stmt = stmt.order_by(*order[::-1])
464
+
465
+ if relations:
466
+ opts = self._build_relation_loaders(self.model, relations)
467
+ stmt = stmt.options(*opts)
468
+
469
+ result = await self.session.execute(stmt)
470
+ return result.scalars().first()
471
+
472
+ async def count(self) -> int | None:
473
+ self._ensure_filtered()
474
+
475
+ stmt = select(func.count(self.model.id)).select_from(self.model)
476
+ if self.filters:
477
+ stmt = stmt.where(and_(*self.filters))
478
+
479
+ result = await self.session.execute(stmt)
480
+ return int(result.scalar_one())
481
+
482
+ async def paginate(self, limit: int = 10, offset: int = 0, relations: Optional[Sequence[str]] = None):
483
+ self._ensure_filtered()
484
+ stmt = select(self.model).filter(and_(*self.filters))
485
+
486
+ if relations:
487
+ opts = self._build_relation_loaders(self.model, relations)
488
+ stmt = stmt.options(*opts)
489
+
490
+ if self._order_by:
491
+ stmt = stmt.order_by(*self._order_by)
492
+ stmt = stmt.limit(limit).offset(offset)
493
+ return await self._execute_query(stmt)
494
+
495
+ @with_transaction_error_handling
496
+ async def create(self, **kwargs):
497
+ obj = self.model(**kwargs)
498
+ self.session.add(obj)
499
+ await self.session.flush()
500
+ return await self._smart_commit(obj)
501
+
502
+ @with_transaction_error_handling
503
+ async def save(self, instance: ModelType):
504
+ self.session.add(instance)
505
+ await self.session.flush()
506
+ return await self._smart_commit(instance)
507
+
508
+ @with_transaction_error_handling
509
+ async def lazy_save(self, instance: ModelType, relations: list[str] | None = None) -> ModelType:
510
+ self.session.add(instance)
511
+ await self.session.commit()
512
+
513
+ if relations is None:
514
+ from sqlalchemy.inspection import inspect
515
+ mapper = inspect(self.model)
516
+ relations = [r.key for r in mapper.relationships]
517
+
518
+ if not relations:
519
+ return instance
520
+
521
+ stmt = select(self.model).filter_by(id=instance.id)
522
+ stmt = stmt.options(*self._build_relation_loaders(self.model, relations))
523
+ result = await self.session.execute(stmt)
524
+ loaded = result.scalar_one_or_none()
525
+ if loaded is None:
526
+ raise NotFoundException("Could not reload after save", 404, "0001")
527
+ return loaded
528
+
529
+ @with_transaction_error_handling
530
+ async def update(self, instance: ModelType, **fields):
531
+ if not fields:
532
+ raise InternalException("No fields provided for update")
533
+ for k, v in fields.items():
534
+ setattr(instance, k, v)
535
+ self.session.add(instance)
536
+ await self._smart_commit()
537
+ return instance
538
+
539
+ @with_transaction_error_handling
540
+ async def update_by_filters(self, filters: Dict[str, Any], **fields):
541
+ if not fields:
542
+ raise InternalException("No fields provided for update")
543
+ stmt = update(self.model).filter_by(**filters).values(**fields)
544
+ await self.session.execute(stmt)
545
+ await self.session.commit()
546
+ return await self.get(**filters)
547
+
548
+ @with_transaction_error_handling
549
+ async def delete(self, instance: Optional[ModelType] = None):
550
+ if instance is not None:
551
+ await self.session.delete(instance)
552
+ await self.session.commit()
553
+ return True
554
+ self._ensure_filtered()
555
+ stmt = delete(self.model).where(and_(*self.filters))
556
+ await self.session.execute(stmt)
557
+ await self.session.commit()
558
+ return True
559
+
560
+ @with_transaction_error_handling
561
+ async def bulk_save(self, instances: Iterable[ModelType]):
562
+ if not instances:
563
+ return
564
+ self.session.add_all(list(instances))
565
+ await self.session.flush()
566
+ if not self.session.in_transaction():
567
+ await self.session.commit()
568
+
569
+ @with_transaction_error_handling
570
+ async def bulk_delete(self):
571
+ self._ensure_filtered()
572
+ stmt = delete(self.model).where(and_(*self.filters))
573
+ result = await self.session.execute(stmt)
574
+ await self._smart_commit()
575
+ return result.rowcount
576
+
577
+ async def get(self, pk: Union[str, int, uuid.UUID], relations: Optional[Sequence[str]] = None) -> Any:
578
+ stmt = select(self.model).filter_by(id=pk)
579
+ if relations:
580
+ opts = self._build_relation_loaders(self.model, relations)
581
+ stmt = stmt.options(*opts)
582
+
583
+ result = await self.session.execute(stmt)
584
+ instance = result.scalar_one_or_none()
585
+ if instance is None:
586
+ raise NotFoundException("Object does not exist", 404, "0001")
587
+ return instance
588
+
589
+ async def is_exists(self):
590
+ self._ensure_filtered()
591
+
592
+ stmt = (
593
+ select(self.model)
594
+ .filter(and_(*self.filters))
595
+ .limit(1)
596
+ )
597
+ result = await self.session.execute(stmt)
598
+ return result.scalars().first() is not None
599
+
600
+ @query_mutator
601
+ def has_relation(self, relation_name: str):
602
+ relationship = getattr(self.model, relation_name)
603
+ subquery = (
604
+ select(1)
605
+ .select_from(relationship.property.mapper.class_)
606
+ .where(relationship.property.primaryjoin)
607
+ .exists()
608
+ )
609
+ self.filters.append(subquery)
610
+ self._filtered = True
611
+ return self
612
+
613
+ def _infer_default_agg(self, column) -> str:
614
+ try:
615
+ from sqlalchemy import Integer, BigInteger, SmallInteger, Float, Numeric
616
+ if hasattr(column, "type") and isinstance(column.type, (Integer, BigInteger, SmallInteger, Float, Numeric)):
617
+ return "sum"
618
+ except Exception:
619
+ pass
620
+ return "max"
621
+
622
+ def _order_expr_for_path(self, token: str):
623
+ """
548
624
  token grammar:
549
625
  [-]<path>[:<agg>][!first|!last]
550
626
  <path> := "field" | "relation__field" (one hop)
@@ -555,103 +631,103 @@ class BaseManager(Generic[ModelType]):
555
631
  "stocks__sold:sum"
556
632
  """
557
633
 
558
- # strip leading '-' (handled by caller), and parse nulls placement
559
- core = token.lstrip("-")
560
- nulls_placement = None
561
- if core.endswith("!first"):
562
- core, nulls_placement = core[:-6], "first"
563
- elif core.endswith("!last"):
564
- core, nulls_placement = core[:-5], "last"
565
-
566
- # split aggregate suffix if present
567
- if ":" in core:
568
- path, agg_name = core.split(":", 1)
569
- agg_name = agg_name.lower().strip()
570
- else:
571
- path, agg_name = core, None
572
-
573
- # base column on the model (no relation hop)
574
- if "__" not in path and "." not in path:
575
- col = getattr(self.model, path, None)
576
- if col is None:
577
- raise ValueError(f"Invalid order_by field '{path}' for {self.model.__name__}")
578
- expr = col
579
- if nulls_placement == "first":
580
- expr = expr.nullsfirst()
581
- elif nulls_placement == "last":
582
- expr = expr.nullslast()
583
- return expr
584
-
585
- # relation hop (exactly one)
586
- parts = re.split(r"\.|\_\_", path)
587
- if len(parts) != 2:
588
- raise ValueError(f"Only one relation hop supported in order_by: {path!r}")
589
-
590
- rel_name, col_name = parts
591
- rel_attr = getattr(self.model, rel_name, None)
592
- if rel_attr is None or not hasattr(rel_attr, "property"):
593
- raise ValueError(f"Invalid relationship '{rel_name}' on {self.model.__name__}")
594
-
595
- target_mapper = rel_attr.property.mapper
596
- target_cls = target_mapper.class_
597
- target_col = getattr(target_cls, col_name, None)
598
- if target_col is None:
599
- raise ValueError(f"Invalid column '{col_name}' on related model {target_cls.__name__}")
600
-
601
- primaryjoin = rel_attr.property.primaryjoin
602
- uselist = rel_attr.property.uselist
603
-
604
- # One-to-many (or many-to-many via association): require aggregate (or infer)
605
- if uselist:
606
- agg_name = agg_name or self._infer_default_agg(target_col)
607
- agg_fn = AGG_MAP.get(agg_name)
608
- if agg_fn is None:
609
- raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
610
-
611
- # SELECT agg(related.col) WHERE primaryjoin (correlated)
612
- subq = (
613
- select(agg_fn(target_col))
614
- .where(primaryjoin)
615
- .correlate(self.model) # tie to outer row
616
- .scalar_subquery()
617
- )
618
- expr = subq
619
-
620
- else:
621
- if agg_name and agg_name != "first":
622
- agg_fn = AGG_MAP.get(agg_name)
623
- if agg_fn is None:
624
- raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
625
- select_expr = agg_fn(target_col)
626
- else:
627
- select_expr = target_col
628
-
629
- sub = select(select_expr).where(primaryjoin).correlate(self.model)
630
- if agg_name == "first":
631
- sub = sub.limit(1)
632
- expr = sub.scalar_subquery()
633
-
634
- if nulls_placement == "first":
635
- expr = expr.nullsfirst()
636
- elif nulls_placement == "last":
637
- expr = expr.nullslast()
638
-
639
- return expr
640
-
641
- @query_mutator
642
- def sort_by(self, tokens):
643
- self._order_by = []
644
- for tok in tokens or []:
645
- direction = desc if tok.startswith("-") else asc
646
- name = tok.lstrip("-")
647
- self._order_by.append(direction(self._order_expr_for_path(name)))
648
- return self
649
-
650
- def model_to_dict(self, instance: ModelType, exclude: set[str] = None) -> dict:
651
- exclude = exclude or set()
652
- return {k: getattr(instance, k) for k in self._column_keys if k not in exclude}
653
-
654
- def _ensure_filtered(self):
655
- if not self._filtered:
656
- raise RuntimeError("You must call `filter()` before this operation.")
634
+ # strip leading '-' (handled by caller), and parse nulls placement
635
+ core = token.lstrip("-")
636
+ nulls_placement = None
637
+ if core.endswith("!first"):
638
+ core, nulls_placement = core[:-6], "first"
639
+ elif core.endswith("!last"):
640
+ core, nulls_placement = core[:-5], "last"
641
+
642
+ # split aggregate suffix if present
643
+ if ":" in core:
644
+ path, agg_name = core.split(":", 1)
645
+ agg_name = agg_name.lower().strip()
646
+ else:
647
+ path, agg_name = core, None
648
+
649
+ # base column on the model (no relation hop)
650
+ if "__" not in path and "." not in path:
651
+ col = getattr(self.model, path, None)
652
+ if col is None:
653
+ raise ValueError(f"Invalid order_by field '{path}' for {self.model.__name__}")
654
+ expr = col
655
+ if nulls_placement == "first":
656
+ expr = expr.nullsfirst()
657
+ elif nulls_placement == "last":
658
+ expr = expr.nullslast()
659
+ return expr
660
+
661
+ # relation hop (exactly one)
662
+ parts = re.split(r"\.|\_\_", path)
663
+ if len(parts) != 2:
664
+ raise ValueError(f"Only one relation hop supported in order_by: {path!r}")
665
+
666
+ rel_name, col_name = parts
667
+ rel_attr = getattr(self.model, rel_name, None)
668
+ if rel_attr is None or not hasattr(rel_attr, "property"):
669
+ raise ValueError(f"Invalid relationship '{rel_name}' on {self.model.__name__}")
670
+
671
+ target_mapper = rel_attr.property.mapper
672
+ target_cls = target_mapper.class_
673
+ target_col = getattr(target_cls, col_name, None)
674
+ if target_col is None:
675
+ raise ValueError(f"Invalid column '{col_name}' on related model {target_cls.__name__}")
676
+
677
+ primaryjoin = rel_attr.property.primaryjoin
678
+ uselist = rel_attr.property.uselist
679
+
680
+ # One-to-many (or many-to-many via association): require aggregate (or infer)
681
+ if uselist:
682
+ agg_name = agg_name or self._infer_default_agg(target_col)
683
+ agg_fn = AGG_MAP.get(agg_name)
684
+ if agg_fn is None:
685
+ raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
686
+
687
+ # SELECT agg(related.col) WHERE primaryjoin (correlated)
688
+ subq = (
689
+ select(agg_fn(target_col))
690
+ .where(primaryjoin)
691
+ .correlate(self.model) # tie to outer row
692
+ .scalar_subquery()
693
+ )
694
+ expr = subq
695
+
696
+ else:
697
+ if agg_name and agg_name != "first":
698
+ agg_fn = AGG_MAP.get(agg_name)
699
+ if agg_fn is None:
700
+ raise ValueError(f"Unsupported aggregate '{agg_name}' in order_by for {path!r}")
701
+ select_expr = agg_fn(target_col)
702
+ else:
703
+ select_expr = target_col
704
+
705
+ sub = select(select_expr).where(primaryjoin).correlate(self.model)
706
+ if agg_name == "first":
707
+ sub = sub.limit(1)
708
+ expr = sub.scalar_subquery()
709
+
710
+ if nulls_placement == "first":
711
+ expr = expr.nullsfirst()
712
+ elif nulls_placement == "last":
713
+ expr = expr.nullslast()
714
+
715
+ return expr
716
+
717
+ @query_mutator
718
+ def sort_by(self, tokens):
719
+ self._order_by = []
720
+ for tok in tokens or []:
721
+ direction = desc if tok.startswith("-") else asc
722
+ name = tok.lstrip("-")
723
+ self._order_by.append(direction(self._order_expr_for_path(name)))
724
+ return self
725
+
726
+ def model_to_dict(self, instance: ModelType, exclude: set[str] = None) -> dict:
727
+ exclude = exclude or set()
728
+ return {k: getattr(instance, k) for k in self._column_keys if k not in exclude}
729
+
730
+ def _ensure_filtered(self):
731
+ if not self._filtered:
732
+ raise RuntimeError("You must call `filter()` before this operation.")
657
733