django-bulk-hooks 0.1.101__py3-none-any.whl → 0.1.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,11 +24,11 @@ def hook(event, *, model, condition=None, priority=DEFAULT_PRIORITY):
24
24
 
25
25
  def select_related(*related_fields):
26
26
  """
27
- Decorator that marks a hook method to preload related fields.
27
+ Decorator that preloads related fields in-place on `new_records`, before the hook logic runs.
28
28
 
29
- This decorator works in conjunction with the hook system to ensure that
30
- related fields are bulk-loaded before the hook logic runs, preventing
31
- queries in loops.
29
+ This decorator provides bulk loading for performance when you explicitly need it.
30
+ If you don't use this decorator, the framework will automatically detect and load
31
+ foreign keys only when conditions need them, preserving standard Django behavior.
32
32
 
33
33
  - Works with instance methods (resolves `self`)
34
34
  - Avoids replacing model instances
@@ -37,9 +37,114 @@ def select_related(*related_fields):
37
37
  """
38
38
 
39
39
  def decorator(func):
40
- # Store the related fields on the function for later access
41
- func._select_related_fields = related_fields
42
- return func
40
+ sig = inspect.signature(func)
41
+
42
+ @wraps(func)
43
+ def wrapper(*args, **kwargs):
44
+ bound = sig.bind_partial(*args, **kwargs)
45
+ bound.apply_defaults()
46
+
47
+ if "new_records" not in bound.arguments:
48
+ raise TypeError(
49
+ "@select_related requires a 'new_records' argument in the decorated function"
50
+ )
51
+
52
+ new_records = bound.arguments["new_records"]
53
+
54
+ if not isinstance(new_records, list):
55
+ raise TypeError(
56
+ f"@select_related expects a list of model instances, got {type(new_records)}"
57
+ )
58
+
59
+ if not new_records:
60
+ return func(*args, **kwargs)
61
+
62
+ # Determine which instances actually need preloading
63
+ model_cls = new_records[0].__class__
64
+ ids_to_fetch = []
65
+ instances_without_pk = []
66
+
67
+ for obj in new_records:
68
+ if obj.pk is None:
69
+ # For objects without PKs (BEFORE_CREATE), check if foreign key fields are already set
70
+ instances_without_pk.append(obj)
71
+ continue
72
+
73
+ # if any related field is not already cached on the instance,
74
+ # mark it for fetching
75
+ if any(field not in obj._state.fields_cache for field in related_fields):
76
+ ids_to_fetch.append(obj.pk)
77
+
78
+ # Load foreign keys for objects with PKs
79
+ fetched = {}
80
+ if ids_to_fetch:
81
+ fetched = model_cls.objects.select_related(*related_fields).in_bulk(ids_to_fetch)
82
+
83
+ # Apply loaded foreign keys to objects with PKs
84
+ for obj in new_records:
85
+ if obj.pk is None:
86
+ continue
87
+
88
+ preloaded = fetched.get(obj.pk)
89
+ if not preloaded:
90
+ continue
91
+ for field in related_fields:
92
+ if field in obj._state.fields_cache:
93
+ # don't override values that were explicitly set or already loaded
94
+ continue
95
+ if "." in field:
96
+ raise ValueError(
97
+ f"@select_related does not support nested fields like '{field}'"
98
+ )
99
+
100
+ try:
101
+ f = model_cls._meta.get_field(field)
102
+ if not (
103
+ f.is_relation and not f.many_to_many and not f.one_to_many
104
+ ):
105
+ continue
106
+ except FieldDoesNotExist:
107
+ continue
108
+
109
+ try:
110
+ rel_obj = getattr(preloaded, field)
111
+ setattr(obj, field, rel_obj)
112
+ obj._state.fields_cache[field] = rel_obj
113
+ except AttributeError:
114
+ pass
115
+
116
+ # For objects without PKs, ensure foreign key fields are properly set in the cache
117
+ # This prevents RelatedObjectDoesNotExist when accessing foreign keys
118
+ for obj in instances_without_pk:
119
+ for field in related_fields:
120
+ if "." in field:
121
+ raise ValueError(
122
+ f"@select_related does not support nested fields like '{field}'"
123
+ )
124
+
125
+ try:
126
+ f = model_cls._meta.get_field(field)
127
+ if not (
128
+ f.is_relation and not f.many_to_many and not f.one_to_many
129
+ ):
130
+ continue
131
+ except FieldDoesNotExist:
132
+ continue
133
+
134
+ # Check if the foreign key field is set
135
+ fk_field_name = f"{field}_id"
136
+ if hasattr(obj, fk_field_name) and getattr(obj, fk_field_name) is not None:
137
+ # The foreign key ID is set, so we can try to get the related object safely
138
+ rel_obj = safe_get_related_object(obj, field)
139
+ if rel_obj is not None:
140
+ # Ensure it's cached to prevent future queries
141
+ if not hasattr(obj._state, 'fields_cache'):
142
+ obj._state.fields_cache = {}
143
+ obj._state.fields_cache[field] = rel_obj
144
+
145
+ return func(*bound.args, **bound.kwargs)
146
+
147
+ return wrapper
43
148
 
