django-bulk-hooks 0.2.44__py3-none-any.whl → 0.2.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,33 @@
1
1
  """
2
2
  Bulk executor service for database operations.
3
3
 
4
- This service coordinates bulk database operations with validation and MTI handling.
4
+ Coordinates bulk database operations with validation and MTI handling.
5
+ This service is the only component that directly calls Django ORM methods.
5
6
  """
6
7
 
7
8
  import logging
9
+ from typing import Any
10
+ from typing import Dict
11
+ from typing import List
12
+ from typing import Optional
13
+ from typing import Set
14
+ from typing import Tuple
8
15
 
9
16
  from django.db import transaction
10
- from django.db.models import AutoField, ForeignKey, Case, When, Value
17
+ from django.db.models import AutoField
18
+ from django.db.models import Case
19
+ from django.db.models import ForeignKey
20
+ from django.db.models import Model
21
+ from django.db.models import QuerySet
22
+ from django.db.models import Value
23
+ from django.db.models import When
24
+ from django.db.models.constants import OnConflict
11
25
  from django.db.models.functions import Cast
12
26
 
27
+ from django_bulk_hooks.helpers import tag_upsert_metadata
28
+ from django_bulk_hooks.operations.field_utils import get_field_value_for_db
29
+ from django_bulk_hooks.operations.field_utils import handle_auto_now_fields_for_inheritance_chain
30
+
13
31
  logger = logging.getLogger(__name__)
14
32
 
15
33
 
@@ -17,19 +35,25 @@ class BulkExecutor:
17
35
  """
18
36
  Executes bulk database operations.
19
37
 
20
- This service coordinates validation, MTI handling, and actual database
21
- operations. It's the only service that directly calls Django ORM methods.
38
+ Coordinates validation, MTI handling, and database operations.
39
+ This is the only service that directly calls Django ORM methods.
22
40
 
23
- Dependencies are explicitly injected via constructor.
41
+ All dependencies are explicitly injected via constructor for testability.
24
42
  """
25
43
 
26
- def __init__(self, queryset, analyzer, mti_handler, record_classifier):
44
+ def __init__(
45
+ self,
46
+ queryset: QuerySet,
47
+ analyzer: Any,
48
+ mti_handler: Any,
49
+ record_classifier: Any,
50
+ ) -> None:
27
51
  """
28
52
  Initialize bulk executor with explicit dependencies.
29
53
 
30
54
  Args:
31
55
  queryset: Django QuerySet instance
32
- analyzer: ModelAnalyzer instance (replaces validator + field_tracker)
56
+ analyzer: ModelAnalyzer instance (validation and field tracking)
33
57
  mti_handler: MTIHandler instance
34
58
  record_classifier: RecordClassifier instance
35
59
  """
@@ -41,51 +65,43 @@ class BulkExecutor:
41
65
 
42
66
  def bulk_create(
43
67
  self,
44
- objs,
45
- batch_size=None,
46
- ignore_conflicts=False,
47
- update_conflicts=False,
48
- update_fields=None,
49
- unique_fields=None,
50
- existing_record_ids=None,
51
- existing_pks_map=None,
52
- **kwargs,
53
- ):
68
+ objs: List[Model],
69
+ batch_size: Optional[int] = None,
70
+ ignore_conflicts: bool = False,
71
+ update_conflicts: bool = False,
72
+ update_fields: Optional[List[str]] = None,
73
+ unique_fields: Optional[List[str]] = None,
74
+ existing_record_ids: Optional[Set[int]] = None,
75
+ existing_pks_map: Optional[Dict[int, int]] = None,
76
+ **kwargs: Any,
77
+ ) -> List[Model]:
54
78
  """
55
79
  Execute bulk create operation.
56
80
 
57
- NOTE: Coordinator is responsible for validation before calling this method.
58
- This executor trusts that inputs have already been validated.
81
+ NOTE: Coordinator validates inputs before calling this method.
82
+ This executor trusts that inputs are pre-validated.
59
83
 
60
84
  Args:
61
- objs: List of model instances to create (pre-validated)
62
- batch_size: Number of objects to create per batch
85
+ objs: Model instances to create (pre-validated)
86
+ batch_size: Objects per batch
63
87
  ignore_conflicts: Whether to ignore conflicts
64
88
  update_conflicts: Whether to update on conflict
65
89
  update_fields: Fields to update on conflict
66
- unique_fields: Fields to use for conflict detection
90
+ unique_fields: Fields for conflict detection
91
+ existing_record_ids: Pre-classified existing record IDs
92
+ existing_pks_map: Pre-classified existing PK mapping
67
93
  **kwargs: Additional arguments
68
94
 
69
95
  Returns:
70
- List of created objects
96
+ List of created/updated objects
71
97
  """
72
98
  if not objs:
73
99
  return objs
74
100
 
75
- # Check if this is an MTI model and route accordingly
101
+ # Route to appropriate handler
76
102
  if self.mti_handler.is_mti_model():
