congrads 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/constraints.py CHANGED
@@ -1,27 +1,105 @@
1
- from abc import ABC, abstractmethod
2
- from numbers import Number
1
+ """
2
+ This module provides a set of constraint classes for guiding neural network
3
+ training by enforcing specific conditions on the network's outputs.
4
+
5
+ The constraints in this module include:
6
+
7
+ - `Constraint`: The base class for all constraint types, defining the
8
+ interface and core behavior.
9
+ - `ImplicationConstraint`: A constraint that enforces one condition only if
10
+ another condition is met, useful for modeling implications between network
11
+ outputs.
12
+ - `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
13
+ a network's output.
14
+ - `BinaryConstraint`: A constraint that enforces a binary comparison between
15
+ two neurons in the network, using a comparison function (e.g., less than,
16
+ greater than).
17
+ - `SumConstraint`: A constraint that enforces that the sum of certain neurons'
18
+ outputs equals a specified value, which can be used to control total output.
19
+ - `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
20
+ on a set of neurons, ensuring that the square of one neuron's output is equal
21
+ to the sum of the squares of other outputs.
22
+
23
+ These constraints can be used to steer the learning process by applying
24
+ conditions such as logical implications or numerical bounds.
25
+
26
+ Usage:
27
+ 1. Define a custom constraint class by inheriting from `Constraint`.
28
+ 2. Apply the constraint to your neural network during training to
29
+ enforce desired output behaviors.
30
+ 3. Use the helper classes like `IdentityTransformation` for handling
31
+ transformations and comparisons in constraints.
32
+
33
+ Dependencies:
34
+ - PyTorch (`torch`)
35
+ """
36
+
3
37
  import random
4
38
  import string
5
- from typing import Callable, Dict
39
+ import warnings
40
+ from abc import ABC, abstractmethod
41
+ from numbers import Number
42
+ from typing import Callable, Dict, Union
43
+
6
44
  from torch import (
7
45
  Tensor,
46
+ count_nonzero,
8
47
  ge,
9
48
  gt,
10
- lt,
49
+ isclose,
11
50
  le,
51
+ logical_not,
52
+ logical_or,
53
+ lt,
54
+ numel,
55
+ ones,
56
+ ones_like,
12
57
  reshape,
58
+ sign,
59
+ sqrt,
60
+ square,
13
61
  stack,
14
- ones,
15
62
  tensor,
16
63
  zeros_like,
17
64
  )
18
- import logging
19
65
  from torch.nn.functional import normalize
20
66
 
21
67
  from .descriptor import Descriptor
68
+ from .transformations import IdentityTransformation, Transformation
69
+ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
22
70
 
23
71
 
24
72
  class Constraint(ABC):
73
+ """
74
+ Abstract base class for defining constraints applied to neural networks.
75
+
76
+ A `Constraint` specifies conditions that the neural network outputs
77
+ should satisfy. It supports monitoring constraint satisfaction
78
+ during training and can adjust loss to enforce constraints. Subclasses
79
+ must implement the `check_constraint` and `calculate_direction` methods.
80
+
81
+ Args:
82
+ neurons (set[str]): Names of the neurons this constraint applies to.
83
+ name (str, optional): A unique name for the constraint. If not provided,
84
+ a name is generated based on the class name and a random suffix.
85
+ monitor_only (bool, optional): If True, only monitor the constraint
86
+ without adjusting the loss. Defaults to False.
87
+ rescale_factor (Number, optional): Factor to scale the
88
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
89
+ than 1 to give weight to the constraint.
90
+
91
+ Raises:
92
+ TypeError: If a provided attribute has an incompatible type.
93
+ ValueError: If any neuron in `neurons` is not
94
+ defined in the `descriptor`.
95
+
96
+ Note:
97
+ - If `rescale_factor <= 1`, a warning is issued, and the value is
98
+ adjusted to a positive value greater than 1.
99
+ - If `name` is not provided, a name is auto-generated,
100
+ and a warning is logged.
101
+
102
+ """
25
103
 
26
104
  descriptor: Descriptor = None
27
105
  device = None
@@ -30,23 +108,39 @@ class Constraint(ABC):
30
108
  self,
31
109
  neurons: set[str],
32
110
  name: str = None,
33
- rescale_factor: float = 1.5,
111
+ monitor_only: bool = False,
112
+ rescale_factor: Number = 1.5,
34
113
  ) -> None:
114
+ """
115
+ Initializes a new Constraint instance.
116
+ """
35
117
 
36
118
  # Init parent class
37
119
  super().__init__()
38
120
 
121
+ # Type checking
122
+ validate_iterable("neurons", neurons, str)
123
+ validate_type("name", name, (str, type(None)))
124
+ validate_type("monitor_only", monitor_only, bool)
125
+ validate_type("rescale_factor", rescale_factor, Number)
126
+
39
127
  # Init object variables
40
128
  self.neurons = neurons
41
129
  self.rescale_factor = rescale_factor
