congrads 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/constraints.py CHANGED
@@ -1,65 +1,55 @@
1
- """
2
- This module provides a set of constraint classes for guiding neural network
3
- training by enforcing specific conditions on the network's outputs.
4
-
5
- The constraints in this module include:
6
-
7
- - `Constraint`: The base class for all constraint types, defining the
8
- interface and core behavior.
9
- - `ImplicationConstraint`: A constraint that enforces one condition only if
10
- another condition is met, useful for modeling implications between network
11
- outputs.
12
- - `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
13
- a network's output.
14
- - `BinaryConstraint`: A constraint that enforces a binary comparison between
15
- two neurons in the network, using a comparison function (e.g., less than,
16
- greater than).
17
- - `SumConstraint`: A constraint that enforces that the sum of certain neurons'
18
- outputs equals a specified value, which can be used to control total output.
19
- - `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
20
- on a set of neurons, ensuring that the square of one neuron's output is equal
21
- to the sum of the squares of other outputs.
22
-
23
- These constraints can be used to steer the learning process by applying
24
- conditions such as logical implications or numerical bounds.
1
+ """Module providing constraint classes for guiding neural network training.
2
+
3
+ This module defines constraints that enforce specific conditions on network outputs
4
+ to steer learning. Available constraint types include:
5
+
6
+ - `Constraint`: Base class for all constraint types, defining the interface and core
7
+ behavior.
8
+ - `ImplicationConstraint`: Enforces one condition only if another condition is met,
9
+ useful for modeling implications between outputs.
10
+ - `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
11
+ - `BinaryConstraint`: Enforces a binary comparison between two tags using a
12
+ comparison function (e.g., less than, greater than).
13
+ - `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
14
+ value, controlling total output.
15
+
16
+ These constraints can steer the learning process by applying logical implications
17
+ or numerical bounds.
25
18
 
26
19
  Usage:
27
20
  1. Define a custom constraint class by inheriting from `Constraint`.
28
- 2. Apply the constraint to your neural network during training to
29
- enforce desired output behaviors.
30
- 3. Use the helper classes like `IdentityTransformation` for handling
31
- transformations and comparisons in constraints.
21
+ 2. Apply the constraint to your neural network during training.
22
+ 3. Use helper classes like `IdentityTransformation` for transformations and
23
+ comparisons in constraints.
32
24
 
33
- Dependencies:
34
- - PyTorch (`torch`)
35
25
  """
36
26
 
37
27
  import random
38
28
  import string
39
29
  import warnings
40
30
  from abc import ABC, abstractmethod
31
+ from collections.abc import Callable
41
32
  from numbers import Number
42
- from typing import Callable, Dict, Union
33
+ from typing import Literal
43
34
 
44
35
  from torch import (
45
36
  Tensor,
46
- count_nonzero,
37
+ argsort,
38
+ eq,
47
39
  ge,
48
40
  gt,
49
- isclose,
50
41
  le,
42
+ logical_and,
51
43
  logical_not,
52
44
  logical_or,
53
45
  lt,
54
- numel,
55
46
  ones,
56
47
  ones_like,
57
48
  reshape,
58
49
  sign,
59
- sqrt,
60
- square,
61
50
  stack,
62
51
  tensor,
52
+ unique,
63
53
  zeros_like,
64
54
  )
65
55
  from torch.nn.functional import normalize
@@ -70,8 +60,7 @@ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
70
60
 
71
61
 
72
62
  class Constraint(ABC):
73
- """
74
- Abstract base class for defining constraints applied to neural networks.
63
+ """Abstract base class for defining constraints applied to neural networks.
75
64
 
76
65
  A `Constraint` specifies conditions that the neural network outputs
77
66
  should satisfy. It supports monitoring constraint satisfaction
@@ -79,18 +68,18 @@ class Constraint(ABC):
79
68
  must implement the `check_constraint` and `calculate_direction` methods.
80
69
 
81
70
  Args:
82
- neurons (set[str]): Names of the neurons this constraint applies to.
71
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
83
72
  name (str, optional): A unique name for the constraint. If not provided,
84
73
  a name is generated based on the class name and a random suffix.
85
- monitor_only (bool, optional): If True, only monitor the constraint
86
- without adjusting the loss. Defaults to False.
74
+ enforce (bool, optional): If False, only monitor the constraint
75
+ without adjusting the loss. Defaults to True.
87
76
  rescale_factor (Number, optional): Factor to scale the
88
77
  constraint-adjusted loss. Defaults to 1.5. Should be greater
89
78
  than 1 to give weight to the constraint.
90
79
 
91
80
  Raises:
92
81
  TypeError: If a provided attribute has an incompatible type.
93
- ValueError: If any neuron in `neurons` is not
82
+ ValueError: If any tag in `tags` is not
94
83
  defined in the `descriptor`.
95
84
 
96
85
  Note:
@@ -104,29 +93,44 @@ class Constraint(ABC):
104
93
  device = None
105
94
 
106
95
  def __init__(
107
- self,
108
- neurons: set[str],
109
- name: str = None,
110
- monitor_only: bool = False,
111
- rescale_factor: Number = 1.5,
96
+ self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
112
97
  ) -> None:
113
- """
114
- Initializes a new Constraint instance.
115
- """
98
+ """Initializes a new Constraint instance.
116
99
 
100
+ Args:
101
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
102
+ name (str, optional): A unique name for the constraint. If not
103
+ provided, a name is generated based on the class name and a
104
+ random suffix.
105
+ enforce (bool, optional): If False, only monitor the constraint
106
+ without adjusting the loss. Defaults to True.
107
+ rescale_factor (Number, optional): Factor to scale the
108
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
109
+ than 1 to give weight to the constraint.
110
+
111
+ Raises:
112
+ TypeError: If a provided attribute has an incompatible type.
113
+ ValueError: If any tag in `tags` is not defined in the `descriptor`.
114
+
115
+ Note:
116
+ - If `rescale_factor <= 1`, a warning is issued.
117
+ - If `name` is not provided, a name is auto-generated, and a
118
+ warning is logged.
119
+ """
117
120
  # Init parent class
118
121
  super().__init__()
119
122
 
120
123
  # Type checking
121
- validate_iterable("neurons", neurons, str)
124
+ validate_iterable("tags", tags, str)
122
125
  validate_type("name", name, str, allow_none=True)
123
- validate_type("monitor_only", monitor_only, bool)
126
+ validate_type("enforce", enforce, bool)
124
127
  validate_type("rescale_factor", rescale_factor, Number)
125
128
 
126
129
  # Init object variables
127
- self.neurons = neurons
130
+ self.tags = tags
128
131
  self.rescale_factor = rescale_factor
129
- self.monitor_only = monitor_only
132
+ self.initial_rescale_factor = rescale_factor
133
+ self.enforce = enforce
130
134
 
131
135
  # Perform checks
132
136
  if rescale_factor <= 1:
@@ -135,120 +139,102 @@ class Constraint(ABC):
135
139
  "will favor general loss over the constraint-adjusted loss. "
136
140
  "Is this intended behavior? Normally, the rescale factor "
137
141
  "should always be larger than 1.",
142
+ stacklevel=2,
138
143
  )
139
- else:
140
- self.rescale_factor = rescale_factor
141
144
 
142
145
  # If no constraint_name is set, generate one based
143
146
  # on the class name and a random suffix
144
147
  if name:
145
148
  self.name = name
146
149
  else:
147
- random_suffix = "".join(
148
- random.choices(string.ascii_uppercase + string.digits, k=6)
149
- )
150
+ random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
150
151
  self.name = f"{self.__class__.__name__}_{random_suffix}"
151
- warnings.warn(
152
- f"Name for constraint is not set. Using {self.name}.",
153
- )
152
+ warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
154
153
 
155
- # Infer layers from descriptor and neurons
154
+ # Infer layers from descriptor and tags
156
155
  self.layers = set()
157
- for neuron in self.neurons:
158
- if neuron not in self.descriptor.neuron_to_layer.keys():
156
+ for tag in self.tags:
157
+ if not self.descriptor.exists(tag):
159
158
  raise ValueError(
160
- f"The neuron name {neuron} used with constraint "
159
+ f"The tag {tag} used with constraint "
161
160
  f"{self.name} is not defined in the descriptor. Please "
162
161
  "add it to the correct layer using "
163
162
  "descriptor.add('layer', ...)."
164
163
  )
165
164
 
166
- self.layers.add(self.descriptor.neuron_to_layer[neuron])
165
+ layer, _ = self.descriptor.location(tag)
166
+ self.layers.add(layer)
167
167
 
168
168
  @abstractmethod
169
- def check_constraint(
170
- self, prediction: dict[str, Tensor]
171
- ) -> tuple[Tensor, int]:
172
- """
173
- Evaluates whether the given model predictions satisfy the constraint.
169
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
170
+ """Evaluates whether the given model predictions satisfy the constraint.
171
+
172
+ 1 IS SATISFIED, 0 IS NOT SATISFIED
174
173
 
175
174
  Args:
176
- prediction (dict[str, Tensor]): Model predictions for the neurons.
175
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
177
176
 
178
177
  Returns:
179
- tuple[Tensor, int]: A tuple where the first element is a tensor
180
- indicating whether the constraint is satisfied (with `True`
181
- for satisfaction, `False` for non-satisfaction, and `torch.nan`
182
- for irrelevant results), and the second element is an integer
183
- value representing the number of relevant constraints.
178
+ tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
179
+ indicating whether the constraint is satisfied (with value 1.0
180
+ for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
181
+ mask that indicates the relevance of each sample (`True` for relevant
182
+ samples and `False` for irrelevant ones).
184
183
 
185
184
  Raises:
186
185
  NotImplementedError: If not implemented in a subclass.
187
186
  """