44
149
  return decorator
45
150
 
@@ -42,7 +42,7 @@ def run(model_cls, event, new_instances, original_instances=None, ctx=None):
42
42
  original_instances = [None] * len(new_instances)
43
43
 
44
44
  # Process all hooks
45
- for handler_cls, method_name, condition, priority, select_related_fields in hooks:
45
+ for handler_cls, method_name, condition, priority in hooks:
46
46
  # Get or create handler instance from cache
47
47
  handler_key = (handler_cls, method_name)
48
48
  if handler_key not in _handler_cache:
@@ -52,87 +52,20 @@ def run(model_cls, event, new_instances, original_instances=None, ctx=None):
52
52
  else:
53
53
  handler_instance, func = _handler_cache[handler_key]
54
54
 
55
- # Apply select_related if specified
56
- if select_related_fields:
57
- new_instances_with_related = _apply_select_related(new_instances, select_related_fields)
58
- else:
59
- new_instances_with_related = new_instances
60
-
61
- # Filter instances based on condition
62
- if condition:
63
- to_process_new = []
64
- to_process_old = []
65
-
66
- logger.debug(f"Checking condition {condition.__class__.__name__} for {len(new_instances)} instances")
67
- for new, original in zip(new_instances_with_related, original_instances, strict=True):
68
- logger.debug(f"Checking instance {new.__class__.__name__}(pk={new.pk})")
69
- try:
70
- matches = condition.check(new, original)
71
- logger.debug(f"Condition check result: {matches}")
72
- if matches:
73
- to_process_new.append(new)
74
- to_process_old.append(original)
75
- except Exception as e:
76
- logger.error(f"Error checking condition: {e}")
77
- raise
78
-
79
- # Only call if we have matching instances
80
- if to_process_new:
81
- logger.debug(f"Running hook for {len(to_process_new)} matching instances")
82
- func(new_records=to_process_new, old_records=to_process_old if any(to_process_old) else None)
83
- else:
84
- logger.debug("No instances matched condition")
85
- else:
86
- # No condition, process all instances
87
- logger.debug("No condition, processing all instances")
88
- func(new_records=new_instances_with_related, old_records=original_instances if any(original_instances) else None)
89
-
90
-
91
- def _apply_select_related(instances, related_fields):
92
- """
93
- Apply select_related to instances to prevent queries in loops.
94
- This function bulk loads related objects and caches them on the instances.
95
- """
96
- if not instances:
97
- return instances
98
-
99
- # Separate instances with and without PKs
100
- instances_with_pk = [obj for obj in instances if obj.pk is not None]
101
- instances_without_pk = [obj for obj in instances if obj.pk is None]
55
+ # If no condition, process all instances at once
56
+ if not condition:
57
+ func(new_records=new_instances, old_records=original_instances if any(original_instances) else None)
58
+ continue
102
59
 
