django-bulk-hooks 0.1.86__tar.gz → 0.1.88__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of django-bulk-hooks might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: django-bulk-hooks
3
- Version: 0.1.86
3
+ Version: 0.1.88
4
4
  Summary: Hook-style hooks for Django bulk operations like bulk_create and bulk_update.
5
5
  License: MIT
6
6
  Keywords: django,bulk,hooks
@@ -195,90 +195,104 @@ class HookCondition:
195
195
  def __invert__(self):
196
196
  return NotCondition(self)
197
197
 
198
+ def get_required_fields(self):
199
+ """
200
+ Returns a set of field names that this condition needs to evaluate.
201
+ Override in subclasses to specify required fields.
202
+ """
203
+ return set()
198
204
 
199
- class IsNotEqual(HookCondition):
200
- def __init__(self, field, value, only_on_change=False):
205
+
206
+ class IsEqual(HookCondition):
207
+ def __init__(self, field, value):
201
208
  self.field = field
202
209
  self.value = value
203
- self.only_on_change = only_on_change
204
210
 
205
211
  def check(self, instance, original_instance=None):
206
- current = resolve_dotted_attr(instance, self.field)
207
- if self.only_on_change:
208
- if original_instance is None:
209
- return False
210
- previous = resolve_dotted_attr(original_instance, self.field)
211
- return previous == self.value and current != self.value
212
- else:
213
- return current != self.value
212
+ current_value = resolve_dotted_attr(instance, self.field)
213
+ return current_value == self.value
214
214
 
215
+ def get_required_fields(self):
216
+ return {self.field.split('.')[0]}
215
217
 
216
- class IsEqual(HookCondition):
217
- def __init__(self, field, value, only_on_change=False):
218
+
219
+ class IsNotEqual(HookCondition):
220
+ def __init__(self, field, value):
218
221
  self.field = field
219
222
  self.value = value
220
- self.only_on_change = only_on_change
221
223
 
222
224
  def check(self, instance, original_instance=None):
223
- current = resolve_dotted_attr(instance, self.field)
224
- if self.only_on_change:
225
- if original_instance is None:
226
- return False
227
- previous = resolve_dotted_attr(original_instance, self.field)
228
- return previous != self.value and current == self.value
229
- else:
230
- return current == self.value
225
+ current_value = resolve_dotted_attr(instance, self.field)
226
+ return current_value != self.value
231
227
 
228
+ def get_required_fields(self):
229
+ return {self.field.split('.')[0]}
232
230
 
233
- class HasChanged(HookCondition):
234
- def __init__(self, field, has_changed=True):
231
+
232
+ class WasEqual(HookCondition):
233
+ def __init__(self, field, value):
235
234
  self.field = field
236
- self.has_changed = has_changed
235
+ self.value = value
237
236
 
238
237
  def check(self, instance, original_instance=None):
239
- if not original_instance:
238
+ if original_instance is None:
240
239
  return False
241
- current = resolve_dotted_attr(instance, self.field)
242
- previous = resolve_dotted_attr(original_instance, self.field)
243
- return (current != previous) == self.has_changed
240
+ original_value = resolve_dotted_attr(original_instance, self.field)
241
+ return original_value == self.value
244
242
 
243
+ def get_required_fields(self):
244
+ return {self.field.split('.')[0]}
245
245
 
246
- class WasEqual(HookCondition):
247
- def __init__(self, field, value, only_on_change=False):
246
+
247
+ class HasChanged(HookCondition):
248
+ def __init__(self, field, has_changed=True):
248
249
  """
249
- Check if a field's original value was `value`.
250
- If only_on_change is True, only return True when the field has changed away from that value.
250
+ Check if a field's value has changed or remained the same.
251
+
252
+ Args:
253
+ field: The field name to check
254
+ has_changed: If True (default), condition passes when field has changed.
255
+ If False, condition passes when field has remained the same.
256
+ This is useful for:
257
+ - Detecting stable/unchanged fields
258
+ - Validating field immutability
259
+ - Ensuring critical fields remain constant
260
+ - State machine validations
251
261
  """