188
-
189
187
  raise NotImplementedError
190
188
 
191
189
  @abstractmethod
192
- def calculate_direction(
193
- self, prediction: dict[str, Tensor]
194
- ) -> Dict[str, Tensor]:
195
- """
196
- Calculates adjustment directions for neurons to
197
- better satisfy the constraint.
190
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
191
+ """Compute adjustment directions to better satisfy the constraint.
192
+
193
+ Given the model predictions, input batch, and context, this method calculates the direction
194
+ in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
198
195
 
199
196
  Args:
200
- prediction (dict[str, Tensor]): Model predictions for the neurons.
197
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
201
198
 
202
199
  Returns:
203
- Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
204
- specifying the adjustment direction for each neuron.
200
+ dict[str, Tensor]: Dictionary mapping network layers to tensors that
201
+ specify the adjustment direction for each tag.
205
202
 
206
203
  Raises:
207
- NotImplementedError: If not implemented in a subclass.
204
+ NotImplementedError: Must be implemented by subclasses.
208
205
  """
209
-
210
206
  raise NotImplementedError
211
207
 
212
208
 
213
209
  class ImplicationConstraint(Constraint):
214
- """
215
- Represents an implication constraint between two
216
- constraints (head and body).
210
+ """Represents an implication constraint between two constraints (head and body).
217
211
 
218
212
  The implication constraint ensures that the `body` constraint only applies
219
213
  when the `head` constraint is satisfied. If the `head` constraint is not
220
214
  satisfied, the `body` constraint does not apply.
221
-
222
- Args:
223
- head (Constraint): The head of the implication. If this constraint
224
- is satisfied, the body constraint must also be satisfied.
225
- body (Constraint): The body of the implication. This constraint
226
- is enforced only when the head constraint is satisfied.
227
- name (str, optional): A unique name for the constraint. If not
228
- provided, the name is generated in the format
229
- "{body.name} if {head.name}". Defaults to None.
230
- monitor_only (bool, optional): If True, the constraint is only
231
- monitored without adjusting the loss. Defaults to False.
232
- rescale_factor (Number, optional): The scaling factor for the
233
- constraint-adjusted loss. Defaults to 1.5.
234
-
235
- Raises:
236
- TypeError: If a provided attribute has an incompatible type.
237
-
238
215
  """
239
216
 
240
217
  def __init__(
241
218
  self,
242
219
  head: Constraint,
243
220
  body: Constraint,
244
- name=None,
245
- monitor_only=False,
246
- rescale_factor=1.5,
221
+ name: str = None,
247
222
  ):
248
- """
249
- Initializes an ImplicationConstraint instance.
250
- """
223
+ """Initializes an ImplicationConstraint instance.
224
+
225
+ Uses `enforce` and `rescale_factor` from the body constraint.
226
+
227
+ Args:
228
+ head (Constraint): Constraint defining the head of the implication.
229
+ body (Constraint): Constraint defining the body of the implication.
230
+ name (str, optional): A unique name for the constraint. If not
231
+ provided, a name is generated based on the class name and a
232
+ random suffix.
233
+
234
+ Raises:
235
+ TypeError: If a provided attribute has an incompatible type.
251
236
 
237
+ """
252
238
  # Type checking
253
239
  validate_type("head", head, Constraint)
254
240
  validate_type("body", body, Constraint)
@@ -257,64 +243,82 @@ class ImplicationConstraint(Constraint):
257
243
  name = f"{body.name} if {head.name}"
258
244
 
259
245
  # Init parent class
260
- super().__init__(
261
- head.neurons | body.neurons,
262
- name,
263
- monitor_only,
264
- rescale_factor,
265
- )
246
+ super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
266
247
 
267
248
  self.head = head
268
249
  self.body = body
269
250
 
270
- def check_constraint(
271
- self, prediction: dict[str, Tensor]
272
- ) -> tuple[Tensor, int]:
251
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
252
+ """Check whether the implication constraint is satisfied.
253
+
254
+ Evaluates the `head` and `body` constraints. The `body` constraint
255
+ is enforced only if the `head` constraint is satisfied. If the
256
+ `head` constraint is not satisfied, the `body` constraint does not
257
+ affect the result.
258
+
259
+ Args:
260
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
273
261
 
262
+ Returns:
263
+ tuple[Tensor, Tensor]:
264
+ - result: Tensor indicating satisfaction of the implication
265
+ constraint (1 if satisfied, 0 otherwise).
266
+ - head_satisfaction: Tensor indicating satisfaction of the
267
+ head constraint alone.
268
+ """
274
269
  # Check satisfaction of head and body constraints
275
- head_satisfaction, _ = self.head.check_constraint(prediction)
276
- body_satisfaction, _ = self.body.check_constraint(prediction)
270
+ head_satisfaction, _ = self.head.check_constraint(data)
271
+ body_satisfaction, _ = self.body.check_constraint(data)
277
272
 
278
273
  # If head constraint is satisfied (returning 1),
279
274
  # the body constraint matters (and should return 0/1 based on body)
280
275
  # If head constraint is not satisfied (returning 0),
281
276
  # the body constraint does not apply (and should return 1)
282
- result = logical_or(
283
- logical_not(head_satisfaction), body_satisfaction
284
- ).float()
277
+ result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
278
+
279
+ return result, head_satisfaction
285
280
 
286
- return result, count_nonzero(head_satisfaction)
281
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
282
+ """Compute adjustment directions for tags to satisfy the constraint.
287
283
 
288
- def calculate_direction(
289
- self, prediction: dict[str, Tensor]
290
- ) -> Dict[str, Tensor]:
284
+ Uses the `body` constraint directions as the update vector. Only
285
+ applies updates if the `head` constraint is satisfied. Currently,
286
+ this method only works for dense layers due to tag-to-index
287
+ translation limitations.
288
+
289
+ Args:
290
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
291
+
292
+ Returns:
293
+ dict[str, Tensor]: Dictionary mapping tags to tensors
294
+ specifying the adjustment direction for each tag.
295
+ """
291
296
  # NOTE currently only works for dense layers
292
- # due to neuron to index translation
297
+ # due to tag to index translation
293
298
 
294
299
  # Use directions of constraint body as update vector
295
- return self.body.calculate_direction(prediction)
300
+ return self.body.calculate_direction(data)
296
301
 
297
302
 
298
303
  class ScalarConstraint(Constraint):
299
- """
300
- A constraint that enforces scalar-based comparisons on a specific neuron.
304
+ """A constraint that enforces scalar-based comparisons on a specific tag.
301
305
 
302
- This class ensures that the output of a specified neuron satisfies a scalar
306
+ This class ensures that the output of a specified tag satisfies a scalar
303
307
  comparison operation (e.g., less than, greater than, etc.). It uses a