77
- logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk create")
78
-
79
- # Use pre-classified records if provided, otherwise classify now
80
- if existing_record_ids is None or existing_pks_map is None:
81
- existing_record_ids = set()
82
- existing_pks_map = {}
83
- if update_conflicts and unique_fields:
84
- existing_record_ids, existing_pks_map = self.record_classifier.classify_for_upsert(objs, unique_fields)
85
-
86
- # Build execution plan with classification results
87
- plan = self.mti_handler.build_create_plan(
88
- objs,
103
+ result = self._handle_mti_create(
104
+ objs=objs,
89
105
  batch_size=batch_size,
90
106
  update_conflicts=update_conflicts,
91
107
  update_fields=update_fields,
@@ -93,55 +109,136 @@ class BulkExecutor:
93
109
  existing_record_ids=existing_record_ids,
94
110
  existing_pks_map=existing_pks_map,
95
111
  )
96
- # Execute the plan
97
- result = self._execute_mti_create_plan(plan)
98
-
99
- # Tag objects with upsert metadata for hook dispatching
100
- if update_conflicts and unique_fields:
101
- self._tag_upsert_metadata(result, existing_record_ids)
102
-
103
- return result
104
-
105
- # Non-MTI model - use Django's native bulk_create
106
- result = self._execute_bulk_create(
107
- objs,
108
- batch_size,
109
- ignore_conflicts,
110
- update_conflicts,
111
- update_fields,
112
- unique_fields,
113
- **kwargs,
112
+ else:
113
+ result = self._execute_standard_bulk_create(
114
+ objs=objs,
115
+ batch_size=batch_size,
116
+ ignore_conflicts=ignore_conflicts,
117
+ update_conflicts=update_conflicts,
118
+ update_fields=update_fields,
119
+ unique_fields=unique_fields,
120
+ **kwargs,
121
+ )
122
+
123
+ # Tag upsert metadata
124
+ self._handle_upsert_metadata_tagging(
125
+ result_objects=result,
126
+ objs=objs,
127
+ update_conflicts=update_conflicts,
128
+ unique_fields=unique_fields,
129
+ existing_record_ids=existing_record_ids,
130
+ existing_pks_map=existing_pks_map,
114
131
  )
115
-
116
- # Tag objects with upsert metadata for hook dispatching
117
- if update_conflicts and unique_fields:
118
- # Use pre-classified results if available, otherwise classify now
119
- if existing_record_ids is None:
120
- existing_record_ids, _ = self.record_classifier.classify_for_upsert(objs, unique_fields)
121
- self._tag_upsert_metadata(result, existing_record_ids)
122
-
132
+
123
133
  return result
124
134
 
125
- def _execute_bulk_create(
126
- self,
127
- objs,
128
- batch_size=None,
129
- ignore_conflicts=False,
130
- update_conflicts=False,
131
- update_fields=None,
132
- unique_fields=None,
133
- **kwargs,
134
- ):
135
+ def bulk_update(self, objs: List[Model], fields: List[str], batch_size: Optional[int] = None) -> int:
135
136
  """
136
- Execute the actual Django bulk_create.
137
+ Execute bulk update operation.
138
+
139
+ NOTE: Coordinator validates inputs before calling this method.
140
+ This executor trusts that inputs are pre-validated.
141
+
142
+ Args:
143
+ objs: Model instances to update (pre-validated)
144
+ fields: Field names to update
145
+ batch_size: Objects per batch
137
146
 
138
- This is the only method that directly calls Django ORM.
139
- We must call the base Django QuerySet to avoid recursion.
147
+ Returns:
148
+ Number of objects updated
140
149
  """
141
- from django.db.models import QuerySet
150
+ if not objs:
151
+ return 0
152
+
153
+ # Ensure auto_now fields are included
154
+ fields = self._add_auto_now_fields(fields, objs)
155
+
156
+ # Route to appropriate handler
157
+ if self.mti_handler.is_mti_model():
158
+ logger.info(f"Using MTI bulk update for {self.model_cls.__name__}")
159
+ plan = self.mti_handler.build_update_plan(objs, fields, batch_size=batch_size)
160
+ return self._execute_mti_update_plan(plan)
161
+
162
+ # Standard bulk update
163
+ base_qs = self._get_base_queryset()
164
+ return base_qs.bulk_update(objs, fields, batch_size=batch_size)
165
+
166
+ def delete_queryset(self) -> Tuple[int, Dict[str, int]]:
167
+ """
168
+ Execute delete on the queryset.
169
+
170
+ NOTE: Coordinator validates inputs before calling this method.
171
+
172
+ Returns:
173
+ Tuple of (count, details dict)
174
+ """
175
+ if not self.queryset:
176
+ return 0, {}
177
+
178
+ return QuerySet.delete(self.queryset)
179
+
180
+ # ==================== Private: Create Helpers ====================
181
+
182
+ def _handle_mti_create(
183
+ self,
184
+ objs: List[Model],
185
+ batch_size: Optional[int],
186
+ update_conflicts: bool,
187
+ update_fields: Optional[List[str]],
188
+ unique_fields: Optional[List[str]],
189
+ existing_record_ids: Optional[Set[int]],
190
+ existing_pks_map: Optional[Dict[int, int]],
191
+ ) -> List[Model]:
192
+ """Handle MTI model creation with classification and planning."""
193
+ # Classify records if not pre-classified
194
+ if existing_record_ids is None or existing_pks_map is None:
195
+ existing_record_ids, existing_pks_map = self._classify_mti_records(objs, update_conflicts, unique_fields)
196
+
197
+ # Build and execute plan
198
+ plan = self.mti_handler.build_create_plan(
199
+ objs=objs,
200
+ batch_size=batch_size,
201
+ update_conflicts=update_conflicts,
202
+ update_fields=update_fields,
203
+ unique_fields=unique_fields,
204
+ existing_record_ids=existing_record_ids,
205
+ existing_pks_map=existing_pks_map,
206
+ )
207
+
208
+ return self._execute_mti_create_plan(plan)
209
+
210
+ def _classify_mti_records(
211
+ self,
212
+ objs: List[Model],
213
+ update_conflicts: bool,
214
+ unique_fields: Optional[List[str]],
215
+ ) -> Tuple[Set[int], Dict[int, int]]:
216
+ """Classify MTI records for upsert operations."""
217
+ if not update_conflicts or not unique_fields:
218
+ return set(), {}
219
+
220
+ # Find correct model to query
221
+ query_model = self.mti_handler.find_model_with_unique_fields(unique_fields)
222
+ logger.info(f"MTI upsert: querying {query_model.__name__} for unique fields {unique_fields}")
223
+
224
+ existing_record_ids, existing_pks_map = self.record_classifier.classify_for_upsert(objs, unique_fields, query_model=query_model)
225
+
226
+ logger.info(f"MTI classification: {len(existing_record_ids)} existing, {len(objs) - len(existing_record_ids)} new")
142
227
 
143
- # Create a base Django queryset (not our HookQuerySet)
144
- base_qs = QuerySet(model=self.model_cls, using=self.queryset.db)
228
+ return existing_record_ids, existing_pks_map
229
+
230
+ def _execute_standard_bulk_create(
231
+ self,
232
+ objs: List[Model],
233
+ batch_size: Optional[int],
234
+ ignore_conflicts: bool,
235
+ update_conflicts: bool,
236
+ update_fields: Optional[List[str]],
237
+ unique_fields: Optional[List[str]],
238
+ **kwargs: Any,
239
+ ) -> List[Model]:
240
+ """Execute Django's native bulk_create for non-MTI models."""
241
+ base_qs = self._get_base_queryset()
145
242
 
146
243
  return base_qs.bulk_create(
147
244
  objs,
@@ -152,396 +249,494 @@ class BulkExecutor:
152
249
  unique_fields=unique_fields,
153
250
  )
154
251
 
155
- def bulk_update(self, objs, fields, batch_size=None):
252
+ def _handle_upsert_metadata_tagging(
253
+ self,
254
+ result_objects: List[Model],
255
+ objs: List[Model],
256
+ update_conflicts: bool,
257
+ unique_fields: Optional[List[str]],
258
+ existing_record_ids: Optional[Set[int]],
259
+ existing_pks_map: Optional[Dict[int, int]],
260
+ ) -> None:
156
261
  """
157
- Execute bulk update operation.
262
+ Tag upsert metadata on result objects.
158
263
 
159
- NOTE: Coordinator is responsible for validation before calling this method.
160
- This executor trusts that inputs have already been validated.
264
+ Centralizes metadata tagging logic for both MTI and non-MTI paths.
161
265
 
162
266
  Args:
163
- objs: List of model instances to update (pre-validated)
164
- fields: List of field names to update
165
- batch_size: Number of objects to update per batch
267
+ result_objects: Objects returned from bulk operation
268
+ objs: Original objects passed to bulk_create
269
+ update_conflicts: Whether this was an upsert operation
270
+ unique_fields: Fields used for conflict detection
271
+ existing_record_ids: Pre-classified existing record IDs
272
+ existing_pks_map: Pre-classified existing PK mapping
273
+ """
274
+ if not (update_conflicts and unique_fields):
275
+ return
276
+
277
+ # Classify if needed
278
+ if existing_record_ids is None or existing_pks_map is None:
279
+ existing_record_ids, existing_pks_map = self.record_classifier.classify_for_upsert(objs, unique_fields)
280
+
281
+ tag_upsert_metadata(result_objects, existing_record_ids, existing_pks_map)
282
+
283
+ # ==================== Private: Update Helpers ====================
284
+
285
+ def _add_auto_now_fields(self, fields: List[str], objs: List[Model]) -> List[str]:
286
+ """
287
+ Add auto_now fields to update list for all models in chain.
288
+
289
+ Handles both MTI and non-MTI models uniformly.
290
+
291
+ Args:
292
+ fields: Original field list
293
+ objs: Objects being updated
166
294
 
167
295
  Returns:
168
- Number of objects updated
296
+ Field list with auto_now fields included
169
297
  """
170
- if not objs:
171
- return 0
298
+ fields = list(fields) # Copy to avoid mutation
172
299
 
173
- # Check if this is an MTI model and route accordingly
300
+ # Get models to check
174
301
  if self.mti_handler.is_mti_model():
175
- logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk update")
176
- # Build execution plan
177
- plan = self.mti_handler.build_update_plan(objs, fields, batch_size=batch_size)
178
- # Execute the plan
179
- return self._execute_mti_update_plan(plan)
302
+ models_to_check = self.mti_handler.get_inheritance_chain()
303
+ else:
304
+ models_to_check = [self.model_cls]
180
305
 
181
- # Non-MTI model - use Django's native bulk_update
182
- # Validation already done by coordinator
183
- from django.db.models import QuerySet
306
+ # Handle auto_now fields uniformly
307
+ auto_now_fields = handle_auto_now_fields_for_inheritance_chain(models_to_check, objs, for_update=True)
184
308
 
185
- base_qs = QuerySet(model=self.model_cls, using=self.queryset.db)
186
- return base_qs.bulk_update(objs, fields, batch_size=batch_size)
309
+ # Add to fields list if not present
310
+ for auto_now_field in auto_now_fields:
311
+ if auto_now_field not in fields:
312
+ fields.append(auto_now_field)
313
+
314
+ return fields
187
315
 
188
- # ==================== MTI PLAN EXECUTION ====================
316
+ # ==================== Private: MTI Create Execution ====================
189
317
 
190
- def _execute_mti_create_plan(self, plan):
318
+ def _execute_mti_create_plan(self, plan: Any) -> List[Model]:
191
319
  """
192
320
  Execute an MTI create plan.
193
321
 
194
- This is where ALL database operations happen for MTI bulk_create.
195
- Handles both new records (INSERT) and existing records (UPDATE) for upsert.
322
+ Handles INSERT and UPDATE for upsert operations.
196
323
 
197
324
  Args:
198
- plan: MTICreatePlan object from MTIHandler
325
+ plan: MTICreatePlan from MTIHandler
199
326
 
200
327
  Returns:
201
328
  List of created/updated objects with PKs assigned
202
329
  """
203
- from django.db.models import QuerySet as BaseQuerySet
204
-
205
330
  if not plan:
206
331
  return []
207
332
 
208
333
  with transaction.atomic(using=self.queryset.db, savepoint=False):
209
- # Step 1: Create/Update all parent objects level by level
210
- parent_instances_map = {} # Maps original obj id() -> {model: parent_instance}
211
-
212
- for parent_level in plan.parent_levels:
213
- # Separate new and existing parent objects
214
- new_parents = []
215
- existing_parents = []
216
-
217
- for parent_obj in parent_level.objects:
218
- orig_obj_id = parent_level.original_object_map[id(parent_obj)]
219
- if orig_obj_id in plan.existing_record_ids:
220
- existing_parents.append(parent_obj)
221
- else:
222
- new_parents.append(parent_obj)
223
-
224
- # Bulk create new parents
225
- if new_parents:
226
- bulk_kwargs = {"batch_size": len(new_parents)}
227
-
228
- if parent_level.update_conflicts:
229
- bulk_kwargs["update_conflicts"] = True
230
- bulk_kwargs["unique_fields"] = parent_level.unique_fields
231
- bulk_kwargs["update_fields"] = parent_level.update_fields
232
-
233
- # Use base QuerySet to avoid recursion
234
- base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
235
- created_parents = base_qs.bulk_create(new_parents, **bulk_kwargs)
236
-
237
- # Copy generated fields back to parent objects
238
- for created_parent, parent_obj in zip(created_parents, new_parents):
239
- for field in parent_level.model_class._meta.local_fields:
240
- created_value = getattr(created_parent, field.name, None)
241
- if created_value is not None:
242
- setattr(parent_obj, field.name, created_value)
243
-
244
- parent_obj._state.adding = False
245
- parent_obj._state.db = self.queryset.db
246
-
247
- # Update existing parents
248
- if existing_parents and parent_level.update_fields:
249
- # Filter update fields to only those that exist in this parent model
250
- parent_model_fields = {field.name for field in parent_level.model_class._meta.local_fields}
251
- filtered_update_fields = [field for field in parent_level.update_fields if field in parent_model_fields]
252
-
253
- if filtered_update_fields:
254
- base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
255
- base_qs.bulk_update(existing_parents, filtered_update_fields)
256
-
257
- # Mark as not adding
258
- for parent_obj in existing_parents:
259
- parent_obj._state.adding = False
260
- parent_obj._state.db = self.queryset.db
261
-
262
- # Map parents back to original objects
263
- for parent_obj in parent_level.objects:
264
- orig_obj_id = parent_level.original_object_map[id(parent_obj)]
265
- if orig_obj_id not in parent_instances_map:
266
- parent_instances_map[orig_obj_id] = {}
267
- parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
268
-
269
- # Step 2: Add parent links to child objects and separate new/existing
270
- new_child_objects = []
271
- existing_child_objects = []
272
-
273
- for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
274
- parent_instances = parent_instances_map.get(id(orig_obj), {})
275
-
276
- # Set parent links
277
- for parent_model, parent_instance in parent_instances.items():
278
- parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
279
- if parent_link:
280
- setattr(child_obj, parent_link.attname, parent_instance.pk)
281
- setattr(child_obj, parent_link.name, parent_instance)
282
-
283
- # Classify as new or existing
284
- if id(orig_obj) in plan.existing_record_ids:
285
- # For existing records, set the PK on child object
286
- pk_value = getattr(orig_obj, "pk", None)
287
- if pk_value:
288
- child_obj.pk = pk_value
289
- child_obj.id = pk_value
290
- existing_child_objects.append(child_obj)
291
- else:
292
- new_child_objects.append(child_obj)
293
-
294
- # Step 3: Bulk create new child objects using _batched_insert (to bypass MTI check)
295
- if new_child_objects:
296
- base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
297
- base_qs._prepare_for_bulk_create(new_child_objects)
298
-
299
- # Partition objects by PK status
300
- objs_without_pk, objs_with_pk = [], []
301
- for obj in new_child_objects:
302
- if obj._is_pk_set():
303
- objs_with_pk.append(obj)
304
- else:
305
- objs_without_pk.append(obj)
306
-
307
- # Get fields for insert
308
- opts = plan.child_model._meta
309
- fields = [f for f in opts.local_fields if not f.generated]
310
-
311
- # Execute bulk insert
312
- if objs_with_pk:
313
- returned_columns = base_qs._batched_insert(
314
- objs_with_pk,
315
- fields,
316
- batch_size=len(objs_with_pk),
317
- )
318
- if returned_columns:
319
- for obj, results in zip(objs_with_pk, returned_columns):
320
- if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
321
- for result, field in zip(results, opts.db_returning_fields):
322
- if field != opts.pk:
323
- setattr(obj, field.attname, result)
324
- obj._state.adding = False
325
- obj._state.db = self.queryset.db
326
- else:
327
- for obj in objs_with_pk:
328
- obj._state.adding = False
329
- obj._state.db = self.queryset.db
330
-
331
- if objs_without_pk:
332
- filtered_fields = [f for f in fields if not isinstance(f, AutoField) and not f.primary_key]
333
- returned_columns = base_qs._batched_insert(
334
- objs_without_pk,
335
- filtered_fields,
336
- batch_size=len(objs_without_pk),
337
- )
338
- if returned_columns:
339
- for obj, results in zip(objs_without_pk, returned_columns):
340
- if hasattr(opts, "db_returning_fields"):
341
- for result, field in zip(results, opts.db_returning_fields):
342
- setattr(obj, field.attname, result)
343
- obj._state.adding = False
344
- obj._state.db = self.queryset.db
345
- else:
346
- for obj in objs_without_pk:
347
- obj._state.adding = False
348
- obj._state.db = self.queryset.db
349
-
350
- # Step 3.5: Update existing child objects
351
- if existing_child_objects and plan.update_fields:
352
- # Filter update fields to only those that exist in the child model
353
- child_model_fields = {field.name for field in plan.child_model._meta.local_fields}
354
- filtered_child_update_fields = [field for field in plan.update_fields if field in child_model_fields]
355
-
356
- if filtered_child_update_fields:
357
- base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
358
- base_qs.bulk_update(existing_child_objects, filtered_child_update_fields)
359
-
360
- # Mark as not adding
361
- for child_obj in existing_child_objects:
362
- child_obj._state.adding = False
363
- child_obj._state.db = self.queryset.db
364
-
365
- # Combine all children for final processing
366
- created_children = new_child_objects + existing_child_objects
367
-
368
- # Step 4: Copy PKs and auto-generated fields back to original objects
369
- pk_field_name = plan.child_model._meta.pk.name
370
-
371
- for orig_obj, child_obj in zip(plan.original_objects, created_children):
372
- # Copy PK
373
- child_pk = getattr(child_obj, pk_field_name)
374
- setattr(orig_obj, pk_field_name, child_pk)
375
-
376
- # Copy auto-generated fields from all levels
377
- parent_instances = parent_instances_map.get(id(orig_obj), {})
378
-
379
- for model_class in plan.inheritance_chain:
380
- # Get source object for this level
381
- if model_class in parent_instances:
382
- source_obj = parent_instances[model_class]
383
- elif model_class == plan.child_model:
384
- source_obj = child_obj
385
- else:
386
- continue
387
-
388
- # Copy auto-generated field values
389
- for field in model_class._meta.local_fields:
390
- if field.name == pk_field_name:
391
- continue
392
-
393
- # Skip parent link fields
394
- if hasattr(field, "remote_field") and field.remote_field:
395
- parent_link = plan.child_model._meta.get_ancestor_link(model_class)
396
- if parent_link and field.name == parent_link.name:
397
- continue
398
-
399
- # Copy auto_now_add, auto_now, and db_returning fields
400
- if (
401
- getattr(field, "auto_now_add", False)
402
- or getattr(field, "auto_now", False)
403
- or getattr(field, "db_returning", False)
404
- ):
405
- source_value = getattr(source_obj, field.name, None)
406
- if source_value is not None:
407
- setattr(orig_obj, field.name, source_value)
408
-
409
- # Update object state
410
- orig_obj._state.adding = False
411
- orig_obj._state.db = self.queryset.db
334
+ # Step 1: Upsert all parent levels
335
+ parent_instances_map = self._upsert_parent_levels(plan)
336
+
337
+ # Step 2: Link children to parents
338
+ self._link_children_to_parents(plan, parent_instances_map)
339
+
340
+ # Step 3: Handle child objects (insert new, update existing)
341
+ self._handle_child_objects(plan)
342
+
343
+ # Step 4: Copy PKs and auto-fields back to original objects
344
+ self._copy_fields_to_original_objects(plan, parent_instances_map)
412
345
 
413
346
  return plan.original_objects
414
347
 
415
- def _execute_mti_update_plan(self, plan):
348
+ def _upsert_parent_levels(self, plan: Any) -> Dict[int, Dict[type, Model]]:
349
+ """
350
+ Upsert all parent objects level by level.
351
+
352
+ Returns:
353
+ Mapping of original obj id() -> {model: parent_instance}
354
+ """
355
+ parent_instances_map: Dict[int, Dict[type, Model]] = {}
356
+
357
+ for parent_level in plan.parent_levels:
358
+ base_qs = QuerySet(model=parent_level.model_class, using=self.queryset.db)
359
+
360
+ # Build bulk_create kwargs
361
+ bulk_kwargs = {"batch_size": len(parent_level.objects)}
362
+
363
+ if parent_level.update_conflicts:
364
+ self._add_upsert_kwargs(bulk_kwargs, parent_level)
365
+
366
+ # Execute upsert
367
+ upserted_parents = base_qs.bulk_create(parent_level.objects, **bulk_kwargs)
368
+
369
+ # Copy generated fields back
370
+ self._copy_generated_fields(upserted_parents, parent_level.objects, parent_level.model_class)
371
+
372
+ # Map parents to original objects
373
+ self._map_parents_to_originals(parent_level, parent_instances_map)
374
+
375
+ return parent_instances_map
376
+
377
+ def _add_upsert_kwargs(self, bulk_kwargs: Dict[str, Any], parent_level: Any) -> None:
378
+ """Add upsert parameters to bulk_create kwargs."""
379
+ bulk_kwargs["update_conflicts"] = True
380
+ bulk_kwargs["unique_fields"] = parent_level.unique_fields
381
+
382
+ # Filter update fields
383
+ parent_model_fields = {field.name for field in parent_level.model_class._meta.local_fields}
384
+ filtered_update_fields = [field for field in parent_level.update_fields if field in parent_model_fields]
385
+
386
+ if filtered_update_fields:
387
+ bulk_kwargs["update_fields"] = filtered_update_fields
388
+
389
+ def _copy_generated_fields(
390
+ self,
391
+ upserted_parents: List[Model],
392
+ parent_objs: List[Model],
393
+ model_class: type[Model],
394
+ ) -> None:
395
+ """Copy generated fields from upserted objects back to parent objects."""
396
+ for upserted_parent, parent_obj in zip(upserted_parents, parent_objs):
397
+ for field in model_class._meta.local_fields:
398
+ # Use attname for FK fields to avoid queries
399
+ field_attr = field.attname if isinstance(field, ForeignKey) else field.name
400
+ upserted_value = getattr(upserted_parent, field_attr, None)
401
+ if upserted_value is not None:
402
+ setattr(parent_obj, field_attr, upserted_value)
403
+
404
+ parent_obj._state.adding = False
405
+ parent_obj._state.db = self.queryset.db
406
+
407
+ def _map_parents_to_originals(self, parent_level: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
408
+ """Map parent instances back to original objects."""
409
+ for parent_obj in parent_level.objects:
410
+ orig_obj_id = parent_level.original_object_map[id(parent_obj)]
411
+ if orig_obj_id not in parent_instances_map:
412
+ parent_instances_map[orig_obj_id] = {}
413
+ parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
414
+
415
+ def _link_children_to_parents(self, plan: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
416
+ """Link child objects to their parent objects and set PKs."""
417
+ for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
418
+ parent_instances = parent_instances_map.get(id(orig_obj), {})
419
+
420
+ for parent_model, parent_instance in parent_instances.items():
421
+ parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
422
+
423
+ if parent_link:
424
+ parent_pk = parent_instance.pk
425
+ setattr(child_obj, parent_link.attname, parent_pk)
426
+ setattr(child_obj, parent_link.name, parent_instance)
427
+ # In MTI, child PK equals parent PK
428
+ child_obj.pk = parent_pk
429
+ child_obj.id = parent_pk
430
+ else:
431
+ logger.warning(f"No parent link found for {parent_model} in {plan.child_model}")
432
+
433
+ def _handle_child_objects(self, plan: Any) -> None:
434
+ """Handle child object insertion and updates."""
435
+ base_qs = QuerySet(model=plan.child_model, using=self.queryset.db)
436
+
437
+ # Split objects: new vs existing
438
+ objs_without_pk, objs_with_pk = self._split_child_objects(plan, base_qs)
439
+
440
+ # Update existing children
441
+ if objs_with_pk and plan.update_fields:
442
+ self._update_existing_children(base_qs, objs_with_pk, plan)
443
+
444
+ # Insert new children
445
+ if objs_without_pk:
446
+ self._insert_new_children(base_qs, objs_without_pk, plan)
447
+
448
+ def _split_child_objects(self, plan: Any, base_qs: QuerySet) -> Tuple[List[Model], List[Model]]:
449
+ """Split child objects into new and existing."""
450
+ if not plan.update_conflicts:
451
+ return plan.child_objects, []
452
+
453
+ # Check which child records exist
454
+ parent_pks = [
455
+ getattr(child_obj, plan.child_model._meta.pk.attname, None)
456
+ for child_obj in plan.child_objects
457
+ if getattr(child_obj, plan.child_model._meta.pk.attname, None)
458
+ ]
459
+
460
+ existing_child_pks = set()
461
+ if parent_pks:
462
+ existing_child_pks = set(base_qs.filter(pk__in=parent_pks).values_list("pk", flat=True))
463
+
464
+ objs_without_pk = []
465
+ objs_with_pk = []
466
+
467
+ for child_obj in plan.child_objects:
468
+ child_pk = getattr(child_obj, plan.child_model._meta.pk.attname, None)
469
+ if child_pk and child_pk in existing_child_pks:
470
+ objs_with_pk.append(child_obj)
471
+ else:
472
+ objs_without_pk.append(child_obj)
473
+
474
+ return objs_without_pk, objs_with_pk
475
+
476
+ def _update_existing_children(self, base_qs: QuerySet, objs_with_pk: List[Model], plan: Any) -> None:
477
+ """Update existing child records."""
478
+ child_model_fields = {field.name for field in plan.child_model._meta.local_fields}
479
+ filtered_child_update_fields = [field for field in plan.update_fields if field in child_model_fields]
480
+
481
+ if filtered_child_update_fields:
482
+ base_qs.bulk_update(objs_with_pk, filtered_child_update_fields)
483
+
484
+ for obj in objs_with_pk:
485
+ obj._state.adding = False
486
+ obj._state.db = self.queryset.db
487
+
488
+ def _insert_new_children(self, base_qs: QuerySet, objs_without_pk: List[Model], plan: Any) -> None:
489
+ """Insert new child records using _batched_insert."""
490
+ base_qs._prepare_for_bulk_create(objs_without_pk)
491
+ opts = plan.child_model._meta
492
+
493
+ # Get fields for insertion
494
+ filtered_fields = [f for f in opts.local_fields if not f.generated]
495
+
496
+ # Build upsert kwargs
497
+ kwargs = self._build_batched_insert_kwargs(plan, len(objs_without_pk))
498
+
499
+ # Execute insert
500
+ returned_columns = base_qs._batched_insert(objs_without_pk, filtered_fields, **kwargs)
501
+
502
+ # Process returned columns
503
+ self._process_returned_columns(objs_without_pk, returned_columns, opts)
504
+
505
+ def _build_batched_insert_kwargs(self, plan: Any, batch_size: int) -> Dict[str, Any]:
506
+ """Build kwargs for _batched_insert call."""
507
+ kwargs = {"batch_size": batch_size}
508
+
509
+ if not (plan.update_conflicts and plan.child_unique_fields):
510
+ return kwargs
511
+
512
+ batched_unique_fields = plan.child_unique_fields
513
+ batched_update_fields = plan.child_update_fields
514
+
515
+ if batched_update_fields:
516
+ on_conflict = OnConflict.UPDATE
517
+ else:
518
+ # No update fields on child - use IGNORE
519
+ on_conflict = OnConflict.IGNORE
520
+ batched_update_fields = None
521
+
522
+ kwargs.update(
523
+ {
524
+ "on_conflict": on_conflict,
525
+ "update_fields": batched_update_fields,
526
+ "unique_fields": batched_unique_fields,
527
+ }
528
+ )
529
+
530
+ return kwargs
531
+
532
+ def _process_returned_columns(self, objs: List[Model], returned_columns: Any, opts: Any) -> None:
533
+ """Process returned columns from _batched_insert."""
534
+ if returned_columns:
535
+ for obj, results in zip(objs, returned_columns):
536
+ if hasattr(opts, "db_returning_fields"):
537
+ for result, field in zip(results, opts.db_returning_fields):
538
+ setattr(obj, field.attname, result)
539
+ obj._state.adding = False
540
+ obj._state.db = self.queryset.db
541
+ else:
542
+ for obj in objs:
543
+ obj._state.adding = False
544
+ obj._state.db = self.queryset.db
545
+
546
+ def _copy_fields_to_original_objects(self, plan: Any, parent_instances_map: Dict[int, Dict[type, Model]]) -> None:
547
+ """Copy PKs and auto-generated fields to original objects."""
548
+ pk_field_name = plan.child_model._meta.pk.name
549
+
550
+ for orig_obj, child_obj in zip(plan.original_objects, plan.child_objects):
551
+ # Copy PK
552
+ child_pk = getattr(child_obj, pk_field_name)
553
+ setattr(orig_obj, pk_field_name, child_pk)
554
+
555
+ # Copy auto-generated fields from all levels
556
+ self._copy_auto_generated_fields(orig_obj, child_obj, plan, parent_instances_map, pk_field_name)
557
+
558
+ # Update state
559
+ orig_obj._state.adding = False
560
+ orig_obj._state.db = self.queryset.db
561
+
562
+ def _copy_auto_generated_fields(
563
+ self,
564
+ orig_obj: Model,
565
+ child_obj: Model,
566
+ plan: Any,
567
+ parent_instances_map: Dict[int, Dict[type, Model]],
568
+ pk_field_name: str,
569
+ ) -> None:
570
+ """Copy auto-generated fields from all inheritance levels."""
571
+ parent_instances = parent_instances_map.get(id(orig_obj), {})
572
+
573
+ for model_class in plan.inheritance_chain:
574
+ # Get source object
575
+ if model_class in parent_instances:
576
+ source_obj = parent_instances[model_class]
577
+ elif model_class == plan.child_model:
578
+ source_obj = child_obj
579
+ else:
580
+ continue
581
+
582
+ # Copy auto-generated fields
583
+ for field in model_class._meta.local_fields:
584
+ if field.name == pk_field_name:
585
+ continue
586
+
587
+ # Skip parent link fields
588
+ if self._is_parent_link_field(field, plan.child_model, model_class):
589
+ continue
590
+
591
+ # Copy auto_now, auto_now_add, and db_returning fields
592
+ if self._is_auto_generated_field(field):
593
+ source_value = getattr(source_obj, field.name, None)
594
+ if source_value is not None:
595
+ setattr(orig_obj, field.name, source_value)
596
+
597
+ def _is_parent_link_field(self, field: Any, child_model: type[Model], model_class: type[Model]) -> bool:
598
+ """Check if field is a parent link field."""
599
+ if not (hasattr(field, "remote_field") and field.remote_field):
600
+ return False
601
+
602
+ parent_link = child_model._meta.get_ancestor_link(model_class)
603
+ return parent_link and field.name == parent_link.name
604
+
605
+ def _is_auto_generated_field(self, field: Any) -> bool:
606
+ """Check if field is auto-generated."""
607
+ return getattr(field, "auto_now_add", False) or getattr(field, "auto_now", False) or getattr(field, "db_returning", False)
608
+
609
+ # ==================== Private: MTI Update Execution ====================
610
+
611
+ def _execute_mti_update_plan(self, plan: Any) -> int:
416
612
  """
417
613
  Execute an MTI update plan.
418
614
 
419
- Updates each table in the inheritance chain using CASE/WHEN for bulk updates.
615
+ Updates each table in the inheritance chain using CASE/WHEN.
420
616
 
421
617
  Args:
422
- plan: MTIUpdatePlan object from MTIHandler
618
+ plan: MTIUpdatePlan from MTIHandler
423
619
 
424
620
  Returns:
425
621
  Number of objects updated
426
622
  """
427
- from django.db.models import Case
428
- from django.db.models import QuerySet as BaseQuerySet
429
- from django.db.models import Value
430
- from django.db.models import When
431
-
432
623
  if not plan:
433
624
  return 0
434
625
 
435
- total_updated = 0
436
-
437
- # Get PKs for filtering
438
- root_pks = [
439
- getattr(obj, "pk", None) or getattr(obj, "id", None)
440
- for obj in plan.objects
441
- if getattr(obj, "pk", None) or getattr(obj, "id", None)
442
- ]
443
-
626
+ root_pks = self._get_root_pks(plan.objects)
444
627
  if not root_pks:
445
628
  return 0
446
629
 
630
+ total_updated = 0
631
+
447
632
  with transaction.atomic(using=self.queryset.db, savepoint=False):
448
- # Update each table in the chain
449
633
  for field_group in plan.field_groups:
450
634
  if not field_group.fields:
451
635
  continue
452
636
 
453
- base_qs = BaseQuerySet(model=field_group.model_class, using=self.queryset.db)
637
+ updated_count = self._update_field_group(field_group, root_pks, plan.objects)
638
+ total_updated += updated_count
454
639
 
455
- # Check if records exist
456
- existing_count = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks}).count()
457
- if existing_count == 0:
458
- continue
640
+ return total_updated
459
641
 
460
- # Build CASE statements for bulk update
461
- case_statements = {}
462
- for field_name in field_group.fields:
463
- field = field_group.model_class._meta.get_field(field_name)
464
- when_statements = []
465
-
466
- # Determine the correct output field for type casting
467
- # For ForeignKey fields, use the target field to ensure correct SQL types
468
- is_fk = isinstance(field, ForeignKey)
469
- case_output_field = field.target_field if is_fk else field
470
-
471
- for pk, obj in zip(root_pks, plan.objects):
472
- obj_pk = getattr(obj, "pk", None) or getattr(obj, "id", None)
473
- if obj_pk is None:
474
- continue
475
-
476
- # Get the field value - handle ForeignKey fields specially
477
- value = getattr(obj, field.attname, None) if is_fk else getattr(obj, field_name)
478
-
479
- # Handle NULL values specially for ForeignKey fields
480
- if is_fk and value is None:
481
- # For ForeignKey fields with None values, use Cast to ensure proper NULL type
482
- # PostgreSQL needs explicit type casting for NULL values in CASE statements
483
- when_statements.append(
484
- When(
485
- **{field_group.filter_field: pk},
486
- then=Cast(Value(None), output_field=case_output_field),
487
- ),
488
- )
489
- else:
490
- # For non-None values or non-FK fields, use Value with output_field
491
- when_statements.append(
492
- When(
493
- **{field_group.filter_field: pk},
494
- then=Value(value, output_field=case_output_field),
495
- ),
496
- )
497
-
498
- if when_statements:
499
- case_statements[field_name] = Case(*when_statements, output_field=case_output_field)
500
-
501
- # Execute bulk update
502
- if case_statements:
503
- try:
504
- updated_count = base_qs.filter(
505
- **{f"{field_group.filter_field}__in": root_pks},
506
- ).update(**case_statements)
507
- total_updated += updated_count
508
- except Exception as e:
509
- logger.error(f"MTI bulk update failed for {field_group.model_class.__name__}: {e}")
642
+ def _get_root_pks(self, objs: List[Model]) -> List[Any]:
643
+ """Extract primary keys from objects."""
644
+ return [
645
+ getattr(obj, "pk", None) or getattr(obj, "id", None) for obj in objs if getattr(obj, "pk", None) or getattr(obj, "id", None)
646
+ ]
510
647
 
511
- return total_updated
648
+ def _update_field_group(self, field_group: Any, root_pks: List[Any], objs: List[Model]) -> int:
649
+ """Update a single field group."""
650
+ base_qs = QuerySet(model=field_group.model_class, using=self.queryset.db)
512
651
 
513
- def delete_queryset(self):
514
- """
515
- Execute delete on the queryset.
652
+ # Check if records exist
653
+ if not self._check_records_exist(base_qs, field_group, root_pks):
654
+ return 0
516
655
 
517
- NOTE: Coordinator is responsible for validation before calling this method.
518
- This executor trusts that inputs have already been validated.
656
+ # Build CASE statements
657
+ case_statements = self._build_case_statements(field_group, root_pks, objs)
519
658
 
520
- Returns:
521
- Tuple of (count, details dict)
522
- """
523
- if not self.queryset:
524
- return 0, {}
659
+ if not case_statements:
660
+ logger.debug(f"No CASE statements for {field_group.model_class.__name__}")
661
+ return 0
525
662
 
526
- # Execute delete via QuerySet
527
- # Validation already done by coordinator
528
- from django.db.models import QuerySet
663
+ # Execute update
664
+ return self._execute_field_group_update(base_qs, field_group, root_pks, case_statements)
529
665
 
530
- return QuerySet.delete(self.queryset)
666
+ def _check_records_exist(self, base_qs: QuerySet, field_group: Any, root_pks: List[Any]) -> bool:
667
+ """Check if any records exist for update."""
668
+ existing_count = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks}).count()
669
+ return existing_count > 0
531
670
 