252
262
  self.field = field
253
- self.value = value
254
- self.only_on_change = only_on_change
263
+ self.has_changed = has_changed
255
264
 
256
265
  def check(self, instance, original_instance=None):
257
266
  if original_instance is None:
258
- return False
259
- previous = resolve_dotted_attr(original_instance, self.field)
260
- if self.only_on_change:
261
- current = resolve_dotted_attr(instance, self.field)
262
- return previous == self.value and current != self.value
263
- else:
264
- return previous == self.value
267
+ # For new instances:
268
+ # - If we're checking for changes (has_changed=True), return False since there's no change yet
269
+ # - If we're checking for stability (has_changed=False), return True since it's technically unchanged
270
+ return not self.has_changed
271
+
272
+ current_value = resolve_dotted_attr(instance, self.field)
273
+ original_value = resolve_dotted_attr(original_instance, self.field)
274
+ return (current_value != original_value) == self.has_changed
275
+
276
+ def get_required_fields(self):
277
+ return {self.field.split('.')[0]}
265
278
 
266
279
 
267
280
  class ChangesTo(HookCondition):
268
281
  def __init__(self, field, value):
269
- """
270
- Check if a field's value has changed to `value`.
271
- Only returns True when original value != value and current value == value.
272
- """
273
282
  self.field = field
274
283
  self.value = value
275
284
 
276
285
  def check(self, instance, original_instance=None):
277
286
  if original_instance is None:
278
- return False
279
- previous = resolve_dotted_attr(original_instance, self.field)
280
- current = resolve_dotted_attr(instance, self.field)
281
- return previous != self.value and current == self.value
287
+ current_value = resolve_dotted_attr(instance, self.field)
288
+ return current_value == self.value
289
+
290
+ current_value = resolve_dotted_attr(instance, self.field)
291
+ original_value = resolve_dotted_attr(original_instance, self.field)
292
+ return current_value == self.value and current_value != original_value
293
+
294
+ def get_required_fields(self):
295
+ return {self.field.split('.')[0]}
282
296
 
283
297
 
284
298
  class IsGreaterThan(HookCondition):
@@ -322,30 +336,41 @@ class IsLessThanOrEqual(HookCondition):
322
336
 
323
337
 
324
338
  class AndCondition(HookCondition):
325
- def __init__(self, cond1, cond2):
326
- self.cond1 = cond1
327
- self.cond2 = cond2
339
+ def __init__(self, condition1, condition2):
340
+ self.condition1 = condition1
341
+ self.condition2 = condition2
328
342
 
329
343
  def check(self, instance, original_instance=None):