304
308
  comparator function to validate the condition and calculates adjustment
305
309
  directions accordingly.
306
310
 
307
311
  Args:
308
- operand (Union[str, Transformation]): Name of the neuron or a
312
+ operand (Union[str, Transformation]): Name of the tag or a
309
313
  transformation to apply.
310
314
  comparator (Callable[[Tensor, Number], Tensor]): A comparison
311
315
  function (e.g., `torch.ge`, `torch.lt`).
312
316
  scalar (Number): The scalar value to compare against.
313
317
  name (str, optional): A unique name for the constraint. If not
314
318
  provided, a name is auto-generated in the format
315
- "<neuron_name> <comparator> <scalar>".
316
- monitor_only (bool, optional): If True, only monitor the constraint
317
- without adjusting the loss. Defaults to False.
319
+ "<tag> <comparator> <scalar>".
320
+ enforce (bool, optional): If False, only monitor the constraint
321
+ without adjusting the loss. Defaults to True.
318
322
  rescale_factor (Number, optional): Factor to scale the
319
323
  constraint-adjusted loss. Defaults to 1.5.
320
324
 
@@ -322,87 +326,120 @@ class ScalarConstraint(Constraint):
322
326
  TypeError: If a provided attribute has an incompatible type.
323
327
 
324
328
  Notes:
325
- - The `neuron_name` must be defined in the `descriptor` mapping.
326
- - The constraint name is composed using the neuron name,
327
- comparator, and scalar value.
329
+ - The `tag` must be defined in the `descriptor` mapping.
330
+ - The constraint name is composed using the tag, comparator, and scalar value.
328
331
 
329
332
  """
330
333
 
331
334
  def __init__(
332
335
  self,
333
- operand: Union[str, Transformation],
336
+ operand: str | Transformation,
334
337
  comparator: Callable[[Tensor, Number], Tensor],
335
338
  scalar: Number,
336
339
  name: str = None,
337
- monitor_only: bool = False,
340
+ enforce: bool = True,
338
341
  rescale_factor: Number = 1.5,
339
342
  ) -> None:
340
- """
341
- Initializes a ScalarConstraint instance.
342
- """
343
+ """Initializes a ScalarConstraint instance.
344
+
345
+ Args:
346
+ operand (Union[str, Transformation]): Function that needs to be
347
+ performed on the network variables before applying the
348
+ constraint.
349
+ comparator (Callable[[Tensor, Number], Tensor]): Comparison
350
+ operator used in the constraint. Supported types are
351
+ {torch.lt, torch.le, torch.st, torch.se}.
352
+ scalar (Number): Constant to compare the variable to.
353
+ name (str, optional): A unique name for the constraint. If not
354
+ provided, a name is generated based on the class name and a
355
+ random suffix.
356
+ enforce (bool, optional): If False, only monitor the constraint
357
+ without adjusting the loss. Defaults to True.
358
+ rescale_factor (Number, optional): Factor to scale the
359
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
360
+ than 1 to give weight to the constraint.
361
+
362
+ Raises:
363
+ TypeError: If a provided attribute has an incompatible type.
343
364
 
365
+ Notes:
366
+ - The `tag` must be defined in the `descriptor` mapping.
367
+ - The constraint name is composed using the tag, comparator, and scalar value.
368
+ """
344
369
  # Type checking
345
370
  validate_type("operand", operand, (str, Transformation))
346
371
  validate_comparator_pytorch("comparator", comparator)
347
- validate_comparator_pytorch("comparator", comparator)
348
372
  validate_type("scalar", scalar, Number)
349
373
 
350
- # If transformation is provided, get neuron name,
351
- # else use IdentityTransformation
374
+ # If transformation is provided, get tag name, else use IdentityTransformation
352
375
  if isinstance(operand, Transformation):
353
- neuron_name = operand.neuron_name
376
+ tag = operand.tag
354
377
  transformation = operand
355
378
  else:
356
- neuron_name = operand
357
- transformation = IdentityTransformation(neuron_name)
379
+ tag = operand
380
+ transformation = IdentityTransformation(tag)
358
381
 
359
382
  # Compose constraint name
360
- name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
383
+ name = f"{tag} {comparator.__name__} {str(scalar)}"
361
384
 
362
385
  # Init parent class
363
- super().__init__({neuron_name}, name, monitor_only, rescale_factor)
386
+ super().__init__({tag}, name, enforce, rescale_factor)
364
387
 
365
388
  # Init variables
389
+ self.tag = tag
366
390
  self.comparator = comparator
367
391
  self.scalar = scalar
368
392
  self.transformation = transformation
369
393
 
370
- # Get layer name and feature index from neuron_name
371
- self.layer = self.descriptor.neuron_to_layer[neuron_name]
372
- self.index = self.descriptor.neuron_to_index[neuron_name]
373
-
374
394
  # Calculate directions based on constraint operator
375
395
  if self.comparator in [lt, le]:
376
- self.direction = -1
377
- elif self.comparator in [gt, ge]:
378
396
  self.direction = 1
397
+ elif self.comparator in [gt, ge]:
398
+ self.direction = -1
399
+
400
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
401
+ """Check if the scalar constraint is satisfied for a given tag.
379
402
 
380
- def check_constraint(
381
- self, prediction: dict[str, Tensor]
382
- ) -> tuple[Tensor, int]:
403
+ Args:
404
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
383
405
 
406
+ Returns:
407
+ tuple[Tensor, Tensor]:
408
+ - result: Tensor indicating whether the tag satisfies the constraint.
409
+ - ones_like(result): Tensor of ones with same shape as `result`.
410
+ """
384
411
  # Select relevant columns
385
- selection = prediction[self.layer][:, self.index]
412
+ selection = self.descriptor.select(self.tag, data)
386
413
 
387
414
  # Apply transformation
388
415
  selection = self.transformation(selection)
389
416
 
390
417
  # Calculate current constraint result
391
418
  result = self.comparator(selection, self.scalar).float()
392
- return result, numel(result)
419
+ return result, ones_like(result)
420
+
421
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
422
+ """Compute adjustment directions to satisfy the scalar constraint.
423
+
424
+ Only works for dense layers due to tag-to-index translation.
425
+
426
+ Args:
427
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
393
428
 
394
- def calculate_direction(
395
- self, prediction: dict[str, Tensor]
396
- ) -> Dict[str, Tensor]:
429
+ Returns:
430
+ dict[str, Tensor]: Dictionary mapping layers to tensors specifying
431
+ the adjustment direction for each tag.
432
+ """
397
433
  # NOTE currently only works for dense layers due
398
- # to neuron to index translation
434
+ # to tag to index translation
399
435
 
400
436
  output = {}
401
437
 
402
438
  for layer in self.layers:
403
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
439
+ output[layer] = zeros_like(data[layer][0], device=self.device)
404
440
 
405
- output[self.layer][self.index] = self.direction
441
+ layer, index = self.descriptor.location(self.tag)
442
+ output[layer][index] = self.direction
406
443
 
407
444
  for layer in self.layers:
408
445
  output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
@@ -411,26 +448,24 @@ class ScalarConstraint(Constraint):
411
448
 
412
449
 
413
450
  class BinaryConstraint(Constraint):
414
- """
415
- A constraint that enforces a binary comparison between two neurons.
451
+ """A constraint that enforces a binary comparison between two tags.
416
452
 
417
- This class ensures that the output of one neuron satisfies a comparison
418
- operation with the output of another neuron
419
- (e.g., less than, greater than, etc.). It uses a comparator function to
420
- validate the condition and calculates adjustment directions accordingly.
453
+ This class ensures that the output of one tag satisfies a comparison
454
+ operation with the output of another tag (e.g., less than, greater than, etc.).
455
+ It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
421
456
 
422
457
  Args:
423
458
  operand_left (Union[str, Transformation]): Name of the left
424
- neuron or a transformation to apply.
459
+ tag or a transformation to apply.
425
460
  comparator (Callable[[Tensor, Number], Tensor]): A comparison
426
461
  function (e.g., `torch.ge`, `torch.lt`).
427
462
  operand_right (Union[str, Transformation]): Name of the right
428
- neuron or a transformation to apply.
463
+ tag or a transformation to apply.
429
464
  name (str, optional): A unique name for the constraint. If not
430
465
  provided, a name is auto-generated in the format
431
- "<neuron_name_left> <comparator> <neuron_name_right>".
432
- monitor_only (bool, optional): If True, only monitor the constraint
433
- without adjusting the loss. Defaults to False.
466
+ "<operand_left> <comparator> <operand_right>".
467
+ enforce (bool, optional): If False, only monitor the constraint
468
+ without adjusting the loss. Defaults to True.
434
469
  rescale_factor (Number, optional): Factor to scale the
435
470
  constraint-adjusted loss. Defaults to 1.5.
436
471
 
@@ -438,84 +473,107 @@ class BinaryConstraint(Constraint):
438
473
  TypeError: If a provided attribute has an incompatible type.
439
474
 
440
475
  Notes:
441
- - The neuron names must be defined in the `descriptor` mapping.
442
- - The constraint name is composed using the left neuron name,
443
- comparator, and right neuron name.
476
+ - The tags must be defined in the `descriptor` mapping.
477
+ - The constraint name is composed using the left tag, comparator, and right tag.
444
478
 
445
479
  """
