django-bulk-hooks 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of django-bulk-hooks might be problematic. Click here for more details.

@@ -5,14 +5,14 @@ This module contains all services for bulk operations following
5
5
  a clean, service-based architecture.
6
6
  """
7
7
 
8
- from django_bulk_hooks.operations.coordinator import BulkOperationCoordinator
9
8
  from django_bulk_hooks.operations.analyzer import ModelAnalyzer
10
9
  from django_bulk_hooks.operations.bulk_executor import BulkExecutor
10
+ from django_bulk_hooks.operations.coordinator import BulkOperationCoordinator
11
11
  from django_bulk_hooks.operations.mti_handler import MTIHandler
12
12
 
13
13
  __all__ = [
14
- 'BulkOperationCoordinator',
15
- 'ModelAnalyzer',
16
- 'BulkExecutor',
17
- 'MTIHandler',
14
+ "BulkExecutor",
15
+ "BulkOperationCoordinator",
16
+ "MTIHandler",
17
+ "ModelAnalyzer",
18
18
  ]
@@ -84,7 +84,7 @@ class ModelAnalyzer:
84
84
  if invalid_types:
85
85
  raise TypeError(
86
86
  f"{operation} expected instances of {self.model_cls.__name__}, "
87
- f"but got {invalid_types}"
87
+ f"but got {invalid_types}",
88
88
  )
89
89
 
90
90
  def _check_has_pks(self, objs, operation="operation"):
@@ -94,7 +94,7 @@ class ModelAnalyzer:
94
94
  if missing_pks:
95
95
  raise ValueError(
96
96
  f"{operation} cannot operate on unsaved {self.model_cls.__name__} instances. "
97
- f"{len(missing_pks)} object(s) have no primary key."
97
+ f"{len(missing_pks)} object(s) have no primary key.",
98
98
  )
99
99
 
100
100
  # ========== Data Fetching Methods ==========
@@ -130,7 +130,7 @@ class ModelAnalyzer:
130
130
  auto_now_fields = []
131
131
  for field in self.model_cls._meta.fields:
132
132
  if getattr(field, "auto_now", False) or getattr(
133
- field, "auto_now_add", False
133
+ field, "auto_now_add", False,
134
134
  ):
135
135
  auto_now_fields.append(field.name)
136
136
  return auto_now_fields
@@ -224,28 +224,28 @@ class ModelAnalyzer:
224
224
  """
225
225
  from django.db.models import Expression
226
226
  from django.db.models.expressions import Combinable
227
-
227
+
228
228
  # Simple value - return as-is
229
229
  if not isinstance(expression, (Expression, Combinable)):
230
230
  return expression
231
-
231
+
232
232
  # For complex expressions, evaluate them in database context
233
233
  # Use annotate() which Django properly handles for all expression types
234
234
  try:
235
235
  # Create a queryset for just this instance
236
236
  instance_qs = self.model_cls.objects.filter(pk=instance.pk)
237
-
237
+
238
238
  # Use annotate with the expression and let Django resolve it
239
239
  resolved_value = instance_qs.annotate(
240
- _resolved_value=expression
241
- ).values_list('_resolved_value', flat=True).first()
242
-
240
+ _resolved_value=expression,
241
+ ).values_list("_resolved_value", flat=True).first()
242
+
243
243
  return resolved_value
244
244
  except Exception as e:
245
245
  # If expression resolution fails, log and return original
246
246
  logger.warning(
247
247
  f"Failed to resolve expression for field '{field_name}' "
248
- f"on {self.model_cls.__name__}: {e}. Using original value."
248
+ f"on {self.model_cls.__name__}: {e}. Using original value.",
249
249
  )
250
250
  return expression
251
251
 