330
- return self.cond1.check(instance, original_instance) and self.cond2.check(
331
- instance, original_instance
344
+ return (
345
+ self.condition1.check(instance, original_instance)
346
+ and self.condition2.check(instance, original_instance)
332
347
  )
333
348
 
349
+ def get_required_fields(self):
350
+ return self.condition1.get_required_fields() | self.condition2.get_required_fields()
351
+
334
352
 
335
353
  class OrCondition(HookCondition):
336
- def __init__(self, cond1, cond2):
337
- self.cond1 = cond1
338
- self.cond2 = cond2
354
+ def __init__(self, condition1, condition2):
355
+ self.condition1 = condition1
356
+ self.condition2 = condition2
339
357
 
340
358
  def check(self, instance, original_instance=None):
341
- return self.cond1.check(instance, original_instance) or self.cond2.check(
342
- instance, original_instance
359
+ return (
360
+ self.condition1.check(instance, original_instance)
361
+ or self.condition2.check(instance, original_instance)
343
362
  )
344
363
 
364
+ def get_required_fields(self):
365
+ return self.condition1.get_required_fields() | self.condition2.get_required_fields()
366
+
345
367
 
346
368
  class NotCondition(HookCondition):
347
- def __init__(self, cond):
348
- self.cond = cond
369
+ def __init__(self, condition):
370
+ self.condition = condition
349
371
 
350
372
  def check(self, instance, original_instance=None):
351
- return not self.cond.check(instance, original_instance)
373
+ return not self.condition.check(instance, original_instance)
374
+
375
+ def get_required_fields(self):
376
+ return self.condition.get_required_fields()
@@ -0,0 +1,296 @@
1
+ from django.db import models, transaction
2
+
3
+ from django_bulk_hooks import engine
4
+ from django_bulk_hooks.constants import (
5
+ AFTER_CREATE,
6
+ AFTER_DELETE,
7
+ AFTER_UPDATE,
8
+ BEFORE_CREATE,
9
+ BEFORE_DELETE,
10
+ BEFORE_UPDATE,
11
+ VALIDATE_CREATE,
12
+ VALIDATE_DELETE,
13
+ VALIDATE_UPDATE,
14
+ )
15
+ from django_bulk_hooks.context import HookContext
16
+ from django_bulk_hooks.queryset import HookQuerySet
17
+
18
+
19
+ class BulkHookManager(models.Manager):
20
+ # Default chunk sizes - can be overridden per model
21
+ DEFAULT_CHUNK_SIZE = 200
22
+ DEFAULT_RELATED_CHUNK_SIZE = 500 # Higher for related object fetching
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+ self._chunk_size = self.DEFAULT_CHUNK_SIZE
27
+ self._related_chunk_size = self.DEFAULT_RELATED_CHUNK_SIZE
28
+ self._prefetch_related_fields = set()
29
+ self._select_related_fields = set()
30
+
31
+ def configure(self, chunk_size=None, related_chunk_size=None,
32
+ select_related=None, prefetch_related=None):
33
+ """
34
+ Configure bulk operation parameters for this manager.
35
+
36
+ Args:
37
+ chunk_size: Number of objects to process in each bulk operation chunk
38
+ related_chunk_size: Number of objects to fetch in each related object query
39
+ select_related: List of fields to always select_related in bulk operations
40
+ prefetch_related: List of fields to always prefetch_related in bulk operations
41
+ """
42
+ if chunk_size is not None:
43
+ self._chunk_size = chunk_size
44
+ if related_chunk_size is not None:
45
+ self._related_chunk_size = related_chunk_size
46
+ if select_related:
47
+ self._select_related_fields.update(select_related)
48
+ if prefetch_related:
49
+ self._prefetch_related_fields.update(prefetch_related)
50
+
51
+ def _load_originals_optimized(self, pks, fields_to_fetch=None):
52
+ """
53
+ Optimized loading of original instances with smart batching and field selection.
54
+ """
55
+ queryset = self.model.objects.filter(pk__in=pks)
56
+
57
+ # Only select specific fields if provided
58
+ if fields_to_fetch:
59
+ queryset = queryset.only('pk', *fields_to_fetch)
60
+
61
+ # Apply configured related field optimizations
62
+ if self._select_related_fields:
63
+ queryset = queryset.select_related(*self._select_related_fields)
64
+ if self._prefetch_related_fields:
65
+ queryset = queryset.prefetch_related(*self._prefetch_related_fields)
66
+
67
+ # Batch load in chunks to avoid memory issues
68
+ all_originals = []
69
+ for i in range(0, len(pks), self._related_chunk_size):
70
+ chunk_pks = pks[i:i + self._related_chunk_size]
71
+ chunk_originals = list(queryset.filter(pk__in=chunk_pks))
72
+ all_originals.extend(chunk_originals)
73
+
74
+ return all_originals
75
+
76
+ def _get_fields_to_fetch(self, objs, fields):
77
+ """
78
+ Determine which fields need to be fetched based on what's being updated
79
+ and what's needed for hooks.
80
+ """
81
+ fields_to_fetch = set(fields)
82
+
83
+ # Add fields needed by registered hooks
84
+ from django_bulk_hooks.registry import get_hooks
85
+ hooks = get_hooks(self.model, "before_update") + get_hooks(self.model, "after_update")
86
+
87
+ for handler_cls, method_name, condition, _ in hooks:
88
+ if condition:
89
+ # If there's a condition, we need all fields it might access
90
+ fields_to_fetch.update(condition.get_required_fields())
91
+
92
+ return fields_to_fetch
93
+
94
+ @transaction.atomic
95
+ def bulk_update(self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs):
96
+ if not objs:
97
+ return []
98
+
99
+ model_cls = self.model
100
+
101
+ if any(not isinstance(obj, model_cls) for obj in objs):
102
+ raise TypeError(
103
+ f"bulk_update expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
104
+ )
105
+
106
+ if not bypass_hooks:
107
+ # Determine which fields we need to fetch
108
+ fields_to_fetch = self._get_fields_to_fetch(objs, fields)
109
+
110
+ # Load originals efficiently
111
+ pks = [obj.pk for obj in objs if obj.pk is not None]
112
+ originals = self._load_originals_optimized(pks, fields_to_fetch)
113
+
114
+ # Create a mapping for quick lookup
115
+ original_map = {obj.pk: obj for obj in originals}
116
+
117
+ # Align originals with new instances
118
+ aligned_originals = [original_map.get(obj.pk) for obj in objs]
119
+
120
+ ctx = HookContext(model_cls)
121
+
122
+ # Run validation hooks first
123
+ if not bypass_validation:
124
+ engine.run(model_cls, VALIDATE_UPDATE, objs, aligned_originals, ctx=ctx)
125
+
126
+ # Then run business logic hooks
127
+ engine.run(model_cls, BEFORE_UPDATE, objs, aligned_originals, ctx=ctx)
128
+
129
+ # Automatically detect fields that were modified during BEFORE_UPDATE hooks
130
+ modified_fields = self._detect_modified_fields(objs, aligned_originals)
131
+ if modified_fields:
132
+ fields_set = set(fields)
133
+ fields_set.update(modified_fields)
134
+ fields = list(fields_set)
135
+
136
+ # Process in chunks
137
+ for i in range(0, len(objs), self._chunk_size):
138
+ chunk = objs[i:i + self._chunk_size]
139
+ super(models.Manager, self).bulk_update(chunk, fields, **kwargs)
140
+
141
+ if not bypass_hooks:
142
+ engine.run(model_cls, AFTER_UPDATE, objs, aligned_originals, ctx=ctx)
143
+
144
+ return objs
145
+
146
+ def _detect_modified_fields(self, new_instances, original_instances):
147
+ """
148
+ Detect fields that were modified during BEFORE_UPDATE hooks by comparing
149
+ new instances with their original values.
150
+ """
151
+ if not original_instances:
152
+ return set()
153
+
154
+ # Create a mapping of pk to original instance for efficient lookup
155
+ original_map = {obj.pk: obj for obj in original_instances if obj.pk is not None}
156
+
157
+ modified_fields = set()
158
+
159
+ for new_instance in new_instances:
160
+ if new_instance.pk is None:
161
+ continue
162
+
163
+ original = original_map.get(new_instance.pk)
164
+ if not original:
165
+ continue
166
+
167
+ # Compare all fields to detect changes
168
+ for field in new_instance._meta.fields:
169
+ if field.name == "id":
170
+ continue
171
+
172
+ new_value = getattr(new_instance, field.name)
173
+ original_value = getattr(original, field.name)
174
+
175
+ # Handle different field types appropriately
176
+ if field.is_relation:
177
+ # For foreign keys, compare the pk values
178
+ new_pk = new_value.pk if new_value else None
179
+ original_pk = original_value.pk if original_value else None
180
+ if new_pk != original_pk:
181
+ modified_fields.add(field.name)
182
+ else:
183
+ # For regular fields, use direct comparison
184
+ if new_value != original_value:
185
+ modified_fields.add(field.name)
186
+
187
+ return modified_fields
188
+
189
+ @transaction.atomic
190
+ def bulk_create(self, objs, bypass_hooks=False, bypass_validation=False, **kwargs):
191
+ if not objs:
192
+ return []
193
+
194
+ model_cls = self.model
195
+ result = []
196
+
197
+ if any(not isinstance(obj, model_cls) for obj in objs):
198
+ raise TypeError(
199
+ f"bulk_create expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
200
+ )
201
+
202
+ if not bypass_hooks:
203
+ ctx = HookContext(model_cls)
204
+
205
+ # Process validation in chunks to avoid memory issues
206
+ if not bypass_validation:
207
+ for i in range(0, len(objs), self._chunk_size):
208
+ chunk = objs[i:i + self._chunk_size]
209
+ engine.run(model_cls, VALIDATE_CREATE, chunk, ctx=ctx)
210
+
211
+ # Process before_create hooks in chunks
212
+ for i in range(0, len(objs), self._chunk_size):
213
+ chunk = objs[i:i + self._chunk_size]
214
+ engine.run(model_cls, BEFORE_CREATE, chunk, ctx=ctx)
215
+
216
+ # Perform bulk create in chunks
217
+ for i in range(0, len(objs), self._chunk_size):
218
+ chunk = objs[i:i + self._chunk_size]
219
+ created_chunk = super(models.Manager, self).bulk_create(chunk, **kwargs)
220
+ result.extend(created_chunk)
221
+
222
+ if not bypass_hooks:
223
+ # Process after_create hooks in chunks
224
+ for i in range(0, len(result), self._chunk_size):
225
+ chunk = result[i:i + self._chunk_size]
226
+ engine.run(model_cls, AFTER_CREATE, chunk, ctx=ctx)
227
+
228
+ return result
229
+
230
+ @transaction.atomic
231
+ def bulk_delete(self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False):
232
+ if not objs:
233
+ return []
234
+
235
+ model_cls = self.model
236
+ chunk_size = batch_size or self._chunk_size
237
+
238
+ if any(not isinstance(obj, model_cls) for obj in objs):
239
+ raise TypeError(
240
+ f"bulk_delete expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
241
+ )
242
+
243
+ ctx = HookContext(model_cls)
244
+
245
+ if not bypass_hooks:
246
+ # Process hooks in chunks
247
+ for i in range(0, len(objs), chunk_size):
248
+ chunk = objs[i:i + chunk_size]
249
+
250
+ if not bypass_validation:
251
+ engine.run(model_cls, VALIDATE_DELETE, chunk, ctx=ctx)
252
+ engine.run(model_cls, BEFORE_DELETE, chunk, ctx=ctx)
253
+
254
+ # Collect PKs and delete in chunks
255
+ pks = [obj.pk for obj in objs if obj.pk is not None]
256
+ for i in range(0, len(pks), chunk_size):
257
+ chunk_pks = pks[i:i + chunk_size]
258
+ model_cls._base_manager.filter(pk__in=chunk_pks).delete()
259
+
260
+ if not bypass_hooks:
261
+ # Process after_delete hooks in chunks
262
+ for i in range(0, len(objs), chunk_size):
263
+ chunk = objs[i:i + chunk_size]
264
+ engine.run(model_cls, AFTER_DELETE, chunk, ctx=ctx)
265
+
266
+ return objs
267
+
268
+ @transaction.atomic
269
+ def update(self, **kwargs):
270
+ objs = list(self.all())
271
+ if not objs:
272
+ return 0
273
+ for key, value in kwargs.items():
274
+ for obj in objs:
275
+ setattr(obj, key, value)
276
+ self.bulk_update(objs, fields=list(kwargs.keys()))
277
+ return len(objs)
278
+
279
+ @transaction.atomic
280
+ def delete(self):
281
+ objs = list(self.all())
282
+ if not objs:
283
+ return 0
284
+ self.bulk_delete(objs)
285
+ return len(objs)
286
+
287
+ @transaction.atomic
288
+ def save(self, obj):
289
+ if obj.pk:
290
+ self.bulk_update(
291
+ [obj],
292
+ fields=[field.name for field in obj._meta.fields if field.name != "id"],
293
+ )
294
+ else:
295
+ self.bulk_create([obj])
296
+ return obj
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "django-bulk-hooks"
3
- version = "0.1.86"
3
+ version = "0.1.88"
4
4
  description = "Hook-style hooks for Django bulk operations like bulk_create and bulk_update."
5
5
  authors = ["Konrad Beck <konrad.beck@merchantcapital.co.za>"]
6
6
  readme = "README.md"
@@ -1,208 +0,0 @@
1
- from django.db import models, transaction
2
-
3
- from django_bulk_hooks import engine
4
- from django_bulk_hooks.constants import (
5
- AFTER_CREATE,
6
- AFTER_DELETE,
7
- AFTER_UPDATE,
8
- BEFORE_CREATE,
9
- BEFORE_DELETE,
10
- BEFORE_UPDATE,
11
- VALIDATE_CREATE,
12
- VALIDATE_DELETE,
13
- VALIDATE_UPDATE,
14
- )
15
- from django_bulk_hooks.context import HookContext
16
- from django_bulk_hooks.queryset import HookQuerySet
17
-
18
-
19
- class BulkHookManager(models.Manager):
20
- CHUNK_SIZE = 200
21
-
22
- def get_queryset(self):
23
- return HookQuerySet(self.model, using=self._db)
24
-
25
- @transaction.atomic
26
- def bulk_update(
27
- self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs
28
- ):
29
- if not objs:
30
- return []
31
-
32
- model_cls = self.model
33
-
34
- if any(not isinstance(obj, model_cls) for obj in objs):
35
- raise TypeError(
36
- f"bulk_update expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
37
- )
38
-
39
- if not bypass_hooks:
40
- # Load originals for hook comparison
41
- originals = list(
42
- model_cls.objects.filter(pk__in=[obj.pk for obj in objs])
43
- )
44
-
45
- ctx = HookContext(model_cls)
46
-
47
- # Run validation hooks first
48
- if not bypass_validation:
49
- engine.run(model_cls, VALIDATE_UPDATE, objs, originals, ctx=ctx)
50
-
51
- # Then run business logic hooks
52
- engine.run(model_cls, BEFORE_UPDATE, objs, originals, ctx=ctx)
53
-
54
- # Automatically detect fields that were modified during BEFORE_UPDATE hooks
55
- modified_fields = self._detect_modified_fields(objs, originals)
56
- if modified_fields:
57
- # Convert to set for efficient union operation
58
- fields_set = set(fields)
59
- fields_set.update(modified_fields)
60
- fields = list(fields_set)
61
-
62
- for i in range(0, len(objs), self.CHUNK_SIZE):
63
- chunk = objs[i : i + self.CHUNK_SIZE]
64
- # Call the base implementation to avoid re-triggering this method
65
- super(models.Manager, self).bulk_update(chunk, fields, **kwargs)
66
-
67
- if not bypass_hooks:
68
- engine.run(model_cls, AFTER_UPDATE, objs, originals, ctx=ctx)
69
-
70
- return objs
71
-
72
- def _detect_modified_fields(self, new_instances, original_instances):
73
- """
74
- Detect fields that were modified during BEFORE_UPDATE hooks by comparing
75
- new instances with their original values.
76
- """
77
- if not original_instances:
78
- return set()
79
-
80
- # Create a mapping of pk to original instance for efficient lookup
81
- original_map = {obj.pk: obj for obj in original_instances if obj.pk is not None}
82
-
83
- modified_fields = set()
84
-
85
- for new_instance in new_instances:
86
- if new_instance.pk is None:
87
- continue
88
-
89
- original = original_map.get(new_instance.pk)
90
- if not original:
91
- continue
92
-
93
- # Compare all fields to detect changes
94
- for field in new_instance._meta.fields:
95
- if field.name == "id":
96
- continue
97
-
98
- new_value = getattr(new_instance, field.name)
99
- original_value = getattr(original, field.name)
100
-
101
- # Handle different field types appropriately
102
- if field.is_relation:
103
- # For foreign keys, compare the pk values
104
- new_pk = new_value.pk if new_value else None
105
- original_pk = original_value.pk if original_value else None
106
- if new_pk != original_pk:
107
- modified_fields.add(field.name)
108
- else:
109
- # For regular fields, use direct comparison
110
- if new_value != original_value:
111
- modified_fields.add(field.name)
112
-
113
- return modified_fields
114
-
115
- @transaction.atomic
116
- def bulk_create(self, objs, bypass_hooks=False, bypass_validation=False, **kwargs):
117
- model_cls = self.model
118
-
119
- if any(not isinstance(obj, model_cls) for obj in objs):
120
- raise TypeError(
121
- f"bulk_create expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
122
- )
123
-
124
- result = []
125
-
126
- if not bypass_hooks:
127
- ctx = HookContext(model_cls)
128
-
129
- # Run validation hooks first
130
- if not bypass_validation:
131
- engine.run(model_cls, VALIDATE_CREATE, objs, ctx=ctx)
132
-
133
- # Then run business logic hooks
134
- engine.run(model_cls, BEFORE_CREATE, objs, ctx=ctx)
135
-
136
- for i in range(0, len(objs), self.CHUNK_SIZE):
137
- chunk = objs[i : i + self.CHUNK_SIZE]
138
- result.extend(super(models.Manager, self).bulk_create(chunk, **kwargs))
139
-
140
- if not bypass_hooks:
141
- engine.run(model_cls, AFTER_CREATE, result, ctx=ctx)
142
-
143
- return result
144
-
145
- @transaction.atomic
146
- def bulk_delete(
147
- self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False
148
- ):
149
- if not objs:
150
- return []
151
-
152
- model_cls = self.model
153
-
154
- if any(not isinstance(obj, model_cls) for obj in objs):
155
- raise TypeError(
156
- f"bulk_delete expected instances of {model_cls.__name__}, but got {set(type(obj).__name__ for obj in objs)}"
157
- )
158
-
159
- ctx = HookContext(model_cls)
160
-
161
- if not bypass_hooks:
162
- # Run validation hooks first
163
- if not bypass_validation:
164
- engine.run(model_cls, VALIDATE_DELETE, objs, ctx=ctx)
165
-
166
- # Then run business logic hooks
167
- engine.run(model_cls, BEFORE_DELETE, objs, ctx=ctx)
168
-
169
- pks = [obj.pk for obj in objs if obj.pk is not None]
170
-
171
- # Use base manager for the actual deletion to prevent recursion
172
- # The hooks have already been fired above, so we don't need them again
173
- model_cls._base_manager.filter(pk__in=pks).delete()
174
-
175
- if not bypass_hooks:
176
- engine.run(model_cls, AFTER_DELETE, objs, ctx=ctx)
177
-
178
- return objs
179
-
180
- @transaction.atomic
181
- def update(self, **kwargs):
182
- objs = list(self.all())
183
- if not objs:
184
- return 0
185
- for key, value in kwargs.items():
186
- for obj in objs:
187
- setattr(obj, key, value)
188
- self.bulk_update(objs, fields=list(kwargs.keys()))
189
- return len(objs)
190
-
191
- @transaction.atomic
192
- def delete(self):
193
- objs = list(self.all())
194
- if not objs:
195
- return 0
196
- self.bulk_delete(objs)
197
- return len(objs)
198
-
199
- @transaction.atomic
200
- def save(self, obj):
201
- if obj.pk:
202
- self.bulk_update(
203
- [obj],
204
- fields=[field.name for field in obj._meta.fields if field.name != "id"],
205
- )
206
- else:
207
- self.bulk_create([obj])
208
- return obj