446
480
 
447
481
  def __init__(
448
482
  self,
449
- operand_left: Union[str, Transformation],
483
+ operand_left: str | Transformation,
450
484
  comparator: Callable[[Tensor, Number], Tensor],
451
- operand_right: Union[str, Transformation],
485
+ operand_right: str | Transformation,
452
486
  name: str = None,
453
- monitor_only: bool = False,
487
+ enforce: bool = True,
454
488
  rescale_factor: Number = 1.5,
455
489
  ) -> None:
456
- """
457
- Initializes a BinaryConstraint instance.
458
- """
490
+ """Initializes a BinaryConstraint instance.
459
491
 
492
+ Args:
493
+ operand_left (Union[str, Transformation]): Name of the left
494
+ tag or a transformation to apply.
495
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
496
+ function (e.g., `torch.ge`, `torch.lt`).
497
+ operand_right (Union[str, Transformation]): Name of the right
498
+ tag or a transformation to apply.
499
+ name (str, optional): A unique name for the constraint. If not
500
+ provided, a name is auto-generated in the format
501
+ "<operand_left> <comparator> <operand_right>".
502
+ enforce (bool, optional): If False, only monitor the constraint
503
+ without adjusting the loss. Defaults to True.
504
+ rescale_factor (Number, optional): Factor to scale the
505
+ constraint-adjusted loss. Defaults to 1.5.
506
+
507
+ Raises:
508
+ TypeError: If a provided attribute has an incompatible type.
509
+
510
+ Notes:
511
+ - The tags must be defined in the `descriptor` mapping.
512
+ - The constraint name is composed using the left tag,
513
+ comparator, and right tag.
514
+ """
460
515
  # Type checking
461
516
  validate_type("operand_left", operand_left, (str, Transformation))
462
517
  validate_comparator_pytorch("comparator", comparator)
463
518
  validate_comparator_pytorch("comparator", comparator)
464
519
  validate_type("operand_right", operand_right, (str, Transformation))
465
520
 
466
- # If transformation is provided, get neuron name,
467
- # else use IdentityTransformation
521
+ # If transformation is provided, get tag name, else use IdentityTransformation
468
522
  if isinstance(operand_left, Transformation):
469
- neuron_name_left = operand_left.neuron_name
523
+ tag_left = operand_left.tag
470
524
  transformation_left = operand_left
471
525
  else:
472
- neuron_name_left = operand_left
473
- transformation_left = IdentityTransformation(neuron_name_left)
526
+ tag_left = operand_left
527
+ transformation_left = IdentityTransformation(tag_left)
474
528
 
475
529
  if isinstance(operand_right, Transformation):
476
- neuron_name_right = operand_right.neuron_name
530
+ tag_right = operand_right.tag
477
531
  transformation_right = operand_right
478
532
  else:
479
- neuron_name_right = operand_right
480
- transformation_right = IdentityTransformation(neuron_name_right)
533
+ tag_right = operand_right
534
+ transformation_right = IdentityTransformation(tag_right)
481
535
 
482
536
  # Compose constraint name
483
- name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
537
+ name = f"{tag_left} {comparator.__name__} {tag_right}"
484
538
 
485
539
  # Init parent class
486
- super().__init__(
487
- {neuron_name_left, neuron_name_right},
488
- name,
489
- monitor_only,
490
- rescale_factor,
491
- )
540
+ super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
492
541
 
493
542
  # Init variables
494
543
  self.comparator = comparator
544
+ self.tag_left = tag_left
545
+ self.tag_right = tag_right
495
546
  self.transformation_left = transformation_left
496
547
  self.transformation_right = transformation_right
497
548
 
498
- # Get layer name and feature index from neuron_name
499
- self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
500
- self.layer_right = self.descriptor.neuron_to_layer[neuron_name_right]
501
- self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
502
- self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
503
-
504
549
  # Calculate directions based on constraint operator
505
550
  if self.comparator in [lt, le]:
506
- self.direction_left = -1
507
- self.direction_right = 1
508
- else:
509
551
  self.direction_left = 1
510
552
  self.direction_right = -1
553
+ else:
554
+ self.direction_left = -1
555
+ self.direction_right = 1
556
+
557
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
558
+ """Evaluate whether the binary constraint is satisfied for the current predictions.
511
559
 
512
- def check_constraint(
513
- self, prediction: dict[str, Tensor]
514
- ) -> tuple[Tensor, int]:
560
+ The constraint compares the outputs of two tags using the specified
561
+ comparator function. A result of `1` indicates the constraint is satisfied
562
+ for a sample, and `0` indicates it is violated.
563
+
564
+ Args:
565
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
515
566
 
567
+ Returns:
568
+ tuple[Tensor, Tensor]:
569
+ - result (Tensor): Binary tensor indicating constraint satisfaction
570
+ (1 for satisfied, 0 for violated) for each sample.
571
+ - mask (Tensor): Tensor of ones with the same shape as `result`,
572
+ used for constraint aggregation.
573
+ """
516
574
  # Select relevant columns
517
- selection_left = prediction[self.layer_left][:, self.index_left]
518
- selection_right = prediction[self.layer_right][:, self.index_right]
575
+ selection_left = self.descriptor.select(self.tag_left, data)
576
+ selection_right = self.descriptor.select(self.tag_right, data)
519
577
 
520
578
  # Apply transformations
521
579
  selection_left = self.transformation_left(selection_left)
@@ -523,21 +581,34 @@ class BinaryConstraint(Constraint):
523
581
 
524
582
  result = self.comparator(selection_left, selection_right).float()
525
583
 
526
- return result, numel(result)
584
+ return result, ones_like(result)
585
+
586
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
587
+ """Compute adjustment directions for the tags involved in the binary constraint.
588
+
589
+ The returned directions indicate how to adjust each tag's output to
590
+ satisfy the constraint. Only currently supported for dense layers.
527
591
 
528
- def calculate_direction(
529
- self, prediction: dict[str, Tensor]
530
- ) -> Dict[str, Tensor]:
592
+ Args:
593
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
594
+
595
+ Returns:
596
+ dict[str, Tensor]: A mapping from layer names to tensors specifying
597
+ the normalized adjustment directions for each tag involved in the
598
+ constraint.
599
+ """
531
600
  # NOTE currently only works for dense layers due
532
- # to neuron to index translation
601
+ # to tag to index translation
533
602
 
534
603
  output = {}
535
604
 
536
605
  for layer in self.layers:
537
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
606
+ output[layer] = zeros_like(data[layer][0], device=self.device)
538
607
 
539
- output[self.layer_left][self.index_left] = self.direction_left
540
- output[self.layer_right][self.index_right] = self.direction_right
608
+ layer_left, index_left = self.descriptor.location(self.tag_left)
609
+ layer_right, index_right = self.descriptor.location(self.tag_right)
610
+ output[layer_left][index_left] = self.direction_left
611
+ output[layer_right][index_right] = self.direction_right
541
612
 
542
613
  for layer in self.layers:
543
614
  output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
@@ -546,142 +617,119 @@ class BinaryConstraint(Constraint):
546
617
 
547
618
 
548
619
  class SumConstraint(Constraint):