130
+ self.monitor_only = monitor_only
42
131
 
43
132
  # Perform checks
44
133
  if rescale_factor <= 1:
45
- logging.warning(
46
- f"Rescale factor for constraint {name} is <= 1. The network will favor general loss over the constraint-adjusted loss. Is this intended behaviour? Normally, the loss should always be larger than 1."
134
+ warnings.warn(
135
+ "Rescale factor for constraint %s is <= 1. The network \
136
+ will favor general loss over the constraint-adjusted loss. \
137
+ Is this intended behavior? Normally, the loss should \
138
+ always be larger than 1.",
139
+ name,
47
140
  )
48
141
 
49
- # If no constraint_name is set, generate one based on the class name and a random suffix
142
+ # If no constraint_name is set, generate one based
143
+ # on the class name and a random suffix
50
144
  if name:
51
145
  self.name = name
52
146
  else:
@@ -54,13 +148,19 @@ class Constraint(ABC):
54
148
  random.choices(string.ascii_uppercase + string.digits, k=6)
55
149
  )
56
150
  self.name = f"{self.__class__.__name__}_{random_suffix}"
57
- logging.warning(f"Name for constraint is not set. Using {self.name}.")
151
+ warnings.warn(
152
+ "Name for constraint is not set. Using %s.", self.name
153
+ )
58
154
 
59
155
  # If rescale factor is not larger than 1, warn user and adjust
60
- if not rescale_factor > 1:
156
+ if rescale_factor <= 1:
61
157
  self.rescale_factor = abs(rescale_factor) + 1.5
