congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/constraints.py CHANGED
@@ -1,65 +1,55 @@
1
- """
2
- This module provides a set of constraint classes for guiding neural network
3
- training by enforcing specific conditions on the network's outputs.
4
-
5
- The constraints in this module include:
6
-
7
- - `Constraint`: The base class for all constraint types, defining the
8
- interface and core behavior.
9
- - `ImplicationConstraint`: A constraint that enforces one condition only if
10
- another condition is met, useful for modeling implications between network
11
- outputs.
12
- - `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
13
- a network's output.
14
- - `BinaryConstraint`: A constraint that enforces a binary comparison between
15
- two neurons in the network, using a comparison function (e.g., less than,
16
- greater than).
17
- - `SumConstraint`: A constraint that enforces that the sum of certain neurons'
18
- outputs equals a specified value, which can be used to control total output.
19
- - `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
20
- on a set of neurons, ensuring that the square of one neuron's output is equal
21
- to the sum of the squares of other outputs.
22
-
23
- These constraints can be used to steer the learning process by applying
24
- conditions such as logical implications or numerical bounds.
1
+ """Module providing constraint classes for guiding neural network training.
2
+
3
+ This module defines constraints that enforce specific conditions on network outputs
4
+ to steer learning. Available constraint types include:
5
+
6
+ - `Constraint`: Base class for all constraint types, defining the interface and core
7
+ behavior.
8
+ - `ImplicationConstraint`: Enforces one condition only if another condition is met,
9
+ useful for modeling implications between outputs.
10
+ - `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
11
+ - `BinaryConstraint`: Enforces a binary comparison between two tags using a
12
+ comparison function (e.g., less than, greater than).
13
+ - `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
14
+ value, controlling total output.
15
+
16
+ These constraints can steer the learning process by applying logical implications
17
+ or numerical bounds.
25
18
 
26
19
  Usage:
27
20
  1. Define a custom constraint class by inheriting from `Constraint`.
28
- 2. Apply the constraint to your neural network during training to
29
- enforce desired output behaviors.
30
- 3. Use the helper classes like `IdentityTransformation` for handling
31
- transformations and comparisons in constraints.
21
+ 2. Apply the constraint to your neural network during training.
22
+ 3. Use helper classes like `IdentityTransformation` for transformations and
23
+ comparisons in constraints.
32
24
 
33
- Dependencies:
34
- - PyTorch (`torch`)
35
25
  """
36
26
 
37
27
  import random
38
28
  import string
39
29
  import warnings
40
30
  from abc import ABC, abstractmethod
31
+ from collections.abc import Callable
41
32
  from numbers import Number
42
- from typing import Callable, Dict, Union
33
+ from typing import Literal
43
34
 
44
35
  from torch import (
45
36
  Tensor,
46
- count_nonzero,
37
+ argsort,
38
+ eq,
47
39
  ge,
48
40
  gt,
49
- isclose,
50
41
  le,
42
+ logical_and,
51
43
  logical_not,
52
44
  logical_or,
53
45
  lt,
54
- numel,
55
46
  ones,
56
47
  ones_like,
57
48
  reshape,
58
49
  sign,
59
- sqrt,
60
- square,
61
50
  stack,
62
51
  tensor,
52
+ unique,
63
53
  zeros_like,
64
54
  )
65
55
  from torch.nn.functional import normalize
@@ -70,8 +60,7 @@ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
70
60
 
71
61
 
72
62
  class Constraint(ABC):
73
- """
74
- Abstract base class for defining constraints applied to neural networks.
63
+ """Abstract base class for defining constraints applied to neural networks.
75
64
 
76
65
  A `Constraint` specifies conditions that the neural network outputs
77
66
  should satisfy. It supports monitoring constraint satisfaction
@@ -79,23 +68,22 @@ class Constraint(ABC):
79
68
  must implement the `check_constraint` and `calculate_direction` methods.
80
69
 
81
70
  Args:
82
- neurons (set[str]): Names of the neurons this constraint applies to.
71
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
83
72
  name (str, optional): A unique name for the constraint. If not provided,
84
73
  a name is generated based on the class name and a random suffix.
85
- monitor_only (bool, optional): If True, only monitor the constraint
86
- without adjusting the loss. Defaults to False.
74
+ enforce (bool, optional): If False, only monitor the constraint
75
+ without adjusting the loss. Defaults to True.
87
76
  rescale_factor (Number, optional): Factor to scale the
88
77
  constraint-adjusted loss. Defaults to 1.5. Should be greater
89
78
  than 1 to give weight to the constraint.
90
79
 
91
80
  Raises:
92
81
  TypeError: If a provided attribute has an incompatible type.
93
- ValueError: If any neuron in `neurons` is not
82
+ ValueError: If any tag in `tags` is not
94
83
  defined in the `descriptor`.
95
84
 
96
85
  Note:
97
- - If `rescale_factor <= 1`, a warning is issued, and the value is
98
- adjusted to a positive value greater than 1.
86
+ - If `rescale_factor <= 1`, a warning is issued.
99
87
  - If `name` is not provided, a name is auto-generated,
100
88
  and a warning is logged.
101
89
 
@@ -105,38 +93,53 @@ class Constraint(ABC):
105
93
  device = None
106
94
 
107
95
  def __init__(
108
- self,
109
- neurons: set[str],
110
- name: str = None,
111
- monitor_only: bool = False,
112
- rescale_factor: Number = 1.5,
96
+ self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
113
97
  ) -> None:
114
- """
115
- Initializes a new Constraint instance.
116
- """
98
+ """Initializes a new Constraint instance.
117
99
 
100
+ Args:
101
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
102
+ name (str, optional): A unique name for the constraint. If not
103
+ provided, a name is generated based on the class name and a
104
+ random suffix.
105
+ enforce (bool, optional): If False, only monitor the constraint
106
+ without adjusting the loss. Defaults to True.
107
+ rescale_factor (Number, optional): Factor to scale the
108
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
109
+ than 1 to give weight to the constraint.
110
+
111
+ Raises:
112
+ TypeError: If a provided attribute has an incompatible type.
113
+ ValueError: If any tag in `tags` is not defined in the `descriptor`.
114
+
115
+ Note:
116
+ - If `rescale_factor <= 1`, a warning is issued.
117
+ - If `name` is not provided, a name is auto-generated, and a
118
+ warning is logged.
119
+ """
118
120
  # Init parent class
119
121
  super().__init__()
120
122
 
121
123
  # Type checking
122
- validate_iterable("neurons", neurons, str)
123
- validate_type("name", name, (str, type(None)))
124
- validate_type("monitor_only", monitor_only, bool)
124
+ validate_iterable("tags", tags, str)
125
+ validate_type("name", name, str, allow_none=True)
126
+ validate_type("enforce", enforce, bool)
125
127
  validate_type("rescale_factor", rescale_factor, Number)
126
128
 
127
129
  # Init object variables
128
- self.neurons = neurons
130
+ self.tags = tags
129
131
  self.rescale_factor = rescale_factor
130
- self.monitor_only = monitor_only
132
+ self.initial_rescale_factor = rescale_factor
133
+ self.enforce = enforce
131
134
 
132
135
  # Perform checks
133
136
  if rescale_factor <= 1:
134
137
  warnings.warn(
135
- "Rescale factor for constraint %s is <= 1. The network \
136
- will favor general loss over the constraint-adjusted loss. \
137
- Is this intended behavior? Normally, the loss should \
138
- always be larger than 1.",
139
- name,
138
+ f"Rescale factor for constraint {name} is <= 1. The network "
139
+ "will favor general loss over the constraint-adjusted loss. "
140
+ "Is this intended behavior? Normally, the rescale factor "
141
+ "should always be larger than 1.",
142
+ stacklevel=2,
140
143
  )
141
144
 
142
145
  # If no constraint_name is set, generate one based
@@ -144,124 +147,94 @@ class Constraint(ABC):
144
147
  if name:
145
148
  self.name = name
146
149
  else:
147
- random_suffix = "".join(
148
- random.choices(string.ascii_uppercase + string.digits, k=6)
149
- )
150
+ random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
150
151
  self.name = f"{self.__class__.__name__}_{random_suffix}"
151
- warnings.warn(
152
- "Name for constraint is not set. Using %s.", self.name
153
- )
152
+ warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
154
153
 
155
- # If rescale factor is not larger than 1, warn user and adjust
156
- if rescale_factor <= 1:
157
- self.rescale_factor = abs(rescale_factor) + 1.5
158
- warnings.warn(
159
- "Rescale factor for constraint %s is < 1, adjusted value \
160
- %s to %s.",
161
- name,
162
- rescale_factor,
163
- self.rescale_factor,
164
- )
165
- else:
166
- self.rescale_factor = rescale_factor
167
-
168
- # Infer layers from descriptor and neurons
154
+ # Infer layers from descriptor and tags
169
155
  self.layers = set()
170
- for neuron in self.neurons:
171
- if neuron not in self.descriptor.neuron_to_layer.keys():
156
+ for tag in self.tags:
157
+ if not self.descriptor.exists(tag):
172
158
  raise ValueError(
173
- f'The neuron name {neuron} used with constraint \
174
- {self.name} is not defined in the descriptor. Please \
175
- add it to the correct layer using \
176
- descriptor.add("layer", ...).'
159
+ f"The tag {tag} used with constraint "
160
+ f"{self.name} is not defined in the descriptor. Please "
161
+ "add it to the correct layer using "
162
+ "descriptor.add('layer', ...)."
177
163
  )
178
164
 
179
- self.layers.add(self.descriptor.neuron_to_layer[neuron])
165
+ layer, _ = self.descriptor.location(tag)
166
+ self.layers.add(layer)
180
167
 
181
168
  @abstractmethod
182
- def check_constraint(
183
- self, prediction: dict[str, Tensor]
184
- ) -> tuple[Tensor, int]:
185
- """
186
- Evaluates whether the given model predictions satisfy the constraint.
169
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
170
+ """Evaluates whether the given model predictions satisfy the constraint.
171
+
172
+ 1 IS SATISFIED, 0 IS NOT SATISFIED
187
173
 
188
174
  Args:
189
- prediction (dict[str, Tensor]): Model predictions for the neurons.
175
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
190
176
 
191
177
  Returns:
192
- tuple[Tensor, int]: A tuple where the first element is a tensor
193
- indicating whether the constraint is satisfied (with `True`
194
- for satisfaction, `False` for non-satisfaction, and `torch.nan`
195
- for irrelevant results), and the second element is an integer
196
- value representing the number of relevant constraints.
178
+ tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
179
+ indicating whether the constraint is satisfied (with value 1.0
180
+ for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
181
+ mask that indicates the relevance of each sample (`True` for relevant
182
+ samples and `False` for irrelevant ones).
197
183
 
198
184
  Raises:
199
185
  NotImplementedError: If not implemented in a subclass.
200
186
  """