103
- # Bulk load related objects for instances with PKs
104
- if instances_with_pk:
105
- model_cls = instances_with_pk[0].__class__
106
- pks = [obj.pk for obj in instances_with_pk]
107
-
108
- # Bulk fetch with select_related
109
- fetched_instances = model_cls.objects.select_related(*related_fields).in_bulk(pks)
110
-
111
- # Apply cached related objects to original instances
112
- for obj in instances_with_pk:
113
- fetched_obj = fetched_instances.get(obj.pk)
114
- if fetched_obj:
115
- for field in related_fields:
116
- if field not in obj._state.fields_cache:
117
- try:
118
- rel_obj = getattr(fetched_obj, field)
119
- setattr(obj, field, rel_obj)
120
- obj._state.fields_cache[field] = rel_obj
121
- except AttributeError:
122
- pass
60
+ # For conditional hooks, filter instances first
61
+ to_process_new = []
62
+ to_process_old = []
123
63
 
124
- # Handle instances without PKs (e.g., BEFORE_CREATE)
125
- for obj in instances_without_pk:
126
- for field in related_fields:
127
- # Check if the foreign key field is set
128
- fk_field_name = f"{field}_id"
129
- if hasattr(obj, fk_field_name) and getattr(obj, fk_field_name) is not None:
130
- # The foreign key ID is set, so we can try to get the related object safely
131
- rel_obj = safe_get_related_object(obj, field)
132
- if rel_obj is not None:
133
- # Ensure it's cached to prevent future queries
134
- if not hasattr(obj._state, 'fields_cache'):
135
- obj._state.fields_cache = {}
136
- obj._state.fields_cache[field] = rel_obj
64
+ for new, original in zip(new_instances, original_instances, strict=True):
65
+ if condition.check(new, original):
66
+ to_process_new.append(new)
67
+ to_process_old.append(original)
137
68
 
138
- return instances
69
+ if to_process_new:
70
+ # Call the function with keyword arguments
71
+ func(new_records=to_process_new, old_records=to_process_old if any(to_process_old) else None)
@@ -1,4 +1,3 @@
1
- import inspect
2
1
  import logging
3
2
  import threading
4
3
  from collections import deque
@@ -75,11 +74,6 @@ class HookMeta(type):
75
74
  for model_cls, event, condition, priority in method.hooks_hooks:
76
75
  key = (model_cls, event, cls, method_name)
77
76
  if key not in HookMeta._registered:
78
- # Check if the method has been decorated with select_related
79
- select_related_fields = getattr(
80
- method, "_select_related_fields", None
81
- )
82
-
83
77
  register_hook(
84
78
  model=model_cls,
85
79
  event=event,
@@ -87,7 +81,6 @@ class HookMeta(type):
87
81
  method_name=method_name,
88
82
  condition=condition,
89
83
  priority=priority,
90
- select_related_fields=select_related_fields,
91
84
  )
92
85
  HookMeta._registered.add(key)
93
86
  return cls
@@ -138,17 +131,10 @@ class HookHandler(metaclass=HookMeta):
138
131
  if len(old_local) < len(new_local):
139
132
  old_local += [None] * (len(new_local) - len(old_local))
140
133
 
141
- for handler_cls, method_name, condition, priority, select_related_fields in hooks:
142
- # Apply select_related if specified to prevent queries in loops
143
- if select_related_fields:
144
- from django_bulk_hooks.engine import _apply_select_related
145
- new_local_with_related = _apply_select_related(new_local, select_related_fields)
146
- else:
147
- new_local_with_related = new_local
148
-
134
+ for handler_cls, method_name, condition, priority in hooks:
149
135
  if condition is not None:
150
136
  checks = [
151
- condition.check(n, o) for n, o in zip(new_local_with_related, old_local)
137
+ condition.check(n, o) for n, o in zip(new_local, old_local)
152
138
  ]
153
139
  if not any(checks):
154
140
  continue
@@ -156,21 +142,10 @@ class HookHandler(metaclass=HookMeta):
156
142
  handler = handler_cls()
157
143
  method = getattr(handler, method_name)
158
144
 