549
- """
550
- A constraint that enforces a weighted summation comparison
551
- between two groups of neurons.
620
+ """A constraint that enforces a weighted summation comparison between two groups of tags.
552
621
 
553
622
  This class evaluates whether the weighted sum of outputs from one set of
554
- neurons satisfies a comparison operation with the weighted sum of
555
- outputs from another set of neurons.
556
-
557
- Args:
558
- operands_left (list[Union[str, Transformation]]): List of neuron
559
- names or transformations on the left side.
560
- comparator (Callable[[Tensor, Number], Tensor]): A comparison
561
- function for the constraint.
562
- operands_right (list[Union[str, Transformation]]): List of neuron
563
- names or transformations on the right side.
564
- weights_left (list[Number], optional): Weights for the left neurons.
565
- Defaults to None.
566
- weights_right (list[Number], optional): Weights for the right
567
- neurons. Defaults to None.
568
- name (str, optional): Unique name for the constraint.
569
- If None, it's auto-generated. Defaults to None.
570
- monitor_only (bool, optional): If True, only monitor the constraint
571
- without adjusting the loss. Defaults to False.
572
- rescale_factor (Number, optional): Factor to scale the
573
- constraint-adjusted loss. Defaults to 1.5.
574
-
575
- Raises:
576
- TypeError: If a provided attribute has an incompatible type.
577
- ValueError: If the dimensions of neuron names and weights mismatch.
578
-
623
+ tags satisfies a comparison operation with the weighted sum of
624
+ outputs from another set of tags.
579
625
  """
580
626
 
581
627
  def __init__(
582
628
  self,
583
- operands_left: list[Union[str, Transformation]],
629
+ operands_left: list[str | Transformation],
584
630
  comparator: Callable[[Tensor, Number], Tensor],
585
- operands_right: list[Union[str, Transformation]],
631
+ operands_right: list[str | Transformation],
586
632
  weights_left: list[Number] = None,
587
633
  weights_right: list[Number] = None,
588
634
  name: str = None,
589
- monitor_only: bool = False,
635
+ enforce: bool = True,
590
636
  rescale_factor: Number = 1.5,
591
637
  ) -> None:
592
- """
593
- Initializes the SumConstraint.
594
- """
638
+ """Initializes the SumConstraint.
595
639
 
640
+ Args:
641
+ operands_left (list[Union[str, Transformation]]): List of tags
642
+ or transformations on the left side.
643
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
644
+ function for the constraint.
645
+ operands_right (list[Union[str, Transformation]]): List of tags
646
+ or transformations on the right side.
647
+ weights_left (list[Number], optional): Weights for the left
648
+ tags. Defaults to None.
649
+ weights_right (list[Number], optional): Weights for the right
650
+ tags. Defaults to None.
651
+ name (str, optional): Unique name for the constraint.
652
+ If None, it's auto-generated. Defaults to None.
653
+ enforce (bool, optional): If False, only monitor the constraint
654
+ without adjusting the loss. Defaults to True.
655
+ rescale_factor (Number, optional): Factor to scale the
656
+ constraint-adjusted loss. Defaults to 1.5.
657
+
658
+ Raises:
659
+ TypeError: If a provided attribute has an incompatible type.
660
+ ValueError: If the dimensions of tags and weights mismatch.
661
+ """
596
662
  # Type checking
597
663
  validate_iterable("operands_left", operands_left, (str, Transformation))
598
664
  validate_comparator_pytorch("comparator", comparator)
599
665
  validate_comparator_pytorch("comparator", comparator)
600
- validate_iterable(
601
- "operands_right", operands_right, (str, Transformation)
602
- )
666
+ validate_iterable("operands_right", operands_right, (str, Transformation))
603
667
  validate_iterable("weights_left", weights_left, Number, allow_none=True)
604
- validate_iterable(
605
- "weights_right", weights_right, Number, allow_none=True
606
- )
668
+ validate_iterable("weights_right", weights_right, Number, allow_none=True)
607
669
 
608
- # If transformation is provided, get neuron name,
609
- # else use IdentityTransformation
610
- neuron_names_left: list[str] = []
670
+ # If transformation is provided, get tag, else use IdentityTransformation
671
+ tags_left: list[str] = []
611
672
  transformations_left: list[Transformation] = []
612
673
  for operand_left in operands_left:
613
674
  if isinstance(operand_left, Transformation):
614
- neuron_name_left = operand_left.neuron_name
615
- neuron_names_left.append(neuron_name_left)
675
+ tag_left = operand_left.tag
676
+ tags_left.append(tag_left)
616
677
  transformations_left.append(operand_left)
617
678
  else:
618
- neuron_name_left = operand_left
619
- neuron_names_left.append(neuron_name_left)
620
- transformations_left.append(
621
- IdentityTransformation(neuron_name_left)
622
- )
679
+ tag_left = operand_left
680
+ tags_left.append(tag_left)
681
+ transformations_left.append(IdentityTransformation(tag_left))
623
682
 
624
- neuron_names_right: list[str] = []
683
+ tags_right: list[str] = []
625
684
  transformations_right: list[Transformation] = []
626
685
  for operand_right in operands_right:
627
686
  if isinstance(operand_right, Transformation):
628
- neuron_name_right = operand_right.neuron_name
629
- neuron_names_right.append(neuron_name_right)
687
+ tag_right = operand_right.tag
688
+ tags_right.append(tag_right)
630
689
  transformations_right.append(operand_right)
631
690
  else:
632
- neuron_name_right = operand_right
633
- neuron_names_right.append(neuron_name_right)
634
- transformations_right.append(
635
- IdentityTransformation(neuron_name_right)
636
- )
691
+ tag_right = operand_right
692
+ tags_right.append(tag_right)
693
+ transformations_right.append(IdentityTransformation(tag_right))
637
694
 
638
695
  # Compose constraint name
639
- w_left = weights_left or [""] * len(neuron_names_left)
640
- w_right = weights_right or [""] * len(neuron_names_right)
641
- left_expr = " + ".join(
642
- f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
643
- )
644
- right_expr = " + ".join(
645
- f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
646
- )
696
+ w_left = weights_left or [""] * len(tags_left)
697
+ w_right = weights_right or [""] * len(tags_right)
698
+ left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
699
+ right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
647
700
  comparator_name = comparator.__name__
648
701
  name = f"{left_expr} {comparator_name} {right_expr}"
649
702
 
650
703
  # Init parent class
651
- neuron_names = set(neuron_names_left) | set(neuron_names_right)
652
- super().__init__(neuron_names, name, monitor_only, rescale_factor)
704
+ tags = set(tags_left) | set(tags_right)
705
+ super().__init__(tags, name, enforce, rescale_factor)
653
706
 
654
707
  # Init variables
655
708
  self.comparator = comparator
656
- self.neuron_names_left = neuron_names_left
657
- self.neuron_names_right = neuron_names_right
709
+ self.tags_left = tags_left
710
+ self.tags_right = tags_right
658
711
  self.transformations_left = transformations_left
659
712
  self.transformations_right = transformations_right
660
713
 
661
- # If feature list dimensions don't match
662
- # weight list dimensions, raise error
663
- if weights_left and (len(neuron_names_left) != len(weights_left)):
714
+ # If feature list dimensions don't match weight list dimensions, raise error
715
+ if weights_left and (len(tags_left) != len(weights_left)):
664
716
  raise ValueError(
665
- "The dimensions of neuron_names_left don't match with the "
666
- "dimensions of weights_left."
717
+ "The dimensions of tags_left don't match with the dimensions of weights_left."
667
718
  )
668
- if weights_right and (len(neuron_names_right) != len(weights_right)):
719
+ if weights_right and (len(tags_right) != len(weights_right)):
669
720
  raise ValueError(
670
- "The dimensions of neuron_names_right don't match with the "
671
- "dimensions of weights_right."
721
+ "The dimensions of tags_right don't match with the dimensions of weights_right."
672
722
  )
673
723
 
674
724
  # If weights are provided for summation, transform them to Tensors
675
725
  if weights_left:
676
726
  self.weights_left = tensor(weights_left, device=self.device)
677
727
  else:
678
- self.weights_left = ones(len(neuron_names_left), device=self.device)
728
+ self.weights_left = ones(len(tags_left), device=self.device)
679
729
  if weights_right:
680
730
  self.weights_right = tensor(weights_right, device=self.device)
681
731
  else:
682
- self.weights_right = ones(
683
- len(neuron_names_right), device=self.device
684
- )
732
+ self.weights_right = ones(len(tags_right), device=self.device)
685
733
 
686
734
  # Calculate directions based on constraint operator
687
735
  if self.comparator in [lt, le]:
@@ -691,80 +739,82 @@ class SumConstraint(Constraint):
691
739
  self.direction_left = 1
692
740
  self.direction_right = -1
693
741
 
694
- def check_constraint(
695
- self, prediction: dict[str, Tensor]
696
- ) -> tuple[Tensor, int]:
742
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
743
+ """Evaluate whether the weighted sum constraint is satisfied.
744
+
745
+ Computes the weighted sum of outputs from the left and right tags,
746
+ applies the specified comparator function, and returns a binary result for
747
+ each sample.
748
+
749
+ Args:
750
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
751
+
752
+ Returns:
753
+ tuple[Tensor, Tensor]:
754
+ - result (Tensor): Binary tensor indicating whether the constraint
755
+ is satisfied (1) or violated (0) for each sample.
756
+ - mask (Tensor): Tensor of ones, used for constraint aggregation.
757
+ """
697
758
 
698
759
  def compute_weighted_sum(
699
- neuron_names: list[str],
760
+ tags: list[str],
700
761
  transformations: list[Transformation],
701
- weights: tensor,
702
- ) -> tensor:
703
- layers = [
704
- self.descriptor.neuron_to_layer[neuron_name]
705
- for neuron_name in neuron_names
706
- ]
707
- indices = [
708
- self.descriptor.neuron_to_index[neuron_name]
709
- for neuron_name in neuron_names
710
- ]
711
-
712
- # Select relevant column
713
- selections = [
714
- prediction[layer][:, index]
715
- for layer, index in zip(layers, indices)
716
- ]
762
+ weights: Tensor,
763
+ ) -> Tensor:
764
+ # Select relevant columns
765
+ selections = [self.descriptor.select(tag, data) for tag in tags]
717
766
 
718
767
  # Apply transformations
719
768
  results = []
720
- for transformation, selection in zip(transformations, selections):
769
+ for transformation, selection in zip(transformations, selections, strict=False):
721
770
  results.append(transformation(selection))
722
771
 
723
- # Extract predictions for all neurons and apply weights in bulk
724
- predictions = stack(
725
- results,
726
- dim=1,
727
- )
772
+ # Extract predictions for all tags and apply weights in bulk
773
+ predictions = stack(results)
728
774
 
729
775
  # Calculate weighted sum
730
- return (predictions * weights.unsqueeze(0)).sum(dim=1)
776
+ return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
731
777
 
732
778
  # Compute weighted sums
733
779
  weighted_sum_left = compute_weighted_sum(
734
- self.neuron_names_left,
735
- self.transformations_left,
736
- self.weights_left,
780
+ self.tags_left, self.transformations_left, self.weights_left
737
781
  )
738
782
  weighted_sum_right = compute_weighted_sum(
739
- self.neuron_names_right,
740
- self.transformations_right,
741
- self.weights_right,
783
+ self.tags_right, self.transformations_right, self.weights_right
742
784
  )
743
785
 
744
786
  # Apply the comparator and calculate the result
745
787
  result = self.comparator(weighted_sum_left, weighted_sum_right).float()
746
788
 
747
- return result, numel(result)
789
+ return result, ones_like(result)
790
+
791
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
792
+ """Compute adjustment directions for tags involved in the weighted sum constraint.
793
+
794
+ The directions indicate how to adjust each tag's output to satisfy the
795
+ constraint. Only dense layers are currently supported.
796
+
797
+ Args:
798
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
748
799
 