201
-
202
187
  raise NotImplementedError
203
188
 
204
189
  @abstractmethod
205
- def calculate_direction(
206
- self, prediction: dict[str, Tensor]
207
- ) -> Dict[str, Tensor]:
208
- """
209
- Calculates adjustment directions for neurons to
210
- better satisfy the constraint.
190
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
191
+ """Compute adjustment directions to better satisfy the constraint.
192
+
193
+ Given the model predictions, input batch, and context, this method calculates the direction
194
+ in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
211
195
 
212
196
  Args:
213
- prediction (dict[str, Tensor]): Model predictions for the neurons.
197
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
214
198
 
215
199
  Returns:
216
- Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
217
- specifying the adjustment direction for each neuron.
200
+ dict[str, Tensor]: Dictionary mapping network layers to tensors that
201
+ specify the adjustment direction for each tag.
218
202
 
219
203
  Raises:
220
- NotImplementedError: If not implemented in a subclass.
204
+ NotImplementedError: Must be implemented by subclasses.
221
205
  """
222
-
223
206
  raise NotImplementedError
224
207
 
225
208
 
226
209
  class ImplicationConstraint(Constraint):
227
- """
228
- Represents an implication constraint between two
229
- constraints (head and body).
210
+ """Represents an implication constraint between two constraints (head and body).
230
211
 
231
212
  The implication constraint ensures that the `body` constraint only applies
232
213
  when the `head` constraint is satisfied. If the `head` constraint is not
233
214
  satisfied, the `body` constraint does not apply.
234
-
235
- Args:
236
- head (Constraint): The head of the implication. If this constraint
237
- is satisfied, the body constraint must also be satisfied.
238
- body (Constraint): The body of the implication. This constraint
239
- is enforced only when the head constraint is satisfied.
240
- name (str, optional): A unique name for the constraint. If not
241
- provided, the name is generated in the format
242
- "{body.name} if {head.name}". Defaults to None.
243
- monitor_only (bool, optional): If True, the constraint is only
244
- monitored without adjusting the loss. Defaults to False.
245
- rescale_factor (Number, optional): The scaling factor for the
246
- constraint-adjusted loss. Defaults to 1.5.
247
-
248
- Raises:
249
- TypeError: If a provided attribute has an incompatible type.
250
-
251
215
  """
252
216
 
253
217
  def __init__(
254
218
  self,
255
219
  head: Constraint,
256
220
  body: Constraint,
257
- name=None,
258
- monitor_only=False,
259
- rescale_factor=1.5,
221
+ name: str = None,
260
222
  ):
261
- """
262
- Initializes an ImplicationConstraint instance.
263
- """
223
+ """Initializes an ImplicationConstraint instance.
224
+
225
+ Uses `enforce` and `rescale_factor` from the body constraint.
226
+
227
+ Args:
228
+ head (Constraint): Constraint defining the head of the implication.
229
+ body (Constraint): Constraint defining the body of the implication.
230
+ name (str, optional): A unique name for the constraint. If not
231
+ provided, a name is generated based on the class name and a
232
+ random suffix.
233
+
234
+ Raises:
235
+ TypeError: If a provided attribute has an incompatible type.
264
236
 
237
+ """
265
238
  # Type checking
266
239
  validate_type("head", head, Constraint)
267
240
  validate_type("body", body, Constraint)
@@ -270,64 +243,82 @@ class ImplicationConstraint(Constraint):
270
243
  name = f"{body.name} if {head.name}"
271
244
 
272
245
  # Init parent class
273
- super().__init__(
274
- head.neurons | body.neurons,
275
- name,
276
- monitor_only,
277
- rescale_factor,
278
- )
246
+ super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
279
247
 
280
248
  self.head = head
281
249
  self.body = body
282
250
 
283
- def check_constraint(
284
- self, prediction: dict[str, Tensor]
285
- ) -> tuple[Tensor, int]:
251
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
252
+ """Check whether the implication constraint is satisfied.
286
253
 
254
+ Evaluates the `head` and `body` constraints. The `body` constraint
255
+ is enforced only if the `head` constraint is satisfied. If the
256
+ `head` constraint is not satisfied, the `body` constraint does not
257
+ affect the result.
258
+
259
+ Args:
260
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
261
+
262
+ Returns:
263
+ tuple[Tensor, Tensor]:
264
+ - result: Tensor indicating satisfaction of the implication
265
+ constraint (1 if satisfied, 0 otherwise).
266
+ - head_satisfaction: Tensor indicating satisfaction of the
267
+ head constraint alone.
268
+ """
287
269
  # Check satisfaction of head and body constraints
288
- head_satisfaction, _ = self.head.check_constraint(prediction)
289
- body_satisfaction, _ = self.body.check_constraint(prediction)
270
+ head_satisfaction, _ = self.head.check_constraint(data)
271
+ body_satisfaction, _ = self.body.check_constraint(data)
290
272
 
291
273
  # If head constraint is satisfied (returning 1),
292
274
  # the body constraint matters (and should return 0/1 based on body)
293
275
  # If head constraint is not satisfied (returning 0),
294
276
  # the body constraint does not apply (and should return 1)
295
- result = logical_or(
296
- logical_not(head_satisfaction), body_satisfaction
297
- ).float()
277
+ result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
298
278
 
299
- return result, count_nonzero(head_satisfaction)
279
+ return result, head_satisfaction
280
+
281
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
282
+ """Compute adjustment directions for tags to satisfy the constraint.
283
+
284
+ Uses the `body` constraint directions as the update vector. Only
285
+ applies updates if the `head` constraint is satisfied. Currently,
286
+ this method only works for dense layers due to tag-to-index
287
+ translation limitations.
288
+
289
+ Args:
290
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
300
291
 
301
- def calculate_direction(
302
- self, prediction: dict[str, Tensor]
303
- ) -> Dict[str, Tensor]:
292
+ Returns:
293
+ dict[str, Tensor]: Dictionary mapping tags to tensors
294
+ specifying the adjustment direction for each tag.
295
+ """
304
296
  # NOTE currently only works for dense layers