159
- # Inspect the method signature to determine parameter order
160
- import inspect
161
-
162
- sig = inspect.signature(method)
163
- params = list(sig.parameters.keys())
164
-
165
- # Remove 'self' from params if it exists
166
- if params and params[0] == "self":
167
- params = params[1:]
168
-
169
- # Always call with keyword arguments to make order irrelevant
170
145
  try:
171
146
  method(
147
+ new_records=new_local,
172
148
  old_records=old_local,
173
- new_records=new_local_with_related,
174
149
  **kwargs,
175
150
  )
176
151
  except Exception:
@@ -20,7 +20,7 @@ class BulkHookManager(models.Manager):
20
20
  # Default chunk sizes - can be overridden per model
21
21
  DEFAULT_CHUNK_SIZE = 200
22
22
  DEFAULT_RELATED_CHUNK_SIZE = 500 # Higher for related object fetching
23
-
23
+
24
24
  def __init__(self):
25
25
  super().__init__()
26
26
  self._chunk_size = self.DEFAULT_CHUNK_SIZE
@@ -28,11 +28,16 @@ class BulkHookManager(models.Manager):
28
28
  self._prefetch_related_fields = set()
29
29
  self._select_related_fields = set()
30
30
 
31
- def configure(self, chunk_size=None, related_chunk_size=None,
32
- select_related=None, prefetch_related=None):
31
+ def configure(
32
+ self,
33
+ chunk_size=None,
34
+ related_chunk_size=None,
35
+ select_related=None,
36
+ prefetch_related=None,
37
+ ):
33
38
  """
34
39
  Configure bulk operation parameters for this manager.
35
-
40
+
36
41
  Args:
37
42
  chunk_size: Number of objects to process in each bulk operation chunk
38
43
  related_chunk_size: Number of objects to fetch in each related object query
@@ -53,24 +58,24 @@ class BulkHookManager(models.Manager):
53
58
  Optimized loading of original instances with smart batching and field selection.
54
59
  """
55
60
  queryset = self.model.objects.filter(pk__in=pks)
56
-
57
- # Only select specific fields if provided
58
- if fields_to_fetch:
59
- queryset = queryset.only('pk', *fields_to_fetch)
60
-
61
+
62
+ # Only select specific fields if provided and not empty
63
+ if fields_to_fetch and len(fields_to_fetch) > 0:
64
+ queryset = queryset.only("pk", *fields_to_fetch)
65
+
61
66
  # Apply configured related field optimizations
62
67
  if self._select_related_fields:
63
68
  queryset = queryset.select_related(*self._select_related_fields)
64
69
  if self._prefetch_related_fields:
65
70
  queryset = queryset.prefetch_related(*self._prefetch_related_fields)
66
-
71
+
67
72
  # Batch load in chunks to avoid memory issues
68
73
  all_originals = []
69
74
  for i in range(0, len(pks), self._related_chunk_size):
70
- chunk_pks = pks[i:i + self._related_chunk_size]
75
+ chunk_pks = pks[i : i + self._related_chunk_size]
71
76
  chunk_originals = list(queryset.filter(pk__in=chunk_pks))
72
77
  all_originals.extend(chunk_originals)
73
-
78
+
74
79
  return all_originals
75
80
 
76
81
  def _get_fields_to_fetch(self, objs, fields):
@@ -79,20 +84,51 @@ class BulkHookManager(models.Manager):
79
84
  and what's needed for hooks.