@@ -266,12 +266,12 @@ class ModelAnalyzer:
266
266
  """
267
267
  if not instances or not update_kwargs:
268
268
  return []
269
-
269
+
270
270
  fields_updated = list(update_kwargs.keys())
271
-
271
+
272
272
  for field_name, value in update_kwargs.items():
273
273
  for instance in instances:
274
274
  resolved_value = self.resolve_expression(field_name, value, instance)
275
275
  setattr(instance, field_name, resolved_value)
276
-
277
- return fields_updated
276
+
277
+ return fields_updated
@@ -5,6 +5,7 @@ This service coordinates bulk database operations with validation and MTI handli
5
5
  """
6
6
 
7
7
  import logging
8
+
8
9
  from django.db import transaction
9
10
  from django.db.models import AutoField
10
11
 
@@ -21,7 +22,7 @@ class BulkExecutor:
21
22
  Dependencies are explicitly injected via constructor.
22
23
  """
23
24
 
24
- def __init__(self, queryset, analyzer, mti_handler):
25
+ def __init__(self, queryset, analyzer, mti_handler, record_classifier):
25
26
  """
26
27
  Initialize bulk executor with explicit dependencies.
27
28
 
@@ -29,10 +30,12 @@ class BulkExecutor:
29
30
  queryset: Django QuerySet instance
30
31
  analyzer: ModelAnalyzer instance (replaces validator + field_tracker)
31
32
  mti_handler: MTIHandler instance
33
+ record_classifier: RecordClassifier instance
32
34
  """
33
35
  self.queryset = queryset
34
36
  self.analyzer = analyzer
35
37
  self.mti_handler = mti_handler
38
+ self.record_classifier = record_classifier
36
39
  self.model_cls = queryset.model
37
40
 
38
41
  def bulk_create(
@@ -69,13 +72,24 @@ class BulkExecutor:
69
72
  # Check if this is an MTI model and route accordingly
70
73
  if self.mti_handler.is_mti_model():
71
74
  logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk create")
72
- # Build execution plan
75
+
76
+ # Classify records using the classifier service
77
+ existing_record_ids = set()
78
+ existing_pks_map = {}
79
+ if update_conflicts and unique_fields:
80
+ existing_record_ids, existing_pks_map = (
81
+ self.record_classifier.classify_for_upsert(objs, unique_fields)
82
+ )
83
+
84
+ # Build execution plan with classification results
73
85
  plan = self.mti_handler.build_create_plan(
74
86
  objs,
75
87
  batch_size=batch_size,
76
88
  update_conflicts=update_conflicts,
77
89
  update_fields=update_fields,
78
90
  unique_fields=unique_fields,
91
+ existing_record_ids=existing_record_ids,
92
+ existing_pks_map=existing_pks_map,
79
93
  )
80
94
  # Execute the plan
81
95
  return self._execute_mti_create_plan(plan)
@@ -161,134 +175,203 @@ class BulkExecutor:
161
175
  Execute an MTI create plan.
162
176
 
163
177
  This is where ALL database operations happen for MTI bulk_create.
178
+ Handles both new records (INSERT) and existing records (UPDATE) for upsert.
164
179
 
165
180
  Args:
166
181
  plan: MTICreatePlan object from MTIHandler
167
182
 
168
183
  Returns:
169
- List of created objects with PKs assigned
184
+ List of created/updated objects with PKs assigned
170
185
  """
171
- from django.db import transaction
172
186
  from django.db.models import QuerySet as BaseQuerySet
173
-
187
+
174
188
  if not plan:
175
189
  return []
176
-
190
+
177
191
  with transaction.atomic(using=self.queryset.db, savepoint=False):
178
- # Step 1: Create all parent objects level by level
192
+ # Step 1: Create/Update all parent objects level by level
179
193
  parent_instances_map = {} # Maps original obj id() -> {model: parent_instance}
180
-
194
+
181
195
  for parent_level in plan.parent_levels:
182
- # Bulk create parents for this level
183
- bulk_kwargs = {"batch_size": len(parent_level.objects)}
184
-
185
- if parent_level.update_conflicts:
186
- bulk_kwargs["update_conflicts"] = True
187
- bulk_kwargs["unique_fields"] = parent_level.unique_fields
188
- bulk_kwargs["update_fields"] = parent_level.update_fields
189
-
190
- # Use base QuerySet to avoid recursion
191
- base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
192
- created_parents = base_qs.bulk_create(parent_level.objects, **bulk_kwargs)
193
-
194
- # Copy generated fields back to parent objects
195
- for created_parent, parent_obj in zip(created_parents, parent_level.objects):
196
- for field in parent_level.model_class._meta.local_fields:
197
- created_value = getattr(created_parent, field.name, None)
198
- if created_value is not None:
199
- setattr(parent_obj, field.name, created_value)
200
-
201
- parent_obj._state.adding = False
202
- parent_obj._state.db = self.queryset.db
203
-
196
+ # Separate new and existing parent objects
197
+ new_parents = []
198
+ existing_parents = []
199
+
200
+ for parent_obj in parent_level.objects:
201
+ orig_obj_id = parent_level.original_object_map[id(parent_obj)]
202
+ if orig_obj_id in plan.existing_record_ids:
203
+ existing_parents.append(parent_obj)
204
+ else:
205
+ new_parents.append(parent_obj)
206
+
207
+ # Bulk create new parents
208
+ if new_parents:
209
+ bulk_kwargs = {"batch_size": len(new_parents)}
210
+
211
+ if parent_level.update_conflicts:
212
+ bulk_kwargs["update_conflicts"] = True
213
+ bulk_kwargs["unique_fields"] = parent_level.unique_fields
214
+ bulk_kwargs["update_fields"] = parent_level.update_fields
215
+
216
+ # Use base QuerySet to avoid recursion
217
+ base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
218
+ created_parents = base_qs.bulk_create(new_parents, **bulk_kwargs)
219
+
220
+ # Copy generated fields back to parent objects
221
+ for created_parent, parent_obj in zip(created_parents, new_parents):
222
+ for field in parent_level.model_class._meta.local_fields:
223
+ created_value = getattr(created_parent, field.name, None)
224
+ if created_value is not None:
225
+ setattr(parent_obj, field.name, created_value)
226
+
227
+ parent_obj._state.adding = False
228
+ parent_obj._state.db = self.queryset.db
229
+
230
+ # Update existing parents
231
+ if existing_parents and parent_level.update_fields:
232
+ # Filter update fields to only those that exist in this parent model
233
+ parent_model_fields = {
234
+ field.name for field in parent_level.model_class._meta.local_fields
235
+ }
236
+ filtered_update_fields = [
237
+ field for field in parent_level.update_fields
238
+ if field in parent_model_fields
239
+ ]
240
+
241
+ if filtered_update_fields:
242
+ base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
243
+ base_qs.bulk_update(existing_parents, filtered_update_fields)
244
+
245
+ # Mark as not adding
246
+ for parent_obj in existing_parents:
247
+ parent_obj._state.adding = False
248
+ parent_obj._state.db = self.queryset.db
249
+
204
250
  # Map parents back to original objects
205
251
  for parent_obj in parent_level.objects:
206
252
  orig_obj_id = parent_level.original_object_map[id(parent_obj)]
207
253
  if orig_obj_id not in parent_instances_map:
208
254
  parent_instances_map[orig_obj_id] = {}
209
255
  parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
210
-
211
- # Step 2: Add parent links to child objects
256
+
257
+ # Step 2: Add parent links to child objects and separate new/existing
258
+ new_child_objects = []
259
+ existing_child_objects = []
260
+
212
261
  for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
213
262
  parent_instances = parent_instances_map.get(id(orig_obj), {})
214
-
263
+
264
+ # Set parent links
215
265
  for parent_model, parent_instance in parent_instances.items():
216
266
  parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
217
267
  if parent_link:
218
268
  setattr(child_obj, parent_link.attname, parent_instance.pk)
219
269
  setattr(child_obj, parent_link.name, parent_instance)