305
- # due to neuron to index translation
297
+ # due to tag to index translation
306
298
 
307
299
  # Use directions of constraint body as update vector
308
- return self.body.calculate_direction(prediction)
300
+ return self.body.calculate_direction(data)
309
301
 
310
302
 
311
303
  class ScalarConstraint(Constraint):
312
- """
313
- A constraint that enforces scalar-based comparisons on a specific neuron.
304
+ """A constraint that enforces scalar-based comparisons on a specific tag.
314
305
 
315
- This class ensures that the output of a specified neuron satisfies a scalar
306
+ This class ensures that the output of a specified tag satisfies a scalar
316
307
  comparison operation (e.g., less than, greater than, etc.). It uses a
317
308
  comparator function to validate the condition and calculates adjustment
318
309
  directions accordingly.
319
310
 
320
311
  Args:
321
- operand (Union[str, Transformation]): Name of the neuron or a
312
+ operand (Union[str, Transformation]): Name of the tag or a
322
313
  transformation to apply.
323
314
  comparator (Callable[[Tensor, Number], Tensor]): A comparison
324
315
  function (e.g., `torch.ge`, `torch.lt`).
325
316
  scalar (Number): The scalar value to compare against.
326
317
  name (str, optional): A unique name for the constraint. If not
327
318
  provided, a name is auto-generated in the format
328
- "<neuron_name> <comparator> <scalar>".
329
- monitor_only (bool, optional): If True, only monitor the constraint
330
- without adjusting the loss. Defaults to False.
319
+ "<tag> <comparator> <scalar>".
320
+ enforce (bool, optional): If False, only monitor the constraint
321
+ without adjusting the loss. Defaults to True.
331
322
  rescale_factor (Number, optional): Factor to scale the
332
323
  constraint-adjusted loss. Defaults to 1.5.
333
324
 
@@ -335,87 +326,120 @@ class ScalarConstraint(Constraint):
335
326
  TypeError: If a provided attribute has an incompatible type.
336
327
 
337
328
  Notes:
338
- - The `neuron_name` must be defined in the `descriptor` mapping.
339
- - The constraint name is composed using the neuron name,
340
- comparator, and scalar value.
329
+ - The `tag` must be defined in the `descriptor` mapping.
330
+ - The constraint name is composed using the tag, comparator, and scalar value.
341
331
 
342
332
  """
343
333
 
344
334
  def __init__(
345
335
  self,
346
- operand: Union[str, Transformation],
336
+ operand: str | Transformation,
347
337
  comparator: Callable[[Tensor, Number], Tensor],
348
338
  scalar: Number,
349
339
  name: str = None,
350
- monitor_only: bool = False,
340
+ enforce: bool = True,
351
341
  rescale_factor: Number = 1.5,
352
342
  ) -> None:
353
- """
354
- Initializes a ScalarConstraint instance.
355
- """
343
+ """Initializes a ScalarConstraint instance.
344
+
345
+ Args:
346
+ operand (Union[str, Transformation]): Function that needs to be
347
+ performed on the network variables before applying the
348
+ constraint.
349
+ comparator (Callable[[Tensor, Number], Tensor]): Comparison
350
+ operator used in the constraint. Supported types are
351
+ {torch.lt, torch.le, torch.st, torch.se}.
352
+ scalar (Number): Constant to compare the variable to.
353
+ name (str, optional): A unique name for the constraint. If not
354
+ provided, a name is generated based on the class name and a
355
+ random suffix.
356
+ enforce (bool, optional): If False, only monitor the constraint
357
+ without adjusting the loss. Defaults to True.
358
+ rescale_factor (Number, optional): Factor to scale the
359
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
360
+ than 1 to give weight to the constraint.
356
361
 
362
+ Raises:
363
+ TypeError: If a provided attribute has an incompatible type.
364
+
365
+ Notes:
366
+ - The `tag` must be defined in the `descriptor` mapping.
367
+ - The constraint name is composed using the tag, comparator, and scalar value.
368
+ """
357
369
  # Type checking
358
370
  validate_type("operand", operand, (str, Transformation))
359
371
  validate_comparator_pytorch("comparator", comparator)
360
- validate_comparator_pytorch("comparator", comparator)
361
372
  validate_type("scalar", scalar, Number)
362
373
 
363
- # If transformation is provided, get neuron name,
364
- # else use IdentityTransformation
374
+ # If transformation is provided, get tag name, else use IdentityTransformation
365
375
  if isinstance(operand, Transformation):
366
- neuron_name = operand.neuron_name
376
+ tag = operand.tag
367
377
  transformation = operand
368
378
  else:
369
- neuron_name = operand
370
- transformation = IdentityTransformation(neuron_name)
379
+ tag = operand
380
+ transformation = IdentityTransformation(tag)
371
381
 
372
382
  # Compose constraint name
373
- name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
383
+ name = f"{tag} {comparator.__name__} {str(scalar)}"
374
384
 
375
385
  # Init parent class
376
- super().__init__({neuron_name}, name, monitor_only, rescale_factor)
386
+ super().__init__({tag}, name, enforce, rescale_factor)
377
387
 
378
388
  # Init variables
389
+ self.tag = tag
379
390
  self.comparator = comparator
380
391
  self.scalar = scalar
381
392
  self.transformation = transformation
382
393
 
383
- # Get layer name and feature index from neuron_name
384
- self.layer = self.descriptor.neuron_to_layer[neuron_name]
385
- self.index = self.descriptor.neuron_to_index[neuron_name]
386
-
387
394
  # Calculate directions based on constraint operator
388
395
  if self.comparator in [lt, le]:
389
- self.direction = -1
390
- elif self.comparator in [gt, ge]:
391
396
  self.direction = 1
397
+ elif self.comparator in [gt, ge]:
398
+ self.direction = -1
392
399
 
393
- def check_constraint(
394
- self, prediction: dict[str, Tensor]
395
- ) -> tuple[Tensor, int]:
400
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
401
+ """Check if the scalar constraint is satisfied for a given tag.
396
402
 
403
+ Args:
404
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
405
+
406
+ Returns:
407
+ tuple[Tensor, Tensor]:
408
+ - result: Tensor indicating whether the tag satisfies the constraint.
409
+ - ones_like(result): Tensor of ones with same shape as `result`.
410
+ """
397
411
  # Select relevant columns
398
- selection = prediction[self.layer][:, self.index]
412
+ selection = self.descriptor.select(self.tag, data)
399
413
 
400
414
  # Apply transformation
401
415
  selection = self.transformation(selection)
402
416
 
403
417
  # Calculate current constraint result
404
418
  result = self.comparator(selection, self.scalar).float()
405
- return result, numel(result)
419
+ return result, ones_like(result)
420
+
421
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
422
+ """Compute adjustment directions to satisfy the scalar constraint.
406
423
 