749
- def calculate_direction(
750
- self, prediction: dict[str, Tensor]
751
- ) -> Dict[str, Tensor]:
800
+ Returns:
801
+ dict[str, Tensor]: Mapping from layer names to normalized tensors
802
+ specifying adjustment directions for each tag involved in the constraint.
803
+ """
752
804
  # NOTE currently only works for dense layers
753
- # due to neuron to index translation
805
+ # due to tag to index translation
754
806
 
755
807
  output = {}
756
808
 
757
809
  for layer in self.layers:
758
- output[layer] = zeros_like(prediction[layer][0], device=self.device)
810
+ output[layer] = zeros_like(data[layer][0], device=self.device)
759
811
 
760
- for neuron_name_left in self.neuron_names_left:
761
- layer = self.descriptor.neuron_to_layer[neuron_name_left]
762
- index = self.descriptor.neuron_to_index[neuron_name_left]
812
+ for tag_left in self.tags_left:
813
+ layer, index = self.descriptor.location(tag_left)
763
814
  output[layer][index] = self.direction_left
764
815
 
765
- for neuron_name_right in self.neuron_names_right:
766
- layer = self.descriptor.neuron_to_layer[neuron_name_right]
767
- index = self.descriptor.neuron_to_index[neuron_name_right]
816
+ for tag_right in self.tags_right:
817
+ layer, index = self.descriptor.location(tag_right)
768
818
  output[layer][index] = self.direction_right
769
819
 
770
820
  for layer in self.layers:
@@ -773,134 +823,434 @@ class SumConstraint(Constraint):
773
823
  return output
774
824
 
775
825
 
776
- class PythagoreanIdentityConstraint(Constraint):
826
+ class MonotonicityConstraint(Constraint):
827
+ """Constraint that enforces a monotonic relationship between two tags.
828
+
829
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
830
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
777
831
  """
778
- A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
779
- where `a` and `b` are neurons or transformations.
780
832
 
781
- This constraint checks that the sum of the squares of two specified
782
- neurons (or their transformations) is approximately equal to 1.
783
- The constraint is evaluated using relative and absolute
784
- tolerance (`rtol` and `atol`) and is applied during the forward pass.
833
+ def __init__(
834
+ self,
835
+ tag_prediction: str,
836
+ tag_reference: str,
837
+ rescale_factor_lower: float = 1.5,
838
+ rescale_factor_upper: float = 1.75,
839
+ stable: bool = True,
840
+ direction: Literal["ascending", "descending"] = "ascending",
841
+ name: str = None,
842
+ enforce: bool = True,
843
+ ):
844
+ """Constraint that enforces monotonicity on a predicted output.
785
845
 
786
- Args:
787
- a (Union[str, Transformation]): The first input, either a
788
- neuron name (str) or a Transformation.
789
- b (Union[str, Transformation]): The second input, either a
790
- neuron name (str) or a Transformation.
791
- rtol (float, optional): The relative tolerance for the
792
- comparison (default is 0.00001).
793
- atol (float, optional): The absolute tolerance for the
794
- comparison (default is 1e-8).
795
- name (str, optional): The name of the constraint
796
- (default is None, and it is generated automatically).
797
- monitor_only (bool, optional): Flag indicating whether the
798
- constraint is only for monitoring (default is False).
799
- rescale_factor (Number, optional): A factor used for
800
- rescaling (default is 1.5).
846
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
847
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
801
848
 