220
-
221
- # Step 3: Bulk create child objects using _batched_insert (to bypass MTI check)
222
- base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
223
- base_qs._prepare_for_bulk_create(plan.child_objects)
224
-
225
- # Partition objects by PK status
226
- objs_without_pk, objs_with_pk = [], []
227
- for obj in plan.child_objects:
228
- if obj._is_pk_set():
229
- objs_with_pk.append(obj)
270
+
271
+ # Classify as new or existing
272
+ if id(orig_obj) in plan.existing_record_ids:
273
+ # For existing records, set the PK on child object
274
+ pk_value = getattr(orig_obj, "pk", None)
275
+ if pk_value:
276
+ child_obj.pk = pk_value
277
+ child_obj.id = pk_value
278
+ existing_child_objects.append(child_obj)
230
279
  else:
231
- objs_without_pk.append(obj)
232
-
233
- # Get fields for insert
234
- opts = plan.child_model._meta
235
- fields = [f for f in opts.local_fields if not f.generated]
236
-
237
- # Execute bulk insert
238
- if objs_with_pk:
239
- returned_columns = base_qs._batched_insert(
240
- objs_with_pk,
241
- fields,
242
- batch_size=len(objs_with_pk),
243
- )
244
- if returned_columns:
245
- for obj, results in zip(objs_with_pk, returned_columns):
246
- if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
247
- for result, field in zip(results, opts.db_returning_fields):
248
- if field != opts.pk:
280
+ new_child_objects.append(child_obj)
281
+
282
+ # Step 3: Bulk create new child objects using _batched_insert (to bypass MTI check)
283
+ if new_child_objects:
284
+ base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
285
+ base_qs._prepare_for_bulk_create(new_child_objects)
286
+
287
+ # Partition objects by PK status
288
+ objs_without_pk, objs_with_pk = [], []
289
+ for obj in new_child_objects:
290
+ if obj._is_pk_set():
291
+ objs_with_pk.append(obj)
292
+ else:
293
+ objs_without_pk.append(obj)
294
+
295
+ # Get fields for insert
296
+ opts = plan.child_model._meta
297
+ fields = [f for f in opts.local_fields if not f.generated]
298
+
299
+ # Execute bulk insert
300
+ if objs_with_pk:
301
+ returned_columns = base_qs._batched_insert(
302
+ objs_with_pk,
303
+ fields,
304
+ batch_size=len(objs_with_pk),
305
+ )
306
+ if returned_columns:
307
+ for obj, results in zip(objs_with_pk, returned_columns):
308
+ if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
309
+ for result, field in zip(results, opts.db_returning_fields):
310
+ if field != opts.pk:
311
+ setattr(obj, field.attname, result)
312
+ obj._state.adding = False
313
+ obj._state.db = self.queryset.db
314
+ else:
315
+ for obj in objs_with_pk:
316
+ obj._state.adding = False
317
+ obj._state.db = self.queryset.db
318
+
319
+ if objs_without_pk:
320
+ filtered_fields = [
321
+ f for f in fields
322
+ if not isinstance(f, AutoField) and not f.primary_key
323
+ ]
324
+ returned_columns = base_qs._batched_insert(
325
+ objs_without_pk,
326
+ filtered_fields,
327
+ batch_size=len(objs_without_pk),
328
+ )
329
+ if returned_columns:
330
+ for obj, results in zip(objs_without_pk, returned_columns):
331
+ if hasattr(opts, "db_returning_fields"):
332
+ for result, field in zip(results, opts.db_returning_fields):
249
333
  setattr(obj, field.attname, result)
