ns-orm 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ns_orm/query.py ADDED
@@ -0,0 +1,659 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from typing import Any, Generic, Optional, TypeVar
5
+
6
+ from ns_orm.database import AsyncDatabase, Database
7
+ from ns_orm.dialects import Dialect
8
+ from ns_orm.exceptions import DoesNotExist, MultipleObjectsReturned, QueryError
9
+ from ns_orm.expressions import Condition, Q, normalize_in_values
10
+
11
+ TModel = TypeVar("TModel")
12
+
13
+
14
+ def _table_alias() -> str:
15
+ return "t"
16
+
17
+
18
+ def _qualified_table(model: type[Any], dialect: Dialect) -> str:
19
+ table = dialect.quote_ident(model.table_name())
20
+ schema = getattr(model._meta, "schema", None)
21
+ if schema:
22
+ return f"{dialect.quote_ident(schema)}.{table}"
23
+ return table
24
+
25
+
26
+ def _where_from_q(
27
+ model: type[Any], dialect: Dialect, q: Q
28
+ ) -> tuple[str, dict[str, Any]]:
29
+ def _lookup_compiler(
30
+ *,
31
+ lookup: str,
32
+ value: Any,
33
+ table_alias: str,
34
+ param_prefix: str,
35
+ start_index: int,
36
+ ) -> tuple[Condition, int]:
37
+ parts = lookup.split("__")
38
+ field = parts[0]
39
+ op = parts[1] if len(parts) > 1 else "exact"
40
+ if field not in model._meta.fields:
41
+ raise QueryError(f"Unknown field: {field}")
42
+
43
+ col = f"{table_alias}.{dialect.quote_ident(field)}"
44
+
45
+ if op == "exact":
46
+ key = f"{param_prefix}{start_index}"
47
+ return Condition(
48
+ sql=f"{col} = :{key}", params={key: value}
49
+ ), start_index + 1
50
+ if op in {"gt", "gte", "lt", "lte"}:
51
+ key = f"{param_prefix}{start_index}"
52
+ op_map = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
53
+ return Condition(
54
+ sql=f"{col} {op_map} :{key}", params={key: value}
55
+ ), start_index + 1
56
+ if op == "in":
57
+ values = normalize_in_values(value)
58
+ if not values:
59
+ return Condition(sql="1=0", params={}), start_index
60
+ keys = []
61
+ params: dict[str, Any] = {}
62
+ idx = start_index
63
+ for v in values:
64
+ key = f"{param_prefix}{idx}"
65
+ idx += 1
66
+ keys.append(f":{key}")
67
+ params[key] = v
68
+ return Condition(sql=f"{col} IN ({', '.join(keys)})", params=params), idx
69
+ if op in {
70
+ "contains",
71
+ "icontains",
72
+ "startswith",
73
+ "istartswith",
74
+ "endswith",
75
+ "iendswith",
76
+ }:
77
+ key = f"{param_prefix}{start_index}"
78
+ if op.endswith("contains"):
79
+ pattern = f"%{value}%"
80
+ elif op.endswith("startswith"):
81
+ pattern = f"{value}%"
82
+ else:
83
+ pattern = f"%{value}"
84
+ if op.startswith("i"):
85
+ return (
86
+ Condition(
87
+ sql=f"LOWER({col}) LIKE LOWER(:{key})", params={key: pattern}
88
+ ),
89
+ start_index + 1,
90
+ )
91
+ return Condition(
92
+ sql=f"{col} LIKE :{key}", params={key: pattern}
93
+ ), start_index + 1
94
+ if op == "isnull":
95
+ is_null = bool(value)
96
+ return (
97
+ Condition(
98
+ sql=f"{col} IS {'NULL' if is_null else 'NOT NULL'}", params={}
99
+ ),
100
+ start_index,
101
+ )
102
+ raise QueryError(f"Unsupported lookup: {lookup}")
103
+
104
+ cond, _ = q.compile(
105
+ table_alias=_table_alias(),
106
+ param_prefix="p",
107
+ start_index=1,
108
+ lookup_compiler=_lookup_compiler,
109
+ )
110
+ return cond.sql, dict(cond.params)
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class QueryState:
115
+ where: Q
116
+ order_by: tuple[str, ...]
117
+ limit: Optional[int]
118
+ offset: Optional[int]
119
+ prefetch: tuple[str, ...]
120
+
121
+
122
+ class QuerySet(Generic[TModel]):
123
+ def __init__(
124
+ self, model: type[TModel], *, db: Database, state: Optional[QueryState] = None
125
+ ):
126
+ self.model = model
127
+ self.db = db
128
+ self.state = state or QueryState(
129
+ where=Q(), order_by=(), limit=None, offset=None, prefetch=()
130
+ )
131
+
132
+ def filter(self, **lookups: Any) -> QuerySet[TModel]:
133
+ q = Q(**lookups)
134
+ combined = self.state.where & q if not self.state.where.is_empty() else q
135
+ return QuerySet(
136
+ self.model, db=self.db, state=replace(self.state, where=combined)
137
+ )
138
+
139
+ def exclude(self, **lookups: Any) -> QuerySet[TModel]:
140
+ q = ~Q(**lookups)
141
+ combined = self.state.where & q if not self.state.where.is_empty() else q
142
+ return QuerySet(
143
+ self.model, db=self.db, state=replace(self.state, where=combined)
144
+ )
145
+
146
+ def order_by(self, *fields: str) -> QuerySet[TModel]:
147
+ return QuerySet(
148
+ self.model, db=self.db, state=replace(self.state, order_by=fields)
149
+ )
150
+
151
+ def limit(self, n: int) -> QuerySet[TModel]:
152
+ return QuerySet(self.model, db=self.db, state=replace(self.state, limit=int(n)))
153
+
154
+ def offset(self, n: int) -> QuerySet[TModel]:
155
+ return QuerySet(
156
+ self.model, db=self.db, state=replace(self.state, offset=int(n))
157
+ )
158
+
159
+ def prefetch_related(self, *names: str) -> QuerySet[TModel]:
160
+ merged = tuple(dict.fromkeys(self.state.prefetch + tuple(names)))
161
+ return QuerySet(
162
+ self.model, db=self.db, state=replace(self.state, prefetch=merged)
163
+ )
164
+
165
+ def select_related(self, *names: str) -> QuerySet[TModel]:
166
+ return self.prefetch_related(*names)
167
+
168
+ def _select_columns(self) -> list[str]:
169
+ cols = []
170
+ alias = _table_alias()
171
+ for name in self.model._meta.fields:
172
+ col = self.db.dialect.quote_ident(name)
173
+ cols.append(f"{alias}.{col} AS {col}")
174
+ return cols
175
+
176
+ def _order_by_sql(self) -> str:
177
+ if not self.state.order_by:
178
+ return ""
179
+ alias = _table_alias()
180
+ parts: list[str] = []
181
+ for f in self.state.order_by:
182
+ direction = "ASC"
183
+ name = f
184
+ if f.startswith("-"):
185
+ direction = "DESC"
186
+ name = f[1:]
187
+ if name not in self.model._meta.fields:
188
+ raise QueryError(f"Unknown order_by field: {name}")
189
+ parts.append(f"{alias}.{self.db.dialect.quote_ident(name)} {direction}")
190
+ return " ORDER BY " + ", ".join(parts)
191
+
192
+ def _select_sql(self) -> tuple[str, dict[str, Any]]:
193
+ where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
194
+ cols = ", ".join(self._select_columns())
195
+ sql = (
196
+ f"SELECT {cols} FROM {_qualified_table(self.model, self.db.dialect)} "
197
+ f"{_table_alias()} WHERE {where_sql}"
198
+ )
199
+ sql += self._order_by_sql()
200
+ sql = self.db.dialect.apply_limit_offset(
201
+ sql, self.state.limit, self.state.offset
202
+ )
203
+ return sql, params
204
+
205
+ def all(self) -> list[TModel]:
206
+ sql, params = self._select_sql()
207
+ rows = self.db.fetch_all(sql, params)
208
+ instances = [self.model.parse_obj(r) for r in rows]
209
+ if self.state.prefetch and instances:
210
+ self._prefetch(instances)
211
+ return instances
212
+
213
+ def first(self) -> Optional[TModel]:
214
+ items = self.limit(1).all()
215
+ return items[0] if items else None
216
+
217
+ def get(self, **lookups: Any) -> TModel:
218
+ items = self.filter(**lookups).limit(2).all()
219
+ if not items:
220
+ raise DoesNotExist(f"{self.model.__name__} matching query does not exist")
221
+ if len(items) > 1:
222
+ raise MultipleObjectsReturned(f"Multiple {self.model.__name__} returned")
223
+ return items[0]
224
+
225
+ def create(self, **data: Any) -> TModel:
226
+ inst = self.model(**data)
227
+ self.save_instance(inst)
228
+ return inst
229
+
230
+ def save_instance(self, inst: Any) -> Any:
231
+ pk_name = inst.pk_name()
232
+ pk_value = getattr(inst, pk_name, None)
233
+ data = inst.to_db_dict()
234
+ if pk_value is None:
235
+ new_pk = self._insert(data)
236
+ if new_pk is not None:
237
+ setattr(inst, pk_name, new_pk)
238
+ return inst
239
+ self.filter(**{pk_name: pk_value}).update(**data)
240
+ return inst
241
+
242
+ def _insert(self, data: dict[str, Any]) -> Any:
243
+ cols = [
244
+ c
245
+ for c in data.keys()
246
+ if c in self.model._meta.fields and data[c] is not None
247
+ ]
248
+ if not cols:
249
+ raise QueryError("No insertable columns provided")
250
+ col_sql = ", ".join(self.db.dialect.quote_ident(c) for c in cols)
251
+ values_sql = ", ".join(f":p{i + 1}" for i in range(len(cols)))
252
+ params = {f"p{i + 1}": data[c] for i, c in enumerate(cols)}
253
+ table = _qualified_table(self.model, self.db.dialect)
254
+ pk = self.model.pk_name()
255
+ if self.db.dialect.name == "postgres":
256
+ sql = (
257
+ f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql}) "
258
+ f"RETURNING {self.db.dialect.quote_ident(pk)}"
259
+ )
260
+ row = self.db.fetch_one(sql, params)
261
+ return None if row is None else row.get(pk)
262
+
263
+ sql = f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql})"
264
+ self.db.execute(sql, params)
265
+ if self.db.dialect.name == "sqlite":
266
+ row = self.db.fetch_one("SELECT last_insert_rowid() AS id")
267
+ return None if row is None else row.get("id")
268
+ return None
269
+
270
+ def update(self, **data: Any) -> int:
271
+ set_cols = [k for k in data.keys() if k in self.model._meta.fields]
272
+ if not set_cols:
273
+ return 0
274
+ where_sql, where_params = _where_from_q(
275
+ self.model, self.db.dialect, self.state.where
276
+ )
277
+ sets = []
278
+ params: dict[str, Any] = {}
279
+ idx = 1
280
+ for c in set_cols:
281
+ key = f"p{idx}"
282
+ idx += 1
283
+ sets.append(f"{self.db.dialect.quote_ident(c)} = :{key}")
284
+ params[key] = data[c]
285
+ params.update(where_params)
286
+ sql = (
287
+ f"UPDATE {_qualified_table(self.model, self.db.dialect)} "
288
+ f"SET {', '.join(sets)} WHERE {where_sql}"
289
+ )
290
+ return self.db.execute(sql, params)
291
+
292
+ def delete(self) -> int:
293
+ where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
294
+ table = _qualified_table(self.model, self.db.dialect)
295
+ sql = f"DELETE FROM {table} WHERE {where_sql}"
296
+ return self.db.execute(sql, params)
297
+
298
+ def _prefetch(self, instances: list[Any]) -> None:
299
+ for name in self.state.prefetch:
300
+ if name in self.model._meta.relations:
301
+ self._prefetch_fk(instances, name)
302
+ elif name in self.model._meta.m2m:
303
+ self._prefetch_m2m(instances, name)
304
+ else:
305
+ raise QueryError(f"Unknown relation for prefetch: {name}")
306
+
307
+ def _prefetch_fk(self, instances: list[Any], rel_name: str) -> None:
308
+ fk_col, fk = self.model._meta.relations[rel_name]
309
+ to_model = fk.to
310
+ if isinstance(to_model, str):
311
+ to_model = self.model._resolve_model(to_model)
312
+ ids = [
313
+ getattr(i, fk_col)
314
+ for i in instances
315
+ if getattr(i, fk_col, None) is not None
316
+ ]
317
+ ids = list(dict.fromkeys(ids))
318
+ if not ids:
319
+ for i in instances:
320
+ setattr(i, rel_name, None)
321
+ return
322
+ related = (
323
+ QuerySet(to_model, db=self.db)
324
+ .filter(**{to_model.pk_name() + "__in": ids})
325
+ .all()
326
+ )
327
+ mapping = {getattr(r, to_model.pk_name()): r for r in related}
328
+ for i in instances:
329
+ setattr(i, rel_name, mapping.get(getattr(i, fk_col)))
330
+
331
+ def _prefetch_m2m(self, instances: list[Any], field_name: str) -> None:
332
+ rel = self.model._meta.m2m[field_name]
333
+ to_model = rel.to
334
+ if isinstance(to_model, str):
335
+ to_model = self.model._resolve_model(to_model)
336
+ from_pk = self.model.pk_name()
337
+ to_pk = to_model.pk_name()
338
+ through = rel.through or f"{self.model.table_name()}_{to_model.table_name()}"
339
+ from_col = rel.from_field or f"{self.model.table_name()}_{from_pk}"
340
+ to_col = rel.to_field or f"{to_model.table_name()}_{to_pk}"
341
+
342
+ base_ids = [
343
+ getattr(i, from_pk)
344
+ for i in instances
345
+ if getattr(i, from_pk, None) is not None
346
+ ]
347
+ base_ids = list(dict.fromkeys(base_ids))
348
+ if not base_ids:
349
+ for i in instances:
350
+ setattr(i, field_name, [])
351
+ return
352
+
353
+ table = self.db.dialect.quote_ident(through)
354
+ sql = (
355
+ f"SELECT {self.db.dialect.quote_ident(from_col)} AS from_id, "
356
+ f"{self.db.dialect.quote_ident(to_col)} AS to_id "
357
+ f"FROM {table} WHERE {self.db.dialect.quote_ident(from_col)} IN ("
358
+ )
359
+ params: dict[str, Any] = {}
360
+ keys = []
361
+ for idx, v in enumerate(base_ids, start=1):
362
+ key = f"p{idx}"
363
+ keys.append(f":{key}")
364
+ params[key] = v
365
+ sql += ", ".join(keys) + ")"
366
+ pairs = self.db.fetch_all(sql, params)
367
+ to_ids = list(dict.fromkeys([p["to_id"] for p in pairs]))
368
+ if not to_ids:
369
+ for i in instances:
370
+ setattr(i, field_name, [])
371
+ return
372
+
373
+ related = (
374
+ QuerySet(to_model, db=self.db).filter(**{to_pk + "__in": to_ids}).all()
375
+ )
376
+ to_map = {getattr(r, to_pk): r for r in related}
377
+ by_from: dict[Any, list[Any]] = {}
378
+ for p in pairs:
379
+ by_from.setdefault(p["from_id"], []).append(to_map.get(p["to_id"]))
380
+ for i in instances:
381
+ setattr(
382
+ i,
383
+ field_name,
384
+ [x for x in by_from.get(getattr(i, from_pk), []) if x is not None],
385
+ )
386
+
387
+
388
+ class AsyncQuerySet(Generic[TModel]):
389
+ def __init__(
390
+ self,
391
+ model: type[TModel],
392
+ *,
393
+ db: AsyncDatabase,
394
+ state: Optional[QueryState] = None,
395
+ ) -> None:
396
+ self.model = model
397
+ self.db = db
398
+ self.state = state or QueryState(
399
+ where=Q(), order_by=(), limit=None, offset=None, prefetch=()
400
+ )
401
+
402
+ def filter(self, **lookups: Any) -> AsyncQuerySet[TModel]:
403
+ q = Q(**lookups)
404
+ combined = self.state.where & q if not self.state.where.is_empty() else q
405
+ return AsyncQuerySet(
406
+ self.model, db=self.db, state=replace(self.state, where=combined)
407
+ )
408
+
409
+ def exclude(self, **lookups: Any) -> AsyncQuerySet[TModel]:
410
+ q = ~Q(**lookups)
411
+ combined = self.state.where & q if not self.state.where.is_empty() else q
412
+ return AsyncQuerySet(
413
+ self.model, db=self.db, state=replace(self.state, where=combined)
414
+ )
415
+
416
+ def order_by(self, *fields: str) -> AsyncQuerySet[TModel]:
417
+ return AsyncQuerySet(
418
+ self.model, db=self.db, state=replace(self.state, order_by=fields)
419
+ )
420
+
421
+ def limit(self, n: int) -> AsyncQuerySet[TModel]:
422
+ return AsyncQuerySet(
423
+ self.model, db=self.db, state=replace(self.state, limit=int(n))
424
+ )
425
+
426
+ def offset(self, n: int) -> AsyncQuerySet[TModel]:
427
+ return AsyncQuerySet(
428
+ self.model, db=self.db, state=replace(self.state, offset=int(n))
429
+ )
430
+
431
+ def prefetch_related(self, *names: str) -> AsyncQuerySet[TModel]:
432
+ merged = tuple(dict.fromkeys(self.state.prefetch + tuple(names)))
433
+ return AsyncQuerySet(
434
+ self.model, db=self.db, state=replace(self.state, prefetch=merged)
435
+ )
436
+
437
+ def select_related(self, *names: str) -> AsyncQuerySet[TModel]:
438
+ return self.prefetch_related(*names)
439
+
440
+ def _select_columns(self) -> list[str]:
441
+ cols = []
442
+ alias = _table_alias()
443
+ for name in self.model._meta.fields:
444
+ col = self.db.dialect.quote_ident(name)
445
+ cols.append(f"{alias}.{col} AS {col}")
446
+ return cols
447
+
448
+ def _order_by_sql(self) -> str:
449
+ if not self.state.order_by:
450
+ return ""
451
+ alias = _table_alias()
452
+ parts: list[str] = []
453
+ for f in self.state.order_by:
454
+ direction = "ASC"
455
+ name = f
456
+ if f.startswith("-"):
457
+ direction = "DESC"
458
+ name = f[1:]
459
+ if name not in self.model._meta.fields:
460
+ raise QueryError(f"Unknown order_by field: {name}")
461
+ parts.append(f"{alias}.{self.db.dialect.quote_ident(name)} {direction}")
462
+ return " ORDER BY " + ", ".join(parts)
463
+
464
+ def _select_sql(self) -> tuple[str, dict[str, Any]]:
465
+ where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
466
+ cols = ", ".join(self._select_columns())
467
+ sql = (
468
+ f"SELECT {cols} FROM {_qualified_table(self.model, self.db.dialect)} "
469
+ f"{_table_alias()} WHERE {where_sql}"
470
+ )
471
+ sql += self._order_by_sql()
472
+ sql = self.db.dialect.apply_limit_offset(
473
+ sql, self.state.limit, self.state.offset
474
+ )
475
+ return sql, params
476
+
477
+ async def all(self) -> list[TModel]:
478
+ sql, params = self._select_sql()
479
+ rows = await self.db.fetch_all(sql, params)
480
+ instances = [self.model.parse_obj(r) for r in rows]
481
+ if self.state.prefetch and instances:
482
+ await self._prefetch(instances)
483
+ return instances
484
+
485
+ async def first(self) -> Optional[TModel]:
486
+ items = await self.limit(1).all()
487
+ return items[0] if items else None
488
+
489
+ async def get(self, **lookups: Any) -> TModel:
490
+ items = await self.filter(**lookups).limit(2).all()
491
+ if not items:
492
+ raise DoesNotExist(f"{self.model.__name__} matching query does not exist")
493
+ if len(items) > 1:
494
+ raise MultipleObjectsReturned(f"Multiple {self.model.__name__} returned")
495
+ return items[0]
496
+
497
+ async def create(self, **data: Any) -> TModel:
498
+ inst = self.model(**data)
499
+ await self.save_instance(inst)
500
+ return inst
501
+
502
+ async def save_instance(self, inst: Any) -> Any:
503
+ pk_name = inst.pk_name()
504
+ pk_value = getattr(inst, pk_name, None)
505
+ data = inst.to_db_dict()
506
+ if pk_value is None:
507
+ new_pk = await self._insert(data)
508
+ if new_pk is not None:
509
+ setattr(inst, pk_name, new_pk)
510
+ return inst
511
+ await self.filter(**{pk_name: pk_value}).update(**data)
512
+ return inst
513
+
514
+ async def _insert(self, data: dict[str, Any]) -> Any:
515
+ cols = [
516
+ c
517
+ for c in data.keys()
518
+ if c in self.model._meta.fields and data[c] is not None
519
+ ]
520
+ if not cols:
521
+ raise QueryError("No insertable columns provided")
522
+ col_sql = ", ".join(self.db.dialect.quote_ident(c) for c in cols)
523
+ values_sql = ", ".join(f":p{i + 1}" for i in range(len(cols)))
524
+ params = {f"p{i + 1}": data[c] for i, c in enumerate(cols)}
525
+ table = _qualified_table(self.model, self.db.dialect)
526
+ pk = self.model.pk_name()
527
+ if self.db.dialect.name == "postgres":
528
+ sql = (
529
+ f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql}) "
530
+ f"RETURNING {self.db.dialect.quote_ident(pk)}"
531
+ )
532
+ row = await self.db.fetch_one(sql, params)
533
+ return None if row is None else row.get(pk)
534
+
535
+ sql = f"INSERT INTO {table} ({col_sql}) VALUES ({values_sql})"
536
+ await self.db.execute(sql, params)
537
+ if self.db.dialect.name == "sqlite":
538
+ row = await self.db.fetch_one("SELECT last_insert_rowid() AS id")
539
+ return None if row is None else row.get("id")
540
+ return None
541
+
542
+ async def update(self, **data: Any) -> int:
543
+ set_cols = [k for k in data.keys() if k in self.model._meta.fields]
544
+ if not set_cols:
545
+ return 0
546
+ where_sql, where_params = _where_from_q(
547
+ self.model, self.db.dialect, self.state.where
548
+ )
549
+ sets = []
550
+ params: dict[str, Any] = {}
551
+ idx = 1
552
+ for c in set_cols:
553
+ key = f"p{idx}"
554
+ idx += 1
555
+ sets.append(f"{self.db.dialect.quote_ident(c)} = :{key}")
556
+ params[key] = data[c]
557
+ params.update(where_params)
558
+ sql = (
559
+ f"UPDATE {_qualified_table(self.model, self.db.dialect)} "
560
+ f"SET {', '.join(sets)} WHERE {where_sql}"
561
+ )
562
+ return await self.db.execute(sql, params)
563
+
564
+ async def delete(self) -> int:
565
+ where_sql, params = _where_from_q(self.model, self.db.dialect, self.state.where)
566
+ table = _qualified_table(self.model, self.db.dialect)
567
+ sql = f"DELETE FROM {table} WHERE {where_sql}"
568
+ return await self.db.execute(sql, params)
569
+
570
+ async def _prefetch(self, instances: list[Any]) -> None:
571
+ for name in self.state.prefetch:
572
+ if name in self.model._meta.relations:
573
+ await self._prefetch_fk(instances, name)
574
+ elif name in self.model._meta.m2m:
575
+ await self._prefetch_m2m(instances, name)
576
+ else:
577
+ raise QueryError(f"Unknown relation for prefetch: {name}")
578
+
579
+ async def _prefetch_fk(self, instances: list[Any], rel_name: str) -> None:
580
+ fk_col, fk = self.model._meta.relations[rel_name]
581
+ to_model = fk.to
582
+ if isinstance(to_model, str):
583
+ to_model = self.model._resolve_model(to_model)
584
+ ids = [
585
+ getattr(i, fk_col)
586
+ for i in instances
587
+ if getattr(i, fk_col, None) is not None
588
+ ]
589
+ ids = list(dict.fromkeys(ids))
590
+ if not ids:
591
+ for i in instances:
592
+ setattr(i, rel_name, None)
593
+ return
594
+ related = (
595
+ await AsyncQuerySet(to_model, db=self.db)
596
+ .filter(**{to_model.pk_name() + "__in": ids})
597
+ .all()
598
+ )
599
+ mapping = {getattr(r, to_model.pk_name()): r for r in related}
600
+ for i in instances:
601
+ setattr(i, rel_name, mapping.get(getattr(i, fk_col)))
602
+
603
+ async def _prefetch_m2m(self, instances: list[Any], field_name: str) -> None:
604
+ rel = self.model._meta.m2m[field_name]
605
+ to_model = rel.to
606
+ if isinstance(to_model, str):
607
+ to_model = self.model._resolve_model(to_model)
608
+ from_pk = self.model.pk_name()
609
+ to_pk = to_model.pk_name()
610
+ through = rel.through or f"{self.model.table_name()}_{to_model.table_name()}"
611
+ from_col = rel.from_field or f"{self.model.table_name()}_{from_pk}"
612
+ to_col = rel.to_field or f"{to_model.table_name()}_{to_pk}"
613
+
614
+ base_ids = [
615
+ getattr(i, from_pk)
616
+ for i in instances
617
+ if getattr(i, from_pk, None) is not None
618
+ ]
619
+ base_ids = list(dict.fromkeys(base_ids))
620
+ if not base_ids:
621
+ for i in instances:
622
+ setattr(i, field_name, [])
623
+ return
624
+
625
+ table = self.db.dialect.quote_ident(through)
626
+ sql = (
627
+ f"SELECT {self.db.dialect.quote_ident(from_col)} AS from_id, "
628
+ f"{self.db.dialect.quote_ident(to_col)} AS to_id "
629
+ f"FROM {table} WHERE {self.db.dialect.quote_ident(from_col)} IN ("
630
+ )
631
+ params: dict[str, Any] = {}
632
+ keys = []
633
+ for idx, v in enumerate(base_ids, start=1):
634
+ key = f"p{idx}"
635
+ keys.append(f":{key}")
636
+ params[key] = v
637
+ sql += ", ".join(keys) + ")"
638
+ pairs = await self.db.fetch_all(sql, params)
639
+ to_ids = list(dict.fromkeys([p["to_id"] for p in pairs]))
640
+ if not to_ids:
641
+ for i in instances:
642
+ setattr(i, field_name, [])
643
+ return
644
+
645
+ related = (
646
+ await AsyncQuerySet(to_model, db=self.db)
647
+ .filter(**{to_pk + "__in": to_ids})
648
+ .all()
649
+ )
650
+ to_map = {getattr(r, to_pk): r for r in related}
651
+ by_from: dict[Any, list[Any]] = {}
652
+ for p in pairs:
653
+ by_from.setdefault(p["from_id"], []).append(to_map.get(p["to_id"]))
654
+ for i in instances:
655
+ setattr(
656
+ i,
657
+ field_name,
658
+ [x for x in by_from.get(getattr(i, from_pk), []) if x is not None],
659
+ )