django-bulk-hooks 0.1.85__tar.gz → 0.1.87__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of django-bulk-hooks might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: django-bulk-hooks
3
- Version: 0.1.85
3
+ Version: 0.1.87
4
4
  Summary: Hook-style hooks for Django bulk operations like bulk_create and bulk_update.
5
5
  License: MIT
6
6
  Keywords: django,bulk,hooks
@@ -195,90 +195,86 @@ class HookCondition:
195
195
  def __invert__(self):
196
196
  return NotCondition(self)
197
197
 
198
+ def get_required_fields(self):
199
+ """
200
+ Returns a set of field names that this condition needs to evaluate.
201
+ Override in subclasses to specify required fields.
202
+ """
203
+ return set()
198
204
 
199
- class IsNotEqual(HookCondition):
200
- def __init__(self, field, value, only_on_change=False):
205
+
206
+ class IsEqual(HookCondition):
207
+ def __init__(self, field, value):
201
208
  self.field = field
202
209
  self.value = value
203
- self.only_on_change = only_on_change
204
210
 
205
211
  def check(self, instance, original_instance=None):
206
- current = resolve_dotted_attr(instance, self.field)
207
- if self.only_on_change:
208
- if original_instance is None:
209
- return False
210
- previous = resolve_dotted_attr(original_instance, self.field)
211
- return previous == self.value and current != self.value
212
- else:
213
- return current != self.value
212
+ current_value = resolve_dotted_attr(instance, self.field)
213
+ return current_value == self.value
214
214
 
215
+ def get_required_fields(self):
216
+ return {self.field.split('.')[0]}
215
217
 
216
- class IsEqual(HookCondition):
217
- def __init__(self, field, value, only_on_change=False):
218
+
219
+ class IsNotEqual(HookCondition):
220
+ def __init__(self, field, value):
218
221
  self.field = field
219
222
  self.value = value
220
- self.only_on_change = only_on_change
221
223
 
222
224
  def check(self, instance, original_instance=None):
223
- current = resolve_dotted_attr(instance, self.field)
224
- if self.only_on_change:
225
- if original_instance is None:
226
- return False
227
- previous = resolve_dotted_attr(original_instance, self.field)
228
- return previous != self.value and current == self.value
229
- else:
230
- return current == self.value
225
+ current_value = resolve_dotted_attr(instance, self.field)
226
+ return current_value != self.value
231
227
 
228
+ def get_required_fields(self):
229
+ return {self.field.split('.')[0]}
232
230
 
233
- class HasChanged(HookCondition):
234
- def __init__(self, field, has_changed=True):
231
+
232
+ class WasEqual(HookCondition):
233
+ def __init__(self, field, value):
235
234
  self.field = field
236
- self.has_changed = has_changed
235
+ self.value = value
237
236
 
238
237
  def check(self, instance, original_instance=None):
239
- if not original_instance:
238
+ if original_instance is None:
240
239
  return False
241
- current = resolve_dotted_attr(instance, self.field)
242
- previous = resolve_dotted_attr(original_instance, self.field)
243
- return (current != previous) == self.has_changed
240
+ original_value = resolve_dotted_attr(original_instance, self.field)
241
+ return original_value == self.value
244
242
 
243
+ def get_required_fields(self):
244
+ return {self.field.split('.')[0]}
245
245
 
246
- class WasEqual(HookCondition):
247
- def __init__(self, field, value, only_on_change=False):
248
- """
249
- Check if a field's original value was `value`.
250
- If only_on_change is True, only return True when the field has changed away from that value.
251
- """
246
+
247
+ class HasChanged(HookCondition):
248
+ def __init__(self, field):
252
249
  self.field = field
253
- self.value = value
254
- self.only_on_change = only_on_change
255
250
 
256
251
  def check(self, instance, original_instance=None):
257
252
  if original_instance is None:
258
- return False
259
- previous = resolve_dotted_attr(original_instance, self.field)
260
- if self.only_on_change:
261
- current = resolve_dotted_attr(instance, self.field)
262
- return previous == self.value and current != self.value
263
- else:
264
- return previous == self.value
253
+ return True
254
+ current_value = resolve_dotted_attr(instance, self.field)
255
+ original_value = resolve_dotted_attr(original_instance, self.field)
256
+ return current_value != original_value
257
+
258
+ def get_required_fields(self):
259
+ return {self.field.split('.')[0]}
265
260
 
266
261
 
267
262
  class ChangesTo(HookCondition):
268
263
  def __init__(self, field, value):
269
- """
270
- Check if a field's value has changed to `value`.
271
- Only returns True when original value != value and current value == value.
272
- """
273
264
  self.field = field
274
265
  self.value = value
275
266
 
276
267
  def check(self, instance, original_instance=None):
277
268
  if original_instance is None:
278
- return False
279
- previous = resolve_dotted_attr(original_instance, self.field)
280
- current = resolve_dotted_attr(instance, self.field)
281
- return previous != self.value and current == self.value
269
+ current_value = resolve_dotted_attr(instance, self.field)
270
+ return current_value == self.value
271
+
272
+ current_value = resolve_dotted_attr(instance, self.field)
273
+ original_value = resolve_dotted_attr(original_instance, self.field)
274
+ return current_value == self.value and current_value != original_value
275
+
276
+ def get_required_fields(self):
277
+ return {self.field.split('.')[0]}
282
278
 
283
279
 
284
280
  class IsGreaterThan(HookCondition):
@@ -322,30 +318,41 @@ class IsLessThanOrEqual(HookCondition):
322
318
 
323
319
 
324
320
  class AndCondition(HookCondition):
325
- def __init__(self, cond1, cond2):
326
- self.cond1 = cond1
327
- self.cond2 = cond2
321
+ def __init__(self, condition1, condition2):
322
+ self.condition1 = condition1
323
+ self.condition2 = condition2
328
324
 
329
325
  def check(self, instance, original_instance=None):
