django-bulk-hooks 0.2.9__py3-none-any.whl → 0.2.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import logging
2
3
  from functools import wraps
3
4
 
4
5
  from django.core.exceptions import FieldDoesNotExist
@@ -6,6 +7,8 @@ from django.core.exceptions import FieldDoesNotExist
6
7
  from django_bulk_hooks.enums import DEFAULT_PRIORITY
7
8
  from django_bulk_hooks.registry import register_hook
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  def hook(event, *, model, condition=None, priority=DEFAULT_PRIORITY):
11
14
  """
@@ -35,10 +38,10 @@ def select_related(*related_fields):
35
38
  def decorator(func):
36
39
  sig = inspect.signature(func)
37
40
 
38
- def preload_related(records, *, model_cls=None):
41
+ def preload_related(records, *, model_cls=None, skip_fields=None):
39
42
  if not isinstance(records, list):
40
43
  raise TypeError(
41
- f"@select_related expects a list of model instances, got {type(records)}"
44
+ f"@select_related expects a list of model instances, got {type(records)}",
42
45
  )
43
46
 
44
47
  if not records:
@@ -47,11 +50,14 @@ def select_related(*related_fields):
47
50
  if model_cls is None:
48
51
  model_cls = records[0].__class__
49
52
 
53
+ if skip_fields is None:
54
+ skip_fields = set()
55
+
50
56
  # Validate field notation upfront
51
57
  for field in related_fields:
52
58
  if "." in field:
53
59
  raise ValueError(
54
- f"Invalid field notation '{field}'. Use Django ORM __ notation (e.g., 'parent__field')"
60
+ f"Invalid field notation '{field}'. Use Django ORM __ notation (e.g., 'parent__field')",
55
61
  )
56
62
 
57
63
  direct_relation_fields = {}
@@ -70,17 +76,11 @@ def select_related(*related_fields):
70
76
  except (FieldDoesNotExist, AttributeError):
71
77
  continue
72
78
 
73
- if (
74
- relation_field.is_relation
75
- and not relation_field.many_to_many
76
- and not relation_field.one_to_many
77
- ):
79
+ if relation_field.is_relation and not relation_field.many_to_many and not relation_field.one_to_many:
78
80
  validated_fields.append(field)
79
81
  direct_relation_fields[field] = relation_field
80
82
 
81
- unsaved_related_ids_by_field = {
82
- field: set() for field in direct_relation_fields.keys()
83
- }
83
+ unsaved_related_ids_by_field = {field: set() for field in direct_relation_fields}
84
84
 
85
85
  saved_ids_to_fetch = []
86
86
  for obj in records:
@@ -88,10 +88,7 @@ def select_related(*related_fields):
88
88
  needs_fetch = False
89
89
  if hasattr(obj, "_state") and hasattr(obj._state, "fields_cache"):
90
90
  try:
91
- needs_fetch = any(
92
- field not in obj._state.fields_cache
93
- for field in related_fields
94
- )
91
+ needs_fetch = any(field not in obj._state.fields_cache for field in related_fields)
95
92
  except (TypeError, AttributeError):
96
93
  needs_fetch = True
97
94
  else:
@@ -123,14 +120,12 @@ def select_related(*related_fields):
123
120
  if base_manager is not None:
124
121
  try:
125
122
  fetched_saved = base_manager.select_related(
126
- *validated_fields
123
+ *validated_fields,
127
124
  ).in_bulk(saved_ids_to_fetch)
128
125
  except Exception:
129
126
  fetched_saved = {}
130
127
 
131
- fetched_unsaved_by_field = {
132
- field: {} for field in direct_relation_fields.keys()
133
- }
128
+ fetched_unsaved_by_field = {field: {} for field in direct_relation_fields}
134
129
 
135
130
  for field_name, relation_field in direct_relation_fields.items():
136
131
  related_ids = unsaved_related_ids_by_field[field_name]
@@ -161,6 +156,10 @@ def select_related(*related_fields):
161
156
  continue
162
157
 
163
158
  for field in related_fields:
159
+ # Skip preloading if this relationship conflicts with FK field being updated
160
+ if field in skip_fields:
161
+ continue
162
+
164
163
  if fields_cache is not None and field in fields_cache:
165
164
  continue
166
165
 
@@ -179,6 +178,10 @@ def select_related(*related_fields):
179
178
  continue
180
179
 
181
180
  for field_name, relation_field in direct_relation_fields.items():
181
+ # Skip preloading if this relationship conflicts with FK field being updated
182
+ if field_name in skip_fields:
183
+ continue
184
+
182
185
  if fields_cache is not None and field_name in fields_cache:
183
186
  continue
184
187
 
@@ -198,6 +201,12 @@ def select_related(*related_fields):
198
201
  if fields_cache is not None:
199
202
  fields_cache[field_name] = rel_obj
200
203
 
204
+ def preload_with_skip_fields(records, *, model_cls=None, skip_fields=None):
205
+ """Wrapper that applies skip_fields logic to the preload function"""
206
+ if skip_fields is None:
207
+ skip_fields = set()
208
+ return preload_related(records, model_cls=model_cls, skip_fields=skip_fields)
209
+
201
210
  @wraps(func)
202
211
  def wrapper(*args, **kwargs):
203
212
  bound = sig.bind_partial(*args, **kwargs)
@@ -205,18 +214,32 @@ def select_related(*related_fields):
205
214
 
206
215
  if "new_records" not in bound.arguments:
207
216
  raise TypeError(
208
- "@preload_related requires a 'new_records' argument in the decorated function"
217
+ "@preload_related requires a 'new_records' argument in the decorated function",
209
218
  )
210
219
 
211
220
  new_records = bound.arguments["new_records"]
212
221
 
213
- model_cls_override = bound.arguments.get("model_cls")
222
+ if not isinstance(new_records, list):
223
+ raise TypeError(
224
+ f"@select_related expects a list of model instances, got {type(new_records)}",
225
+ )
214
226
 
215
- preload_related(new_records, model_cls=model_cls_override)
227
+ if not new_records:
228
+ # Empty list, nothing to preload
229
+ return func(*args, **kwargs)
230
+
231
+ # Validate field notation upfront (same as in preload_related)
232
+ for field in related_fields:
233
+ if "." in field:
234
+ raise ValueError(
235
+ f"Invalid field notation '{field}'. Use Django ORM __ notation (e.g., 'parent__field')",
236
+ )
216
237
 
217
- return func(*bound.args, **bound.kwargs)
238
+ # Don't preload here - let the dispatcher handle it
239
+ # The dispatcher will call the preload function with skip_fields
240
+ return func(*args, **kwargs)
218
241
 
219
- wrapper._select_related_preload = preload_related
242
+ wrapper._select_related_preload = preload_with_skip_fields
220
243
  wrapper._select_related_fields = related_fields
221
244
 
222
245
  return wrapper
@@ -241,8 +264,27 @@ def bulk_hook(model_cls, event, when=None, priority=None):
241
264
  def __init__(self):
242
265
  self.func = func
243
266
 
244
- def handle(self, new_records=None, old_records=None, **kwargs):
245
- return self.func(new_records, old_records, **kwargs)
267
+ def handle(self, changeset=None, new_records=None, old_records=None, **kwargs):
268
+ # Support both old and new hook signatures for backward compatibility
269
+ # Old signature: def hook(self, new_records, old_records, **kwargs)
270
+ # New signature: def hook(self, changeset, new_records, old_records, **kwargs)
271
+
272
+ # Check function signature to determine which format to use
273
+ import inspect
274
+
275
+ sig = inspect.signature(func)
276
+ params = list(sig.parameters.keys())
277
+
278
+ if "changeset" in params:
279
+ # New signature with changeset
280
+ return self.func(changeset, new_records, old_records, **kwargs)
281
+ # Old signature without changeset
282
+ # Only pass changeset in kwargs if the function accepts **kwargs
283
+ if "kwargs" in params or any(param.startswith("**") for param in sig.parameters):
284
+ kwargs["changeset"] = changeset
285
+ return self.func(new_records, old_records, **kwargs)
286
+ # Function doesn't accept **kwargs, just call with positional args
287
+ return self.func(new_records, old_records)
246
288
 
247
289
  # Register the hook using the registry
248
290
  register_hook(