80
85
  """
81
86
  fields_to_fetch = set(fields)
82
-
87
+
83
88
  # Add fields needed by registered hooks
84
89
  from django_bulk_hooks.registry import get_hooks
85
- hooks = get_hooks(self.model, "before_update") + get_hooks(self.model, "after_update")
86
-
90
+
91
+ hooks = get_hooks(self.model, "before_update") + get_hooks(
92
+ self.model, "after_update"
93
+ )
94
+
87
95
  for handler_cls, method_name, condition, _ in hooks:
88
96
  if condition:
89
97
  # If there's a condition, we need all fields it might access
90
98
  fields_to_fetch.update(condition.get_required_fields())
91
-
92
- return fields_to_fetch
99
+
100
+ # Filter out fields that don't exist on the model
101
+ valid_fields = set()
102
+ invalid_fields = set()
103
+ for field_name in fields_to_fetch:
104
+ try:
105
+ self.model._meta.get_field(field_name)
106
+ valid_fields.add(field_name)
107
+ except Exception as e:
108
+ # Field doesn't exist, skip it
109
+ invalid_fields.add(field_name)
110
+ import logging
111
+
112
+ logger = logging.getLogger(__name__)
113
+ logger.debug(
114
+ f"Field '{field_name}' requested by hook condition but doesn't exist on {self.model.__name__}: {e}"
115
+ )
116
+ continue
117
+
118
+ if invalid_fields:
119
+ import logging
120
+ logger = logging.getLogger(__name__)
121
+ logger.warning(
122
+ f"Invalid fields requested for {self.model.__name__}: {invalid_fields}. "
123
+ f"These fields were ignored to prevent errors."
124
+ )
125
+
126
+ return valid_fields
93
127
 
94
128
  @transaction.atomic
95
- def bulk_update(self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs):
129
+ def bulk_update(
130
+ self, objs, fields, bypass_hooks=False, bypass_validation=False, **kwargs
131
+ ):
96
132
  if not objs:
97
133
  return []
98
134
 
@@ -106,14 +142,14 @@ class BulkHookManager(models.Manager):
106
142
  if not bypass_hooks:
107
143
  # Determine which fields we need to fetch
108
144
  fields_to_fetch = self._get_fields_to_fetch(objs, fields)
109
-
145
+
110
146
  # Load originals efficiently
111
147
  pks = [obj.pk for obj in objs if obj.pk is not None]
112
148
  originals = self._load_originals_optimized(pks, fields_to_fetch)
113
-
149
+
114
150
  # Create a mapping for quick lookup
115
151
  original_map = {obj.pk: obj for obj in originals}
116
-
152
+
117
153
  # Align originals with new instances
118
154
  aligned_originals = [original_map.get(obj.pk) for obj in objs]
119
155
 
@@ -135,7 +171,7 @@ class BulkHookManager(models.Manager):
135
171
 
136
172
  # Process in chunks
137
173
  for i in range(0, len(objs), self._chunk_size):
138
- chunk = objs[i:i + self._chunk_size]
174
+ chunk = objs[i : i + self._chunk_size]
139
175
  super(models.Manager, self).bulk_update(chunk, fields, **kwargs)
140
176
 
141
177
  if not bypass_hooks:
@@ -205,30 +241,32 @@ class BulkHookManager(models.Manager):
205
241
  # Process validation in chunks to avoid memory issues
206
242
  if not bypass_validation:
207
243
  for i in range(0, len(objs), self._chunk_size):
208
- chunk = objs[i:i + self._chunk_size]
244
+ chunk = objs[i : i + self._chunk_size]
209
245
  engine.run(model_cls, VALIDATE_CREATE, chunk, ctx=ctx)
210
246
 
211
247
  # Process before_create hooks in chunks
212
248
  for i in range(0, len(objs), self._chunk_size):
213
- chunk = objs[i:i + self._chunk_size]
249
+ chunk = objs[i : i + self._chunk_size]
214
250
  engine.run(model_cls, BEFORE_CREATE, chunk, ctx=ctx)
215
251
 
216
252
  # Perform bulk create in chunks
217
253
  for i in range(0, len(objs), self._chunk_size):
218
- chunk = objs[i:i + self._chunk_size]
254
+ chunk = objs[i : i + self._chunk_size]
219
255
  created_chunk = super(models.Manager, self).bulk_create(chunk, **kwargs)
220
256
  result.extend(created_chunk)
221
257
 
222
258
  if not bypass_hooks:
223
259
  # Process after_create hooks in chunks
224
260
  for i in range(0, len(result), self._chunk_size):
225
- chunk = result[i:i + self._chunk_size]
261
+ chunk = result[i : i + self._chunk_size]
226
262
  engine.run(model_cls, AFTER_CREATE, chunk, ctx=ctx)
227
263
 
228
264
  return result
229
265
 
230
266
  @transaction.atomic
231
- def bulk_delete(self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False):
267
+ def bulk_delete(
268
+ self, objs, batch_size=None, bypass_hooks=False, bypass_validation=False
269
+ ):
232
270
  if not objs:
233
271
  return []
234
272
 
@@ -245,8 +283,8 @@ class BulkHookManager(models.Manager):
245
283
  if not bypass_hooks:
246
284
  # Process hooks in chunks
247
285
  for i in range(0, len(objs), chunk_size):
248
- chunk = objs[i:i + chunk_size]
249
-
286
+ chunk = objs[i : i + chunk_size]
287
+
250
288
  if not bypass_validation:
251
289
  engine.run(model_cls, VALIDATE_DELETE, chunk, ctx=ctx)
252
290
  engine.run(model_cls, BEFORE_DELETE, chunk, ctx=ctx)
@@ -254,13 +292,13 @@ class BulkHookManager(models.Manager):
254
292
  # Collect PKs and delete in chunks
255
293
  pks = [obj.pk for obj in objs if obj.pk is not None]
256
294
  for i in range(0, len(pks), chunk_size):
257
- chunk_pks = pks[i:i + chunk_size]
295
+ chunk_pks = pks[i : i + chunk_size]
258
296
  model_cls._base_manager.filter(pk__in=chunk_pks).delete()
259
297
 
260
298
  if not bypass_hooks:
261
299
  # Process after_delete hooks in chunks
262
300
  for i in range(0, len(objs), chunk_size):
263
- chunk = objs[i:i + chunk_size]
301
+ chunk = objs[i : i + chunk_size]
264
302
  engine.run(model_cls, AFTER_DELETE, chunk, ctx=ctx)
265
303
 
266
304
  return objs
@@ -1,8 +1,4 @@
1
- import contextlib
2
- from functools import wraps
3
-
4
1
  from django.db import models, transaction
5
- from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor
6
2
 
7
3
  from django_bulk_hooks.constants import (
8
4
  AFTER_CREATE,
@@ -18,6 +14,9 @@ from django_bulk_hooks.constants import (
18
14
  from django_bulk_hooks.context import HookContext
19
15
  from django_bulk_hooks.engine import run
20
16
  from django_bulk_hooks.manager import BulkHookManager
17
+ from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor
18
+ from functools import wraps
19
+ import contextlib
21
20
 
22
21
 
23
22
  @contextlib.contextmanager
@@ -27,7 +26,7 @@ def patch_foreign_key_behavior():
27
26
  RelatedObjectDoesNotExist when accessing an unset foreign key field.
28
27
  """