250
- obj._state.adding = False
251
- obj._state.db = self.queryset.db
252
- else:
253
- for obj in objs_with_pk:
254
- obj._state.adding = False
255
- obj._state.db = self.queryset.db
256
-
257
- if objs_without_pk:
258
- filtered_fields = [
259
- f for f in fields
260
- if not isinstance(f, AutoField) and not f.primary_key
334
+ obj._state.adding = False
335
+ obj._state.db = self.queryset.db
336
+ else:
337
+ for obj in objs_without_pk:
338
+ obj._state.adding = False
339
+ obj._state.db = self.queryset.db
340
+
341
+ # Step 3.5: Update existing child objects
342
+ if existing_child_objects and plan.update_fields:
343
+ # Filter update fields to only those that exist in the child model
344
+ child_model_fields = {
345
+ field.name for field in plan.child_model._meta.local_fields
346
+ }
347
+ filtered_child_update_fields = [
348
+ field for field in plan.update_fields
349
+ if field in child_model_fields
261
350
  ]
262
- returned_columns = base_qs._batched_insert(
263
- objs_without_pk,
264
- filtered_fields,
265
- batch_size=len(objs_without_pk),
266
- )
267
- if returned_columns:
268
- for obj, results in zip(objs_without_pk, returned_columns):
269
- if hasattr(opts, "db_returning_fields"):
270
- for result, field in zip(results, opts.db_returning_fields):
271
- setattr(obj, field.attname, result)
272
- obj._state.adding = False
273
- obj._state.db = self.queryset.db
274
- else:
275
- for obj in objs_without_pk:
276
- obj._state.adding = False
277
- obj._state.db = self.queryset.db
278
-
279
- created_children = plan.child_objects
280
-
351
+
352
+ if filtered_child_update_fields:
353
+ base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
354
+ base_qs.bulk_update(existing_child_objects, filtered_child_update_fields)
355
+
356
+ # Mark as not adding
357
+ for child_obj in existing_child_objects:
358
+ child_obj._state.adding = False
359
+ child_obj._state.db = self.queryset.db
360
+
361
+ # Combine all children for final processing
362
+ created_children = new_child_objects + existing_child_objects
363
+
281
364
  # Step 4: Copy PKs and auto-generated fields back to original objects
282
365
  pk_field_name = plan.child_model._meta.pk.name
283
-
366
+
284
367
  for orig_obj, child_obj in zip(plan.original_objects, created_children):
285
368
  # Copy PK
286
369
  child_pk = getattr(child_obj, pk_field_name)
287
370
  setattr(orig_obj, pk_field_name, child_pk)
288
-
371
+
289
372
  # Copy auto-generated fields from all levels
290
373
  parent_instances = parent_instances_map.get(id(orig_obj), {})
291
-
374
+
292
375
  for model_class in plan.inheritance_chain:
293
376
  # Get source object for this level
294
377
  if model_class in parent_instances:
@@ -297,30 +380,30 @@ class BulkExecutor:
297
380
  source_obj = child_obj
298
381
  else:
299
382
  continue
300
-
383
+
301
384
  # Copy auto-generated field values
302
385
  for field in model_class._meta.local_fields:
303
386
  if field.name == pk_field_name:
304
387
  continue
305
-
388
+
306
389
  # Skip parent link fields
307
- if hasattr(field, 'remote_field') and field.remote_field:
390
+ if hasattr(field, "remote_field") and field.remote_field:
308
391
  parent_link = plan.child_model._meta.get_ancestor_link(model_class)
309
392
  if parent_link and field.name == parent_link.name:
310
393
  continue
311
-
394
+
312
395
  # Copy auto_now_add, auto_now, and db_returning fields
313
- if (getattr(field, 'auto_now_add', False) or
314
- getattr(field, 'auto_now', False) or
315
- getattr(field, 'db_returning', False)):
396
+ if (getattr(field, "auto_now_add", False) or
397
+ getattr(field, "auto_now", False) or
398
+ getattr(field, "db_returning", False)):
316
399
  source_value = getattr(source_obj, field.name, None)
317
400
  if source_value is not None:
318
401
  setattr(orig_obj, field.name, source_value)
319
-
402
+
320
403
  # Update object state
321
404
  orig_obj._state.adding = False
322
405
  orig_obj._state.db = self.queryset.db
323
-
406
+
324
407
  return plan.original_objects
325
408
 
326
409
  def _execute_mti_update_plan(self, plan):
@@ -335,86 +418,94 @@ class BulkExecutor:
335
418
  Returns:
336
419
  Number of objects updated
337
420
  """
338
- from django.db import transaction
339
- from django.db.models import Case, Value, When, QuerySet as BaseQuerySet
340
-
421
+ from django.db.models import Case
422
+ from django.db.models import QuerySet as BaseQuerySet
423
+ from django.db.models import Value
424
+ from django.db.models import When
425
+
341
426
  if not plan:
342
427
  return 0
343
-
428
+
344
429
  total_updated = 0
345
-
430
+
346
431
  # Get PKs for filtering
347
432
  root_pks = [
348
- getattr(obj, "pk", None) or getattr(obj, "id", None)
349
- for obj in plan.objects
433
+ getattr(obj, "pk", None) or getattr(obj, "id", None)
434
+ for obj in plan.objects
350
435
  if getattr(obj, "pk", None) or getattr(obj, "id", None)
351
436
  ]
352
-
437
+
353
438
  if not root_pks:
354
439
  return 0
355
-
440
+
356
441
  with transaction.atomic(using=self.queryset.db, savepoint=False):
357
442
  # Update each table in the chain
358
443
  for field_group in plan.field_groups:
359
444
  if not field_group.fields:
360
445
  continue
361
-
446
+
362
447
  base_qs = BaseQuerySet(model=field_group.model_class, using=self.queryset.db)
363
-
448
+
364
449
  # Check if records exist
365
450
  existing_count = base_qs.filter(**{f"{field_group.filter_field}__in": root_pks}).count()
366
451
  if existing_count == 0:
367
452
  continue
368
-
453
+
369
454
  # Build CASE statements for bulk update
370
455
  case_statements = {}
371
456
  for field_name in field_group.fields:
372
457
  field = field_group.model_class._meta.get_field(field_name)
373
-
458
+
374
459
  # Use column name for FK fields
375
- if getattr(field, 'is_relation', False) and hasattr(field, 'attname'):
460
+ if getattr(field, "is_relation", False) and hasattr(field, "attname"):
376
461
  db_field_name = field.attname
377
462
  target_field = field.target_field
378
463
  else:
379
464
  db_field_name = field_name
380
465
  target_field = field
381
-
466
+
382
467
  when_statements = []
383
468
  for pk, obj in zip(root_pks, plan.objects):
384
469
  obj_pk = getattr(obj, "pk", None) or getattr(obj, "id", None)
385
470
  if obj_pk is None:
386
471
  continue
387
-
472
+
388
473
  value = getattr(obj, db_field_name)
389
-
474
+
390
475
  # For FK fields, ensure we get the actual ID value, not the related object
391
- if getattr(field, 'is_relation', False) and hasattr(field, 'attname'):
476
+ if getattr(field, "is_relation", False) and hasattr(field, "attname"):
392
477
  # If value is a model instance, get its pk
393
- if value is not None and hasattr(value, 'pk'):
478
+ if value is not None and hasattr(value, "pk"):
394
479
  value = value.pk
395
-
480
+ # If value is a string representation of an ID, convert to int
481
+ elif value is not None and isinstance(value, str) and value.isdigit():
482
+ value = int(value)
483
+ # If value is None or empty string, ensure it's None
484
+ elif value == "":
485
+ value = None
486
+
396
487
  when_statements.append(
397
488
  When(
398
489
  **{field_group.filter_field: pk},
399
490
  then=Value(value, output_field=target_field),
400
- )
491
+ ),
401
492
  )
402
-
493
+
403
494
  if when_statements:
404
495
  case_statements[db_field_name] = Case(
405
- *when_statements, output_field=target_field
496
+ *when_statements, output_field=target_field,
406
497
  )
407
-
498
+
408
499
  # Execute bulk update
409
500
  if case_statements:
410
501
  try:
411
502
  updated_count = base_qs.filter(
412
- **{f"{field_group.filter_field}__in": root_pks}
503
+ **{f"{field_group.filter_field}__in": root_pks},
413
504
  ).update(**case_statements)
414
505
  total_updated += updated_count
415
506
  except Exception as e:
416
507
  logger.error(f"MTI bulk update failed for {field_group.model_class.__name__}: {e}")
417
-
508
+
418
509
  return total_updated
419
510
 
420
511
  def delete_queryset(self):