802
- Raises:
803
- TypeError: If a provided attribute has an incompatible type.
849
+ Args:
850
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
851
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
852
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
853
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
854
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
855
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
856
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
857
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
858
+ """
859
+ # Type checking
860
+ validate_type("rescale_factor_lower", rescale_factor_lower, float)
861
+ validate_type("rescale_factor_upper", rescale_factor_upper, float)
862
+ validate_type("stable", stable, bool)
863
+ validate_type("direction", direction, str)
864
+
865
+ # Compose constraint name
866
+ if name is None:
867
+ name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
868
+
869
+ # Init parent class
870
+ super().__init__({tag_prediction}, name, enforce, 1.0)
871
+
872
+ # Init variables
873
+ self.tag_prediction = tag_prediction
874
+ self.tag_reference = tag_reference
875
+ self.rescale_factor_lower = rescale_factor_lower
876
+ self.rescale_factor_upper = rescale_factor_upper
877
+ self.stable = stable
878
+ self.descending = direction == "descending"
879
+
880
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
881
+ """Evaluate whether the monotonicity constraint is satisfied."""
882
+ # Select relevant columns
883
+ preds = self.descriptor.select(self.tag_prediction, data)
884
+ targets = self.descriptor.select(self.tag_reference, data)
885
+
886
+ # Utility: convert values -> ranks (0 ... num_features-1)
887
+ def compute_ranks(x: Tensor, descending: bool) -> Tensor:
888
+ return argsort(
889
+ argsort(x, descending=descending, stable=self.stable, dim=0),
890
+ descending=False,
891
+ stable=self.stable,
892
+ dim=0,
893
+ )
804
894
 
895
+ # Compute predicted and target ranks
896
+ pred_ranks = compute_ranks(preds, descending=self.descending)
897
+ target_ranks = compute_ranks(targets, descending=False)
898
+
899
+ # Rank difference
900
+ rank_diff = pred_ranks - target_ranks
901
+
902
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
903
+ batch_size = preds.shape[0]
904
+ invert_direction = -1 if self.descending else 1
905
+ self.compared_rankings = (
906
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
907
+ + self.rescale_factor_lower * sign(rank_diff)
908
+ ) * invert_direction
909
+
910
+ # Calculate satisfaction
911
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
912
+
913
+ return incorrect_rankings, ones_like(incorrect_rankings)
914
+
915
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
916
+ """Calculates ranking adjustments for monotonicity enforcement."""
917
+ layer, _ = self.descriptor.location(self.tag_prediction)
918
+ return {layer: self.compared_rankings}
919
+
920
+
921
+ class GroupedMonotonicityConstraint(MonotonicityConstraint):
922
+ """Constraint that enforces a monotonic relationship between two tags.
923
+
924
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
925
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
805
926
  """
806
927
 
807
928
  def __init__(
808
929
  self,
809
- a: Union[str, Transformation],
810
- b: Union[str, Transformation],
811
- rtol: float = 0.00001,
812
- atol: float = 1e-8,
930
+ tag_prediction: str,
931
+ tag_reference: str,
932
+ tag_group_identifier: str,
933
+ rescale_factor_lower: float = 1.5,
934
+ rescale_factor_upper: float = 1.75,
935
+ stable: bool = True,
936
+ direction: Literal["ascending", "descending"] = "ascending",
937
+ name: str = None,
938
+ enforce: bool = True,
939
+ ):
940
+ """Constraint that enforces monotonicity on a predicted output.
941
+
942
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
943
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
944
+
945
+ Args:
946
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
947
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
948
+ tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
949
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
950
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
951
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
952
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
953
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
954
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
955
+ """
956
+ # Compose constraint name
957
+ if name is None:
958
+ name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
959
+
960
+ # Init parent class
961
+ super().__init__(
962
+ tag_prediction=tag_prediction,
963
+ tag_reference=tag_reference,
964
+ rescale_factor_lower=rescale_factor_lower,
965
+ rescale_factor_upper=rescale_factor_upper,
966
+ stable=stable,
967
+ direction=direction,
968
+ name=name,
969
+ enforce=enforce,
970
+ )
971
+
972
+ # Init variables
973
+ self.tag_prediction = tag_prediction
974
+ self.tag_reference = tag_reference
975
+ self.tag_group_identifier = tag_group_identifier
976
+
977
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
978
+ """Evaluate whether the monotonicity constraint is satisfied."""
979
+ # Select group identifiers and convert to unique list
980
+ group_identifiers = self.descriptor.select(self.tag_group_identifier, data)
981
+ unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
982
+
983
+ # Initialize checks and directions
984
+ checks = zeros_like(group_identifiers, device=self.device)
985
+ self.directions = zeros_like(group_identifiers, device=self.device)
986
+
987
+ # Get prediction and target keys
988
+ preds_key, _ = self.descriptor.location(self.tag_prediction)
989
+ targets_key, _ = self.descriptor.location(self.tag_reference)
990
+
991
+ for group_identifier in unique_group_identifiers:
992
+ # Create mask for the samples in this group
993
+ group_mask = (group_identifiers == group_identifier).squeeze(1)
994
+
995
+ # Create mini-batch for the group
996
+ group_data = {
997
+ preds_key: data[preds_key][group_mask],
998
+ targets_key: data[targets_key][group_mask],
999
+ }
1000
+
1001
+ # Call super on the mini-batch
1002
+ checks[group_mask], _ = super().check_constraint(group_data)
1003
+ self.directions[group_mask] = self.compared_rankings
1004
+
1005
+ return checks, ones_like(checks)
1006
+
1007
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
1008
+ """Calculates ranking adjustments for monotonicity enforcement."""
1009
+ layer, _ = self.descriptor.location(self.tag_prediction)
1010
+ return {layer: self.directions}
1011
+
1012
+
1013
+ class ANDConstraint(Constraint):
1014
+ """A composite constraint that enforces the logical AND of multiple constraints.
1015
+
1016
+ This class combines multiple sub-constraints and evaluates them jointly:
1017
+
1018
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1019
+ are satisfied (elementwise logical AND).
1020
+ * The corrective direction is computed by weighting each sub-constraint's
1021
+ direction with its satisfaction mask and summing across all sub-constraints.
1022
+ """
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ *constraints: Constraint,
813
1027
  name: str = None,
814
1028
  monitor_only: bool = False,
815
1029
  rescale_factor: Number = 1.5,
816
1030
  ) -> None:
1031
+ """A composite constraint that enforces the logical AND of multiple constraints.
1032
+
1033
+ This class combines multiple sub-constraints and evaluates them jointly:
1034
+
1035
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1036
+ are satisfied (elementwise logical AND).
1037
+ * The corrective direction is computed by weighting each sub-constraint's
1038
+ direction with its satisfaction mask and summing across all sub-constraints.
1039
+
1040
+ Args:
1041
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1042
+ name (str, optional): A custom name for this constraint. If not provided,
1043
+ the name will be composed from the sub-constraint names joined with
1044
+ " AND ".
1045
+ monitor_only (bool, optional): If True, the constraint will be monitored
1046
+ but not enforced. Defaults to False.
1047
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1048
+ corrections. Defaults to 1.5.
1049
+
1050
+ Attributes:
1051
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1052
+ neurons (set): The union of neurons referenced by the sub-constraints.
1053
+ name (str): The name of the constraint (composed or custom).
817
1054
  """
818
- Initialize the PythagoreanIdentityConstraint.
1055
+ # Type checking
1056
+ validate_iterable("constraints", constraints, Constraint)
1057
+
1058
+ # Compose constraint name
1059
+ if not name:
1060
+ name = " AND ".join([constraint.name for constraint in constraints])
1061
+
1062
+ # Init parent class
1063
+ super().__init__(
1064
+ set().union(*(constraint.tags for constraint in constraints)),
1065
+ name,
1066
+ monitor_only,
1067
+ rescale_factor,
1068
+ )
1069
+
1070
+ # Init variables
1071
+ self.constraints = constraints
1072
+
1073
+ def check_constraint(self, data: dict[str, Tensor]):
1074
+ """Evaluate whether all sub-constraints are satisfied.
1075
+
1076
+ Args:
1077
+ data: Model predictions and associated batch/context information.
1078
+
1079
+ Returns:
1080
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1081
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1082
+ elementwise whether all constraints are satisfied
1083
+ (logical AND).
1084
+ * `mask`: A tensor of ones with the same shape as
1085
+ `total_satisfaction`. Typically used as a weighting mask
1086
+ in downstream processing.
819
1087
  """
1088
+ total_satisfaction: Tensor = None
1089
+ total_mask: Tensor = None
1090
+
1091
+ # TODO vectorize this loop
1092
+ for constraint in self.constraints:
1093
+ satisfaction, mask = constraint.check_constraint(data)
1094
+ if total_satisfaction is None:
1095
+ total_satisfaction = satisfaction
1096
+ total_mask = mask
1097
+ else:
1098
+ total_satisfaction = logical_and(total_satisfaction, satisfaction)
1099
+ total_mask = logical_or(total_mask, mask)
820
1100
 
821
- # Type checking
822
- validate_type("a", a, (str, Transformation))
823
- validate_type("b", b, (str, Transformation))
824
- validate_type("rtol", rtol, float)
825
- validate_type("atol", atol, float)
826
-
827
- # If transformation is provided, get neuron name,
828
- # else use IdentityTransformation
829
- if isinstance(a, Transformation):
830
- neuron_name_a = a.neuron_name
831
- transformation_a = a
832
- else:
833
- neuron_name_a = a
834
- transformation_a = IdentityTransformation(neuron_name_a)
1101
+ return total_satisfaction.float(), total_mask.float()
835
1102
 
836
- if isinstance(b, Transformation):
837
- neuron_name_b = b.neuron_name
838
- transformation_b = b
839
- else:
840
- neuron_name_b = b
841
- transformation_b = IdentityTransformation(neuron_name_b)
1103
+ def calculate_direction(self, data: dict[str, Tensor]):
1104
+ """Compute the corrective direction by aggregating sub-constraint directions.
1105
+
1106
+ Each sub-constraint contributes its corrective direction, weighted
1107
+ by its satisfaction mask. The directions are summed across constraints
1108
+ for each affected layer.
1109
+
1110
+ Args:
1111
+ data: Model predictions and associated batch/context information.
1112
+
1113
+ Returns:
1114
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1115
+ tensors. Each entry represents the aggregated correction to apply
1116
+ to that layer, based on the satisfaction-weighted sum of
1117
+ sub-constraint directions.
1118
+ """
1119
+ total_direction: dict[str, Tensor] = {}
1120
+
1121
+ # TODO vectorize this loop
1122
+ for constraint in self.constraints:
1123
+ # TODO improve efficiency by avoiding double computation?
1124
+ satisfaction, _ = constraint.check_constraint(data)
1125
+ direction = constraint.calculate_direction(data)
1126
+
1127
+ for layer, dir in direction.items():
1128
+ if layer not in total_direction:
1129
+ total_direction[layer] = satisfaction.unsqueeze(1) * dir
1130
+ else:
1131
+ total_direction[layer] += satisfaction.unsqueeze(1) * dir
1132
+
1133
+ return total_direction
1134
+
1135
+
1136
+ class ORConstraint(Constraint):
1137
+ """A composite constraint that enforces the logical OR of multiple constraints.
1138
+
1139
+ This class combines multiple sub-constraints and evaluates them jointly:
1140
+
1141
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1142
+ is satisfied (elementwise logical OR).
1143
+ * The corrective direction is computed by weighting each sub-constraint's
1144
+ direction with its satisfaction mask and summing across all sub-constraints.
1145
+ """
1146
+
1147
+ def __init__(
1148
+ self,
1149
+ *constraints: Constraint,
1150
+ name: str = None,
1151
+ monitor_only: bool = False,
1152
+ rescale_factor: Number = 1.5,
1153
+ ) -> None:
1154
+ """A composite constraint that enforces the logical OR of multiple constraints.
1155
+
1156
+ This class combines multiple sub-constraints and evaluates them jointly:
1157
+
1158
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1159
+ is satisfied (elementwise logical OR).
1160
+ * The corrective direction is computed by weighting each sub-constraint's
1161
+ direction with its satisfaction mask and summing across all sub-constraints.
1162
+
1163
+ Args:
1164
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1165
+ name (str, optional): A custom name for this constraint. If not provided,
1166
+ the name will be composed from the sub-constraint names joined with
1167
+ " OR ".
1168
+ monitor_only (bool, optional): If True, the constraint will be monitored
1169
+ but not enforced. Defaults to False.
1170
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1171
+ corrections. Defaults to 1.5.
1172
+
1173
+ Attributes:
1174
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1175
+ neurons (set): The union of neurons referenced by the sub-constraints.
1176
+ name (str): The name of the constraint (composed or custom).
1177
+ """
1178
+ # Type checking
1179
+ validate_iterable("constraints", constraints, Constraint)
842
1180
 
843
1181
  # Compose constraint name
844
- name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
1182
+ if not name:
1183
+ name = " OR ".join([constraint.name for constraint in constraints])
845
1184
 
846
1185
  # Init parent class
847
1186
  super().__init__(
848
- {neuron_name_a, neuron_name_b},
1187
+ set().union(*(constraint.tags for constraint in constraints)),
849
1188
  name,
850
1189
  monitor_only,
851
1190
  rescale_factor,
852
1191
  )
853
1192
 
854
1193
  # Init variables
855
- self.transformation_a = transformation_a
856
- self.transformation_b = transformation_b
857
- self.rtol = rtol
858
- self.atol = atol
1194
+ self.constraints = constraints
859
1195
 
860
- # Get layer name and feature index from neuron_name
861
- self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
862
- self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
863
- self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
864
- self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
1196
+ def check_constraint(self, data: dict[str, Tensor]):
1197
+ """Evaluate whether any sub-constraints are satisfied.
865
1198
 
