plain.postgres 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. plain/postgres/CHANGELOG.md +1028 -0
  2. plain/postgres/README.md +925 -0
  3. plain/postgres/__init__.py +120 -0
  4. plain/postgres/agents/.claude/rules/plain-postgres.md +78 -0
  5. plain/postgres/aggregates.py +236 -0
  6. plain/postgres/backups/__init__.py +0 -0
  7. plain/postgres/backups/cli.py +148 -0
  8. plain/postgres/backups/clients.py +94 -0
  9. plain/postgres/backups/core.py +172 -0
  10. plain/postgres/base.py +1415 -0
  11. plain/postgres/cli/__init__.py +3 -0
  12. plain/postgres/cli/db.py +142 -0
  13. plain/postgres/cli/migrations.py +1085 -0
  14. plain/postgres/config.py +18 -0
  15. plain/postgres/connection.py +1331 -0
  16. plain/postgres/connections.py +77 -0
  17. plain/postgres/constants.py +13 -0
  18. plain/postgres/constraints.py +495 -0
  19. plain/postgres/database_url.py +94 -0
  20. plain/postgres/db.py +59 -0
  21. plain/postgres/default_settings.py +38 -0
  22. plain/postgres/deletion.py +475 -0
  23. plain/postgres/dialect.py +640 -0
  24. plain/postgres/entrypoints.py +4 -0
  25. plain/postgres/enums.py +103 -0
  26. plain/postgres/exceptions.py +217 -0
  27. plain/postgres/expressions.py +1912 -0
  28. plain/postgres/fields/__init__.py +2118 -0
  29. plain/postgres/fields/encrypted.py +354 -0
  30. plain/postgres/fields/json.py +413 -0
  31. plain/postgres/fields/mixins.py +30 -0
  32. plain/postgres/fields/related.py +1192 -0
  33. plain/postgres/fields/related_descriptors.py +290 -0
  34. plain/postgres/fields/related_lookups.py +223 -0
  35. plain/postgres/fields/related_managers.py +661 -0
  36. plain/postgres/fields/reverse_descriptors.py +229 -0
  37. plain/postgres/fields/reverse_related.py +328 -0
  38. plain/postgres/fields/timezones.py +143 -0
  39. plain/postgres/forms.py +773 -0
  40. plain/postgres/functions/__init__.py +189 -0
  41. plain/postgres/functions/comparison.py +127 -0
  42. plain/postgres/functions/datetime.py +454 -0
  43. plain/postgres/functions/math.py +140 -0
  44. plain/postgres/functions/mixins.py +59 -0
  45. plain/postgres/functions/text.py +282 -0
  46. plain/postgres/functions/window.py +125 -0
  47. plain/postgres/indexes.py +286 -0
  48. plain/postgres/lookups.py +758 -0
  49. plain/postgres/meta.py +584 -0
  50. plain/postgres/migrations/__init__.py +53 -0
  51. plain/postgres/migrations/autodetector.py +1379 -0
  52. plain/postgres/migrations/exceptions.py +54 -0
  53. plain/postgres/migrations/executor.py +188 -0
  54. plain/postgres/migrations/graph.py +364 -0
  55. plain/postgres/migrations/loader.py +377 -0
  56. plain/postgres/migrations/migration.py +180 -0
  57. plain/postgres/migrations/operations/__init__.py +34 -0
  58. plain/postgres/migrations/operations/base.py +139 -0
  59. plain/postgres/migrations/operations/fields.py +373 -0
  60. plain/postgres/migrations/operations/models.py +798 -0
  61. plain/postgres/migrations/operations/special.py +184 -0
  62. plain/postgres/migrations/optimizer.py +74 -0
  63. plain/postgres/migrations/questioner.py +340 -0
  64. plain/postgres/migrations/recorder.py +119 -0
  65. plain/postgres/migrations/serializer.py +378 -0
  66. plain/postgres/migrations/state.py +882 -0
  67. plain/postgres/migrations/utils.py +147 -0
  68. plain/postgres/migrations/writer.py +302 -0
  69. plain/postgres/options.py +207 -0
  70. plain/postgres/otel.py +231 -0
  71. plain/postgres/preflight.py +336 -0
  72. plain/postgres/query.py +2242 -0
  73. plain/postgres/query_utils.py +456 -0
  74. plain/postgres/registry.py +217 -0
  75. plain/postgres/schema.py +1885 -0
  76. plain/postgres/sql/__init__.py +40 -0
  77. plain/postgres/sql/compiler.py +1869 -0
  78. plain/postgres/sql/constants.py +22 -0
  79. plain/postgres/sql/datastructures.py +222 -0
  80. plain/postgres/sql/query.py +2947 -0
  81. plain/postgres/sql/where.py +374 -0
  82. plain/postgres/test/__init__.py +0 -0
  83. plain/postgres/test/pytest.py +117 -0
  84. plain/postgres/test/utils.py +18 -0
  85. plain/postgres/transaction.py +222 -0
  86. plain/postgres/types.py +92 -0
  87. plain/postgres/types.pyi +751 -0
  88. plain/postgres/utils.py +345 -0
  89. plain_postgres-0.84.0.dist-info/METADATA +937 -0
  90. plain_postgres-0.84.0.dist-info/RECORD +93 -0
  91. plain_postgres-0.84.0.dist-info/WHEEL +4 -0
  92. plain_postgres-0.84.0.dist-info/entry_points.txt +5 -0
  93. plain_postgres-0.84.0.dist-info/licenses/LICENSE +61 -0