407
- def calculate_direction(
408
- self, prediction: dict[str, Tensor]
409
- ) -> Dict[str, Tensor]:
424
+ Only works for dense layers due to tag-to-index translation.
425
+
426
+ Args:
427
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
428
+
429
+ Returns:
430
+ dict[str, Tensor]: Dictionary mapping layers to tensors specifying
431
+ the adjustment direction for each tag.
432
+ """
410
433
  # NOTE currently only works for dense layers due
411
- # to neuron to index translation
434
+ # to tag to index translation
412
435
 
413
436
  output = {}
414
437
 
415
438
  for layer in self.layers:
416
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
439
+ output[layer] = zeros_like(data[layer][0], device=self.device)
417
440
 
418
- output[self.layer][self.index] = self.direction
441
+ layer, index = self.descriptor.location(self.tag)
442
+ output[layer][index] = self.direction
419
443
 
420
444
  for layer in self.layers:
421
445
  output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
@@ -424,26 +448,24 @@ class ScalarConstraint(Constraint):
424
448
 
425
449
 
426
450
  class BinaryConstraint(Constraint):
427
- """
428
- A constraint that enforces a binary comparison between two neurons.
451
+ """A constraint that enforces a binary comparison between two tags.
429
452
 
430
- This class ensures that the output of one neuron satisfies a comparison
431
- operation with the output of another neuron
432
- (e.g., less than, greater than, etc.). It uses a comparator function to
433
- validate the condition and calculates adjustment directions accordingly.
453
+ This class ensures that the output of one tag satisfies a comparison
454
+ operation with the output of another tag (e.g., less than, greater than, etc.).
455
+ It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
434
456
 
435
457
  Args:
436
458
  operand_left (Union[str, Transformation]): Name of the left
437
- neuron or a transformation to apply.
459
+ tag or a transformation to apply.
438
460
  comparator (Callable[[Tensor, Number], Tensor]): A comparison
439
461
  function (e.g., `torch.ge`, `torch.lt`).
440
462
  operand_right (Union[str, Transformation]): Name of the right
441
- neuron or a transformation to apply.
463
+ tag or a transformation to apply.
442
464
  name (str, optional): A unique name for the constraint. If not
443
465
  provided, a name is auto-generated in the format
444
- "<neuron_name_left> <comparator> <neuron_name_right>".
445
- monitor_only (bool, optional): If True, only monitor the constraint
446
- without adjusting the loss. Defaults to False.
466
+ "<operand_left> <comparator> <operand_right>".
467
+ enforce (bool, optional): If False, only monitor the constraint
468
+ without adjusting the loss. Defaults to True.
447
469
  rescale_factor (Number, optional): Factor to scale the
448
470
  constraint-adjusted loss. Defaults to 1.5.
449
471
 
@@ -451,84 +473,107 @@ class BinaryConstraint(Constraint):
451
473
  TypeError: If a provided attribute has an incompatible type.
452
474
 
453
475
  Notes:
454
- - The neuron names must be defined in the `descriptor` mapping.
455
- - The constraint name is composed using the left neuron name,
456
- comparator, and right neuron name.
476
+ - The tags must be defined in the `descriptor` mapping.
477
+ - The constraint name is composed using the left tag, comparator, and right tag.
457
478
 
458
479
  """
459
480
 
460
481
  def __init__(
461
482
  self,
462
- operand_left: Union[str, Transformation],
483
+ operand_left: str | Transformation,
463
484
  comparator: Callable[[Tensor, Number], Tensor],
464
- operand_right: Union[str, Transformation],
485
+ operand_right: str | Transformation,
465
486
  name: str = None,
466
- monitor_only: bool = False,
487
+ enforce: bool = True,
467
488
  rescale_factor: Number = 1.5,
468
489
  ) -> None:
469
- """
470
- Initializes a BinaryConstraint instance.
471
- """
490
+ """Initializes a BinaryConstraint instance.
491
+
492
+ Args:
493
+ operand_left (Union[str, Transformation]): Name of the left
494
+ tag or a transformation to apply.
495
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
496
+ function (e.g., `torch.ge`, `torch.lt`).
497
+ operand_right (Union[str, Transformation]): Name of the right
498
+ tag or a transformation to apply.
499
+ name (str, optional): A unique name for the constraint. If not
500
+ provided, a name is auto-generated in the format
501
+ "<operand_left> <comparator> <operand_right>".
502
+ enforce (bool, optional): If False, only monitor the constraint
503
+ without adjusting the loss. Defaults to True.
504
+ rescale_factor (Number, optional): Factor to scale the
505
+ constraint-adjusted loss. Defaults to 1.5.
506
+
507
+ Raises:
508
+ TypeError: If a provided attribute has an incompatible type.
472
509
 
510
+ Notes:
511
+ - The tags must be defined in the `descriptor` mapping.
512
+ - The constraint name is composed using the left tag,
513
+ comparator, and right tag.
514
+ """
473
515
  # Type checking
474
516
  validate_type("operand_left", operand_left, (str, Transformation))
475
517
  validate_comparator_pytorch("comparator", comparator)
476
518
  validate_comparator_pytorch("comparator", comparator)
477
519
  validate_type("operand_right", operand_right, (str, Transformation))
478
520
 
479
- # If transformation is provided, get neuron name,
480
- # else use IdentityTransformation
521
+ # If transformation is provided, get tag name, else use IdentityTransformation
481
522
  if isinstance(operand_left, Transformation):
482
- neuron_name_left = operand_left.neuron_name
523
+ tag_left = operand_left.tag
483
524
  transformation_left = operand_left
484
525
  else:
485
- neuron_name_left = operand_left
486
- transformation_left = IdentityTransformation(neuron_name_left)
526
+ tag_left = operand_left
527
+ transformation_left = IdentityTransformation(tag_left)
487
528
 
488
529
  if isinstance(operand_right, Transformation):
489
- neuron_name_right = operand_right.neuron_name
530
+ tag_right = operand_right.tag
490
531
  transformation_right = operand_right
491
532
  else:
492
- neuron_name_right = operand_right
493
- transformation_right = IdentityTransformation(neuron_name_right)
533
+ tag_right = operand_right
534
+ transformation_right = IdentityTransformation(tag_right)
494
535
 
495
536
  # Compose constraint name
496
- name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
537
+ name = f"{tag_left} {comparator.__name__} {tag_right}"
497
538
 
498
539
  # Init parent class
499
- super().__init__(
500
- {neuron_name_left, neuron_name_right},
501
- name,
502
- monitor_only,
503
- rescale_factor,
504
- )
540
+ super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
505
541
 
506
542
  # Init variables
507
543
  self.comparator = comparator
544
+ self.tag_left = tag_left
545
+ self.tag_right = tag_right
508
546
  self.transformation_left = transformation_left
509
547
  self.transformation_right = transformation_right
510
548
 
511
- # Get layer name and feature index from neuron_name
512
- self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
513
- self.layer_right = self.descriptor.neuron_to_layer[neuron_name_right]
514
- self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
515
- self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
516
-
517
549
  # Calculate directions based on constraint operator
518
550
  if self.comparator in [lt, le]:
519
- self.direction_left = -1
520
- self.direction_right = 1
521
- else:
522
551
  self.direction_left = 1
523
552
  self.direction_right = -1
553
+ else:
554
+ self.direction_left = -1
555
+ self.direction_right = 1
524
556
 
525
- def check_constraint(
526
- self, prediction: dict[str, Tensor]
527
- ) -> tuple[Tensor, int]:
557
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
558
+ """Evaluate whether the binary constraint is satisfied for the current predictions.
528
559
 
560
+ The constraint compares the outputs of two tags using the specified
561
+ comparator function. A result of `1` indicates the constraint is satisfied
562
+ for a sample, and `0` indicates it is violated.
563
+
564
+ Args:
565
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
566
+
567
+ Returns:
568
+ tuple[Tensor, Tensor]:
569
+ - result (Tensor): Binary tensor indicating constraint satisfaction
570
+ (1 for satisfied, 0 for violated) for each sample.
571
+ - mask (Tensor): Tensor of ones with the same shape as `result`,
572
+ used for constraint aggregation.
573
+ """
529
574
  # Select relevant columns
530
- selection_left = prediction[self.layer_left][:, self.index_left]
531
- selection_right = prediction[self.layer_right][:, self.index_right]
575
+ selection_left = self.descriptor.select(self.tag_left, data)
576
+ selection_right = self.descriptor.select(self.tag_right, data)
532
577
 
533
578
  # Apply transformations
534
579
  selection_left = self.transformation_left(selection_left)
@@ -536,21 +581,34 @@ class BinaryConstraint(Constraint):
536
581
 
537
582
  result = self.comparator(selection_left, selection_right).float()
538
583
 
539
- return result, numel(result)
584
+ return result, ones_like(result)
585
+
586
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
587
+ """Compute adjustment directions for the tags involved in the binary constraint.
588
+
589
+ The returned directions indicate how to adjust each tag's output to
590
+ satisfy the constraint. Only currently supported for dense layers.
591
+
592
+ Args:
593
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
540
594
 
541
- def calculate_direction(
542
- self, prediction: dict[str, Tensor]
543
- ) -> Dict[str, Tensor]:
595
+ Returns:
596
+ dict[str, Tensor]: A mapping from layer names to tensors specifying
597
+ the normalized adjustment directions for each tag involved in the
598
+ constraint.
599
+ """
544
600
  # NOTE currently only works for dense layers due
545
- # to neuron to index translation
601
+ # to tag to index translation
546
602
 
547
603
  output = {}
548
604
 
549
605
  for layer in self.layers:
550
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
606
+ output[layer] = zeros_like(data[layer][0], device=self.device)
551
607
 
