congrads 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +21 -13
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +728 -316
- congrads/core.py +525 -139
- congrads/datasets.py +273 -516
- congrads/descriptor.py +95 -30
- congrads/metrics.py +185 -38
- congrads/networks.py +51 -28
- congrads/requirements.txt +6 -0
- congrads/transformations.py +139 -0
- congrads/utils.py +710 -0
- congrads-1.0.1.dist-info/LICENSE +26 -0
- congrads-1.0.1.dist-info/METADATA +208 -0
- congrads-1.0.1.dist-info/RECORD +16 -0
- {congrads-0.1.0.dist-info → congrads-1.0.1.dist-info}/WHEEL +1 -1
- congrads/learners.py +0 -233
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- {congrads-0.1.0.dist-info → congrads-1.0.1.dist-info}/top_level.txt +0 -0
congrads/constraints.py
CHANGED
|
@@ -1,312 +1,512 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module provides a set of constraint classes for guiding neural network
|
|
3
|
+
training by enforcing specific conditions on the network's outputs.
|
|
4
|
+
|
|
5
|
+
The constraints in this module include:
|
|
6
|
+
|
|
7
|
+
- `Constraint`: The base class for all constraint types, defining the
|
|
8
|
+
interface and core behavior.
|
|
9
|
+
- `ImplicationConstraint`: A constraint that enforces one condition only if
|
|
10
|
+
another condition is met, useful for modeling implications between network
|
|
11
|
+
outputs.
|
|
12
|
+
- `ScalarConstraint`: A constraint that enforces scalar-based comparisons on
|
|
13
|
+
a network's output.
|
|
14
|
+
- `BinaryConstraint`: A constraint that enforces a binary comparison between
|
|
15
|
+
two neurons in the network, using a comparison function (e.g., less than,
|
|
16
|
+
greater than).
|
|
17
|
+
- `SumConstraint`: A constraint that enforces that the sum of certain neurons'
|
|
18
|
+
outputs equals a specified value, which can be used to control total output.
|
|
19
|
+
- `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
|
|
20
|
+
on a set of neurons, ensuring that the square of one neuron's output is equal
|
|
21
|
+
to the sum of the squares of other outputs.
|
|
22
|
+
|
|
23
|
+
These constraints can be used to steer the learning process by applying
|
|
24
|
+
conditions such as logical implications or numerical bounds.
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
1. Define a custom constraint class by inheriting from `Constraint`.
|
|
28
|
+
2. Apply the constraint to your neural network during training to
|
|
29
|
+
enforce desired output behaviors.
|
|
30
|
+
3. Use the helper classes like `IdentityTransformation` for handling
|
|
31
|
+
transformations and comparisons in constraints.
|
|
32
|
+
|
|
33
|
+
Dependencies:
|
|
34
|
+
- PyTorch (`torch`)
|
|
35
|
+
"""
|
|
36
|
+
|
|
3
37
|
import random
|
|
4
38
|
import string
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
import
|
|
39
|
+
import warnings
|
|
40
|
+
from abc import ABC, abstractmethod
|
|
41
|
+
from numbers import Number
|
|
42
|
+
from typing import Callable, Dict, Union
|
|
43
|
+
|
|
44
|
+
from torch import (
|
|
45
|
+
Tensor,
|
|
46
|
+
count_nonzero,
|
|
47
|
+
ge,
|
|
48
|
+
gt,
|
|
49
|
+
isclose,
|
|
50
|
+
le,
|
|
51
|
+
logical_not,
|
|
52
|
+
logical_or,
|
|
53
|
+
lt,
|
|
54
|
+
numel,
|
|
55
|
+
ones,
|
|
56
|
+
ones_like,
|
|
57
|
+
reshape,
|
|
58
|
+
sign,
|
|
59
|
+
sqrt,
|
|
60
|
+
square,
|
|
61
|
+
stack,
|
|
62
|
+
tensor,
|
|
63
|
+
zeros_like,
|
|
64
|
+
)
|
|
8
65
|
from torch.nn.functional import normalize
|
|
9
66
|
|
|
10
67
|
from .descriptor import Descriptor
|
|
68
|
+
from .transformations import IdentityTransformation, Transformation
|
|
69
|
+
from .utils import validate_comparator_pytorch, validate_iterable, validate_type
|
|
11
70
|
|
|
12
71
|
|
|
13
72
|
class Constraint(ABC):
|
|
14
73
|
"""
|
|
15
|
-
Abstract base class for defining constraints
|
|
16
|
-
|
|
17
|
-
A
|
|
18
|
-
|
|
74
|
+
Abstract base class for defining constraints applied to neural networks.
|
|
75
|
+
|
|
76
|
+
A `Constraint` specifies conditions that the neural network outputs
|
|
77
|
+
should satisfy. It supports monitoring constraint satisfaction
|
|
78
|
+
during training and can adjust loss to enforce constraints. Subclasses
|
|
79
|
+
must implement the `check_constraint` and `calculate_direction` methods.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
neurons (set[str]): Names of the neurons this constraint applies to.
|
|
83
|
+
name (str, optional): A unique name for the constraint. If not provided,
|
|
84
|
+
a name is generated based on the class name and a random suffix.
|
|
85
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
86
|
+
without adjusting the loss. Defaults to False.
|
|
87
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
88
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
89
|
+
than 1 to give weight to the constraint.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
93
|
+
ValueError: If any neuron in `neurons` is not
|
|
94
|
+
defined in the `descriptor`.
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
- If `rescale_factor <= 1`, a warning is issued, and the value is
|
|
98
|
+
adjusted to a positive value greater than 1.
|
|
99
|
+
- If `name` is not provided, a name is auto-generated,
|
|
100
|
+
and a warning is logged.
|
|
19
101
|
|
|
20
|
-
Attributes:
|
|
21
|
-
descriptor (Descriptor): The descriptor object that provides a mapping of neurons to layers.
|
|
22
|
-
constraint_name (str): A unique name for the constraint, which can be provided or generated automatically.
|
|
23
|
-
rescale_factor (float): A factor used to scale the influence of the constraint on the overall loss.
|
|
24
|
-
neuron_names (set[str]): A set of neuron names that are involved in the constraint.
|
|
25
|
-
layers (set): A set of layers associated with the neurons specified in `neuron_names`.
|
|
26
102
|
"""
|
|
27
103
|
|
|
28
104
|
descriptor: Descriptor = None
|
|
105
|
+
device = None
|
|
29
106
|
|
|
30
107
|
def __init__(
|
|
31
108
|
self,
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
109
|
+
neurons: set[str],
|
|
110
|
+
name: str = None,
|
|
111
|
+
monitor_only: bool = False,
|
|
112
|
+
rescale_factor: Number = 1.5,
|
|
35
113
|
) -> None:
|
|
36
114
|
"""
|
|
37
|
-
Initializes
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
neuron_names (set[str]): A set of neuron names that are affected by the constraint.
|
|
41
|
-
constraint_name (str, optional): A custom name for the constraint. If not provided, a random name is generated.
|
|
42
|
-
rescale_factor (float, optional): A factor that scales the influence of the constraint. Defaults to 1.5.
|
|
43
|
-
|
|
44
|
-
Raises:
|
|
45
|
-
ValueError: If the descriptor has not been set or if a neuron name is not found in the descriptor.
|
|
115
|
+
Initializes a new Constraint instance.
|
|
46
116
|
"""
|
|
47
117
|
|
|
48
118
|
# Init parent class
|
|
49
119
|
super().__init__()
|
|
50
120
|
|
|
121
|
+
# Type checking
|
|
122
|
+
validate_iterable("neurons", neurons, str)
|
|
123
|
+
validate_type("name", name, (str, type(None)))
|
|
124
|
+
validate_type("monitor_only", monitor_only, bool)
|
|
125
|
+
validate_type("rescale_factor", rescale_factor, Number)
|
|
126
|
+
|
|
51
127
|
# Init object variables
|
|
128
|
+
self.neurons = neurons
|
|
52
129
|
self.rescale_factor = rescale_factor
|
|
53
|
-
self.
|
|
130
|
+
self.monitor_only = monitor_only
|
|
54
131
|
|
|
55
132
|
# Perform checks
|
|
56
133
|
if rescale_factor <= 1:
|
|
57
|
-
|
|
58
|
-
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"Rescale factor for constraint %s is <= 1. The network \
|
|
136
|
+
will favor general loss over the constraint-adjusted loss. \
|
|
137
|
+
Is this intended behavior? Normally, the loss should \
|
|
138
|
+
always be larger than 1.",
|
|
139
|
+
name,
|
|
59
140
|
)
|
|
60
141
|
|
|
61
|
-
# If no constraint_name is set, generate one based
|
|
62
|
-
|
|
63
|
-
|
|
142
|
+
# If no constraint_name is set, generate one based
|
|
143
|
+
# on the class name and a random suffix
|
|
144
|
+
if name:
|
|
145
|
+
self.name = name
|
|
64
146
|
else:
|
|
65
147
|
random_suffix = "".join(
|
|
66
148
|
random.choices(string.ascii_uppercase + string.digits, k=6)
|
|
67
149
|
)
|
|
68
|
-
self.
|
|
69
|
-
|
|
70
|
-
|
|
150
|
+
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
151
|
+
warnings.warn(
|
|
152
|
+
"Name for constraint is not set. Using %s.", self.name
|
|
71
153
|
)
|
|
72
154
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
"The descriptor of the base Constraint class in not set. Please assign the descriptor to the general Constraint class with 'Constraint.descriptor = descriptor' before defining network-specific contraints."
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
if not rescale_factor > 1:
|
|
155
|
+
# If rescale factor is not larger than 1, warn user and adjust
|
|
156
|
+
if rescale_factor <= 1:
|
|
79
157
|
self.rescale_factor = abs(rescale_factor) + 1.5
|
|
80
|
-
|
|
81
|
-
|
|
158
|
+
warnings.warn(
|
|
159
|
+
"Rescale factor for constraint %s is < 1, adjusted value \
|
|
160
|
+
%s to %s.",
|
|
161
|
+
name,
|
|
162
|
+
rescale_factor,
|
|
163
|
+
self.rescale_factor,
|
|
82
164
|
)
|
|
83
165
|
else:
|
|
84
166
|
self.rescale_factor = rescale_factor
|
|
85
167
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.run_init_descriptor()
|
|
89
|
-
|
|
90
|
-
def run_init_descriptor(self) -> None:
|
|
91
|
-
"""
|
|
92
|
-
Initializes the layers associated with the constraint by mapping the neuron names to their corresponding layers
|
|
93
|
-
from the descriptor.
|
|
94
|
-
|
|
95
|
-
This method populates the `layers` attribute with layers associated with the neuron names provided in the constraint.
|
|
96
|
-
|
|
97
|
-
Raises:
|
|
98
|
-
ValueError: If a neuron name is not found in the descriptor's mapping of neurons to layers.
|
|
99
|
-
"""
|
|
100
|
-
|
|
168
|
+
# Infer layers from descriptor and neurons
|
|
101
169
|
self.layers = set()
|
|
102
|
-
for
|
|
103
|
-
if
|
|
104
|
-
self.layers.add(self.descriptor.neuron_to_layer[neuron_name])
|
|
105
|
-
else:
|
|
170
|
+
for neuron in self.neurons:
|
|
171
|
+
if neuron not in self.descriptor.neuron_to_layer.keys():
|
|
106
172
|
raise ValueError(
|
|
107
|
-
f'The neuron name {
|
|
173
|
+
f'The neuron name {neuron} used with constraint \
|
|
174
|
+
{self.name} is not defined in the descriptor. Please \
|
|
175
|
+
add it to the correct layer using \
|
|
176
|
+
descriptor.add("layer", ...).'
|
|
108
177
|
)
|
|
109
178
|
|
|
179
|
+
self.layers.add(self.descriptor.neuron_to_layer[neuron])
|
|
180
|
+
|
|
110
181
|
@abstractmethod
|
|
111
|
-
def check_constraint(
|
|
182
|
+
def check_constraint(
|
|
183
|
+
self, prediction: dict[str, Tensor]
|
|
184
|
+
) -> tuple[Tensor, int]:
|
|
112
185
|
"""
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
This method should be implemented in subclasses to define the specific logic for evaluating the constraint based on the model's predictions.
|
|
186
|
+
Evaluates whether the given model predictions satisfy the constraint.
|
|
116
187
|
|
|
117
188
|
Args:
|
|
118
|
-
prediction (dict[str, Tensor]):
|
|
189
|
+
prediction (dict[str, Tensor]): Model predictions for the neurons.
|
|
119
190
|
|
|
120
191
|
Returns:
|
|
121
|
-
|
|
192
|
+
tuple[Tensor, int]: A tuple where the first element is a tensor
|
|
193
|
+
indicating whether the constraint is satisfied (with `True`
|
|
194
|
+
for satisfaction, `False` for non-satisfaction, and `torch.nan`
|
|
195
|
+
for irrelevant results), and the second element is an integer
|
|
196
|
+
value representing the number of relevant constraints.
|
|
122
197
|
|
|
123
198
|
Raises:
|
|
124
|
-
NotImplementedError: If
|
|
199
|
+
NotImplementedError: If not implemented in a subclass.
|
|
125
200
|
"""
|
|
126
201
|
|
|
127
202
|
raise NotImplementedError
|
|
128
203
|
|
|
129
204
|
@abstractmethod
|
|
130
|
-
def calculate_direction(
|
|
205
|
+
def calculate_direction(
|
|
206
|
+
self, prediction: dict[str, Tensor]
|
|
207
|
+
) -> Dict[str, Tensor]:
|
|
131
208
|
"""
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
This method should be implemented in subclasses to define how to adjust the model's predictions based on the constraint.
|
|
209
|
+
Calculates adjustment directions for neurons to
|
|
210
|
+
better satisfy the constraint.
|
|
135
211
|
|
|
136
212
|
Args:
|
|
137
|
-
prediction (dict[str, Tensor]):
|
|
213
|
+
prediction (dict[str, Tensor]): Model predictions for the neurons.
|
|
138
214
|
|
|
139
215
|
Returns:
|
|
140
|
-
|
|
216
|
+
Dict[str, Tensor]: Dictionary mapping neuron layers to tensors
|
|
217
|
+
specifying the adjustment direction for each neuron.
|
|
141
218
|
|
|
142
219
|
Raises:
|
|
143
|
-
NotImplementedError: If
|
|
220
|
+
NotImplementedError: If not implemented in a subclass.
|
|
144
221
|
"""
|
|
222
|
+
|
|
145
223
|
raise NotImplementedError
|
|
146
224
|
|
|
147
225
|
|
|
148
|
-
class
|
|
226
|
+
class ImplicationConstraint(Constraint):
|
|
227
|
+
"""
|
|
228
|
+
Represents an implication constraint between two
|
|
229
|
+
constraints (head and body).
|
|
230
|
+
|
|
231
|
+
The implication constraint ensures that the `body` constraint only applies
|
|
232
|
+
when the `head` constraint is satisfied. If the `head` constraint is not
|
|
233
|
+
satisfied, the `body` constraint does not apply.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
head (Constraint): The head of the implication. If this constraint
|
|
237
|
+
is satisfied, the body constraint must also be satisfied.
|
|
238
|
+
body (Constraint): The body of the implication. This constraint
|
|
239
|
+
is enforced only when the head constraint is satisfied.
|
|
240
|
+
name (str, optional): A unique name for the constraint. If not
|
|
241
|
+
provided, the name is generated in the format
|
|
242
|
+
"{body.name} if {head.name}". Defaults to None.
|
|
243
|
+
monitor_only (bool, optional): If True, the constraint is only
|
|
244
|
+
monitored without adjusting the loss. Defaults to False.
|
|
245
|
+
rescale_factor (Number, optional): The scaling factor for the
|
|
246
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
250
|
+
|
|
149
251
|
"""
|
|
150
|
-
A subclass of the `Constraint` class that applies a scalar constraint on a specific neuron in the model.
|
|
151
252
|
|
|
152
|
-
|
|
153
|
-
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
head: Constraint,
|
|
256
|
+
body: Constraint,
|
|
257
|
+
name=None,
|
|
258
|
+
monitor_only=False,
|
|
259
|
+
rescale_factor=1.5,
|
|
260
|
+
):
|
|
261
|
+
"""
|
|
262
|
+
Initializes an ImplicationConstraint instance.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# Type checking
|
|
266
|
+
validate_type("head", head, Constraint)
|
|
267
|
+
validate_type("body", body, Constraint)
|
|
268
|
+
|
|
269
|
+
# Compose constraint name
|
|
270
|
+
name = f"{body.name} if {head.name}"
|
|
271
|
+
|
|
272
|
+
# Init parent class
|
|
273
|
+
super().__init__(
|
|
274
|
+
head.neurons | body.neurons,
|
|
275
|
+
name,
|
|
276
|
+
monitor_only,
|
|
277
|
+
rescale_factor,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
self.head = head
|
|
281
|
+
self.body = body
|
|
282
|
+
|
|
283
|
+
def check_constraint(
|
|
284
|
+
self, prediction: dict[str, Tensor]
|
|
285
|
+
) -> tuple[Tensor, int]:
|
|
286
|
+
|
|
287
|
+
# Check satisfaction of head and body constraints
|
|
288
|
+
head_satisfaction, _ = self.head.check_constraint(prediction)
|
|
289
|
+
body_satisfaction, _ = self.body.check_constraint(prediction)
|
|
290
|
+
|
|
291
|
+
# If head constraint is satisfied (returning 1),
|
|
292
|
+
# the body constraint matters (and should return 0/1 based on body)
|
|
293
|
+
# If head constraint is not satisfied (returning 0),
|
|
294
|
+
# the body constraint does not apply (and should return 1)
|
|
295
|
+
result = logical_or(
|
|
296
|
+
logical_not(head_satisfaction), body_satisfaction
|
|
297
|
+
).float()
|
|
298
|
+
|
|
299
|
+
return result, count_nonzero(head_satisfaction)
|
|
300
|
+
|
|
301
|
+
def calculate_direction(
|
|
302
|
+
self, prediction: dict[str, Tensor]
|
|
303
|
+
) -> Dict[str, Tensor]:
|
|
304
|
+
# NOTE currently only works for dense layers
|
|
305
|
+
# due to neuron to index translation
|
|
306
|
+
|
|
307
|
+
# Use directions of constraint body as update vector
|
|
308
|
+
return self.body.calculate_direction(prediction)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class ScalarConstraint(Constraint):
|
|
312
|
+
"""
|
|
313
|
+
A constraint that enforces scalar-based comparisons on a specific neuron.
|
|
314
|
+
|
|
315
|
+
This class ensures that the output of a specified neuron satisfies a scalar
|
|
316
|
+
comparison operation (e.g., less than, greater than, etc.). It uses a
|
|
317
|
+
comparator function to validate the condition and calculates adjustment
|
|
318
|
+
directions accordingly.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
operand (Union[str, Transformation]): Name of the neuron or a
|
|
322
|
+
transformation to apply.
|
|
323
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
324
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
325
|
+
scalar (Number): The scalar value to compare against.
|
|
326
|
+
name (str, optional): A unique name for the constraint. If not
|
|
327
|
+
provided, a name is auto-generated in the format
|
|
328
|
+
"<neuron_name> <comparator> <scalar>".
|
|
329
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
330
|
+
without adjusting the loss. Defaults to False.
|
|
331
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
332
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
336
|
+
|
|
337
|
+
Notes:
|
|
338
|
+
- The `neuron_name` must be defined in the `descriptor` mapping.
|
|
339
|
+
- The constraint name is composed using the neuron name,
|
|
340
|
+
comparator, and scalar value.
|
|
154
341
|
|
|
155
|
-
Attributes:
|
|
156
|
-
comparator (Callable[[Tensor, Number], Tensor]): A comparator function (e.g., greater than, less than) to evaluate the constraint.
|
|
157
|
-
scalar (Number): The scalar value to compare the neuron value against.
|
|
158
|
-
direction (int): The direction in which the constraint should adjust the model's predictions (either 1 or -1 based on the comparator).
|
|
159
|
-
layer (str): The layer associated with the specified neuron.
|
|
160
|
-
index (int): The index of the specified neuron within the layer.
|
|
161
342
|
"""
|
|
162
343
|
|
|
163
344
|
def __init__(
|
|
164
345
|
self,
|
|
165
|
-
|
|
346
|
+
operand: Union[str, Transformation],
|
|
166
347
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
167
348
|
scalar: Number,
|
|
168
349
|
name: str = None,
|
|
169
|
-
|
|
170
|
-
rescale_factor:
|
|
350
|
+
monitor_only: bool = False,
|
|
351
|
+
rescale_factor: Number = 1.5,
|
|
171
352
|
) -> None:
|
|
172
353
|
"""
|
|
173
|
-
Initializes
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
neuron_name (str): The name of the neuron that the constraint applies to.
|
|
177
|
-
comparator (Callable[[Tensor, Number], Tensor]): The comparator function used to evaluate the constraint (e.g., ge, le, gt, lt).
|
|
178
|
-
scalar (Number): The scalar value that the neuron value is compared to.
|
|
179
|
-
name (str, optional): A custom name for the constraint. If not provided, a name is generated based on the neuron name, comparator, and scalar.
|
|
180
|
-
descriptor (Descriptor, optional): The descriptor that maps neurons to layers. If not provided, the global descriptor is used.
|
|
181
|
-
rescale_factor (float, optional): A factor that scales the influence of the constraint on the overall loss. Defaults to 1.5.
|
|
182
|
-
|
|
183
|
-
Raises:
|
|
184
|
-
ValueError: If the comparator function is not one of the supported comparison operators (ge, le, gt, lt).
|
|
354
|
+
Initializes a ScalarConstraint instance.
|
|
185
355
|
"""
|
|
186
356
|
|
|
357
|
+
# Type checking
|
|
358
|
+
validate_type("operand", operand, (str, Transformation))
|
|
359
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
360
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
361
|
+
validate_type("scalar", scalar, Number)
|
|
362
|
+
|
|
363
|
+
# If transformation is provided, get neuron name,
|
|
364
|
+
# else use IdentityTransformation
|
|
365
|
+
if isinstance(operand, Transformation):
|
|
366
|
+
neuron_name = operand.neuron_name
|
|
367
|
+
transformation = operand
|
|
368
|
+
else:
|
|
369
|
+
neuron_name = operand
|
|
370
|
+
transformation = IdentityTransformation(neuron_name)
|
|
371
|
+
|
|
187
372
|
# Compose constraint name
|
|
188
|
-
name = f"{neuron_name}
|
|
373
|
+
name = f"{neuron_name} {comparator.__name__} {str(scalar)}"
|
|
189
374
|
|
|
190
375
|
# Init parent class
|
|
191
|
-
super().__init__({neuron_name}, name, rescale_factor)
|
|
376
|
+
super().__init__({neuron_name}, name, monitor_only, rescale_factor)
|
|
192
377
|
|
|
193
378
|
# Init variables
|
|
194
379
|
self.comparator = comparator
|
|
195
380
|
self.scalar = scalar
|
|
196
|
-
|
|
197
|
-
if descriptor != None:
|
|
198
|
-
self.descriptor = descriptor
|
|
199
|
-
self.run_init_descriptor()
|
|
381
|
+
self.transformation = transformation
|
|
200
382
|
|
|
201
383
|
# Get layer name and feature index from neuron_name
|
|
202
384
|
self.layer = self.descriptor.neuron_to_layer[neuron_name]
|
|
203
385
|
self.index = self.descriptor.neuron_to_index[neuron_name]
|
|
204
386
|
|
|
205
|
-
# If comparator function is not supported, raise error
|
|
206
|
-
if comparator not in [ge, le, gt, lt]:
|
|
207
|
-
raise ValueError(
|
|
208
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
209
|
-
)
|
|
210
|
-
|
|
211
387
|
# Calculate directions based on constraint operator
|
|
212
388
|
if self.comparator in [lt, le]:
|
|
213
|
-
self.direction = 1
|
|
214
|
-
elif self.comparator in [gt, ge]:
|
|
215
389
|
self.direction = -1
|
|
390
|
+
elif self.comparator in [gt, ge]:
|
|
391
|
+
self.direction = 1
|
|
216
392
|
|
|
217
|
-
def check_constraint(
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
The constraint is evaluated by applying the comparator to the value of the specified neuron and the scalar value.
|
|
393
|
+
def check_constraint(
|
|
394
|
+
self, prediction: dict[str, Tensor]
|
|
395
|
+
) -> tuple[Tensor, int]:
|
|
222
396
|
|
|
223
|
-
|
|
224
|
-
|
|
397
|
+
# Select relevant columns
|
|
398
|
+
selection = prediction[self.layer][:, self.index]
|
|
225
399
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
"""
|
|
400
|
+
# Apply transformation
|
|
401
|
+
selection = self.transformation(selection)
|
|
229
402
|
|
|
230
|
-
|
|
403
|
+
# Calculate current constraint result
|
|
404
|
+
result = self.comparator(selection, self.scalar).float()
|
|
405
|
+
return result, numel(result)
|
|
231
406
|
|
|
232
|
-
|
|
407
|
+
def calculate_direction(
|
|
408
|
+
self, prediction: dict[str, Tensor]
|
|
409
|
+
) -> Dict[str, Tensor]:
|
|
410
|
+
# NOTE currently only works for dense layers due
|
|
411
|
+
# to neuron to index translation
|
|
233
412
|
|
|
234
|
-
|
|
235
|
-
"""
|
|
236
|
-
Calculates the direction in which the model's predictions need to be adjusted to satisfy the constraint.
|
|
413
|
+
output = {}
|
|
237
414
|
|
|
238
|
-
|
|
415
|
+
for layer in self.layers:
|
|
416
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
239
417
|
|
|
240
|
-
|
|
241
|
-
prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
|
|
418
|
+
output[self.layer][self.index] = self.direction
|
|
242
419
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
"""
|
|
246
|
-
|
|
247
|
-
output = zeros(
|
|
248
|
-
prediction[self.layer].size(),
|
|
249
|
-
device=prediction[self.layer].device,
|
|
250
|
-
)
|
|
251
|
-
output[:, self.index] = self.direction
|
|
420
|
+
for layer in self.layers:
|
|
421
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
252
422
|
|
|
253
|
-
return
|
|
423
|
+
return output
|
|
254
424
|
|
|
255
425
|
|
|
256
426
|
class BinaryConstraint(Constraint):
|
|
257
427
|
"""
|
|
258
|
-
A
|
|
259
|
-
|
|
260
|
-
This class
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
comparator (Callable[[Tensor, Number], Tensor]): A
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
428
|
+
A constraint that enforces a binary comparison between two neurons.
|
|
429
|
+
|
|
430
|
+
This class ensures that the output of one neuron satisfies a comparison
|
|
431
|
+
operation with the output of another neuron
|
|
432
|
+
(e.g., less than, greater than, etc.). It uses a comparator function to
|
|
433
|
+
validate the condition and calculates adjustment directions accordingly.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
437
|
+
neuron or a transformation to apply.
|
|
438
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
439
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
440
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
441
|
+
neuron or a transformation to apply.
|
|
442
|
+
name (str, optional): A unique name for the constraint. If not
|
|
443
|
+
provided, a name is auto-generated in the format
|
|
444
|
+
"<neuron_name_left> <comparator> <neuron_name_right>".
|
|
445
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
446
|
+
without adjusting the loss. Defaults to False.
|
|
447
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
448
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
449
|
+
|
|
450
|
+
Raises:
|
|
451
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
452
|
+
|
|
453
|
+
Notes:
|
|
454
|
+
- The neuron names must be defined in the `descriptor` mapping.
|
|
455
|
+
- The constraint name is composed using the left neuron name,
|
|
456
|
+
comparator, and right neuron name.
|
|
457
|
+
|
|
275
458
|
"""
|
|
276
459
|
|
|
277
460
|
def __init__(
|
|
278
461
|
self,
|
|
279
|
-
|
|
462
|
+
operand_left: Union[str, Transformation],
|
|
280
463
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
281
|
-
|
|
464
|
+
operand_right: Union[str, Transformation],
|
|
282
465
|
name: str = None,
|
|
283
|
-
|
|
284
|
-
rescale_factor:
|
|
466
|
+
monitor_only: bool = False,
|
|
467
|
+
rescale_factor: Number = 1.5,
|
|
285
468
|
) -> None:
|
|
286
469
|
"""
|
|
287
|
-
Initializes
|
|
288
|
-
|
|
289
|
-
Args:
|
|
290
|
-
neuron_name_left (str): The name of the first neuron in the constraint.
|
|
291
|
-
comparator (Callable[[Tensor, Number], Tensor]): A function that compares the values of the two neurons.
|
|
292
|
-
neuron_name_right (str): The name of the second neuron in the constraint.
|
|
293
|
-
name (str, optional): The name of the constraint. If not provided, a default name is generated.
|
|
294
|
-
descriptor (Descriptor, optional): The descriptor containing the mapping of neurons to layers.
|
|
295
|
-
rescale_factor (float, optional): A factor to rescale the constraint value. Default is 1.5.
|
|
470
|
+
Initializes a BinaryConstraint instance.
|
|
296
471
|
"""
|
|
297
472
|
|
|
473
|
+
# Type checking
|
|
474
|
+
validate_type("operand_left", operand_left, (str, Transformation))
|
|
475
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
476
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
477
|
+
validate_type("operand_right", operand_right, (str, Transformation))
|
|
478
|
+
|
|
479
|
+
# If transformation is provided, get neuron name,
|
|
480
|
+
# else use IdentityTransformation
|
|
481
|
+
if isinstance(operand_left, Transformation):
|
|
482
|
+
neuron_name_left = operand_left.neuron_name
|
|
483
|
+
transformation_left = operand_left
|
|
484
|
+
else:
|
|
485
|
+
neuron_name_left = operand_left
|
|
486
|
+
transformation_left = IdentityTransformation(neuron_name_left)
|
|
487
|
+
|
|
488
|
+
if isinstance(operand_right, Transformation):
|
|
489
|
+
neuron_name_right = operand_right.neuron_name
|
|
490
|
+
transformation_right = operand_right
|
|
491
|
+
else:
|
|
492
|
+
neuron_name_right = operand_right
|
|
493
|
+
transformation_right = IdentityTransformation(neuron_name_right)
|
|
494
|
+
|
|
298
495
|
# Compose constraint name
|
|
299
|
-
name = f"{neuron_name_left}
|
|
496
|
+
name = f"{neuron_name_left} {comparator.__name__} {neuron_name_right}"
|
|
300
497
|
|
|
301
498
|
# Init parent class
|
|
302
|
-
super().__init__(
|
|
499
|
+
super().__init__(
|
|
500
|
+
{neuron_name_left, neuron_name_right},
|
|
501
|
+
name,
|
|
502
|
+
monitor_only,
|
|
503
|
+
rescale_factor,
|
|
504
|
+
)
|
|
303
505
|
|
|
304
506
|
# Init variables
|
|
305
507
|
self.comparator = comparator
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
self.descriptor = descriptor
|
|
309
|
-
self.run_init_descriptor()
|
|
508
|
+
self.transformation_left = transformation_left
|
|
509
|
+
self.transformation_right = transformation_right
|
|
310
510
|
|
|
311
511
|
# Get layer name and feature index from neuron_name
|
|
312
512
|
self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
@@ -314,12 +514,6 @@ class BinaryConstraint(Constraint):
|
|
|
314
514
|
self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
|
|
315
515
|
self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
|
|
316
516
|
|
|
317
|
-
# If comparator function is not supported, raise error
|
|
318
|
-
if comparator not in [ge, le, gt, lt]:
|
|
319
|
-
raise RuntimeError(
|
|
320
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
321
|
-
)
|
|
322
|
-
|
|
323
517
|
# Calculate directions based on constraint operator
|
|
324
518
|
if self.comparator in [lt, le]:
|
|
325
519
|
self.direction_left = -1
|
|
@@ -328,128 +522,179 @@ class BinaryConstraint(Constraint):
|
|
|
328
522
|
self.direction_left = 1
|
|
329
523
|
self.direction_right = -1
|
|
330
524
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
p=2,
|
|
335
|
-
dim=0,
|
|
336
|
-
)
|
|
337
|
-
self.direction_left = normalized_directions[0]
|
|
338
|
-
self.direction_right = normalized_directions[1]
|
|
525
|
+
def check_constraint(
|
|
526
|
+
self, prediction: dict[str, Tensor]
|
|
527
|
+
) -> tuple[Tensor, int]:
|
|
339
528
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
529
|
+
# Select relevant columns
|
|
530
|
+
selection_left = prediction[self.layer_left][:, self.index_left]
|
|
531
|
+
selection_right = prediction[self.layer_right][:, self.index_right]
|
|
343
532
|
|
|
344
|
-
|
|
345
|
-
|
|
533
|
+
# Apply transformations
|
|
534
|
+
selection_left = self.transformation_left(selection_left)
|
|
535
|
+
selection_right = self.transformation_right(selection_right)
|
|
346
536
|
|
|
347
|
-
|
|
348
|
-
prediction (dict[str, Tensor]): A dictionary containing the predictions for each layer.
|
|
537
|
+
result = self.comparator(selection_left, selection_right).float()
|
|
349
538
|
|
|
350
|
-
|
|
351
|
-
dict[str, Tensor]: A dictionary with the layer names as keys and the constraint satisfaction results as values.
|
|
352
|
-
"""
|
|
539
|
+
return result, numel(result)
|
|
353
540
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
541
|
+
def calculate_direction(
|
|
542
|
+
self, prediction: dict[str, Tensor]
|
|
543
|
+
) -> Dict[str, Tensor]:
|
|
544
|
+
# NOTE currently only works for dense layers due
|
|
545
|
+
# to neuron to index translation
|
|
358
546
|
|
|
359
|
-
|
|
547
|
+
output = {}
|
|
360
548
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
Calculates the direction for gradient adjustment for both neurons involved in the constraint.
|
|
549
|
+
for layer in self.layers:
|
|
550
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
364
551
|
|
|
365
|
-
|
|
552
|
+
output[self.layer_left][self.index_left] = self.direction_left
|
|
553
|
+
output[self.layer_right][self.index_right] = self.direction_right
|
|
366
554
|
|
|
367
|
-
|
|
368
|
-
|
|
555
|
+
for layer in self.layers:
|
|
556
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
369
557
|
|
|
370
|
-
|
|
371
|
-
dict[str, Tensor]: A dictionary with the layer names as keys and the direction vectors as values.
|
|
372
|
-
"""
|
|
558
|
+
return output
|
|
373
559
|
|
|
374
|
-
output_left = zeros(
|
|
375
|
-
prediction[self.layer_left].size(),
|
|
376
|
-
device=prediction[self.layer_left].device,
|
|
377
|
-
)
|
|
378
|
-
output_left[:, self.index_left] = self.direction_left
|
|
379
560
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
561
|
+
class SumConstraint(Constraint):
|
|
562
|
+
"""
|
|
563
|
+
A constraint that enforces a weighted summation comparison
|
|
564
|
+
between two groups of neurons.
|
|
565
|
+
|
|
566
|
+
This class evaluates whether the weighted sum of outputs from one set of
|
|
567
|
+
neurons satisfies a comparison operation with the weighted sum of
|
|
568
|
+
outputs from another set of neurons.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
operands_left (list[Union[str, Transformation]]): List of neuron
|
|
572
|
+
names or transformations on the left side.
|
|
573
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
574
|
+
function for the constraint.
|
|
575
|
+
operands_right (list[Union[str, Transformation]]): List of neuron
|
|
576
|
+
names or transformations on the right side.
|
|
577
|
+
weights_left (list[Number], optional): Weights for the left neurons.
|
|
578
|
+
Defaults to None.
|
|
579
|
+
weights_right (list[Number], optional): Weights for the right
|
|
580
|
+
neurons. Defaults to None.
|
|
581
|
+
name (str, optional): Unique name for the constraint.
|
|
582
|
+
If None, it's auto-generated. Defaults to None.
|
|
583
|
+
monitor_only (bool, optional): If True, only monitor the constraint
|
|
584
|
+
without adjusting the loss. Defaults to False.
|
|
585
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
586
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
587
|
+
|
|
588
|
+
Raises:
|
|
589
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
590
|
+
ValueError: If the dimensions of neuron names and weights mismatch.
|
|
387
591
|
|
|
592
|
+
"""
|
|
388
593
|
|
|
389
|
-
# FIXME
|
|
390
|
-
class SumConstraint(Constraint):
|
|
391
594
|
def __init__(
|
|
392
595
|
self,
|
|
393
|
-
|
|
596
|
+
operands_left: list[Union[str, Transformation]],
|
|
394
597
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
395
|
-
|
|
396
|
-
weights_left: list[
|
|
397
|
-
weights_right: list[
|
|
598
|
+
operands_right: list[Union[str, Transformation]],
|
|
599
|
+
weights_left: list[Number] = None,
|
|
600
|
+
weights_right: list[Number] = None,
|
|
398
601
|
name: str = None,
|
|
399
|
-
|
|
400
|
-
rescale_factor:
|
|
602
|
+
monitor_only: bool = False,
|
|
603
|
+
rescale_factor: Number = 1.5,
|
|
401
604
|
) -> None:
|
|
605
|
+
"""
|
|
606
|
+
Initializes the SumConstraint.
|
|
607
|
+
"""
|
|
402
608
|
|
|
403
|
-
#
|
|
404
|
-
|
|
405
|
-
|
|
609
|
+
# Type checking
|
|
610
|
+
validate_iterable("operands_left", operands_left, (str, Transformation))
|
|
611
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
612
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
613
|
+
validate_iterable(
|
|
614
|
+
"operands_right", operands_right, (str, Transformation)
|
|
615
|
+
)
|
|
616
|
+
validate_iterable("weights_left", weights_left, Number, allow_none=True)
|
|
617
|
+
validate_iterable(
|
|
618
|
+
"weights_right", weights_right, Number, allow_none=True
|
|
406
619
|
)
|
|
407
620
|
|
|
621
|
+
# If transformation is provided, get neuron name,
|
|
622
|
+
# else use IdentityTransformation
|
|
623
|
+
neuron_names_left: list[str] = []
|
|
624
|
+
transformations_left: list[Transformation] = []
|
|
625
|
+
for operand_left in operands_left:
|
|
626
|
+
if isinstance(operand_left, Transformation):
|
|
627
|
+
neuron_name_left = operand_left.neuron_name
|
|
628
|
+
neuron_names_left.append(neuron_name_left)
|
|
629
|
+
transformations_left.append(operand_left)
|
|
630
|
+
else:
|
|
631
|
+
neuron_name_left = operand_left
|
|
632
|
+
neuron_names_left.append(neuron_name_left)
|
|
633
|
+
transformations_left.append(
|
|
634
|
+
IdentityTransformation(neuron_name_left)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
neuron_names_right: list[str] = []
|
|
638
|
+
transformations_right: list[Transformation] = []
|
|
639
|
+
for operand_right in operands_right:
|
|
640
|
+
if isinstance(operand_right, Transformation):
|
|
641
|
+
neuron_name_right = operand_right.neuron_name
|
|
642
|
+
neuron_names_right.append(neuron_name_right)
|
|
643
|
+
transformations_right.append(operand_right)
|
|
644
|
+
else:
|
|
645
|
+
neuron_name_right = operand_right
|
|
646
|
+
neuron_names_right.append(neuron_name_right)
|
|
647
|
+
transformations_right.append(
|
|
648
|
+
IdentityTransformation(neuron_name_right)
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Compose constraint name
|
|
652
|
+
w_left = weights_left or [""] * len(neuron_names_left)
|
|
653
|
+
w_right = weights_right or [""] * len(neuron_names_right)
|
|
654
|
+
left_expr = " + ".join(
|
|
655
|
+
f"{w}{n}" for w, n in zip(w_left, neuron_names_left)
|
|
656
|
+
)
|
|
657
|
+
right_expr = " + ".join(
|
|
658
|
+
f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
|
|
659
|
+
)
|
|
660
|
+
comparator_name = comparator.__name__
|
|
661
|
+
name = f"{left_expr} {comparator_name} {right_expr}"
|
|
662
|
+
|
|
663
|
+
# Init parent class
|
|
664
|
+
neuron_names = set(neuron_names_left) | set(neuron_names_right)
|
|
665
|
+
super().__init__(neuron_names, name, monitor_only, rescale_factor)
|
|
666
|
+
|
|
408
667
|
# Init variables
|
|
409
668
|
self.comparator = comparator
|
|
669
|
+
self.neuron_names_left = neuron_names_left
|
|
670
|
+
self.neuron_names_right = neuron_names_right
|
|
671
|
+
self.transformations_left = transformations_left
|
|
672
|
+
self.transformations_right = transformations_right
|
|
410
673
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
self.run_init_descriptor()
|
|
414
|
-
|
|
415
|
-
# Get layer names and feature indices from neuron_name
|
|
416
|
-
self.layers_left = []
|
|
417
|
-
self.indices_left = []
|
|
418
|
-
for neuron_name in neuron_names_left:
|
|
419
|
-
self.layers_left.append(self.descriptor.neuron_to_layer[neuron_name])
|
|
420
|
-
self.indices_left.append(self.descriptor.neuron_to_index[neuron_name])
|
|
421
|
-
|
|
422
|
-
self.layers_right = []
|
|
423
|
-
self.indices_right = []
|
|
424
|
-
for neuron_name in neuron_names_right:
|
|
425
|
-
self.layers_right.append(self.descriptor.neuron_to_layer[neuron_name])
|
|
426
|
-
self.indices_right.append(self.descriptor.neuron_to_index[neuron_name])
|
|
427
|
-
|
|
428
|
-
# If comparator function is not supported, raise error
|
|
429
|
-
if comparator not in [ge, le, gt, lt]:
|
|
430
|
-
raise ValueError(
|
|
431
|
-
f"Comparator {str(comparator)} used for constraint {name} is not supported. Only ge, le, gt, lt are allowed."
|
|
432
|
-
)
|
|
433
|
-
|
|
434
|
-
# If feature list dimensions don't match weight list dimensions, raise error
|
|
674
|
+
# If feature list dimensions don't match
|
|
675
|
+
# weight list dimensions, raise error
|
|
435
676
|
if weights_left and (len(neuron_names_left) != len(weights_left)):
|
|
436
677
|
raise ValueError(
|
|
437
|
-
"The dimensions of neuron_names_left don't match with the
|
|
678
|
+
"The dimensions of neuron_names_left don't match with the \
|
|
679
|
+
dimensions of weights_left."
|
|
438
680
|
)
|
|
439
681
|
if weights_right and (len(neuron_names_right) != len(weights_right)):
|
|
440
682
|
raise ValueError(
|
|
441
|
-
"The dimensions of neuron_names_right don't match with the
|
|
683
|
+
"The dimensions of neuron_names_right don't match with the \
|
|
684
|
+
dimensions of weights_right."
|
|
442
685
|
)
|
|
443
686
|
|
|
444
687
|
# If weights are provided for summation, transform them to Tensors
|
|
445
688
|
if weights_left:
|
|
446
|
-
self.weights_left =
|
|
689
|
+
self.weights_left = tensor(weights_left, device=self.device)
|
|
447
690
|
else:
|
|
448
|
-
self.weights_left = ones(len(neuron_names_left))
|
|
691
|
+
self.weights_left = ones(len(neuron_names_left), device=self.device)
|
|
449
692
|
if weights_right:
|
|
450
|
-
self.weights_right =
|
|
693
|
+
self.weights_right = tensor(weights_right, device=self.device)
|
|
451
694
|
else:
|
|
452
|
-
self.weights_right = ones(
|
|
695
|
+
self.weights_right = ones(
|
|
696
|
+
len(neuron_names_right), device=self.device
|
|
697
|
+
)
|
|
453
698
|
|
|
454
699
|
# Calculate directions based on constraint operator
|
|
455
700
|
if self.comparator in [lt, le]:
|
|
@@ -459,49 +704,216 @@ class SumConstraint(Constraint):
|
|
|
459
704
|
self.direction_left = 1
|
|
460
705
|
self.direction_right = -1
|
|
461
706
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
707
|
+
def check_constraint(
|
|
708
|
+
self, prediction: dict[str, Tensor]
|
|
709
|
+
) -> tuple[Tensor, int]:
|
|
710
|
+
|
|
711
|
+
def compute_weighted_sum(
|
|
712
|
+
neuron_names: list[str],
|
|
713
|
+
transformations: list[Transformation],
|
|
714
|
+
weights: tensor,
|
|
715
|
+
) -> tensor:
|
|
716
|
+
layers = [
|
|
717
|
+
self.descriptor.neuron_to_layer[neuron_name]
|
|
718
|
+
for neuron_name in neuron_names
|
|
719
|
+
]
|
|
720
|
+
indices = [
|
|
721
|
+
self.descriptor.neuron_to_index[neuron_name]
|
|
722
|
+
for neuron_name in neuron_names
|
|
723
|
+
]
|
|
724
|
+
|
|
725
|
+
# Select relevant column
|
|
726
|
+
selections = [
|
|
727
|
+
prediction[layer][:, index]
|
|
728
|
+
for layer, index in zip(layers, indices)
|
|
729
|
+
]
|
|
730
|
+
|
|
731
|
+
# Apply transformations
|
|
732
|
+
results = []
|
|
733
|
+
for transformation, selection in zip(transformations, selections):
|
|
734
|
+
results.append(transformation(selection))
|
|
735
|
+
|
|
736
|
+
# Extract predictions for all neurons and apply weights in bulk
|
|
737
|
+
predictions = stack(
|
|
738
|
+
results,
|
|
739
|
+
dim=1,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# Calculate weighted sum
|
|
743
|
+
return (predictions * weights.unsqueeze(0)).sum(dim=1)
|
|
744
|
+
|
|
745
|
+
# Compute weighted sums
|
|
746
|
+
weighted_sum_left = compute_weighted_sum(
|
|
747
|
+
self.neuron_names_left,
|
|
748
|
+
self.transformations_left,
|
|
749
|
+
self.weights_left,
|
|
750
|
+
)
|
|
751
|
+
weighted_sum_right = compute_weighted_sum(
|
|
752
|
+
self.neuron_names_right,
|
|
753
|
+
self.transformations_right,
|
|
754
|
+
self.weights_right,
|
|
465
755
|
)
|
|
466
|
-
self.direction_left = normalized_directions[0]
|
|
467
|
-
self.direction_right = normalized_directions[1]
|
|
468
756
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
#
|
|
478
|
-
#
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
757
|
+
# Apply the comparator and calculate the result
|
|
758
|
+
result = self.comparator(weighted_sum_left, weighted_sum_right).float()
|
|
759
|
+
|
|
760
|
+
return result, numel(result)
|
|
761
|
+
|
|
762
|
+
def calculate_direction(
|
|
763
|
+
self, prediction: dict[str, Tensor]
|
|
764
|
+
) -> Dict[str, Tensor]:
|
|
765
|
+
# NOTE currently only works for dense layers
|
|
766
|
+
# due to neuron to index translation
|
|
767
|
+
|
|
768
|
+
output = {}
|
|
769
|
+
|
|
770
|
+
for layer in self.layers:
|
|
771
|
+
output[layer] = zeros_like(prediction[layer][0], device=self.device)
|
|
772
|
+
|
|
773
|
+
for neuron_name_left in self.neuron_names_left:
|
|
774
|
+
layer = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
775
|
+
index = self.descriptor.neuron_to_index[neuron_name_left]
|
|
776
|
+
output[layer][index] = self.direction_left
|
|
777
|
+
|
|
778
|
+
for neuron_name_right in self.neuron_names_right:
|
|
779
|
+
layer = self.descriptor.neuron_to_layer[neuron_name_right]
|
|
780
|
+
index = self.descriptor.neuron_to_index[neuron_name_right]
|
|
781
|
+
output[layer][index] = self.direction_right
|
|
782
|
+
|
|
783
|
+
for layer in self.layers:
|
|
784
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
785
|
+
|
|
786
|
+
return output
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
class PythagoreanIdentityConstraint(Constraint):
|
|
790
|
+
"""
|
|
791
|
+
A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
|
|
792
|
+
where `a` and `b` are neurons or transformations.
|
|
793
|
+
|
|
794
|
+
This constraint checks that the sum of the squares of two specified
|
|
795
|
+
neurons (or their transformations) is approximately equal to 1.
|
|
796
|
+
The constraint is evaluated using relative and absolute
|
|
797
|
+
tolerance (`rtol` and `atol`) and is applied during the forward pass.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
a (Union[str, Transformation]): The first input, either a
|
|
801
|
+
neuron name (str) or a Transformation.
|
|
802
|
+
b (Union[str, Transformation]): The second input, either a
|
|
803
|
+
neuron name (str) or a Transformation.
|
|
804
|
+
rtol (float, optional): The relative tolerance for the
|
|
805
|
+
comparison (default is 0.00001).
|
|
806
|
+
atol (float, optional): The absolute tolerance for the
|
|
807
|
+
comparison (default is 1e-8).
|
|
808
|
+
name (str, optional): The name of the constraint
|
|
809
|
+
(default is None, and it is generated automatically).
|
|
810
|
+
monitor_only (bool, optional): Flag indicating whether the
|
|
811
|
+
constraint is only for monitoring (default is False).
|
|
812
|
+
rescale_factor (Number, optional): A factor used for
|
|
813
|
+
rescaling (default is 1.5).
|
|
814
|
+
|
|
815
|
+
Raises:
|
|
816
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
817
|
+
|
|
818
|
+
"""
|
|
819
|
+
|
|
820
|
+
def __init__(
|
|
821
|
+
self,
|
|
822
|
+
a: Union[str, Transformation],
|
|
823
|
+
b: Union[str, Transformation],
|
|
824
|
+
rtol: float = 0.00001,
|
|
825
|
+
atol: float = 1e-8,
|
|
826
|
+
name: str = None,
|
|
827
|
+
monitor_only: bool = False,
|
|
828
|
+
rescale_factor: Number = 1.5,
|
|
829
|
+
) -> None:
|
|
830
|
+
"""
|
|
831
|
+
Initialize the PythagoreanIdentityConstraint.
|
|
832
|
+
"""
|
|
833
|
+
|
|
834
|
+
# Type checking
|
|
835
|
+
validate_type("a", a, (str, Transformation))
|
|
836
|
+
validate_type("b", b, (str, Transformation))
|
|
837
|
+
validate_type("rtol", rtol, float)
|
|
838
|
+
validate_type("atol", atol, float)
|
|
839
|
+
|
|
840
|
+
# If transformation is provided, get neuron name,
|
|
841
|
+
# else use IdentityTransformation
|
|
842
|
+
if isinstance(a, Transformation):
|
|
843
|
+
neuron_name_a = a.neuron_name
|
|
844
|
+
transformation_a = a
|
|
845
|
+
else:
|
|
846
|
+
neuron_name_a = a
|
|
847
|
+
transformation_a = IdentityTransformation(neuron_name_a)
|
|
848
|
+
|
|
849
|
+
if isinstance(b, Transformation):
|
|
850
|
+
neuron_name_b = b.neuron_name
|
|
851
|
+
transformation_b = b
|
|
852
|
+
else:
|
|
853
|
+
neuron_name_b = b
|
|
854
|
+
transformation_b = IdentityTransformation(neuron_name_b)
|
|
855
|
+
|
|
856
|
+
# Compose constraint name
|
|
857
|
+
name = f"{neuron_name_a}² + {neuron_name_b}² ≈ 1"
|
|
858
|
+
|
|
859
|
+
# Init parent class
|
|
860
|
+
super().__init__(
|
|
861
|
+
{neuron_name_a, neuron_name_b},
|
|
862
|
+
name,
|
|
863
|
+
monitor_only,
|
|
864
|
+
rescale_factor,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
# Init variables
|
|
868
|
+
self.transformation_a = transformation_a
|
|
869
|
+
self.transformation_b = transformation_b
|
|
870
|
+
self.rtol = rtol
|
|
871
|
+
self.atol = atol
|
|
872
|
+
|
|
873
|
+
# Get layer name and feature index from neuron_name
|
|
874
|
+
self.layer_a = self.descriptor.neuron_to_layer[neuron_name_a]
|
|
875
|
+
self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
|
|
876
|
+
self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
|
|
877
|
+
self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
|
|
878
|
+
|
|
879
|
+
def check_constraint(
|
|
880
|
+
self, prediction: dict[str, Tensor]
|
|
881
|
+
) -> tuple[Tensor, int]:
|
|
882
|
+
|
|
883
|
+
# Select relevant columns
|
|
884
|
+
selection_a = prediction[self.layer_a][:, self.index_a]
|
|
885
|
+
selection_b = prediction[self.layer_b][:, self.index_b]
|
|
886
|
+
|
|
887
|
+
# Apply transformations
|
|
888
|
+
selection_a = self.transformation_a(selection_a)
|
|
889
|
+
selection_b = self.transformation_b(selection_b)
|
|
890
|
+
|
|
891
|
+
# Calculate result
|
|
892
|
+
result = isclose(
|
|
893
|
+
square(selection_a) + square(selection_b),
|
|
894
|
+
ones_like(selection_a, device=self.device),
|
|
895
|
+
rtol=self.rtol,
|
|
896
|
+
atol=self.atol,
|
|
897
|
+
).float()
|
|
898
|
+
|
|
899
|
+
return result, numel(result)
|
|
900
|
+
|
|
901
|
+
def calculate_direction(
|
|
902
|
+
self, prediction: dict[str, Tensor]
|
|
903
|
+
) -> Dict[str, Tensor]:
|
|
904
|
+
# NOTE currently only works for dense layers due
|
|
905
|
+
# to neuron to index translation
|
|
906
|
+
|
|
907
|
+
output = {}
|
|
908
|
+
|
|
909
|
+
for layer in self.layers:
|
|
910
|
+
output[layer] = zeros_like(prediction[layer], device=self.device)
|
|
911
|
+
|
|
912
|
+
a = prediction[self.layer_a][:, self.index_a]
|
|
913
|
+
b = prediction[self.layer_b][:, self.index_b]
|
|
914
|
+
m = sqrt(square(a) + square(b))
|
|
915
|
+
|
|
916
|
+
output[self.layer_a][:, self.index_a] = a / m * sign(1 - m)
|
|
917
|
+
output[self.layer_b][:, self.index_b] = b / m * sign(1 - m)
|
|
918
|
+
|
|
919
|
+
return output
|