29
28
  original_get = ForwardManyToOneDescriptor.__get__
30
-
29
+
31
30
  @wraps(original_get)
32
31
  def safe_get(self, instance, cls=None):
33
32
  if instance is None:
@@ -36,7 +35,7 @@ def patch_foreign_key_behavior():
36
35
  return original_get(self, instance, cls)
37
36
  except self.RelatedObjectDoesNotExist:
38
37
  return None
39
-
38
+
40
39
  # Patch the descriptor
41
40
  ForwardManyToOneDescriptor.__get__ = safe_get
42
41
  try:
@@ -64,7 +63,7 @@ class HookModelMixin(models.Model):
64
63
  # Skip hook validation during admin form validation
65
64
  # This prevents RelatedObjectDoesNotExist errors when Django hasn't
66
65
  # fully set up the object's relationships yet
67
- if hasattr(self, "_state") and getattr(self._state, "validating", False):
66
+ if hasattr(self, '_state') and getattr(self._state, 'validating', False):
68
67
  return
69
68
 
70
69
  # Determine if this is a create or update operation
@@ -81,9 +80,7 @@ class HookModelMixin(models.Model):
81
80
  old_instance = self.__class__.objects.get(pk=self.pk)
82
81
  ctx = HookContext(self.__class__)
83
82
  with patch_foreign_key_behavior():
