ns-orm 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ns_orm/__init__.py +96 -0
- ns_orm/cli.py +174 -0
- ns_orm/database.py +292 -0
- ns_orm/dialects.py +290 -0
- ns_orm/exceptions.py +26 -0
- ns_orm/expressions.py +108 -0
- ns_orm/fields.py +313 -0
- ns_orm/manager.py +72 -0
- ns_orm/migrations/__init__.py +3 -0
- ns_orm/migrations/autodetector.py +159 -0
- ns_orm/migrations/executor.py +150 -0
- ns_orm/migrations/loader.py +53 -0
- ns_orm/migrations/migration.py +14 -0
- ns_orm/migrations/operations.py +93 -0
- ns_orm/migrations/state.py +42 -0
- ns_orm/migrations/writer.py +79 -0
- ns_orm/model.py +151 -0
- ns_orm/query.py +659 -0
- ns_orm/schema.py +131 -0
- ns_orm/typing.py +39 -0
- ns_orm/utils.py +58 -0
- ns_orm-0.0.0.dist-info/METADATA +289 -0
- ns_orm-0.0.0.dist-info/RECORD +27 -0
- ns_orm-0.0.0.dist-info/WHEEL +5 -0
- ns_orm-0.0.0.dist-info/entry_points.txt +2 -0
- ns_orm-0.0.0.dist-info/licenses/LICENSE +201 -0
- ns_orm-0.0.0.dist-info/top_level.txt +1 -0
ns_orm/query.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, replace
|
|
4
|
+
from typing import Any, Generic, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
from ns_orm.database import AsyncDatabase, Database
|
|
7
|
+
from ns_orm.dialects import Dialect
|
|
8
|
+
from ns_orm.exceptions import DoesNotExist, MultipleObjectsReturned, QueryError
|
|
9
|
+
from ns_orm.expressions import Condition, Q, normalize_in_values
|
|
10
|
+
|
|
11
|
+
TModel = TypeVar("TModel")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _table_alias() -> str:
|
|
15
|
+
return "t"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _qualified_table(model: type[Any], dialect: Dialect) -> str:
|
|
19
|
+
table = dialect.quote_ident(model.table_name())
|
|
20
|
+
schema = getattr(model._meta, "schema", None)
|
|
21
|
+
if schema:
|
|
22
|
+
return f"{dialect.quote_ident(schema)}.{table}"
|
|
23
|
+
return table
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _where_from_q(
|
|
27
|
+
model: type[Any], dialect: Dialect, q: Q
|
|
28
|
+
) -> tuple[str, dict[str, Any]]:
|
|
29
|
+
def _lookup_compiler(
|
|
30
|
+
*,
|
|
31
|
+
lookup: str,
|
|
32
|
+
value: Any,
|
|
33
|
+
table_alias: str,
|
|
34
|
+
param_prefix: str,
|
|
35
|
+
start_index: int,
|
|
36
|
+
) -> tuple[Condition, int]:
|
|
37
|
+
parts = lookup.split("__")
|
|
38
|
+
field = parts[0]
|
|
39
|
+
op = parts[1] if len(parts) > 1 else "exact"
|
|
40
|
+
if field not in model._meta.fields:
|
|
41
|
+
raise QueryError(f"Unknown field: {field}")
|
|
42
|
+
|
|
43
|
+
col = f"{table_alias}.{dialect.quote_ident(field)}"
|
|
44
|
+
|
|
45
|
+
if op == "exact":
|
|
46
|
+
key = f"{param_prefix}{start_index}"
|
|
47
|
+
return Condition(
|
|
48
|
+
sql=f"{col} = :{key}", params={key: value}
|
|
49
|
+
), start_index + 1
|
|
50
|
+
if op in {"gt", "gte", "lt", "lte"}:
|
|
51
|
+
key = f"{param_prefix}{start_index}"
|
|
52
|
+
op_map = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
|
|
53
|
+
return Condition(
|
|
54
|
+
sql=f"{col} {op_map} :{key}", params={key: value}
|
|
55
|
+
), start_index + 1
|
|
56
|
+
if op == "in":
|
|
57
|
+
values = normalize_in_values(value)
|
|
58
|
+
if not values:
|
|
59
|
+
return Condition(sql="1=0", params={}), start_index
|
|
60
|
+
keys = []
|
|
61
|
+
params: dict[str, Any] = {}
|
|
62
|
+
idx = start_index
|
|
63
|
+
for v in values:
|
|
64
|
+
key = f"{param_prefix}{idx}"
|
|
65
|
+
idx += 1
|
|
66
|
+
keys.append(f":{key}")
|
|
67
|
+
params[key] = v
|
|
68
|
+
return Condition(sql=f"{col} IN ({', '.join(keys)})", params=params), idx
|
|
69
|
+
if op in {
|
|
70
|
+
"contains",
|
|
71
|
+
"icontains",
|
|
72
|
+
"startswith",
|
|
73
|
+
"istartswith",
|
|
74
|
+
"endswith",
|
|
75
|
+
"iendswith",
|
|
76
|
+
}:
|
|
77
|
+
key = f"{param_prefix}{start_index}"
|
|
78
|
+
if op.endswith("contains"):
|
|
79
|
+
pattern = f"%{value}%"
|
|
80
|
+
elif op.endswith("startswith"):
|
|
81
|
+
pattern = f"{value}%"
|
|
82
|
+
else:
|
|
83
|
+
pattern = f"%{value}"
|
|
84
|
+
if op.startswith("i"):
|
|
85
|
+
return (
|
|
86
|
+
Condition(
|
|
87
|
+
sql=f"LOWER({col}) LIKE LOWER(:{key})", params={key: pattern}
|
|
88
|
+
),
|
|
89
|
+
start_index + 1,
|
|
90
|
+
)
|
|
91
|
+
return Condition(
|
|
92
|
+
sql=f"{col} LIKE :{key}", params={key: pattern}
|
|
93
|
+
), start_index + 1
|
|
94
|
+
if op == "isnull":
|
|
95
|
+
is_null = bool(value)
|
|
96
|
+
return (
|
|
97
|
+
Condition(
|
|
98
|
+
sql=f"{col} IS {'NULL' if is_null else 'NOT NULL'}", params={}
|
|
99
|
+
),
|
|
100
|
+
start_index,
|
|
101
|
+
)
|
|
102
|
+
raise QueryError(f"Unsupported lookup: {lookup}")
|
|
103
|
+
|
|
104
|
+
cond, _ = q.compile(
|
|
105
|
+
table_alias=_table_alias(),
|
|
106
|
+
param_prefix="p",
|
|
107
|
+
start_index=1,
|
|
108
|
+
lookup_compiler=_lookup_compiler,
|
|
109
|
+
)
|
|
110
|
+
return cond.sql, dict(cond.params)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class QueryState:
|
|
115
|
+
where: Q
|
|
116
|
+
order_by: tuple[str, ...]
|
|
117
|
+
limit: Optional[int]
|
|
118
|
+
offset: Optional[int]
|
|
119
|
+
prefetch: tuple[str, ...]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class QuerySet(Generic[TModel]):
|
|
123
|
+
def __init__(
|
|
124
|
+
self, model: type[TModel], *, db: Database, state: Optional[QueryState] = None
|
|
125
|
+
):
|
|
126
|
+
self.model = model
|
|
127
|
+
self.db = db
|
|
128
|
+
self.state = state or QueryState(
|
|
129
|
+
where=Q(), order_by=(), limit=None, offset=None, prefetch=()
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def filter(self, **lookups: Any) -> QuerySet[TModel]:
|
|
133
|
+
q = Q(**lookups)
|
|
134
|
+
combined = self.state.where & q if not self.state.where.is_empty() else q
|
|
135
|
+
return QuerySet(
|
|
136
|
+
self.model, db=self.db, state=replace(self.state, where=combined)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def exclude(self, **lookups: Any) -> QuerySet[TModel]:
|
|
140
|
+
q = ~Q(**lookups)
|
|
141
|
+
combined = self.state.where & q if not self.state.where.is_empty() else q
|
|
142
|
+
return QuerySet(
|
|
143
|
+
self.model, db=self.db, state=replace(self.state, where=combined)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def order_by(self, *fields: str) -> QuerySet[TModel]:
|
|
147
|
+
return QuerySet(
|
|
148
|
+
self.model, db=self.db, state=replace(self.state, order_by=fields)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def limit(self, n: int) -> QuerySet[TModel]:
|
|
152
|
+
return QuerySet(self.model, db=self.db, state=replace(self.state, limit=int(n)))
|
|
153
|
+
|
|
154
|
+
def offset(self, n: int) -> QuerySet[TModel]:
|
|
155
|
+
return QuerySet(
|
|
156
|
+
self.model, db=self.db, state=replace(self.state, offset=int(n))
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def prefetch_related(self, *names: str) -> QuerySet[TModel]:
|
|
160
|
+
merged = tuple(dict.fromkeys(self.state.prefetch + tuple(names)))
|
|
161
|
+
return QuerySet(
|
|
162
|
+
self.model, db=self.db, state=replace(self.state, prefetch=merged)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def select_related(self, *names: str) -> QuerySet[TModel]:
|
|
166
|
+
return self.prefetch_related(*names)
|
|
167
|
+
|
|
168
|
+
def _select_columns(self) -> list[str]:
|
|
169
|
+
cols = []
|
|
170
|
+
alias = _table_alias()
|
|
171
|
+
for name in self.model._meta.fields:
|
|
172
|
+
col = self.db.dialect.quote_ident(name)
|
|
173
|
+
cols.append(f"{alias}.{col} AS {col}")
|
|
174
|
+
return cols
|
|
175
|
+
|
|
176
|
+
def _order_by_sql(self) -> str:
|
|
177
|
+
if not self.state.order_by:
|
|
178
|
+
return ""
|
|
179
|
+
alias = _table_alias()
|
|
180
|
+
parts: list[str] = []
|
|
181
|
+
for f in self.state.order_by:
|
|
182
|
+
direction = "ASC"
|
|
183
|
+
name = f
|
|
184
|
+
if f.startswith("-"):
|
|
185
|
+
direction = "DESC"
|
|
186
|
+
name = f[1:]
|
|
187
|
+
if name not in self.model._meta.fields:
|
|
188
|
+
raise QueryError(f"Unknown order_by field: {name}")
|
|
189
|
+
parts.append(f"{alias}.{self.db.dialect.quote_ident(name)} {direction}")
|
|
190
|
+
return " ORDER BY " + ", ".join(parts)
|
|
191
|
+
|
|
192
|
+
def _select_sql(self) -> tuple[str, dict[str, Any]]:
|
|
193
|
+
where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
|
|
194
|
+
cols = ", ".join(self._select_columns())
|
|
195
|
+
sql = (
|
|
196
|
+
f"SELECT {cols} FROM {_qualified_table(self.model, self.db.dialect)} "
|
|
197
|
+
f"{_table_alias()} WHERE {where_sql}"
|
|
198
|
+
)
|
|
199
|
+
sql += self._order_by_sql()
|
|
200
|
+
sql = self.db.dialect.apply_limit_offset(
|
|
201
|
+
sql, self.state.limit, self.state.offset
|
|
202
|
+
)
|
|
203
|
+
return sql, params
|
|
204
|
+
|
|
205
|
+
def all(self) -> list[TModel]:
|
|
206
|
+
sql, params = self._select_sql()
|
|
207
|
+
rows = self.db.fetch_all(sql, params)
|
|
208
|
+
instances = [self.model.parse_obj(r) for r in rows]
|
|
209
|
+
if self.state.prefetch and instances:
|
|
210
|
+
self._prefetch(instances)
|
|
211
|
+
return instances
|
|
212
|
+
|
|
213
|
+
def first(self) -> Optional[TModel]:
|
|
214
|
+
items = self.limit(1).all()
|
|
215
|
+
return items[0] if items else None
|
|
216
|
+
|
|
217
|
+
def get(self, **lookups: Any) -> TModel:
|
|
218
|
+
items = self.filter(**lookups).limit(2).all()
|
|
219
|
+
if not items:
|
|
220
|
+
raise DoesNotExist(f"{self.model.__name__} matching query does not exist")
|
|
221
|
+
if len(items) > 1:
|
|
222
|
+
raise MultipleObjectsReturned(f"Multiple {self.model.__name__} returned")
|
|
223
|
+
return items[0]
|
|
224
|
+
|
|
225
|
+
def create(self, **data: Any) -> TModel:
|
|
226
|
+
inst = self.model(**data)
|
|
227
|
+
self.save_instance(inst)
|
|
228
|
+
return inst
|
|
229
|
+
|
|
230
|
+
def save_instance(self, inst: Any) -> Any:
|
|
231
|
+
pk_name = inst.pk_name()
|
|
232
|
+
pk_value = getattr(inst, pk_name, None)
|
|
233
|
+
data = inst.to_db_dict()
|
|
234
|
+
if pk_value is None:
|
|
235
|
+
new_pk = self._insert(data)
|
|
236
|
+
if new_pk is not None:
|
|
237
|
+
setattr(inst, pk_name, new_pk)
|
|
238
|
+
return inst
|
|
239
|
+
self.filter(**{pk_name: pk_value}).update(**data)
|
|
240
|
+
return inst
|
|
241
|
+
|
|
242
|
+
def _insert(self, data: dict[str, Any]) -> Any:
|
|
243
|
+
cols = [
|
|
244
|
+
c
|
|
245
|
+
for c in data.keys()
|
|
246
|
+
if c in self.model._meta.fields and data[c] is not None
|
|
247
|
+
]
|
|
248
|
+
if not cols:
|
|
249
|
+
raise QueryError("No insertable columns provided")
|
|
250
|
+
col_sql = ", ".join(self.db.dialect.quote_ident(c) for c in cols)
|
|
251
|
+
values_sql = ", ".join(f":p{i + 1}" for i in range(len(cols)))
|
|
252
|
+
params = {f"p{i + 1}": data[c] for i, c in enumerate(cols)}
|
|
253
|
+
table = _qualified_table(self.model, self.db.dialect)
|
|
254
|
+
pk = self.model.pk_name()
|
|
255
|
+
if self.db.dialect.name == "postgres":
|
|
256
|
+
sql = (
|
|
257
|
+
f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql}) "
|
|
258
|
+
f"RETURNING {self.db.dialect.quote_ident(pk)}"
|
|
259
|
+
)
|
|
260
|
+
row = self.db.fetch_one(sql, params)
|
|
261
|
+
return None if row is None else row.get(pk)
|
|
262
|
+
|
|
263
|
+
sql = f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql})"
|
|
264
|
+
self.db.execute(sql, params)
|
|
265
|
+
if self.db.dialect.name == "sqlite":
|
|
266
|
+
row = self.db.fetch_one("SELECT last_insert_rowid() AS id")
|
|
267
|
+
return None if row is None else row.get("id")
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
def update(self, **data: Any) -> int:
|
|
271
|
+
set_cols = [k for k in data.keys() if k in self.model._meta.fields]
|
|
272
|
+
if not set_cols:
|
|
273
|
+
return 0
|
|
274
|
+
where_sql, where_params = _where_from_q(
|
|
275
|
+
self.model, self.db.dialect, self.state.where
|
|
276
|
+
)
|
|
277
|
+
sets = []
|
|
278
|
+
params: dict[str, Any] = {}
|
|
279
|
+
idx = 1
|
|
280
|
+
for c in set_cols:
|
|
281
|
+
key = f"p{idx}"
|
|
282
|
+
idx += 1
|
|
283
|
+
sets.append(f"{self.db.dialect.quote_ident(c)} = :{key}")
|
|
284
|
+
params[key] = data[c]
|
|
285
|
+
params.update(where_params)
|
|
286
|
+
sql = (
|
|
287
|
+
f"UPDATE {_qualified_table(self.model, self.db.dialect)} "
|
|
288
|
+
f"SET {', '.join(sets)} WHERE {where_sql}"
|
|
289
|
+
)
|
|
290
|
+
return self.db.execute(sql, params)
|
|
291
|
+
|
|
292
|
+
def delete(self) -> int:
|
|
293
|
+
where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
|
|
294
|
+
table = _qualified_table(self.model, self.db.dialect)
|
|
295
|
+
sql = f"DELETE FROM {table} WHERE {where_sql}"
|
|
296
|
+
return self.db.execute(sql, params)
|
|
297
|
+
|
|
298
|
+
def _prefetch(self, instances: list[Any]) -> None:
|
|
299
|
+
for name in self.state.prefetch:
|
|
300
|
+
if name in self.model._meta.relations:
|
|
301
|
+
self._prefetch_fk(instances, name)
|
|
302
|
+
elif name in self.model._meta.m2m:
|
|
303
|
+
self._prefetch_m2m(instances, name)
|
|
304
|
+
else:
|
|
305
|
+
raise QueryError(f"Unknown relation for prefetch: {name}")
|
|
306
|
+
|
|
307
|
+
def _prefetch_fk(self, instances: list[Any], rel_name: str) -> None:
|
|
308
|
+
fk_col, fk = self.model._meta.relations[rel_name]
|
|
309
|
+
to_model = fk.to
|
|
310
|
+
if isinstance(to_model, str):
|
|
311
|
+
to_model = self.model._resolve_model(to_model)
|
|
312
|
+
ids = [
|
|
313
|
+
getattr(i, fk_col)
|
|
314
|
+
for i in instances
|
|
315
|
+
if getattr(i, fk_col, None) is not None
|
|
316
|
+
]
|
|
317
|
+
ids = list(dict.fromkeys(ids))
|
|
318
|
+
if not ids:
|
|
319
|
+
for i in instances:
|
|
320
|
+
setattr(i, rel_name, None)
|
|
321
|
+
return
|
|
322
|
+
related = (
|
|
323
|
+
QuerySet(to_model, db=self.db)
|
|
324
|
+
.filter(**{to_model.pk_name() + "__in": ids})
|
|
325
|
+
.all()
|
|
326
|
+
)
|
|
327
|
+
mapping = {getattr(r, to_model.pk_name()): r for r in related}
|
|
328
|
+
for i in instances:
|
|
329
|
+
setattr(i, rel_name, mapping.get(getattr(i, fk_col)))
|
|
330
|
+
|
|
331
|
+
def _prefetch_m2m(self, instances: list[Any], field_name: str) -> None:
|
|
332
|
+
rel = self.model._meta.m2m[field_name]
|
|
333
|
+
to_model = rel.to
|
|
334
|
+
if isinstance(to_model, str):
|
|
335
|
+
to_model = self.model._resolve_model(to_model)
|
|
336
|
+
from_pk = self.model.pk_name()
|
|
337
|
+
to_pk = to_model.pk_name()
|
|
338
|
+
through = rel.through or f"{self.model.table_name()}_{to_model.table_name()}"
|
|
339
|
+
from_col = rel.from_field or f"{self.model.table_name()}_{from_pk}"
|
|
340
|
+
to_col = rel.to_field or f"{to_model.table_name()}_{to_pk}"
|
|
341
|
+
|
|
342
|
+
base_ids = [
|
|
343
|
+
getattr(i, from_pk)
|
|
344
|
+
for i in instances
|
|
345
|
+
if getattr(i, from_pk, None) is not None
|
|
346
|
+
]
|
|
347
|
+
base_ids = list(dict.fromkeys(base_ids))
|
|
348
|
+
if not base_ids:
|
|
349
|
+
for i in instances:
|
|
350
|
+
setattr(i, field_name, [])
|
|
351
|
+
return
|
|
352
|
+
|
|
353
|
+
table = self.db.dialect.quote_ident(through)
|
|
354
|
+
sql = (
|
|
355
|
+
f"SELECT {self.db.dialect.quote_ident(from_col)} AS from_id, "
|
|
356
|
+
f"{self.db.dialect.quote_ident(to_col)} AS to_id "
|
|
357
|
+
f"FROM {table} WHERE {self.db.dialect.quote_ident(from_col)} IN ("
|
|
358
|
+
)
|
|
359
|
+
params: dict[str, Any] = {}
|
|
360
|
+
keys = []
|
|
361
|
+
for idx, v in enumerate(base_ids, start=1):
|
|
362
|
+
key = f"p{idx}"
|
|
363
|
+
keys.append(f":{key}")
|
|
364
|
+
params[key] = v
|
|
365
|
+
sql += ", ".join(keys) + ")"
|
|
366
|
+
pairs = self.db.fetch_all(sql, params)
|
|
367
|
+
to_ids = list(dict.fromkeys([p["to_id"] for p in pairs]))
|
|
368
|
+
if not to_ids:
|
|
369
|
+
for i in instances:
|
|
370
|
+
setattr(i, field_name, [])
|
|
371
|
+
return
|
|
372
|
+
|
|
373
|
+
related = (
|
|
374
|
+
QuerySet(to_model, db=self.db).filter(**{to_pk + "__in": to_ids}).all()
|
|
375
|
+
)
|
|
376
|
+
to_map = {getattr(r, to_pk): r for r in related}
|
|
377
|
+
by_from: dict[Any, list[Any]] = {}
|
|
378
|
+
for p in pairs:
|
|
379
|
+
by_from.setdefault(p["from_id"], []).append(to_map.get(p["to_id"]))
|
|
380
|
+
for i in instances:
|
|
381
|
+
setattr(
|
|
382
|
+
i,
|
|
383
|
+
field_name,
|
|
384
|
+
[x for x in by_from.get(getattr(i, from_pk), []) if x is not None],
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class AsyncQuerySet(Generic[TModel]):
|
|
389
|
+
def __init__(
|
|
390
|
+
self,
|
|
391
|
+
model: type[TModel],
|
|
392
|
+
*,
|
|
393
|
+
db: AsyncDatabase,
|
|
394
|
+
state: Optional[QueryState] = None,
|
|
395
|
+
) -> None:
|
|
396
|
+
self.model = model
|
|
397
|
+
self.db = db
|
|
398
|
+
self.state = state or QueryState(
|
|
399
|
+
where=Q(), order_by=(), limit=None, offset=None, prefetch=()
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def filter(self, **lookups: Any) -> AsyncQuerySet[TModel]:
|
|
403
|
+
q = Q(**lookups)
|
|
404
|
+
combined = self.state.where & q if not self.state.where.is_empty() else q
|
|
405
|
+
return AsyncQuerySet(
|
|
406
|
+
self.model, db=self.db, state=replace(self.state, where=combined)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
def exclude(self, **lookups: Any) -> AsyncQuerySet[TModel]:
|
|
410
|
+
q = ~Q(**lookups)
|
|
411
|
+
combined = self.state.where & q if not self.state.where.is_empty() else q
|
|
412
|
+
return AsyncQuerySet(
|
|
413
|
+
self.model, db=self.db, state=replace(self.state, where=combined)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
def order_by(self, *fields: str) -> AsyncQuerySet[TModel]:
|
|
417
|
+
return AsyncQuerySet(
|
|
418
|
+
self.model, db=self.db, state=replace(self.state, order_by=fields)
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
def limit(self, n: int) -> AsyncQuerySet[TModel]:
|
|
422
|
+
return AsyncQuerySet(
|
|
423
|
+
self.model, db=self.db, state=replace(self.state, limit=int(n))
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
def offset(self, n: int) -> AsyncQuerySet[TModel]:
|
|
427
|
+
return AsyncQuerySet(
|
|
428
|
+
self.model, db=self.db, state=replace(self.state, offset=int(n))
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def prefetch_related(self, *names: str) -> AsyncQuerySet[TModel]:
|
|
432
|
+
merged = tuple(dict.fromkeys(self.state.prefetch + tuple(names)))
|
|
433
|
+
return AsyncQuerySet(
|
|
434
|
+
self.model, db=self.db, state=replace(self.state, prefetch=merged)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def select_related(self, *names: str) -> AsyncQuerySet[TModel]:
|
|
438
|
+
return self.prefetch_related(*names)
|
|
439
|
+
|
|
440
|
+
def _select_columns(self) -> list[str]:
|
|
441
|
+
cols = []
|
|
442
|
+
alias = _table_alias()
|
|
443
|
+
for name in self.model._meta.fields:
|
|
444
|
+
col = self.db.dialect.quote_ident(name)
|
|
445
|
+
cols.append(f"{alias}.{col} AS {col}")
|
|
446
|
+
return cols
|
|
447
|
+
|
|
448
|
+
def _order_by_sql(self) -> str:
|
|
449
|
+
if not self.state.order_by:
|
|
450
|
+
return ""
|
|
451
|
+
alias = _table_alias()
|
|
452
|
+
parts: list[str] = []
|
|
453
|
+
for f in self.state.order_by:
|
|
454
|
+
direction = "ASC"
|
|
455
|
+
name = f
|
|
456
|
+
if f.startswith("-"):
|
|
457
|
+
direction = "DESC"
|
|
458
|
+
name = f[1:]
|
|
459
|
+
if name not in self.model._meta.fields:
|
|
460
|
+
raise QueryError(f"Unknown order_by field: {name}")
|
|
461
|
+
parts.append(f"{alias}.{self.db.dialect.quote_ident(name)} {direction}")
|
|
462
|
+
return " ORDER BY " + ", ".join(parts)
|
|
463
|
+
|
|
464
|
+
def _select_sql(self) -> tuple[str, dict[str, Any]]:
|
|
465
|
+
where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
|
|
466
|
+
cols = ", ".join(self._select_columns())
|
|
467
|
+
sql = (
|
|
468
|
+
f"SELECT {cols} FROM {_qualified_table(self.model, self.db.dialect)} "
|
|
469
|
+
f"{_table_alias()} WHERE {where_sql}"
|
|
470
|
+
)
|
|
471
|
+
sql += self._order_by_sql()
|
|
472
|
+
sql = self.db.dialect.apply_limit_offset(
|
|
473
|
+
sql, self.state.limit, self.state.offset
|
|
474
|
+
)
|
|
475
|
+
return sql, params
|
|
476
|
+
|
|
477
|
+
async def all(self) -> list[TModel]:
|
|
478
|
+
sql, params = self._select_sql()
|
|
479
|
+
rows = await self.db.fetch_all(sql, params)
|
|
480
|
+
instances = [self.model.parse_obj(r) for r in rows]
|
|
481
|
+
if self.state.prefetch and instances:
|
|
482
|
+
await self._prefetch(instances)
|
|
483
|
+
return instances
|
|
484
|
+
|
|
485
|
+
async def first(self) -> Optional[TModel]:
|
|
486
|
+
items = await self.limit(1).all()
|
|
487
|
+
return items[0] if items else None
|
|
488
|
+
|
|
489
|
+
async def get(self, **lookups: Any) -> TModel:
|
|
490
|
+
items = await self.filter(**lookups).limit(2).all()
|
|
491
|
+
if not items:
|
|
492
|
+
raise DoesNotExist(f"{self.model.__name__} matching query does not exist")
|
|
493
|
+
if len(items) > 1:
|
|
494
|
+
raise MultipleObjectsReturned(f"Multiple {self.model.__name__} returned")
|
|
495
|
+
return items[0]
|
|
496
|
+
|
|
497
|
+
async def create(self, **data: Any) -> TModel:
|
|
498
|
+
inst = self.model(**data)
|
|
499
|
+
await self.save_instance(inst)
|
|
500
|
+
return inst
|
|
501
|
+
|
|
502
|
+
async def save_instance(self, inst: Any) -> Any:
|
|
503
|
+
pk_name = inst.pk_name()
|
|
504
|
+
pk_value = getattr(inst, pk_name, None)
|
|
505
|
+
data = inst.to_db_dict()
|
|
506
|
+
if pk_value is None:
|
|
507
|
+
new_pk = await self._insert(data)
|
|
508
|
+
if new_pk is not None:
|
|
509
|
+
setattr(inst, pk_name, new_pk)
|
|
510
|
+
return inst
|
|
511
|
+
await self.filter(**{pk_name: pk_value}).update(**data)
|
|
512
|
+
return inst
|
|
513
|
+
|
|
514
|
+
async def _insert(self, data: dict[str, Any]) -> Any:
|
|
515
|
+
cols = [
|
|
516
|
+
c
|
|
517
|
+
for c in data.keys()
|
|
518
|
+
if c in self.model._meta.fields and data[c] is not None
|
|
519
|
+
]
|
|
520
|
+
if not cols:
|
|
521
|
+
raise QueryError("No insertable columns provided")
|
|
522
|
+
col_sql = ", ".join(self.db.dialect.quote_ident(c) for c in cols)
|
|
523
|
+
values_sql = ", ".join(f":p{i + 1}" for i in range(len(cols)))
|
|
524
|
+
params = {f"p{i + 1}": data[c] for i, c in enumerate(cols)}
|
|
525
|
+
table = _qualified_table(self.model, self.db.dialect)
|
|
526
|
+
pk = self.model.pk_name()
|
|
527
|
+
if self.db.dialect.name == "postgres":
|
|
528
|
+
sql = (
|
|
529
|
+
f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql}) "
|
|
530
|
+
f"RETURNING {self.db.dialect.quote_ident(pk)}"
|
|
531
|
+
)
|
|
532
|
+
row = await self.db.fetch_one(sql, params)
|
|
533
|
+
return None if row is None else row.get(pk)
|
|
534
|
+
|
|
535
|
+
sql = f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql})"
|
|
536
|
+
await self.db.execute(sql, params)
|
|
537
|
+
if self.db.dialect.name == "sqlite":
|
|
538
|
+
row = await self.db.fetch_one("SELECT last_insert_rowid() AS id")
|
|
539
|
+
return None if row is None else row.get("id")
|
|
540
|
+
return None
|
|
541
|
+
|
|
542
|
+
async def update(self, **data: Any) -> int:
|
|
543
|
+
set_cols = [k for k in data.keys() if k in self.model._meta.fields]
|
|
544
|
+
if not set_cols:
|
|
545
|
+
return 0
|
|
546
|
+
where_sql, where_params = _where_from_q(
|
|
547
|
+
self.model, self.db.dialect, self.state.where
|
|
548
|
+
)
|
|
549
|
+
sets = []
|
|
550
|
+
params: dict[str, Any] = {}
|
|
551
|
+
idx = 1
|
|
552
|
+
for c in set_cols:
|
|
553
|
+
key = f"p{idx}"
|
|
554
|
+
idx += 1
|
|
555
|
+
sets.append(f"{self.db.dialect.quote_ident(c)} = :{key}")
|
|
556
|
+
params[key] = data[c]
|
|
557
|
+
params.update(where_params)
|
|
558
|
+
sql = (
|
|
559
|
+
f"UPDATE {_qualified_table(self.model, self.db.dialect)} "
|
|
560
|
+
f"SET {', '.join(sets)} WHERE {where_sql}"
|
|
561
|
+
)
|
|
562
|
+
return await self.db.execute(sql, params)
|
|
563
|
+
|
|
564
|
+
async def delete(self) -> int:
|
|
565
|
+
where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
|
|
566
|
+
table = _qualified_table(self.model, self.db.dialect)
|
|
567
|
+
sql = f"DELETE FROM {table} WHERE {where_sql}"
|
|
568
|
+
return await self.db.execute(sql, params)
|
|
569
|
+
|
|
570
|
+
async def _prefetch(self, instances: list[Any]) -> None:
|
|
571
|
+
for name in self.state.prefetch:
|
|
572
|
+
if name in self.model._meta.relations:
|
|
573
|
+
await self._prefetch_fk(instances, name)
|
|
574
|
+
elif name in self.model._meta.m2m:
|
|
575
|
+
await self._prefetch_m2m(instances, name)
|
|
576
|
+
else:
|
|
577
|
+
raise QueryError(f"Unknown relation for prefetch: {name}")
|
|
578
|
+
|
|
579
|
+
async def _prefetch_fk(self, instances: list[Any], rel_name: str) -> None:
|
|
580
|
+
fk_col, fk = self.model._meta.relations[rel_name]
|
|
581
|
+
to_model = fk.to
|
|
582
|
+
if isinstance(to_model, str):
|
|
583
|
+
to_model = self.model._resolve_model(to_model)
|
|
584
|
+
ids = [
|
|
585
|
+
getattr(i, fk_col)
|
|
586
|
+
for i in instances
|
|
587
|
+
if getattr(i, fk_col, None) is not None
|
|
588
|
+
]
|
|
589
|
+
ids = list(dict.fromkeys(ids))
|
|
590
|
+
if not ids:
|
|
591
|
+
for i in instances:
|
|
592
|
+
setattr(i, rel_name, None)
|
|
593
|
+
return
|
|
594
|
+
related = (
|
|
595
|
+
await AsyncQuerySet(to_model, db=self.db)
|
|
596
|
+
.filter(**{to_model.pk_name() + "__in": ids})
|
|
597
|
+
.all()
|
|
598
|
+
)
|
|
599
|
+
mapping = {getattr(r, to_model.pk_name()): r for r in related}
|
|
600
|
+
for i in instances:
|
|
601
|
+
setattr(i, rel_name, mapping.get(getattr(i, fk_col)))
|
|
602
|
+
|
|
603
|
+
async def _prefetch_m2m(self, instances: list[Any], field_name: str) -> None:
|
|
604
|
+
rel = self.model._meta.m2m[field_name]
|
|
605
|
+
to_model = rel.to
|
|
606
|
+
if isinstance(to_model, str):
|
|
607
|
+
to_model = self.model._resolve_model(to_model)
|
|
608
|
+
from_pk = self.model.pk_name()
|
|
609
|
+
to_pk = to_model.pk_name()
|
|
610
|
+
through = rel.through or f"{self.model.table_name()}_{to_model.table_name()}"
|
|
611
|
+
from_col = rel.from_field or f"{self.model.table_name()}_{from_pk}"
|
|
612
|
+
to_col = rel.to_field or f"{to_model.table_name()}_{to_pk}"
|
|
613
|
+
|
|
614
|
+
base_ids = [
|
|
615
|
+
getattr(i, from_pk)
|
|
616
|
+
for i in instances
|
|
617
|
+
if getattr(i, from_pk, None) is not None
|
|
618
|
+
]
|
|
619
|
+
base_ids = list(dict.fromkeys(base_ids))
|
|
620
|
+
if not base_ids:
|
|
621
|
+
for i in instances:
|
|
622
|
+
setattr(i, field_name, [])
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
table = self.db.dialect.quote_ident(through)
|
|
626
|
+
sql = (
|
|
627
|
+
f"SELECT {self.db.dialect.quote_ident(from_col)} AS from_id, "
|
|
628
|
+
f"{self.db.dialect.quote_ident(to_col)} AS to_id "
|
|
629
|
+
f"FROM {table} WHERE {self.db.dialect.quote_ident(from_col)} IN ("
|
|
630
|
+
)
|
|
631
|
+
params: dict[str, Any] = {}
|
|
632
|
+
keys = []
|
|
633
|
+
for idx, v in enumerate(base_ids, start=1):
|
|
634
|
+
key = f"p{idx}"
|
|
635
|
+
keys.append(f":{key}")
|
|
636
|
+
params[key] = v
|
|
637
|
+
sql += ", ".join(keys) + ")"
|
|
638
|
+
pairs = await self.db.fetch_all(sql, params)
|
|
639
|
+
to_ids = list(dict.fromkeys([p["to_id"] for p in pairs]))
|
|
640
|
+
if not to_ids:
|
|
641
|
+
for i in instances:
|
|
642
|
+
setattr(i, field_name, [])
|
|
643
|
+
return
|
|
644
|
+
|
|
645
|
+
related = (
|
|
646
|
+
await AsyncQuerySet(to_model, db=self.db)
|
|
647
|
+
.filter(**{to_pk + "__in": to_ids})
|
|
648
|
+
.all()
|
|
649
|
+
)
|
|
650
|
+
to_map = {getattr(r, to_pk): r for r in related}
|
|
651
|
+
by_from: dict[Any, list[Any]] = {}
|
|
652
|
+
for p in pairs:
|
|
653
|
+
by_from.setdefault(p["from_id"], []).append(to_map.get(p["to_id"]))
|
|
654
|
+
for i in instances:
|
|
655
|
+
setattr(
|
|
656
|
+
i,
|
|
657
|
+
field_name,
|
|
658
|
+
[x for x in by_from.get(getattr(i, from_pk), []) if x is not None],
|
|
659
|
+
)
|