866
- def check_constraint(
867
- self, prediction: dict[str, Tensor]
868
- ) -> tuple[Tensor, int]:
1199
+ Args:
1200
+ data: Model predictions and associated batch/context information.
869
1201
 
870
- # Select relevant columns
871
- selection_a = prediction[self.layer_a][:, self.index_a]
872
- selection_b = prediction[self.layer_b][:, self.index_b]
1202
+ Returns:
1203
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1204
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1205
+ elementwise whether any constraints are satisfied
1206
+ (logical OR).
1207
+ * `mask`: A tensor of ones with the same shape as
1208
+ `total_satisfaction`. Typically used as a weighting mask
1209
+ in downstream processing.
1210
+ """
1211
+ total_satisfaction: Tensor = None
1212
+ total_mask: Tensor = None
1213
+
1214
+ # TODO vectorize this loop
1215
+ for constraint in self.constraints:
1216
+ satisfaction, mask = constraint.check_constraint(data)
1217
+ if total_satisfaction is None:
1218
+ total_satisfaction = satisfaction
1219
+ total_mask = mask
1220
+ else:
1221
+ total_satisfaction = logical_or(total_satisfaction, satisfaction)
1222
+ total_mask = logical_or(total_mask, mask)
873
1223
 
874
- # Apply transformations
875
- selection_a = self.transformation_a(selection_a)
876
- selection_b = self.transformation_b(selection_b)
877
-
878
- # Calculate result
879
- result = isclose(
880
- square(selection_a) + square(selection_b),
881
- ones_like(selection_a, device=self.device),
882
- rtol=self.rtol,
883
- atol=self.atol,
884
- ).float()
885
-
886
- return result, numel(result)
887
-
888
- def calculate_direction(
889
- self, prediction: dict[str, Tensor]
890
- ) -> Dict[str, Tensor]:
891
- # NOTE currently only works for dense layers due
892
- # to neuron to index translation
1224
+ return total_satisfaction.float(), total_mask.float()
893
1225
 
894
- output = {}
1226
+ def calculate_direction(self, data: dict[str, Tensor]):
1227
+ """Compute the corrective direction by aggregating sub-constraint directions.
895
1228
 
896
- for layer in self.layers:
897
- output[layer] = zeros_like(prediction[layer], device=self.device)
1229
+ Each sub-constraint contributes its corrective direction, weighted
1230
+ by its satisfaction mask. The directions are summed across constraints
1231
+ for each affected layer.
1232
+
1233
+ Args:
1234
+ data: Model predictions and associated batch/context information.
898
1235
 
899
- a = prediction[self.layer_a][:, self.index_a]
900
- b = prediction[self.layer_b][:, self.index_b]
901
- m = sqrt(square(a) + square(b))
1236
+ Returns:
1237
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1238
+ tensors. Each entry represents the aggregated correction to apply
1239
+ to that layer, based on the satisfaction-weighted sum of
1240
+ sub-constraint directions.
1241
+ """
1242
+ total_direction: dict[str, Tensor] = {}
902
1243
 
903
- output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
904
- output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
1244
+ # TODO vectorize this loop
1245
+ for constraint in self.constraints:
1246
+ # TODO improve efficiency by avoiding double computation?
1247
+ satisfaction, _ = constraint.check_constraint(data)
1248
+ direction = constraint.calculate_direction(data)
905
1249
 
906
- return output
1250
+ for layer, dir in direction.items():
1251
+ if layer not in total_direction:
1252
+ total_direction[layer] = satisfaction.unsqueeze(1) * dir
1253
+ else:
1254
+ total_direction[layer] += satisfaction.unsqueeze(1) * dir
1255
+
1256
+ return total_direction