552
- output[self.layer_left][self.index_left] = self.direction_left
553
- output[self.layer_right][self.index_right] = self.direction_right
608
+ layer_left, index_left = self.descriptor.location(self.tag_left)
609
+ layer_right, index_right = self.descriptor.location(self.tag_right)
610
+ output[layer_left][index_left] = self.direction_left
611
+ output[layer_right][index_right] = self.direction_right
554
612
 
555
613
  for layer in self.layers:
556
614
  output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
@@ -559,142 +617,119 @@ class BinaryConstraint(Constraint):
559
617
 
560
618
 
561
619
  class SumConstraint(Constraint):
562
- """
563
- A constraint that enforces a weighted summation comparison
564
- between two groups of neurons.
620
+ """A constraint that enforces a weighted summation comparison between two groups of tags.
565
621
 
566
622
  This class evaluates whether the weighted sum of outputs from one set of
567
- neurons satisfies a comparison operation with the weighted sum of
568
- outputs from another set of neurons.
569
-
570
- Args:
571
- operands_left (list[Union[str, Transformation]]): List of neuron
572
- names or transformations on the left side.
573
- comparator (Callable[[Tensor, Number], Tensor]): A comparison
574
- function for the constraint.
575
- operands_right (list[Union[str, Transformation]]): List of neuron
576
- names or transformations on the right side.
577
- weights_left (list[Number], optional): Weights for the left neurons.
578
- Defaults to None.
579
- weights_right (list[Number], optional): Weights for the right
580
- neurons. Defaults to None.
581
- name (str, optional): Unique name for the constraint.
582
- If None, it's auto-generated. Defaults to None.
583
- monitor_only (bool, optional): If True, only monitor the constraint
584
- without adjusting the loss. Defaults to False.
585
- rescale_factor (Number, optional): Factor to scale the
586
- constraint-adjusted loss. Defaults to 1.5.
587
-
588
- Raises:
589
- TypeError: If a provided attribute has an incompatible type.
590
- ValueError: If the dimensions of neuron names and weights mismatch.
591
-
623
+ tags satisfies a comparison operation with the weighted sum of
624
+ outputs from another set of tags.
592
625
  """
593
626
 
594
627
  def __init__(
595
628
  self,
596
- operands_left: list[Union[str, Transformation]],
629
+ operands_left: list[str | Transformation],
597
630
  comparator: Callable[[Tensor, Number], Tensor],
598
- operands_right: list[Union[str, Transformation]],
631
+ operands_right: list[str | Transformation],
599
632
  weights_left: list[Number] = None,
600
633
  weights_right: list[Number] = None,
601
634
  name: str = None,
602
- monitor_only: bool = False,
635
+ enforce: bool = True,
603
636
  rescale_factor: Number = 1.5,
604
637
  ) -> None:
605
- """
606
- Initializes the SumConstraint.
607
- """
638
+ """Initializes the SumConstraint.
639
+
640
+ Args:
641
+ operands_left (list[Union[str, Transformation]]): List of tags
642
+ or transformations on the left side.
643
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
644
+ function for the constraint.
645
+ operands_right (list[Union[str, Transformation]]): List of tags
646
+ or transformations on the right side.
647
+ weights_left (list[Number], optional): Weights for the left
648
+ tags. Defaults to None.
649
+ weights_right (list[Number], optional): Weights for the right
650
+ tags. Defaults to None.
651
+ name (str, optional): Unique name for the constraint.
652
+ If None, it's auto-generated. Defaults to None.
653
+ enforce (bool, optional): If False, only monitor the constraint
654
+ without adjusting the loss. Defaults to True.
655
+ rescale_factor (Number, optional): Factor to scale the
656
+ constraint-adjusted loss. Defaults to 1.5.
608
657
 
658
+ Raises:
659
+ TypeError: If a provided attribute has an incompatible type.
660
+ ValueError: If the dimensions of tags and weights mismatch.
661
+ """
609
662
  # Type checking
610
663
  validate_iterable("operands_left", operands_left, (str, Transformation))
611
664
  validate_comparator_pytorch("comparator", comparator)
612
665
  validate_comparator_pytorch("comparator", comparator)
613
- validate_iterable(
614
- "operands_right", operands_right, (str, Transformation)
615
- )
666
+ validate_iterable("operands_right", operands_right, (str, Transformation))
616
667
  validate_iterable("weights_left", weights_left, Number, allow_none=True)
617
- validate_iterable(
618
- "weights_right", weights_right, Number, allow_none=True
619
- )
668
+ validate_iterable("weights_right", weights_right, Number, allow_none=True)
620
669
 
621
- # If transformation is provided, get neuron name,
622
- # else use IdentityTransformation
623
- neuron_names_left: list[str] = []
670
+ # If transformation is provided, get tag, else use IdentityTransformation
671
+ tags_left: list[str] = []
624
672
  transformations_left: list[Transformation] = []
625
673
  for operand_left in operands_left:
626
674
  if isinstance(operand_left, Transformation):
627
- neuron_name_left = operand_left.neuron_name
628
- neuron_names_left.append(neuron_name_left)
675
+ tag_left = operand_left.tag
676
+ tags_left.append(tag_left)
629
677
  transformations_left.append(operand_left)
630
678
  else:
631
- neuron_name_left = operand_left
632
- neuron_names_left.append(neuron_name_left)
633
- transformations_left.append(
634
- IdentityTransformation(neuron_name_left)
635
- )
679
+ tag_left = operand_left
680
+ tags_left.append(tag_left)
681
+ transformations_left.append(IdentityTransformation(tag_left))
636
682
 
637
- neuron_names_right: list[str] = []
683
+ tags_right: list[str] = []
638
684
  transformations_right: list[Transformation] = []
639
685
  for operand_right in operands_right:
640
686
  if isinstance(operand_right, Transformation):
641
- neuron_name_right = operand_right.neuron_name
642
- neuron_names_right.append(neuron_name_right)
687
+ tag_right = operand_right.tag
688
+ tags_right.append(tag_right)
643
689
  transformations_right.append(operand_right)
644
690
  else:
645
- neuron_name_right = operand_right
646
- neuron_names_right.append(neuron_name_right)
647
- transformations_right.append(
648
- IdentityTransformation(neuron_name_right)
649
- )
691
+ tag_right = operand_right
692
+ tags_right.append(tag_right)
693
+ transformations_right.append(IdentityTransformation(tag_right))
650
694
 
651
695
  # Compose constraint name
652
- w_left = weights_left or [""] * len(neuron_names_left)
653
- w_right = weights_right or [""] * len(neuron_names_right)
654
- left_expr = " + ".join(
655
- f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
656
- )
657
- right_expr = " + ".join(
658
- f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
659
- )
696
+ w_left = weights_left or [""] * len(tags_left)
697
+ w_right = weights_right or [""] * len(tags_right)
698
+ left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
699
+ right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
660
700
  comparator_name = comparator.__name__
661
701
  name = f"{left_expr} {comparator_name} {right_expr}"
662
702
 
663
703
  # Init parent class
664
- neuron_names = set(neuron_names_left) | set(neuron_names_right)
665
- super().__init__(neuron_names, name, monitor_only, rescale_factor)
704
+ tags = set(tags_left) | set(tags_right)
705
+ super().__init__(tags, name, enforce, rescale_factor)
666
706
 
667
707
  # Init variables
668
708
  self.comparator = comparator
669
- self.neuron_names_left = neuron_names_left
670
- self.neuron_names_right = neuron_names_right
709
+ self.tags_left = tags_left
710
+ self.tags_right = tags_right
671
711
  self.transformations_left = transformations_left
672
712
  self.transformations_right = transformations_right
673
713
 
674
- # If feature list dimensions don't match
675
- # weight list dimensions, raise error
676
- if weights_left and (len(neuron_names_left) != len(weights_left)):
714
+ # If feature list dimensions don't match weight list dimensions, raise error
715
+ if weights_left and (len(tags_left) != len(weights_left)):
677
716
  raise ValueError(
678
- "The dimensions of neuron_names_left don't match with the \
679
- dimensions of weights_left."
717
+ "The dimensions of tags_left don't match with the dimensions of weights_left."
680
718
  )
681
- if weights_right and (len(neuron_names_right) != len(weights_right)):
719
+ if weights_right and (len(tags_right) != len(weights_right)):
682
720
  raise ValueError(
683
- "The dimensions of neuron_names_right don't match with the \
684
- dimensions of weights_right."
721
+ "The dimensions of tags_right don't match with the dimensions of weights_right."
685
722
  )
686
723
 
687
724
  # If weights are provided for summation, transform them to Tensors
688
725
  if weights_left:
689
726
  self.weights_left = tensor(weights_left, device=self.device)
690
727
  else:
691
- self.weights_left = ones(len(neuron_names_left), device=self.device)
728
+ self.weights_left = ones(len(tags_left), device=self.device)
692
729
  if weights_right:
693
730
  self.weights_right = tensor(weights_right, device=self.device)
694
731
  else:
695
- self.weights_right = ones(
696
- len(neuron_names_right), device=self.device
697
- )
732
+ self.weights_right = ones(len(tags_right), device=self.device)
698
733
 
699
734
  # Calculate directions based on constraint operator
700
735
  if self.comparator in [lt, le]:
@@ -704,80 +739,82 @@ class SumConstraint(Constraint):
704
739
  self.direction_left = 1
705
740
  self.direction_right = -1
706
741
 
707
- def check_constraint(
708
- self, prediction: dict[str, Tensor]
709
- ) -> tuple[Tensor, int]:
742
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
743
+ """Evaluate whether the weighted sum constraint is satisfied.
744
+
745
+ Computes the weighted sum of outputs from the left and right tags,
746
+ applies the specified comparator function, and returns a binary result for
747
+ each sample.
748
+
749
+ Args:
750
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
751
+
752
+ Returns:
753
+ tuple[Tensor, Tensor]:
754
+ - result (Tensor): Binary tensor indicating whether the constraint
755
+ is satisfied (1) or violated (0) for each sample.
756
+ - mask (Tensor): Tensor of ones, used for constraint aggregation.
757
+ """
710
758
 
711
759
  def compute_weighted_sum(
712
- neuron_names: list[str],
760
+ tags: list[str],
713
761
  transformations: list[Transformation],
714
- weights: tensor,
715
- ) -> tensor:
716
- layers = [
717
- self.descriptor.neuron_to_layer[neuron_name]
718
- for neuron_name in neuron_names
719
- ]
720
- indices = [
721
- self.descriptor.neuron_to_index[neuron_name]
722
- for neuron_name in neuron_names
723
- ]
724
-
725
- # Select relevant column
726
- selections = [
727
- prediction[layer][:, index]
728
- for layer, index in zip(layers, indices)
729
- ]
762
+ weights: Tensor,
763
+ ) -> Tensor:
764
+ # Select relevant columns
765
+ selections = [self.descriptor.select(tag, data) for tag in tags]
730
766
 
731
767
  # Apply transformations
732
768
  results = []
733
- for transformation, selection in zip(transformations, selections):
769
+ for transformation, selection in zip(transformations, selections, strict=False):
734
770
  results.append(transformation(selection))
735
771
 
736
- # Extract predictions for all neurons and apply weights in bulk
737
- predictions = stack(
738
- results,
739
- dim=1,
740
- )
772
+ # Extract predictions for all tags and apply weights in bulk
773
+ predictions = stack(results)
741
774
 
742
775
  # Calculate weighted sum
743
- return (predictions * weights.unsqueeze(0)).sum(dim=1)
776
+ return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
744
777
 
745
778
  # Compute weighted sums
746
779
  weighted_sum_left = compute_weighted_sum(
747
- self.neuron_names_left,
748
- self.transformations_left,
749
- self.weights_left,
780
+ self.tags_left, self.transformations_left, self.weights_left
750
781
  )
751
782
  weighted_sum_right = compute_weighted_sum(
752
- self.neuron_names_right,
753
- self.transformations_right,
754
- self.weights_right,
783
+ self.tags_right, self.transformations_right, self.weights_right
755
784
  )
756
785
 
757
786
  # Apply the comparator and calculate the result
758
787
  result = self.comparator(weighted_sum_left, weighted_sum_right).float()
759
788
 
760
- return result, numel(result)
789
+ return result, ones_like(result)
790
+
791
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
792
+ """Compute adjustment directions for tags involved in the weighted sum constraint.
793
+
794
+ The directions indicate how to adjust each tag's output to satisfy the
795
+ constraint. Only dense layers are currently supported.
761
796
 
