django-bulk-hooks 0.2.15__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of django-bulk-hooks might be problematic. Click here for more details.

Files changed (26) hide show
  1. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/PKG-INFO +1 -1
  2. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/decorators.py +7 -1
  3. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/dispatcher.py +10 -0
  4. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/bulk_executor.py +159 -76
  5. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/coordinator.py +23 -2
  6. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/mti_handler.py +30 -1
  7. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/mti_plans.py +8 -0
  8. django_bulk_hooks-0.2.16/django_bulk_hooks/operations/record_classifier.py +183 -0
  9. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/registry.py +15 -0
  10. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/pyproject.toml +1 -1
  11. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/LICENSE +0 -0
  12. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/README.md +0 -0
  13. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/__init__.py +0 -0
  14. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/changeset.py +0 -0
  15. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/conditions.py +0 -0
  16. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/constants.py +0 -0
  17. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/context.py +0 -0
  18. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/enums.py +0 -0
  19. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/factory.py +0 -0
  20. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/handler.py +0 -0
  21. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/helpers.py +0 -0
  22. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/manager.py +0 -0
  23. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/models.py +0 -0
  24. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/__init__.py +0 -0
  25. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/operations/analyzer.py +0 -0
  26. {django_bulk_hooks-0.2.15 → django_bulk_hooks-0.2.16}/django_bulk_hooks/queryset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: django-bulk-hooks
3
- Version: 0.2.15
3
+ Version: 0.2.16
4
4
  Summary: Hook-style hooks for Django bulk operations like bulk_create and bulk_update.
5
5
  License: MIT
6
6
  Keywords: django,bulk,hooks
@@ -290,7 +290,13 @@ def bulk_hook(model_cls, event, when=None, priority=None):
290
290
  return self.func(changeset, new_records, old_records, **kwargs)
291
291
  else:
292
292
  # Old signature without changeset
293
- return self.func(new_records, old_records, **kwargs)
293
+ # Only pass changeset in kwargs if the function accepts **kwargs
294
+ if 'kwargs' in params or any(param.startswith('**') for param in sig.parameters):
295
+ kwargs['changeset'] = changeset
296
+ return self.func(new_records, old_records, **kwargs)
297
+ else:
298
+ # Function doesn't accept **kwargs, just call with positional args
299
+ return self.func(new_records, old_records)
294
300
 
295
301
  # Register the hook using the registry