84
- run(
85
- self.__class__, VALIDATE_UPDATE, [self], [old_instance], ctx=ctx
86
- )
83
+ run(self.__class__, VALIDATE_UPDATE, [self], [old_instance], ctx=ctx)
87
84
  except self.__class__.DoesNotExist:
88
85
  # If the old instance doesn't exist, treat as create
89
86
  ctx = HookContext(self.__class__)
@@ -94,40 +91,24 @@ class HookModelMixin(models.Model):
94
91
  is_create = self.pk is None
95
92
  ctx = HookContext(self.__class__)
96
93
 
97
- # Run BEFORE hooks before saving to allow field modifications
94
+ # Use a single context manager for all hooks
98
95
  with patch_foreign_key_behavior():
99
96
  if is_create:
100
97
  # For create operations
101
- run(self.__class__, VALIDATE_CREATE, [self], ctx=ctx)
102
98
  run(self.__class__, BEFORE_CREATE, [self], ctx=ctx)
103
- else:
104
- # For update operations
105
- try:
106
- old_instance = self.__class__.objects.get(pk=self.pk)
107
- run(
108
- self.__class__, VALIDATE_UPDATE, [self], [old_instance], ctx=ctx
109
- )
110
- run(self.__class__, BEFORE_UPDATE, [self], [old_instance], ctx=ctx)
111
- except self.__class__.DoesNotExist:
112
- # If the old instance doesn't exist, treat as create
113
- run(self.__class__, VALIDATE_CREATE, [self], ctx=ctx)
114
- run(self.__class__, BEFORE_CREATE, [self], ctx=ctx)
115
-
116
- # Now let Django save with any modifications from BEFORE hooks
117
- super().save(*args, **kwargs)
118
-
119
- # Then run AFTER hooks
120
- with patch_foreign_key_behavior():
121
- if is_create:
122
- # For create operations
99
+ super().save(*args, **kwargs)
123
100
  run(self.__class__, AFTER_CREATE, [self], ctx=ctx)
124
101
  else:
125
102
  # For update operations
126
103
  try:
127
104
  old_instance = self.__class__.objects.get(pk=self.pk)
105
+ run(self.__class__, BEFORE_UPDATE, [self], [old_instance], ctx=ctx)
106
+ super().save(*args, **kwargs)
128
107
  run(self.__class__, AFTER_UPDATE, [self], [old_instance], ctx=ctx)
129
108
  except self.__class__.DoesNotExist:
130
109
  # If the old instance doesn't exist, treat as create
110
+ run(self.__class__, BEFORE_CREATE, [self], ctx=ctx)
111
+ super().save(*args, **kwargs)
131
112
  run(self.__class__, AFTER_CREATE, [self], ctx=ctx)
132
113
 
133
114
  return self
@@ -141,5 +122,5 @@ class HookModelMixin(models.Model):
141
122
  run(self.__class__, BEFORE_DELETE, [self], ctx=ctx)
142
123
  result = super().delete(*args, **kwargs)
143
124
  run(self.__class__, AFTER_DELETE, [self], ctx=ctx)
144
-
125
+
145
126
  return result
@@ -3,15 +3,15 @@ from typing import Union
3
3
 
4
4
  from django_bulk_hooks.enums import Priority
5
5
 
6
- _hooks: dict[tuple[type, str], list[tuple[type, str, Callable, int, tuple]]] = {}
6
+ _hooks: dict[tuple[type, str], list[tuple[type, str, Callable, int]]] = {}
7
7
 
8
8
 
9
9
  def register_hook(
10
- model, event, handler_cls, method_name, condition, priority: Union[int, Priority], select_related_fields=None
10
+ model, event, handler_cls, method_name, condition, priority: Union[int, Priority]
11
11
  ):
12
12
  key = (model, event)
13
13
  hooks = _hooks.setdefault(key, [])
14
- hooks.append((handler_cls, method_name, condition, priority, select_related_fields))
14
+ hooks.append((handler_cls, method_name, condition, priority))
15
15
  # keep sorted by priority
16
16
  hooks.sort(key=lambda x: x[3])
17
17