762
- def calculate_direction(
763
- self, prediction: dict[str, Tensor]
764
- ) -> Dict[str, Tensor]:
797
+ Args:
798
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
799
+
800
+ Returns:
801
+ dict[str, Tensor]: Mapping from layer names to normalized tensors
802
+ specifying adjustment directions for each tag involved in the constraint.
803
+ """
765
804
  # NOTE currently only works for dense layers
766
- # due to neuron to index translation
805
+ # due to tag to index translation
767
806
 
768
807
  output = {}
769
808
 
770
809
  for layer in self.layers:
771
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
810
+ output[layer] = zeros_like(data[layer][0], device=self.device)
772
811
 
773
- for neuron_name_left in self.neuron_names_left:
774
- layer = self.descriptor.neuron_to_layer[neuron_name_left]
775
- index = self.descriptor.neuron_to_index[neuron_name_left]
812
+ for tag_left in self.tags_left:
813
+ layer, index = self.descriptor.location(tag_left)
776
814
  output[layer][index] = self.direction_left
777
815
 
778
- for neuron_name_right in self.neuron_names_right:
779
- layer = self.descriptor.neuron_to_layer[neuron_name_right]
780
- index = self.descriptor.neuron_to_index[neuron_name_right]
816
+ for tag_right in self.tags_right:
817
+ layer, index = self.descriptor.location(tag_right)
781
818
  output[layer][index] = self.direction_right
782
819
 
783
820
  for layer in self.layers:
@@ -786,134 +823,434 @@ class SumConstraint(Constraint):
786
823
  return output
787
824
 
788
825
 
789
- class PythagoreanIdentityConstraint(Constraint):
826
+ class MonotonicityConstraint(Constraint):
827
+ """Constraint that enforces a monotonic relationship between two tags.
828
+
829
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
830
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
790
831
  """
791
- A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
792
- where `a` and `b` are neurons or transformations.
793
832
 
794
- This constraint checks that the sum of the squares of two specified
795
- neurons (or their transformations) is approximately equal to 1.
796
- The constraint is evaluated using relative and absolute
797
- tolerance (`rtol` and `atol`) and is applied during the forward pass.
833
+ def __init__(
834
+ self,
835
+ tag_prediction: str,
836
+ tag_reference: str,
837
+ rescale_factor_lower: float = 1.5,
838
+ rescale_factor_upper: float = 1.75,
839
+ stable: bool = True,
840
+ direction: Literal["ascending", "descending"] = "ascending",
841
+ name: str = None,
842
+ enforce: bool = True,
843
+ ):
844
+ """Constraint that enforces monotonicity on a predicted output.
798
845
 
799
- Args:
800
- a (Union[str, Transformation]): The first input, either a
801
- neuron name (str) or a Transformation.
802
- b (Union[str, Transformation]): The second input, either a
803
- neuron name (str) or a Transformation.
804
- rtol (float, optional): The relative tolerance for the
805
- comparison (default is 0.00001).
806
- atol (float, optional): The absolute tolerance for the
807
- comparison (default is 1e-8).
808
- name (str, optional): The name of the constraint
809
- (default is None, and it is generated automatically).
810
- monitor_only (bool, optional): Flag indicating whether the
811
- constraint is only for monitoring (default is False).
812
- rescale_factor (Number, optional): A factor used for
813
- rescaling (default is 1.5).
846
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
847
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
814
848
 