330
- return self.cond1.check(instance, original_instance) and self.cond2.check(
331
- instance, original_instance
326
+ return (
327
+ self.condition1.check(instance, original_instance)
328
+ and self.condition2.check(instance, original_instance)
332
329
  )
333
330
 
331
+ def get_required_fields(self):
332
+ return self.condition1.get_required_fields() | self.condition2.get_required_fields()
333
+
334
334
 
335
335
  class OrCondition(HookCondition):
336
- def __init__(self, cond1, cond2):
337
- self.cond1 = cond1
338
- self.cond2 = cond2
336
+ def __init__(self, condition1, condition2):
337
+ self.condition1 = condition1
338
+ self.condition2 = condition2
339
339
 
340
340
  def check(self, instance, original_instance=None):
341
- return self.cond1.check(instance, original_instance) or self.cond2.check(
342
- instance, original_instance
341
+ return (
342
+ self.condition1.check(instance, original_instance)
343
+ or self.condition2.check(instance, original_instance)
343
344
  )
344
345
 
346
+ def get_required_fields(self):
347
+ return self.condition1.get_required_fields() | self.condition2.get_required_fields()
348
+
345
349
 
346
350
  class NotCondition(HookCondition):
347
- def __init__(self, cond):
348
- self.cond = cond
351
+ def __init__(self, condition):
352
+ self.condition = condition
349
353
 
350
354
  def check(self, instance, original_instance=None):
351
- return not self.cond.check(instance, original_instance)
355
+ return not self.condition.check(instance, original_instance)
356
+
357
+ def get_required_fields(self):
358
+ return self.condition.get_required_fields()
@@ -8,7 +8,12 @@ from django_bulk_hooks.conditions import safe_get_related_object, safe_get_relat
8
8
  logger = logging.getLogger(__name__)
9
9
 
10
10
 
11
+ # Cache for hook handlers to avoid creating them repeatedly
12
+ _handler_cache = {}
13
+
11
14
  def run(model_cls, event, new_instances, original_instances=None, ctx=None):
15
+ # Get hooks from cache or fetch them
16
+ cache_key = (model_cls, event)
12
17
  hooks = get_hooks(model_cls, event)
13
18
 
14
19
  if not hooks:
@@ -32,19 +37,32 @@ def run(model_cls, event, new_instances, original_instances=None, ctx=None):
32
37
  logger.error("Unexpected error during validation for %s: %s", instance, e)
33
38
  raise
34
39
 
40
+ # Pre-create None list for originals if needed
41
+ if original_instances is None:
42
+ original_instances = [None] * len(new_instances)
43
+
44
+ # Process all hooks
35
45
  for handler_cls, method_name, condition, priority in hooks:
36
- handler_instance = handler_cls()
37
- func = getattr(handler_instance, method_name)
46
+ # Get or create handler instance from cache
47
+ handler_key = (handler_cls, method_name)
48
+ if handler_key not in _handler_cache:
49
+ handler_instance = handler_cls()
50
+ func = getattr(handler_instance, method_name)
51
+ _handler_cache[handler_key] = (handler_instance, func)
52
+ else:
53
+ handler_instance, func = _handler_cache[handler_key]
54
+
55
+ # If no condition, process all instances at once
56
+ if not condition:
57
+ func(new_records=new_instances, old_records=original_instances if any(original_instances) else None)
58
+ continue
38
59
 
60
+ # For conditional hooks, filter instances first
39
61
  to_process_new = []
40
62
  to_process_old = []
41
63
 
42
- for new, original in zip(
43
- new_instances,
44
- original_instances or [None] * len(new_instances),
45
- strict=True,
46
- ):
47
- if not condition or condition.check(new, original):
64
+ for new, original in zip(new_instances, original_instances, strict=True):
65
+ if condition.check(new, original):
48
66
  to_process_new.append(new)
49
67
  to_process_old.append(original)
50
68
 
@@ -0,0 +1,296 @@
1
+ from django.db import models, transaction
2
+
3
+ from django_bulk_hooks import engine
4
+ from django_bulk_hooks.constants import (
5
+ AFTER_CREATE,
6
+ AFTER_DELETE,
7
+ AFTER_UPDATE,
8
+ BEFORE_CREATE,
9
+ BEFORE_DELETE,
10
+ BEFORE_UPDATE,
11
+ VALIDATE_CREATE,
12
+ VALIDATE_DELETE,
13
+ VALIDATE_UPDATE,
14
+ )
15
+ from django_bulk_hooks.context import HookContext
16
+ from django_bulk_hooks.queryset import HookQuerySet
17
+
18
+
19
+ class BulkHookManager(models.Manager):
20
+ # Default chunk sizes - can be overridden per model
21
+ DEFAULT_CHUNK_SIZE = 200
22
+ DEFAULT_RELATED_CHUNK_SIZE = 500 # Higher for related object fetching
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+ self._chunk_size = self.DEFAULT_CHUNK_SIZE
27
+ self._related_chunk_size = self.DEFAULT_RELATED_CHUNK_SIZE
28
+ self._prefetch_related_fields = set()
29
+ self._select_related_fields = set()
30
+
31
+ def configure(self, chunk_size=None, related_chunk_size=None,
32
+ select_related=None, prefetch_related=None):
33
+ """
34
+ Configure bulk operation parameters for this manager.
35
+
36
+ Args:
37
+ chunk_size: Number of objects to process in each bulk operation chunk
38
+ related_chunk_size: Number of objects to fetch in each related object query
39
+ select_related: List of fields to always select_related in bulk operations
40
+ prefetch_related: List of fields to always prefetch_related in bulk operations
41
+ """
42
+ if chunk_size is not None:
43
+ self._chunk_size = chunk_size
44
+ if related_chunk_size is not None:
45
+ self._related_chunk_size = related_chunk_size
46
+ if select_related:
47
+ self._select_related_fields.update(select_related)
48
+ if prefetch_related:
49
+ self._prefetch_related_fields.update(prefetch_related)
50
+
51
+ def _load_originals_optimized(self, pks, fields_to_fetch=None):
52
+ """
53
+ Optimized loading of original instances with smart batching and field selection.
54
+ """
55
+ queryset = self.model.objects.filter(pk__in=pks)
56
+
57
+ # Only select specific fields if provided
58
+ if fields_to_fetch:
59
+ queryset = queryset.only('pk', *fields_to_fetch)
60
+
61
+ # Apply configured related field optimizations
62
+ if self._select_related_fields:
63
+ queryset = queryset.select_related(*self._select_related_fields)
64
+ if self._prefetch_related_fields:
65
+ queryset = queryset.prefetch_related(*self._prefetch_related_fields)
66
+
67
+ # Batch load in chunks to avoid memory issues
68
+ all_originals = []
69
+ for i in range(0, len(pks), self._related_chunk_size):
70
+ chunk_pks = pks[i:i + self._related_chunk_size]
71
+ chunk_originals = list(queryset.filter(pk__in=chunk_pks))
72
+ all_originals.extend(chunk_originals)
73
+
74
+ return all_originals
75
+
76
+ def _get_fields_to_fetch(self, objs, fields):
77
+ """
78
+ Determine which fields need to be fetched based on what's being updated
79
+ and what's needed for hooks.
80
+ """
81
+ fields_to_fetch = set(fields)
82
+
83
+ # Add fields needed by registered hooks
84
+ from django_bulk_hooks.registry import get_hooks
85
+ hooks = get_hooks(self.model, "before_update") + get_hooks(self.model, "after_update")
86
+
87
+ for handler_cls, method_name, condition, _ in hooks:
88
+ if condition:
89
+ # If there's a condition, we need all fields it might access
90
+ fields_to_fetch.update(condition.get_required_fields())
91
+
92
+ return fields_to_fetch
93
+
94
+ @transaction.atomic
95
+ def bulk_update(self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs):
96
+ if not objs:
97
+ return []
98
+
99
+ model_cls = self.model
100
+
101
+ if any(not isinstance(obj, model_cls) for obj in objs):
102
+ raise TypeError(
103
+ f"bulk_update expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
104
+ )
105
+
106
+ if not bypass_hooks:
107
+ # Determine which fields we need to fetch
108
+ fields_to_fetch = self._get_fields_to_fetch(objs, fields)
109
+
110
+ # Load originals efficiently
111
+ pks = [obj.pk for obj in objs if obj.pk is not None]
112
+ originals = self._load_originals_optimized(pks, fields_to_fetch)
113
+
114
+ # Create a mapping for quick lookup
115
+ original_map = {obj.pk: obj for obj in originals}
116
+
117
+ # Align originals with new instances
118
+ aligned_originals = [original_map.get(obj.pk) for obj in objs]
119
+
120
+ ctx = HookContext(model_cls)
121
+
122
+ # Run validation hooks first
123
+ if not bypass_validation:
124
+ engine.run(model_cls, VALIDATE_UPDATE, objs, aligned_originals, ctx=ctx)
125
+
126
+ # Then run business logic hooks
127
+ engine.run(model_cls, BEFORE_UPDATE, objs, aligned_originals, ctx=ctx)
128
+
129
+ # Automatically detect fields that were modified during BEFORE_UPDATE hooks
130
+ modified_fields = self._detect_modified_fields(objs, aligned_originals)
131
+ if modified_fields:
132
+ fields_set = set(fields)
133
+ fields_set.update(modified_fields)
134
+ fields = list(fields_set)
135
+
136
+ # Process in chunks
137
+ for i in range(0, len(objs), self._chunk_size):
138
+ chunk = objs[i:i + self._chunk_size]
139
+ super(models.Manager, self).bulk_update(chunk, fields, **kwargs)
140
+
141
+ if not bypass_hooks:
142
+ engine.run(model_cls, AFTER_UPDATE, objs, aligned_originals, ctx=ctx)
143
+
144
+ return objs
145
+
146
+ def _detect_modified_fields(self, new_instances, original_instances):
147
+ """
148
+ Detect fields that were modified during BEFORE_UPDATE hooks by comparing
149
+ new instances with their original values.
150
+ """
151
+ if not original_instances:
152
+ return set()
153
+
154
+ # Create a mapping of pk to original instance for efficient lookup
155
+ original_map = {obj.pk: obj for obj in original_instances if obj.pk is not None}
156
+
157
+ modified_fields = set()
158
+
159
+ for new_instance in new_instances:
160
+ if new_instance.pk is None:
161
+ continue
162
+
163
+ original = original_map.get(new_instance.pk)
164
+ if not original:
165
+ continue
166
+
167
+ # Compare all fields to detect changes
168
+ for field in new_instance._meta.fields:
169
+ if field.name == "id":
170
+ continue
171
+
172
+ new_value = getattr(new_instance, field.name)
173
+ original_value = getattr(original, field.name)
174
+
175
+ # Handle different field types appropriately
176
+ if field.is_relation:
177
+ # For foreign keys, compare the pk values
178
+ new_pk = new_value.pk if new_value else None
179
+ original_pk = original_value.pk if original_value else None
180
+ if new_pk != original_pk:
181
+ modified_fields.add(field.name)
182
+ else:
183
+ # For regular fields, use direct comparison
184
+ if new_value != original_value:
185
+ modified_fields.add(field.name)
186
+
187
+ return modified_fields
188
+
189
+ @transaction.atomic
190
+ def bulk_create(self, objs, bypass_hooks=False, bypass_validation=False, **kwargs):
191
+ if not objs:
192
+ return []
193
+
194
+ model_cls = self.model
195
+ result = []
196
+
197
+ if any(not isinstance(obj, model_cls) for obj in objs):
198
+ raise TypeError(
199
+ f"bulk_create expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
200
+ )
201
+
202
+ if not bypass_hooks:
203
+ ctx = HookContext(model_cls)
204
+
205
+ # Process validation in chunks to avoid memory issues
206
+ if not bypass_validation:
207
+ for i in range(0, len(objs), self._chunk_size):
208
+ chunk = objs[i:i + self._chunk_size]
209
+ engine.run(model_cls, VALIDATE_CREATE, chunk, ctx=ctx)
210
+
211
+ # Process before_create hooks in chunks
212
+ for i in range(0, len(objs), self._chunk_size):
213
+ chunk = objs[i:i + self._chunk_size]
214
+ engine.run(model_cls, BEFORE_CREATE, chunk, ctx=ctx)
215
+
216
+ # Perform bulk create in chunks
217
+ for i in range(0, len(objs), self._chunk_size):
218
+ chunk = objs[i:i + self._chunk_size]
219
+ created_chunk = super(models.Manager, self).bulk_create(chunk, **kwargs)
220
+ result.extend(created_chunk)
221
+
222
+ if not bypass_hooks:
223
+ # Process after_create hooks in chunks
224
+ for i in range(0, len(result), self._chunk_size):
225
+ chunk = result[i:i + self._chunk_size]
226
+ engine.run(model_cls, AFTER_CREATE, chunk, ctx=ctx)
227
+
228
+ return result
229
+
230
+ @transaction.atomic
231
+ def bulk_delete(self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False):
232
+ if not objs:
233
+ return []
234
+
235
+ model_cls = self.model
236
+ chunk_size = batch_size or self._chunk_size
237
+
238
+ if any(not isinstance(obj, model_cls) for obj in objs):
239
+ raise TypeError(
240
+ f"bulk_delete expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
241
+ )
242
+
243
+ ctx = HookContext(model_cls)
244
+
245
+ if not bypass_hooks:
246
+ # Process hooks in chunks
247
+ for i in range(0, len(objs), chunk_size):
248
+ chunk = objs[i:i + chunk_size]
249
+
250
+ if not bypass_validation:
251
+ engine.run(model_cls, VALIDATE_DELETE, chunk, ctx=ctx)
252
+ engine.run(model_cls, BEFORE_DELETE, chunk, ctx=ctx)
253
+
254
+ # Collect PKs and delete in chunks
255
+ pks = [obj.pk for obj in objs if obj.pk is not None]
256
+ for i in range(0, len(pks), chunk_size):
257
+ chunk_pks = pks[i:i + chunk_size]
258
+ model_cls._base_manager.filter(pk__in=chunk_pks).delete()
259
+
260
+ if not bypass_hooks:
261
+ # Process after_delete hooks in chunks
262
+ for i in range(0, len(objs), chunk_size):
263
+ chunk = objs[i:i + chunk_size]
264
+ engine.run(model_cls, AFTER_DELETE, chunk, ctx=ctx)
265
+
266
+ return objs
267
+
268
+ @transaction.atomic
269
+ def update(self, **kwargs):
270
+ objs = list(self.all())
271
+ if not objs:
272
+ return 0
273
+ for key, value in kwargs.items():
274
+ for obj in objs:
275
+ setattr(obj, key, value)
276
+ self.bulk_update(objs, fields=list(kwargs.keys()))
277
+ return len(objs)
278
+
279
+ @transaction.atomic
280
+ def delete(self):
281
+ objs = list(self.all())
282
+ if not objs:
283
+ return 0
284
+ self.bulk_delete(objs)
285
+ return len(objs)
286
+
287
+ @transaction.atomic
288
+ def save(self, obj):
289
+ if obj.pk:
290
+ self.bulk_update(
291
+ [obj],
292
+ fields=[field.name for field in obj._meta.fields if field.name != "id"],
293
+ )
294
+ else:
295
+ self.bulk_create([obj])
296
+ return obj
@@ -91,40 +91,24 @@ class HookModelMixin(models.Model):
91
91
  is_create = self.pk is None
92
92
  ctx = HookContext(self.__class__)
93
93
 
94
- if is_create:
95
- # For create operations, run BEFORE hooks first
96
- with patch_foreign_key_behavior():
94
+ # Use a single context manager for all hooks
95
+ with patch_foreign_key_behavior():
96
+ if is_create:
97
+ # For create operations
97
98
  run(self.__class__, BEFORE_CREATE, [self], ctx=ctx)
98
-
99
- # Then let Django save
100
- super().save(*args, **kwargs)
101
-
102
- # Then run AFTER hooks
103
- with patch_foreign_key_behavior():
99
+ super().save(*args, **kwargs)
104
100
  run(self.__class__, AFTER_CREATE, [self], ctx=ctx)
105
- else:
106
- # For update operations, we need to get the old record
107
- try:
108
- old_instance = self.__class__.objects.get(pk=self.pk)
109
-
110
- # Run BEFORE hooks first
111
- with patch_foreign_key_behavior():
101
+ else:
102
+ # For update operations
103
+ try:
104
+ old_instance = self.__class__.objects.get(pk=self.pk)
112
105
  run(self.__class__, BEFORE_UPDATE, [self], [old_instance], ctx=ctx)
113
-
114
- # Then let Django save
115
- super().save(*args, **kwargs)
116
-
117
- # Then run AFTER hooks
118
- with patch_foreign_key_behavior():
106
+ super().save(*args, **kwargs)
119
107
  run(self.__class__, AFTER_UPDATE, [self], [old_instance], ctx=ctx)
120
- except self.__class__.DoesNotExist:
121
- # If the old instance doesn't exist, treat as create
122
- with patch_foreign_key_behavior():
108
+ except self.__class__.DoesNotExist:
109
+ # If the old instance doesn't exist, treat as create
123
110
  run(self.__class__, BEFORE_CREATE, [self], ctx=ctx)
124
-
125
- super().save(*args, **kwargs)
126
-
127
- with patch_foreign_key_behavior():
111
+ super().save(*args, **kwargs)
128
112
  run(self.__class__, AFTER_CREATE, [self], ctx=ctx)
129
113
 
130
114
  return self
@@ -132,16 +116,11 @@ class HookModelMixin(models.Model):
132
116
  def delete(self, *args, **kwargs):
133
117
  ctx = HookContext(self.__class__)
134
118
 
135
- # Run validation hooks first
119
+ # Use a single context manager for all hooks
136
120
  with patch_foreign_key_behavior():
137
121
  run(self.__class__, VALIDATE_DELETE, [self], ctx=ctx)
138
-
139
- # Then run business logic hooks
140
- with patch_foreign_key_behavior():
141
122
  run(self.__class__, BEFORE_DELETE, [self], ctx=ctx)
142
-
143
- result = super().delete(*args, **kwargs)
144
-
145
- with patch_foreign_key_behavior():
123
+ result = super().delete(*args, **kwargs)
146
124
  run(self.__class__, AFTER_DELETE, [self], ctx=ctx)
125
+
147
126
  return result
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "django-bulk-hooks"
3
- version = "0.1.85"
3
+ version = "0.1.87"
4
4
  description = "Hook-style hooks for Django bulk operations like bulk_create and bulk_update."
5
5
  authors = ["Konrad Beck <konrad.beck@merchantcapital.co.za>"]
6
6
  readme = "README.md"
@@ -1,208 +0,0 @@
1
- from django.db import models, transaction
2
-
3
- from django_bulk_hooks import engine
4
- from django_bulk_hooks.constants import (
5
- AFTER_CREATE,
6
- AFTER_DELETE,
7
- AFTER_UPDATE,
8
- BEFORE_CREATE,
9
- BEFORE_DELETE,
10
- BEFORE_UPDATE,
11
- VALIDATE_CREATE,
12
- VALIDATE_DELETE,
13
- VALIDATE_UPDATE,
14
- )
15
- from django_bulk_hooks.context import HookContext
16
- from django_bulk_hooks.queryset import HookQuerySet
17
-
18
-
19
- class BulkHookManager(models.Manager):
20
- CHUNK_SIZE = 200
21
-
22
- def get_queryset(self):
23
- return HookQuerySet(self.model, using=self._db)
24
-
25
- @transaction.atomic
26
- def bulk_update(
27
- self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs
28
- ):
29
- if not objs:
30
- return []
31
-
32
- model_cls = self.model
33
-
34
- if any(not isinstance(obj, model_cls) for obj in objs):
35
- raise TypeError(
36
- f"bulk_update expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
37
- )
38
-
39
- if not bypass_hooks:
40
- # Load originals for hook comparison
41
- originals = list(
42
- model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
43
- )
44
-
45
- ctx = HookContext(model_cls)
46
-
47
- # Run validation hooks first
48
- if not bypass_validation:
49
- engine.run(model_cls, VALIDATE_UPDATE, objs, originals, ctx=ctx)
50
-
51
- # Then run business logic hooks
52
- engine.run(model_cls, BEFORE_UPDATE, objs, originals, ctx=ctx)
53
-
54
- # Automatically detect fields that were modified during BEFORE_UPDATE hooks
55
- modified_fields = self._detect_modified_fields(objs, originals)
56
- if modified_fields:
57
- # Convert to set for efficient union operation
58
- fields_set = set(fields)
59
- fields_set.update(modified_fields)
60
- fields = list(fields_set)
61
-
62
- for i in range(0, len(objs), self.CHUNK_SIZE):
63
- chunk = objs[i : i + self.CHUNK_SIZE]
64
- # Call the base implementation to avoid re-triggering this method
65
- super(models.Manager, self).bulk_update(chunk, fields, **kwargs)
66
-
67
- if not bypass_hooks:
68
- engine.run(model_cls, AFTER_UPDATE, objs, originals, ctx=ctx)
69
-
70
- return objs
71
-
72
- def _detect_modified_fields(self, new_instances, original_instances):
73
- """
74
- Detect fields that were modified during BEFORE_UPDATE hooks by comparing
75
- new instances with their original values.
76
- """
77
- if not original_instances:
78
- return set()
79
-
80
- # Create a mapping of pk to original instance for efficient lookup
81
- original_map = {obj.pk: obj for obj in original_instances if obj.pk is not None}
82
-
83
- modified_fields = set()
84
-
85
- for new_instance in new_instances:
86
- if new_instance.pk is None:
87
- continue
88
-
89
- original = original_map.get(new_instance.pk)
90
- if not original:
91
- continue
92
-
93
- # Compare all fields to detect changes
94
- for field in new_instance._meta.fields:
95
- if field.name == "id":
96
- continue
97
-
98
- new_value = getattr(new_instance, field.name)
99
- original_value = getattr(original, field.name)
100
-
101
- # Handle different field types appropriately
102
- if field.is_relation:
103
- # For foreign keys, compare the pk values
104
- new_pk = new_value.pk if new_value else None
105
- original_pk = original_value.pk if original_value else None
106
- if new_pk != original_pk:
107
- modified_fields.add(field.name)
108
- else:
109
- # For regular fields, use direct comparison
110
- if new_value != original_value:
111
- modified_fields.add(field.name)
112
-
113
- return modified_fields
114
-
115
- @transaction.atomic
116
- def bulk_create(self, objs, bypass_hooks=False, bypass_validation=False, **kwargs):
117
- model_cls = self.model
118
-
119
- if any(not isinstance(obj, model_cls) for obj in objs):
120
- raise TypeError(
121
- f"bulk_create expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
122
- )
123
-
124
- result = []
125
-
126
- if not bypass_hooks:
127
- ctx = HookContext(model_cls)
128
-
129
- # Run validation hooks first
130
- if not bypass_validation:
131
- engine.run(model_cls, VALIDATE_CREATE, objs, ctx=ctx)
132
-
133
- # Then run business logic hooks
134
- engine.run(model_cls, BEFORE_CREATE, objs, ctx=ctx)
135
-
136
- for i in range(0, len(objs), self.CHUNK_SIZE):
137
- chunk = objs[i : i + self.CHUNK_SIZE]
138
- result.extend(super(models.Manager, self).bulk_create(chunk, **kwargs))
139
-
140
- if not bypass_hooks:
141
- engine.run(model_cls, AFTER_CREATE, result, ctx=ctx)
142
-
143
- return result
144
-
145
- @transaction.atomic
146
- def bulk_delete(
147
- self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False
148
- ):
149
- if not objs:
150
- return []
151
-
152
- model_cls = self.model
153
-
154
- if any(not isinstance(obj, model_cls) for obj in objs):
155
- raise TypeError(
156
- f"bulk_delete expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
157
- )
158
-
159
- ctx = HookContext(model_cls)
160
-
161
- if not bypass_hooks:
162
- # Run validation hooks first
163
- if not bypass_validation:
164
- engine.run(model_cls, VALIDATE_DELETE, objs, ctx=ctx)
165
-
166
- # Then run business logic hooks
167
- engine.run(model_cls, BEFORE_DELETE, objs, ctx=ctx)
168
-
169
- pks = [obj.pk for obj in objs if obj.pk is not None]
170
-
171
- # Use base manager for the actual deletion to prevent recursion
172
- # The hooks have already been fired above, so we don't need them again
173
- model_cls._base_manager.filter(pk__in=pks).delete()
174
-
175
- if not bypass_hooks:
176
- engine.run(model_cls, AFTER_DELETE, objs, ctx=ctx)
177
-
178
- return objs
179
-
180
- @transaction.atomic
181
- def update(self, **kwargs):
182
- objs = list(self.all())
183
- if not objs:
184
- return 0
185
- for key, value in kwargs.items():
186
- for obj in objs:
187
- setattr(obj, key, value)
188
- self.bulk_update(objs, fields=list(kwargs.keys()))
189
- return len(objs)
190
-
191
- @transaction.atomic
192
- def delete(self):
193
- objs = list(self.all())
194
- if not objs:
195
- return 0
196
- self.bulk_delete(objs)
197
- return len(objs)
198
-
199
- @transaction.atomic
200
- def save(self, obj):
201
- if obj.pk:
202
- self.bulk_update(
203
- [obj],
204
- fields=[field.name for field in obj._meta.fields if field.name != "id"],
205
- )
206
- else:
207
- self.bulk_create([obj])
208
- return obj