congrads 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +17 -10
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +664 -134
- congrads/core.py +482 -110
- congrads/datasets.py +315 -11
- congrads/descriptor.py +100 -20
- congrads/metrics.py +178 -16
- congrads/networks.py +47 -23
- congrads/requirements.txt +6 -0
- congrads/transformations.py +139 -0
- congrads/utils.py +439 -39
- congrads-1.0.1.dist-info/METADATA +208 -0
- congrads-1.0.1.dist-info/RECORD +16 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/WHEEL +1 -1
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/LICENSE +0 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/top_level.txt +0 -0
congrads/constraints.py
CHANGED
|
@@ -1,27 +1,105 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module provides a set of constraint classes for guiding neural network
|
|
3
|
+
training by enforcing specific conditions on the network's outputs.
|
|
4
|
+
|
|
5
|
+
The constraints in this module include:
|
|
6
|
+
|
|
7
|
+
- `Constraint`: The base class for all constraint types, defining the
|
|
8
|
+
interface and core behavior.
|
|
9
|
+
- `ImplicationConstraint`: A constraint that enforces one condition only if
|
|
10
|
+
another condition is met, useful for modeling implications between network
|
|
11
|
+
outputs.
|
|
12
|
+
- `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
|
|
13
|
+
a network's output.
|
|
14
|
+
- `BinaryConstraint`: A constraint that enforces a binary comparison between
|
|
15
|
+
two neurons in the network, using a comparison function (e.g., less than,
|
|
16
|
+
greater than).
|
|
17
|
+
- `SumConstraint`: A constraint that enforces that the sum of certain neurons'
|
|
18
|
+
outputs equals a specified value, which can be used to control total output.
|
|
19
|
+
- `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
|
|
20
|
+
on a set of neurons, ensuring that the square of one neuron's output is equal
|
|
21
|
+
to the sum of the squares of other outputs.
|
|
22
|
+
|
|
23
|
+
These constraints can be used to steer the learning process by applying
|
|
24
|
+
conditions such as logical implications or numerical bounds.
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
1. Define a custom constraint class by inheriting from `Constraint`.
|
|
28
|
+
2. Apply the constraint to your neural network during training to
|
|
29
|
+
enforce desired output behaviors.
|
|
30
|
+
3. Use the helper classes like `IdentityTransformation` for handling
|
|
31
|
+
transformations and comparisons in constraints.
|
|
32
|
+
|
|
33
|
+
Dependencies:
|
|
34
|
+
- PyTorch (`torch`)
|
|
35
|
+
"""
|
|
36
|
+
|
|
3
37
|
import random
|
|
4
38
|
import string
|
|
5
|
-
|
|
39
|
+
import warnings
|
|
40
|
+
from abc import ABC, abstractmethod
|
|
41
|
+
from numbers import Number
|
|
42
|
+
from typing import Callable, Dict, Union
|
|
43
|
+
|
|
6
44
|
from torch import (
|
|
7
45
|
Tensor,
|
|
46
|
+
count_nonzero,
|
|
8
47
|
ge,
|
|
9
48
|
gt,
|
|
10
|
-
|
|
49
|
+
isclose,
|
|
11
50
|
le,
|
|
51
|
+
logical_not,
|
|
52
|
+
logical_or,
|
|
53
|
+
lt,
|
|
54
|
+
numel,
|
|
55
|
+
ones,
|
|
56
|
+
ones_like,
|
|
12
57
|
reshape,
|
|
58
|
+
sign,
|
|
59
|
+
sqrt,
|
|
60
|
+
square,
|
|
13
61
|
stack,
|
|
14
|
-
ones,
|
|
15
62
|
tensor,
|
|
16
63
|
zeros_like,
|
|
17
64
|
)
|
|
18
|
-
import logging
|
|
19
65
|
from torch.nn.functional import normalize
|
|
20
66
|
|
|
21
67
|
from .descriptor import Descriptor
|
|
68
|
+
from .transformations import IdentityTransformation, Transformation
|
|
69
|
+
from .utils import validate_comparator_pytorch, validate_iterable, validate_type
|
|
22
70
|
|
|
23
71
|
|
|
24
72
|
class Constraint(ABC):
|
|
73
|
+
"""
|
|
74
|
+
Abstract base class for defining constraints applied to neural networks.
|
|
75
|
+
|
|
76
|
+
A `Constraint` specifies conditions that the neural network outputs
|
|
77
|
+
should satisfy. It supports monitoring constraint satisfaction
|
|
78
|
+
during training and can adjust loss to enforce constraints. Subclasses
|
|
79
|
+
must implement the `check_constraint` and `calculate_direction` methods.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
neurons (set[str]): Names of the neurons this constraint applies to.
|
|
83
|
+
name (str, optional): A unique name for the constraint. If not provided,
|
|
84
|
+
a name is generated based on the class name and a random suffix.
|
|
85
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
86
|
+
without adjusting the loss. Defaults to False.
|
|
87
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
88
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
89
|
+
than 1 to give weight to the constraint.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
93
|
+
ValueError: If any neuron in `neurons` is not
|
|
94
|
+
defined in the `descriptor`.
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
- If `rescale_factor <= 1`, a warning is issued, and the value is
|
|
98
|
+
adjusted to a positive value greater than 1.
|
|
99
|
+
- If `name` is not provided, a name is auto-generated,
|
|
100
|
+
and a warning is logged.
|
|
101
|
+
|
|
102
|
+
"""
|
|
25
103
|
|
|
26
104
|
descriptor: Descriptor = None
|
|
27
105
|
device = None
|
|
@@ -30,23 +108,39 @@ class Constraint(ABC):
|
|
|
30
108
|
self,
|
|
31
109
|
neurons: set[str],
|
|
32
110
|
name: str = None,
|
|
33
|
-
|
|
111
|
+
monitor_only: bool = False,
|
|
112
|
+
rescale_factor: Number = 1.5,
|
|
34
113
|
) -> None:
|
|
114
|
+
"""
|
|
115
|
+
Initializes a new Constraint instance.
|
|
116
|
+
"""
|
|
35
117
|
|
|
36
118
|
# Init parent class
|
|
37
119
|
super().__init__()
|
|
38
120
|
|
|
121
|
+
# Type checking
|
|
122
|
+
validate_iterable("neurons", neurons, str)
|
|
123
|
+
validate_type("name", name, (str, type(None)))
|
|
124
|
+
validate_type("monitor_only", monitor_only, bool)
|
|
125
|
+
validate_type("rescale_factor", rescale_factor, Number)
|
|
126
|
+
|
|
39
127
|
# Init object variables
|
|
40
128
|
self.neurons = neurons
|
|
41
129
|
self.rescale_factor = rescale_factor
|
|
130
|
+
self.monitor_only = monitor_only
|
|
42
131
|
|
|
43
132
|
# Perform checks
|
|
44
133
|
if rescale_factor <= 1:
|
|
45
|
-
|
|
46
|
-
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"Rescale factor for constraint %s is <= 1. The network \
|
|
136
|
+
will favor general loss over the constraint-adjusted loss. \
|
|
137
|
+
Is this intended behavior? Normally, the loss should \
|
|
138
|
+
always be larger than 1.",
|
|
139
|
+
name,
|
|
47
140
|
)
|
|
48
141
|
|
|
49
|
-
# If no constraint_name is set, generate one based
|
|
142
|
+
# If no constraint_name is set, generate one based
|
|
143
|
+
# on the class name and a random suffix
|
|
50
144
|
if name:
|
|
51
145
|
self.name = name
|
|
52
146
|
else:
|
|
@@ -54,13 +148,19 @@ class Constraint(ABC):
|
|
|
54
148
|
random.choices(string.ascii_uppercase + string.digits, k=6)
|
|
55
149
|
)
|
|
56
150
|
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
57
|
-
|
|
151
|
+
warnings.warn(
|
|
152
|
+
"Name for constraint is not set. Using %s.", self.name
|
|
153
|
+
)
|
|
58
154
|
|
|
59
155
|
# If rescale factor is not larger than 1, warn user and adjust
|
|
60
|
-
if
|
|
156
|
+
if rescale_factor <= 1:
|
|
61
157
|
self.rescale_factor = abs(rescale_factor) + 1.5
|
|
62
|
-
|
|
63
|
-
|
|
158
|
+
warnings.warn(
|
|
159
|
+
"Rescale factor for constraint %s is < 1, adjusted value \
|
|
160
|
+
%s to %s.",
|
|
161
|
+
name,
|
|
162
|
+
rescale_factor,
|
|
163
|
+
self.rescale_factor,
|
|
64
164
|
)
|
|
65
165
|
else:
|
|
66
166
|
self.rescale_factor = rescale_factor
|
|
@@ -70,83 +170,250 @@ class Constraint(ABC):
|
|
|
70
170
|
for neuron in self.neurons:
|
|
71
171
|
if neuron not in self.descriptor.neuron_to_layer.keys():
|
|
72
172
|
raise ValueError(
|
|
73
|
-
f'The neuron name {neuron} used with constraint
|
|
173
|
+
f'The neuron name {neuron} used with constraint \
|
|
174
|
+
{self.name} is not defined in the descriptor. Please \
|
|
175
|
+
add it to the correct layer using \
|
|
176
|
+
descriptor.add("layer", ...).'
|
|
74
177
|
)
|
|
75
178
|
|
|
76
179
|
self.layers.add(self.descriptor.neuron_to_layer[neuron])
|
|
77
180
|
|
|
78
|
-
# TODO only denormalize if required for efficiency
|
|
79
|
-
def _denormalize(self, input: Tensor, neuron_names: list[str]):
|
|
80
|
-
# Extract min and max for each neuron
|
|
81
|
-
min_values = tensor(
|
|
82
|
-
[self.descriptor.neuron_to_minmax[name][0] for name in neuron_names],
|
|
83
|
-
device=input.device,
|
|
84
|
-
)
|
|
85
|
-
max_values = tensor(
|
|
86
|
-
[self.descriptor.neuron_to_minmax[name][1] for name in neuron_names],
|
|
87
|
-
device=input.device,
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# Apply vectorized denormalization
|
|
91
|
-
return input * (max_values - min_values) + min_values
|
|
92
|
-
|
|
93
181
|
@abstractmethod
|
|
94
|
-
def check_constraint(
|
|
182
|
+
def check_constraint(
|
|
183
|
+
self, prediction: dict[str, Tensor]
|
|
184
|
+
) -> tuple[Tensor, int]:
|
|
185
|
+
"""
|
|
186
|
+
Evaluates whether the given model predictions satisfy the constraint.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
prediction (dict[str, Tensor]): Model predictions for the neurons.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
tuple[Tensor, int]: A tuple where the first element is a tensor
|
|
193
|
+
indicating whether the constraint is satisfied (with `True`
|
|
194
|
+
for satisfaction, `False` for non-satisfaction, and `torch.nan`
|
|
195
|
+
for irrelevant results), and the second element is an integer
|
|
196
|
+
value representing the number of relevant constraints.
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
NotImplementedError: If not implemented in a subclass.
|
|
200
|
+
"""
|
|
201
|
+
|
|
95
202
|
raise NotImplementedError
|
|
96
203
|
|
|
97
204
|
@abstractmethod
|
|
98
|
-
def calculate_direction(
|
|
205
|
+
def calculate_direction(
|
|
206
|
+
self, prediction: dict[str, Tensor]
|
|
207
|
+
) -> Dict[str, Tensor]:
|
|
208
|
+
"""
|
|
209
|
+
Calculates adjustment directions for neurons to
|
|
210
|
+
better satisfy the constraint.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
prediction (dict[str, Tensor]): Model predictions for the neurons.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
|
|
217
|
+
specifying the adjustment direction for each neuron.
|
|
218
|
+
|
|
219
|
+
Raises:
|
|
220
|
+
NotImplementedError: If not implemented in a subclass.
|
|
221
|
+
"""
|
|
222
|
+
|
|
99
223
|
raise NotImplementedError
|
|
100
224
|
|
|
101
225
|
|
|
226
|
+
class ImplicationConstraint(Constraint):
|
|
227
|
+
"""
|
|
228
|
+
Represents an implication constraint between two
|
|
229
|
+
constraints (head and body).
|
|
230
|
+
|
|
231
|
+
The implication constraint ensures that the `body` constraint only applies
|
|
232
|
+
when the `head` constraint is satisfied. If the `head` constraint is not
|
|
233
|
+
satisfied, the `body` constraint does not apply.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
head (Constraint): The head of the implication. If this constraint
|
|
237
|
+
is satisfied, the body constraint must also be satisfied.
|
|
238
|
+
body (Constraint): The body of the implication. This constraint
|
|
239
|
+
is enforced only when the head constraint is satisfied.
|
|
240
|
+
name (str, optional): A unique name for the constraint. If not
|
|
241
|
+
provided, the name is generated in the format
|
|
242
|
+
"{body.name} if {head.name}". Defaults to None.
|
|
243
|
+
monitor_only (bool, optional): If True, the constraint is only
|
|
244
|
+
monitored without adjusting the loss. Defaults to False.
|
|
245
|
+
rescale_factor (Number, optional): The scaling factor for the
|
|
246
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
250
|
+
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
head: Constraint,
|
|
256
|
+
body: Constraint,
|
|
257
|
+
name=None,
|
|
258
|
+
monitor_only=False,
|
|
259
|
+
rescale_factor=1.5,
|
|
260
|
+
):
|
|
261
|
+
"""
|
|
262
|
+
Initializes an ImplicationConstraint instance.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# Type checking
|
|
266
|
+
validate_type("head", head, Constraint)
|
|
267
|
+
validate_type("body", body, Constraint)
|
|
268
|
+
|
|
269
|
+
# Compose constraint name
|
|
270
|
+
name = f"{body.name} if {head.name}"
|
|
271
|
+
|
|
272
|
+
# Init parent class
|
|
273
|
+
super().__init__(
|
|
274
|
+
head.neurons | body.neurons,
|
|
275
|
+
name,
|
|
276
|
+
monitor_only,
|
|
277
|
+
rescale_factor,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
self.head = head
|
|
281
|
+
self.body = body
|
|
282
|
+
|
|
283
|
+
def check_constraint(
|
|
284
|
+
self, prediction: dict[str, Tensor]
|
|
285
|
+
) -> tuple[Tensor, int]:
|
|
286
|
+
|
|
287
|
+
# Check satisfaction of head and body constraints
|
|
288
|
+
head_satisfaction, _ = self.head.check_constraint(prediction)
|
|
289
|
+
body_satisfaction, _ = self.body.check_constraint(prediction)
|
|
290
|
+
|
|
291
|
+
# If head constraint is satisfied (returning 1),
|
|
292
|
+
# the body constraint matters (and should return 0/1 based on body)
|
|
293
|
+
# If head constraint is not satisfied (returning 0),
|
|
294
|
+
# the body constraint does not apply (and should return 1)
|
|
295
|
+
result = logical_or(
|
|
296
|
+
logical_not(head_satisfaction), body_satisfaction
|
|
297
|
+
).float()
|
|
298
|
+
|
|
299
|
+
return result, count_nonzero(head_satisfaction)
|
|
300
|
+
|
|
301
|
+
def calculate_direction(
|
|
302
|
+
self, prediction: dict[str, Tensor]
|
|
303
|
+
) -> Dict[str, Tensor]:
|
|
304
|
+
# NOTE currently only works for dense layers
|
|
305
|
+
# due to neuron to index translation
|
|
306
|
+
|
|
307
|
+
# Use directions of constraint body as update vector
|
|
308
|
+
return self.body.calculate_direction(prediction)
|
|
309
|
+
|
|
310
|
+
|
|
102
311
|
class ScalarConstraint(Constraint):
|
|
312
|
+
"""
|
|
313
|
+
A constraint that enforces scalar-based comparisons on a specific neuron.
|
|
314
|
+
|
|
315
|
+
This class ensures that the output of a specified neuron satisfies a scalar
|
|
316
|
+
comparison operation (e.g., less than, greater than, etc.). It uses a
|
|
317
|
+
comparator function to validate the condition and calculates adjustment
|
|
318
|
+
directions accordingly.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
operand (Union[str, Transformation]): Name of the neuron or a
|
|
322
|
+
transformation to apply.
|
|
323
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
324
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
325
|
+
scalar (Number): The scalar value to compare against.
|
|
326
|
+
name (str, optional): A unique name for the constraint. If not
|
|
327
|
+
provided, a name is auto-generated in the format
|
|
328
|
+
"<neuron_name> <comparator> <scalar>".
|
|
329
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
330
|
+
without adjusting the loss. Defaults to False.
|
|
331
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
332
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
336
|
+
|
|
337
|
+
Notes:
|
|
338
|
+
- The `neuron_name` must be defined in the `descriptor` mapping.
|
|
339
|
+
- The constraint name is composed using the neuron name,
|
|
340
|
+
comparator, and scalar value.
|
|
341
|
+
|
|
342
|
+
"""
|
|
103
343
|
|
|
104
344
|
def __init__(
|
|
105
345
|
self,
|
|
106
|
-
|
|
346
|
+
operand: Union[str, Transformation],
|
|
107
347
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
108
348
|
scalar: Number,
|
|
109
349
|
name: str = None,
|
|
110
|
-
|
|
350
|
+
monitor_only: bool = False,
|
|
351
|
+
rescale_factor: Number = 1.5,
|
|
111
352
|
) -> None:
|
|
353
|
+
"""
|
|
354
|
+
Initializes a ScalarConstraint instance.
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
# Type checking
|
|
358
|
+
validate_type("operand", operand, (str, Transformation))
|
|
359
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
360
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
361
|
+
validate_type("scalar", scalar, Number)
|
|
362
|
+
|
|
363
|
+
# If transformation is provided, get neuron name,
|
|
364
|
+
# else use IdentityTransformation
|
|
365
|
+
if isinstance(operand, Transformation):
|
|
366
|
+
neuron_name = operand.neuron_name
|
|
367
|
+
transformation = operand
|
|
368
|
+
else:
|
|
369
|
+
neuron_name = operand
|
|
370
|
+
transformation = IdentityTransformation(neuron_name)
|
|
112
371
|
|
|
113
372
|
# Compose constraint name
|
|
114
|
-
name = f"{neuron_name}
|
|
373
|
+
name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
|
|
115
374
|
|
|
116
375
|
# Init parent class
|
|
117
|
-
super().__init__({neuron_name}, name, rescale_factor)
|
|
376
|
+
super().__init__({neuron_name}, name, monitor_only, rescale_factor)
|
|
118
377
|
|
|
119
378
|
# Init variables
|
|
120
379
|
self.comparator = comparator
|
|
121
380
|
self.scalar = scalar
|
|
381
|
+
self.transformation = transformation
|
|
122
382
|
|
|
123
383
|
# Get layer name and feature index from neuron_name
|
|
124
384
|
self.layer = self.descriptor.neuron_to_layer[neuron_name]
|
|
125
385
|
self.index = self.descriptor.neuron_to_index[neuron_name]
|
|
126
386
|
|
|
127
|
-
# If comparator function is not supported, raise error
|
|
128
|
-
if comparator not in [ge, le, gt, lt]:
|
|
129
|
-
raise ValueError(
|
|
130
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
131
|
-
)
|
|
132
|
-
|
|
133
387
|
# Calculate directions based on constraint operator
|
|
134
388
|
if self.comparator in [lt, le]:
|
|
135
389
|
self.direction = -1
|
|
136
390
|
elif self.comparator in [gt, ge]:
|
|
137
391
|
self.direction = 1
|
|
138
392
|
|
|
139
|
-
def check_constraint(
|
|
393
|
+
def check_constraint(
|
|
394
|
+
self, prediction: dict[str, Tensor]
|
|
395
|
+
) -> tuple[Tensor, int]:
|
|
140
396
|
|
|
141
|
-
|
|
397
|
+
# Select relevant columns
|
|
398
|
+
selection = prediction[self.layer][:, self.index]
|
|
142
399
|
|
|
143
|
-
|
|
144
|
-
|
|
400
|
+
# Apply transformation
|
|
401
|
+
selection = self.transformation(selection)
|
|
402
|
+
|
|
403
|
+
# Calculate current constraint result
|
|
404
|
+
result = self.comparator(selection, self.scalar).float()
|
|
405
|
+
return result, numel(result)
|
|
406
|
+
|
|
407
|
+
def calculate_direction(
|
|
408
|
+
self, prediction: dict[str, Tensor]
|
|
409
|
+
) -> Dict[str, Tensor]:
|
|
410
|
+
# NOTE currently only works for dense layers due
|
|
411
|
+
# to neuron to index translation
|
|
145
412
|
|
|
146
413
|
output = {}
|
|
147
414
|
|
|
148
415
|
for layer in self.layers:
|
|
149
|
-
output[layer] = zeros_like(prediction[layer][0])
|
|
416
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
150
417
|
|
|
151
418
|
output[self.layer][self.index] = self.direction
|
|
152
419
|
|
|
@@ -157,28 +424,89 @@ class ScalarConstraint(Constraint):
|
|
|
157
424
|
|
|
158
425
|
|
|
159
426
|
class BinaryConstraint(Constraint):
|
|
427
|
+
"""
|
|
428
|
+
A constraint that enforces a binary comparison between two neurons.
|
|
429
|
+
|
|
430
|
+
This class ensures that the output of one neuron satisfies a comparison
|
|
431
|
+
operation with the output of another neuron
|
|
432
|
+
(e.g., less than, greater than, etc.). It uses a comparator function to
|
|
433
|
+
validate the condition and calculates adjustment directions accordingly.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
437
|
+
neuron or a transformation to apply.
|
|
438
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
439
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
440
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
441
|
+
neuron or a transformation to apply.
|
|
442
|
+
name (str, optional): A unique name for the constraint. If not
|
|
443
|
+
provided, a name is auto-generated in the format
|
|
444
|
+
"<neuron_name_left> <comparator> <neuron_name_right>".
|
|
445
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
446
|
+
without adjusting the loss. Defaults to False.
|
|
447
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
448
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
449
|
+
|
|
450
|
+
Raises:
|
|
451
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
452
|
+
|
|
453
|
+
Notes:
|
|
454
|
+
- The neuron names must be defined in the `descriptor` mapping.
|
|
455
|
+
- The constraint name is composed using the left neuron name,
|
|
456
|
+
comparator, and right neuron name.
|
|
457
|
+
|
|
458
|
+
"""
|
|
160
459
|
|
|
161
460
|
def __init__(
|
|
162
461
|
self,
|
|
163
|
-
|
|
462
|
+
operand_left: Union[str, Transformation],
|
|
164
463
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
165
|
-
|
|
464
|
+
operand_right: Union[str, Transformation],
|
|
166
465
|
name: str = None,
|
|
167
|
-
|
|
466
|
+
monitor_only: bool = False,
|
|
467
|
+
rescale_factor: Number = 1.5,
|
|
168
468
|
) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Initializes a BinaryConstraint instance.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
# Type checking
|
|
474
|
+
validate_type("operand_left", operand_left, (str, Transformation))
|
|
475
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
476
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
477
|
+
validate_type("operand_right", operand_right, (str, Transformation))
|
|
478
|
+
|
|
479
|
+
# If transformation is provided, get neuron name,
|
|
480
|
+
# else use IdentityTransformation
|
|
481
|
+
if isinstance(operand_left, Transformation):
|
|
482
|
+
neuron_name_left = operand_left.neuron_name
|
|
483
|
+
transformation_left = operand_left
|
|
484
|
+
else:
|
|
485
|
+
neuron_name_left = operand_left
|
|
486
|
+
transformation_left = IdentityTransformation(neuron_name_left)
|
|
487
|
+
|
|
488
|
+
if isinstance(operand_right, Transformation):
|
|
489
|
+
neuron_name_right = operand_right.neuron_name
|
|
490
|
+
transformation_right = operand_right
|
|
491
|
+
else:
|
|
492
|
+
neuron_name_right = operand_right
|
|
493
|
+
transformation_right = IdentityTransformation(neuron_name_right)
|
|
169
494
|
|
|
170
495
|
# Compose constraint name
|
|
171
|
-
name = f"{neuron_name_left}
|
|
496
|
+
name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
|
|
172
497
|
|
|
173
498
|
# Init parent class
|
|
174
499
|
super().__init__(
|
|
175
500
|
{neuron_name_left, neuron_name_right},
|
|
176
501
|
name,
|
|
502
|
+
monitor_only,
|
|
177
503
|
rescale_factor,
|
|
178
504
|
)
|
|
179
505
|
|
|
180
506
|
# Init variables
|
|
181
507
|
self.comparator = comparator
|
|
508
|
+
self.transformation_left = transformation_left
|
|
509
|
+
self.transformation_right = transformation_right
|
|
182
510
|
|
|
183
511
|
# Get layer name and feature index from neuron_name
|
|
184
512
|
self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
@@ -186,12 +514,6 @@ class BinaryConstraint(Constraint):
|
|
|
186
514
|
self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
|
|
187
515
|
self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
|
|
188
516
|
|
|
189
|
-
# If comparator function is not supported, raise error
|
|
190
|
-
if comparator not in [ge, le, gt, lt]:
|
|
191
|
-
raise RuntimeError(
|
|
192
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
193
|
-
)
|
|
194
|
-
|
|
195
517
|
# Calculate directions based on constraint operator
|
|
196
518
|
if self.comparator in [lt, le]:
|
|
197
519
|
self.direction_left = -1
|
|
@@ -200,20 +522,32 @@ class BinaryConstraint(Constraint):
|
|
|
200
522
|
self.direction_left = 1
|
|
201
523
|
self.direction_right = -1
|
|
202
524
|
|
|
203
|
-
def check_constraint(
|
|
525
|
+
def check_constraint(
|
|
526
|
+
self, prediction: dict[str, Tensor]
|
|
527
|
+
) -> tuple[Tensor, int]:
|
|
204
528
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
529
|
+
# Select relevant columns
|
|
530
|
+
selection_left = prediction[self.layer_left][:, self.index_left]
|
|
531
|
+
selection_right = prediction[self.layer_right][:, self.index_right]
|
|
532
|
+
|
|
533
|
+
# Apply transformations
|
|
534
|
+
selection_left = self.transformation_left(selection_left)
|
|
535
|
+
selection_right = self.transformation_right(selection_right)
|
|
536
|
+
|
|
537
|
+
result = self.comparator(selection_left, selection_right).float()
|
|
209
538
|
|
|
210
|
-
|
|
211
|
-
|
|
539
|
+
return result, numel(result)
|
|
540
|
+
|
|
541
|
+
def calculate_direction(
|
|
542
|
+
self, prediction: dict[str, Tensor]
|
|
543
|
+
) -> Dict[str, Tensor]:
|
|
544
|
+
# NOTE currently only works for dense layers due
|
|
545
|
+
# to neuron to index translation
|
|
212
546
|
|
|
213
547
|
output = {}
|
|
214
548
|
|
|
215
549
|
for layer in self.layers:
|
|
216
|
-
output[layer] = zeros_like(prediction[layer][0])
|
|
550
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
217
551
|
|
|
218
552
|
output[self.layer_left][self.index_left] = self.direction_left
|
|
219
553
|
output[self.layer_right][self.index_right] = self.direction_right
|
|
@@ -225,40 +559,129 @@ class BinaryConstraint(Constraint):
|
|
|
225
559
|
|
|
226
560
|
|
|
227
561
|
class SumConstraint(Constraint):
|
|
562
|
+
"""
|
|
563
|
+
A constraint that enforces a weighted summation comparison
|
|
564
|
+
between two groups of neurons.
|
|
565
|
+
|
|
566
|
+
This class evaluates whether the weighted sum of outputs from one set of
|
|
567
|
+
neurons satisfies a comparison operation with the weighted sum of
|
|
568
|
+
outputs from another set of neurons.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
operands_left (list[Union[str, Transformation]]): List of neuron
|
|
572
|
+
names or transformations on the left side.
|
|
573
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
574
|
+
function for the constraint.
|
|
575
|
+
operands_right (list[Union[str, Transformation]]): List of neuron
|
|
576
|
+
names or transformations on the right side.
|
|
577
|
+
weights_left (list[Number], optional): Weights for the left neurons.
|
|
578
|
+
Defaults to None.
|
|
579
|
+
weights_right (list[Number], optional): Weights for the right
|
|
580
|
+
neurons. Defaults to None.
|
|
581
|
+
name (str, optional): Unique name for the constraint.
|
|
582
|
+
If None, it's auto-generated. Defaults to None.
|
|
583
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
584
|
+
without adjusting the loss. Defaults to False.
|
|
585
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
586
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
590
|
+
ValueError: If the dimensions of neuron names and weights mismatch.
|
|
591
|
+
|
|
592
|
+
"""
|
|
593
|
+
|
|
228
594
|
def __init__(
|
|
229
595
|
self,
|
|
230
|
-
|
|
596
|
+
operands_left: list[Union[str, Transformation]],
|
|
231
597
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
232
|
-
|
|
233
|
-
weights_left: list[
|
|
234
|
-
weights_right: list[
|
|
598
|
+
operands_right: list[Union[str, Transformation]],
|
|
599
|
+
weights_left: list[Number] = None,
|
|
600
|
+
weights_right: list[Number] = None,
|
|
235
601
|
name: str = None,
|
|
236
|
-
|
|
602
|
+
monitor_only: bool = False,
|
|
603
|
+
rescale_factor: Number = 1.5,
|
|
237
604
|
) -> None:
|
|
605
|
+
"""
|
|
606
|
+
Initializes the SumConstraint.
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
# Type checking
|
|
610
|
+
validate_iterable("operands_left", operands_left, (str, Transformation))
|
|
611
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
612
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
613
|
+
validate_iterable(
|
|
614
|
+
"operands_right", operands_right, (str, Transformation)
|
|
615
|
+
)
|
|
616
|
+
validate_iterable("weights_left", weights_left, Number, allow_none=True)
|
|
617
|
+
validate_iterable(
|
|
618
|
+
"weights_right", weights_right, Number, allow_none=True
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# If transformation is provided, get neuron name,
|
|
622
|
+
# else use IdentityTransformation
|
|
623
|
+
neuron_names_left: list[str] = []
|
|
624
|
+
transformations_left: list[Transformation] = []
|
|
625
|
+
for operand_left in operands_left:
|
|
626
|
+
if isinstance(operand_left, Transformation):
|
|
627
|
+
neuron_name_left = operand_left.neuron_name
|
|
628
|
+
neuron_names_left.append(neuron_name_left)
|
|
629
|
+
transformations_left.append(operand_left)
|
|
630
|
+
else:
|
|
631
|
+
neuron_name_left = operand_left
|
|
632
|
+
neuron_names_left.append(neuron_name_left)
|
|
633
|
+
transformations_left.append(
|
|
634
|
+
IdentityTransformation(neuron_name_left)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
neuron_names_right: list[str] = []
|
|
638
|
+
transformations_right: list[Transformation] = []
|
|
639
|
+
for operand_right in operands_right:
|
|
640
|
+
if isinstance(operand_right, Transformation):
|
|
641
|
+
neuron_name_right = operand_right.neuron_name
|
|
642
|
+
neuron_names_right.append(neuron_name_right)
|
|
643
|
+
transformations_right.append(operand_right)
|
|
644
|
+
else:
|
|
645
|
+
neuron_name_right = operand_right
|
|
646
|
+
neuron_names_right.append(neuron_name_right)
|
|
647
|
+
transformations_right.append(
|
|
648
|
+
IdentityTransformation(neuron_name_right)
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Compose constraint name
|
|
652
|
+
w_left = weights_left or [""] * len(neuron_names_left)
|
|
653
|
+
w_right = weights_right or [""] * len(neuron_names_right)
|
|
654
|
+
left_expr = " + ".join(
|
|
655
|
+
f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
|
|
656
|
+
)
|
|
657
|
+
right_expr = " + ".join(
|
|
658
|
+
f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
|
|
659
|
+
)
|
|
660
|
+
comparator_name = comparator.__name__
|
|
661
|
+
name = f"{left_expr} {comparator_name} {right_expr}"
|
|
238
662
|
|
|
239
663
|
# Init parent class
|
|
240
664
|
neuron_names = set(neuron_names_left) | set(neuron_names_right)
|
|
241
|
-
super().__init__(neuron_names, name, rescale_factor)
|
|
665
|
+
super().__init__(neuron_names, name, monitor_only, rescale_factor)
|
|
242
666
|
|
|
243
667
|
# Init variables
|
|
244
668
|
self.comparator = comparator
|
|
245
669
|
self.neuron_names_left = neuron_names_left
|
|
246
670
|
self.neuron_names_right = neuron_names_right
|
|
671
|
+
self.transformations_left = transformations_left
|
|
672
|
+
self.transformations_right = transformations_right
|
|
247
673
|
|
|
248
|
-
# If
|
|
249
|
-
|
|
250
|
-
raise ValueError(
|
|
251
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# If feature list dimensions don't match weight list dimensions, raise error
|
|
674
|
+
# If feature list dimensions don't match
|
|
675
|
+
# weight list dimensions, raise error
|
|
255
676
|
if weights_left and (len(neuron_names_left) != len(weights_left)):
|
|
256
677
|
raise ValueError(
|
|
257
|
-
"The dimensions of neuron_names_left don't match with the
|
|
678
|
+
"The dimensions of neuron_names_left don't match with the \
|
|
679
|
+
dimensions of weights_left."
|
|
258
680
|
)
|
|
259
681
|
if weights_right and (len(neuron_names_right) != len(weights_right)):
|
|
260
682
|
raise ValueError(
|
|
261
|
-
"The dimensions of neuron_names_right don't match with the
|
|
683
|
+
"The dimensions of neuron_names_right don't match with the \
|
|
684
|
+
dimensions of weights_right."
|
|
262
685
|
)
|
|
263
686
|
|
|
264
687
|
# If weights are provided for summation, transform them to Tensors
|
|
@@ -269,7 +692,9 @@ class SumConstraint(Constraint):
|
|
|
269
692
|
if weights_right:
|
|
270
693
|
self.weights_right = tensor(weights_right, device=self.device)
|
|
271
694
|
else:
|
|
272
|
-
self.weights_right = ones(
|
|
695
|
+
self.weights_right = ones(
|
|
696
|
+
len(neuron_names_right), device=self.device
|
|
697
|
+
)
|
|
273
698
|
|
|
274
699
|
# Calculate directions based on constraint operator
|
|
275
700
|
if self.comparator in [lt, le]:
|
|
@@ -279,9 +704,15 @@ class SumConstraint(Constraint):
|
|
|
279
704
|
self.direction_left = 1
|
|
280
705
|
self.direction_right = -1
|
|
281
706
|
|
|
282
|
-
def check_constraint(
|
|
707
|
+
def check_constraint(
|
|
708
|
+
self, prediction: dict[str, Tensor]
|
|
709
|
+
) -> tuple[Tensor, int]:
|
|
283
710
|
|
|
284
|
-
def compute_weighted_sum(
|
|
711
|
+
def compute_weighted_sum(
|
|
712
|
+
neuron_names: list[str],
|
|
713
|
+
transformations: list[Transformation],
|
|
714
|
+
weights: tensor,
|
|
715
|
+
) -> tensor:
|
|
285
716
|
layers = [
|
|
286
717
|
self.descriptor.neuron_to_layer[neuron_name]
|
|
287
718
|
for neuron_name in neuron_names
|
|
@@ -291,37 +722,53 @@ class SumConstraint(Constraint):
|
|
|
291
722
|
for neuron_name in neuron_names
|
|
292
723
|
]
|
|
293
724
|
|
|
725
|
+
# Select relevant column
|
|
726
|
+
selections = [
|
|
727
|
+
prediction[layer][:, index]
|
|
728
|
+
for layer, index in zip(layers, indices)
|
|
729
|
+
]
|
|
730
|
+
|
|
731
|
+
# Apply transformations
|
|
732
|
+
results = []
|
|
733
|
+
for transformation, selection in zip(transformations, selections):
|
|
734
|
+
results.append(transformation(selection))
|
|
735
|
+
|
|
294
736
|
# Extract predictions for all neurons and apply weights in bulk
|
|
295
737
|
predictions = stack(
|
|
296
|
-
|
|
738
|
+
results,
|
|
297
739
|
dim=1,
|
|
298
740
|
)
|
|
299
741
|
|
|
300
|
-
# Denormalize if required
|
|
301
|
-
predictions_denorm = self._denormalize(predictions, neuron_names)
|
|
302
|
-
|
|
303
742
|
# Calculate weighted sum
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
return weighted_sum
|
|
743
|
+
return (predictions * weights.unsqueeze(0)).sum(dim=1)
|
|
307
744
|
|
|
745
|
+
# Compute weighted sums
|
|
308
746
|
weighted_sum_left = compute_weighted_sum(
|
|
309
|
-
self.neuron_names_left,
|
|
747
|
+
self.neuron_names_left,
|
|
748
|
+
self.transformations_left,
|
|
749
|
+
self.weights_left,
|
|
310
750
|
)
|
|
311
751
|
weighted_sum_right = compute_weighted_sum(
|
|
312
|
-
self.neuron_names_right,
|
|
752
|
+
self.neuron_names_right,
|
|
753
|
+
self.transformations_right,
|
|
754
|
+
self.weights_right,
|
|
313
755
|
)
|
|
314
756
|
|
|
315
757
|
# Apply the comparator and calculate the result
|
|
316
|
-
|
|
758
|
+
result = self.comparator(weighted_sum_left, weighted_sum_right).float()
|
|
759
|
+
|
|
760
|
+
return result, numel(result)
|
|
317
761
|
|
|
318
|
-
def calculate_direction(
|
|
319
|
-
|
|
762
|
+
def calculate_direction(
|
|
763
|
+
self, prediction: dict[str, Tensor]
|
|
764
|
+
) -> Dict[str, Tensor]:
|
|
765
|
+
# NOTE currently only works for dense layers
|
|
766
|
+
# due to neuron to index translation
|
|
320
767
|
|
|
321
768
|
output = {}
|
|
322
769
|
|
|
323
770
|
for layer in self.layers:
|
|
324
|
-
output[layer] = zeros_like(prediction[layer][0])
|
|
771
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
325
772
|
|
|
326
773
|
for neuron_name_left in self.neuron_names_left:
|
|
327
774
|
layer = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
@@ -339,51 +786,134 @@ class SumConstraint(Constraint):
|
|
|
339
786
|
return output
|
|
340
787
|
|
|
341
788
|
|
|
342
|
-
|
|
343
|
-
|
|
789
|
+
class PythagoreanIdentityConstraint(Constraint):
|
|
790
|
+
"""
|
|
791
|
+
A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
|
|
792
|
+
where `a` and `b` are neurons or transformations.
|
|
793
|
+
|
|
794
|
+
This constraint checks that the sum of the squares of two specified
|
|
795
|
+
neurons (or their transformations) is approximately equal to 1.
|
|
796
|
+
The constraint is evaluated using relative and absolute
|
|
797
|
+
tolerance (`rtol` and `atol`) and is applied during the forward pass.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
a (Union[str, Transformation]): The first input, either a
|
|
801
|
+
neuron name (str) or a Transformation.
|
|
802
|
+
b (Union[str, Transformation]): The second input, either a
|
|
803
|
+
neuron name (str) or a Transformation.
|
|
804
|
+
rtol (float, optional): The relative tolerance for the
|
|
805
|
+
comparison (default is 0.00001).
|
|
806
|
+
atol (float, optional): The absolute tolerance for the
|
|
807
|
+
comparison (default is 1e-8).
|
|
808
|
+
name (str, optional): The name of the constraint
|
|
809
|
+
(default is None, and it is generated automatically).
|
|
810
|
+
monitor_only (bool, optional): Flag indicating whether the
|
|
811
|
+
constraint is only for monitoring (default is False).
|
|
812
|
+
rescale_factor (Number, optional): A factor used for
|
|
813
|
+
rescaling (default is 1.5).
|
|
814
|
+
|
|
815
|
+
Raises:
|
|
816
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
817
|
+
|
|
818
|
+
"""
|
|
344
819
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
820
|
+
def __init__(
|
|
821
|
+
self,
|
|
822
|
+
a: Union[str, Transformation],
|
|
823
|
+
b: Union[str, Transformation],
|
|
824
|
+
rtol: float = 0.00001,
|
|
825
|
+
atol: float = 1e-8,
|
|
826
|
+
name: str = None,
|
|
827
|
+
monitor_only: bool = False,
|
|
828
|
+
rescale_factor: Number = 1.5,
|
|
829
|
+
) -> None:
|
|
830
|
+
"""
|
|
831
|
+
Initialize the PythagoreanIdentityConstraint.
|
|
832
|
+
"""
|
|
833
|
+
|
|
834
|
+
# Type checking
|
|
835
|
+
validate_type("a", a, (str, Transformation))
|
|
836
|
+
validate_type("b", b, (str, Transformation))
|
|
837
|
+
validate_type("rtol", rtol, float)
|
|
838
|
+
validate_type("atol", atol, float)
|
|
839
|
+
|
|
840
|
+
# If transformation is provided, get neuron name,
|
|
841
|
+
# else use IdentityTransformation
|
|
842
|
+
if isinstance(a, Transformation):
|
|
843
|
+
neuron_name_a = a.neuron_name
|
|
844
|
+
transformation_a = a
|
|
845
|
+
else:
|
|
846
|
+
neuron_name_a = a
|
|
847
|
+
transformation_a = IdentityTransformation(neuron_name_a)
|
|
352
848
|
|
|
353
|
-
|
|
354
|
-
|
|
849
|
+
if isinstance(b, Transformation):
|
|
850
|
+
neuron_name_b = b.neuron_name
|
|
851
|
+
transformation_b = b
|
|
852
|
+
else:
|
|
853
|
+
neuron_name_b = b
|
|
854
|
+
transformation_b = IdentityTransformation(neuron_name_b)
|
|
355
855
|
|
|
356
|
-
#
|
|
357
|
-
|
|
856
|
+
# Compose constraint name
|
|
857
|
+
name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
|
|
858
|
+
|
|
859
|
+
# Init parent class
|
|
860
|
+
super().__init__(
|
|
861
|
+
{neuron_name_a, neuron_name_b},
|
|
862
|
+
name,
|
|
863
|
+
monitor_only,
|
|
864
|
+
rescale_factor,
|
|
865
|
+
)
|
|
358
866
|
|
|
359
|
-
#
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
867
|
+
# Init variables
|
|
868
|
+
self.transformation_a = transformation_a
|
|
869
|
+
self.transformation_b = transformation_b
|
|
870
|
+
self.rtol = rtol
|
|
871
|
+
self.atol = atol
|
|
363
872
|
|
|
364
|
-
#
|
|
365
|
-
|
|
366
|
-
|
|
873
|
+
# Get layer name and feature index from neuron_name
|
|
874
|
+
self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
|
|
875
|
+
self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
|
|
876
|
+
self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
|
|
877
|
+
self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
|
|
878
|
+
|
|
879
|
+
def check_constraint(
|
|
880
|
+
self, prediction: dict[str, Tensor]
|
|
881
|
+
) -> tuple[Tensor, int]:
|
|
882
|
+
|
|
883
|
+
# Select relevant columns
|
|
884
|
+
selection_a = prediction[self.layer_a][:, self.index_a]
|
|
885
|
+
selection_b = prediction[self.layer_b][:, self.index_b]
|
|
886
|
+
|
|
887
|
+
# Apply transformations
|
|
888
|
+
selection_a = self.transformation_a(selection_a)
|
|
889
|
+
selection_b = self.transformation_b(selection_b)
|
|
890
|
+
|
|
891
|
+
# Calculate result
|
|
892
|
+
result = isclose(
|
|
893
|
+
square(selection_a) + square(selection_b),
|
|
894
|
+
ones_like(selection_a, device=self.device),
|
|
895
|
+
rtol=self.rtol,
|
|
896
|
+
atol=self.atol,
|
|
897
|
+
).float()
|
|
898
|
+
|
|
899
|
+
return result, numel(result)
|
|
900
|
+
|
|
901
|
+
def calculate_direction(
|
|
902
|
+
self, prediction: dict[str, Tensor]
|
|
903
|
+
) -> Dict[str, Tensor]:
|
|
904
|
+
# NOTE currently only works for dense layers due
|
|
905
|
+
# to neuron to index translation
|
|
367
906
|
|
|
368
|
-
|
|
369
|
-
# # Check if values for column in batch are only increasing
|
|
370
|
-
# result = ~ge(
|
|
371
|
-
# diff(
|
|
372
|
-
# prediction[self.layer][:, self.index],
|
|
373
|
-
# prepend=zeros_like(
|
|
374
|
-
# prediction[self.layer][:, self.index][:1],
|
|
375
|
-
# device=prediction[self.layer].device,
|
|
376
|
-
# ),
|
|
377
|
-
# ),
|
|
378
|
-
# 0,
|
|
379
|
-
# )
|
|
907
|
+
output = {}
|
|
380
908
|
|
|
381
|
-
|
|
909
|
+
for layer in self.layers:
|
|
910
|
+
output[layer] = zeros_like(prediction[layer], device=self.device)
|
|
382
911
|
|
|
383
|
-
|
|
384
|
-
|
|
912
|
+
a = prediction[self.layer_a][:, self.index_a]
|
|
913
|
+
b = prediction[self.layer_b][:, self.index_b]
|
|
914
|
+
m = sqrt(square(a) + square(b))
|
|
385
915
|
|
|
386
|
-
|
|
387
|
-
|
|
916
|
+
output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
|
|
917
|
+
output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
|
|
388
918
|
|
|
389
|
-
|
|
919
|
+
return output
|