296
302
  register_hook(
@@ -244,3 +244,13 @@ def get_dispatcher():
244
244
  # Create dispatcher with the registry instance
245
245
  _dispatcher = HookDispatcher(get_registry())
246
246
  return _dispatcher
247
+
248
+
249
+ def reset_dispatcher():
250
+ """
251
+ Reset the global dispatcher instance.
252
+
253
+ Useful for testing to ensure clean state between tests.
254
+ """
255
+ global _dispatcher
256
+ _dispatcher = None
@@ -21,7 +21,7 @@ class BulkExecutor:
21
21
  Dependencies are explicitly injected via constructor.
22
22
  """
23
23
 
24
- def __init__(self, queryset, analyzer, mti_handler):
24
+ def __init__(self, queryset, analyzer, mti_handler, record_classifier):
25
25
  """
26
26
  Initialize bulk executor with explicit dependencies.
27
27
 
@@ -29,10 +29,12 @@ class BulkExecutor:
29
29
  queryset: Django QuerySet instance
30
30
  analyzer: ModelAnalyzer instance (replaces validator + field_tracker)
31
31
  mti_handler: MTIHandler instance
32
+ record_classifier: RecordClassifier instance
32
33
  """
33
34
  self.queryset = queryset
34
35
  self.analyzer = analyzer
35
36
  self.mti_handler = mti_handler
37
+ self.record_classifier = record_classifier
36
38
  self.model_cls = queryset.model
37
39
 
38
40
  def bulk_create(
@@ -69,13 +71,24 @@ class BulkExecutor:
69
71
  # Check if this is an MTI model and route accordingly
70
72
  if self.mti_handler.is_mti_model():
71
73
  logger.info(f"Detected MTI model {self.model_cls.__name__}, using MTI bulk create")
72
- # Build execution plan
74
+
75
+ # Classify records using the classifier service
76
+ existing_record_ids = set()
77
+ existing_pks_map = {}
78
+ if update_conflicts and unique_fields:
79
+ existing_record_ids, existing_pks_map = (
80
+ self.record_classifier.classify_for_upsert(objs, unique_fields)
81
+ )
82
+
83
+ # Build execution plan with classification results
73
84
  plan = self.mti_handler.build_create_plan(
74
85
  objs,
75
86
  batch_size=batch_size,
76
87
  update_conflicts=update_conflicts,
77
88
  update_fields=update_fields,
78
89
  unique_fields=unique_fields,
90
+ existing_record_ids=existing_record_ids,
91
+ existing_pks_map=existing_pks_map,
79
92
  )
80
93
  # Execute the plan
81
94
  return self._execute_mti_create_plan(plan)
@@ -161,12 +174,13 @@ class BulkExecutor:
161
174
  Execute an MTI create plan.
162
175
 
163
176
  This is where ALL database operations happen for MTI bulk_create.
177
+ Handles both new records (INSERT) and existing records (UPDATE) for upsert.
164
178
 
165
179
  Args:
166
180
  plan: MTICreatePlan object from MTIHandler
167
181
 
168
182
  Returns:
169
- List of created objects with PKs assigned
183
+ List of created/updated objects with PKs assigned
170
184
  """
171
185
  from django.db import transaction
172
186
  from django.db.models import QuerySet as BaseQuerySet
@@ -175,31 +189,63 @@ class BulkExecutor:
175
189
  return []
176
190
 
177
191
  with transaction.atomic(using=self.queryset.db, savepoint=False):
178
- # Step 1: Create all parent objects level by level
192
+ # Step 1: Create/Update all parent objects level by level
179
193
  parent_instances_map = {} # Maps original obj id() -> {model: parent_instance}
180
194
 
181
195
  for parent_level in plan.parent_levels:
182
- # Bulk create parents for this level
183
- bulk_kwargs = {"batch_size": len(parent_level.objects)}
196
+ # Separate new and existing parent objects
197
+ new_parents = []
198
+ existing_parents = []
184
199
 
185
- if parent_level.update_conflicts:
186
- bulk_kwargs["update_conflicts"] = True
187
- bulk_kwargs["unique_fields"] = parent_level.unique_fields
188
- bulk_kwargs["update_fields"] = parent_level.update_fields
200
+ for parent_obj in parent_level.objects:
201
+ orig_obj_id = parent_level.original_object_map[id(parent_obj)]
202
+ if orig_obj_id in plan.existing_record_ids:
203
+ existing_parents.append(parent_obj)
204
+ else:
205
+ new_parents.append(parent_obj)
189
206
 
190
- # Use base QuerySet to avoid recursion
191
- base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
192
- created_parents = base_qs.bulk_create(parent_level.objects, **bulk_kwargs)
207
+ # Bulk create new parents
208
+ if new_parents:
209
+ bulk_kwargs = {"batch_size": len(new_parents)}
210
+
211
+ if parent_level.update_conflicts:
212
+ bulk_kwargs["update_conflicts"] = True
213
+ bulk_kwargs["unique_fields"] = parent_level.unique_fields
214
+ bulk_kwargs["update_fields"] = parent_level.update_fields
215
+
216
+ # Use base QuerySet to avoid recursion
217
+ base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
218
+ created_parents = base_qs.bulk_create(new_parents, **bulk_kwargs)
219
+
220
+ # Copy generated fields back to parent objects
221
+ for created_parent, parent_obj in zip(created_parents, new_parents):
222
+ for field in parent_level.model_class._meta.local_fields:
223
+ created_value = getattr(created_parent, field.name, None)
224
+ if created_value is not None:
225
+ setattr(parent_obj, field.name, created_value)
226
+
227
+ parent_obj._state.adding = False
228
+ parent_obj._state.db = self.queryset.db
193
229
 
194
- # Copy generated fields back to parent objects
195
- for created_parent, parent_obj in zip(created_parents, parent_level.objects):
196
- for field in parent_level.model_class._meta.local_fields:
197
- created_value = getattr(created_parent, field.name, None)
198
- if created_value is not None:
199
- setattr(parent_obj, field.name, created_value)
230
+ # Update existing parents
231
+ if existing_parents and parent_level.update_fields:
232
+ # Filter update fields to only those that exist in this parent model
233
+ parent_model_fields = {
234
+ field.name for field in parent_level.model_class._meta.local_fields
235
+ }
236
+ filtered_update_fields = [
237
+ field for field in parent_level.update_fields
238
+ if field in parent_model_fields
239
+ ]
240
+
241
+ if filtered_update_fields:
242
+ base_qs = BaseQuerySet(model=parent_level.model_class, using=self.queryset.db)
243
+ base_qs.bulk_update(existing_parents, filtered_update_fields)
200
244
 
201
- parent_obj._state.adding = False
202
- parent_obj._state.db = self.queryset.db
245
+ # Mark as not adding
246
+ for parent_obj in existing_parents:
247
+ parent_obj._state.adding = False
248
+ parent_obj._state.db = self.queryset.db
203
249
 
204
250
  # Map parents back to original objects
205
251
  for parent_obj in parent_level.objects:
@@ -208,75 +254,112 @@ class BulkExecutor:
208
254
  parent_instances_map[orig_obj_id] = {}
209
255
  parent_instances_map[orig_obj_id][parent_level.model_class] = parent_obj
210
256
 
211
- # Step 2: Add parent links to child objects
257
+ # Step 2: Add parent links to child objects and separate new/existing
258
+ new_child_objects = []
259
+ existing_child_objects = []
260
+
212
261
  for child_obj, orig_obj in zip(plan.child_objects, plan.original_objects):
213
262
  parent_instances = parent_instances_map.get(id(orig_obj), {})
214
263
 
264
+ # Set parent links
215
265
  for parent_model, parent_instance in parent_instances.items():
216
266
  parent_link = plan.child_model._meta.get_ancestor_link(parent_model)
217
267
  if parent_link:
218
268
  setattr(child_obj, parent_link.attname, parent_instance.pk)
219
269
  setattr(child_obj, parent_link.name, parent_instance)
220
-
221
- # Step 3: Bulk create child objects using _batched_insert (to bypass MTI check)
222
- base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
223
- base_qs._prepare_for_bulk_create(plan.child_objects)
224
-
225
- # Partition objects by PK status
226
- objs_without_pk, objs_with_pk = [], []
227
- for obj in plan.child_objects:
228
- if obj._is_pk_set():
229
- objs_with_pk.append(obj)
270
+
271
+ # Classify as new or existing
272
+ if id(orig_obj) in plan.existing_record_ids:
273
+ # For existing records, set the PK on child object
274
+ pk_value = getattr(orig_obj, 'pk', None)
275
+ if pk_value:
276
+ setattr(child_obj, 'pk', pk_value)
277
+ setattr(child_obj, 'id', pk_value)
278
+ existing_child_objects.append(child_obj)
230
279
  else:
231
- objs_without_pk.append(obj)
232
-
233
- # Get fields for insert
234
- opts = plan.child_model._meta
235
- fields = [f for f in opts.local_fields if not f.generated]
280
+ new_child_objects.append(child_obj)
236
281
 
237
- # Execute bulk insert
238
- if objs_with_pk:
239
- returned_columns = base_qs._batched_insert(
240
- objs_with_pk,
241
- fields,
242
- batch_size=len(objs_with_pk),
243
- )
244
- if returned_columns:
245
- for obj, results in zip(objs_with_pk, returned_columns):
246
- if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
247
- for result, field in zip(results, opts.db_returning_fields):
248
- if field != opts.pk:
282
+ # Step 3: Bulk create new child objects using _batched_insert (to bypass MTI check)
283
+ if new_child_objects:
284
+ base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
285
+ base_qs._prepare_for_bulk_create(new_child_objects)
286
+
287
+ # Partition objects by PK status
288
+ objs_without_pk, objs_with_pk = [], []
289
+ for obj in new_child_objects:
290
+ if obj._is_pk_set():
291
+ objs_with_pk.append(obj)
292
+ else:
293
+ objs_without_pk.append(obj)
294
+
295
+ # Get fields for insert
296
+ opts = plan.child_model._meta
297
+ fields = [f for f in opts.local_fields if not f.generated]
298
+
299
+ # Execute bulk insert
300
+ if objs_with_pk:
301
+ returned_columns = base_qs._batched_insert(
302
+ objs_with_pk,
303
+ fields,
304
+ batch_size=len(objs_with_pk),
305
+ )
306
+ if returned_columns:
307
+ for obj, results in zip(objs_with_pk, returned_columns):
308
+ if hasattr(opts, "db_returning_fields") and hasattr(opts, "pk"):
309
+ for result, field in zip(results, opts.db_returning_fields):
310
+ if field != opts.pk:
311
+ setattr(obj, field.attname, result)
312
+ obj._state.adding = False
313
+ obj._state.db = self.queryset.db
314
+ else:
315
+ for obj in objs_with_pk:
316
+ obj._state.adding = False
317
+ obj._state.db = self.queryset.db
318
+
319
+ if objs_without_pk:
320
+ filtered_fields = [
321
+ f for f in fields
322
+ if not isinstance(f, AutoField) and not f.primary_key
323
+ ]
324
+ returned_columns = base_qs._batched_insert(
325
+ objs_without_pk,
326
+ filtered_fields,
327
+ batch_size=len(objs_without_pk),
328
+ )
329
+ if returned_columns:
330
+ for obj, results in zip(objs_without_pk, returned_columns):
331
+ if hasattr(opts, "db_returning_fields"):
332
+ for result, field in zip(results, opts.db_returning_fields):
249
333
  setattr(obj, field.attname, result)
250
- obj._state.adding = False
251
- obj._state.db = self.queryset.db
252
- else:
253
- for obj in objs_with_pk:
254
- obj._state.adding = False
255
- obj._state.db = self.queryset.db
334
+ obj._state.adding = False
335
+ obj._state.db = self.queryset.db
336
+ else:
337
+ for obj in objs_without_pk:
338
+ obj._state.adding = False
339
+ obj._state.db = self.queryset.db
256
340
 
257
- if objs_without_pk:
258
- filtered_fields = [
259
- f for f in fields
260
- if not isinstance(f, AutoField) and not f.primary_key
341
+ # Step 3.5: Update existing child objects
342
+ if existing_child_objects and plan.update_fields:
343
+ # Filter update fields to only those that exist in the child model
344
+ child_model_fields = {
345
+ field.name for field in plan.child_model._meta.local_fields
346
+ }
347
+ filtered_child_update_fields = [
348
+ field for field in plan.update_fields
349
+ if field in child_model_fields
261
350
  ]
262
- returned_columns = base_qs._batched_insert(
263
- objs_without_pk,
264
- filtered_fields,
265
- batch_size=len(objs_without_pk),
266
- )
267
- if returned_columns:
268
- for obj, results in zip(objs_without_pk, returned_columns):
269
- if hasattr(opts, "db_returning_fields"):
270
- for result, field in zip(results, opts.db_returning_fields):
271
- setattr(obj, field.attname, result)
272
- obj._state.adding = False
273
- obj._state.db = self.queryset.db
274
- else:
275
- for obj in objs_without_pk:
276
- obj._state.adding = False
277
- obj._state.db = self.queryset.db
351
+
352
+ if filtered_child_update_fields:
353
+ base_qs = BaseQuerySet(model=plan.child_model, using=self.queryset.db)
354
+ base_qs.bulk_update(existing_child_objects, filtered_child_update_fields)
355
+
356
+ # Mark as not adding
357
+ for child_obj in existing_child_objects:
358
+ child_obj._state.adding = False
359
+ child_obj._state.db = self.queryset.db
278
360
 
279
- created_children = plan.child_objects
361
+ # Combine all children for final processing
362
+ created_children = new_child_objects + existing_child_objects
280
363
 
281
364
  # Step 4: Copy PKs and auto-generated fields back to original objects
282
365
  pk_field_name = plan.child_model._meta.pk.name
@@ -44,6 +44,7 @@ class BulkOperationCoordinator:
44
44
  # Lazy initialization
45
45
  self._analyzer = None
46
46
  self._mti_handler = None
47
+ self._record_classifier = None
47
48
  self._executor = None
48
49
  self._dispatcher = None
49
50
 
@@ -65,6 +66,15 @@ class BulkOperationCoordinator:
65
66
  self._mti_handler = MTIHandler(self.model_cls)
66
67
  return self._mti_handler
67
68
 
69
+ @property
70
+ def record_classifier(self):
71
+ """Get or create RecordClassifier"""
72
+ if self._record_classifier is None:
73
+ from django_bulk_hooks.operations.record_classifier import RecordClassifier
74
+
75
+ self._record_classifier = RecordClassifier(self.model_cls)
76
+ return self._record_classifier
77
+
68
78
  @property
69
79
  def executor(self):
70
80
  """Get or create BulkExecutor"""
@@ -75,6 +85,7 @@ class BulkOperationCoordinator:
75
85
  queryset=self.queryset,
76
86
  analyzer=self.analyzer,
77
87
  mti_handler=self.mti_handler,
88
+ record_classifier=self.record_classifier,
78
89
  )
79
90
  return self._executor
80
91
 
@@ -298,7 +309,9 @@ class BulkOperationCoordinator:
298
309
 
299
310
  # Step 3: Fetch new state (after database update)
300
311
  # This captures any Subquery/F() computed values
301
- new_instances = list(self.queryset)
312
+ # Use primary keys to fetch updated instances since queryset filters may no longer match
313
+ pks = [inst.pk for inst in old_instances]
314
+ new_instances = list(self.model_cls.objects.filter(pk__in=pks))
302
315
 
303
316
  # Step 4: Build changeset
304
317
  changeset = build_changeset_for_update(
@@ -338,7 +351,10 @@ class BulkOperationCoordinator:
338
351
  if modified_fields:
339
352
  self._persist_hook_modifications(new_instances, modified_fields)
340
353
 
341
- # Step 9: Run AFTER_UPDATE hooks (read-only side effects)
354
+ # Step 9: Take snapshot before AFTER_UPDATE hooks
355
+ pre_after_hook_state = self._snapshot_instance_state(new_instances)
356
+
357
+ # Step 10: Run AFTER_UPDATE hooks (read-only side effects)
342
358
  for model_cls in models_in_chain:
343
359
  model_changeset = self._build_changeset_for_model(changeset, model_cls)
344
360
  self.dispatcher.dispatch(
@@ -347,6 +363,11 @@ class BulkOperationCoordinator:
347
363
  bypass_hooks=False
348
364
  )
349
365
 
366
+ # Step 11: Auto-persist AFTER_UPDATE modifications (if any)
367
+ after_modified_fields = self._detect_modifications(new_instances, pre_after_hook_state)
368
+ if after_modified_fields:
369
+ self._persist_hook_modifications(new_instances, after_modified_fields)
370
+
350
371
  return update_count
351
372
 
352
373
  def _run_before_update_hooks_with_tracking(self, instances, models_in_chain, changeset):
@@ -121,6 +121,8 @@ class MTIHandler:
121
121
  update_conflicts=False,
122
122
  unique_fields=None,
123
123
  update_fields=None,
124
+ existing_record_ids=None,
125
+ existing_pks_map=None,
124
126
  ):
125
127
  """
126
128
  Build an execution plan for bulk creating MTI model instances.
@@ -134,6 +136,8 @@ class MTIHandler:
134
136
  update_conflicts: Enable UPSERT on conflict
135
137
  unique_fields: Fields for conflict detection
136
138
  update_fields: Fields to update on conflict
139
+ existing_record_ids: Set of id() for objects that exist in DB (from RecordClassifier)
140
+ existing_pks_map: Dict mapping id(obj) -> pk for existing records (from RecordClassifier)
137
141
 
138
142
  Returns:
139
143
  MTICreatePlan object
@@ -149,6 +153,19 @@ class MTIHandler:
149
153
 
150
154
  batch_size = batch_size or len(objs)
151
155
 
156
+ # Use provided classification (no more DB query here!)
157
+ if existing_record_ids is None:
158
+ existing_record_ids = set()
159
+ if existing_pks_map is None:
160
+ existing_pks_map = {}
161
+
162
+ # Set PKs on existing objects so they can be updated
163
+ if existing_pks_map:
164
+ for obj in objs:
165
+ if id(obj) in existing_pks_map:
166
+ setattr(obj, 'pk', existing_pks_map[id(obj)])
167
+ setattr(obj, 'id', existing_pks_map[id(obj)])
168
+
152
169
  # Build parent levels
153
170
  parent_levels = self._build_parent_levels(
154
171
  objs,
@@ -171,6 +188,10 @@ class MTIHandler:
171
188
  child_model=inheritance_chain[-1],
172
189
  original_objects=objs,
173
190
  batch_size=batch_size,
191
+ existing_record_ids=existing_record_ids,
192
+ update_conflicts=update_conflicts,
193
+ unique_fields=unique_fields or [],
194
+ update_fields=update_fields or [],
174
195
  )
175
196
 
176
197
  def _build_parent_levels(
@@ -272,9 +293,17 @@ class MTIHandler:
272
293
  ut = (ut,)
273
294
  ut_field_sets = [tuple(group) for group in ut]
274
295
 
296
+ # Check individual field uniqueness
297
+ unique_field_sets = []
298
+ for field in model_class._meta.local_fields:
299
+ if field.unique and not field.primary_key:
300
+ unique_field_sets.append((field.name,))
301
+
275
302
  # Compare as sets
276
303
  provided_set = set(normalized_unique)
277
- for group in constraint_field_sets + ut_field_sets:
304
+ all_constraint_sets = constraint_field_sets + ut_field_sets + unique_field_sets
305
+
306
+ for group in all_constraint_sets:
278
307
  if provided_set == set(group):
279
308
  return True
280
309
  return False
@@ -45,6 +45,10 @@ class MTICreatePlan:
45
45
  child_model: The child model class
46
46
  original_objects: Original objects provided by user
47
47
  batch_size: Batch size for operations
48
+ existing_record_ids: Set of id() of original objects that represent existing DB records
49
+ update_conflicts: Whether this is an upsert operation
50
+ unique_fields: Fields used for conflict detection
51
+ update_fields: Fields to update on conflict
48
52
  """
49
53
  inheritance_chain: List[Any]
50
54
  parent_levels: List[ParentLevel]
@@ -52,6 +56,10 @@ class MTICreatePlan:
52
56
  child_model: Any
53
57
  original_objects: List[Any]
54
58
  batch_size: int = None
59
+ existing_record_ids: set = field(default_factory=set)
60
+ update_conflicts: bool = False
61
+ unique_fields: List[str] = field(default_factory=list)
62
+ update_fields: List[str] = field(default_factory=list)
55
63
 
56
64
 
57
65
  @dataclass
@@ -0,0 +1,183 @@
1
+ """
2
+ Record classification service for database queries.
3
+
4
+ This service handles all database queries related to classifying and fetching
5
+ records based on various criteria (PKs, unique fields, etc.).
6
+
7
+ Separates data access concerns from business logic.
8
+ """
9
+
10
+ import logging
11
+ from django.db.models import Q
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RecordClassifier:
17
+ """
18
+ Service for classifying and fetching records via database queries.
19
+
20
+ This is the SINGLE point of truth for record classification queries.
21
+ Keeps database access logic separate from business/planning logic.
22
+ """
23
+
24
+ def __init__(self, model_cls):
25
+ """
26
+ Initialize classifier for a specific model.
27
+
28
+ Args:
29
+ model_cls: The Django model class
30
+ """
31
+ self.model_cls = model_cls
32
+
33
+ def classify_for_upsert(self, objs, unique_fields):
34
+ """
35
+ Classify records as new or existing based on unique_fields.
36
+
37
+ Queries the database to check which records already exist based on the
38
+ unique_fields constraint.
39
+
40
+ Args:
41
+ objs: List of model instances
42
+ unique_fields: List of field names that form the unique constraint
43
+
44
+ Returns:
45
+ Tuple of (existing_record_ids, existing_pks_map)
46
+ - existing_record_ids: Set of id() for objects that exist in DB
47
+ - existing_pks_map: Dict mapping id(obj) -> pk for existing records
48
+ """
49
+ if not unique_fields or not objs:
50
+ return set(), {}
51
+
52
+ # Build a query to find existing records
53
+ queries = []
54
+ obj_to_unique_values = {}
55
+
56
+ for obj in objs:
57
+ # Build lookup dict for this object's unique fields
58
+ lookup = {}
59
+ for field_name in unique_fields:
60
+ value = getattr(obj, field_name, None)
61
+ if value is None:
62
+ # Can't match on None values
63
+ break
64
+ lookup[field_name] = value
65
+ else:
66
+ # All unique fields have values, add to query
67
+ if lookup:
68
+ queries.append(Q(**lookup))
69
+ obj_to_unique_values[id(obj)] = tuple(lookup.values())
70
+
71
+ if not queries:
72
+ return set(), {}
73
+
74
+ # Query for existing records
75
+ combined_query = queries[0]
76
+ for q in queries[1:]:
77
+ combined_query |= q
78
+
79
+ existing_records = list(
80
+ self.model_cls.objects.filter(combined_query).values('pk', *unique_fields)
81
+ )
82
+
83
+ # Map existing records back to original objects
84
+ existing_record_ids = set()
85
+ existing_pks_map = {}
86
+
87
+ for record in existing_records:
88
+ record_values = tuple(record[field] for field in unique_fields)
89
+ # Find which object(s) match these values
90
+ for obj_id, obj_values in obj_to_unique_values.items():
91
+ if obj_values == record_values:
92
+ existing_record_ids.add(obj_id)
93
+ existing_pks_map[obj_id] = record['pk']
94
+
95
+ logger.info(
96
+ f"Classified {len(existing_record_ids)} existing and "
97
+ f"{len(objs) - len(existing_record_ids)} new records for upsert"
98
+ )
99
+
100
+ return existing_record_ids, existing_pks_map
101
+
102
+ def fetch_by_pks(self, pks, select_related=None, prefetch_related=None):
103
+ """
104
+ Fetch records by primary keys with optional relationship loading.
105
+
106
+ Args:
107
+ pks: List of primary key values
108
+ select_related: Optional list of fields to select_related
109
+ prefetch_related: Optional list of fields to prefetch_related
110
+
111
+ Returns:
112
+ Dict[pk, instance] for O(1) lookups
113
+ """
114
+ if not pks:
115
+ return {}
116
+
117
+ queryset = self.model_cls._base_manager.filter(pk__in=pks)
118
+
119
+ if select_related:
120
+ queryset = queryset.select_related(*select_related)
121
+
122
+ if prefetch_related:
123
+ queryset = queryset.prefetch_related(*prefetch_related)
124
+
125
+ return {obj.pk: obj for obj in queryset}
126
+
127
+ def fetch_by_unique_constraint(self, field_values_map):
128
+ """
129
+ Fetch records matching a unique constraint.
130
+
131
+ Args:
132
+ field_values_map: Dict of {field_name: value} for unique constraint
133
+
134
+ Returns:
135
+ Model instance if found, None otherwise
136
+ """
137
+ try:
138
+ return self.model_cls.objects.get(**field_values_map)
139
+ except self.model_cls.DoesNotExist:
140
+ return None
141
+ except self.model_cls.MultipleObjectsReturned:
142
+ logger.warning(
143
+ f"Multiple {self.model_cls.__name__} records found for "
144
+ f"unique constraint {field_values_map}"
145
+ )
146
+ return self.model_cls.objects.filter(**field_values_map).first()
147
+
148
+ def exists_by_pks(self, pks):
149
+ """
150
+ Check if records exist by primary keys without fetching them.
151
+
152
+ Args:
153
+ pks: List of primary key values
154
+
155
+ Returns:
156
+ Set of PKs that exist in the database
157
+ """
158
+ if not pks:
159
+ return set()
160
+
161
+ existing_pks = self.model_cls.objects.filter(
162
+ pk__in=pks
163
+ ).values_list('pk', flat=True)
164
+
165
+ return set(existing_pks)
166
+
167
+ def count_by_unique_fields(self, objs, unique_fields):
168
+ """
169
+ Count how many objects already exist based on unique fields.
170
+
171
+ Useful for validation or reporting before upsert operations.
172
+
173
+ Args:
174
+ objs: List of model instances
175
+ unique_fields: List of field names that form the unique constraint
176
+
177
+ Returns:
178
+ Tuple of (existing_count, new_count)
179
+ """
180
+ existing_ids, _ = self.classify_for_upsert(objs, unique_fields)
181
+ existing_count = len(existing_ids)
182
+ new_count = len(objs) - existing_count
183
+ return existing_count, new_count
@@ -165,6 +165,16 @@ class HookRegistry:
165
165
  with self._lock:
166
166
  return dict(self._hooks)
167
167
 
168
+ @property
169
+ def hooks(self) -> Dict[Tuple[Type, str], List[HookInfo]]:
170
+ """
171
+ Expose internal hooks dictionary for testing purposes.
172
+
173
+ This property provides direct access to the internal hooks storage
174
+ to allow tests to clear the registry state between test runs.
175
+ """
176
+ return self._hooks
177
+
168
178
  def count_hooks(
169
179
  self, model: Optional[Type] = None, event: Optional[str] = None
170
180
  ) -> int:
@@ -286,3 +296,8 @@ def list_all_hooks() -> Dict[Tuple[Type, str], List[HookInfo]]:
286
296
  """
287
297
  registry = get_registry()
288
298
  return registry.list_all()
299
+
300
+
301
+ # Expose hooks dictionary for testing purposes
302
+ # This provides backward compatibility with tests that expect to access _hooks directly
303
+ _hooks = get_registry().hooks
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "django-bulk-hooks"
3
- version = "0.2.15"
3
+ version = "0.2.16"
4
4
  description = "Hook-style hooks for Django bulk operations like bulk_create and bulk_update."
5
5
  authors = ["Konrad Beck <konrad.beck@merchantcapital.co.za>"]
6
6
  readme = "README.md"