congrads 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/constraints.py CHANGED
@@ -1,312 +1,512 @@
1
- from abc import ABC, abstractmethod
2
- from numbers import Number
1
+ """
2
+ This module provides a set of constraint classes for guiding neural network
3
+ training by enforcing specific conditions on the network's outputs.
4
+
5
+ The constraints in this module include:
6
+
7
+ - `Constraint`: The base class for all constraint types, defining the
8
+ interface and core behavior.
9
+ - `ImplicationConstraint`: A constraint that enforces one condition only if
10
+ another condition is met, useful for modeling implications between network
11
+ outputs.
12
+ - `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
13
+ a network's output.
14
+ - `BinaryConstraint`: A constraint that enforces a binary comparison between
15
+ two neurons in the network, using a comparison function (e.g., less than,
16
+ greater than).
17
+ - `SumConstraint`: A constraint that enforces that the sum of certain neurons'
18
+ outputs equals a specified value, which can be used to control total output.
19
+ - `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
20
+ on a set of neurons, ensuring that the square of one neuron's output is equal
21
+ to the sum of the squares of other outputs.
22
+
23
+ These constraints can be used to steer the learning process by applying
24
+ conditions such as logical implications or numerical bounds.
25
+
26
+ Usage:
27
+ 1. Define a custom constraint class by inheriting from `Constraint`.
28
+ 2. Apply the constraint to your neural network during training to
29
+ enforce desired output behaviors.
30
+ 3. Use the helper classes like `IdentityTransformation` for handling
31
+ transformations and comparisons in constraints.
32
+
33
+ Dependencies:
34
+ - PyTorch (`torch`)
35
+ """
36
+
3
37
  import random
4
38
  import string
5
- from typing import Callable, Dict
6
- from torch import Tensor, ge, gt, lt, le, zeros, FloatTensor, ones, tensor, float32
7
- import logging
39
+ import warnings
40
+ from abc import ABC, abstractmethod
41
+ from numbers import Number
42
+ from typing import Callable, Dict, Union
43
+
44
+ from torch import (
45
+ Tensor,
46
+ count_nonzero,
47
+ ge,
48
+ gt,
49
+ isclose,
50
+ le,
51
+ logical_not,
52
+ logical_or,
53
+ lt,
54
+ numel,
55
+ ones,
56
+ ones_like,
57
+ reshape,
58
+ sign,
59
+ sqrt,
60
+ square,
61
+ stack,
62
+ tensor,
63
+ zeros_like,
64
+ )
8
65
  from torch.nn.functional import normalize
9
66
 
10
67
  from .descriptor import Descriptor
68
+ from .transformations import IdentityTransformation, Transformation
69
+ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
11
70
 
12
71
 
13
72
  class Constraint(ABC):
14
73
  """
15
- Abstract base class for defining constraints that can be applied during optimization in the constraint-guided gradient descent process.
16
-
17
- A constraint guides the optimization by evaluating the model's predictions and adjusting the loss based on certain conditions.
18
- Constraints can be applied to specific layers or neurons of the model, and they are scaled by a rescale factor to control the influence of the constraint on the overall loss.
74
+ Abstract base class for defining constraints applied to neural networks.
75
+
76
+ A `Constraint` specifies conditions that the neural network outputs
77
+ should satisfy. It supports monitoring constraint satisfaction
78
+ during training and can adjust loss to enforce constraints. Subclasses
79
+ must implement the `check_constraint` and `calculate_direction` methods.
80
+
81
+ Args:
82
+ neurons (set[str]): Names of the neurons this constraint applies to.
83
+ name (str, optional): A unique name for the constraint. If not provided,
84
+ a name is generated based on the class name and a random suffix.
85
+ monitor_only (bool, optional): If True, only monitor the constraint
86
+ without adjusting the loss. Defaults to False.
87
+ rescale_factor (Number, optional): Factor to scale the
88
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
89
+ than 1 to give weight to the constraint.
90
+
91
+ Raises:
92
+ TypeError: If a provided attribute has an incompatible type.
93
+ ValueError: If any neuron in `neurons` is not
94
+ defined in the `descriptor`.
95
+
96
+ Note:
97
+ - If `rescale_factor <= 1`, a warning is issued, and the value is
98
+ adjusted to a positive value greater than 1.
99
+ - If `name` is not provided, a name is auto-generated,
100
+ and a warning is logged.
19
101
 
20
- Attributes:
21
- descriptor (Descriptor): The descriptor object that provides a mapping of neurons to layers.
22
- constraint_name (str): A unique name for the constraint, which can be provided or generated automatically.
23
- rescale_factor (float): A factor used to scale the influence of the constraint on the overall loss.
24
- neuron_names (set[str]): A set of neuron names that are involved in the constraint.
25
- layers (set): A set of layers associated with the neurons specified in `neuron_names`.
26
102
  """
27
103
 
28
104
  descriptor: Descriptor = None
105
+ device = None
29
106
 
30
107
  def __init__(
31
108
  self,
32
- neuron_names: set[str],
33
- constraint_name: str = None,
34
- rescale_factor: float = 1.5,
109
+ neurons: set[str],
110
+ name: str = None,
111
+ monitor_only: bool = False,
112
+ rescale_factor: Number = 1.5,
35
113
  ) -> None:
36
114
  """
37
- Initializes the Constraint object with the given neuron names, constraint name, and rescale factor.
38
-
39
- Args:
40
- neuron_names (set[str]): A set of neuron names that are affected by the constraint.
41
- constraint_name (str, optional): A custom name for the constraint. If not provided, a random name is generated.
42
- rescale_factor (float, optional): A factor that scales the influence of the constraint. Defaults to 1.5.
43
-
44
- Raises:
45
- ValueError: If the descriptor has not been set or if a neuron name is not found in the descriptor.
115
+ Initializes a new Constraint instance.
46
116
  """
47
117
 
48
118
  # Init parent class
49
119
  super().__init__()
50
120
 
121
+ # Type checking
122
+ validate_iterable("neurons", neurons, str)
123
+ validate_type("name", name, (str, type(None)))
124
+ validate_type("monitor_only", monitor_only, bool)
125
+ validate_type("rescale_factor", rescale_factor, Number)
126
+
51
127
  # Init object variables
