congrads 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +804 -454
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +118 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.7.dist-info → congrads-1.1.0.dist-info}/METADATA +48 -41
- congrads-1.1.0.dist-info/RECORD +14 -0
- congrads-1.1.0.dist-info/WHEEL +4 -0
- congrads-1.0.7.dist-info/LICENSE +0 -26
- congrads-1.0.7.dist-info/RECORD +0 -15
- congrads-1.0.7.dist-info/WHEEL +0 -5
- congrads-1.0.7.dist-info/top_level.txt +0 -1
congrads/constraints.py
CHANGED
|
@@ -1,65 +1,55 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
outputs equals a specified value, which can be used to control total output.
|
|
19
|
-
- `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
|
|
20
|
-
on a set of neurons, ensuring that the square of one neuron's output is equal
|
|
21
|
-
to the sum of the squares of other outputs.
|
|
22
|
-
|
|
23
|
-
These constraints can be used to steer the learning process by applying
|
|
24
|
-
conditions such as logical implications or numerical bounds.
|
|
1
|
+
"""Module providing constraint classes for guiding neural network training.
|
|
2
|
+
|
|
3
|
+
This module defines constraints that enforce specific conditions on network outputs
|
|
4
|
+
to steer learning. Available constraint types include:
|
|
5
|
+
|
|
6
|
+
- `Constraint`: Base class for all constraint types, defining the interface and core
|
|
7
|
+
behavior.
|
|
8
|
+
- `ImplicationConstraint`: Enforces one condition only if another condition is met,
|
|
9
|
+
useful for modeling implications between outputs.
|
|
10
|
+
- `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
|
|
11
|
+
- `BinaryConstraint`: Enforces a binary comparison between two tags using a
|
|
12
|
+
comparison function (e.g., less than, greater than).
|
|
13
|
+
- `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
|
|
14
|
+
value, controlling total output.
|
|
15
|
+
|
|
16
|
+
These constraints can steer the learning process by applying logical implications
|
|
17
|
+
or numerical bounds.
|
|
25
18
|
|
|
26
19
|
Usage:
|
|
27
20
|
1. Define a custom constraint class by inheriting from `Constraint`.
|
|
28
|
-
2. Apply the constraint to your neural network during training
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
transformations and comparisons in constraints.
|
|
21
|
+
2. Apply the constraint to your neural network during training.
|
|
22
|
+
3. Use helper classes like `IdentityTransformation` for transformations and
|
|
23
|
+
comparisons in constraints.
|
|
32
24
|
|
|
33
|
-
Dependencies:
|
|
34
|
-
- PyTorch (`torch`)
|
|
35
25
|
"""
|
|
36
26
|
|
|
37
27
|
import random
|
|
38
28
|
import string
|
|
39
29
|
import warnings
|
|
40
30
|
from abc import ABC, abstractmethod
|
|
31
|
+
from collections.abc import Callable
|
|
41
32
|
from numbers import Number
|
|
42
|
-
from typing import
|
|
33
|
+
from typing import Literal
|
|
43
34
|
|
|
44
35
|
from torch import (
|
|
45
36
|
Tensor,
|
|
46
|
-
|
|
37
|
+
argsort,
|
|
38
|
+
eq,
|
|
47
39
|
ge,
|
|
48
40
|
gt,
|
|
49
|
-
isclose,
|
|
50
41
|
le,
|
|
42
|
+
logical_and,
|
|
51
43
|
logical_not,
|
|
52
44
|
logical_or,
|
|
53
45
|
lt,
|
|
54
|
-
numel,
|
|
55
46
|
ones,
|
|
56
47
|
ones_like,
|
|
57
48
|
reshape,
|
|
58
49
|
sign,
|
|
59
|
-
sqrt,
|
|
60
|
-
square,
|
|
61
50
|
stack,
|
|
62
51
|
tensor,
|
|
52
|
+
unique,
|
|
63
53
|
zeros_like,
|
|
64
54
|
)
|
|
65
55
|
from torch.nn.functional import normalize
|
|
@@ -70,8 +60,7 @@ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
|
|
|
70
60
|
|
|
71
61
|
|
|
72
62
|
class Constraint(ABC):
|
|
73
|
-
"""
|
|
74
|
-
Abstract base class for defining constraints applied to neural networks.
|
|
63
|
+
"""Abstract base class for defining constraints applied to neural networks.
|
|
75
64
|
|
|
76
65
|
A `Constraint` specifies conditions that the neural network outputs
|
|
77
66
|
should satisfy. It supports monitoring constraint satisfaction
|
|
@@ -79,18 +68,18 @@ class Constraint(ABC):
|
|
|
79
68
|
must implement the `check_constraint` and `calculate_direction` methods.
|
|
80
69
|
|
|
81
70
|
Args:
|
|
82
|
-
|
|
71
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
83
72
|
name (str, optional): A unique name for the constraint. If not provided,
|
|
84
73
|
a name is generated based on the class name and a random suffix.
|
|
85
|
-
|
|
86
|
-
without adjusting the loss. Defaults to
|
|
74
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
75
|
+
without adjusting the loss. Defaults to True.
|
|
87
76
|
rescale_factor (Number, optional): Factor to scale the
|
|
88
77
|
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
89
78
|
than 1 to give weight to the constraint.
|
|
90
79
|
|
|
91
80
|
Raises:
|
|
92
81
|
TypeError: If a provided attribute has an incompatible type.
|
|
93
|
-
ValueError: If any
|
|
82
|
+
ValueError: If any tag in `tags` is not
|
|
94
83
|
defined in the `descriptor`.
|
|
95
84
|
|
|
96
85
|
Note:
|
|
@@ -104,29 +93,44 @@ class Constraint(ABC):
|
|
|
104
93
|
device = None
|
|
105
94
|
|
|
106
95
|
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
neurons: set[str],
|
|
109
|
-
name: str = None,
|
|
110
|
-
monitor_only: bool = False,
|
|
111
|
-
rescale_factor: Number = 1.5,
|
|
96
|
+
self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
|
|
112
97
|
) -> None:
|
|
113
|
-
"""
|
|
114
|
-
Initializes a new Constraint instance.
|
|
115
|
-
"""
|
|
98
|
+
"""Initializes a new Constraint instance.
|
|
116
99
|
|
|
100
|
+
Args:
|
|
101
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
102
|
+
name (str, optional): A unique name for the constraint. If not
|
|
103
|
+
provided, a name is generated based on the class name and a
|
|
104
|
+
random suffix.
|
|
105
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
106
|
+
without adjusting the loss. Defaults to True.
|
|
107
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
108
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
109
|
+
than 1 to give weight to the constraint.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
113
|
+
ValueError: If any tag in `tags` is not defined in the `descriptor`.
|
|
114
|
+
|
|
115
|
+
Note:
|
|
116
|
+
- If `rescale_factor <= 1`, a warning is issued.
|
|
117
|
+
- If `name` is not provided, a name is auto-generated, and a
|
|
118
|
+
warning is logged.
|
|
119
|
+
"""
|
|
117
120
|
# Init parent class
|
|
118
121
|
super().__init__()
|
|
119
122
|
|
|
120
123
|
# Type checking
|
|
121
|
-
validate_iterable("
|
|
124
|
+
validate_iterable("tags", tags, str)
|
|
122
125
|
validate_type("name", name, str, allow_none=True)
|
|
123
|
-
validate_type("
|
|
126
|
+
validate_type("enforce", enforce, bool)
|
|
124
127
|
validate_type("rescale_factor", rescale_factor, Number)
|
|
125
128
|
|
|
126
129
|
# Init object variables
|
|
127
|
-
self.
|
|
130
|
+
self.tags = tags
|
|
128
131
|
self.rescale_factor = rescale_factor
|
|
129
|
-
self.
|
|
132
|
+
self.initial_rescale_factor = rescale_factor
|
|
133
|
+
self.enforce = enforce
|
|
130
134
|
|
|
131
135
|
# Perform checks
|
|
132
136
|
if rescale_factor <= 1:
|
|
@@ -135,120 +139,102 @@ class Constraint(ABC):
|
|
|
135
139
|
"will favor general loss over the constraint-adjusted loss. "
|
|
136
140
|
"Is this intended behavior? Normally, the rescale factor "
|
|
137
141
|
"should always be larger than 1.",
|
|
142
|
+
stacklevel=2,
|
|
138
143
|
)
|
|
139
|
-
else:
|
|
140
|
-
self.rescale_factor = rescale_factor
|
|
141
144
|
|
|
142
145
|
# If no constraint_name is set, generate one based
|
|
143
146
|
# on the class name and a random suffix
|
|
144
147
|
if name:
|
|
145
148
|
self.name = name
|
|
146
149
|
else:
|
|
147
|
-
random_suffix = "".join(
|
|
148
|
-
random.choices(string.ascii_uppercase + string.digits, k=6)
|
|
149
|
-
)
|
|
150
|
+
random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
|
150
151
|
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
151
|
-
warnings.warn(
|
|
152
|
-
f"Name for constraint is not set. Using {self.name}.",
|
|
153
|
-
)
|
|
152
|
+
warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
|
|
154
153
|
|
|
155
|
-
# Infer layers from descriptor and
|
|
154
|
+
# Infer layers from descriptor and tags
|
|
156
155
|
self.layers = set()
|
|
157
|
-
for
|
|
158
|
-
if
|
|
156
|
+
for tag in self.tags:
|
|
157
|
+
if not self.descriptor.exists(tag):
|
|
159
158
|
raise ValueError(
|
|
160
|
-
f"The
|
|
159
|
+
f"The tag {tag} used with constraint "
|
|
161
160
|
f"{self.name} is not defined in the descriptor. Please "
|
|
162
161
|
"add it to the correct layer using "
|
|
163
162
|
"descriptor.add('layer', ...)."
|
|
164
163
|
)
|
|
165
164
|
|
|
166
|
-
self.
|
|
165
|
+
layer, _ = self.descriptor.location(tag)
|
|
166
|
+
self.layers.add(layer)
|
|
167
167
|
|
|
168
168
|
@abstractmethod
|
|
169
|
-
def check_constraint(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
Evaluates whether the given model predictions satisfy the constraint.
|
|
169
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
170
|
+
"""Evaluates whether the given model predictions satisfy the constraint.
|
|
171
|
+
|
|
172
|
+
1 IS SATISFIED, 0 IS NOT SATISFIED
|
|
174
173
|
|
|
175
174
|
Args:
|
|
176
|
-
|
|
175
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
177
176
|
|
|
178
177
|
Returns:
|
|
179
|
-
tuple[Tensor,
|
|
180
|
-
indicating whether the constraint is satisfied (with
|
|
181
|
-
for satisfaction,
|
|
182
|
-
|
|
183
|
-
|
|
178
|
+
tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
|
|
179
|
+
indicating whether the constraint is satisfied (with value 1.0
|
|
180
|
+
for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
|
|
181
|
+
mask that indicates the relevance of each sample (`True` for relevant
|
|
182
|
+
samples and `False` for irrelevant ones).
|
|
184
183
|
|
|
185
184
|
Raises:
|
|
186
185
|
NotImplementedError: If not implemented in a subclass.
|
|
187
186
|
"""
|
|
188
|
-
|
|
189
187
|
raise NotImplementedError
|
|
190
188
|
|
|
191
189
|
@abstractmethod
|
|
192
|
-
def calculate_direction(
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
better satisfy the constraint.
|
|
190
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
191
|
+
"""Compute adjustment directions to better satisfy the constraint.
|
|
192
|
+
|
|
193
|
+
Given the model predictions, input batch, and context, this method calculates the direction
|
|
194
|
+
in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
|
|
198
195
|
|
|
199
196
|
Args:
|
|
200
|
-
|
|
197
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
201
198
|
|
|
202
199
|
Returns:
|
|
203
|
-
|
|
204
|
-
|
|
200
|
+
dict[str, Tensor]: Dictionary mapping network layers to tensors that
|
|
201
|
+
specify the adjustment direction for each tag.
|
|
205
202
|
|
|
206
203
|
Raises:
|
|
207
|
-
NotImplementedError:
|
|
204
|
+
NotImplementedError: Must be implemented by subclasses.
|
|
208
205
|
"""
|
|
209
|
-
|
|
210
206
|
raise NotImplementedError
|
|
211
207
|
|
|
212
208
|
|
|
213
209
|
class ImplicationConstraint(Constraint):
|
|
214
|
-
"""
|
|
215
|
-
Represents an implication constraint between two
|
|
216
|
-
constraints (head and body).
|
|
210
|
+
"""Represents an implication constraint between two constraints (head and body).
|
|
217
211
|
|
|
218
212
|
The implication constraint ensures that the `body` constraint only applies
|
|
219
213
|
when the `head` constraint is satisfied. If the `head` constraint is not
|
|
220
214
|
satisfied, the `body` constraint does not apply.
|
|
221
|
-
|
|
222
|
-
Args:
|
|
223
|
-
head (Constraint): The head of the implication. If this constraint
|
|
224
|
-
is satisfied, the body constraint must also be satisfied.
|
|
225
|
-
body (Constraint): The body of the implication. This constraint
|
|
226
|
-
is enforced only when the head constraint is satisfied.
|
|
227
|
-
name (str, optional): A unique name for the constraint. If not
|
|
228
|
-
provided, the name is generated in the format
|
|
229
|
-
"{body.name} if {head.name}". Defaults to None.
|
|
230
|
-
monitor_only (bool, optional): If True, the constraint is only
|
|
231
|
-
monitored without adjusting the loss. Defaults to False.
|
|
232
|
-
rescale_factor (Number, optional): The scaling factor for the
|
|
233
|
-
constraint-adjusted loss. Defaults to 1.5.
|
|
234
|
-
|
|
235
|
-
Raises:
|
|
236
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
237
|
-
|
|
238
215
|
"""
|
|
239
216
|
|
|
240
217
|
def __init__(
|
|
241
218
|
self,
|
|
242
219
|
head: Constraint,
|
|
243
220
|
body: Constraint,
|
|
244
|
-
name=None,
|
|
245
|
-
monitor_only=False,
|
|
246
|
-
rescale_factor=1.5,
|
|
221
|
+
name: str = None,
|
|
247
222
|
):
|
|
248
|
-
"""
|
|
249
|
-
|
|
250
|
-
|
|
223
|
+
"""Initializes an ImplicationConstraint instance.
|
|
224
|
+
|
|
225
|
+
Uses `enforce` and `rescale_factor` from the body constraint.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
head (Constraint): Constraint defining the head of the implication.
|
|
229
|
+
body (Constraint): Constraint defining the body of the implication.
|
|
230
|
+
name (str, optional): A unique name for the constraint. If not
|
|
231
|
+
provided, a name is generated based on the class name and a
|
|
232
|
+
random suffix.
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
251
236
|
|
|
237
|
+
"""
|
|
252
238
|
# Type checking
|
|
253
239
|
validate_type("head", head, Constraint)
|
|
254
240
|
validate_type("body", body, Constraint)
|
|
@@ -257,64 +243,82 @@ class ImplicationConstraint(Constraint):
|
|
|
257
243
|
name = f"{body.name} if {head.name}"
|
|
258
244
|
|
|
259
245
|
# Init parent class
|
|
260
|
-
super().__init__(
|
|
261
|
-
head.neurons | body.neurons,
|
|
262
|
-
name,
|
|
263
|
-
monitor_only,
|
|
264
|
-
rescale_factor,
|
|
265
|
-
)
|
|
246
|
+
super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
|
|
266
247
|
|
|
267
248
|
self.head = head
|
|
268
249
|
self.body = body
|
|
269
250
|
|
|
270
|
-
def check_constraint(
|
|
271
|
-
|
|
272
|
-
|
|
251
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
252
|
+
"""Check whether the implication constraint is satisfied.
|
|
253
|
+
|
|
254
|
+
Evaluates the `head` and `body` constraints. The `body` constraint
|
|
255
|
+
is enforced only if the `head` constraint is satisfied. If the
|
|
256
|
+
`head` constraint is not satisfied, the `body` constraint does not
|
|
257
|
+
affect the result.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
273
261
|
|
|
262
|
+
Returns:
|
|
263
|
+
tuple[Tensor, Tensor]:
|
|
264
|
+
- result: Tensor indicating satisfaction of the implication
|
|
265
|
+
constraint (1 if satisfied, 0 otherwise).
|
|
266
|
+
- head_satisfaction: Tensor indicating satisfaction of the
|
|
267
|
+
head constraint alone.
|
|
268
|
+
"""
|
|
274
269
|
# Check satisfaction of head and body constraints
|
|
275
|
-
head_satisfaction, _ = self.head.check_constraint(
|
|
276
|
-
body_satisfaction, _ = self.body.check_constraint(
|
|
270
|
+
head_satisfaction, _ = self.head.check_constraint(data)
|
|
271
|
+
body_satisfaction, _ = self.body.check_constraint(data)
|
|
277
272
|
|
|
278
273
|
# If head constraint is satisfied (returning 1),
|
|
279
274
|
# the body constraint matters (and should return 0/1 based on body)
|
|
280
275
|
# If head constraint is not satisfied (returning 0),
|
|
281
276
|
# the body constraint does not apply (and should return 1)
|
|
282
|
-
result = logical_or(
|
|
283
|
-
|
|
284
|
-
|
|
277
|
+
result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
|
|
278
|
+
|
|
279
|
+
return result, head_satisfaction
|
|
285
280
|
|
|
286
|
-
|
|
281
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
282
|
+
"""Compute adjustment directions for tags to satisfy the constraint.
|
|
287
283
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
284
|
+
Uses the `body` constraint directions as the update vector. Only
|
|
285
|
+
applies updates if the `head` constraint is satisfied. Currently,
|
|
286
|
+
this method only works for dense layers due to tag-to-index
|
|
287
|
+
translation limitations.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
dict[str, Tensor]: Dictionary mapping tags to tensors
|
|
294
|
+
specifying the adjustment direction for each tag.
|
|
295
|
+
"""
|
|
291
296
|
# NOTE currently only works for dense layers
|
|
292
|
-
# due to
|
|
297
|
+
# due to tag to index translation
|
|
293
298
|
|
|
294
299
|
# Use directions of constraint body as update vector
|
|
295
|
-
return self.body.calculate_direction(
|
|
300
|
+
return self.body.calculate_direction(data)
|
|
296
301
|
|
|
297
302
|
|
|
298
303
|
class ScalarConstraint(Constraint):
|
|
299
|
-
"""
|
|
300
|
-
A constraint that enforces scalar-based comparisons on a specific neuron.
|
|
304
|
+
"""A constraint that enforces scalar-based comparisons on a specific tag.
|
|
301
305
|
|
|
302
|
-
This class ensures that the output of a specified
|
|
306
|
+
This class ensures that the output of a specified tag satisfies a scalar
|
|
303
307
|
comparison operation (e.g., less than, greater than, etc.). It uses a
|
|
304
308
|
comparator function to validate the condition and calculates adjustment
|
|
305
309
|
directions accordingly.
|
|
306
310
|
|
|
307
311
|
Args:
|
|
308
|
-
operand (Union[str, Transformation]): Name of the
|
|
312
|
+
operand (Union[str, Transformation]): Name of the tag or a
|
|
309
313
|
transformation to apply.
|
|
310
314
|
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
311
315
|
function (e.g., `torch.ge`, `torch.lt`).
|
|
312
316
|
scalar (Number): The scalar value to compare against.
|
|
313
317
|
name (str, optional): A unique name for the constraint. If not
|
|
314
318
|
provided, a name is auto-generated in the format
|
|
315
|
-
"<
|
|
316
|
-
|
|
317
|
-
without adjusting the loss. Defaults to
|
|
319
|
+
"<tag> <comparator> <scalar>".
|
|
320
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
321
|
+
without adjusting the loss. Defaults to True.
|
|
318
322
|
rescale_factor (Number, optional): Factor to scale the
|
|
319
323
|
constraint-adjusted loss. Defaults to 1.5.
|
|
320
324
|
|
|
@@ -322,87 +326,120 @@ class ScalarConstraint(Constraint):
|
|
|
322
326
|
TypeError: If a provided attribute has an incompatible type.
|
|
323
327
|
|
|
324
328
|
Notes:
|
|
325
|
-
- The `
|
|
326
|
-
- The constraint name is composed using the
|
|
327
|
-
comparator, and scalar value.
|
|
329
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
330
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
328
331
|
|
|
329
332
|
"""
|
|
330
333
|
|
|
331
334
|
def __init__(
|
|
332
335
|
self,
|
|
333
|
-
operand:
|
|
336
|
+
operand: str | Transformation,
|
|
334
337
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
335
338
|
scalar: Number,
|
|
336
339
|
name: str = None,
|
|
337
|
-
|
|
340
|
+
enforce: bool = True,
|
|
338
341
|
rescale_factor: Number = 1.5,
|
|
339
342
|
) -> None:
|
|
340
|
-
"""
|
|
341
|
-
|
|
342
|
-
|
|
343
|
+
"""Initializes a ScalarConstraint instance.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
operand (Union[str, Transformation]): Function that needs to be
|
|
347
|
+
performed on the network variables before applying the
|
|
348
|
+
constraint.
|
|
349
|
+
comparator (Callable[[Tensor, Number], Tensor]): Comparison
|
|
350
|
+
operator used in the constraint. Supported types are
|
|
351
|
+
{torch.lt, torch.le, torch.st, torch.se}.
|
|
352
|
+
scalar (Number): Constant to compare the variable to.
|
|
353
|
+
name (str, optional): A unique name for the constraint. If not
|
|
354
|
+
provided, a name is generated based on the class name and a
|
|
355
|
+
random suffix.
|
|
356
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
357
|
+
without adjusting the loss. Defaults to True.
|
|
358
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
359
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
360
|
+
than 1 to give weight to the constraint.
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
343
364
|
|
|
365
|
+
Notes:
|
|
366
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
367
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
368
|
+
"""
|
|
344
369
|
# Type checking
|
|
345
370
|
validate_type("operand", operand, (str, Transformation))
|
|
346
371
|
validate_comparator_pytorch("comparator", comparator)
|
|
347
|
-
validate_comparator_pytorch("comparator", comparator)
|
|
348
372
|
validate_type("scalar", scalar, Number)
|
|
349
373
|
|
|
350
|
-
# If transformation is provided, get
|
|
351
|
-
# else use IdentityTransformation
|
|
374
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
352
375
|
if isinstance(operand, Transformation):
|
|
353
|
-
|
|
376
|
+
tag = operand.tag
|
|
354
377
|
transformation = operand
|
|
355
378
|
else:
|
|
356
|
-
|
|
357
|
-
transformation = IdentityTransformation(
|
|
379
|
+
tag = operand
|
|
380
|
+
transformation = IdentityTransformation(tag)
|
|
358
381
|
|
|
359
382
|
# Compose constraint name
|
|
360
|
-
name = f"{
|
|
383
|
+
name = f"{tag} {comparator.__name__} {str(scalar)}"
|
|
361
384
|
|
|
362
385
|
# Init parent class
|
|
363
|
-
super().__init__({
|
|
386
|
+
super().__init__({tag}, name, enforce, rescale_factor)
|
|
364
387
|
|
|
365
388
|
# Init variables
|
|
389
|
+
self.tag = tag
|
|
366
390
|
self.comparator = comparator
|
|
367
391
|
self.scalar = scalar
|
|
368
392
|
self.transformation = transformation
|
|
369
393
|
|
|
370
|
-
# Get layer name and feature index from neuron_name
|
|
371
|
-
self.layer = self.descriptor.neuron_to_layer[neuron_name]
|
|
372
|
-
self.index = self.descriptor.neuron_to_index[neuron_name]
|
|
373
|
-
|
|
374
394
|
# Calculate directions based on constraint operator
|
|
375
395
|
if self.comparator in [lt, le]:
|
|
376
|
-
self.direction = -1
|
|
377
|
-
elif self.comparator in [gt, ge]:
|
|
378
396
|
self.direction = 1
|
|
397
|
+
elif self.comparator in [gt, ge]:
|
|
398
|
+
self.direction = -1
|
|
399
|
+
|
|
400
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
401
|
+
"""Check if the scalar constraint is satisfied for a given tag.
|
|
379
402
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
) -> tuple[Tensor, int]:
|
|
403
|
+
Args:
|
|
404
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
383
405
|
|
|
406
|
+
Returns:
|
|
407
|
+
tuple[Tensor, Tensor]:
|
|
408
|
+
- result: Tensor indicating whether the tag satisfies the constraint.
|
|
409
|
+
- ones_like(result): Tensor of ones with same shape as `result`.
|
|
410
|
+
"""
|
|
384
411
|
# Select relevant columns
|
|
385
|
-
selection =
|
|
412
|
+
selection = self.descriptor.select(self.tag, data)
|
|
386
413
|
|
|
387
414
|
# Apply transformation
|
|
388
415
|
selection = self.transformation(selection)
|
|
389
416
|
|
|
390
417
|
# Calculate current constraint result
|
|
391
418
|
result = self.comparator(selection, self.scalar).float()
|
|
392
|
-
return result,
|
|
419
|
+
return result, ones_like(result)
|
|
420
|
+
|
|
421
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
422
|
+
"""Compute adjustment directions to satisfy the scalar constraint.
|
|
423
|
+
|
|
424
|
+
Only works for dense layers due to tag-to-index translation.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
393
428
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
429
|
+
Returns:
|
|
430
|
+
dict[str, Tensor]: Dictionary mapping layers to tensors specifying
|
|
431
|
+
the adjustment direction for each tag.
|
|
432
|
+
"""
|
|
397
433
|
# NOTE currently only works for dense layers due
|
|
398
|
-
# to
|
|
434
|
+
# to tag to index translation
|
|
399
435
|
|
|
400
436
|
output = {}
|
|
401
437
|
|
|
402
438
|
for layer in self.layers:
|
|
403
|
-
output[layer] = zeros_like(
|
|
439
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
404
440
|
|
|
405
|
-
|
|
441
|
+
layer, index = self.descriptor.location(self.tag)
|
|
442
|
+
output[layer][index] = self.direction
|
|
406
443
|
|
|
407
444
|
for layer in self.layers:
|
|
408
445
|
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
@@ -411,26 +448,24 @@ class ScalarConstraint(Constraint):
|
|
|
411
448
|
|
|
412
449
|
|
|
413
450
|
class BinaryConstraint(Constraint):
|
|
414
|
-
"""
|
|
415
|
-
A constraint that enforces a binary comparison between two neurons.
|
|
451
|
+
"""A constraint that enforces a binary comparison between two tags.
|
|
416
452
|
|
|
417
|
-
This class ensures that the output of one
|
|
418
|
-
operation with the output of another
|
|
419
|
-
|
|
420
|
-
validate the condition and calculates adjustment directions accordingly.
|
|
453
|
+
This class ensures that the output of one tag satisfies a comparison
|
|
454
|
+
operation with the output of another tag (e.g., less than, greater than, etc.).
|
|
455
|
+
It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
|
|
421
456
|
|
|
422
457
|
Args:
|
|
423
458
|
operand_left (Union[str, Transformation]): Name of the left
|
|
424
|
-
|
|
459
|
+
tag or a transformation to apply.
|
|
425
460
|
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
426
461
|
function (e.g., `torch.ge`, `torch.lt`).
|
|
427
462
|
operand_right (Union[str, Transformation]): Name of the right
|
|
428
|
-
|
|
463
|
+
tag or a transformation to apply.
|
|
429
464
|
name (str, optional): A unique name for the constraint. If not
|
|
430
465
|
provided, a name is auto-generated in the format
|
|
431
|
-
"<
|
|
432
|
-
|
|
433
|
-
without adjusting the loss. Defaults to
|
|
466
|
+
"<operand_left> <comparator> <operand_right>".
|
|
467
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
468
|
+
without adjusting the loss. Defaults to True.
|
|
434
469
|
rescale_factor (Number, optional): Factor to scale the
|
|
435
470
|
constraint-adjusted loss. Defaults to 1.5.
|
|
436
471
|
|
|
@@ -438,84 +473,107 @@ class BinaryConstraint(Constraint):
|
|
|
438
473
|
TypeError: If a provided attribute has an incompatible type.
|
|
439
474
|
|
|
440
475
|
Notes:
|
|
441
|
-
- The
|
|
442
|
-
- The constraint name is composed using the left
|
|
443
|
-
comparator, and right neuron name.
|
|
476
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
477
|
+
- The constraint name is composed using the left tag, comparator, and right tag.
|
|
444
478
|
|
|
445
479
|
"""
|
|
446
480
|
|
|
447
481
|
def __init__(
|
|
448
482
|
self,
|
|
449
|
-
operand_left:
|
|
483
|
+
operand_left: str | Transformation,
|
|
450
484
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
451
|
-
operand_right:
|
|
485
|
+
operand_right: str | Transformation,
|
|
452
486
|
name: str = None,
|
|
453
|
-
|
|
487
|
+
enforce: bool = True,
|
|
454
488
|
rescale_factor: Number = 1.5,
|
|
455
489
|
) -> None:
|
|
456
|
-
"""
|
|
457
|
-
Initializes a BinaryConstraint instance.
|
|
458
|
-
"""
|
|
490
|
+
"""Initializes a BinaryConstraint instance.
|
|
459
491
|
|
|
492
|
+
Args:
|
|
493
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
494
|
+
tag or a transformation to apply.
|
|
495
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
496
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
497
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
498
|
+
tag or a transformation to apply.
|
|
499
|
+
name (str, optional): A unique name for the constraint. If not
|
|
500
|
+
provided, a name is auto-generated in the format
|
|
501
|
+
"<operand_left> <comparator> <operand_right>".
|
|
502
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
503
|
+
without adjusting the loss. Defaults to True.
|
|
504
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
505
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
509
|
+
|
|
510
|
+
Notes:
|
|
511
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
512
|
+
- The constraint name is composed using the left tag,
|
|
513
|
+
comparator, and right tag.
|
|
514
|
+
"""
|
|
460
515
|
# Type checking
|
|
461
516
|
validate_type("operand_left", operand_left, (str, Transformation))
|
|
462
517
|
validate_comparator_pytorch("comparator", comparator)
|
|
463
518
|
validate_comparator_pytorch("comparator", comparator)
|
|
464
519
|
validate_type("operand_right", operand_right, (str, Transformation))
|
|
465
520
|
|
|
466
|
-
# If transformation is provided, get
|
|
467
|
-
# else use IdentityTransformation
|
|
521
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
468
522
|
if isinstance(operand_left, Transformation):
|
|
469
|
-
|
|
523
|
+
tag_left = operand_left.tag
|
|
470
524
|
transformation_left = operand_left
|
|
471
525
|
else:
|
|
472
|
-
|
|
473
|
-
transformation_left = IdentityTransformation(
|
|
526
|
+
tag_left = operand_left
|
|
527
|
+
transformation_left = IdentityTransformation(tag_left)
|
|
474
528
|
|
|
475
529
|
if isinstance(operand_right, Transformation):
|
|
476
|
-
|
|
530
|
+
tag_right = operand_right.tag
|
|
477
531
|
transformation_right = operand_right
|
|
478
532
|
else:
|
|
479
|
-
|
|
480
|
-
transformation_right = IdentityTransformation(
|
|
533
|
+
tag_right = operand_right
|
|
534
|
+
transformation_right = IdentityTransformation(tag_right)
|
|
481
535
|
|
|
482
536
|
# Compose constraint name
|
|
483
|
-
name = f"{
|
|
537
|
+
name = f"{tag_left} {comparator.__name__} {tag_right}"
|
|
484
538
|
|
|
485
539
|
# Init parent class
|
|
486
|
-
super().__init__(
|
|
487
|
-
{neuron_name_left, neuron_name_right},
|
|
488
|
-
name,
|
|
489
|
-
monitor_only,
|
|
490
|
-
rescale_factor,
|
|
491
|
-
)
|
|
540
|
+
super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
|
|
492
541
|
|
|
493
542
|
# Init variables
|
|
494
543
|
self.comparator = comparator
|
|
544
|
+
self.tag_left = tag_left
|
|
545
|
+
self.tag_right = tag_right
|
|
495
546
|
self.transformation_left = transformation_left
|
|
496
547
|
self.transformation_right = transformation_right
|
|
497
548
|
|
|
498
|
-
# Get layer name and feature index from neuron_name
|
|
499
|
-
self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
500
|
-
self.layer_right = self.descriptor.neuron_to_layer[neuron_name_right]
|
|
501
|
-
self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
|
|
502
|
-
self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
|
|
503
|
-
|
|
504
549
|
# Calculate directions based on constraint operator
|
|
505
550
|
if self.comparator in [lt, le]:
|
|
506
|
-
self.direction_left = -1
|
|
507
|
-
self.direction_right = 1
|
|
508
|
-
else:
|
|
509
551
|
self.direction_left = 1
|
|
510
552
|
self.direction_right = -1
|
|
553
|
+
else:
|
|
554
|
+
self.direction_left = -1
|
|
555
|
+
self.direction_right = 1
|
|
556
|
+
|
|
557
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
558
|
+
"""Evaluate whether the binary constraint is satisfied for the current predictions.
|
|
511
559
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
560
|
+
The constraint compares the outputs of two tags using the specified
|
|
561
|
+
comparator function. A result of `1` indicates the constraint is satisfied
|
|
562
|
+
for a sample, and `0` indicates it is violated.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
515
566
|
|
|
567
|
+
Returns:
|
|
568
|
+
tuple[Tensor, Tensor]:
|
|
569
|
+
- result (Tensor): Binary tensor indicating constraint satisfaction
|
|
570
|
+
(1 for satisfied, 0 for violated) for each sample.
|
|
571
|
+
- mask (Tensor): Tensor of ones with the same shape as `result`,
|
|
572
|
+
used for constraint aggregation.
|
|
573
|
+
"""
|
|
516
574
|
# Select relevant columns
|
|
517
|
-
selection_left =
|
|
518
|
-
selection_right =
|
|
575
|
+
selection_left = self.descriptor.select(self.tag_left, data)
|
|
576
|
+
selection_right = self.descriptor.select(self.tag_right, data)
|
|
519
577
|
|
|
520
578
|
# Apply transformations
|
|
521
579
|
selection_left = self.transformation_left(selection_left)
|
|
@@ -523,21 +581,34 @@ class BinaryConstraint(Constraint):
|
|
|
523
581
|
|
|
524
582
|
result = self.comparator(selection_left, selection_right).float()
|
|
525
583
|
|
|
526
|
-
return result,
|
|
584
|
+
return result, ones_like(result)
|
|
585
|
+
|
|
586
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
587
|
+
"""Compute adjustment directions for the tags involved in the binary constraint.
|
|
588
|
+
|
|
589
|
+
The returned directions indicate how to adjust each tag's output to
|
|
590
|
+
satisfy the constraint. Only currently supported for dense layers.
|
|
527
591
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
592
|
+
Args:
|
|
593
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
dict[str, Tensor]: A mapping from layer names to tensors specifying
|
|
597
|
+
the normalized adjustment directions for each tag involved in the
|
|
598
|
+
constraint.
|
|
599
|
+
"""
|
|
531
600
|
# NOTE currently only works for dense layers due
|
|
532
|
-
# to
|
|
601
|
+
# to tag to index translation
|
|
533
602
|
|
|
534
603
|
output = {}
|
|
535
604
|
|
|
536
605
|
for layer in self.layers:
|
|
537
|
-
output[layer] = zeros_like(
|
|
606
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
538
607
|
|
|
539
|
-
|
|
540
|
-
|
|
608
|
+
layer_left, index_left = self.descriptor.location(self.tag_left)
|
|
609
|
+
layer_right, index_right = self.descriptor.location(self.tag_right)
|
|
610
|
+
output[layer_left][index_left] = self.direction_left
|
|
611
|
+
output[layer_right][index_right] = self.direction_right
|
|
541
612
|
|
|
542
613
|
for layer in self.layers:
|
|
543
614
|
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
@@ -546,142 +617,119 @@ class BinaryConstraint(Constraint):
|
|
|
546
617
|
|
|
547
618
|
|
|
548
619
|
class SumConstraint(Constraint):
|
|
549
|
-
"""
|
|
550
|
-
A constraint that enforces a weighted summation comparison
|
|
551
|
-
between two groups of neurons.
|
|
620
|
+
"""A constraint that enforces a weighted summation comparison between two groups of tags.
|
|
552
621
|
|
|
553
622
|
This class evaluates whether the weighted sum of outputs from one set of
|
|
554
|
-
|
|
555
|
-
outputs from another set of
|
|
556
|
-
|
|
557
|
-
Args:
|
|
558
|
-
operands_left (list[Union[str, Transformation]]): List of neuron
|
|
559
|
-
names or transformations on the left side.
|
|
560
|
-
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
561
|
-
function for the constraint.
|
|
562
|
-
operands_right (list[Union[str, Transformation]]): List of neuron
|
|
563
|
-
names or transformations on the right side.
|
|
564
|
-
weights_left (list[Number], optional): Weights for the left neurons.
|
|
565
|
-
Defaults to None.
|
|
566
|
-
weights_right (list[Number], optional): Weights for the right
|
|
567
|
-
neurons. Defaults to None.
|
|
568
|
-
name (str, optional): Unique name for the constraint.
|
|
569
|
-
If None, it's auto-generated. Defaults to None.
|
|
570
|
-
monitor_only (bool, optional): If True, only monitor the constraint
|
|
571
|
-
without adjusting the loss. Defaults to False.
|
|
572
|
-
rescale_factor (Number, optional): Factor to scale the
|
|
573
|
-
constraint-adjusted loss. Defaults to 1.5.
|
|
574
|
-
|
|
575
|
-
Raises:
|
|
576
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
577
|
-
ValueError: If the dimensions of neuron names and weights mismatch.
|
|
578
|
-
|
|
623
|
+
tags satisfies a comparison operation with the weighted sum of
|
|
624
|
+
outputs from another set of tags.
|
|
579
625
|
"""
|
|
580
626
|
|
|
581
627
|
def __init__(
|
|
582
628
|
self,
|
|
583
|
-
operands_left: list[
|
|
629
|
+
operands_left: list[str | Transformation],
|
|
584
630
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
585
|
-
operands_right: list[
|
|
631
|
+
operands_right: list[str | Transformation],
|
|
586
632
|
weights_left: list[Number] = None,
|
|
587
633
|
weights_right: list[Number] = None,
|
|
588
634
|
name: str = None,
|
|
589
|
-
|
|
635
|
+
enforce: bool = True,
|
|
590
636
|
rescale_factor: Number = 1.5,
|
|
591
637
|
) -> None:
|
|
592
|
-
"""
|
|
593
|
-
Initializes the SumConstraint.
|
|
594
|
-
"""
|
|
638
|
+
"""Initializes the SumConstraint.
|
|
595
639
|
|
|
640
|
+
Args:
|
|
641
|
+
operands_left (list[Union[str, Transformation]]): List of tags
|
|
642
|
+
or transformations on the left side.
|
|
643
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
644
|
+
function for the constraint.
|
|
645
|
+
operands_right (list[Union[str, Transformation]]): List of tags
|
|
646
|
+
or transformations on the right side.
|
|
647
|
+
weights_left (list[Number], optional): Weights for the left
|
|
648
|
+
tags. Defaults to None.
|
|
649
|
+
weights_right (list[Number], optional): Weights for the right
|
|
650
|
+
tags. Defaults to None.
|
|
651
|
+
name (str, optional): Unique name for the constraint.
|
|
652
|
+
If None, it's auto-generated. Defaults to None.
|
|
653
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
654
|
+
without adjusting the loss. Defaults to True.
|
|
655
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
656
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
657
|
+
|
|
658
|
+
Raises:
|
|
659
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
660
|
+
ValueError: If the dimensions of tags and weights mismatch.
|
|
661
|
+
"""
|
|
596
662
|
# Type checking
|
|
597
663
|
validate_iterable("operands_left", operands_left, (str, Transformation))
|
|
598
664
|
validate_comparator_pytorch("comparator", comparator)
|
|
599
665
|
validate_comparator_pytorch("comparator", comparator)
|
|
600
|
-
validate_iterable(
|
|
601
|
-
"operands_right", operands_right, (str, Transformation)
|
|
602
|
-
)
|
|
666
|
+
validate_iterable("operands_right", operands_right, (str, Transformation))
|
|
603
667
|
validate_iterable("weights_left", weights_left, Number, allow_none=True)
|
|
604
|
-
validate_iterable(
|
|
605
|
-
"weights_right", weights_right, Number, allow_none=True
|
|
606
|
-
)
|
|
668
|
+
validate_iterable("weights_right", weights_right, Number, allow_none=True)
|
|
607
669
|
|
|
608
|
-
# If transformation is provided, get
|
|
609
|
-
|
|
610
|
-
neuron_names_left: list[str] = []
|
|
670
|
+
# If transformation is provided, get tag, else use IdentityTransformation
|
|
671
|
+
tags_left: list[str] = []
|
|
611
672
|
transformations_left: list[Transformation] = []
|
|
612
673
|
for operand_left in operands_left:
|
|
613
674
|
if isinstance(operand_left, Transformation):
|
|
614
|
-
|
|
615
|
-
|
|
675
|
+
tag_left = operand_left.tag
|
|
676
|
+
tags_left.append(tag_left)
|
|
616
677
|
transformations_left.append(operand_left)
|
|
617
678
|
else:
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
transformations_left.append(
|
|
621
|
-
IdentityTransformation(neuron_name_left)
|
|
622
|
-
)
|
|
679
|
+
tag_left = operand_left
|
|
680
|
+
tags_left.append(tag_left)
|
|
681
|
+
transformations_left.append(IdentityTransformation(tag_left))
|
|
623
682
|
|
|
624
|
-
|
|
683
|
+
tags_right: list[str] = []
|
|
625
684
|
transformations_right: list[Transformation] = []
|
|
626
685
|
for operand_right in operands_right:
|
|
627
686
|
if isinstance(operand_right, Transformation):
|
|
628
|
-
|
|
629
|
-
|
|
687
|
+
tag_right = operand_right.tag
|
|
688
|
+
tags_right.append(tag_right)
|
|
630
689
|
transformations_right.append(operand_right)
|
|
631
690
|
else:
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
transformations_right.append(
|
|
635
|
-
IdentityTransformation(neuron_name_right)
|
|
636
|
-
)
|
|
691
|
+
tag_right = operand_right
|
|
692
|
+
tags_right.append(tag_right)
|
|
693
|
+
transformations_right.append(IdentityTransformation(tag_right))
|
|
637
694
|
|
|
638
695
|
# Compose constraint name
|
|
639
|
-
w_left = weights_left or [""] * len(
|
|
640
|
-
w_right = weights_right or [""] * len(
|
|
641
|
-
left_expr = " + ".join(
|
|
642
|
-
|
|
643
|
-
)
|
|
644
|
-
right_expr = " + ".join(
|
|
645
|
-
f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
|
|
646
|
-
)
|
|
696
|
+
w_left = weights_left or [""] * len(tags_left)
|
|
697
|
+
w_right = weights_right or [""] * len(tags_right)
|
|
698
|
+
left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
|
|
699
|
+
right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
|
|
647
700
|
comparator_name = comparator.__name__
|
|
648
701
|
name = f"{left_expr} {comparator_name} {right_expr}"
|
|
649
702
|
|
|
650
703
|
# Init parent class
|
|
651
|
-
|
|
652
|
-
super().__init__(
|
|
704
|
+
tags = set(tags_left) | set(tags_right)
|
|
705
|
+
super().__init__(tags, name, enforce, rescale_factor)
|
|
653
706
|
|
|
654
707
|
# Init variables
|
|
655
708
|
self.comparator = comparator
|
|
656
|
-
self.
|
|
657
|
-
self.
|
|
709
|
+
self.tags_left = tags_left
|
|
710
|
+
self.tags_right = tags_right
|
|
658
711
|
self.transformations_left = transformations_left
|
|
659
712
|
self.transformations_right = transformations_right
|
|
660
713
|
|
|
661
|
-
# If feature list dimensions don't match
|
|
662
|
-
|
|
663
|
-
if weights_left and (len(neuron_names_left) != len(weights_left)):
|
|
714
|
+
# If feature list dimensions don't match weight list dimensions, raise error
|
|
715
|
+
if weights_left and (len(tags_left) != len(weights_left)):
|
|
664
716
|
raise ValueError(
|
|
665
|
-
"The dimensions of
|
|
666
|
-
"dimensions of weights_left."
|
|
717
|
+
"The dimensions of tags_left don't match with the dimensions of weights_left."
|
|
667
718
|
)
|
|
668
|
-
if weights_right and (len(
|
|
719
|
+
if weights_right and (len(tags_right) != len(weights_right)):
|
|
669
720
|
raise ValueError(
|
|
670
|
-
"The dimensions of
|
|
671
|
-
"dimensions of weights_right."
|
|
721
|
+
"The dimensions of tags_right don't match with the dimensions of weights_right."
|
|
672
722
|
)
|
|
673
723
|
|
|
674
724
|
# If weights are provided for summation, transform them to Tensors
|
|
675
725
|
if weights_left:
|
|
676
726
|
self.weights_left = tensor(weights_left, device=self.device)
|
|
677
727
|
else:
|
|
678
|
-
self.weights_left = ones(len(
|
|
728
|
+
self.weights_left = ones(len(tags_left), device=self.device)
|
|
679
729
|
if weights_right:
|
|
680
730
|
self.weights_right = tensor(weights_right, device=self.device)
|
|
681
731
|
else:
|
|
682
|
-
self.weights_right = ones(
|
|
683
|
-
len(neuron_names_right), device=self.device
|
|
684
|
-
)
|
|
732
|
+
self.weights_right = ones(len(tags_right), device=self.device)
|
|
685
733
|
|
|
686
734
|
# Calculate directions based on constraint operator
|
|
687
735
|
if self.comparator in [lt, le]:
|
|
@@ -691,80 +739,82 @@ class SumConstraint(Constraint):
|
|
|
691
739
|
self.direction_left = 1
|
|
692
740
|
self.direction_right = -1
|
|
693
741
|
|
|
694
|
-
def check_constraint(
|
|
695
|
-
|
|
696
|
-
|
|
742
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
743
|
+
"""Evaluate whether the weighted sum constraint is satisfied.
|
|
744
|
+
|
|
745
|
+
Computes the weighted sum of outputs from the left and right tags,
|
|
746
|
+
applies the specified comparator function, and returns a binary result for
|
|
747
|
+
each sample.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
tuple[Tensor, Tensor]:
|
|
754
|
+
- result (Tensor): Binary tensor indicating whether the constraint
|
|
755
|
+
is satisfied (1) or violated (0) for each sample.
|
|
756
|
+
- mask (Tensor): Tensor of ones, used for constraint aggregation.
|
|
757
|
+
"""
|
|
697
758
|
|
|
698
759
|
def compute_weighted_sum(
|
|
699
|
-
|
|
760
|
+
tags: list[str],
|
|
700
761
|
transformations: list[Transformation],
|
|
701
|
-
weights:
|
|
702
|
-
) ->
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
for neuron_name in neuron_names
|
|
706
|
-
]
|
|
707
|
-
indices = [
|
|
708
|
-
self.descriptor.neuron_to_index[neuron_name]
|
|
709
|
-
for neuron_name in neuron_names
|
|
710
|
-
]
|
|
711
|
-
|
|
712
|
-
# Select relevant column
|
|
713
|
-
selections = [
|
|
714
|
-
prediction[layer][:, index]
|
|
715
|
-
for layer, index in zip(layers, indices)
|
|
716
|
-
]
|
|
762
|
+
weights: Tensor,
|
|
763
|
+
) -> Tensor:
|
|
764
|
+
# Select relevant columns
|
|
765
|
+
selections = [self.descriptor.select(tag, data) for tag in tags]
|
|
717
766
|
|
|
718
767
|
# Apply transformations
|
|
719
768
|
results = []
|
|
720
|
-
for transformation, selection in zip(transformations, selections):
|
|
769
|
+
for transformation, selection in zip(transformations, selections, strict=False):
|
|
721
770
|
results.append(transformation(selection))
|
|
722
771
|
|
|
723
|
-
# Extract predictions for all
|
|
724
|
-
predictions = stack(
|
|
725
|
-
results,
|
|
726
|
-
dim=1,
|
|
727
|
-
)
|
|
772
|
+
# Extract predictions for all tags and apply weights in bulk
|
|
773
|
+
predictions = stack(results)
|
|
728
774
|
|
|
729
775
|
# Calculate weighted sum
|
|
730
|
-
return (predictions * weights.
|
|
776
|
+
return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
|
|
731
777
|
|
|
732
778
|
# Compute weighted sums
|
|
733
779
|
weighted_sum_left = compute_weighted_sum(
|
|
734
|
-
self.
|
|
735
|
-
self.transformations_left,
|
|
736
|
-
self.weights_left,
|
|
780
|
+
self.tags_left, self.transformations_left, self.weights_left
|
|
737
781
|
)
|
|
738
782
|
weighted_sum_right = compute_weighted_sum(
|
|
739
|
-
self.
|
|
740
|
-
self.transformations_right,
|
|
741
|
-
self.weights_right,
|
|
783
|
+
self.tags_right, self.transformations_right, self.weights_right
|
|
742
784
|
)
|
|
743
785
|
|
|
744
786
|
# Apply the comparator and calculate the result
|
|
745
787
|
result = self.comparator(weighted_sum_left, weighted_sum_right).float()
|
|
746
788
|
|
|
747
|
-
return result,
|
|
789
|
+
return result, ones_like(result)
|
|
790
|
+
|
|
791
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
792
|
+
"""Compute adjustment directions for tags involved in the weighted sum constraint.
|
|
793
|
+
|
|
794
|
+
The directions indicate how to adjust each tag's output to satisfy the
|
|
795
|
+
constraint. Only dense layers are currently supported.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
748
799
|
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
800
|
+
Returns:
|
|
801
|
+
dict[str, Tensor]: Mapping from layer names to normalized tensors
|
|
802
|
+
specifying adjustment directions for each tag involved in the constraint.
|
|
803
|
+
"""
|
|
752
804
|
# NOTE currently only works for dense layers
|
|
753
|
-
# due to
|
|
805
|
+
# due to tag to index translation
|
|
754
806
|
|
|
755
807
|
output = {}
|
|
756
808
|
|
|
757
809
|
for layer in self.layers:
|
|
758
|
-
output[layer] = zeros_like(
|
|
810
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
759
811
|
|
|
760
|
-
for
|
|
761
|
-
layer = self.descriptor.
|
|
762
|
-
index = self.descriptor.neuron_to_index[neuron_name_left]
|
|
812
|
+
for tag_left in self.tags_left:
|
|
813
|
+
layer, index = self.descriptor.location(tag_left)
|
|
763
814
|
output[layer][index] = self.direction_left
|
|
764
815
|
|
|
765
|
-
for
|
|
766
|
-
layer = self.descriptor.
|
|
767
|
-
index = self.descriptor.neuron_to_index[neuron_name_right]
|
|
816
|
+
for tag_right in self.tags_right:
|
|
817
|
+
layer, index = self.descriptor.location(tag_right)
|
|
768
818
|
output[layer][index] = self.direction_right
|
|
769
819
|
|
|
770
820
|
for layer in self.layers:
|
|
@@ -773,134 +823,434 @@ class SumConstraint(Constraint):
|
|
|
773
823
|
return output
|
|
774
824
|
|
|
775
825
|
|
|
776
|
-
class
|
|
826
|
+
class MonotonicityConstraint(Constraint):
|
|
827
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
828
|
+
|
|
829
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
830
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
777
831
|
"""
|
|
778
|
-
A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
|
|
779
|
-
where `a` and `b` are neurons or transformations.
|
|
780
832
|
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
833
|
+
def __init__(
|
|
834
|
+
self,
|
|
835
|
+
tag_prediction: str,
|
|
836
|
+
tag_reference: str,
|
|
837
|
+
rescale_factor_lower: float = 1.5,
|
|
838
|
+
rescale_factor_upper: float = 1.75,
|
|
839
|
+
stable: bool = True,
|
|
840
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
841
|
+
name: str = None,
|
|
842
|
+
enforce: bool = True,
|
|
843
|
+
):
|
|
844
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
785
845
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
neuron name (str) or a Transformation.
|
|
789
|
-
b (Union[str, Transformation]): The second input, either a
|
|
790
|
-
neuron name (str) or a Transformation.
|
|
791
|
-
rtol (float, optional): The relative tolerance for the
|
|
792
|
-
comparison (default is 0.00001).
|
|
793
|
-
atol (float, optional): The absolute tolerance for the
|
|
794
|
-
comparison (default is 1e-8).
|
|
795
|
-
name (str, optional): The name of the constraint
|
|
796
|
-
(default is None, and it is generated automatically).
|
|
797
|
-
monitor_only (bool, optional): Flag indicating whether the
|
|
798
|
-
constraint is only for monitoring (default is False).
|
|
799
|
-
rescale_factor (Number, optional): A factor used for
|
|
800
|
-
rescaling (default is 1.5).
|
|
846
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
847
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
801
848
|
|
|
802
|
-
|
|
803
|
-
|
|
849
|
+
Args:
|
|
850
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
851
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
852
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
853
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
854
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
855
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
856
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
857
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
858
|
+
"""
|
|
859
|
+
# Type checking
|
|
860
|
+
validate_type("rescale_factor_lower", rescale_factor_lower, float)
|
|
861
|
+
validate_type("rescale_factor_upper", rescale_factor_upper, float)
|
|
862
|
+
validate_type("stable", stable, bool)
|
|
863
|
+
validate_type("direction", direction, str)
|
|
864
|
+
|
|
865
|
+
# Compose constraint name
|
|
866
|
+
if name is None:
|
|
867
|
+
name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
|
|
868
|
+
|
|
869
|
+
# Init parent class
|
|
870
|
+
super().__init__({tag_prediction}, name, enforce, 1.0)
|
|
871
|
+
|
|
872
|
+
# Init variables
|
|
873
|
+
self.tag_prediction = tag_prediction
|
|
874
|
+
self.tag_reference = tag_reference
|
|
875
|
+
self.rescale_factor_lower = rescale_factor_lower
|
|
876
|
+
self.rescale_factor_upper = rescale_factor_upper
|
|
877
|
+
self.stable = stable
|
|
878
|
+
self.descending = direction == "descending"
|
|
879
|
+
|
|
880
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
881
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
882
|
+
# Select relevant columns
|
|
883
|
+
preds = self.descriptor.select(self.tag_prediction, data)
|
|
884
|
+
targets = self.descriptor.select(self.tag_reference, data)
|
|
885
|
+
|
|
886
|
+
# Utility: convert values -> ranks (0 ... num_features-1)
|
|
887
|
+
def compute_ranks(x: Tensor, descending: bool) -> Tensor:
|
|
888
|
+
return argsort(
|
|
889
|
+
argsort(x, descending=descending, stable=self.stable, dim=0),
|
|
890
|
+
descending=False,
|
|
891
|
+
stable=self.stable,
|
|
892
|
+
dim=0,
|
|
893
|
+
)
|
|
804
894
|
|
|
895
|
+
# Compute predicted and target ranks
|
|
896
|
+
pred_ranks = compute_ranks(preds, descending=self.descending)
|
|
897
|
+
target_ranks = compute_ranks(targets, descending=False)
|
|
898
|
+
|
|
899
|
+
# Rank difference
|
|
900
|
+
rank_diff = pred_ranks - target_ranks
|
|
901
|
+
|
|
902
|
+
# Rescale differences into [rescale_factor_lower, rescale_factor_upper]
|
|
903
|
+
batch_size = preds.shape[0]
|
|
904
|
+
invert_direction = -1 if self.descending else 1
|
|
905
|
+
self.compared_rankings = (
|
|
906
|
+
(rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
|
|
907
|
+
+ self.rescale_factor_lower * sign(rank_diff)
|
|
908
|
+
) * invert_direction
|
|
909
|
+
|
|
910
|
+
# Calculate satisfaction
|
|
911
|
+
incorrect_rankings = eq(self.compared_rankings, 0).float()
|
|
912
|
+
|
|
913
|
+
return incorrect_rankings, ones_like(incorrect_rankings)
|
|
914
|
+
|
|
915
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
916
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
917
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
918
|
+
return {layer: self.compared_rankings}
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
class GroupedMonotonicityConstraint(MonotonicityConstraint):
|
|
922
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
923
|
+
|
|
924
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
925
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
805
926
|
"""
|
|
806
927
|
|
|
807
928
|
def __init__(
|
|
808
929
|
self,
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
930
|
+
tag_prediction: str,
|
|
931
|
+
tag_reference: str,
|
|
932
|
+
tag_group_identifier: str,
|
|
933
|
+
rescale_factor_lower: float = 1.5,
|
|
934
|
+
rescale_factor_upper: float = 1.75,
|
|
935
|
+
stable: bool = True,
|
|
936
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
937
|
+
name: str = None,
|
|
938
|
+
enforce: bool = True,
|
|
939
|
+
):
|
|
940
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
941
|
+
|
|
942
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
943
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
947
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
948
|
+
tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
|
|
949
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
950
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
951
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
952
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
953
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
954
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
955
|
+
"""
|
|
956
|
+
# Compose constraint name
|
|
957
|
+
if name is None:
|
|
958
|
+
name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
|
|
959
|
+
|
|
960
|
+
# Init parent class
|
|
961
|
+
super().__init__(
|
|
962
|
+
tag_prediction=tag_prediction,
|
|
963
|
+
tag_reference=tag_reference,
|
|
964
|
+
rescale_factor_lower=rescale_factor_lower,
|
|
965
|
+
rescale_factor_upper=rescale_factor_upper,
|
|
966
|
+
stable=stable,
|
|
967
|
+
direction=direction,
|
|
968
|
+
name=name,
|
|
969
|
+
enforce=enforce,
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
# Init variables
|
|
973
|
+
self.tag_prediction = tag_prediction
|
|
974
|
+
self.tag_reference = tag_reference
|
|
975
|
+
self.tag_group_identifier = tag_group_identifier
|
|
976
|
+
|
|
977
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
978
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
979
|
+
# Select group identifiers and convert to unique list
|
|
980
|
+
group_identifiers = self.descriptor.select(self.tag_group_identifier, data)
|
|
981
|
+
unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
|
|
982
|
+
|
|
983
|
+
# Initialize checks and directions
|
|
984
|
+
checks = zeros_like(group_identifiers, device=self.device)
|
|
985
|
+
self.directions = zeros_like(group_identifiers, device=self.device)
|
|
986
|
+
|
|
987
|
+
# Get prediction and target keys
|
|
988
|
+
preds_key, _ = self.descriptor.location(self.tag_prediction)
|
|
989
|
+
targets_key, _ = self.descriptor.location(self.tag_reference)
|
|
990
|
+
|
|
991
|
+
for group_identifier in unique_group_identifiers:
|
|
992
|
+
# Create mask for the samples in this group
|
|
993
|
+
group_mask = (group_identifiers == group_identifier).squeeze(1)
|
|
994
|
+
|
|
995
|
+
# Create mini-batch for the group
|
|
996
|
+
group_data = {
|
|
997
|
+
preds_key: data[preds_key][group_mask],
|
|
998
|
+
targets_key: data[targets_key][group_mask],
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
# Call super on the mini-batch
|
|
1002
|
+
checks[group_mask], _ = super().check_constraint(group_data)
|
|
1003
|
+
self.directions[group_mask] = self.compared_rankings
|
|
1004
|
+
|
|
1005
|
+
return checks, ones_like(checks)
|
|
1006
|
+
|
|
1007
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
1008
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
1009
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
1010
|
+
return {layer: self.directions}
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
class ANDConstraint(Constraint):
|
|
1014
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1015
|
+
|
|
1016
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1017
|
+
|
|
1018
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1019
|
+
are satisfied (elementwise logical AND).
|
|
1020
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1021
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1022
|
+
"""
|
|
1023
|
+
|
|
1024
|
+
def __init__(
|
|
1025
|
+
self,
|
|
1026
|
+
*constraints: Constraint,
|
|
813
1027
|
name: str = None,
|
|
814
1028
|
monitor_only: bool = False,
|
|
815
1029
|
rescale_factor: Number = 1.5,
|
|
816
1030
|
) -> None:
|
|
1031
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1032
|
+
|
|
1033
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1034
|
+
|
|
1035
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1036
|
+
are satisfied (elementwise logical AND).
|
|
1037
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1038
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1039
|
+
|
|
1040
|
+
Args:
|
|
1041
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1042
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1043
|
+
the name will be composed from the sub-constraint names joined with
|
|
1044
|
+
" AND ".
|
|
1045
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1046
|
+
but not enforced. Defaults to False.
|
|
1047
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1048
|
+
corrections. Defaults to 1.5.
|
|
1049
|
+
|
|
1050
|
+
Attributes:
|
|
1051
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1052
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1053
|
+
name (str): The name of the constraint (composed or custom).
|
|
817
1054
|
"""
|
|
818
|
-
|
|
1055
|
+
# Type checking
|
|
1056
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
1057
|
+
|
|
1058
|
+
# Compose constraint name
|
|
1059
|
+
if not name:
|
|
1060
|
+
name = " AND ".join([constraint.name for constraint in constraints])
|
|
1061
|
+
|
|
1062
|
+
# Init parent class
|
|
1063
|
+
super().__init__(
|
|
1064
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
1065
|
+
name,
|
|
1066
|
+
monitor_only,
|
|
1067
|
+
rescale_factor,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
# Init variables
|
|
1071
|
+
self.constraints = constraints
|
|
1072
|
+
|
|
1073
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1074
|
+
"""Evaluate whether all sub-constraints are satisfied.
|
|
1075
|
+
|
|
1076
|
+
Args:
|
|
1077
|
+
data: Model predictions and associated batch/context information.
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1081
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1082
|
+
elementwise whether all constraints are satisfied
|
|
1083
|
+
(logical AND).
|
|
1084
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1085
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1086
|
+
in downstream processing.
|
|
819
1087
|
"""
|
|
1088
|
+
total_satisfaction: Tensor = None
|
|
1089
|
+
total_mask: Tensor = None
|
|
1090
|
+
|
|
1091
|
+
# TODO vectorize this loop
|
|
1092
|
+
for constraint in self.constraints:
|
|
1093
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1094
|
+
if total_satisfaction is None:
|
|
1095
|
+
total_satisfaction = satisfaction
|
|
1096
|
+
total_mask = mask
|
|
1097
|
+
else:
|
|
1098
|
+
total_satisfaction = logical_and(total_satisfaction, satisfaction)
|
|
1099
|
+
total_mask = logical_or(total_mask, mask)
|
|
820
1100
|
|
|
821
|
-
|
|
822
|
-
validate_type("a", a, (str, Transformation))
|
|
823
|
-
validate_type("b", b, (str, Transformation))
|
|
824
|
-
validate_type("rtol", rtol, float)
|
|
825
|
-
validate_type("atol", atol, float)
|
|
826
|
-
|
|
827
|
-
# If transformation is provided, get neuron name,
|
|
828
|
-
# else use IdentityTransformation
|
|
829
|
-
if isinstance(a, Transformation):
|
|
830
|
-
neuron_name_a = a.neuron_name
|
|
831
|
-
transformation_a = a
|
|
832
|
-
else:
|
|
833
|
-
neuron_name_a = a
|
|
834
|
-
transformation_a = IdentityTransformation(neuron_name_a)
|
|
1101
|
+
return total_satisfaction.float(), total_mask.float()
|
|
835
1102
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
1103
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1104
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
1105
|
+
|
|
1106
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1107
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1108
|
+
for each affected layer.
|
|
1109
|
+
|
|
1110
|
+
Args:
|
|
1111
|
+
data: Model predictions and associated batch/context information.
|
|
1112
|
+
|
|
1113
|
+
Returns:
|
|
1114
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1115
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1116
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1117
|
+
sub-constraint directions.
|
|
1118
|
+
"""
|
|
1119
|
+
total_direction: dict[str, Tensor] = {}
|
|
1120
|
+
|
|
1121
|
+
# TODO vectorize this loop
|
|
1122
|
+
for constraint in self.constraints:
|
|
1123
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1124
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1125
|
+
direction = constraint.calculate_direction(data)
|
|
1126
|
+
|
|
1127
|
+
for layer, dir in direction.items():
|
|
1128
|
+
if layer not in total_direction:
|
|
1129
|
+
total_direction[layer] = satisfaction.unsqueeze(1) * dir
|
|
1130
|
+
else:
|
|
1131
|
+
total_direction[layer] += satisfaction.unsqueeze(1) * dir
|
|
1132
|
+
|
|
1133
|
+
return total_direction
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
class ORConstraint(Constraint):
|
|
1137
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1138
|
+
|
|
1139
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1140
|
+
|
|
1141
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1142
|
+
is satisfied (elementwise logical OR).
|
|
1143
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1144
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1145
|
+
"""
|
|
1146
|
+
|
|
1147
|
+
def __init__(
|
|
1148
|
+
self,
|
|
1149
|
+
*constraints: Constraint,
|
|
1150
|
+
name: str = None,
|
|
1151
|
+
monitor_only: bool = False,
|
|
1152
|
+
rescale_factor: Number = 1.5,
|
|
1153
|
+
) -> None:
|
|
1154
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1155
|
+
|
|
1156
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1157
|
+
|
|
1158
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1159
|
+
is satisfied (elementwise logical OR).
|
|
1160
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1161
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1165
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1166
|
+
the name will be composed from the sub-constraint names joined with
|
|
1167
|
+
" OR ".
|
|
1168
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1169
|
+
but not enforced. Defaults to False.
|
|
1170
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1171
|
+
corrections. Defaults to 1.5.
|
|
1172
|
+
|
|
1173
|
+
Attributes:
|
|
1174
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1175
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1176
|
+
name (str): The name of the constraint (composed or custom).
|
|
1177
|
+
"""
|
|
1178
|
+
# Type checking
|
|
1179
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
842
1180
|
|
|
843
1181
|
# Compose constraint name
|
|
844
|
-
|
|
1182
|
+
if not name:
|
|
1183
|
+
name = " OR ".join([constraint.name for constraint in constraints])
|
|
845
1184
|
|
|
846
1185
|
# Init parent class
|
|
847
1186
|
super().__init__(
|
|
848
|
-
|
|
1187
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
849
1188
|
name,
|
|
850
1189
|
monitor_only,
|
|
851
1190
|
rescale_factor,
|
|
852
1191
|
)
|
|
853
1192
|
|
|
854
1193
|
# Init variables
|
|
855
|
-
self.
|
|
856
|
-
self.transformation_b = transformation_b
|
|
857
|
-
self.rtol = rtol
|
|
858
|
-
self.atol = atol
|
|
1194
|
+
self.constraints = constraints
|
|
859
1195
|
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
|
|
863
|
-
self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
|
|
864
|
-
self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
|
|
1196
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1197
|
+
"""Evaluate whether any sub-constraints are satisfied.
|
|
865
1198
|
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
) -> tuple[Tensor, int]:
|
|
1199
|
+
Args:
|
|
1200
|
+
data: Model predictions and associated batch/context information.
|
|
869
1201
|
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
1202
|
+
Returns:
|
|
1203
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1204
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1205
|
+
elementwise whether any constraints are satisfied
|
|
1206
|
+
(logical OR).
|
|
1207
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1208
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1209
|
+
in downstream processing.
|
|
1210
|
+
"""
|
|
1211
|
+
total_satisfaction: Tensor = None
|
|
1212
|
+
total_mask: Tensor = None
|
|
1213
|
+
|
|
1214
|
+
# TODO vectorize this loop
|
|
1215
|
+
for constraint in self.constraints:
|
|
1216
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1217
|
+
if total_satisfaction is None:
|
|
1218
|
+
total_satisfaction = satisfaction
|
|
1219
|
+
total_mask = mask
|
|
1220
|
+
else:
|
|
1221
|
+
total_satisfaction = logical_or(total_satisfaction, satisfaction)
|
|
1222
|
+
total_mask = logical_or(total_mask, mask)
|
|
873
1223
|
|
|
874
|
-
|
|
875
|
-
selection_a = self.transformation_a(selection_a)
|
|
876
|
-
selection_b = self.transformation_b(selection_b)
|
|
877
|
-
|
|
878
|
-
# Calculate result
|
|
879
|
-
result = isclose(
|
|
880
|
-
square(selection_a) + square(selection_b),
|
|
881
|
-
ones_like(selection_a, device=self.device),
|
|
882
|
-
rtol=self.rtol,
|
|
883
|
-
atol=self.atol,
|
|
884
|
-
).float()
|
|
885
|
-
|
|
886
|
-
return result, numel(result)
|
|
887
|
-
|
|
888
|
-
def calculate_direction(
|
|
889
|
-
self, prediction: dict[str, Tensor]
|
|
890
|
-
) -> Dict[str, Tensor]:
|
|
891
|
-
# NOTE currently only works for dense layers due
|
|
892
|
-
# to neuron to index translation
|
|
1224
|
+
return total_satisfaction.float(), total_mask.float()
|
|
893
1225
|
|
|
894
|
-
|
|
1226
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1227
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
895
1228
|
|
|
896
|
-
|
|
897
|
-
|
|
1229
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1230
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1231
|
+
for each affected layer.
|
|
1232
|
+
|
|
1233
|
+
Args:
|
|
1234
|
+
data: Model predictions and associated batch/context information.
|
|
898
1235
|
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
1236
|
+
Returns:
|
|
1237
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1238
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1239
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1240
|
+
sub-constraint directions.
|
|
1241
|
+
"""
|
|
1242
|
+
total_direction: dict[str, Tensor] = {}
|
|
902
1243
|
|
|
903
|
-
|
|
904
|
-
|
|
1244
|
+
# TODO vectorize this loop
|
|
1245
|
+
for constraint in self.constraints:
|
|
1246
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1247
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1248
|
+
direction = constraint.calculate_direction(data)
|
|
905
1249
|
|
|
906
|
-
|
|
1250
|
+
for layer, dir in direction.items():
|
|
1251
|
+
if layer not in total_direction:
|
|
1252
|
+
total_direction[layer] = satisfaction.unsqueeze(1) * dir
|
|
1253
|
+
else:
|
|
1254
|
+
total_direction[layer] += satisfaction.unsqueeze(1) * dir
|
|
1255
|
+
|
|
1256
|
+
return total_direction
|