62
- logging.warning(
63
- f"Rescale factor for constraint {name} is < 1, adjusted value {rescale_factor} to {self.rescale_factor}."
158
+ warnings.warn(
159
+ "Rescale factor for constraint %s is < 1, adjusted value \
160
+ %s to %s.",
161
+ name,
162
+ rescale_factor,
163
+ self.rescale_factor,
64
164
  )
65
165
  else:
66
166
  self.rescale_factor = rescale_factor
@@ -70,83 +170,250 @@ class Constraint(ABC):
70
170
  for neuron in self.neurons:
71
171
  if neuron not in self.descriptor.neuron_to_layer.keys():
72
172
  raise ValueError(
73
- f'The neuron name {neuron} used with constraint {self.name} is not defined in the descriptor. Please add it to the correct layer using descriptor.add("layer", ...).'
173
+ f'The neuron name {neuron} used with constraint \
174
+ {self.name} is not defined in the descriptor. Please \
175
+ add it to the correct layer using \
176
+ descriptor.add("layer", ...).'
74
177
  )
75
178
 
76
179
  self.layers.add(self.descriptor.neuron_to_layer[neuron])
77
180
 
78
- # TODO only denormalize if required for efficiency
79
- def _denormalize(self, input: Tensor, neuron_names: list[str]):
80
- # Extract min and max for each neuron
81
- min_values = tensor(
82
- [self.descriptor.neuron_to_minmax[name][0] for name in neuron_names],
83
- device=input.device,
84
- )
85
- max_values = tensor(
86
- [self.descriptor.neuron_to_minmax[name][1] for name in neuron_names],
87
- device=input.device,
88
- )
89
-
90
- # Apply vectorized denormalization
91
- return input * (max_values - min_values) + min_values
92
-
93
181
  @abstractmethod
94
- def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
182
+ def check_constraint(
183
+ self, prediction: dict[str, Tensor]
184
+ ) -> tuple[Tensor, int]:
185
+ """
186
+ Evaluates whether the given model predictions satisfy the constraint.
187
+
188
+ Args:
189
+ prediction (dict[str, Tensor]): Model predictions for the neurons.
190
+
191
+ Returns:
192
+ tuple[Tensor, int]: A tuple where the first element is a tensor
193
+ indicating whether the constraint is satisfied (with `True`
194
+ for satisfaction, `False` for non-satisfaction, and `torch.nan`
195
+ for irrelevant results), and the second element is an integer
196
+ value representing the number of relevant constraints.
197
+
198
+ Raises:
199
+ NotImplementedError: If not implemented in a subclass.
200
+ """
201
+
95
202
  raise NotImplementedError
96
203
 
97
204
  @abstractmethod
98
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
205
+ def calculate_direction(
206
+ self, prediction: dict[str, Tensor]
207
+ ) -> Dict[str, Tensor]:
208
+ """
209
+ Calculates adjustment directions for neurons to
210
+ better satisfy the constraint.
211
+
212
+ Args:
213
+ prediction (dict[str, Tensor]): Model predictions for the neurons.
214
+
215
+ Returns:
216
+ Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
217
+ specifying the adjustment direction for each neuron.
218
+
219
+ Raises:
220
+ NotImplementedError: If not implemented in a subclass.
221
+ """
222
+
99
223
  raise NotImplementedError
100
224
 
101
225
 
226
+ class ImplicationConstraint(Constraint):
227
+ """
228
+ Represents an implication constraint between two
229
+ constraints (head and body).
230
+
231
+ The implication constraint ensures that the `body` constraint only applies
232
+ when the `head` constraint is satisfied. If the `head` constraint is not
233
+ satisfied, the `body` constraint does not apply.
234
+
235
+ Args:
236
+ head (Constraint): The head of the implication. If this constraint
237
+ is satisfied, the body constraint must also be satisfied.
238
+ body (Constraint): The body of the implication. This constraint
239
+ is enforced only when the head constraint is satisfied.
240
+ name (str, optional): A unique name for the constraint. If not
241
+ provided, the name is generated in the format
242
+ "{body.name} if {head.name}". Defaults to None.
243
+ monitor_only (bool, optional): If True, the constraint is only
244
+ monitored without adjusting the loss. Defaults to False.
245
+ rescale_factor (Number, optional): The scaling factor for the
246
+ constraint-adjusted loss. Defaults to 1.5.
247
+
248
+ Raises:
249
+ TypeError: If a provided attribute has an incompatible type.
250
+
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ head: Constraint,
256
+ body: Constraint,
257
+ name=None,
258
+ monitor_only=False,
259
+ rescale_factor=1.5,
260
+ ):
261
+ """
262
+ Initializes an ImplicationConstraint instance.
263
+ """
264
+
265
+ # Type checking
266
+ validate_type("head", head, Constraint)
267
+ validate_type("body", body, Constraint)
268
+
269
+ # Compose constraint name
270
+ name = f"{body.name} if {head.name}"
271
+
272
+ # Init parent class
273
+ super().__init__(
274
+ head.neurons | body.neurons,
275
+ name,
276
+ monitor_only,
277
+ rescale_factor,
278
+ )
279
+
280
+ self.head = head
281
+ self.body = body
282
+
283
+ def check_constraint(
284
+ self, prediction: dict[str, Tensor]
285
+ ) -> tuple[Tensor, int]:
286
+
287
+ # Check satisfaction of head and body constraints
288
+ head_satisfaction, _ = self.head.check_constraint(prediction)
289
+ body_satisfaction, _ = self.body.check_constraint(prediction)
290
+
291
+ # If head constraint is satisfied (returning 1),
292
+ # the body constraint matters (and should return 0/1 based on body)
293
+ # If head constraint is not satisfied (returning 0),
294
+ # the body constraint does not apply (and should return 1)
295
+ result = logical_or(
296
+ logical_not(head_satisfaction), body_satisfaction
297
+ ).float()
298
+
299
+ return result, count_nonzero(head_satisfaction)
300
+
301
+ def calculate_direction(
302
+ self, prediction: dict[str, Tensor]
303
+ ) -> Dict[str, Tensor]:
304
+ # NOTE currently only works for dense layers
305
+ # due to neuron to index translation
306
+
307
+ # Use directions of constraint body as update vector
308
+ return self.body.calculate_direction(prediction)
309
+
310
+
102
311
  class ScalarConstraint(Constraint):
312
+ """
313
+ A constraint that enforces scalar-based comparisons on a specific neuron.
314
+
315
+ This class ensures that the output of a specified neuron satisfies a scalar
316
+ comparison operation (e.g., less than, greater than, etc.). It uses a
317
+ comparator function to validate the condition and calculates adjustment
318
+ directions accordingly.
319
+
320
+ Args:
321
+ operand (Union[str, Transformation]): Name of the neuron or a
322
+ transformation to apply.
323
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
324
+ function (e.g., `torch.ge`, `torch.lt`).
325
+ scalar (Number): The scalar value to compare against.
326
+ name (str, optional): A unique name for the constraint. If not
327
+ provided, a name is auto-generated in the format
328
+ "<neuron_name> <comparator> <scalar>".
329
+ monitor_only (bool, optional): If True, only monitor the constraint
330
+ without adjusting the loss. Defaults to False.
331
+ rescale_factor (Number, optional): Factor to scale the
332
+ constraint-adjusted loss. Defaults to 1.5.
333
+
334
+ Raises:
335
+ TypeError: If a provided attribute has an incompatible type.
336
+
337
+ Notes:
338
+ - The `neuron_name` must be defined in the `descriptor` mapping.
339
+ - The constraint name is composed using the neuron name,
340
+ comparator, and scalar value.
341
+
342
+ """
103
343
 
104
344
  def __init__(
105
345
  self,
106
- neuron_name: str,
346
+ operand: Union[str, Transformation],
107
347
  comparator: Callable[[Tensor, Number], Tensor],
108
348
  scalar: Number,
109
349
  name: str = None,
110
- rescale_factor: float = 1.5,
350
+ monitor_only: bool = False,
351
+ rescale_factor: Number = 1.5,
111
352
  ) -> None:
353
+ """
354
+ Initializes a ScalarConstraint instance.
355
+ """
356
+
357
+ # Type checking
358
+ validate_type("operand", operand, (str, Transformation))
359
+ validate_comparator_pytorch("comparator", comparator)
360
+ validate_comparator_pytorch("comparator", comparator)
361
+ validate_type("scalar", scalar, Number)
362
+
363
+ # If transformation is provided, get neuron name,
364
+ # else use IdentityTransformation
365
+ if isinstance(operand, Transformation):
366
+ neuron_name = operand.neuron_name
367
+ transformation = operand
368
+ else:
369
+ neuron_name = operand
370
+ transformation = IdentityTransformation(neuron_name)
112
371
 
113
372
  # Compose constraint name
114
- name = f"{neuron_name}_{comparator.__name__}_{str(scalar)}"
373
+ name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
115
374
 
116
375
  # Init parent class
117
- super().__init__({neuron_name}, name, rescale_factor)
376
+ super().__init__({neuron_name}, name, monitor_only, rescale_factor)
118
377
 
119
378
  # Init variables
120
379
  self.comparator = comparator
121
380
  self.scalar = scalar
381
+ self.transformation = transformation
122
382
 
123
383
  # Get layer name and feature index from neuron_name
124
384
  self.layer = self.descriptor.neuron_to_layer[neuron_name]
125
385
  self.index = self.descriptor.neuron_to_index[neuron_name]
126
386
 
127
- # If comparator function is not supported, raise error
128
- if comparator not in [ge, le, gt, lt]:
129
- raise ValueError(
130
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
131
- )
132
-
133
387
  # Calculate directions based on constraint operator
134
388
  if self.comparator in [lt, le]:
135
389
  self.direction = -1
136
390
  elif self.comparator in [gt, ge]:
137
391
  self.direction = 1
138
392
 
139
- def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
393
+ def check_constraint(
394
+ self, prediction: dict[str, Tensor]
395
+ ) -> tuple[Tensor, int]:
140
396
 
141
- return ~self.comparator(prediction[self.layer][:, self.index], self.scalar)
397
+ # Select relevant columns
398
+ selection = prediction[self.layer][:, self.index]
142
399
 
143
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
144
- # NOTE currently only works for dense layers due to neuron to index translation
400
+ # Apply transformation
401
+ selection = self.transformation(selection)
402
+
403
+ # Calculate current constraint result
404
+ result = self.comparator(selection, self.scalar).float()
405
+ return result, numel(result)
406
+
407
+ def calculate_direction(
408
+ self, prediction: dict[str, Tensor]
409
+ ) -> Dict[str, Tensor]:
410
+ # NOTE currently only works for dense layers due
411
+ # to neuron to index translation
145
412
 
146
413
  output = {}
147
414
 
148
415
  for layer in self.layers:
149
- output[layer] = zeros_like(prediction[layer][0])
416
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
150
417
 
151
418
  output[self.layer][self.index] = self.direction
152
419
 
@@ -157,28 +424,89 @@ class ScalarConstraint(Constraint):
157
424
 
158
425
 
159
426
  class BinaryConstraint(Constraint):
427
+ """
428
+ A constraint that enforces a binary comparison between two neurons.
429
+
430
+ This class ensures that the output of one neuron satisfies a comparison
431
+ operation with the output of another neuron
432
+ (e.g., less than, greater than, etc.). It uses a comparator function to
433
+ validate the condition and calculates adjustment directions accordingly.
434
+
435
+ Args:
436
+ operand_left (Union[str, Transformation]): Name of the left
437
+ neuron or a transformation to apply.
438
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
439
+ function (e.g., `torch.ge`, `torch.lt`).
440
+ operand_right (Union[str, Transformation]): Name of the right
441
+ neuron or a transformation to apply.
442
+ name (str, optional): A unique name for the constraint. If not
443
+ provided, a name is auto-generated in the format
444
+ "<neuron_name_left> <comparator> <neuron_name_right>".
445
+ monitor_only (bool, optional): If True, only monitor the constraint
446
+ without adjusting the loss. Defaults to False.
447
+ rescale_factor (Number, optional): Factor to scale the
448
+ constraint-adjusted loss. Defaults to 1.5.
449
+
450
+ Raises:
451
+ TypeError: If a provided attribute has an incompatible type.
452
+
453
+ Notes:
454
+ - The neuron names must be defined in the `descriptor` mapping.
455
+ - The constraint name is composed using the left neuron name,
456
+ comparator, and right neuron name.
457
+
458
+ """
160
459
 
161
460
  def __init__(
162
461
  self,
163
- neuron_name_left: str,
462
+ operand_left: Union[str, Transformation],
164
463
  comparator: Callable[[Tensor, Number], Tensor],
165
- neuron_name_right: str,
464
+ operand_right: Union[str, Transformation],
166
465
  name: str = None,
167
- rescale_factor: float = 1.5,
466
+ monitor_only: bool = False,
467
+ rescale_factor: Number = 1.5,
168
468
  ) -> None:
469
+ """
470
+ Initializes a BinaryConstraint instance.
471
+ """
472
+
473
+ # Type checking
474
+ validate_type("operand_left", operand_left, (str, Transformation))
475
+ validate_comparator_pytorch("comparator", comparator)
476
+ validate_comparator_pytorch("comparator", comparator)
477
+ validate_type("operand_right", operand_right, (str, Transformation))
478
+
479
+ # If transformation is provided, get neuron name,
480
+ # else use IdentityTransformation
481
+ if isinstance(operand_left, Transformation):
482
+ neuron_name_left = operand_left.neuron_name
483
+ transformation_left = operand_left
484
+ else:
485
+ neuron_name_left = operand_left
486
+ transformation_left = IdentityTransformation(neuron_name_left)
487
+
488
+ if isinstance(operand_right, Transformation):
489
+ neuron_name_right = operand_right.neuron_name
490
+ transformation_right = operand_right
491
+ else:
492
+ neuron_name_right = operand_right
493
+ transformation_right = IdentityTransformation(neuron_name_right)
169
494
 
170
495
  # Compose constraint name
171
- name = f"{neuron_name_left}_{comparator.__name__}_{neuron_name_right}"
496
+ name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
172
497
 
173
498
  # Init parent class
174
499
  super().__init__(
175
500
  {neuron_name_left, neuron_name_right},
176
501
  name,
502
+ monitor_only,
177
503
  rescale_factor,
178
504
  )
179
505
 
180
506
  # Init variables
181
507
  self.comparator = comparator
508
+ self.transformation_left = transformation_left
509
+ self.transformation_right = transformation_right
182
510
 
183
511
  # Get layer name and feature index from neuron_name
184
512
  self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
@@ -186,12 +514,6 @@ class BinaryConstraint(Constraint):
186
514
  self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
187
515
  self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
188
516
 
189
- # If comparator function is not supported, raise error
190
- if comparator not in [ge, le, gt, lt]:
191
- raise RuntimeError(
192
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
193
- )
194
-
195
517
  # Calculate directions based on constraint operator
196
518
  if self.comparator in [lt, le]:
197
519
  self.direction_left = -1
@@ -200,20 +522,32 @@ class BinaryConstraint(Constraint):
200
522
  self.direction_left = 1
201
523
  self.direction_right = -1
202
524
 
203
- def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
525
+ def check_constraint(
526
+ self, prediction: dict[str, Tensor]
527
+ ) -> tuple[Tensor, int]:
204
528
 
205
- return ~self.comparator(
206
- prediction[self.layer_left][:, self.index_left],
207
- prediction[self.layer_right][:, self.index_right],
208
- )
529
+ # Select relevant columns
530
+ selection_left = prediction[self.layer_left][:, self.index_left]
531
+ selection_right = prediction[self.layer_right][:, self.index_right]
532
+
533
+ # Apply transformations
534
+ selection_left = self.transformation_left(selection_left)
535
+ selection_right = self.transformation_right(selection_right)
536
+
537
+ result = self.comparator(selection_left, selection_right).float()
209
538
 
210
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
211
- # NOTE currently only works for dense layers due to neuron to index translation
539
+ return result, numel(result)
540
+
541
+ def calculate_direction(
542
+ self, prediction: dict[str, Tensor]
543
+ ) -> Dict[str, Tensor]:
544
+ # NOTE currently only works for dense layers due
545
+ # to neuron to index translation
212
546
 
213
547
  output = {}
214
548
 
215
549
  for layer in self.layers:
216
- output[layer] = zeros_like(prediction[layer][0])
550
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
217
551
 
218
552
  output[self.layer_left][self.index_left] = self.direction_left
219
553
  output[self.layer_right][self.index_right] = self.direction_right
@@ -225,40 +559,129 @@ class BinaryConstraint(Constraint):
225
559
 
226
560
 
227
561
  class SumConstraint(Constraint):
562
+ """
563
+ A constraint that enforces a weighted summation comparison
564
+ between two groups of neurons.
565
+
566
+ This class evaluates whether the weighted sum of outputs from one set of
567
+ neurons satisfies a comparison operation with the weighted sum of
568
+ outputs from another set of neurons.
569
+
570
+ Args:
571
+ operands_left (list[Union[str, Transformation]]): List of neuron
572
+ names or transformations on the left side.
573
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
574
+ function for the constraint.
575
+ operands_right (list[Union[str, Transformation]]): List of neuron
576
+ names or transformations on the right side.
577
+ weights_left (list[Number], optional): Weights for the left neurons.
578
+ Defaults to None.
579
+ weights_right (list[Number], optional): Weights for the right
580
+ neurons. Defaults to None.
581
+ name (str, optional): Unique name for the constraint.
582
+ If None, it's auto-generated. Defaults to None.
583
+ monitor_only (bool, optional): If True, only monitor the constraint
584
+ without adjusting the loss. Defaults to False.
585
+ rescale_factor (Number, optional): Factor to scale the
586
+ constraint-adjusted loss. Defaults to 1.5.
587
+
588
+ Raises:
589
+ TypeError: If a provided attribute has an incompatible type.
590
+ ValueError: If the dimensions of neuron names and weights mismatch.
591
+
592
+ """
593
+
228
594
  def __init__(
229
595
  self,
230
- neuron_names_left: list[str],
596
+ operands_left: list[Union[str, Transformation]],
231
597
  comparator: Callable[[Tensor, Number], Tensor],
232
- neuron_names_right: list[str],
233
- weights_left: list[float] = None,
234
- weights_right: list[float] = None,
598
+ operands_right: list[Union[str, Transformation]],
599
+ weights_left: list[Number] = None,
600
+ weights_right: list[Number] = None,
235
601
  name: str = None,
236
- rescale_factor: float = 1.5,
602
+ monitor_only: bool = False,
603
+ rescale_factor: Number = 1.5,
237
604
  ) -> None:
605
+ """
606
+ Initializes the SumConstraint.
607
+ """
608
+
609
+ # Type checking
610
+ validate_iterable("operands_left", operands_left, (str, Transformation))
611
+ validate_comparator_pytorch("comparator", comparator)
612
+ validate_comparator_pytorch("comparator", comparator)
613
+ validate_iterable(
614
+ "operands_right", operands_right, (str, Transformation)
615
+ )
616
+ validate_iterable("weights_left", weights_left, Number, allow_none=True)
617
+ validate_iterable(
618
+ "weights_right", weights_right, Number, allow_none=True
619
+ )
620
+
621
+ # If transformation is provided, get neuron name,
622
+ # else use IdentityTransformation
623
+ neuron_names_left: list[str] = []
624
+ transformations_left: list[Transformation] = []
625
+ for operand_left in operands_left:
626
+ if isinstance(operand_left, Transformation):
627
+ neuron_name_left = operand_left.neuron_name
628
+ neuron_names_left.append(neuron_name_left)
629
+ transformations_left.append(operand_left)
630
+ else:
631
+ neuron_name_left = operand_left
632
+ neuron_names_left.append(neuron_name_left)
633
+ transformations_left.append(
634
+ IdentityTransformation(neuron_name_left)
635
+ )
636
+
637
+ neuron_names_right: list[str] = []
638
+ transformations_right: list[Transformation] = []
639
+ for operand_right in operands_right:
640
+ if isinstance(operand_right, Transformation):
641
+ neuron_name_right = operand_right.neuron_name
642
+ neuron_names_right.append(neuron_name_right)
643
+ transformations_right.append(operand_right)
644
+ else:
645
+ neuron_name_right = operand_right
646
+ neuron_names_right.append(neuron_name_right)
647
+ transformations_right.append(
648
+ IdentityTransformation(neuron_name_right)
649
+ )
650
+
651
+ # Compose constraint name
652
+ w_left = weights_left or [""] * len(neuron_names_left)
653
+ w_right = weights_right or [""] * len(neuron_names_right)
654
+ left_expr = " + ".join(
655
+ f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
656
+ )
657
+ right_expr = " + ".join(
658
+ f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
659
+ )
660
+ comparator_name = comparator.__name__
661
+ name = f"{left_expr} {comparator_name} {right_expr}"
238
662
 
239
663
  # Init parent class
240
664
  neuron_names = set(neuron_names_left) | set(neuron_names_right)
241
- super().__init__(neuron_names, name, rescale_factor)
665
+ super().__init__(neuron_names, name, monitor_only, rescale_factor)
242
666
 
243
667
  # Init variables
244
668
  self.comparator = comparator
245
669
  self.neuron_names_left = neuron_names_left
246
670
  self.neuron_names_right = neuron_names_right
671
+ self.transformations_left = transformations_left
672
+ self.transformations_right = transformations_right
247
673
 
248
- # If comparator function is not supported, raise error
249
- if comparator not in [ge, le, gt, lt]:
250
- raise ValueError(
251
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
252
- )
253
-
254
- # If feature list dimensions don't match weight list dimensions, raise error
674
+ # If feature list dimensions don't match
675
+ # weight list dimensions, raise error
255
676
  if weights_left and (len(neuron_names_left) != len(weights_left)):
256
677
  raise ValueError(
257
- "The dimensions of neuron_names_left don't match with the dimensions of weights_left."
678
+ "The dimensions of neuron_names_left don't match with the \
679
+ dimensions of weights_left."
258
680
  )
259
681
  if weights_right and (len(neuron_names_right) != len(weights_right)):
260
682
  raise ValueError(
261
- "The dimensions of neuron_names_right don't match with the dimensions of weights_right."
683
+ "The dimensions of neuron_names_right don't match with the \
684
+ dimensions of weights_right."
262
685
  )
263
686
 
264
687
  # If weights are provided for summation, transform them to Tensors
@@ -269,7 +692,9 @@ class SumConstraint(Constraint):
269
692
  if weights_right:
270
693
  self.weights_right = tensor(weights_right, device=self.device)
271
694
  else:
272
- self.weights_right = ones(len(neuron_names_right), device=self.device)
695
+ self.weights_right = ones(
696
+ len(neuron_names_right), device=self.device
697
+ )
273
698
 
274
699
  # Calculate directions based on constraint operator
275
700
  if self.comparator in [lt, le]:
@@ -279,9 +704,15 @@ class SumConstraint(Constraint):
279
704
  self.direction_left = 1
280
705
  self.direction_right = -1
281
706
 
282
- def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
707
+ def check_constraint(
708
+ self, prediction: dict[str, Tensor]
709
+ ) -> tuple[Tensor, int]:
283
710
 
284
- def compute_weighted_sum(neuron_names: list[str], weights: tensor) -> tensor:
711
+ def compute_weighted_sum(
712
+ neuron_names: list[str],
713
+ transformations: list[Transformation],
714
+ weights: tensor,
715
+ ) -> tensor:
285
716
  layers = [
286
717
  self.descriptor.neuron_to_layer[neuron_name]
287
718
  for neuron_name in neuron_names
@@ -291,37 +722,53 @@ class SumConstraint(Constraint):
291
722
  for neuron_name in neuron_names
292
723
  ]
293
724
 
725
+ # Select relevant column
726
+ selections = [
727
+ prediction[layer][:, index]
728
+ for layer, index in zip(layers, indices)
729
+ ]
730
+
731
+ # Apply transformations
732
+ results = []
733
+ for transformation, selection in zip(transformations, selections):
734
+ results.append(transformation(selection))
735
+
294
736
  # Extract predictions for all neurons and apply weights in bulk
295
737
  predictions = stack(
296
- [prediction[layer][:, index] for layer, index in zip(layers, indices)],
738
+ results,
297
739
  dim=1,
298
740
  )
299
741
 
300
- # Denormalize if required
301
- predictions_denorm = self._denormalize(predictions, neuron_names)
302
-
303
742
  # Calculate weighted sum
304
- weighted_sum = (predictions_denorm * weights.unsqueeze(0)).sum(dim=1)
305
-
306
- return weighted_sum
743
+ return (predictions * weights.unsqueeze(0)).sum(dim=1)
307
744
 
745
+ # Compute weighted sums
308
746
  weighted_sum_left = compute_weighted_sum(
309
- self.neuron_names_left, self.weights_left
747
+ self.neuron_names_left,
748
+ self.transformations_left,
749
+ self.weights_left,
310
750
  )
311
751
  weighted_sum_right = compute_weighted_sum(
312
- self.neuron_names_right, self.weights_right
752
+ self.neuron_names_right,
753
+ self.transformations_right,
754
+ self.weights_right,
313
755
  )
314
756
 
315
757
  # Apply the comparator and calculate the result
316
- return ~self.comparator(weighted_sum_left, weighted_sum_right)
758
+ result = self.comparator(weighted_sum_left, weighted_sum_right).float()
759
+
760
+ return result, numel(result)
317
761
 
318
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
319
- # NOTE currently only works for dense layers due to neuron to index translation
762
+ def calculate_direction(
763
+ self, prediction: dict[str, Tensor]
764
+ ) -> Dict[str, Tensor]:
765
+ # NOTE currently only works for dense layers
766
+ # due to neuron to index translation
320
767
 
321
768
  output = {}
322
769
 
323
770
  for layer in self.layers:
324
- output[layer] = zeros_like(prediction[layer][0])
771
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
325
772
 
326
773
  for neuron_name_left in self.neuron_names_left:
327
774
  layer = self.descriptor.neuron_to_layer[neuron_name_left]
@@ -339,51 +786,134 @@ class SumConstraint(Constraint):
339
786
  return output
340
787
 
341
788
 
342
- # class MonotonicityConstraint(Constraint):
343
- # # TODO docstring
789
+ class PythagoreanIdentityConstraint(Constraint):
790
+ """
791
+ A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
792
+ where `a` and `b` are neurons or transformations.
793
+
794
+ This constraint checks that the sum of the squares of two specified
795
+ neurons (or their transformations) is approximately equal to 1.
796
+ The constraint is evaluated using relative and absolute
797
+ tolerance (`rtol` and `atol`) and is applied during the forward pass.
798
+
799
+ Args:
800
+ a (Union[str, Transformation]): The first input, either a
801
+ neuron name (str) or a Transformation.
802
+ b (Union[str, Transformation]): The second input, either a
803
+ neuron name (str) or a Transformation.
804
+ rtol (float, optional): The relative tolerance for the
805
+ comparison (default is 0.00001).
806
+ atol (float, optional): The absolute tolerance for the
807
+ comparison (default is 1e-8).
808
+ name (str, optional): The name of the constraint
809
+ (default is None, and it is generated automatically).
810
+ monitor_only (bool, optional): Flag indicating whether the
811
+ constraint is only for monitoring (default is False).
812
+ rescale_factor (Number, optional): A factor used for
813
+ rescaling (default is 1.5).
814
+
815
+ Raises:
816
+ TypeError: If a provided attribute has an incompatible type.
817
+
818
+ """
344
819
 
345
- # def __init__(
346
- # self,
347
- # neuron_name: str,
348
- # name: str = None,
349
- # descriptor: Descriptor = None,
350
- # rescale_factor: float = 1.5,
351
- # ) -> None:
820
+ def __init__(
821
+ self,
822
+ a: Union[str, Transformation],
823
+ b: Union[str, Transformation],
824
+ rtol: float = 0.00001,
825
+ atol: float = 1e-8,
826
+ name: str = None,
827
+ monitor_only: bool = False,
828
+ rescale_factor: Number = 1.5,
829
+ ) -> None:
830
+ """
831
+ Initialize the PythagoreanIdentityConstraint.
832
+ """
833
+
834
+ # Type checking
835
+ validate_type("a", a, (str, Transformation))
836
+ validate_type("b", b, (str, Transformation))
837
+ validate_type("rtol", rtol, float)
838
+ validate_type("atol", atol, float)
839
+
840
+ # If transformation is provided, get neuron name,
841
+ # else use IdentityTransformation
842
+ if isinstance(a, Transformation):
843
+ neuron_name_a = a.neuron_name
844
+ transformation_a = a
845
+ else:
846
+ neuron_name_a = a
847
+ transformation_a = IdentityTransformation(neuron_name_a)
352
848
 
353
- # # Compose constraint name
354
- # name = f"Monotonicity_{neuron_name}"
849
+ if isinstance(b, Transformation):
850
+ neuron_name_b = b.neuron_name
851
+ transformation_b = b
852
+ else:
853
+ neuron_name_b = b
854
+ transformation_b = IdentityTransformation(neuron_name_b)
355
855
 
356
- # # Init parent class
357
- # super().__init__({neuron_name}, name, rescale_factor)
856
+ # Compose constraint name
857
+ name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
858
+
859
+ # Init parent class
860
+ super().__init__(
861
+ {neuron_name_a, neuron_name_b},
862
+ name,
863
+ monitor_only,
864
+ rescale_factor,
865
+ )
358
866
 
359
- # # Init variables
360
- # if descriptor != None:
361
- # self.descriptor = descriptor
362
- # self.run_init_descriptor()
867
+ # Init variables
868
+ self.transformation_a = transformation_a
869
+ self.transformation_b = transformation_b
870
+ self.rtol = rtol
871
+ self.atol = atol
363
872
 
364
- # # Get layer name and feature index from neuron_name
365
- # self.layer = self.descriptor.neuron_to_layer[neuron_name]
366
- # self.index = self.descriptor.neuron_to_index[neuron_name]
873
+ # Get layer name and feature index from neuron_name
874
+ self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
875
+ self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
876
+ self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
877
+ self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
878
+
879
+ def check_constraint(
880
+ self, prediction: dict[str, Tensor]
881
+ ) -> tuple[Tensor, int]:
882
+
883
+ # Select relevant columns
884
+ selection_a = prediction[self.layer_a][:, self.index_a]
885
+ selection_b = prediction[self.layer_b][:, self.index_b]
886
+
887
+ # Apply transformations
888
+ selection_a = self.transformation_a(selection_a)
889
+ selection_b = self.transformation_b(selection_b)
890
+
891
+ # Calculate result
892
+ result = isclose(
893
+ square(selection_a) + square(selection_b),
894
+ ones_like(selection_a, device=self.device),
895
+ rtol=self.rtol,
896
+ atol=self.atol,
897
+ ).float()
898
+
899
+ return result, numel(result)
900
+
901
+ def calculate_direction(
902
+ self, prediction: dict[str, Tensor]
903
+ ) -> Dict[str, Tensor]:
904
+ # NOTE currently only works for dense layers due
905
+ # to neuron to index translation
367
906
 
368
- # def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
369
- # # Check if values for column in batch are only increasing
370
- # result = ~ge(
371
- # diff(
372
- # prediction[self.layer][:, self.index],
373
- # prepend=zeros_like(
374
- # prediction[self.layer][:, self.index][:1],
375
- # device=prediction[self.layer].device,
376
- # ),
377
- # ),
378
- # 0,
379
- # )
907
+ output = {}
380
908
 
381
- # return {self.layer: result}
909
+ for layer in self.layers:
910
+ output[layer] = zeros_like(prediction[layer], device=self.device)
382
911
 
383
- # def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
384
- # # TODO implement
912
+ a = prediction[self.layer_a][:, self.index_a]
913
+ b = prediction[self.layer_b][:, self.index_b]
914
+ m = sqrt(square(a) + square(b))
385
915
 
386
- # output = {self.layer: zeros_like(prediction[self.layer][0])}
387
- # output[self.layer][self.index] = 1
916
+ output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
917
+ output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
388
918
 
389
- # return output
919
+ return output