815
- Raises:
816
- TypeError: If a provided attribute has an incompatible type.
849
+ Args:
850
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
851
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
852
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
853
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
854
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
855
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
856
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
857
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
858
+ """
859
+ # Type checking
860
+ validate_type("rescale_factor_lower", rescale_factor_lower, float)
861
+ validate_type("rescale_factor_upper", rescale_factor_upper, float)
862
+ validate_type("stable", stable, bool)
863
+ validate_type("direction", direction, str)
817
864
 
865
+ # Compose constraint name
866
+ if name is None:
867
+ name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
868
+
869
+ # Init parent class
870
+ super().__init__({tag_prediction}, name, enforce, 1.0)
871
+
872
+ # Init variables
873
+ self.tag_prediction = tag_prediction
874
+ self.tag_reference = tag_reference
875
+ self.rescale_factor_lower = rescale_factor_lower
876
+ self.rescale_factor_upper = rescale_factor_upper
877
+ self.stable = stable
878
+ self.descending = direction == "descending"
879
+
880
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
881
+ """Evaluate whether the monotonicity constraint is satisfied."""
882
+ # Select relevant columns
883
+ preds = self.descriptor.select(self.tag_prediction, data)
884
+ targets = self.descriptor.select(self.tag_reference, data)
885
+
886
+ # Utility: convert values -> ranks (0 ... num_features-1)
887
+ def compute_ranks(x: Tensor, descending: bool) -> Tensor:
888
+ return argsort(
889
+ argsort(x, descending=descending, stable=self.stable, dim=0),
890
+ descending=False,
891
+ stable=self.stable,
892
+ dim=0,
893
+ )
894
+
895
+ # Compute predicted and target ranks
896
+ pred_ranks = compute_ranks(preds, descending=self.descending)
897
+ target_ranks = compute_ranks(targets, descending=False)
898
+
899
+ # Rank difference
900
+ rank_diff = pred_ranks - target_ranks
901
+
902
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
903
+ batch_size = preds.shape[0]
904
+ invert_direction = -1 if self.descending else 1
905
+ self.compared_rankings = (
906
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
907
+ + self.rescale_factor_lower * sign(rank_diff)
908
+ ) * invert_direction
909
+
910
+ # Calculate satisfaction
911
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
912
+
913
+ return incorrect_rankings, ones_like(incorrect_rankings)
914
+
915
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
916
+ """Calculates ranking adjustments for monotonicity enforcement."""
917
+ layer, _ = self.descriptor.location(self.tag_prediction)
918
+ return {layer: self.compared_rankings}
919
+
920
+
921
+ class GroupedMonotonicityConstraint(MonotonicityConstraint):
922
+ """Constraint that enforces a monotonic relationship between two tags.
923
+
924
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
925
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
818
926
  """
819
927
 
820
928
  def __init__(
821
929
  self,
822
- a: Union[str, Transformation],
823
- b: Union[str, Transformation],
824
- rtol: float = 0.00001,
825
- atol: float = 1e-8,
930
+ tag_prediction: str,
931
+ tag_reference: str,
932
+ tag_group_identifier: str,
933
+ rescale_factor_lower: float = 1.5,
934
+ rescale_factor_upper: float = 1.75,
935
+ stable: bool = True,
936
+ direction: Literal["ascending", "descending"] = "ascending",
937
+ name: str = None,
938
+ enforce: bool = True,
939
+ ):
940
+ """Constraint that enforces monotonicity on a predicted output.
941
+
942
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
943
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
944
+
945
+ Args:
946
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
947
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
948
+ tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
949
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
950
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
951
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
952
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
953
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
954
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
955
+ """
956
+ # Compose constraint name
957
+ if name is None:
958
+ name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
959
+
960
+ # Init parent class
961
+ super().__init__(
962
+ tag_prediction=tag_prediction,
963
+ tag_reference=tag_reference,
964
+ rescale_factor_lower=rescale_factor_lower,
965
+ rescale_factor_upper=rescale_factor_upper,
966
+ stable=stable,
967
+ direction=direction,
968
+ name=name,
969
+ enforce=enforce,
970
+ )
971
+
972
+ # Init variables
973
+ self.tag_prediction = tag_prediction
974
+ self.tag_reference = tag_reference
975
+ self.tag_group_identifier = tag_group_identifier
976
+
977
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
978
+ """Evaluate whether the monotonicity constraint is satisfied."""
979
+ # Select group identifiers and convert to unique list
980
+ group_identifiers = self.descriptor.select(self.tag_group_identifier, data)
981
+ unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
982
+
983
+ # Initialize checks and directions
984
+ checks = zeros_like(group_identifiers, device=self.device)
985
+ self.directions = zeros_like(group_identifiers, device=self.device)
986
+
987
+ # Get prediction and target keys
988
+ preds_key, _ = self.descriptor.location(self.tag_prediction)
989
+ targets_key, _ = self.descriptor.location(self.tag_reference)
990
+
991
+ for group_identifier in unique_group_identifiers:
992
+ # Create mask for the samples in this group
993
+ group_mask = (group_identifiers == group_identifier).squeeze(1)
994
+
995
+ # Create mini-batch for the group
996
+ group_data = {
997
+ preds_key: data[preds_key][group_mask],
998
+ targets_key: data[targets_key][group_mask],
999
+ }
1000
+
1001
+ # Call super on the mini-batch
1002
+ checks[group_mask], _ = super().check_constraint(group_data)
1003
+ self.directions[group_mask] = self.compared_rankings
1004
+
1005
+ return checks, ones_like(checks)
1006
+
1007
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
1008
+ """Calculates ranking adjustments for monotonicity enforcement."""
1009
+ layer, _ = self.descriptor.location(self.tag_prediction)
1010
+ return {layer: self.directions}
1011
+
1012
+
1013
+ class ANDConstraint(Constraint):
1014
+ """A composite constraint that enforces the logical AND of multiple constraints.
1015
+
1016
+ This class combines multiple sub-constraints and evaluates them jointly:
1017
+
1018
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1019
+ are satisfied (elementwise logical AND).
1020
+ * The corrective direction is computed by weighting each sub-constraint's
1021
+ direction with its satisfaction mask and summing across all sub-constraints.
1022
+ """
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ *constraints: Constraint,
826
1027
  name: str = None,
827
1028
  monitor_only: bool = False,
828
1029
  rescale_factor: Number = 1.5,
829
1030
  ) -> None:
1031
+ """A composite constraint that enforces the logical AND of multiple constraints.
1032
+
1033
+ This class combines multiple sub-constraints and evaluates them jointly:
1034
+
1035
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1036
+ are satisfied (elementwise logical AND).
1037
+ * The corrective direction is computed by weighting each sub-constraint's
1038
+ direction with its satisfaction mask and summing across all sub-constraints.
1039
+
1040
+ Args:
1041
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1042
+ name (str, optional): A custom name for this constraint. If not provided,
1043
+ the name will be composed from the sub-constraint names joined with
1044
+ " AND ".
1045
+ monitor_only (bool, optional): If True, the constraint will be monitored
1046
+ but not enforced. Defaults to False.
1047
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1048
+ corrections. Defaults to 1.5.
1049
+
1050
+ Attributes:
1051
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1052
+ neurons (set): The union of neurons referenced by the sub-constraints.
1053
+ name (str): The name of the constraint (composed or custom).
830
1054
  """
831
- Initialize the PythagoreanIdentityConstraint.
1055
+ # Type checking
1056
+ validate_iterable("constraints", constraints, Constraint)
1057
+
1058
+ # Compose constraint name
1059
+ if not name:
1060
+ name = " AND ".join([constraint.name for constraint in constraints])
1061
+
1062
+ # Init parent class
1063
+ super().__init__(
1064
+ set().union(*(constraint.tags for constraint in constraints)),
1065
+ name,
1066
+ monitor_only,
1067
+ rescale_factor,
1068
+ )
1069
+
1070
+ # Init variables
1071
+ self.constraints = constraints
1072
+
1073
+ def check_constraint(self, data: dict[str, Tensor]):
1074
+ """Evaluate whether all sub-constraints are satisfied.
1075
+
1076
+ Args:
1077
+ data: Model predictions and associated batch/context information.
1078
+
1079
+ Returns:
1080
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1081
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1082
+ elementwise whether all constraints are satisfied
1083
+ (logical AND).
1084
+ * `mask`: A tensor of ones with the same shape as
1085
+ `total_satisfaction`. Typically used as a weighting mask
1086
+ in downstream processing.
832
1087
  """
1088
+ total_satisfaction: Tensor = None
1089
+ total_mask: Tensor = None
1090
+
1091
+ # TODO vectorize this loop
1092
+ for constraint in self.constraints:
1093
+ satisfaction, mask = constraint.check_constraint(data)
1094
+ if total_satisfaction is None:
1095
+ total_satisfaction = satisfaction
1096
+ total_mask = mask
1097
+ else:
1098
+ total_satisfaction = logical_and(total_satisfaction, satisfaction)
1099
+ total_mask = logical_or(total_mask, mask)
833
1100
 
834
- # Type checking
835
- validate_type("a", a, (str, Transformation))
836
- validate_type("b", b, (str, Transformation))
837
- validate_type("rtol", rtol, float)
838
- validate_type("atol", atol, float)
839
-
840
- # If transformation is provided, get neuron name,
841
- # else use IdentityTransformation
842
- if isinstance(a, Transformation):
843
- neuron_name_a = a.neuron_name
844
- transformation_a = a
845
- else:
846
- neuron_name_a = a
847
- transformation_a = IdentityTransformation(neuron_name_a)
1101
+ return total_satisfaction.float(), total_mask.float()
848
1102
 
849
- if isinstance(b, Transformation):
850
- neuron_name_b = b.neuron_name
851
- transformation_b = b
852
- else:
853
- neuron_name_b = b
854
- transformation_b = IdentityTransformation(neuron_name_b)
1103
+ def calculate_direction(self, data: dict[str, Tensor]):
1104
+ """Compute the corrective direction by aggregating sub-constraint directions.
1105
+
1106
+ Each sub-constraint contributes its corrective direction, weighted
1107
+ by its satisfaction mask. The directions are summed across constraints
1108
+ for each affected layer.
1109
+
1110
+ Args:
1111
+ data: Model predictions and associated batch/context information.
1112
+
1113
+ Returns:
1114
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1115
+ tensors. Each entry represents the aggregated correction to apply
1116
+ to that layer, based on the satisfaction-weighted sum of
1117
+ sub-constraint directions.
1118
+ """
1119
+ total_direction: dict[str, Tensor] = {}
1120
+
1121
+ # TODO vectorize this loop
1122
+ for constraint in self.constraints:
1123
+ # TODO improve efficiency by avoiding double computation?
1124
+ satisfaction, _ = constraint.check_constraint(data)
1125
+ direction = constraint.calculate_direction(data)
1126
+
1127
+ for layer, dir in direction.items():
1128
+ if layer not in total_direction:
1129
+ total_direction[layer] = satisfaction.unsqueeze(1) * dir
1130
+ else:
1131
+ total_direction[layer] += satisfaction.unsqueeze(1) * dir
1132
+
1133
+ return total_direction
1134
+
1135
+
1136
+ class ORConstraint(Constraint):
1137
+ """A composite constraint that enforces the logical OR of multiple constraints.
1138
+
1139
+ This class combines multiple sub-constraints and evaluates them jointly:
1140
+
1141
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1142
+ is satisfied (elementwise logical OR).
1143
+ * The corrective direction is computed by weighting each sub-constraint's
1144
+ direction with its satisfaction mask and summing across all sub-constraints.
1145
+ """
1146
+
1147
+ def __init__(
1148
+ self,
1149
+ *constraints: Constraint,
1150
+ name: str = None,
1151
+ monitor_only: bool = False,
1152
+ rescale_factor: Number = 1.5,
1153
+ ) -> None:
1154
+ """A composite constraint that enforces the logical OR of multiple constraints.
1155
+
1156
+ This class combines multiple sub-constraints and evaluates them jointly:
1157
+
1158
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1159
+ is satisfied (elementwise logical OR).
1160
+ * The corrective direction is computed by weighting each sub-constraint's
1161
+ direction with its satisfaction mask and summing across all sub-constraints.
1162
+
1163
+ Args:
1164
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1165
+ name (str, optional): A custom name for this constraint. If not provided,
1166
+ the name will be composed from the sub-constraint names joined with
1167
+ " OR ".
1168
+ monitor_only (bool, optional): If True, the constraint will be monitored
1169
+ but not enforced. Defaults to False.
1170
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1171
+ corrections. Defaults to 1.5.
1172
+
1173
+ Attributes:
1174
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1175
+ neurons (set): The union of neurons referenced by the sub-constraints.
1176
+ name (str): The name of the constraint (composed or custom).
1177
+ """
1178
+ # Type checking
1179
+ validate_iterable("constraints", constraints, Constraint)
855
1180
 
856
1181
  # Compose constraint name
857
- name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
1182
+ if not name:
1183
+ name = " OR ".join([constraint.name for constraint in constraints])
858
1184
 
859
1185
  # Init parent class
860
1186
  super().__init__(
861
- {neuron_name_a, neuron_name_b},
1187
+ set().union(*(constraint.tags for constraint in constraints)),
862
1188
  name,
863
1189
  monitor_only,
864
1190
  rescale_factor,
865
1191
  )
866
1192
 
867
1193
  # Init variables
868
- self.transformation_a = transformation_a
869
- self.transformation_b = transformation_b
870
- self.rtol = rtol
871
- self.atol = atol
1194
+ self.constraints = constraints
872
1195
 
873
- # Get layer name and feature index from neuron_name
874
- self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
875
- self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
876
- self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
877
- self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
1196
+ def check_constraint(self, data: dict[str, Tensor]):
1197
+ """Evaluate whether any sub-constraints are satisfied.
878
1198
 