532
- def _tag_upsert_metadata(self, result_objects, existing_record_ids):
533
- """
534
- Tag objects with metadata indicating whether they were created or updated.
535
-
536
- This metadata is used by the coordinator to determine which hooks to fire.
537
- The metadata is temporary and will be cleaned up after hook execution.
538
-
539
- Args:
540
- result_objects: List of objects returned from bulk operation
541
- existing_record_ids: Set of id() for objects that existed before the operation
542
- """
543
- for obj in result_objects:
544
- # Tag with metadata for hook dispatching
545
- was_created = id(obj) not in existing_record_ids
546
- obj._bulk_hooks_was_created = was_created
547
- obj._bulk_hooks_upsert_metadata = True
671
+ def _build_case_statements(self, field_group: Any, root_pks: List[Any], objs: List[Model]) -> Dict[str, Case]:
672
+ """Build CASE statements for all fields in the group."""
673
+ case_statements = {}
674
+
675
+ logger.debug(f"Building CASE statements for {field_group.model_class.__name__} with {len(field_group.fields)} fields")
676
+
677
+ for field_name in field_group.fields:
678
+ case_stmt = self._build_field_case_statement(field_name, field_group, root_pks, objs)
679
+ if case_stmt:
680
+ case_statements[field_name] = case_stmt
681
+
682
+ return case_statements
683
+
684
+ def _build_field_case_statement(
685
+ self,
686
+ field_name: str,
687
+ field_group: Any,
688
+ root_pks: List[Any],
689
+ objs: List[Model],
690
+ ) -> Optional[Case]:
691
+ """Build CASE statement for a single field."""
692
+ field = field_group.model_class._meta.get_field(field_name)
693
+ when_statements = []
694
+
695
+ for pk, obj in zip(root_pks, objs):
696
+ obj_pk = getattr(obj, "pk", None) or getattr(obj, "id", None)
697
+ if obj_pk is None:
698
+ continue
699
+
700
+ # Get and convert field value
701
+ value = get_field_value_for_db(obj, field_name, field_group.model_class)
702
+ value = field.to_python(value)
703
+
704
+ # Create WHEN with type casting
705
+ when_statement = When(
706
+ **{field_group.filter_field: pk},
707
+ then=Cast(Value(value), output_field=field),
708
+ )
709
+ when_statements.append(when_statement)
710
+
711
+ if when_statements:
712
+ return Case(*when_statements, output_field=field)
713
+
714
+ return None
715
+
716
+ def _execute_field_group_update(
717
+ self,
718
+ base_qs: QuerySet,
719
+ field_group: Any,
720
+ root_pks: List[Any],
721
+ case_statements: Dict[str, Case],
722
+ ) -> int:
723
+ """Execute the actual update query."""
724
+ logger.debug(f"Executing update for {field_group.model_class.__name__} with {len(case_statements)} fields")
725
+
726
+ try:
727
+ query_qs = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks})
728
+ updated_count = query_qs.update(**case_statements)
729
+
730
+ logger.debug(f"Updated {updated_count} records in {field_group.model_class.__name__}")
731
+
732
+ return updated_count
733
+
734
+ except Exception as e:
735
+ logger.error(f"MTI bulk update failed for {field_group.model_class.__name__}: {e}")
736
+ raise
737
+
738
+ # ==================== Private: Utilities ====================
739
+
740
+ def _get_base_queryset(self) -> QuerySet:
741
+ """Get base Django QuerySet to avoid recursion."""
742
+ return QuerySet(model=self.model_cls, using=self.queryset.db)