@@ -0,0 +1,661 @@
1
+ """
2
+ Managers for related objects.
3
+
4
+ These managers provide the API for working with collections of related objects
5
+ through foreign key and many-to-many relationships.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Callable, Iterable
14
+
15
+ from plain.postgres.base import Model
16
+ from plain.postgres.fields.related import ForeignKeyField, ManyToManyField
17
+
18
+ import builtins
19
+
20
+ from plain.postgres import transaction
21
+ from plain.postgres.db import get_connection
22
+ from plain.postgres.dialect import quote_name
23
+ from plain.postgres.expressions import Window
24
+ from plain.postgres.functions import RowNumber
25
+ from plain.postgres.lookups import GreaterThan, LessThanOrEqual
26
+ from plain.postgres.query import QuerySet
27
+ from plain.postgres.query_utils import Q
28
+ from plain.postgres.utils import resolve_callables
29
+
30
+ # TypeVar for generic manager support
31
+ T = TypeVar("T", bound="Model")
32
+ # TypeVar for custom QuerySet types (defaults to QuerySet[Any] when not specified)
33
+ QS = TypeVar("QS", bound="QuerySet[Any]", default="QuerySet[Any]")
34
+
35
+
36
+ def _filter_prefetch_queryset(
37
+ queryset: QuerySet, field_name: str, instances: Iterable[Model]
38
+ ) -> QuerySet:
39
+ filter_kwargs: dict[str, Any] = {f"{field_name}__in": instances}
40
+ predicate = Q(**filter_kwargs)
41
+ if queryset.sql_query.is_sliced:
42
+ # Use window functions for limited queryset prefetching
43
+ low_mark, high_mark = queryset.sql_query.low_mark, queryset.sql_query.high_mark
44
+ order_by = [
45
+ expr for expr, _ in queryset.sql_query.get_compiler().get_order_by()
46
+ ]
47
+ window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
48
+ predicate &= GreaterThan(window, low_mark)
49
+ if high_mark is not None:
50
+ predicate &= LessThanOrEqual(window, high_mark)
51
+ queryset.sql_query.clear_limits()
52
+ return queryset.filter(predicate)
53
+
54
+
55
+ class BaseRelatedManager(Generic[T, QS]):
56
+ """
57
+ Base class for all related object managers.
58
+
59
+ All related managers should have a 'query' property that returns a QuerySet.
60
+ """
61
+
62
+ @property
63
+ def query(self) -> QS:
64
+ """Access the QuerySet for this relationship."""
65
+ return self.get_queryset()
66
+
67
+ def get_queryset(self) -> QS:
68
+ """Return the QuerySet for this relationship."""
69
+ raise NotImplementedError("Subclasses must implement get_queryset()")
70
+
71
+
72
+ class ReverseForeignKeyManager(BaseRelatedManager[T, QS]):
73
+ """
74
+ Manager for the reverse side of a foreign key relation.
75
+
76
+ This manager adds behaviors specific to foreign key relations.
77
+ """
78
+
79
+ # Type hints for attributes
80
+ model: type[T]
81
+ instance: Model
82
+ field: ForeignKeyField
83
+ core_filters: dict[str, Model]
84
+ allow_null: bool
85
+
86
+ def __init__(
87
+ self, instance: Model, field: ForeignKeyField, related_model: type[Model]
88
+ ):
89
+ assert field.name is not None, "Field must have a name"
90
+ self.model = cast(type[T], related_model)
91
+ self.instance = instance
92
+ self.field = field
93
+ self.core_filters = {self.field.name: instance}
94
+ self.allow_null = self.field.allow_null
95
+
96
+ def _check_fk_val(self) -> None:
97
+ for field in self.field.foreign_related_fields:
98
+ if getattr(self.instance, field.attname) is None:
99
+ raise ValueError(
100
+ f'"{self.instance!r}" needs to have a value for field '
101
+ f'"{field.attname}" before this relationship can be used.'
102
+ )
103
+
104
+ def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
105
+ """
106
+ Filter the queryset for the instance this manager is bound to.
107
+ """
108
+ from plain.postgres.exceptions import FieldError
109
+
110
+ queryset._defer_next_filter = True
111
+ queryset = queryset.filter(**self.core_filters)
112
+ for field in self.field.foreign_related_fields:
113
+ val = getattr(self.instance, field.attname)
114
+ if val is None:
115
+ return queryset.none()
116
+
117
+ try:
118
+ target_field = self.field.target_field
119
+ except FieldError:
120
+ # The relationship has multiple target fields. Use a tuple
121
+ # for related object id.
122
+ rel_obj_id = tuple(
123
+ [
124
+ getattr(self.instance, target_field.attname)
125
+ for target_field in self.field.path_infos[-1].target_fields
126
+ ]
127
+ )
128
+ else:
129
+ rel_obj_id = getattr(self.instance, target_field.attname)
130
+ queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
131
+ return queryset
132
+
133
+ def _remove_prefetched_objects(self) -> None:
134
+ try:
135
+ self.instance._prefetched_objects_cache.pop(
136
+ self.field.remote_field.get_cache_name()
137
+ )
138
+ except (AttributeError, KeyError):
139
+ pass # nothing to clear from cache
140
+
141
+ def get_queryset(self) -> QS:
142
+ # Even if this relation is not to primary key, we require still primary key value.
143
+ # The wish is that the instance has been already saved to DB,
144
+ # although having a primary key value isn't a guarantee of that.
145
+ if self.instance.id is None:
146
+ raise ValueError(
147
+ f"{self.instance.__class__.__name__!r} instance needs to have a "
148
+ f"primary key value before this relationship can be used."
149
+ )
150
+ try:
151
+ return self.instance._prefetched_objects_cache[
152
+ self.field.remote_field.get_cache_name()
153
+ ]
154
+ except (AttributeError, KeyError):
155
+ queryset = self.model.query
156
+ return cast(QS, self._apply_rel_filters(queryset))
157
+
158
+ def get_prefetch_queryset(
159
+ self, instances: Iterable[Model], queryset: QuerySet | None = None
160
+ ) -> tuple[
161
+ QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
162
+ ]:
163
+ if queryset is None:
164
+ queryset = self.model.query
165
+
166
+ rel_obj_attr = self.field.get_local_related_value
167
+ instance_attr = self.field.get_foreign_related_value
168
+ instances_dict = {instance_attr(inst): inst for inst in instances}
169
+ queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
170
+
171
+ # Since we just bypassed this class' get_queryset(), we must manage
172
+ # the reverse relation manually.
173
+ for rel_obj in queryset:
174
+ if not self.field.is_cached(rel_obj):
175
+ instance = instances_dict[rel_obj_attr(rel_obj)]
176
+ setattr(rel_obj, self.field.name, instance)
177
+ cache_name = self.field.remote_field.get_cache_name()
178
+ return queryset, rel_obj_attr, instance_attr, False, cache_name, False
179
+
180
+ def add(self, *objs: T, bulk: bool = True) -> None:
181
+ self._check_fk_val()
182
+ self._remove_prefetched_objects()
183
+
184
+ def check_and_update_obj(obj: Any) -> None:
185
+ if not isinstance(obj, self.model):
186
+ raise TypeError(
187
+ f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
188
+ )
189
+ setattr(obj, self.field.name, self.instance)
190
+
191
+ if bulk:
192
+ ids = []
193
+ for obj in objs:
194
+ check_and_update_obj(obj)
195
+ if obj._state.adding:
196
+ raise ValueError(
197
+ f"{obj!r} instance isn't saved. Use bulk=False or save "
198
+ "the object first."
199
+ )
200
+ ids.append(obj.id)
201
+ self.model._model_meta.base_queryset.filter(id__in=ids).update(
202
+ **{
203
+ self.field.name: self.instance,
204
+ }
205
+ )
206
+ else:
207
+ with transaction.atomic(savepoint=False):
208
+ for obj in objs:
209
+ check_and_update_obj(obj)
210
+ obj.save()
211
+
212
+ def create(self, **kwargs: Any) -> T:
213
+ self._check_fk_val()
214
+ kwargs[self.field.name] = self.instance
215
+ return cast(T, self.model.query.create(**kwargs))
216
+
217
+ def get_or_create(self, **kwargs: Any) -> tuple[T, bool]:
218
+ self._check_fk_val()
219
+ kwargs[self.field.name] = self.instance
220
+ return cast(tuple[T, bool], self.model.query.get_or_create(**kwargs))
221
+
222
+ def update_or_create(self, **kwargs: Any) -> tuple[T, bool]:
223
+ self._check_fk_val()
224
+ kwargs[self.field.name] = self.instance
225
+ return cast(tuple[T, bool], self.model.query.update_or_create(**kwargs))
226
+
227
+ def remove(self, *objs: T, bulk: bool = True) -> None:
228
+ # remove() is only provided if the ForeignKeyField can have a value of null
229
+ if not self.allow_null:
230
+ raise AttributeError(
231
+ f"Cannot call remove() on a related manager for field "
232
+ f"{self.field.name} where null=False."
233
+ )
234
+ if not objs:
235
+ return
236
+ self._check_fk_val()
237
+ val = self.field.get_foreign_related_value(self.instance)
238
+ old_ids = set()
239
+ for obj in objs:
240
+ if not isinstance(obj, self.model):
241
+ raise TypeError(
242
+ f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
243
+ )
244
+ # Is obj actually part of this descriptor set?
245
+ if self.field.get_local_related_value(obj) == val:
246
+ old_ids.add(obj.id)
247
+ else:
248
+ raise self.field.remote_field.model.DoesNotExist(
249
+ f"{obj!r} is not related to {self.instance!r}."
250
+ )
251
+ self._clear(self.query.filter(id__in=old_ids), bulk)
252
+
253
+ def clear(self, *, bulk: bool = True) -> None:
254
+ # clear() is only provided if the ForeignKeyField can have a value of null
255
+ if not self.allow_null:
256
+ raise AttributeError(
257
+ f"Cannot call clear() on a related manager for field "
258
+ f"{self.field.name} where null=False."
259
+ )
260
+ self._check_fk_val()
261
+ self._clear(self.query, bulk)
262
+
263
+ def _clear(self, queryset: QuerySet, bulk: bool) -> None:
264
+ self._remove_prefetched_objects()
265
+ if bulk:
266
+ # `QuerySet.update()` is intrinsically atomic.
267
+ queryset.update(**{self.field.name: None})
268
+ else:
269
+ with transaction.atomic(savepoint=False):
270
+ for obj in queryset:
271
+ setattr(obj, self.field.name, None)
272
+ obj.save(update_fields=[self.field.name])
273
+
274
+ def set(self, objs: Any, *, bulk: bool = True, clear: bool = False) -> None:
275
+ self._check_fk_val()
276
+ # Force evaluation of `objs` in case it's a queryset whose value
277
+ # could be affected by `manager.clear()`. Refs #19816.
278
+ objs = tuple(objs)
279
+
280
+ if self.field.allow_null:
281
+ with transaction.atomic(savepoint=False):
282
+ if clear:
283
+ self.clear(bulk=bulk)
284
+ self.add(*objs, bulk=bulk)
285
+ else:
286
+ old_objs = set(self.query.all())
287
+ new_objs = []
288
+ for obj in objs:
289
+ if obj in old_objs:
290
+ old_objs.remove(obj)
291
+ else:
292
+ new_objs.append(obj)
293
+
294
+ self.remove(*old_objs, bulk=bulk)
295
+ self.add(*new_objs, bulk=bulk)
296
+ else:
297
+ self.add(*objs, bulk=bulk)
298
+
299
+
300
+ class ManyToManyManager(BaseRelatedManager[T, QS]):
301
+ """
302
+ Manager for both forward and reverse sides of a many-to-many relation.
303
+
304
+ This manager handles both directions of many-to-many relations with
305
+ conditional logic for symmetrical relationships (which only apply to
306
+ forward relations).
307
+ """
308
+
309
+ # Type hints for attributes
310
+ model: type[T]
311
+ instance: Model
312
+ field: ManyToManyField
313
+ through: type[Model]
314
+ query_field_name: str
315
+ prefetch_cache_name: str
316
+ source_field_name: str
317
+ target_field_name: str
318
+ symmetrical: bool
319
+ core_filters: dict[str, Any]
320
+ id_field_names: dict[str, str]
321
+ related_val: tuple[Any, ...]
322
+
323
+ def __init__(
324
+ self,
325
+ instance: Model,
326
+ field: ManyToManyField,
327
+ through: type[Model],
328
+ related_model: type[Model],
329
+ is_reverse: bool,
330
+ symmetrical: bool = False,
331
+ ):
332
+ assert field.name is not None, "Field must have a name"
333
+ # Set direction-specific attributes
334
+ if is_reverse:
335
+ # Reverse: accessing from the target model back to the source
336
+ self.model = cast(type[T], related_model)
337
+ self.query_field_name = field.name
338
+ self.prefetch_cache_name = field.related_query_name()
339
+ self.source_field_name = field.m2m_reverse_field_name()
340
+ self.target_field_name = field.m2m_field_name()
341
+ self.symmetrical = False # Reverse relations are never symmetrical
342
+ else:
343
+ # Forward: accessing from the source model to the target
344
+ self.model = cast(type[T], related_model)
345
+ self.query_field_name = field.related_query_name()
346
+ self.prefetch_cache_name = field.name
347
+ self.source_field_name = field.m2m_field_name()
348
+ self.target_field_name = field.m2m_reverse_field_name()
349
+ self.symmetrical = symmetrical
350
+
351
+ # Initialize common M2M attributes
352
+ self.instance = instance
353
+ self.through = through
354
+
355
+ # M2M through model fields are always ForeignKey
356
+ self.source_field = cast(
357
+ "ForeignKeyField",
358
+ self.through._model_meta.get_forward_field(self.source_field_name),
359
+ )
360
+ self.target_field = cast(
361
+ "ForeignKeyField",
362
+ self.through._model_meta.get_forward_field(self.target_field_name),
363
+ )
364
+
365
+ self.core_filters = {}
366
+ self.id_field_names = {}
367
+ for lh_field, rh_field in self.source_field.related_fields:
368
+ core_filter_key = f"{self.query_field_name}__{rh_field.name}"
369
+ self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
370
+ self.id_field_names[lh_field.name] = rh_field.name # type: ignore[assignment]
371
+
372
+ self.related_val = self.source_field.get_foreign_related_value(instance)
373
+ if None in self.related_val:
374
+ raise ValueError(
375
+ f'"{instance!r}" needs to have a value for field "{self.id_field_names[self.source_field_name]}" before '
376
+ "this many-to-many relationship can be used."
377
+ )
378
+ # Even if this relation is not to primary key, we require still primary key value.
379
+ if instance.id is None:
380
+ raise ValueError(
381
+ f"{instance.__class__.__name__!r} instance needs to have a primary key value before "
382
+ "a many-to-many relationship can be used."
383
+ )
384
+
385
+ def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
386
+ """Filter the queryset for the instance this manager is bound to."""
387
+ queryset._defer_next_filter = True
388
+ return queryset._next_is_sticky().filter(**self.core_filters)
389
+
390
+ def _remove_prefetched_objects(self) -> None:
391
+ try:
392
+ self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
393
+ except (AttributeError, KeyError):
394
+ pass # nothing to clear from cache
395
+
396
+ def get_queryset(self) -> QS:
397
+ try:
398
+ return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
399
+ except (AttributeError, KeyError):
400
+ queryset = self.model.query
401
+ return cast(QS, self._apply_rel_filters(queryset))
402
+
403
+ def get_prefetch_queryset(
404
+ self, instances: Iterable[Model], queryset: QuerySet | None = None
405
+ ) -> tuple[
406
+ QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
407
+ ]:
408
+ if queryset is None:
409
+ queryset = self.model.query
410
+
411
+ queryset = _filter_prefetch_queryset(
412
+ queryset._next_is_sticky(), self.query_field_name, instances
413
+ )
414
+
415
+ # M2M: need to annotate the query in order to get the primary model
416
+ # that the secondary model was actually related to.
417
+ from typing import cast
418
+
419
+ from plain.postgres.fields.related import ForeignKeyField
420
+
421
+ fk = cast(
422
+ ForeignKeyField,
423
+ self.through._model_meta.get_forward_field(self.source_field_name),
424
+ ) # M2M through model fields are always ForeignKey
425
+ join_table = fk.model.model_options.db_table
426
+ qn = quote_name
427
+ queryset = queryset.extra(
428
+ select={
429
+ f"_prefetch_related_val_{f.attname}": f"{qn(join_table)}.{qn(f.column)}"
430
+ for f in fk.local_related_fields
431
+ }
432
+ )
433
+ conn = get_connection()
434
+ return (
435
+ queryset,
436
+ lambda result: tuple(
437
+ getattr(result, f"_prefetch_related_val_{f.attname}")
438
+ for f in fk.local_related_fields
439
+ ),
440
+ lambda inst: tuple(
441
+ f.get_db_prep_value(getattr(inst, f.attname), conn)
442
+ for f in fk.foreign_related_fields
443
+ ),
444
+ False,
445
+ self.prefetch_cache_name,
446
+ False,
447
+ )
448
+
449
+ def clear(self) -> None:
450
+ with transaction.atomic(savepoint=False):
451
+ self._remove_prefetched_objects()
452
+ filters = self._build_remove_filters(self.model.query)
453
+ self.through.query.filter(filters).delete()
454
+
455
+ def set(
456
+ self,
457
+ objs: Any,
458
+ *,
459
+ clear: bool = False,
460
+ through_defaults: dict[str, Any] | None = None,
461
+ ) -> None:
462
+ # Force evaluation of `objs` in case it's a queryset whose value
463
+ # could be affected by `manager.clear()`. Refs #19816.
464
+ objs = tuple(objs)
465
+
466
+ with transaction.atomic(savepoint=False):
467
+ if clear:
468
+ self.clear()
469
+ self.add(*objs, through_defaults=through_defaults)
470
+ else:
471
+ old_ids = set(
472
+ self.query.values_list(
473
+ self.target_field.target_field.attname, flat=True
474
+ )
475
+ )
476
+
477
+ new_objs = []
478
+ for obj in objs:
479
+ fk_val = (
480
+ self.target_field.get_foreign_related_value(obj)[0]
481
+ if isinstance(obj, self.model)
482
+ else self.target_field.get_prep_value(obj)
483
+ )
484
+ if fk_val in old_ids:
485
+ old_ids.remove(fk_val)
486
+ else:
487
+ new_objs.append(obj)
488
+
489
+ self.remove(*old_ids)
490
+ self.add(*new_objs, through_defaults=through_defaults)
491
+
492
+ def create(
493
+ self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
494
+ ) -> T:
495
+ new_obj = self.model.query.create(**kwargs)
496
+ self.add(new_obj, through_defaults=through_defaults)
497
+ return cast(T, new_obj)
498
+
499
+ def get_or_create(
500
+ self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
501
+ ) -> tuple[T, bool]:
502
+ obj, created = self.model.query.get_or_create(**kwargs)
503
+ # We only need to add() if created because if we got an object back
504
+ # from get() then the relationship already exists.
505
+ if created:
506
+ self.add(obj, through_defaults=through_defaults)
507
+ return cast(T, obj), created
508
+
509
+ def update_or_create(
510
+ self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
511
+ ) -> tuple[T, bool]:
512
+ obj, created = self.model.query.update_or_create(**kwargs)
513
+ # We only need to add() if created because if we got an object back
514
+ # from get() then the relationship already exists.
515
+ if created:
516
+ self.add(obj, through_defaults=through_defaults)
517
+ return cast(T, obj), created
518
+
519
+ def _get_target_ids(self, target_field_name: str, objs: Any) -> builtins.set[Any]:
520
+ """Return the set of ids of `objs` that the target field references."""
521
+ from typing import cast
522
+
523
+ from plain.postgres import Model
524
+ from plain.postgres.fields.related import ForeignKeyField
525
+
526
+ target_ids: set[Any] = set()
527
+ target_field = cast(
528
+ ForeignKeyField,
529
+ self.through._model_meta.get_forward_field(target_field_name),
530
+ ) # M2M through model fields are always ForeignKey
531
+ for obj in objs:
532
+ if isinstance(obj, self.model):
533
+ target_id = target_field.get_foreign_related_value(obj)[0]
534
+ if target_id is None:
535
+ raise ValueError(
536
+ f'Cannot add "{obj!r}": the value for field "{target_field_name}" is None'
537
+ )
538
+ target_ids.add(target_id)
539
+ elif isinstance(obj, Model):
540
+ raise TypeError(
541
+ f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
542
+ )
543
+ else:
544
+ target_ids.add(target_field.get_prep_value(obj))
545
+ return target_ids
546
+
547
+ def _get_missing_target_ids(
548
+ self,
549
+ source_field_name: str,
550
+ target_field_name: str,
551
+ target_ids: builtins.set[Any],
552
+ ) -> builtins.set[Any]:
553
+ """Return the subset of ids of `objs` that aren't already assigned to this relationship."""
554
+ vals = self.through.query.values_list(target_field_name, flat=True).filter(
555
+ **{
556
+ source_field_name: self.related_val[0],
557
+ f"{target_field_name}__in": target_ids,
558
+ }
559
+ )
560
+ return target_ids.difference(vals)
561
+
562
+ def _add_items(
563
+ self,
564
+ source_field_name: str,
565
+ target_field_name: str,
566
+ *objs: Any,
567
+ through_defaults: dict[str, Any] | None = None,
568
+ ) -> None:
569
+ if not objs:
570
+ return
571
+
572
+ through_defaults = dict(resolve_callables(through_defaults or {}))
573
+ target_ids = self._get_target_ids(target_field_name, objs)
574
+
575
+ missing_target_ids = self._get_missing_target_ids(
576
+ source_field_name, target_field_name, target_ids
577
+ )
578
+ with transaction.atomic(savepoint=False):
579
+ # Add the ones that aren't there already.
580
+ self.through.query.bulk_create(
581
+ [
582
+ self.through(
583
+ **through_defaults,
584
+ **{
585
+ f"{source_field_name}_id": self.related_val[0],
586
+ f"{target_field_name}_id": target_id,
587
+ },
588
+ )
589
+ for target_id in missing_target_ids
590
+ ],
591
+ )
592
+
593
+ def _remove_items(
594
+ self, source_field_name: str, target_field_name: str, *objs: Any
595
+ ) -> None:
596
+ if not objs:
597
+ return
598
+
599
+ # Check that all the objects are of the right type
600
+ old_ids = set()
601
+ for obj in objs:
602
+ if isinstance(obj, self.model):
603
+ fk_val = self.target_field.get_foreign_related_value(obj)[0]
604
+ old_ids.add(fk_val)
605
+ else:
606
+ old_ids.add(obj)
607
+
608
+ with transaction.atomic(savepoint=False):
609
+ target_model_qs = self.model.query
610
+ if target_model_qs._has_filters():
611
+ old_vals = target_model_qs.filter(
612
+ **{f"{self.target_field.target_field.attname}__in": old_ids}
613
+ )
614
+ else:
615
+ old_vals = old_ids
616
+ filters = self._build_remove_filters(old_vals)
617
+ self.through.query.filter(filters).delete()
618
+
619
+ def _build_remove_filters(self, removed_vals: Any) -> Any:
620
+ filters = Q.create([(self.source_field_name, self.related_val)])
621
+ # No need to add a subquery condition if removed_vals is a QuerySet without
622
+ # filters.
623
+ removed_vals_filters = (
624
+ not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
625
+ )
626
+ if removed_vals_filters:
627
+ filters = filters & Q.create(
628
+ [(f"{self.target_field_name}__in", removed_vals)]
629
+ )
630
+ # Add symmetrical filters for forward symmetrical relations
631
+ if self.symmetrical:
632
+ symmetrical_filters = Q.create([(self.target_field_name, self.related_val)])
633
+ if removed_vals_filters:
634
+ symmetrical_filters = symmetrical_filters & Q.create(
635
+ [(f"{self.source_field_name}__in", removed_vals)]
636
+ )
637
+ filters = filters | symmetrical_filters
638
+ return filters
639
+
640
+ def add(self, *objs: T, through_defaults: dict[str, Any] | None = None) -> None:
641
+ self._remove_prefetched_objects()
642
+ with transaction.atomic(savepoint=False):
643
+ self._add_items(
644
+ self.source_field_name,
645
+ self.target_field_name,
646
+ *objs,
647
+ through_defaults=through_defaults,
648
+ )
649
+ # If this is a symmetrical m2m relation to self, add the mirror
650
+ # entry in the m2m table.
651
+ if self.symmetrical:
652
+ self._add_items(
653
+ self.target_field_name,
654
+ self.source_field_name,
655
+ *objs,
656
+ through_defaults=through_defaults,
657
+ )
658
+
659
+ def remove(self, *objs: T) -> None:
660
+ self._remove_prefetched_objects()
661
+ self._remove_items(self.source_field_name, self.target_field_name, *objs)