879
- def check_constraint(
880
- self, prediction: dict[str, Tensor]
881
- ) -> tuple[Tensor, int]:
1199
+ Args:
1200
+ data: Model predictions and associated batch/context information.
882
1201
 
883
- # Select relevant columns
884
- selection_a = prediction[self.layer_a][:, self.index_a]
885
- selection_b = prediction[self.layer_b][:, self.index_b]
1202
+ Returns:
1203
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1204
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1205
+ elementwise whether any constraints are satisfied
1206
+ (logical OR).
1207
+ * `mask`: A tensor of ones with the same shape as
1208
+ `total_satisfaction`. Typically used as a weighting mask
1209
+ in downstream processing.
1210
+ """
1211
+ total_satisfaction: Tensor = None
1212
+ total_mask: Tensor = None
1213
+
1214
+ # TODO vectorize this loop
1215
+ for constraint in self.constraints:
1216
+ satisfaction, mask = constraint.check_constraint(data)
1217
+ if total_satisfaction is None:
1218
+ total_satisfaction = satisfaction
1219
+ total_mask = mask
1220
+ else:
1221
+ total_satisfaction = logical_or(total_satisfaction, satisfaction)
1222
+ total_mask = logical_or(total_mask, mask)
886
1223
 
887
- # Apply transformations
888
- selection_a = self.transformation_a(selection_a)
889
- selection_b = self.transformation_b(selection_b)
890
-
891
- # Calculate result
892
- result = isclose(
893
- square(selection_a) + square(selection_b),
894
- ones_like(selection_a, device=self.device),
895
- rtol=self.rtol,
896
- atol=self.atol,
897
- ).float()
898
-
899
- return result, numel(result)
900
-
901
- def calculate_direction(
902
- self, prediction: dict[str, Tensor]
903
- ) -> Dict[str, Tensor]:
904
- # NOTE currently only works for dense layers due
905
- # to neuron to index translation
1224
+ return total_satisfaction.float(), total_mask.float()
906
1225
 
907
- output = {}
1226
+ def calculate_direction(self, data: dict[str, Tensor]):
1227
+ """Compute the corrective direction by aggregating sub-constraint directions.
908
1228
 
909
- for layer in self.layers:
910
- output[layer] = zeros_like(prediction[layer], device=self.device)
1229
+ Each sub-constraint contributes its corrective direction, weighted
1230
+ by its satisfaction mask. The directions are summed across constraints
1231
+ for each affected layer.
1232
+
1233
+ Args:
1234
+ data: Model predictions and associated batch/context information.
911
1235
 
912
- a = prediction[self.layer_a][:, self.index_a]
913
- b = prediction[self.layer_b][:, self.index_b]
914
- m = sqrt(square(a) + square(b))
1236
+ Returns:
1237
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1238
+ tensors. Each entry represents the aggregated correction to apply
1239
+ to that layer, based on the satisfaction-weighted sum of
1240
+ sub-constraint directions.
1241
+ """
1242
+ total_direction: dict[str, Tensor] = {}
915
1243
 
916
- output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
917
- output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
1244
+ # TODO vectorize this loop
1245
+ for constraint in self.constraints:
1246
+ # TODO improve efficiency by avoiding double computation?
1247
+ satisfaction, _ = constraint.check_constraint(data)
1248
+ direction = constraint.calculate_direction(data)
918
1249
 
919
- return output
1250
+ for layer, dir in direction.items():
1251
+ if layer not in total_direction:
1252
+ total_direction[layer] = satisfaction.unsqueeze(1) * dir
1253
+ else:
1254
+ total_direction[layer] += satisfaction.unsqueeze(1) * dir
1255
+
1256
+ return total_direction