128
+ self.neurons = neurons
52
129
  self.rescale_factor = rescale_factor
53
- self.neuron_names = neuron_names
130
+ self.monitor_only = monitor_only
54
131
 
55
132
  # Perform checks
56
133
  if rescale_factor <= 1:
57
- logging.warning(
58
- f"Rescale factor for constraint {constraint_name} is <= 1. The network will favour general loss over the constraint-adjusted loss. Is this intended behaviour? Normally, the loss should always be larger than 1."
134
+ warnings.warn(
135
+ "Rescale factor for constraint %s is <= 1. The network \
136
+ will favor general loss over the constraint-adjusted loss. \
137
+ Is this intended behavior? Normally, the loss should \
138
+ always be larger than 1.",
139
+ name,
59
140
  )
60
141
 
61
- # If no constraint_name is set, generate one based on the class name and a random suffix
62
- if constraint_name:
63
- self.constraint_name = constraint_name
142
+ # If no constraint_name is set, generate one based
143
+ # on the class name and a random suffix
144
+ if name:
145
+ self.name = name
64
146
  else:
65
147
  random_suffix = "".join(
66
148
  random.choices(string.ascii_uppercase + string.digits, k=6)
67
149
  )
68
- self.constraint_name = f"{self.__class__.__name__}_{random_suffix}"
69
- logging.warning(
70
- f"Name for constraint is not set. Using {self.constraint_name}."
150
+ self.name = f"{self.__class__.__name__}_{random_suffix}"
151
+ warnings.warn(
152
+ "Name for constraint is not set. Using %s.", self.name
71
153
  )
72
154
 
73
- if self.descriptor == None:
74
- raise ValueError(
75
- "The descriptor of the base Constraint class in not set. Please assign the descriptor to the general Constraint class with 'Constraint.descriptor = descriptor' before defining network-specific contraints."
76
- )
77
-
78
- if not rescale_factor > 1:
155
+ # If rescale factor is not larger than 1, warn user and adjust
156
+ if rescale_factor <= 1:
79
157
  self.rescale_factor = abs(rescale_factor) + 1.5
80
- logging.warning(
81
- f"Rescale factor for constraint {constraint_name} is < 1, adjusted value {rescale_factor} to {self.rescale_factor}."
158
+ warnings.warn(
159
+ "Rescale factor for constraint %s is < 1, adjusted value \
160
+ %s to %s.",
161
+ name,
162
+ rescale_factor,
163
+ self.rescale_factor,
82
164
  )
83
165
  else:
84
166
  self.rescale_factor = rescale_factor
85
167
 
86
- self.neuron_names = neuron_names
87
-
88
- self.run_init_descriptor()
89
-
90
- def run_init_descriptor(self) -> None:
91
- """
92
- Initializes the layers associated with the constraint by mapping the neuron names to their corresponding layers
93
- from the descriptor.
94
-
95
- This method populates the `layers` attribute with layers associated with the neuron names provided in the constraint.
96
-
97
- Raises:
98
- ValueError: If a neuron name is not found in the descriptor's mapping of neurons to layers.
99
- """
100
-
168
+ # Infer layers from descriptor and neurons
101
169
  self.layers = set()
102
- for neuron_name in self.neuron_names:
103
- if neuron_name in self.descriptor.neuron_to_layer.keys():
104
- self.layers.add(self.descriptor.neuron_to_layer[neuron_name])
105
- else:
170
+ for neuron in self.neurons:
171
+ if neuron not in self.descriptor.neuron_to_layer.keys():
106
172
  raise ValueError(
107
- f'The neuron name {neuron_name} used with constraint {self.constraint_name} is not defined in the descriptor. Please add it to the correct layer using descriptor.add("layer", ...).'
173
+ f'The neuron name {neuron} used with constraint \
174
+ {self.name} is not defined in the descriptor. Please \
175
+ add it to the correct layer using \
176
+ descriptor.add("layer", ...).'
108
177
  )
109
178
 
179
+ self.layers.add(self.descriptor.neuron_to_layer[neuron])
180
+
110
181
  @abstractmethod
111
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
182
+ def check_constraint(
183
+ self, prediction: dict[str, Tensor]
184
+ ) -> tuple[Tensor, int]:
112
185
  """
113
- Abstract method to check if the constraint is satisfied based on the model's predictions.
114
-
115
- This method should be implemented in subclasses to define the specific logic for evaluating the constraint based on the model's predictions.
186
+ Evaluates whether the given model predictions satisfy the constraint.
116
187
 
117
188
  Args:
118
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
189
+ prediction (dict[str, Tensor]): Model predictions for the neurons.
119
190
 
120
191
  Returns:
121
- dict[str, Tensor]: A dictionary containing the satisfaction status of the constraint for each layer or neuron.
192
+ tuple[Tensor, int]: A tuple where the first element is a tensor
193
+ indicating whether the constraint is satisfied (with `True`
194
+ for satisfaction, `False` for non-satisfaction, and `torch.nan`
195
+ for irrelevant results), and the second element is an integer
196
+ value representing the number of relevant constraints.
122
197
 
123
198
  Raises:
124
- NotImplementedError: If the method is not implemented in a subclass.
199
+ NotImplementedError: If not implemented in a subclass.
125
200
  """
126
201
 
127
202
  raise NotImplementedError
128
203
 
129
204
  @abstractmethod
130
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
205
+ def calculate_direction(
206
+ self, prediction: dict[str, Tensor]
207
+ ) -> Dict[str, Tensor]:
131
208
  """
132
- Abstract method to calculate the direction in which the model's predictions need to be adjusted to satisfy the constraint.
133
-
134
- This method should be implemented in subclasses to define how to adjust the model's predictions based on the constraint.
209
+ Calculates adjustment directions for neurons to
210
+ better satisfy the constraint.
135
211
 
136
212
  Args:
137
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
213
+ prediction (dict[str, Tensor]): Model predictions for the neurons.
138
214
 
139
215
  Returns:
140
- dict[str, Tensor]: A dictionary containing the direction for each layer or neuron, to adjust the model's predictions.
216
+ Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
217
+ specifying the adjustment direction for each neuron.
141
218
 
142
219
  Raises:
143
- NotImplementedError: If the method is not implemented in a subclass.
220
+ NotImplementedError: If not implemented in a subclass.
144
221
  """
222
+
145
223
  raise NotImplementedError
146
224
 
147
225
 
148
- class ScalarConstraint(Constraint):
226
+ class ImplicationConstraint(Constraint):
227
+ """
228
+ Represents an implication constraint between two
229
+ constraints (head and body).
230
+
231
+ The implication constraint ensures that the `body` constraint only applies
232
+ when the `head` constraint is satisfied. If the `head` constraint is not
233
+ satisfied, the `body` constraint does not apply.
234
+
235
+ Args:
236
+ head (Constraint): The head of the implication. If this constraint
237
+ is satisfied, the body constraint must also be satisfied.
238
+ body (Constraint): The body of the implication. This constraint
239
+ is enforced only when the head constraint is satisfied.
240
+ name (str, optional): A unique name for the constraint. If not
241
+ provided, the name is generated in the format
242
+ "{body.name} if {head.name}". Defaults to None.
243
+ monitor_only (bool, optional): If True, the constraint is only
244
+ monitored without adjusting the loss. Defaults to False.
245
+ rescale_factor (Number, optional): The scaling factor for the
246
+ constraint-adjusted loss. Defaults to 1.5.
247
+
248
+ Raises:
249
+ TypeError: If a provided attribute has an incompatible type.
250
+
149
251
  """
150
- A subclass of the `Constraint` class that applies a scalar constraint on a specific neuron in the model.
151
252
 
152
- This constraint compares the value of a specific neuron in the model to a scalar value using a specified comparator (e.g., greater than, less than).
153
- If the constraint is violated, it adjusts the loss according to the direction defined by the comparator.
253
+ def __init__(
254
+ self,
255
+ head: Constraint,
256
+ body: Constraint,
257
+ name=None,
258
+ monitor_only=False,
259
+ rescale_factor=1.5,
260
+ ):
261
+ """
262
+ Initializes an ImplicationConstraint instance.
263
+ """
264
+
265
+ # Type checking
266
+ validate_type("head", head, Constraint)
267
+ validate_type("body", body, Constraint)
268
+
269
+ # Compose constraint name
270
+ name = f"{body.name} if {head.name}"
271
+
272
+ # Init parent class
273
+ super().__init__(
274
+ head.neurons | body.neurons,
275
+ name,
276
+ monitor_only,
277
+ rescale_factor,
278
+ )
279
+
280
+ self.head = head
281
+ self.body = body
282
+
283
+ def check_constraint(
284
+ self, prediction: dict[str, Tensor]
285
+ ) -> tuple[Tensor, int]:
286
+
287
+ # Check satisfaction of head and body constraints
288
+ head_satisfaction, _ = self.head.check_constraint(prediction)
289
+ body_satisfaction, _ = self.body.check_constraint(prediction)
290
+
291
+ # If head constraint is satisfied (returning 1),
292
+ # the body constraint matters (and should return 0/1 based on body)
293
+ # If head constraint is not satisfied (returning 0),
294
+ # the body constraint does not apply (and should return 1)
295
+ result = logical_or(
296
+ logical_not(head_satisfaction), body_satisfaction
297
+ ).float()
298
+
299
+ return result, count_nonzero(head_satisfaction)
300
+
301
+ def calculate_direction(
302
+ self, prediction: dict[str, Tensor]
303
+ ) -> Dict[str, Tensor]:
304
+ # NOTE currently only works for dense layers
305
+ # due to neuron to index translation
306
+
307
+ # Use directions of constraint body as update vector
308
+ return self.body.calculate_direction(prediction)
309
+
310
+
311
+ class ScalarConstraint(Constraint):
312
+ """
313
+ A constraint that enforces scalar-based comparisons on a specific neuron.
314
+
315
+ This class ensures that the output of a specified neuron satisfies a scalar
316
+ comparison operation (e.g., less than, greater than, etc.). It uses a
317
+ comparator function to validate the condition and calculates adjustment
318
+ directions accordingly.
319
+
320
+ Args:
321
+ operand (Union[str, Transformation]): Name of the neuron or a
322
+ transformation to apply.
323
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
324
+ function (e.g., `torch.ge`, `torch.lt`).
325
+ scalar (Number): The scalar value to compare against.
326
+ name (str, optional): A unique name for the constraint. If not
327
+ provided, a name is auto-generated in the format
328
+ "<neuron_name> <comparator> <scalar>".
329
+ monitor_only (bool, optional): If True, only monitor the constraint
330
+ without adjusting the loss. Defaults to False.
331
+ rescale_factor (Number, optional): Factor to scale the
332
+ constraint-adjusted loss. Defaults to 1.5.
333
+
334
+ Raises:
335
+ TypeError: If a provided attribute has an incompatible type.
336
+
337
+ Notes:
338
+ - The `neuron_name` must be defined in the `descriptor` mapping.
339
+ - The constraint name is composed using the neuron name,
340
+ comparator, and scalar value.
154
341
 
155
- Attributes:
156
- comparator (Callable[[Tensor, Number], Tensor]): A comparator function (e.g., greater than, less than) to evaluate the constraint.
157
- scalar (Number): The scalar value to compare the neuron value against.
158
- direction (int): The direction in which the constraint should adjust the model's predictions (either 1 or -1 based on the comparator).
159
- layer (str): The layer associated with the specified neuron.
160
- index (int): The index of the specified neuron within the layer.
161
342
  """
162
343
 
163
344
  def __init__(
164
345
  self,
165
- neuron_name: str,
346
+ operand: Union[str, Transformation],
166
347
  comparator: Callable[[Tensor, Number], Tensor],
167
348
  scalar: Number,
168
349
  name: str = None,
169
- descriptor: Descriptor = None,
170
- rescale_factor: float = 1.5,
350
+ monitor_only: bool = False,
351
+ rescale_factor: Number = 1.5,
171
352
  ) -> None:
172
353
  """
173
- Initializes the ScalarConstraint with the given neuron name, comparator, scalar value, and other optional parameters.
174
-
175
- Args:
176
- neuron_name (str): The name of the neuron that the constraint applies to.
177
- comparator (Callable[[Tensor, Number], Tensor]): The comparator function used to evaluate the constraint (e.g., ge, le, gt, lt).
178
- scalar (Number): The scalar value that the neuron value is compared to.
179
- name (str, optional): A custom name for the constraint. If not provided, a name is generated based on the neuron name, comparator, and scalar.
180
- descriptor (Descriptor, optional): The descriptor that maps neurons to layers. If not provided, the global descriptor is used.
181
- rescale_factor (float, optional): A factor that scales the influence of the constraint on the overall loss. Defaults to 1.5.
182
-
183
- Raises:
184
- ValueError: If the comparator function is not one of the supported comparison operators (ge, le, gt, lt).
354
+ Initializes a ScalarConstraint instance.
185
355
  """
186
356
 
357
+ # Type checking
358
+ validate_type("operand", operand, (str, Transformation))
359
+ validate_comparator_pytorch("comparator", comparator)
360
+ validate_comparator_pytorch("comparator", comparator)
361
+ validate_type("scalar", scalar, Number)
362
+
363
+ # If transformation is provided, get neuron name,
364
+ # else use IdentityTransformation
365
+ if isinstance(operand, Transformation):
366
+ neuron_name = operand.neuron_name
367
+ transformation = operand
368
+ else:
369
+ neuron_name = operand
370
+ transformation = IdentityTransformation(neuron_name)
371
+
187
372
  # Compose constraint name
188
- name = f"{neuron_name}_{comparator.__name__}_{str(scalar)}"
373
+ name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
189
374
 
190
375
  # Init parent class
191
- super().__init__({neuron_name}, name, rescale_factor)
376
+ super().__init__({neuron_name}, name, monitor_only, rescale_factor)
192
377
 
193
378
  # Init variables
194
379
  self.comparator = comparator
195
380
  self.scalar = scalar
196
-
197
- if descriptor != None:
198
- self.descriptor = descriptor
199
- self.run_init_descriptor()
381
+ self.transformation = transformation
200
382
 
201
383
  # Get layer name and feature index from neuron_name
202
384
  self.layer = self.descriptor.neuron_to_layer[neuron_name]
203
385
  self.index = self.descriptor.neuron_to_index[neuron_name]
204
386
 
205
- # If comparator function is not supported, raise error
206
- if comparator not in [ge, le, gt, lt]:
207
- raise ValueError(
208
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
209
- )
210
-
211
387
  # Calculate directions based on constraint operator
212
388
  if self.comparator in [lt, le]:
213
- self.direction = 1
214
- elif self.comparator in [gt, ge]:
215
389
  self.direction = -1
390
+ elif self.comparator in [gt, ge]:
391
+ self.direction = 1
216
392
 
217
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
218
- """
219
- Checks if the constraint is satisfied based on the model's predictions.
220
-
221
- The constraint is evaluated by applying the comparator to the value of the specified neuron and the scalar value.
393
+ def check_constraint(
394
+ self, prediction: dict[str, Tensor]
395
+ ) -> tuple[Tensor, int]:
222
396
 
223
- Args:
224
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
397
+ # Select relevant columns
398
+ selection = prediction[self.layer][:, self.index]
225
399
 
226
- Returns:
227
- dict[str, Tensor]: A dictionary containing the constraint satisfaction result for the specified layer.
228
- """
400
+ # Apply transformation
401
+ selection = self.transformation(selection)
229
402
 
230
- result = ~self.comparator(prediction[self.layer][:, self.index], self.scalar)
403
+ # Calculate current constraint result
404
+ result = self.comparator(selection, self.scalar).float()
405
+ return result, numel(result)
231
406
 
232
- return {self.layer: result}
407
+ def calculate_direction(
408
+ self, prediction: dict[str, Tensor]
409
+ ) -> Dict[str, Tensor]:
410
+ # NOTE currently only works for dense layers due
411
+ # to neuron to index translation
233
412
 
234
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
235
- """
236
- Calculates the direction in which the model's predictions need to be adjusted to satisfy the constraint.
413
+ output = {}
237
414
 
238
- The direction is determined by the comparator and represents either a positive or negative adjustment.
415
+ for layer in self.layers:
416
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
239
417
 
240
- Args:
241
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
418
+ output[self.layer][self.index] = self.direction
242
419
 
243
- Returns:
244
- dict[str, Tensor]: A dictionary containing the direction for each layer or neuron, to adjust the model's predictions.
245
- """
246
-
247
- output = zeros(
248
- prediction[self.layer].size(),
249
- device=prediction[self.layer].device,
250
- )
251
- output[:, self.index] = self.direction
420
+ for layer in self.layers:
421
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
252
422
 
253
- return {self.layer: output}
423
+ return output
254
424
 
255
425
 
256
426
  class BinaryConstraint(Constraint):
257
427
  """
258
- A class representing a binary constraint between two neurons in a neural network.
259
-
260
- This class checks and enforces a constraint between two neurons using a
261
- comparator function. The constraint is applied between two neurons located
262
- in different layers of the neural network. The class also calculates the
263
- direction for gradient adjustment based on the comparator.
264
-
265
- Attributes:
266
- neuron_name_left (str): The name of the first neuron involved in the constraint.
267
- neuron_name_right (str): The name of the second neuron involved in the constraint.
268
- comparator (Callable[[Tensor, Number], Tensor]): A function that compares the values of the two neurons.
269
- layer_left (str): The layer name for the first neuron.
270
- layer_right (str): The layer name for the second neuron.
271
- index_left (int): The index of the first neuron within its layer.
272
- index_right (int): The index of the second neuron within its layer.
273
- direction_left (float): The normalized direction for gradient adjustment of the first neuron.
274
- direction_right (float): The normalized direction for gradient adjustment of the second neuron.
428
+ A constraint that enforces a binary comparison between two neurons.
429
+
430
+ This class ensures that the output of one neuron satisfies a comparison
431
+ operation with the output of another neuron
432
+ (e.g., less than, greater than, etc.). It uses a comparator function to
433
+ validate the condition and calculates adjustment directions accordingly.
434
+
435
+ Args:
436
+ operand_left (Union[str, Transformation]): Name of the left
437
+ neuron or a transformation to apply.
438
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
439
+ function (e.g., `torch.ge`, `torch.lt`).
440
+ operand_right (Union[str, Transformation]): Name of the right
441
+ neuron or a transformation to apply.
442
+ name (str, optional): A unique name for the constraint. If not
443
+ provided, a name is auto-generated in the format
444
+ "<neuron_name_left> <comparator> <neuron_name_right>".
445
+ monitor_only (bool, optional): If True, only monitor the constraint
446
+ without adjusting the loss. Defaults to False.
447
+ rescale_factor (Number, optional): Factor to scale the
448
+ constraint-adjusted loss. Defaults to 1.5.
449
+
450
+ Raises:
451
+ TypeError: If a provided attribute has an incompatible type.
452
+
453
+ Notes:
454
+ - The neuron names must be defined in the `descriptor` mapping.
455
+ - The constraint name is composed using the left neuron name,
456
+ comparator, and right neuron name.
457
+
275
458
  """
276
459
 
277
460
  def __init__(
278
461
  self,
279
- neuron_name_left: str,
462
+ operand_left: Union[str, Transformation],
280
463
  comparator: Callable[[Tensor, Number], Tensor],
281
- neuron_name_right: str,
464
+ operand_right: Union[str, Transformation],
282
465
  name: str = None,
283
- descriptor: Descriptor = None,
284
- rescale_factor: float = 1.5,
466
+ monitor_only: bool = False,
467
+ rescale_factor: Number = 1.5,
285
468
  ) -> None:
286
469
  """
287
- Initializes the binary constraint with two neurons, a comparator, and other configuration options.
288
-
289
- Args:
290
- neuron_name_left (str): The name of the first neuron in the constraint.
291
- comparator (Callable[[Tensor, Number], Tensor]): A function that compares the values of the two neurons.
292
- neuron_name_right (str): The name of the second neuron in the constraint.
293
- name (str, optional): The name of the constraint. If not provided, a default name is generated.
294
- descriptor (Descriptor, optional): The descriptor containing the mapping of neurons to layers.
295
- rescale_factor (float, optional): A factor to rescale the constraint value. Default is 1.5.
470
+ Initializes a BinaryConstraint instance.
296
471
  """
297
472
 
473
+ # Type checking
474
+ validate_type("operand_left", operand_left, (str, Transformation))
475
+ validate_comparator_pytorch("comparator", comparator)
476
+ validate_comparator_pytorch("comparator", comparator)
477
+ validate_type("operand_right", operand_right, (str, Transformation))
478
+
479
+ # If transformation is provided, get neuron name,
480
+ # else use IdentityTransformation
481
+ if isinstance(operand_left, Transformation):
482
+ neuron_name_left = operand_left.neuron_name
483
+ transformation_left = operand_left
484
+ else:
485
+ neuron_name_left = operand_left
486
+ transformation_left = IdentityTransformation(neuron_name_left)
487
+
488
+ if isinstance(operand_right, Transformation):
489
+ neuron_name_right = operand_right.neuron_name
490
+ transformation_right = operand_right
491
+ else:
492
+ neuron_name_right = operand_right
493
+ transformation_right = IdentityTransformation(neuron_name_right)
494
+
298
495
  # Compose constraint name
299
- name = f"{neuron_name_left}_{comparator.__name__}_{neuron_name_right}"
496
+ name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
300
497
 
301
498
  # Init parent class
302
- super().__init__({neuron_name_left, neuron_name_right}, name, rescale_factor)
499
+ super().__init__(
500
+ {neuron_name_left, neuron_name_right},
501
+ name,
502
+ monitor_only,
503
+ rescale_factor,
504
+ )
303
505
 
304
506
  # Init variables
305
507
  self.comparator = comparator
306
-
307
- if descriptor != None:
308
- self.descriptor = descriptor
309
- self.run_init_descriptor()
508
+ self.transformation_left = transformation_left
509
+ self.transformation_right = transformation_right
310
510
 
311
511
  # Get layer name and feature index from neuron_name
312
512
  self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
@@ -314,12 +514,6 @@ class BinaryConstraint(Constraint):
314
514
  self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
315
515
  self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
316
516
 
317
- # If comparator function is not supported, raise error
318
- if comparator not in [ge, le, gt, lt]:
319
- raise RuntimeError(
320
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
321
- )
322
-
323
517
  # Calculate directions based on constraint operator
324
518
  if self.comparator in [lt, le]:
325
519
  self.direction_left = -1
@@ -328,128 +522,179 @@ class BinaryConstraint(Constraint):
328
522
  self.direction_left = 1
329
523
  self.direction_right = -1
330
524
 
331
- # Normalize directions
332
- normalized_directions = normalize(
333
- tensor([self.direction_left, self.direction_right]).type(float32),
334
- p=2,
335
- dim=0,
336
- )
337
- self.direction_left = normalized_directions[0]
338
- self.direction_right = normalized_directions[1]
525
+ def check_constraint(
526
+ self, prediction: dict[str, Tensor]
527
+ ) -> tuple[Tensor, int]:
339
528
 
340
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
341
- """
342
- Checks whether the binary constraint is satisfied between the two neurons.
529
+ # Select relevant columns
530
+ selection_left = prediction[self.layer_left][:, self.index_left]
531
+ selection_right = prediction[self.layer_right][:, self.index_right]
343
532
 
344
- This function applies the comparator to the output values of the two neurons
345
- and returns a Boolean result for each neuron.
533
+ # Apply transformations
534
+ selection_left = self.transformation_left(selection_left)
535
+ selection_right = self.transformation_right(selection_right)
346
536
 
347
- Args:
348
- prediction (dict[str, Tensor]): A dictionary containing the predictions for each layer.
537
+ result = self.comparator(selection_left, selection_right).float()
349
538
 
350
- Returns:
351
- dict[str, Tensor]: A dictionary with the layer names as keys and the constraint satisfaction results as values.
352
- """
539
+ return result, numel(result)
353
540
 
354
- result = ~self.comparator(
355
- prediction[self.layer_left][:, self.index_left],
356
- prediction[self.layer_right][:, self.index_right],
357
- )
541
+ def calculate_direction(
542
+ self, prediction: dict[str, Tensor]
543
+ ) -> Dict[str, Tensor]:
544
+ # NOTE currently only works for dense layers due
545
+ # to neuron to index translation
358
546
 
359
- return {self.layer_left: result, self.layer_right: result}
547
+ output = {}
360
548
 
361
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
362
- """
363
- Calculates the direction for gradient adjustment for both neurons involved in the constraint.
549
+ for layer in self.layers:
550
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
364
551
 
365
- The directions are normalized and represent the direction in which the constraint should be enforced.
552
+ output[self.layer_left][self.index_left] = self.direction_left
553
+ output[self.layer_right][self.index_right] = self.direction_right
366
554
 
367
- Args:
368
- prediction (dict[str, Tensor]): A dictionary containing the predictions for each layer.
555
+ for layer in self.layers:
556
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
369
557
 
370
- Returns:
371
- dict[str, Tensor]: A dictionary with the layer names as keys and the direction vectors as values.
372
- """
558
+ return output
373
559
 
374
- output_left = zeros(
375
- prediction[self.layer_left].size(),
376
- device=prediction[self.layer_left].device,
377
- )
378
- output_left[:, self.index_left] = self.direction_left
379
560
 
380
- output_right = zeros(
381
- prediction[self.layer_right].size(),
382
- device=prediction[self.layer_right].device,
383
- )
384
- output_right[:, self.index_right] = self.direction_right
385
-
386
- return {self.layer_left: output_left, self.layer_right: output_right}
561
+ class SumConstraint(Constraint):
562
+ """
563
+ A constraint that enforces a weighted summation comparison
564
+ between two groups of neurons.
565
+
566
+ This class evaluates whether the weighted sum of outputs from one set of
567
+ neurons satisfies a comparison operation with the weighted sum of
568
+ outputs from another set of neurons.
569
+
570
+ Args:
571
+ operands_left (list[Union[str, Transformation]]): List of neuron
572
+ names or transformations on the left side.
573
+ comparator (Callable[[Tensor, Number], Tensor]): A comparison
574
+ function for the constraint.
575
+ operands_right (list[Union[str, Transformation]]): List of neuron
576
+ names or transformations on the right side.
577
+ weights_left (list[Number], optional): Weights for the left neurons.
578
+ Defaults to None.
579
+ weights_right (list[Number], optional): Weights for the right
580
+ neurons. Defaults to None.
581
+ name (str, optional): Unique name for the constraint.
582
+ If None, it's auto-generated. Defaults to None.
583
+ monitor_only (bool, optional): If True, only monitor the constraint
584
+ without adjusting the loss. Defaults to False.
585
+ rescale_factor (Number, optional): Factor to scale the
586
+ constraint-adjusted loss. Defaults to 1.5.
587
+
588
+ Raises:
589
+ TypeError: If a provided attribute has an incompatible type.
590
+ ValueError: If the dimensions of neuron names and weights mismatch.
387
591
 
592
+ """
388
593
 
389
- # FIXME
390
- class SumConstraint(Constraint):
391
594
  def __init__(
392
595
  self,
393
- neuron_names_left: list[str],
596
+ operands_left: list[Union[str, Transformation]],
394
597
  comparator: Callable[[Tensor, Number], Tensor],
395
- neuron_names_right: list[str],
396
- weights_left: list[float] = None,
397
- weights_right: list[float] = None,
598
+ operands_right: list[Union[str, Transformation]],
599
+ weights_left: list[Number] = None,
600
+ weights_right: list[Number] = None,
398
601
  name: str = None,
399
- descriptor: Descriptor = None,
400
- rescale_factor: float = 1.5,
602
+ monitor_only: bool = False,
603
+ rescale_factor: Number = 1.5,
401
604
  ) -> None:
605
+ """
606
+ Initializes the SumConstraint.
607
+ """
402
608
 
403
- # Init parent class
404
- super().__init__(
405
- set(neuron_names_left) & set(neuron_names_right), name, rescale_factor
609
+ # Type checking
610
+ validate_iterable("operands_left", operands_left, (str, Transformation))
611
+ validate_comparator_pytorch("comparator", comparator)
612
+ validate_comparator_pytorch("comparator", comparator)
613
+ validate_iterable(
614
+ "operands_right", operands_right, (str, Transformation)
615
+ )
616
+ validate_iterable("weights_left", weights_left, Number, allow_none=True)
617
+ validate_iterable(
618
+ "weights_right", weights_right, Number, allow_none=True
406
619
  )
407
620
 
621
+ # If transformation is provided, get neuron name,
622
+ # else use IdentityTransformation
623
+ neuron_names_left: list[str] = []
624
+ transformations_left: list[Transformation] = []
625
+ for operand_left in operands_left:
626
+ if isinstance(operand_left, Transformation):
627
+ neuron_name_left = operand_left.neuron_name
628
+ neuron_names_left.append(neuron_name_left)
629
+ transformations_left.append(operand_left)
630
+ else:
631
+ neuron_name_left = operand_left
632
+ neuron_names_left.append(neuron_name_left)
633
+ transformations_left.append(
634
+ IdentityTransformation(neuron_name_left)
635
+ )
636
+
637
+ neuron_names_right: list[str] = []
638
+ transformations_right: list[Transformation] = []
639
+ for operand_right in operands_right:
640
+ if isinstance(operand_right, Transformation):
641
+ neuron_name_right = operand_right.neuron_name
642
+ neuron_names_right.append(neuron_name_right)
643
+ transformations_right.append(operand_right)
644
+ else:
645
+ neuron_name_right = operand_right
646
+ neuron_names_right.append(neuron_name_right)
647
+ transformations_right.append(
648
+ IdentityTransformation(neuron_name_right)
649
+ )
650
+
651
+ # Compose constraint name
652
+ w_left = weights_left or [""] * len(neuron_names_left)
653
+ w_right = weights_right or [""] * len(neuron_names_right)
654
+ left_expr = " + ".join(
655
+ f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
656
+ )
657
+ right_expr = " + ".join(
658
+ f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
659
+ )
660
+ comparator_name = comparator.__name__
661
+ name = f"{left_expr} {comparator_name} {right_expr}"
662
+
663
+ # Init parent class
664
+ neuron_names = set(neuron_names_left) | set(neuron_names_right)
665
+ super().__init__(neuron_names, name, monitor_only, rescale_factor)
666
+
408
667
  # Init variables
409
668
  self.comparator = comparator
669
+ self.neuron_names_left = neuron_names_left
670
+ self.neuron_names_right = neuron_names_right
671
+ self.transformations_left = transformations_left
672
+ self.transformations_right = transformations_right
410
673
 
411
- if descriptor != None:
412
- self.descriptor = descriptor
413
- self.run_init_descriptor()
414
-
415
- # Get layer names and feature indices from neuron_name
416
- self.layers_left = []
417
- self.indices_left = []
418
- for neuron_name in neuron_names_left:
419
- self.layers_left.append(self.descriptor.neuron_to_layer[neuron_name])
420
- self.indices_left.append(self.descriptor.neuron_to_index[neuron_name])
421
-
422
- self.layers_right = []
423
- self.indices_right = []
424
- for neuron_name in neuron_names_right:
425
- self.layers_right.append(self.descriptor.neuron_to_layer[neuron_name])
426
- self.indices_right.append(self.descriptor.neuron_to_index[neuron_name])
427
-
428
- # If comparator function is not supported, raise error
429
- if comparator not in [ge, le, gt, lt]:
430
- raise ValueError(
431
- f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
432
- )
433
-
434
- # If feature list dimensions don't match weight list dimensions, raise error
674
+ # If feature list dimensions don't match
675
+ # weight list dimensions, raise error
435
676
  if weights_left and (len(neuron_names_left) != len(weights_left)):
436
677
  raise ValueError(
437
- "The dimensions of neuron_names_left don't match with the dimensions of weights_left."
678
+ "The dimensions of neuron_names_left don't match with the \
679
+ dimensions of weights_left."
438
680
  )
439
681
  if weights_right and (len(neuron_names_right) != len(weights_right)):
440
682
  raise ValueError(
441
- "The dimensions of neuron_names_right don't match with the dimensions of weights_right."
683
+ "The dimensions of neuron_names_right don't match with the \
684
+ dimensions of weights_right."
442
685
  )
443
686
 
444
687
  # If weights are provided for summation, transform them to Tensors
445
688
  if weights_left:
446
- self.weights_left = FloatTensor(weights_left)
689
+ self.weights_left = tensor(weights_left, device=self.device)
447
690
  else:
448
- self.weights_left = ones(len(neuron_names_left))
691
+ self.weights_left = ones(len(neuron_names_left), device=self.device)
449
692
  if weights_right:
450
- self.weights_right = FloatTensor(weights_right)
693
+ self.weights_right = tensor(weights_right, device=self.device)
451
694
  else:
452
- self.weights_right = ones(len(neuron_names_right))
695
+ self.weights_right = ones(
696
+ len(neuron_names_right), device=self.device
697
+ )
453
698
 
454
699
  # Calculate directions based on constraint operator
455
700
  if self.comparator in [lt, le]:
@@ -459,49 +704,216 @@ class SumConstraint(Constraint):
459
704
  self.direction_left = 1
460
705
  self.direction_right = -1
461
706
 
462
- # Normalize directions
463
- normalized_directions = normalize(
464
- tensor(self.direction_left, self.direction_right), p=2, dim=0
707
+ def check_constraint(
708
+ self, prediction: dict[str, Tensor]
709
+ ) -> tuple[Tensor, int]:
710
+
711
+ def compute_weighted_sum(
712
+ neuron_names: list[str],
713
+ transformations: list[Transformation],
714
+ weights: tensor,
715
+ ) -> tensor:
716
+ layers = [
717
+ self.descriptor.neuron_to_layer[neuron_name]
718
+ for neuron_name in neuron_names
719
+ ]
720
+ indices = [
721
+ self.descriptor.neuron_to_index[neuron_name]
722
+ for neuron_name in neuron_names
723
+ ]
724
+
725
+ # Select relevant column
726
+ selections = [
727
+ prediction[layer][:, index]
728
+ for layer, index in zip(layers, indices)
729
+ ]
730
+
731
+ # Apply transformations
732
+ results = []
733
+ for transformation, selection in zip(transformations, selections):
734
+ results.append(transformation(selection))
735
+
736
+ # Extract predictions for all neurons and apply weights in bulk
737
+ predictions = stack(
738
+ results,
739
+ dim=1,
740
+ )
741
+
742
+ # Calculate weighted sum
743
+ return (predictions * weights.unsqueeze(0)).sum(dim=1)
744
+
745
+ # Compute weighted sums
746
+ weighted_sum_left = compute_weighted_sum(
747
+ self.neuron_names_left,
748
+ self.transformations_left,
749
+ self.weights_left,
750
+ )
751
+ weighted_sum_right = compute_weighted_sum(
752
+ self.neuron_names_right,
753
+ self.transformations_right,
754
+ self.weights_right,
465
755
  )
466
- self.direction_left = normalized_directions[0]
467
- self.direction_right = normalized_directions[1]
468
756
 
469
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
470
- raise NotImplementedError
471
- # # TODO remove the dynamic to device conversion and do this in initialization one way or another
472
- # weighted_sum_left = (
473
- # prediction[layer_left][:, index_left]
474
- # * self.weights_left.to(prediction[layer_left].device)
475
- # ).sum(dim=1)
476
- # weighted_sum_right = (
477
- # prediction[layer_right][:, index_right]
478
- # * self.weights_right.to(prediction[layer_right].device)
479
- # ).sum(dim=1)
480
-
481
- # result = ~self.comparator(weighted_sum_left, weighted_sum_right)
482
-
483
- # return {layer_left: result, layer_right: result}
484
- pass
485
-
486
- def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
487
- raise NotImplementedError
488
- # # TODO move this to constructor somehow
489
- # layer_left = prediction.neuron_to_layer[self.neuron_name_left]
490
- # layer_right = prediction.neuron_to_layer[self.neuron_name_right]
491
- # index_left = prediction.neuron_to_index[self.neuron_name_left]
492
- # index_right = prediction.neuron_to_index[self.neuron_name_right]
493
-
494
- # output_left = zeros(
495
- # prediction[layer_left].size(),
496
- # device=prediction[layer_left].device,
497
- # )
498
- # output_left[:, index_left] = self.direction_left
499
-
500
- # output_right = zeros(
501
- # prediction.layer_to_data[layer_right].size(),
502
- # device=prediction.layer_to_data[layer_right].device,
503
- # )
504
- # output_right[:, index_right] = self.direction_right
505
-
506
- # return {layer_left: output_left, layer_right: output_right}
507
- pass
757
+ # Apply the comparator and calculate the result
758
+ result = self.comparator(weighted_sum_left, weighted_sum_right).float()
759
+
760
+ return result, numel(result)
761
+
762
+ def calculate_direction(
763
+ self, prediction: dict[str, Tensor]
764
+ ) -> Dict[str, Tensor]:
765
+ # NOTE currently only works for dense layers
766
+ # due to neuron to index translation
767
+
768
+ output = {}
769
+
770
+ for layer in self.layers:
771
+ output[layer] = zeros_like(prediction[layer][0], device=self.device)
772
+
773
+ for neuron_name_left in self.neuron_names_left:
774
+ layer = self.descriptor.neuron_to_layer[neuron_name_left]
775
+ index = self.descriptor.neuron_to_index[neuron_name_left]
776
+ output[layer][index] = self.direction_left
777
+
778
+ for neuron_name_right in self.neuron_names_right:
779
+ layer = self.descriptor.neuron_to_layer[neuron_name_right]
780
+ index = self.descriptor.neuron_to_index[neuron_name_right]
781
+ output[layer][index] = self.direction_right
782
+
783
+ for layer in self.layers:
784
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
785
+
786
+ return output
787
+
788
+
789
+ class PythagoreanIdentityConstraint(Constraint):
790
+ """
791
+ A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
792
+ where `a` and `b` are neurons or transformations.
793
+
794
+ This constraint checks that the sum of the squares of two specified
795
+ neurons (or their transformations) is approximately equal to 1.
796
+ The constraint is evaluated using relative and absolute
797
+ tolerance (`rtol` and `atol`) and is applied during the forward pass.
798
+
799
+ Args:
800
+ a (Union[str, Transformation]): The first input, either a
801
+ neuron name (str) or a Transformation.
802
+ b (Union[str, Transformation]): The second input, either a
803
+ neuron name (str) or a Transformation.
804
+ rtol (float, optional): The relative tolerance for the
805
+ comparison (default is 0.00001).
806
+ atol (float, optional): The absolute tolerance for the
807
+ comparison (default is 1e-8).
808
+ name (str, optional): The name of the constraint
809
+ (default is None, and it is generated automatically).
810
+ monitor_only (bool, optional): Flag indicating whether the
811
+ constraint is only for monitoring (default is False).
812
+ rescale_factor (Number, optional): A factor used for
813
+ rescaling (default is 1.5).
814
+
815
+ Raises:
816
+ TypeError: If a provided attribute has an incompatible type.
817
+
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ a: Union[str, Transformation],
823
+ b: Union[str, Transformation],
824
+ rtol: float = 0.00001,
825
+ atol: float = 1e-8,
826
+ name: str = None,
827
+ monitor_only: bool = False,
828
+ rescale_factor: Number = 1.5,
829
+ ) -> None:
830
+ """
831
+ Initialize the PythagoreanIdentityConstraint.
832
+ """
833
+
834
+ # Type checking
835
+ validate_type("a", a, (str, Transformation))
836
+ validate_type("b", b, (str, Transformation))
837
+ validate_type("rtol", rtol, float)
838
+ validate_type("atol", atol, float)
839
+
840
+ # If transformation is provided, get neuron name,
841
+ # else use IdentityTransformation
842
+ if isinstance(a, Transformation):
843
+ neuron_name_a = a.neuron_name
844
+ transformation_a = a
845
+ else:
846
+ neuron_name_a = a
847
+ transformation_a = IdentityTransformation(neuron_name_a)
848
+
849
+ if isinstance(b, Transformation):
850
+ neuron_name_b = b.neuron_name
851
+ transformation_b = b
852
+ else:
853
+ neuron_name_b = b
854
+ transformation_b = IdentityTransformation(neuron_name_b)
855
+
856
+ # Compose constraint name
857
+ name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
858
+
859
+ # Init parent class
860
+ super().__init__(
861
+ {neuron_name_a, neuron_name_b},
862
+ name,
863
+ monitor_only,
864
+ rescale_factor,
865
+ )
866
+
867
+ # Init variables
868
+ self.transformation_a = transformation_a
869
+ self.transformation_b = transformation_b
870
+ self.rtol = rtol
871
+ self.atol = atol
872
+
873
+ # Get layer name and feature index from neuron_name
874
+ self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
875
+ self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
876
+ self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
877
+ self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
878
+
879
+ def check_constraint(
880
+ self, prediction: dict[str, Tensor]
881
+ ) -> tuple[Tensor, int]:
882
+
883
+ # Select relevant columns
884
+ selection_a = prediction[self.layer_a][:, self.index_a]
885
+ selection_b = prediction[self.layer_b][:, self.index_b]
886
+
887
+ # Apply transformations
888
+ selection_a = self.transformation_a(selection_a)
889
+ selection_b = self.transformation_b(selection_b)
890
+
891
+ # Calculate result
892
+ result = isclose(
893
+ square(selection_a) + square(selection_b),
894
+ ones_like(selection_a, device=self.device),
895
+ rtol=self.rtol,
896
+ atol=self.atol,
897
+ ).float()
898
+
899
+ return result, numel(result)
900
+
901
+ def calculate_direction(
902
+ self, prediction: dict[str, Tensor]
903
+ ) -> Dict[str, Tensor]:
904
+ # NOTE currently only works for dense layers due
905
+ # to neuron to index translation
906
+
907
+ output = {}
908
+
909
+ for layer in self.layers:
910
+ output[layer] = zeros_like(prediction[layer], device=self.device)
911
+
912
+ a = prediction[self.layer_a][:, self.index_a]
913
+ b = prediction[self.layer_b][:, self.index_b]
914
+ m = sqrt(square(a) + square(b))
915
+
916
+ output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
917
+ output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
918
+
919
+ return output