django-bulk-hooks 0.2.9__py3-none-any.whl → 0.2.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,437 +1,742 @@
1
- """
2
- Bulk executor service for database operations.
3
-
4
- This service coordinates bulk database operations with validation and MTI handling.
5
- """
6
-
7
- import logging
8
- from django.db import transaction
9
- from django.db.models import AutoField
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class BulkExecutor:
15
- """
16
- Executes bulk database operations.
17
-
18
- This service coordinates validation, MTI handling, and actual database
19
- operations. It's the only service that directly calls Django ORM methods.
20
-
21
- Dependencies are explicitly injected via constructor.
22
- """
23
-
24
- def __init__(self, queryset, analyzer, mti_handler):
25
- """
26
- Initialize bulk executor with explicit dependencies.
27
-
28
- Args:
29
- queryset: Django QuerySet instance
30
- analyzer: ModelAnalyzer instance (replaces validator + field_tracker)
31
- mti_handler: MTIHandler instance
32
- """
33
- self.queryset = queryset
34
- self.analyzer = analyzer
35
- self.mti_handler = mti_handler
36
- self.model_cls = queryset.model
37
-
38
- def bulk_create(
39
- self,
40
- objs,
41
- batch_size=None,
42
- ignore_conflicts=False,
43
- update_conflicts=False,
44
- update_fields=None,
45
- unique_fields=None,
46
- **kwargs,
47
- ):
48
- """
49
- Execute bulk create operation.
50
-
51
- NOTE: Coordinator is responsible for validation before calling this method.
52
- This executor trusts that inputs have already been validated.
53
-
54
- Args:
55
- objs: List of model instances to create (pre-validated)
56
- batch_size: Number of objects to create per batch
57
- ignore_conflicts: Whether to ignore conflicts
58
- update_conflicts: Whether to update on conflict
59
- update_fields: Fields to update on conflict
60
- unique_fields: Fields to use for conflict detection
61
- **kwargs: Additional arguments
62
-
63
- Returns:
64
- List of created objects
65
- """
66
- if not objs:
67
- return objs
68
-
69
- # Check if this is an MTI model and route accordingly
70
- if self.mti_handler.is_mti_model():
71
- logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk create")
72
- # Build execution plan
73
- plan = self.mti_handler.build_create_plan(
74
- objs,
75
- batch_size=batch_size,
76
- update_conflicts=update_conflicts,
77
- update_fields=update_fields,
78
- unique_fields=unique_fields,
79
- )
80
- # Execute the plan
81
- return self._execute_mti_create_plan(plan)
82
-
83
- # Non-MTI model - use Django's native bulk_create
84
- return self._execute_bulk_create(
85
- objs,
86
- batch_size,
87
- ignore_conflicts,
88
- update_conflicts,
89
- update_fields,
90
- unique_fields,
91
- **kwargs,
92
- )
93
-
94
- def _execute_bulk_create(
95
- self,
96
- objs,
97
- batch_size=None,
98
- ignore_conflicts=False,
99
- update_conflicts=False,
100
- update_fields=None,
101
- unique_fields=None,
102
- **kwargs,
103
- ):
104
- """
105
- Execute the actual Django bulk_create.
106
-
107
- This is the only method that directly calls Django ORM.
108
- We must call the base Django QuerySet to avoid recursion.
109
- """
110
- from django.db.models import QuerySet
111
-
112
- # Create a base Django queryset (not our HookQuerySet)
113
- base_qs = QuerySet(model=self.model_cls, using=self.queryset.db)
114
-
115
- return base_qs.bulk_create(
116
- objs,
117
- batch_size=batch_size,
118
- ignore_conflicts=ignore_conflicts,
119
- update_conflicts=update_conflicts,
120
- update_fields=update_fields,
121
- unique_fields=unique_fields,
122
- )
123
-
124
- def bulk_update(self, objs, fields, batch_size=None):
125
- """
126
- Execute bulk update operation.
127
-
128
- NOTE: Coordinator is responsible for validation before calling this method.
129
- This executor trusts that inputs have already been validated.
130
-
131
- Args:
132
- objs: List of model instances to update (pre-validated)
133
- fields: List of field names to update
134
- batch_size: Number of objects to update per batch
135
-
136
- Returns:
137
- Number of objects updated
138
- """
139
- if not objs:
140
- return 0
141
-
142
- # Check if this is an MTI model and route accordingly
143
- if self.mti_handler.is_mti_model():
144
- logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk update")
145
- # Build execution plan
146
- plan = self.mti_handler.build_update_plan(objs, fields, batch_size=batch_size)
147
- # Execute the plan
148
- return self._execute_mti_update_plan(plan)
149
-
150
- # Non-MTI model - use Django's native bulk_update
151
- # Validation already done by coordinator
152
- from django.db.models import QuerySet
153
-
154
- base_qs = QuerySet(model=self.model_cls, using=self.queryset.db)
155
- return base_qs.bulk_update(objs, fields, batch_size=batch_size)
156
-
157
- # ==================== MTI PLAN EXECUTION ====================
158
-
159
- def _execute_mti_create_plan(self, plan):
160
- """
161
- Execute an MTI create plan.
162
-
163
- This is where ALL database operations happen for MTI bulk_create.
164
-
165
- Args:
166
- plan: MTICreatePlan object from MTIHandler
167
-
168
- Returns:
169
- List of created objects with PKs assigned
170
- """
171
- from django.db import transaction
172
- from django.db.models import QuerySet as BaseQuerySet
173
-
174
- if not plan:
175
- return []
176
-
177
- with transaction.atomic(using=self.queryset.db, savepoint=False):
178
- # Step 1: Create all parent objects level by level
179
- parent_instances_map = {} # Maps original obj id() -> {model: parent_instance}
180
-
181
- for parent_level in plan.parent_levels:
182
- # Bulk create parents for this level
183
- bulk_kwargs = {"batch_size": len(parent_level.objects)}
184
-
185
- if parent_level.update_conflicts:
186
- bulk_kwargs["update_conflicts"] = True
187
- bulk_kwargs["unique_fields"] = parent_level.unique_fields
188
- bulk_kwargs["update_fields"] = parent_level.update_fields
189
-
190
- # Use base QuerySet to avoid recursion
191
- base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
192
- created_parents = base_qs.bulk_create(parent_level.objects, **bulk_kwargs)
193
-
194
- # Copy generated fields back to parent objects
195
- for created_parent, parent_obj in zip(created_parents, parent_level.objects):
196
- for field in parent_level.model_class._meta.local_fields:
197
- created_value = getattr(created_parent, field.name, None)
198
- if created_value is not None:
199
- setattr(parent_obj, field.name, created_value)
200
-
201
- parent_obj._state.adding = False
202
- parent_obj._state.db = self.queryset.db
203
-
204
- # Map parents back to original objects
205
- for parent_obj in parent_level.objects:
206
- orig_obj_id = parent_level.original_object_map[id(parent_obj)]
207
- if orig_obj_id not in parent_instances_map:
208
- parent_instances_map[orig_obj_id] = {}
209
- parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
210
-
211
- # Step 2: Add parent links to child objects
212
- for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
213
- parent_instances = parent_instances_map.get(id(orig_obj), {})
214
-
215
- for parent_model, parent_instance in parent_instances.items():
216
- parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
217
- if parent_link:
218
- setattr(child_obj, parent_link.attname, parent_instance.pk)
219
- setattr(child_obj, parent_link.name, parent_instance)
220
-
221
- # Step 3: Bulk create child objects using _batched_insert (to bypass MTI check)
222
- base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
223
- base_qs._prepare_for_bulk_create(plan.child_objects)
224
-
225
- # Partition objects by PK status
226
- objs_without_pk, objs_with_pk = [], []
227
- for obj in plan.child_objects:
228
- if obj._is_pk_set():
229
- objs_with_pk.append(obj)
230
- else:
231
- objs_without_pk.append(obj)
232
-
233
- # Get fields for insert
234
- opts = plan.child_model._meta
235
- fields = [f for f in opts.local_fields if not f.generated]
236
-
237
- # Execute bulk insert
238
- if objs_with_pk:
239
- returned_columns = base_qs._batched_insert(
240
- objs_with_pk,
241
- fields,
242
- batch_size=len(objs_with_pk),
243
- )
244
- if returned_columns:
245
- for obj, results in zip(objs_with_pk, returned_columns):
246
- if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
247
- for result, field in zip(results, opts.db_returning_fields):
248
- if field != opts.pk:
249
- setattr(obj, field.attname, result)
250
- obj._state.adding = False
251
- obj._state.db = self.queryset.db
252
- else:
253
- for obj in objs_with_pk:
254
- obj._state.adding = False
255
- obj._state.db = self.queryset.db
256
-
257
- if objs_without_pk:
258
- filtered_fields = [
259
- f for f in fields
260
- if not isinstance(f, AutoField) and not f.primary_key
261
- ]
262
- returned_columns = base_qs._batched_insert(
263
- objs_without_pk,
264
- filtered_fields,
265
- batch_size=len(objs_without_pk),
266
- )
267
- if returned_columns:
268
- for obj, results in zip(objs_without_pk, returned_columns):
269
- if hasattr(opts, "db_returning_fields"):
270
- for result, field in zip(results, opts.db_returning_fields):
271
- setattr(obj, field.attname, result)
272
- obj._state.adding = False
273
- obj._state.db = self.queryset.db
274
- else:
275
- for obj in objs_without_pk:
276
- obj._state.adding = False
277
- obj._state.db = self.queryset.db
278
-
279
- created_children = plan.child_objects
280
-
281
- # Step 4: Copy PKs and auto-generated fields back to original objects
282
- pk_field_name = plan.child_model._meta.pk.name
283
-
284
- for orig_obj, child_obj in zip(plan.original_objects, created_children):
285
- # Copy PK
286
- child_pk = getattr(child_obj, pk_field_name)
287
- setattr(orig_obj, pk_field_name, child_pk)
288
-
289
- # Copy auto-generated fields from all levels
290
- parent_instances = parent_instances_map.get(id(orig_obj), {})
291
-
292
- for model_class in plan.inheritance_chain:
293
- # Get source object for this level
294
- if model_class in parent_instances:
295
- source_obj = parent_instances[model_class]
296
- elif model_class == plan.child_model:
297
- source_obj = child_obj
298
- else:
299
- continue
300
-
301
- # Copy auto-generated field values
302
- for field in model_class._meta.local_fields:
303
- if field.name == pk_field_name:
304
- continue
305
-
306
- # Skip parent link fields
307
- if hasattr(field, 'remote_field') and field.remote_field:
308
- parent_link = plan.child_model._meta.get_ancestor_link(model_class)
309
- if parent_link and field.name == parent_link.name:
310
- continue
311
-
312
- # Copy auto_now_add, auto_now, and db_returning fields
313
- if (getattr(field, 'auto_now_add', False) or
314
- getattr(field, 'auto_now', False) or
315
- getattr(field, 'db_returning', False)):
316
- source_value = getattr(source_obj, field.name, None)
317
- if source_value is not None:
318
- setattr(orig_obj, field.name, source_value)
319
-
320
- # Update object state
321
- orig_obj._state.adding = False
322
- orig_obj._state.db = self.queryset.db
323
-
324
- return plan.original_objects
325
-
326
- def _execute_mti_update_plan(self, plan):
327
- """
328
- Execute an MTI update plan.
329
-
330
- Updates each table in the inheritance chain using CASE/WHEN for bulk updates.
331
-
332
- Args:
333
- plan: MTIUpdatePlan object from MTIHandler
334
-
335
- Returns:
336
- Number of objects updated
337
- """
338
- from django.db import transaction
339
- from django.db.models import Case, Value, When, QuerySet as BaseQuerySet
340
-
341
- if not plan:
342
- return 0
343
-
344
- total_updated = 0
345
-
346
- # Get PKs for filtering
347
- root_pks = [
348
- getattr(obj, "pk", None) or getattr(obj, "id", None)
349
- for obj in plan.objects
350
- if getattr(obj, "pk", None) or getattr(obj, "id", None)
351
- ]
352
-
353
- if not root_pks:
354
- return 0
355
-
356
- with transaction.atomic(using=self.queryset.db, savepoint=False):
357
- # Update each table in the chain
358
- for field_group in plan.field_groups:
359
- if not field_group.fields:
360
- continue
361
-
362
- base_qs = BaseQuerySet(model=field_group.model_class, using=self.queryset.db)
363
-
364
- # Check if records exist
365
- existing_count = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks}).count()
366
- if existing_count == 0:
367
- continue
368
-
369
- # Build CASE statements for bulk update
370
- case_statements = {}
371
- for field_name in field_group.fields:
372
- field = field_group.model_class._meta.get_field(field_name)
373
-
374
- # Use column name for FK fields
375
- if getattr(field, 'is_relation', False) and hasattr(field, 'attname'):
376
- db_field_name = field.attname
377
- target_field = field.target_field
378
- else:
379
- db_field_name = field_name
380
- target_field = field
381
-
382
- when_statements = []
383
- for pk, obj in zip(root_pks, plan.objects):
384
- obj_pk = getattr(obj, "pk", None) or getattr(obj, "id", None)
385
- if obj_pk is None:
386
- continue
387
-
388
- value = getattr(obj, db_field_name)
389
-
390
- # For FK fields, ensure we get the actual ID value, not the related object
391
- if getattr(field, 'is_relation', False) and hasattr(field, 'attname'):
392
- # If value is a model instance, get its pk
393
- if value is not None and hasattr(value, 'pk'):
394
- value = value.pk
395
-
396
- when_statements.append(
397
- When(
398
- **{field_group.filter_field: pk},
399
- then=Value(value, output_field=target_field),
400
- )
401
- )
402
-
403
- if when_statements:
404
- case_statements[db_field_name] = Case(
405
- *when_statements, output_field=target_field
406
- )
407
-
408
- # Execute bulk update
409
- if case_statements:
410
- try:
411
- updated_count = base_qs.filter(
412
- **{f"{field_group.filter_field}__in": root_pks}
413
- ).update(**case_statements)
414
- total_updated += updated_count
415
- except Exception as e:
416
- logger.error(f"MTI bulk update failed for {field_group.model_class.__name__}: {e}")
417
-
418
- return total_updated
419
-
420
- def delete_queryset(self):
421
- """
422
- Execute delete on the queryset.
423
-
424
- NOTE: Coordinator is responsible for validation before calling this method.
425
- This executor trusts that inputs have already been validated.
426
-
427
- Returns:
428
- Tuple of (count, details dict)
429
- """
430
- if not self.queryset:
431
- return 0, {}
432
-
433
- # Execute delete via QuerySet
434
- # Validation already done by coordinator
435
- from django.db.models import QuerySet
436
-
437
- return QuerySet.delete(self.queryset)
1
+ """
2
+ Bulk executor service for database operations.
3
+
4
+ Coordinates bulk database operations with validation and MTI handling.
5
+ This service is the only component that directly calls Django ORM methods.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any
10
+ from typing import Dict
11
+ from typing import List
12
+ from typing import Optional
13
+ from typing import Set
14
+ from typing import Tuple
15
+
16
+ from django.db import transaction
17
+ from django.db.models import AutoField
18
+ from django.db.models import Case
19
+ from django.db.models import ForeignKey
20
+ from django.db.models import Model
21
+ from django.db.models import QuerySet
22
+ from django.db.models import Value
23
+ from django.db.models import When
24
+ from django.db.models.constants import OnConflict
25
+ from django.db.models.functions import Cast
26
+
27
+ from django_bulk_hooks.helpers import tag_upsert_metadata
28
+ from django_bulk_hooks.operations.field_utils import get_field_value_for_db
29
+ from django_bulk_hooks.operations.field_utils import handle_auto_now_fields_for_inheritance_chain
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class BulkExecutor:
35
+ """
36
+ Executes bulk database operations.
37
+
38
+ Coordinates validation, MTI handling, and database operations.
39
+ This is the only service that directly calls Django ORM methods.
40
+
41
+ All dependencies are explicitly injected via constructor for testability.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ queryset: QuerySet,
47
+ analyzer: Any,
48
+ mti_handler: Any,
49
+ record_classifier: Any,
50
+ ) -> None:
51
+ """
52
+ Initialize bulk executor with explicit dependencies.
53
+
54
+ Args:
55
+ queryset: Django QuerySet instance
56
+ analyzer: ModelAnalyzer instance (validation and field tracking)
57
+ mti_handler: MTIHandler instance
58
+ record_classifier: RecordClassifier instance
59
+ """
60
+ self.queryset = queryset
61
+ self.analyzer = analyzer
62
+ self.mti_handler = mti_handler
63
+ self.record_classifier = record_classifier
64
+ self.model_cls = queryset.model
65
+
66
+ def bulk_create(
67
+ self,
68
+ objs: List[Model],
69
+ batch_size: Optional[int] = None,
70
+ ignore_conflicts: bool = False,
71
+ update_conflicts: bool = False,
72
+ update_fields: Optional[List[str]] = None,
73
+ unique_fields: Optional[List[str]] = None,
74
+ existing_record_ids: Optional[Set[int]] = None,
75
+ existing_pks_map: Optional[Dict[int, int]] = None,
76
+ **kwargs: Any,
77
+ ) -> List[Model]:
78
+ """
79
+ Execute bulk create operation.
80
+
81
+ NOTE: Coordinator validates inputs before calling this method.
82
+ This executor trusts that inputs are pre-validated.
83
+
84
+ Args:
85
+ objs: Model instances to create (pre-validated)
86
+ batch_size: Objects per batch
87
+ ignore_conflicts: Whether to ignore conflicts
88
+ update_conflicts: Whether to update on conflict
89
+ update_fields: Fields to update on conflict
90
+ unique_fields: Fields for conflict detection
91
+ existing_record_ids: Pre-classified existing record IDs
92
+ existing_pks_map: Pre-classified existing PK mapping
93
+ **kwargs: Additional arguments
94
+
95
+ Returns:
96
+ List of created/updated objects
97
+ """
98
+ if not objs:
99
+ return objs
100
+
101
+ # Route to appropriate handler
102
+ if self.mti_handler.is_mti_model():
103
+ result = self._handle_mti_create(
104
+ objs=objs,
105
+ batch_size=batch_size,
106
+ update_conflicts=update_conflicts,
107
+ update_fields=update_fields,
108
+ unique_fields=unique_fields,
109
+ existing_record_ids=existing_record_ids,
110
+ existing_pks_map=existing_pks_map,
111
+ )
112
+ else:
113
+ result = self._execute_standard_bulk_create(
114
+ objs=objs,
115
+ batch_size=batch_size,
116
+ ignore_conflicts=ignore_conflicts,
117
+ update_conflicts=update_conflicts,
118
+ update_fields=update_fields,
119
+ unique_fields=unique_fields,
120
+ **kwargs,
121
+ )
122
+
123
+ # Tag upsert metadata
124
+ self._handle_upsert_metadata_tagging(
125
+ result_objects=result,
126
+ objs=objs,
127
+ update_conflicts=update_conflicts,
128
+ unique_fields=unique_fields,
129
+ existing_record_ids=existing_record_ids,
130
+ existing_pks_map=existing_pks_map,
131
+ )
132
+
133
+ return result
134
+
135
+ def bulk_update(self, objs: List[Model], fields: List[str], batch_size: Optional[int] = None) -> int:
136
+ """
137
+ Execute bulk update operation.
138
+
139
+ NOTE: Coordinator validates inputs before calling this method.
140
+ This executor trusts that inputs are pre-validated.
141
+
142
+ Args:
143
+ objs: Model instances to update (pre-validated)
144
+ fields: Field names to update
145
+ batch_size: Objects per batch
146
+
147
+ Returns:
148
+ Number of objects updated
149
+ """
150
+ if not objs:
151
+ return 0
152
+
153
+ # Ensure auto_now fields are included
154
+ fields = self._add_auto_now_fields(fields, objs)
155
+
156
+ # Route to appropriate handler
157
+ if self.mti_handler.is_mti_model():
158
+ logger.info(f"Using MTI bulk update for {self.model_cls.__name__}")
159
+ plan = self.mti_handler.build_update_plan(objs, fields, batch_size=batch_size)
160
+ return self._execute_mti_update_plan(plan)
161
+
162
+ # Standard bulk update
163
+ base_qs = self._get_base_queryset()
164
+ return base_qs.bulk_update(objs, fields, batch_size=batch_size)
165
+
166
+ def delete_queryset(self) -> Tuple[int, Dict[str, int]]:
167
+ """
168
+ Execute delete on the queryset.
169
+
170
+ NOTE: Coordinator validates inputs before calling this method.
171
+
172
+ Returns:
173
+ Tuple of (count, details dict)
174
+ """
175
+ if not self.queryset:
176
+ return 0, {}
177
+
178
+ return QuerySet.delete(self.queryset)
179
+
180
+ # ==================== Private: Create Helpers ====================
181
+
182
+ def _handle_mti_create(
183
+ self,
184
+ objs: List[Model],
185
+ batch_size: Optional[int],
186
+ update_conflicts: bool,
187
+ update_fields: Optional[List[str]],
188
+ unique_fields: Optional[List[str]],
189
+ existing_record_ids: Optional[Set[int]],
190
+ existing_pks_map: Optional[Dict[int, int]],
191
+ ) -> List[Model]:
192
+ """Handle MTI model creation with classification and planning."""
193
+ # Classify records if not pre-classified
194
+ if existing_record_ids is None or existing_pks_map is None:
195
+ existing_record_ids, existing_pks_map = self._classify_mti_records(objs, update_conflicts, unique_fields)
196
+
197
+ # Build and execute plan
198
+ plan = self.mti_handler.build_create_plan(
199
+ objs=objs,
200
+ batch_size=batch_size,
201
+ update_conflicts=update_conflicts,
202
+ update_fields=update_fields,
203
+ unique_fields=unique_fields,
204
+ existing_record_ids=existing_record_ids,
205
+ existing_pks_map=existing_pks_map,
206
+ )
207
+
208
+ return self._execute_mti_create_plan(plan)
209
+
210
+ def _classify_mti_records(
211
+ self,
212
+ objs: List[Model],
213
+ update_conflicts: bool,
214
+ unique_fields: Optional[List[str]],
215
+ ) -> Tuple[Set[int], Dict[int, int]]:
216
+ """Classify MTI records for upsert operations."""
217
+ if not update_conflicts or not unique_fields:
218
+ return set(), {}
219
+
220
+ # Find correct model to query
221
+ query_model = self.mti_handler.find_model_with_unique_fields(unique_fields)
222
+ logger.info(f"MTI upsert: querying {query_model.__name__} for unique fields {unique_fields}")
223
+
224
+ existing_record_ids, existing_pks_map = self.record_classifier.classify_for_upsert(objs, unique_fields, query_model=query_model)
225
+
226
+ logger.info(f"MTI classification: {len(existing_record_ids)} existing, {len(objs) - len(existing_record_ids)} new")
227
+
228
+ return existing_record_ids, existing_pks_map
229
+
230
+ def _execute_standard_bulk_create(
231
+ self,
232
+ objs: List[Model],
233
+ batch_size: Optional[int],
234
+ ignore_conflicts: bool,
235
+ update_conflicts: bool,
236
+ update_fields: Optional[List[str]],
237
+ unique_fields: Optional[List[str]],
238
+ **kwargs: Any,
239
+ ) -> List[Model]:
240
+ """Execute Django's native bulk_create for non-MTI models."""
241
+ base_qs = self._get_base_queryset()
242
+
243
+ return base_qs.bulk_create(
244
+ objs,
245
+ batch_size=batch_size,
246
+ ignore_conflicts=ignore_conflicts,
247
+ update_conflicts=update_conflicts,
248
+ update_fields=update_fields,
249
+ unique_fields=unique_fields,
250
+ )
251
+
252
+ def _handle_upsert_metadata_tagging(
253
+ self,
254
+ result_objects: List[Model],
255
+ objs: List[Model],
256
+ update_conflicts: bool,
257
+ unique_fields: Optional[List[str]],
258
+ existing_record_ids: Optional[Set[int]],
259
+ existing_pks_map: Optional[Dict[int, int]],
260
+ ) -> None:
261
+ """
262
+ Tag upsert metadata on result objects.
263
+
264
+ Centralizes metadata tagging logic for both MTI and non-MTI paths.
265
+
266
+ Args:
267
+ result_objects: Objects returned from bulk operation
268
+ objs: Original objects passed to bulk_create
269
+ update_conflicts: Whether this was an upsert operation
270
+ unique_fields: Fields used for conflict detection
271
+ existing_record_ids: Pre-classified existing record IDs
272
+ existing_pks_map: Pre-classified existing PK mapping
273
+ """
274
+ if not (update_conflicts and unique_fields):
275
+ return
276
+
277
+ # Classify if needed
278
+ if existing_record_ids is None or existing_pks_map is None:
279
+ existing_record_ids, existing_pks_map = self.record_classifier.classify_for_upsert(objs, unique_fields)
280
+
281
+ tag_upsert_metadata(result_objects, existing_record_ids, existing_pks_map)
282
+
283
+ # ==================== Private: Update Helpers ====================
284
+
285
+ def _add_auto_now_fields(self, fields: List[str], objs: List[Model]) -> List[str]:
286
+ """
287
+ Add auto_now fields to update list for all models in chain.
288
+
289
+ Handles both MTI and non-MTI models uniformly.
290
+
291
+ Args:
292
+ fields: Original field list
293
+ objs: Objects being updated
294
+
295
+ Returns:
296
+ Field list with auto_now fields included
297
+ """
298
+ fields = list(fields) # Copy to avoid mutation
299
+
300
+ # Get models to check
301
+ if self.mti_handler.is_mti_model():
302
+ models_to_check = self.mti_handler.get_inheritance_chain()
303
+ else:
304
+ models_to_check = [self.model_cls]
305
+
306
+ # Handle auto_now fields uniformly
307
+ auto_now_fields = handle_auto_now_fields_for_inheritance_chain(models_to_check, objs, for_update=True)
308
+
309
+ # Add to fields list if not present
310
+ for auto_now_field in auto_now_fields:
311
+ if auto_now_field not in fields:
312
+ fields.append(auto_now_field)
313
+
314
+ return fields
315
+
316
+ # ==================== Private: MTI Create Execution ====================
317
+
318
+ def _execute_mti_create_plan(self, plan: Any) -> List[Model]:
319
+ """
320
+ Execute an MTI create plan.
321
+
322
+ Handles INSERT and UPDATE for upsert operations.
323
+
324
+ Args:
325
+ plan: MTICreatePlan from MTIHandler
326
+
327
+ Returns:
328
+ List of created/updated objects with PKs assigned
329
+ """
330
+ if not plan:
331
+ return []
332
+
333
+ with transaction.atomic(using=self.queryset.db, savepoint=False):
334
+ # Step 1: Upsert all parent levels
335
+ parent_instances_map = self._upsert_parent_levels(plan)
336
+
337
+ # Step 2: Link children to parents
338
+ self._link_children_to_parents(plan, parent_instances_map)
339
+
340
+ # Step 3: Handle child objects (insert new, update existing)
341
+ self._handle_child_objects(plan)
342
+
343
+ # Step 4: Copy PKs and auto-fields back to original objects
344
+ self._copy_fields_to_original_objects(plan, parent_instances_map)
345
+
346
+ return plan.original_objects
347
+
348
+ def _upsert_parent_levels(self, plan: Any) -> Dict[int, Dict[type, Model]]:
349
+ """
350
+ Upsert all parent objects level by level.
351
+
352
+ Returns:
353
+ Mapping of original obj id() -> {model: parent_instance}
354
+ """
355
+ parent_instances_map: Dict[int, Dict[type, Model]] = {}
356
+
357
+ for parent_level in plan.parent_levels:
358
+ base_qs = QuerySet(model=parent_level.model_class, using=self.queryset.db)
359
+
360
+ # Build bulk_create kwargs
361
+ bulk_kwargs = {"batch_size": len(parent_level.objects)}
362
+
363
+ if parent_level.update_conflicts:
364
+ self._add_upsert_kwargs(bulk_kwargs, parent_level)
365
+
366
+ # Execute upsert
367
+ upserted_parents = base_qs.bulk_create(parent_level.objects, **bulk_kwargs)
368
+
369
+ # Copy generated fields back
370
+ self._copy_generated_fields(upserted_parents, parent_level.objects, parent_level.model_class)
371
+
372
+ # Map parents to original objects
373
+ self._map_parents_to_originals(parent_level, parent_instances_map)
374
+
375
+ return parent_instances_map
376
+
377
+ def _add_upsert_kwargs(self, bulk_kwargs: Dict[str, Any], parent_level: Any) -> None:
378
+ """Add upsert parameters to bulk_create kwargs."""
379
+ bulk_kwargs["update_conflicts"] = True
380
+ bulk_kwargs["unique_fields"] = parent_level.unique_fields
381
+
382
+ # Filter update fields
383
+ parent_model_fields = {field.name for field in parent_level.model_class._meta.local_fields}
384
+ filtered_update_fields = [field for field in parent_level.update_fields if field in parent_model_fields]
385
+
386
+ if filtered_update_fields:
387
+ bulk_kwargs["update_fields"] = filtered_update_fields
388
+
389
+ def _copy_generated_fields(
390
+ self,
391
+ upserted_parents: List[Model],
392
+ parent_objs: List[Model],
393
+ model_class: type[Model],
394
+ ) -> None:
395
+ """Copy generated fields from upserted objects back to parent objects."""
396
+ for upserted_parent, parent_obj in zip(upserted_parents, parent_objs):
397
+ for field in model_class._meta.local_fields:
398
+ # Use attname for FK fields to avoid queries
399
+ field_attr = field.attname if isinstance(field, ForeignKey) else field.name
400
+ upserted_value = getattr(upserted_parent, field_attr, None)
401
+ if upserted_value is not None:
402
+ setattr(parent_obj, field_attr, upserted_value)
403
+
404
+ parent_obj._state.adding = False
405
+ parent_obj._state.db = self.queryset.db
406
+
407
+ def _map_parents_to_originals(self, parent_level: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
408
+ """Map parent instances back to original objects."""
409
+ for parent_obj in parent_level.objects:
410
+ orig_obj_id = parent_level.original_object_map[id(parent_obj)]
411
+ if orig_obj_id not in parent_instances_map:
412
+ parent_instances_map[orig_obj_id] = {}
413
+ parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
414
+
415
+ def _link_children_to_parents(self, plan: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
416
+ """Link child objects to their parent objects and set PKs."""
417
+ for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
418
+ parent_instances = parent_instances_map.get(id(orig_obj), {})
419
+
420
+ for parent_model, parent_instance in parent_instances.items():
421
+ parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
422
+
423
+ if parent_link:
424
+ parent_pk = parent_instance.pk
425
+ setattr(child_obj, parent_link.attname, parent_pk)
426
+ setattr(child_obj, parent_link.name, parent_instance)
427
+ # In MTI, child PK equals parent PK
428
+ child_obj.pk = parent_pk
429
+ child_obj.id = parent_pk
430
+ else:
431
+ logger.warning(f"No parent link found for {parent_model} in {plan.child_model}")
432
+
433
+ def _handle_child_objects(self, plan: Any) -> None:
434
+ """Handle child object insertion and updates."""
435
+ base_qs = QuerySet(model=plan.child_model, using=self.queryset.db)
436
+
437
+ # Split objects: new vs existing
438
+ objs_without_pk, objs_with_pk = self._split_child_objects(plan, base_qs)
439
+
440
+ # Update existing children
441
+ if objs_with_pk and plan.update_fields:
442
+ self._update_existing_children(base_qs, objs_with_pk, plan)
443
+
444
+ # Insert new children
445
+ if objs_without_pk:
446
+ self._insert_new_children(base_qs, objs_without_pk, plan)
447
+
448
+ def _split_child_objects(self, plan: Any, base_qs: QuerySet) -> Tuple[List[Model], List[Model]]:
449
+ """Split child objects into new and existing."""
450
+ if not plan.update_conflicts:
451
+ return plan.child_objects, []
452
+
453
+ # Check which child records exist
454
+ parent_pks = [
455
+ getattr(child_obj, plan.child_model._meta.pk.attname, None)
456
+ for child_obj in plan.child_objects
457
+ if getattr(child_obj, plan.child_model._meta.pk.attname, None)
458
+ ]
459
+
460
+ existing_child_pks = set()
461
+ if parent_pks:
462
+ existing_child_pks = set(base_qs.filter(pk__in=parent_pks).values_list("pk", flat=True))
463
+
464
+ objs_without_pk = []
465
+ objs_with_pk = []
466
+
467
+ for child_obj in plan.child_objects:
468
+ child_pk = getattr(child_obj, plan.child_model._meta.pk.attname, None)
469
+ if child_pk and child_pk in existing_child_pks:
470
+ objs_with_pk.append(child_obj)
471
+ else:
472
+ objs_without_pk.append(child_obj)
473
+
474
+ return objs_without_pk, objs_with_pk
475
+
476
+ def _update_existing_children(self, base_qs: QuerySet, objs_with_pk: List[Model], plan: Any) -> None:
477
+ """Update existing child records."""
478
+ child_model_fields = {field.name for field in plan.child_model._meta.local_fields}
479
+ filtered_child_update_fields = [field for field in plan.update_fields if field in child_model_fields]
480
+
481
+ if filtered_child_update_fields:
482
+ base_qs.bulk_update(objs_with_pk, filtered_child_update_fields)
483
+
484
+ for obj in objs_with_pk:
485
+ obj._state.adding = False
486
+ obj._state.db = self.queryset.db
487
+
488
+ def _insert_new_children(self, base_qs: QuerySet, objs_without_pk: List[Model], plan: Any) -> None:
489
+ """Insert new child records using _batched_insert."""
490
+ base_qs._prepare_for_bulk_create(objs_without_pk)
491
+ opts = plan.child_model._meta
492
+
493
+ # Get fields for insertion
494
+ filtered_fields = [f for f in opts.local_fields if not f.generated]
495
+
496
+ # Build upsert kwargs
497
+ kwargs = self._build_batched_insert_kwargs(plan, len(objs_without_pk))
498
+
499
+ # Execute insert
500
+ returned_columns = base_qs._batched_insert(objs_without_pk, filtered_fields, **kwargs)
501
+
502
+ # Process returned columns
503
+ self._process_returned_columns(objs_without_pk, returned_columns, opts)
504
+
505
+ def _build_batched_insert_kwargs(self, plan: Any, batch_size: int) -> Dict[str, Any]:
506
+ """Build kwargs for _batched_insert call."""
507
+ kwargs = {"batch_size": batch_size}
508
+
509
+ if not (plan.update_conflicts and plan.child_unique_fields):
510
+ return kwargs
511
+
512
+ batched_unique_fields = plan.child_unique_fields
513
+ batched_update_fields = plan.child_update_fields
514
+
515
+ if batched_update_fields:
516
+ on_conflict = OnConflict.UPDATE
517
+ else:
518
+ # No update fields on child - use IGNORE
519
+ on_conflict = OnConflict.IGNORE
520
+ batched_update_fields = None
521
+
522
+ kwargs.update(
523
+ {
524
+ "on_conflict": on_conflict,
525
+ "update_fields": batched_update_fields,
526
+ "unique_fields": batched_unique_fields,
527
+ }
528
+ )
529
+
530
+ return kwargs
531
+
532
+ def _process_returned_columns(self, objs: List[Model], returned_columns: Any, opts: Any) -> None:
533
+ """Process returned columns from _batched_insert."""
534
+ if returned_columns:
535
+ for obj, results in zip(objs, returned_columns):
536
+ if hasattr(opts, "db_returning_fields"):
537
+ for result, field in zip(results, opts.db_returning_fields):
538
+ setattr(obj, field.attname, result)
539
+ obj._state.adding = False
540
+ obj._state.db = self.queryset.db
541
+ else:
542
+ for obj in objs:
543
+ obj._state.adding = False
544
+ obj._state.db = self.queryset.db
545
+
546
+ def _copy_fields_to_original_objects(self, plan: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
547
+ """Copy PKs and auto-generated fields to original objects."""
548
+ pk_field_name = plan.child_model._meta.pk.name
549
+
550
+ for orig_obj, child_obj in zip(plan.original_objects, plan.child_objects):
551
+ # Copy PK
552
+ child_pk = getattr(child_obj, pk_field_name)
553
+ setattr(orig_obj, pk_field_name, child_pk)
554
+
555
+ # Copy auto-generated fields from all levels
556
+ self._copy_auto_generated_fields(orig_obj, child_obj, plan, parent_instances_map, pk_field_name)
557
+
558
+ # Update state
559
+ orig_obj._state.adding = False
560
+ orig_obj._state.db = self.queryset.db
561
+
562
+ def _copy_auto_generated_fields(
563
+ self,
564
+ orig_obj: Model,
565
+ child_obj: Model,
566
+ plan: Any,
567
+ parent_instances_map: Dict[int, Dict[type, Model]],
568
+ pk_field_name: str,
569
+ ) -> None:
570
+ """Copy auto-generated fields from all inheritance levels."""
571
+ parent_instances = parent_instances_map.get(id(orig_obj), {})
572
+
573
+ for model_class in plan.inheritance_chain:
574
+ # Get source object
575
+ if model_class in parent_instances:
576
+ source_obj = parent_instances[model_class]
577
+ elif model_class == plan.child_model:
578
+ source_obj = child_obj
579
+ else:
580
+ continue
581
+
582
+ # Copy auto-generated fields
583
+ for field in model_class._meta.local_fields:
584
+ if field.name == pk_field_name:
585
+ continue
586
+
587
+ # Skip parent link fields
588
+ if self._is_parent_link_field(field, plan.child_model, model_class):
589
+ continue
590
+
591
+ # Copy auto_now, auto_now_add, and db_returning fields
592
+ if self._is_auto_generated_field(field):
593
+ source_value = getattr(source_obj, field.name, None)
594
+ if source_value is not None:
595
+ setattr(orig_obj, field.name, source_value)
596
+
597
+ def _is_parent_link_field(self, field: Any, child_model: type[Model], model_class: type[Model]) -> bool:
598
+ """Check if field is a parent link field."""
599
+ if not (hasattr(field, "remote_field") and field.remote_field):
600
+ return False
601
+
602
+ parent_link = child_model._meta.get_ancestor_link(model_class)
603
+ return parent_link and field.name == parent_link.name
604
+
605
+ def _is_auto_generated_field(self, field: Any) -> bool:
606
+ """Check if field is auto-generated."""
607
+ return getattr(field, "auto_now_add", False) or getattr(field, "auto_now", False) or getattr(field, "db_returning", False)
608
+
609
+ # ==================== Private: MTI Update Execution ====================
610
+
611
+ def _execute_mti_update_plan(self, plan: Any) -> int:
612
+ """
613
+ Execute an MTI update plan.
614
+
615
+ Updates each table in the inheritance chain using CASE/WHEN.
616
+
617
+ Args:
618
+ plan: MTIUpdatePlan from MTIHandler
619
+
620
+ Returns:
621
+ Number of objects updated
622
+ """
623
+ if not plan:
624
+ return 0
625
+
626
+ root_pks = self._get_root_pks(plan.objects)
627
+ if not root_pks:
628
+ return 0
629
+
630
+ total_updated = 0
631
+
632
+ with transaction.atomic(using=self.queryset.db, savepoint=False):
633
+ for field_group in plan.field_groups:
634
+ if not field_group.fields:
635
+ continue
636
+
637
+ updated_count = self._update_field_group(field_group, root_pks, plan.objects)
638
+ total_updated += updated_count
639
+
640
+ return total_updated
641
+
642
+ def _get_root_pks(self, objs: List[Model]) -> List[Any]:
643
+ """Extract primary keys from objects."""
644
+ return [
645
+ getattr(obj, "pk", None) or getattr(obj, "id", None) for obj in objs if getattr(obj, "pk", None) or getattr(obj, "id", None)
646
+ ]
647
+
648
+ def _update_field_group(self, field_group: Any, root_pks: List[Any], objs: List[Model]) -> int:
649
+ """Update a single field group."""
650
+ base_qs = QuerySet(model=field_group.model_class, using=self.queryset.db)
651
+
652
+ # Check if records exist
653
+ if not self._check_records_exist(base_qs, field_group, root_pks):
654
+ return 0
655
+
656
+ # Build CASE statements
657
+ case_statements = self._build_case_statements(field_group, root_pks, objs)
658
+
659
+ if not case_statements:
660
+ logger.debug(f"No CASE statements for {field_group.model_class.__name__}")
661
+ return 0
662
+
663
+ # Execute update
664
+ return self._execute_field_group_update(base_qs, field_group, root_pks, case_statements)
665
+
666
+ def _check_records_exist(self, base_qs: QuerySet, field_group: Any, root_pks: List[Any]) -> bool:
667
+ """Check if any records exist for update."""
668
+ existing_count = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks}).count()
669
+ return existing_count > 0
670
+
671
+ def _build_case_statements(self, field_group: Any, root_pks: List[Any], objs: List[Model]) -> Dict[str, Case]:
672
+ """Build CASE statements for all fields in the group."""
673
+ case_statements = {}
674
+
675
+ logger.debug(f"Building CASE statements for {field_group.model_class.__name__} with {len(field_group.fields)} fields")
676
+
677
+ for field_name in field_group.fields:
678
+ case_stmt = self._build_field_case_statement(field_name, field_group, root_pks, objs)
679
+ if case_stmt:
680
+ case_statements[field_name] = case_stmt
681
+
682
+ return case_statements
683
+
684
+ def _build_field_case_statement(
685
+ self,
686
+ field_name: str,
687
+ field_group: Any,
688
+ root_pks: List[Any],
689
+ objs: List[Model],
690
+ ) -> Optional[Case]:
691
+ """Build CASE statement for a single field."""
692
+ field = field_group.model_class._meta.get_field(field_name)
693
+ when_statements = []
694
+
695
+ for pk, obj in zip(root_pks, objs):
696
+ obj_pk = getattr(obj, "pk", None) or getattr(obj, "id", None)
697
+ if obj_pk is None:
698
+ continue
699
+
700
+ # Get and convert field value
701
+ value = get_field_value_for_db(obj, field_name, field_group.model_class)
702
+ value = field.to_python(value)
703
+
704
+ # Create WHEN with type casting
705
+ when_statement = When(
706
+ **{field_group.filter_field: pk},
707
+ then=Cast(Value(value), output_field=field),
708
+ )
709
+ when_statements.append(when_statement)
710
+
711
+ if when_statements:
712
+ return Case(*when_statements, output_field=field)
713
+
714
+ return None
715
+
716
+ def _execute_field_group_update(
717
+ self,
718
+ base_qs: QuerySet,
719
+ field_group: Any,
720
+ root_pks: List[Any],
721
+ case_statements: Dict[str, Case],
722
+ ) -> int:
723
+ """Execute the actual update query."""
724
+ logger.debug(f"Executing update for {field_group.model_class.__name__} with {len(case_statements)} fields")
725
+
726
+ try:
727
+ query_qs = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks})
728
+ updated_count = query_qs.update(**case_statements)
729
+
730
+ logger.debug(f"Updated {updated_count} records in {field_group.model_class.__name__}")
731
+
732
+ return updated_count
733
+
734
+ except Exception as e:
735
+ logger.error(f"MTI bulk update failed for {field_group.model_class.__name__}: {e}")
736
+ raise
737
+
738
+ # ==================== Private: Utilities ====================
739
+
740
+ def _get_base_queryset(self) -> QuerySet:
741
+ """Get base Django QuerySet to avoid recursion."""
742
+ return QuerySet(model=self.model_cls, using=self.queryset.db)