congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +813 -476
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +118 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.6.dist-info → congrads-1.1.0.dist-info}/METADATA +48 -41
- congrads-1.1.0.dist-info/RECORD +14 -0
- congrads-1.1.0.dist-info/WHEEL +4 -0
- congrads-1.0.6.dist-info/LICENSE +0 -26
- congrads-1.0.6.dist-info/RECORD +0 -15
- congrads-1.0.6.dist-info/WHEEL +0 -5
- congrads-1.0.6.dist-info/top_level.txt +0 -1
congrads/constraints.py
CHANGED
|
@@ -1,65 +1,55 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
outputs equals a specified value, which can be used to control total output.
|
|
19
|
-
- `PythagoreanConstraint`: A constraint that enforces the Pythagorean theorem
|
|
20
|
-
on a set of neurons, ensuring that the square of one neuron's output is equal
|
|
21
|
-
to the sum of the squares of other outputs.
|
|
22
|
-
|
|
23
|
-
These constraints can be used to steer the learning process by applying
|
|
24
|
-
conditions such as logical implications or numerical bounds.
|
|
1
|
+
"""Module providing constraint classes for guiding neural network training.
|
|
2
|
+
|
|
3
|
+
This module defines constraints that enforce specific conditions on network outputs
|
|
4
|
+
to steer learning. Available constraint types include:
|
|
5
|
+
|
|
6
|
+
- `Constraint`: Base class for all constraint types, defining the interface and core
|
|
7
|
+
behavior.
|
|
8
|
+
- `ImplicationConstraint`: Enforces one condition only if another condition is met,
|
|
9
|
+
useful for modeling implications between outputs.
|
|
10
|
+
- `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
|
|
11
|
+
- `BinaryConstraint`: Enforces a binary comparison between two tags using a
|
|
12
|
+
comparison function (e.g., less than, greater than).
|
|
13
|
+
- `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
|
|
14
|
+
value, controlling total output.
|
|
15
|
+
|
|
16
|
+
These constraints can steer the learning process by applying logical implications
|
|
17
|
+
or numerical bounds.
|
|
25
18
|
|
|
26
19
|
Usage:
|
|
27
20
|
1. Define a custom constraint class by inheriting from `Constraint`.
|
|
28
|
-
2. Apply the constraint to your neural network during training
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
transformations and comparisons in constraints.
|
|
21
|
+
2. Apply the constraint to your neural network during training.
|
|
22
|
+
3. Use helper classes like `IdentityTransformation` for transformations and
|
|
23
|
+
comparisons in constraints.
|
|
32
24
|
|
|
33
|
-
Dependencies:
|
|
34
|
-
- PyTorch (`torch`)
|
|
35
25
|
"""
|
|
36
26
|
|
|
37
27
|
import random
|
|
38
28
|
import string
|
|
39
29
|
import warnings
|
|
40
30
|
from abc import ABC, abstractmethod
|
|
31
|
+
from collections.abc import Callable
|
|
41
32
|
from numbers import Number
|
|
42
|
-
from typing import
|
|
33
|
+
from typing import Literal
|
|
43
34
|
|
|
44
35
|
from torch import (
|
|
45
36
|
Tensor,
|
|
46
|
-
|
|
37
|
+
argsort,
|
|
38
|
+
eq,
|
|
47
39
|
ge,
|
|
48
40
|
gt,
|
|
49
|
-
isclose,
|
|
50
41
|
le,
|
|
42
|
+
logical_and,
|
|
51
43
|
logical_not,
|
|
52
44
|
logical_or,
|
|
53
45
|
lt,
|
|
54
|
-
numel,
|
|
55
46
|
ones,
|
|
56
47
|
ones_like,
|
|
57
48
|
reshape,
|
|
58
49
|
sign,
|
|
59
|
-
sqrt,
|
|
60
|
-
square,
|
|
61
50
|
stack,
|
|
62
51
|
tensor,
|
|
52
|
+
unique,
|
|
63
53
|
zeros_like,
|
|
64
54
|
)
|
|
65
55
|
from torch.nn.functional import normalize
|
|
@@ -70,8 +60,7 @@ from .utils import validate_comparator_pytorch, validate_iterable, validate_type
|
|
|
70
60
|
|
|
71
61
|
|
|
72
62
|
class Constraint(ABC):
|
|
73
|
-
"""
|
|
74
|
-
Abstract base class for defining constraints applied to neural networks.
|
|
63
|
+
"""Abstract base class for defining constraints applied to neural networks.
|
|
75
64
|
|
|
76
65
|
A `Constraint` specifies conditions that the neural network outputs
|
|
77
66
|
should satisfy. It supports monitoring constraint satisfaction
|
|
@@ -79,23 +68,22 @@ class Constraint(ABC):
|
|
|
79
68
|
must implement the `check_constraint` and `calculate_direction` methods.
|
|
80
69
|
|
|
81
70
|
Args:
|
|
82
|
-
|
|
71
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
83
72
|
name (str, optional): A unique name for the constraint. If not provided,
|
|
84
73
|
a name is generated based on the class name and a random suffix.
|
|
85
|
-
|
|
86
|
-
without adjusting the loss. Defaults to
|
|
74
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
75
|
+
without adjusting the loss. Defaults to True.
|
|
87
76
|
rescale_factor (Number, optional): Factor to scale the
|
|
88
77
|
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
89
78
|
than 1 to give weight to the constraint.
|
|
90
79
|
|
|
91
80
|
Raises:
|
|
92
81
|
TypeError: If a provided attribute has an incompatible type.
|
|
93
|
-
ValueError: If any
|
|
82
|
+
ValueError: If any tag in `tags` is not
|
|
94
83
|
defined in the `descriptor`.
|
|
95
84
|
|
|
96
85
|
Note:
|
|
97
|
-
- If `rescale_factor <= 1`, a warning is issued
|
|
98
|
-
adjusted to a positive value greater than 1.
|
|
86
|
+
- If `rescale_factor <= 1`, a warning is issued.
|
|
99
87
|
- If `name` is not provided, a name is auto-generated,
|
|
100
88
|
and a warning is logged.
|
|
101
89
|
|
|
@@ -105,38 +93,53 @@ class Constraint(ABC):
|
|
|
105
93
|
device = None
|
|
106
94
|
|
|
107
95
|
def __init__(
|
|
108
|
-
self,
|
|
109
|
-
neurons: set[str],
|
|
110
|
-
name: str = None,
|
|
111
|
-
monitor_only: bool = False,
|
|
112
|
-
rescale_factor: Number = 1.5,
|
|
96
|
+
self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
|
|
113
97
|
) -> None:
|
|
114
|
-
"""
|
|
115
|
-
Initializes a new Constraint instance.
|
|
116
|
-
"""
|
|
98
|
+
"""Initializes a new Constraint instance.
|
|
117
99
|
|
|
100
|
+
Args:
|
|
101
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
102
|
+
name (str, optional): A unique name for the constraint. If not
|
|
103
|
+
provided, a name is generated based on the class name and a
|
|
104
|
+
random suffix.
|
|
105
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
106
|
+
without adjusting the loss. Defaults to True.
|
|
107
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
108
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
109
|
+
than 1 to give weight to the constraint.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
113
|
+
ValueError: If any tag in `tags` is not defined in the `descriptor`.
|
|
114
|
+
|
|
115
|
+
Note:
|
|
116
|
+
- If `rescale_factor <= 1`, a warning is issued.
|
|
117
|
+
- If `name` is not provided, a name is auto-generated, and a
|
|
118
|
+
warning is logged.
|
|
119
|
+
"""
|
|
118
120
|
# Init parent class
|
|
119
121
|
super().__init__()
|
|
120
122
|
|
|
121
123
|
# Type checking
|
|
122
|
-
validate_iterable("
|
|
123
|
-
validate_type("name", name,
|
|
124
|
-
validate_type("
|
|
124
|
+
validate_iterable("tags", tags, str)
|
|
125
|
+
validate_type("name", name, str, allow_none=True)
|
|
126
|
+
validate_type("enforce", enforce, bool)
|
|
125
127
|
validate_type("rescale_factor", rescale_factor, Number)
|
|
126
128
|
|
|
127
129
|
# Init object variables
|
|
128
|
-
self.
|
|
130
|
+
self.tags = tags
|
|
129
131
|
self.rescale_factor = rescale_factor
|
|
130
|
-
self.
|
|
132
|
+
self.initial_rescale_factor = rescale_factor
|
|
133
|
+
self.enforce = enforce
|
|
131
134
|
|
|
132
135
|
# Perform checks
|
|
133
136
|
if rescale_factor <= 1:
|
|
134
137
|
warnings.warn(
|
|
135
|
-
"Rescale factor for constraint
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
f"Rescale factor for constraint {name} is <= 1. The network "
|
|
139
|
+
"will favor general loss over the constraint-adjusted loss. "
|
|
140
|
+
"Is this intended behavior? Normally, the rescale factor "
|
|
141
|
+
"should always be larger than 1.",
|
|
142
|
+
stacklevel=2,
|
|
140
143
|
)
|
|
141
144
|
|
|
142
145
|
# If no constraint_name is set, generate one based
|
|
@@ -144,124 +147,94 @@ class Constraint(ABC):
|
|
|
144
147
|
if name:
|
|
145
148
|
self.name = name
|
|
146
149
|
else:
|
|
147
|
-
random_suffix = "".join(
|
|
148
|
-
random.choices(string.ascii_uppercase + string.digits, k=6)
|
|
149
|
-
)
|
|
150
|
+
random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
|
150
151
|
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
151
|
-
warnings.warn(
|
|
152
|
-
"Name for constraint is not set. Using %s.", self.name
|
|
153
|
-
)
|
|
152
|
+
warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
|
|
154
153
|
|
|
155
|
-
#
|
|
156
|
-
if rescale_factor <= 1:
|
|
157
|
-
self.rescale_factor = abs(rescale_factor) + 1.5
|
|
158
|
-
warnings.warn(
|
|
159
|
-
"Rescale factor for constraint %s is < 1, adjusted value \
|
|
160
|
-
%s to %s.",
|
|
161
|
-
name,
|
|
162
|
-
rescale_factor,
|
|
163
|
-
self.rescale_factor,
|
|
164
|
-
)
|
|
165
|
-
else:
|
|
166
|
-
self.rescale_factor = rescale_factor
|
|
167
|
-
|
|
168
|
-
# Infer layers from descriptor and neurons
|
|
154
|
+
# Infer layers from descriptor and tags
|
|
169
155
|
self.layers = set()
|
|
170
|
-
for
|
|
171
|
-
if
|
|
156
|
+
for tag in self.tags:
|
|
157
|
+
if not self.descriptor.exists(tag):
|
|
172
158
|
raise ValueError(
|
|
173
|
-
f
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
159
|
+
f"The tag {tag} used with constraint "
|
|
160
|
+
f"{self.name} is not defined in the descriptor. Please "
|
|
161
|
+
"add it to the correct layer using "
|
|
162
|
+
"descriptor.add('layer', ...)."
|
|
177
163
|
)
|
|
178
164
|
|
|
179
|
-
self.
|
|
165
|
+
layer, _ = self.descriptor.location(tag)
|
|
166
|
+
self.layers.add(layer)
|
|
180
167
|
|
|
181
168
|
@abstractmethod
|
|
182
|
-
def check_constraint(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
Evaluates whether the given model predictions satisfy the constraint.
|
|
169
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
170
|
+
"""Evaluates whether the given model predictions satisfy the constraint.
|
|
171
|
+
|
|
172
|
+
1 IS SATISFIED, 0 IS NOT SATISFIED
|
|
187
173
|
|
|
188
174
|
Args:
|
|
189
|
-
|
|
175
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
190
176
|
|
|
191
177
|
Returns:
|
|
192
|
-
tuple[Tensor,
|
|
193
|
-
indicating whether the constraint is satisfied (with
|
|
194
|
-
for satisfaction,
|
|
195
|
-
|
|
196
|
-
|
|
178
|
+
tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
|
|
179
|
+
indicating whether the constraint is satisfied (with value 1.0
|
|
180
|
+
for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
|
|
181
|
+
mask that indicates the relevance of each sample (`True` for relevant
|
|
182
|
+
samples and `False` for irrelevant ones).
|
|
197
183
|
|
|
198
184
|
Raises:
|
|
199
185
|
NotImplementedError: If not implemented in a subclass.
|
|
200
186
|
"""
|
|
201
|
-
|
|
202
187
|
raise NotImplementedError
|
|
203
188
|
|
|
204
189
|
@abstractmethod
|
|
205
|
-
def calculate_direction(
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
better satisfy the constraint.
|
|
190
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
191
|
+
"""Compute adjustment directions to better satisfy the constraint.
|
|
192
|
+
|
|
193
|
+
Given the model predictions, input batch, and context, this method calculates the direction
|
|
194
|
+
in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
|
|
211
195
|
|
|
212
196
|
Args:
|
|
213
|
-
|
|
197
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
214
198
|
|
|
215
199
|
Returns:
|
|
216
|
-
|
|
217
|
-
|
|
200
|
+
dict[str, Tensor]: Dictionary mapping network layers to tensors that
|
|
201
|
+
specify the adjustment direction for each tag.
|
|
218
202
|
|
|
219
203
|
Raises:
|
|
220
|
-
NotImplementedError:
|
|
204
|
+
NotImplementedError: Must be implemented by subclasses.
|
|
221
205
|
"""
|
|
222
|
-
|
|
223
206
|
raise NotImplementedError
|
|
224
207
|
|
|
225
208
|
|
|
226
209
|
class ImplicationConstraint(Constraint):
|
|
227
|
-
"""
|
|
228
|
-
Represents an implication constraint between two
|
|
229
|
-
constraints (head and body).
|
|
210
|
+
"""Represents an implication constraint between two constraints (head and body).
|
|
230
211
|
|
|
231
212
|
The implication constraint ensures that the `body` constraint only applies
|
|
232
213
|
when the `head` constraint is satisfied. If the `head` constraint is not
|
|
233
214
|
satisfied, the `body` constraint does not apply.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
head (Constraint): The head of the implication. If this constraint
|
|
237
|
-
is satisfied, the body constraint must also be satisfied.
|
|
238
|
-
body (Constraint): The body of the implication. This constraint
|
|
239
|
-
is enforced only when the head constraint is satisfied.
|
|
240
|
-
name (str, optional): A unique name for the constraint. If not
|
|
241
|
-
provided, the name is generated in the format
|
|
242
|
-
"{body.name} if {head.name}". Defaults to None.
|
|
243
|
-
monitor_only (bool, optional): If True, the constraint is only
|
|
244
|
-
monitored without adjusting the loss. Defaults to False.
|
|
245
|
-
rescale_factor (Number, optional): The scaling factor for the
|
|
246
|
-
constraint-adjusted loss. Defaults to 1.5.
|
|
247
|
-
|
|
248
|
-
Raises:
|
|
249
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
250
|
-
|
|
251
215
|
"""
|
|
252
216
|
|
|
253
217
|
def __init__(
|
|
254
218
|
self,
|
|
255
219
|
head: Constraint,
|
|
256
220
|
body: Constraint,
|
|
257
|
-
name=None,
|
|
258
|
-
monitor_only=False,
|
|
259
|
-
rescale_factor=1.5,
|
|
221
|
+
name: str = None,
|
|
260
222
|
):
|
|
261
|
-
"""
|
|
262
|
-
|
|
263
|
-
|
|
223
|
+
"""Initializes an ImplicationConstraint instance.
|
|
224
|
+
|
|
225
|
+
Uses `enforce` and `rescale_factor` from the body constraint.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
head (Constraint): Constraint defining the head of the implication.
|
|
229
|
+
body (Constraint): Constraint defining the body of the implication.
|
|
230
|
+
name (str, optional): A unique name for the constraint. If not
|
|
231
|
+
provided, a name is generated based on the class name and a
|
|
232
|
+
random suffix.
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
264
236
|
|
|
237
|
+
"""
|
|
265
238
|
# Type checking
|
|
266
239
|
validate_type("head", head, Constraint)
|
|
267
240
|
validate_type("body", body, Constraint)
|
|
@@ -270,64 +243,82 @@ class ImplicationConstraint(Constraint):
|
|
|
270
243
|
name = f"{body.name} if {head.name}"
|
|
271
244
|
|
|
272
245
|
# Init parent class
|
|
273
|
-
super().__init__(
|
|
274
|
-
head.neurons | body.neurons,
|
|
275
|
-
name,
|
|
276
|
-
monitor_only,
|
|
277
|
-
rescale_factor,
|
|
278
|
-
)
|
|
246
|
+
super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
|
|
279
247
|
|
|
280
248
|
self.head = head
|
|
281
249
|
self.body = body
|
|
282
250
|
|
|
283
|
-
def check_constraint(
|
|
284
|
-
|
|
285
|
-
) -> tuple[Tensor, int]:
|
|
251
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
252
|
+
"""Check whether the implication constraint is satisfied.
|
|
286
253
|
|
|
254
|
+
Evaluates the `head` and `body` constraints. The `body` constraint
|
|
255
|
+
is enforced only if the `head` constraint is satisfied. If the
|
|
256
|
+
`head` constraint is not satisfied, the `body` constraint does not
|
|
257
|
+
affect the result.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
tuple[Tensor, Tensor]:
|
|
264
|
+
- result: Tensor indicating satisfaction of the implication
|
|
265
|
+
constraint (1 if satisfied, 0 otherwise).
|
|
266
|
+
- head_satisfaction: Tensor indicating satisfaction of the
|
|
267
|
+
head constraint alone.
|
|
268
|
+
"""
|
|
287
269
|
# Check satisfaction of head and body constraints
|
|
288
|
-
head_satisfaction, _ = self.head.check_constraint(
|
|
289
|
-
body_satisfaction, _ = self.body.check_constraint(
|
|
270
|
+
head_satisfaction, _ = self.head.check_constraint(data)
|
|
271
|
+
body_satisfaction, _ = self.body.check_constraint(data)
|
|
290
272
|
|
|
291
273
|
# If head constraint is satisfied (returning 1),
|
|
292
274
|
# the body constraint matters (and should return 0/1 based on body)
|
|
293
275
|
# If head constraint is not satisfied (returning 0),
|
|
294
276
|
# the body constraint does not apply (and should return 1)
|
|
295
|
-
result = logical_or(
|
|
296
|
-
logical_not(head_satisfaction), body_satisfaction
|
|
297
|
-
).float()
|
|
277
|
+
result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
|
|
298
278
|
|
|
299
|
-
return result,
|
|
279
|
+
return result, head_satisfaction
|
|
280
|
+
|
|
281
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
282
|
+
"""Compute adjustment directions for tags to satisfy the constraint.
|
|
283
|
+
|
|
284
|
+
Uses the `body` constraint directions as the update vector. Only
|
|
285
|
+
applies updates if the `head` constraint is satisfied. Currently,
|
|
286
|
+
this method only works for dense layers due to tag-to-index
|
|
287
|
+
translation limitations.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
300
291
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
292
|
+
Returns:
|
|
293
|
+
dict[str, Tensor]: Dictionary mapping tags to tensors
|
|
294
|
+
specifying the adjustment direction for each tag.
|
|
295
|
+
"""
|
|
304
296
|
# NOTE currently only works for dense layers
|
|
305
|
-
# due to
|
|
297
|
+
# due to tag to index translation
|
|
306
298
|
|
|
307
299
|
# Use directions of constraint body as update vector
|
|
308
|
-
return self.body.calculate_direction(
|
|
300
|
+
return self.body.calculate_direction(data)
|
|
309
301
|
|
|
310
302
|
|
|
311
303
|
class ScalarConstraint(Constraint):
|
|
312
|
-
"""
|
|
313
|
-
A constraint that enforces scalar-based comparisons on a specific neuron.
|
|
304
|
+
"""A constraint that enforces scalar-based comparisons on a specific tag.
|
|
314
305
|
|
|
315
|
-
This class ensures that the output of a specified
|
|
306
|
+
This class ensures that the output of a specified tag satisfies a scalar
|
|
316
307
|
comparison operation (e.g., less than, greater than, etc.). It uses a
|
|
317
308
|
comparator function to validate the condition and calculates adjustment
|
|
318
309
|
directions accordingly.
|
|
319
310
|
|
|
320
311
|
Args:
|
|
321
|
-
operand (Union[str, Transformation]): Name of the
|
|
312
|
+
operand (Union[str, Transformation]): Name of the tag or a
|
|
322
313
|
transformation to apply.
|
|
323
314
|
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
324
315
|
function (e.g., `torch.ge`, `torch.lt`).
|
|
325
316
|
scalar (Number): The scalar value to compare against.
|
|
326
317
|
name (str, optional): A unique name for the constraint. If not
|
|
327
318
|
provided, a name is auto-generated in the format
|
|
328
|
-
"<
|
|
329
|
-
|
|
330
|
-
without adjusting the loss. Defaults to
|
|
319
|
+
"<tag> <comparator> <scalar>".
|
|
320
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
321
|
+
without adjusting the loss. Defaults to True.
|
|
331
322
|
rescale_factor (Number, optional): Factor to scale the
|
|
332
323
|
constraint-adjusted loss. Defaults to 1.5.
|
|
333
324
|
|
|
@@ -335,87 +326,120 @@ class ScalarConstraint(Constraint):
|
|
|
335
326
|
TypeError: If a provided attribute has an incompatible type.
|
|
336
327
|
|
|
337
328
|
Notes:
|
|
338
|
-
- The `
|
|
339
|
-
- The constraint name is composed using the
|
|
340
|
-
comparator, and scalar value.
|
|
329
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
330
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
341
331
|
|
|
342
332
|
"""
|
|
343
333
|
|
|
344
334
|
def __init__(
|
|
345
335
|
self,
|
|
346
|
-
operand:
|
|
336
|
+
operand: str | Transformation,
|
|
347
337
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
348
338
|
scalar: Number,
|
|
349
339
|
name: str = None,
|
|
350
|
-
|
|
340
|
+
enforce: bool = True,
|
|
351
341
|
rescale_factor: Number = 1.5,
|
|
352
342
|
) -> None:
|
|
353
|
-
"""
|
|
354
|
-
|
|
355
|
-
|
|
343
|
+
"""Initializes a ScalarConstraint instance.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
operand (Union[str, Transformation]): Function that needs to be
|
|
347
|
+
performed on the network variables before applying the
|
|
348
|
+
constraint.
|
|
349
|
+
comparator (Callable[[Tensor, Number], Tensor]): Comparison
|
|
350
|
+
operator used in the constraint. Supported types are
|
|
351
|
+
{torch.lt, torch.le, torch.st, torch.se}.
|
|
352
|
+
scalar (Number): Constant to compare the variable to.
|
|
353
|
+
name (str, optional): A unique name for the constraint. If not
|
|
354
|
+
provided, a name is generated based on the class name and a
|
|
355
|
+
random suffix.
|
|
356
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
357
|
+
without adjusting the loss. Defaults to True.
|
|
358
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
359
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
360
|
+
than 1 to give weight to the constraint.
|
|
356
361
|
|
|
362
|
+
Raises:
|
|
363
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
364
|
+
|
|
365
|
+
Notes:
|
|
366
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
367
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
368
|
+
"""
|
|
357
369
|
# Type checking
|
|
358
370
|
validate_type("operand", operand, (str, Transformation))
|
|
359
371
|
validate_comparator_pytorch("comparator", comparator)
|
|
360
|
-
validate_comparator_pytorch("comparator", comparator)
|
|
361
372
|
validate_type("scalar", scalar, Number)
|
|
362
373
|
|
|
363
|
-
# If transformation is provided, get
|
|
364
|
-
# else use IdentityTransformation
|
|
374
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
365
375
|
if isinstance(operand, Transformation):
|
|
366
|
-
|
|
376
|
+
tag = operand.tag
|
|
367
377
|
transformation = operand
|
|
368
378
|
else:
|
|
369
|
-
|
|
370
|
-
transformation = IdentityTransformation(
|
|
379
|
+
tag = operand
|
|
380
|
+
transformation = IdentityTransformation(tag)
|
|
371
381
|
|
|
372
382
|
# Compose constraint name
|
|
373
|
-
name = f"{
|
|
383
|
+
name = f"{tag} {comparator.__name__} {str(scalar)}"
|
|
374
384
|
|
|
375
385
|
# Init parent class
|
|
376
|
-
super().__init__({
|
|
386
|
+
super().__init__({tag}, name, enforce, rescale_factor)
|
|
377
387
|
|
|
378
388
|
# Init variables
|
|
389
|
+
self.tag = tag
|
|
379
390
|
self.comparator = comparator
|
|
380
391
|
self.scalar = scalar
|
|
381
392
|
self.transformation = transformation
|
|
382
393
|
|
|
383
|
-
# Get layer name and feature index from neuron_name
|
|
384
|
-
self.layer = self.descriptor.neuron_to_layer[neuron_name]
|
|
385
|
-
self.index = self.descriptor.neuron_to_index[neuron_name]
|
|
386
|
-
|
|
387
394
|
# Calculate directions based on constraint operator
|
|
388
395
|
if self.comparator in [lt, le]:
|
|
389
|
-
self.direction = -1
|
|
390
|
-
elif self.comparator in [gt, ge]:
|
|
391
396
|
self.direction = 1
|
|
397
|
+
elif self.comparator in [gt, ge]:
|
|
398
|
+
self.direction = -1
|
|
392
399
|
|
|
393
|
-
def check_constraint(
|
|
394
|
-
|
|
395
|
-
) -> tuple[Tensor, int]:
|
|
400
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
401
|
+
"""Check if the scalar constraint is satisfied for a given tag.
|
|
396
402
|
|
|
403
|
+
Args:
|
|
404
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
tuple[Tensor, Tensor]:
|
|
408
|
+
- result: Tensor indicating whether the tag satisfies the constraint.
|
|
409
|
+
- ones_like(result): Tensor of ones with same shape as `result`.
|
|
410
|
+
"""
|
|
397
411
|
# Select relevant columns
|
|
398
|
-
selection =
|
|
412
|
+
selection = self.descriptor.select(self.tag, data)
|
|
399
413
|
|
|
400
414
|
# Apply transformation
|
|
401
415
|
selection = self.transformation(selection)
|
|
402
416
|
|
|
403
417
|
# Calculate current constraint result
|
|
404
418
|
result = self.comparator(selection, self.scalar).float()
|
|
405
|
-
return result,
|
|
419
|
+
return result, ones_like(result)
|
|
420
|
+
|
|
421
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
422
|
+
"""Compute adjustment directions to satisfy the scalar constraint.
|
|
406
423
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
424
|
+
Only works for dense layers due to tag-to-index translation.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
dict[str, Tensor]: Dictionary mapping layers to tensors specifying
|
|
431
|
+
the adjustment direction for each tag.
|
|
432
|
+
"""
|
|
410
433
|
# NOTE currently only works for dense layers due
|
|
411
|
-
# to
|
|
434
|
+
# to tag to index translation
|
|
412
435
|
|
|
413
436
|
output = {}
|
|
414
437
|
|
|
415
438
|
for layer in self.layers:
|
|
416
|
-
output[layer] = zeros_like(
|
|
439
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
417
440
|
|
|
418
|
-
|
|
441
|
+
layer, index = self.descriptor.location(self.tag)
|
|
442
|
+
output[layer][index] = self.direction
|
|
419
443
|
|
|
420
444
|
for layer in self.layers:
|
|
421
445
|
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
@@ -424,26 +448,24 @@ class ScalarConstraint(Constraint):
|
|
|
424
448
|
|
|
425
449
|
|
|
426
450
|
class BinaryConstraint(Constraint):
|
|
427
|
-
"""
|
|
428
|
-
A constraint that enforces a binary comparison between two neurons.
|
|
451
|
+
"""A constraint that enforces a binary comparison between two tags.
|
|
429
452
|
|
|
430
|
-
This class ensures that the output of one
|
|
431
|
-
operation with the output of another
|
|
432
|
-
|
|
433
|
-
validate the condition and calculates adjustment directions accordingly.
|
|
453
|
+
This class ensures that the output of one tag satisfies a comparison
|
|
454
|
+
operation with the output of another tag (e.g., less than, greater than, etc.).
|
|
455
|
+
It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
|
|
434
456
|
|
|
435
457
|
Args:
|
|
436
458
|
operand_left (Union[str, Transformation]): Name of the left
|
|
437
|
-
|
|
459
|
+
tag or a transformation to apply.
|
|
438
460
|
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
439
461
|
function (e.g., `torch.ge`, `torch.lt`).
|
|
440
462
|
operand_right (Union[str, Transformation]): Name of the right
|
|
441
|
-
|
|
463
|
+
tag or a transformation to apply.
|
|
442
464
|
name (str, optional): A unique name for the constraint. If not
|
|
443
465
|
provided, a name is auto-generated in the format
|
|
444
|
-
"<
|
|
445
|
-
|
|
446
|
-
without adjusting the loss. Defaults to
|
|
466
|
+
"<operand_left> <comparator> <operand_right>".
|
|
467
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
468
|
+
without adjusting the loss. Defaults to True.
|
|
447
469
|
rescale_factor (Number, optional): Factor to scale the
|
|
448
470
|
constraint-adjusted loss. Defaults to 1.5.
|
|
449
471
|
|
|
@@ -451,84 +473,107 @@ class BinaryConstraint(Constraint):
|
|
|
451
473
|
TypeError: If a provided attribute has an incompatible type.
|
|
452
474
|
|
|
453
475
|
Notes:
|
|
454
|
-
- The
|
|
455
|
-
- The constraint name is composed using the left
|
|
456
|
-
comparator, and right neuron name.
|
|
476
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
477
|
+
- The constraint name is composed using the left tag, comparator, and right tag.
|
|
457
478
|
|
|
458
479
|
"""
|
|
459
480
|
|
|
460
481
|
def __init__(
|
|
461
482
|
self,
|
|
462
|
-
operand_left:
|
|
483
|
+
operand_left: str | Transformation,
|
|
463
484
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
464
|
-
operand_right:
|
|
485
|
+
operand_right: str | Transformation,
|
|
465
486
|
name: str = None,
|
|
466
|
-
|
|
487
|
+
enforce: bool = True,
|
|
467
488
|
rescale_factor: Number = 1.5,
|
|
468
489
|
) -> None:
|
|
469
|
-
"""
|
|
470
|
-
|
|
471
|
-
|
|
490
|
+
"""Initializes a BinaryConstraint instance.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
494
|
+
tag or a transformation to apply.
|
|
495
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
496
|
+
function (e.g., `torch.ge`, `torch.lt`).
|
|
497
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
498
|
+
tag or a transformation to apply.
|
|
499
|
+
name (str, optional): A unique name for the constraint. If not
|
|
500
|
+
provided, a name is auto-generated in the format
|
|
501
|
+
"<operand_left> <comparator> <operand_right>".
|
|
502
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
503
|
+
without adjusting the loss. Defaults to True.
|
|
504
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
505
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
472
509
|
|
|
510
|
+
Notes:
|
|
511
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
512
|
+
- The constraint name is composed using the left tag,
|
|
513
|
+
comparator, and right tag.
|
|
514
|
+
"""
|
|
473
515
|
# Type checking
|
|
474
516
|
validate_type("operand_left", operand_left, (str, Transformation))
|
|
475
517
|
validate_comparator_pytorch("comparator", comparator)
|
|
476
518
|
validate_comparator_pytorch("comparator", comparator)
|
|
477
519
|
validate_type("operand_right", operand_right, (str, Transformation))
|
|
478
520
|
|
|
479
|
-
# If transformation is provided, get
|
|
480
|
-
# else use IdentityTransformation
|
|
521
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
481
522
|
if isinstance(operand_left, Transformation):
|
|
482
|
-
|
|
523
|
+
tag_left = operand_left.tag
|
|
483
524
|
transformation_left = operand_left
|
|
484
525
|
else:
|
|
485
|
-
|
|
486
|
-
transformation_left = IdentityTransformation(
|
|
526
|
+
tag_left = operand_left
|
|
527
|
+
transformation_left = IdentityTransformation(tag_left)
|
|
487
528
|
|
|
488
529
|
if isinstance(operand_right, Transformation):
|
|
489
|
-
|
|
530
|
+
tag_right = operand_right.tag
|
|
490
531
|
transformation_right = operand_right
|
|
491
532
|
else:
|
|
492
|
-
|
|
493
|
-
transformation_right = IdentityTransformation(
|
|
533
|
+
tag_right = operand_right
|
|
534
|
+
transformation_right = IdentityTransformation(tag_right)
|
|
494
535
|
|
|
495
536
|
# Compose constraint name
|
|
496
|
-
name = f"{
|
|
537
|
+
name = f"{tag_left} {comparator.__name__} {tag_right}"
|
|
497
538
|
|
|
498
539
|
# Init parent class
|
|
499
|
-
super().__init__(
|
|
500
|
-
{neuron_name_left, neuron_name_right},
|
|
501
|
-
name,
|
|
502
|
-
monitor_only,
|
|
503
|
-
rescale_factor,
|
|
504
|
-
)
|
|
540
|
+
super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
|
|
505
541
|
|
|
506
542
|
# Init variables
|
|
507
543
|
self.comparator = comparator
|
|
544
|
+
self.tag_left = tag_left
|
|
545
|
+
self.tag_right = tag_right
|
|
508
546
|
self.transformation_left = transformation_left
|
|
509
547
|
self.transformation_right = transformation_right
|
|
510
548
|
|
|
511
|
-
# Get layer name and feature index from neuron_name
|
|
512
|
-
self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
|
|
513
|
-
self.layer_right = self.descriptor.neuron_to_layer[neuron_name_right]
|
|
514
|
-
self.index_left = self.descriptor.neuron_to_index[neuron_name_left]
|
|
515
|
-
self.index_right = self.descriptor.neuron_to_index[neuron_name_right]
|
|
516
|
-
|
|
517
549
|
# Calculate directions based on constraint operator
|
|
518
550
|
if self.comparator in [lt, le]:
|
|
519
|
-
self.direction_left = -1
|
|
520
|
-
self.direction_right = 1
|
|
521
|
-
else:
|
|
522
551
|
self.direction_left = 1
|
|
523
552
|
self.direction_right = -1
|
|
553
|
+
else:
|
|
554
|
+
self.direction_left = -1
|
|
555
|
+
self.direction_right = 1
|
|
524
556
|
|
|
525
|
-
def check_constraint(
|
|
526
|
-
|
|
527
|
-
) -> tuple[Tensor, int]:
|
|
557
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
558
|
+
"""Evaluate whether the binary constraint is satisfied for the current predictions.
|
|
528
559
|
|
|
560
|
+
The constraint compares the outputs of two tags using the specified
|
|
561
|
+
comparator function. A result of `1` indicates the constraint is satisfied
|
|
562
|
+
for a sample, and `0` indicates it is violated.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
tuple[Tensor, Tensor]:
|
|
569
|
+
- result (Tensor): Binary tensor indicating constraint satisfaction
|
|
570
|
+
(1 for satisfied, 0 for violated) for each sample.
|
|
571
|
+
- mask (Tensor): Tensor of ones with the same shape as `result`,
|
|
572
|
+
used for constraint aggregation.
|
|
573
|
+
"""
|
|
529
574
|
# Select relevant columns
|
|
530
|
-
selection_left =
|
|
531
|
-
selection_right =
|
|
575
|
+
selection_left = self.descriptor.select(self.tag_left, data)
|
|
576
|
+
selection_right = self.descriptor.select(self.tag_right, data)
|
|
532
577
|
|
|
533
578
|
# Apply transformations
|
|
534
579
|
selection_left = self.transformation_left(selection_left)
|
|
@@ -536,21 +581,34 @@ class BinaryConstraint(Constraint):
|
|
|
536
581
|
|
|
537
582
|
result = self.comparator(selection_left, selection_right).float()
|
|
538
583
|
|
|
539
|
-
return result,
|
|
584
|
+
return result, ones_like(result)
|
|
585
|
+
|
|
586
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
587
|
+
"""Compute adjustment directions for the tags involved in the binary constraint.
|
|
588
|
+
|
|
589
|
+
The returned directions indicate how to adjust each tag's output to
|
|
590
|
+
satisfy the constraint. Only currently supported for dense layers.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
540
594
|
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
595
|
+
Returns:
|
|
596
|
+
dict[str, Tensor]: A mapping from layer names to tensors specifying
|
|
597
|
+
the normalized adjustment directions for each tag involved in the
|
|
598
|
+
constraint.
|
|
599
|
+
"""
|
|
544
600
|
# NOTE currently only works for dense layers due
|
|
545
|
-
# to
|
|
601
|
+
# to tag to index translation
|
|
546
602
|
|
|
547
603
|
output = {}
|
|
548
604
|
|
|
549
605
|
for layer in self.layers:
|
|
550
|
-
output[layer] = zeros_like(
|
|
606
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
551
607
|
|
|
552
|
-
|
|
553
|
-
|
|
608
|
+
layer_left, index_left = self.descriptor.location(self.tag_left)
|
|
609
|
+
layer_right, index_right = self.descriptor.location(self.tag_right)
|
|
610
|
+
output[layer_left][index_left] = self.direction_left
|
|
611
|
+
output[layer_right][index_right] = self.direction_right
|
|
554
612
|
|
|
555
613
|
for layer in self.layers:
|
|
556
614
|
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
@@ -559,142 +617,119 @@ class BinaryConstraint(Constraint):
|
|
|
559
617
|
|
|
560
618
|
|
|
561
619
|
class SumConstraint(Constraint):
|
|
562
|
-
"""
|
|
563
|
-
A constraint that enforces a weighted summation comparison
|
|
564
|
-
between two groups of neurons.
|
|
620
|
+
"""A constraint that enforces a weighted summation comparison between two groups of tags.
|
|
565
621
|
|
|
566
622
|
This class evaluates whether the weighted sum of outputs from one set of
|
|
567
|
-
|
|
568
|
-
outputs from another set of
|
|
569
|
-
|
|
570
|
-
Args:
|
|
571
|
-
operands_left (list[Union[str, Transformation]]): List of neuron
|
|
572
|
-
names or transformations on the left side.
|
|
573
|
-
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
574
|
-
function for the constraint.
|
|
575
|
-
operands_right (list[Union[str, Transformation]]): List of neuron
|
|
576
|
-
names or transformations on the right side.
|
|
577
|
-
weights_left (list[Number], optional): Weights for the left neurons.
|
|
578
|
-
Defaults to None.
|
|
579
|
-
weights_right (list[Number], optional): Weights for the right
|
|
580
|
-
neurons. Defaults to None.
|
|
581
|
-
name (str, optional): Unique name for the constraint.
|
|
582
|
-
If None, it's auto-generated. Defaults to None.
|
|
583
|
-
monitor_only (bool, optional): If True, only monitor the constraint
|
|
584
|
-
without adjusting the loss. Defaults to False.
|
|
585
|
-
rescale_factor (Number, optional): Factor to scale the
|
|
586
|
-
constraint-adjusted loss. Defaults to 1.5.
|
|
587
|
-
|
|
588
|
-
Raises:
|
|
589
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
590
|
-
ValueError: If the dimensions of neuron names and weights mismatch.
|
|
591
|
-
|
|
623
|
+
tags satisfies a comparison operation with the weighted sum of
|
|
624
|
+
outputs from another set of tags.
|
|
592
625
|
"""
|
|
593
626
|
|
|
594
627
|
def __init__(
|
|
595
628
|
self,
|
|
596
|
-
operands_left: list[
|
|
629
|
+
operands_left: list[str | Transformation],
|
|
597
630
|
comparator: Callable[[Tensor, Number], Tensor],
|
|
598
|
-
operands_right: list[
|
|
631
|
+
operands_right: list[str | Transformation],
|
|
599
632
|
weights_left: list[Number] = None,
|
|
600
633
|
weights_right: list[Number] = None,
|
|
601
634
|
name: str = None,
|
|
602
|
-
|
|
635
|
+
enforce: bool = True,
|
|
603
636
|
rescale_factor: Number = 1.5,
|
|
604
637
|
) -> None:
|
|
605
|
-
"""
|
|
606
|
-
|
|
607
|
-
|
|
638
|
+
"""Initializes the SumConstraint.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
operands_left (list[Union[str, Transformation]]): List of tags
|
|
642
|
+
or transformations on the left side.
|
|
643
|
+
comparator (Callable[[Tensor, Number], Tensor]): A comparison
|
|
644
|
+
function for the constraint.
|
|
645
|
+
operands_right (list[Union[str, Transformation]]): List of tags
|
|
646
|
+
or transformations on the right side.
|
|
647
|
+
weights_left (list[Number], optional): Weights for the left
|
|
648
|
+
tags. Defaults to None.
|
|
649
|
+
weights_right (list[Number], optional): Weights for the right
|
|
650
|
+
tags. Defaults to None.
|
|
651
|
+
name (str, optional): Unique name for the constraint.
|
|
652
|
+
If None, it's auto-generated. Defaults to None.
|
|
653
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
654
|
+
without adjusting the loss. Defaults to True.
|
|
655
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
656
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
608
657
|
|
|
658
|
+
Raises:
|
|
659
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
660
|
+
ValueError: If the dimensions of tags and weights mismatch.
|
|
661
|
+
"""
|
|
609
662
|
# Type checking
|
|
610
663
|
validate_iterable("operands_left", operands_left, (str, Transformation))
|
|
611
664
|
validate_comparator_pytorch("comparator", comparator)
|
|
612
665
|
validate_comparator_pytorch("comparator", comparator)
|
|
613
|
-
validate_iterable(
|
|
614
|
-
"operands_right", operands_right, (str, Transformation)
|
|
615
|
-
)
|
|
666
|
+
validate_iterable("operands_right", operands_right, (str, Transformation))
|
|
616
667
|
validate_iterable("weights_left", weights_left, Number, allow_none=True)
|
|
617
|
-
validate_iterable(
|
|
618
|
-
"weights_right", weights_right, Number, allow_none=True
|
|
619
|
-
)
|
|
668
|
+
validate_iterable("weights_right", weights_right, Number, allow_none=True)
|
|
620
669
|
|
|
621
|
-
# If transformation is provided, get
|
|
622
|
-
|
|
623
|
-
neuron_names_left: list[str] = []
|
|
670
|
+
# If transformation is provided, get tag, else use IdentityTransformation
|
|
671
|
+
tags_left: list[str] = []
|
|
624
672
|
transformations_left: list[Transformation] = []
|
|
625
673
|
for operand_left in operands_left:
|
|
626
674
|
if isinstance(operand_left, Transformation):
|
|
627
|
-
|
|
628
|
-
|
|
675
|
+
tag_left = operand_left.tag
|
|
676
|
+
tags_left.append(tag_left)
|
|
629
677
|
transformations_left.append(operand_left)
|
|
630
678
|
else:
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
transformations_left.append(
|
|
634
|
-
IdentityTransformation(neuron_name_left)
|
|
635
|
-
)
|
|
679
|
+
tag_left = operand_left
|
|
680
|
+
tags_left.append(tag_left)
|
|
681
|
+
transformations_left.append(IdentityTransformation(tag_left))
|
|
636
682
|
|
|
637
|
-
|
|
683
|
+
tags_right: list[str] = []
|
|
638
684
|
transformations_right: list[Transformation] = []
|
|
639
685
|
for operand_right in operands_right:
|
|
640
686
|
if isinstance(operand_right, Transformation):
|
|
641
|
-
|
|
642
|
-
|
|
687
|
+
tag_right = operand_right.tag
|
|
688
|
+
tags_right.append(tag_right)
|
|
643
689
|
transformations_right.append(operand_right)
|
|
644
690
|
else:
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
transformations_right.append(
|
|
648
|
-
IdentityTransformation(neuron_name_right)
|
|
649
|
-
)
|
|
691
|
+
tag_right = operand_right
|
|
692
|
+
tags_right.append(tag_right)
|
|
693
|
+
transformations_right.append(IdentityTransformation(tag_right))
|
|
650
694
|
|
|
651
695
|
# Compose constraint name
|
|
652
|
-
w_left = weights_left or [""] * len(
|
|
653
|
-
w_right = weights_right or [""] * len(
|
|
654
|
-
left_expr = " + ".join(
|
|
655
|
-
|
|
656
|
-
)
|
|
657
|
-
right_expr = " + ".join(
|
|
658
|
-
f"{w}{n}" for w, n in zip(w_right, neuron_names_right)
|
|
659
|
-
)
|
|
696
|
+
w_left = weights_left or [""] * len(tags_left)
|
|
697
|
+
w_right = weights_right or [""] * len(tags_right)
|
|
698
|
+
left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
|
|
699
|
+
right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
|
|
660
700
|
comparator_name = comparator.__name__
|
|
661
701
|
name = f"{left_expr} {comparator_name} {right_expr}"
|
|
662
702
|
|
|
663
703
|
# Init parent class
|
|
664
|
-
|
|
665
|
-
super().__init__(
|
|
704
|
+
tags = set(tags_left) | set(tags_right)
|
|
705
|
+
super().__init__(tags, name, enforce, rescale_factor)
|
|
666
706
|
|
|
667
707
|
# Init variables
|
|
668
708
|
self.comparator = comparator
|
|
669
|
-
self.
|
|
670
|
-
self.
|
|
709
|
+
self.tags_left = tags_left
|
|
710
|
+
self.tags_right = tags_right
|
|
671
711
|
self.transformations_left = transformations_left
|
|
672
712
|
self.transformations_right = transformations_right
|
|
673
713
|
|
|
674
|
-
# If feature list dimensions don't match
|
|
675
|
-
|
|
676
|
-
if weights_left and (len(neuron_names_left) != len(weights_left)):
|
|
714
|
+
# If feature list dimensions don't match weight list dimensions, raise error
|
|
715
|
+
if weights_left and (len(tags_left) != len(weights_left)):
|
|
677
716
|
raise ValueError(
|
|
678
|
-
"The dimensions of
|
|
679
|
-
dimensions of weights_left."
|
|
717
|
+
"The dimensions of tags_left don't match with the dimensions of weights_left."
|
|
680
718
|
)
|
|
681
|
-
if weights_right and (len(
|
|
719
|
+
if weights_right and (len(tags_right) != len(weights_right)):
|
|
682
720
|
raise ValueError(
|
|
683
|
-
"The dimensions of
|
|
684
|
-
dimensions of weights_right."
|
|
721
|
+
"The dimensions of tags_right don't match with the dimensions of weights_right."
|
|
685
722
|
)
|
|
686
723
|
|
|
687
724
|
# If weights are provided for summation, transform them to Tensors
|
|
688
725
|
if weights_left:
|
|
689
726
|
self.weights_left = tensor(weights_left, device=self.device)
|
|
690
727
|
else:
|
|
691
|
-
self.weights_left = ones(len(
|
|
728
|
+
self.weights_left = ones(len(tags_left), device=self.device)
|
|
692
729
|
if weights_right:
|
|
693
730
|
self.weights_right = tensor(weights_right, device=self.device)
|
|
694
731
|
else:
|
|
695
|
-
self.weights_right = ones(
|
|
696
|
-
len(neuron_names_right), device=self.device
|
|
697
|
-
)
|
|
732
|
+
self.weights_right = ones(len(tags_right), device=self.device)
|
|
698
733
|
|
|
699
734
|
# Calculate directions based on constraint operator
|
|
700
735
|
if self.comparator in [lt, le]:
|
|
@@ -704,80 +739,82 @@ class SumConstraint(Constraint):
|
|
|
704
739
|
self.direction_left = 1
|
|
705
740
|
self.direction_right = -1
|
|
706
741
|
|
|
707
|
-
def check_constraint(
|
|
708
|
-
|
|
709
|
-
|
|
742
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
743
|
+
"""Evaluate whether the weighted sum constraint is satisfied.
|
|
744
|
+
|
|
745
|
+
Computes the weighted sum of outputs from the left and right tags,
|
|
746
|
+
applies the specified comparator function, and returns a binary result for
|
|
747
|
+
each sample.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
751
|
+
|
|
752
|
+
Returns:
|
|
753
|
+
tuple[Tensor, Tensor]:
|
|
754
|
+
- result (Tensor): Binary tensor indicating whether the constraint
|
|
755
|
+
is satisfied (1) or violated (0) for each sample.
|
|
756
|
+
- mask (Tensor): Tensor of ones, used for constraint aggregation.
|
|
757
|
+
"""
|
|
710
758
|
|
|
711
759
|
def compute_weighted_sum(
|
|
712
|
-
|
|
760
|
+
tags: list[str],
|
|
713
761
|
transformations: list[Transformation],
|
|
714
|
-
weights:
|
|
715
|
-
) ->
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
for neuron_name in neuron_names
|
|
719
|
-
]
|
|
720
|
-
indices = [
|
|
721
|
-
self.descriptor.neuron_to_index[neuron_name]
|
|
722
|
-
for neuron_name in neuron_names
|
|
723
|
-
]
|
|
724
|
-
|
|
725
|
-
# Select relevant column
|
|
726
|
-
selections = [
|
|
727
|
-
prediction[layer][:, index]
|
|
728
|
-
for layer, index in zip(layers, indices)
|
|
729
|
-
]
|
|
762
|
+
weights: Tensor,
|
|
763
|
+
) -> Tensor:
|
|
764
|
+
# Select relevant columns
|
|
765
|
+
selections = [self.descriptor.select(tag, data) for tag in tags]
|
|
730
766
|
|
|
731
767
|
# Apply transformations
|
|
732
768
|
results = []
|
|
733
|
-
for transformation, selection in zip(transformations, selections):
|
|
769
|
+
for transformation, selection in zip(transformations, selections, strict=False):
|
|
734
770
|
results.append(transformation(selection))
|
|
735
771
|
|
|
736
|
-
# Extract predictions for all
|
|
737
|
-
predictions = stack(
|
|
738
|
-
results,
|
|
739
|
-
dim=1,
|
|
740
|
-
)
|
|
772
|
+
# Extract predictions for all tags and apply weights in bulk
|
|
773
|
+
predictions = stack(results)
|
|
741
774
|
|
|
742
775
|
# Calculate weighted sum
|
|
743
|
-
return (predictions * weights.
|
|
776
|
+
return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
|
|
744
777
|
|
|
745
778
|
# Compute weighted sums
|
|
746
779
|
weighted_sum_left = compute_weighted_sum(
|
|
747
|
-
self.
|
|
748
|
-
self.transformations_left,
|
|
749
|
-
self.weights_left,
|
|
780
|
+
self.tags_left, self.transformations_left, self.weights_left
|
|
750
781
|
)
|
|
751
782
|
weighted_sum_right = compute_weighted_sum(
|
|
752
|
-
self.
|
|
753
|
-
self.transformations_right,
|
|
754
|
-
self.weights_right,
|
|
783
|
+
self.tags_right, self.transformations_right, self.weights_right
|
|
755
784
|
)
|
|
756
785
|
|
|
757
786
|
# Apply the comparator and calculate the result
|
|
758
787
|
result = self.comparator(weighted_sum_left, weighted_sum_right).float()
|
|
759
788
|
|
|
760
|
-
return result,
|
|
789
|
+
return result, ones_like(result)
|
|
790
|
+
|
|
791
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
792
|
+
"""Compute adjustment directions for tags involved in the weighted sum constraint.
|
|
793
|
+
|
|
794
|
+
The directions indicate how to adjust each tag's output to satisfy the
|
|
795
|
+
constraint. Only dense layers are currently supported.
|
|
761
796
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
797
|
+
Args:
|
|
798
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
799
|
+
|
|
800
|
+
Returns:
|
|
801
|
+
dict[str, Tensor]: Mapping from layer names to normalized tensors
|
|
802
|
+
specifying adjustment directions for each tag involved in the constraint.
|
|
803
|
+
"""
|
|
765
804
|
# NOTE currently only works for dense layers
|
|
766
|
-
# due to
|
|
805
|
+
# due to tag to index translation
|
|
767
806
|
|
|
768
807
|
output = {}
|
|
769
808
|
|
|
770
809
|
for layer in self.layers:
|
|
771
|
-
output[layer] = zeros_like(
|
|
810
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
772
811
|
|
|
773
|
-
for
|
|
774
|
-
layer = self.descriptor.
|
|
775
|
-
index = self.descriptor.neuron_to_index[neuron_name_left]
|
|
812
|
+
for tag_left in self.tags_left:
|
|
813
|
+
layer, index = self.descriptor.location(tag_left)
|
|
776
814
|
output[layer][index] = self.direction_left
|
|
777
815
|
|
|
778
|
-
for
|
|
779
|
-
layer = self.descriptor.
|
|
780
|
-
index = self.descriptor.neuron_to_index[neuron_name_right]
|
|
816
|
+
for tag_right in self.tags_right:
|
|
817
|
+
layer, index = self.descriptor.location(tag_right)
|
|
781
818
|
output[layer][index] = self.direction_right
|
|
782
819
|
|
|
783
820
|
for layer in self.layers:
|
|
@@ -786,134 +823,434 @@ class SumConstraint(Constraint):
|
|
|
786
823
|
return output
|
|
787
824
|
|
|
788
825
|
|
|
789
|
-
class
|
|
826
|
+
class MonotonicityConstraint(Constraint):
|
|
827
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
828
|
+
|
|
829
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
830
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
790
831
|
"""
|
|
791
|
-
A constraint that enforces the Pythagorean identity: a² + b² ≈ 1,
|
|
792
|
-
where `a` and `b` are neurons or transformations.
|
|
793
832
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
833
|
+
def __init__(
|
|
834
|
+
self,
|
|
835
|
+
tag_prediction: str,
|
|
836
|
+
tag_reference: str,
|
|
837
|
+
rescale_factor_lower: float = 1.5,
|
|
838
|
+
rescale_factor_upper: float = 1.75,
|
|
839
|
+
stable: bool = True,
|
|
840
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
841
|
+
name: str = None,
|
|
842
|
+
enforce: bool = True,
|
|
843
|
+
):
|
|
844
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
798
845
|
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
neuron name (str) or a Transformation.
|
|
802
|
-
b (Union[str, Transformation]): The second input, either a
|
|
803
|
-
neuron name (str) or a Transformation.
|
|
804
|
-
rtol (float, optional): The relative tolerance for the
|
|
805
|
-
comparison (default is 0.00001).
|
|
806
|
-
atol (float, optional): The absolute tolerance for the
|
|
807
|
-
comparison (default is 1e-8).
|
|
808
|
-
name (str, optional): The name of the constraint
|
|
809
|
-
(default is None, and it is generated automatically).
|
|
810
|
-
monitor_only (bool, optional): Flag indicating whether the
|
|
811
|
-
constraint is only for monitoring (default is False).
|
|
812
|
-
rescale_factor (Number, optional): A factor used for
|
|
813
|
-
rescaling (default is 1.5).
|
|
846
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
847
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
814
848
|
|
|
815
|
-
|
|
816
|
-
|
|
849
|
+
Args:
|
|
850
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
851
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
852
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
853
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
854
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
855
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
856
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
857
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
858
|
+
"""
|
|
859
|
+
# Type checking
|
|
860
|
+
validate_type("rescale_factor_lower", rescale_factor_lower, float)
|
|
861
|
+
validate_type("rescale_factor_upper", rescale_factor_upper, float)
|
|
862
|
+
validate_type("stable", stable, bool)
|
|
863
|
+
validate_type("direction", direction, str)
|
|
817
864
|
|
|
865
|
+
# Compose constraint name
|
|
866
|
+
if name is None:
|
|
867
|
+
name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
|
|
868
|
+
|
|
869
|
+
# Init parent class
|
|
870
|
+
super().__init__({tag_prediction}, name, enforce, 1.0)
|
|
871
|
+
|
|
872
|
+
# Init variables
|
|
873
|
+
self.tag_prediction = tag_prediction
|
|
874
|
+
self.tag_reference = tag_reference
|
|
875
|
+
self.rescale_factor_lower = rescale_factor_lower
|
|
876
|
+
self.rescale_factor_upper = rescale_factor_upper
|
|
877
|
+
self.stable = stable
|
|
878
|
+
self.descending = direction == "descending"
|
|
879
|
+
|
|
880
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
881
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
882
|
+
# Select relevant columns
|
|
883
|
+
preds = self.descriptor.select(self.tag_prediction, data)
|
|
884
|
+
targets = self.descriptor.select(self.tag_reference, data)
|
|
885
|
+
|
|
886
|
+
# Utility: convert values -> ranks (0 ... num_features-1)
|
|
887
|
+
def compute_ranks(x: Tensor, descending: bool) -> Tensor:
|
|
888
|
+
return argsort(
|
|
889
|
+
argsort(x, descending=descending, stable=self.stable, dim=0),
|
|
890
|
+
descending=False,
|
|
891
|
+
stable=self.stable,
|
|
892
|
+
dim=0,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Compute predicted and target ranks
|
|
896
|
+
pred_ranks = compute_ranks(preds, descending=self.descending)
|
|
897
|
+
target_ranks = compute_ranks(targets, descending=False)
|
|
898
|
+
|
|
899
|
+
# Rank difference
|
|
900
|
+
rank_diff = pred_ranks - target_ranks
|
|
901
|
+
|
|
902
|
+
# Rescale differences into [rescale_factor_lower, rescale_factor_upper]
|
|
903
|
+
batch_size = preds.shape[0]
|
|
904
|
+
invert_direction = -1 if self.descending else 1
|
|
905
|
+
self.compared_rankings = (
|
|
906
|
+
(rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
|
|
907
|
+
+ self.rescale_factor_lower * sign(rank_diff)
|
|
908
|
+
) * invert_direction
|
|
909
|
+
|
|
910
|
+
# Calculate satisfaction
|
|
911
|
+
incorrect_rankings = eq(self.compared_rankings, 0).float()
|
|
912
|
+
|
|
913
|
+
return incorrect_rankings, ones_like(incorrect_rankings)
|
|
914
|
+
|
|
915
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
916
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
917
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
918
|
+
return {layer: self.compared_rankings}
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
class GroupedMonotonicityConstraint(MonotonicityConstraint):
|
|
922
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
923
|
+
|
|
924
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
925
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
818
926
|
"""
|
|
819
927
|
|
|
820
928
|
def __init__(
|
|
821
929
|
self,
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
930
|
+
tag_prediction: str,
|
|
931
|
+
tag_reference: str,
|
|
932
|
+
tag_group_identifier: str,
|
|
933
|
+
rescale_factor_lower: float = 1.5,
|
|
934
|
+
rescale_factor_upper: float = 1.75,
|
|
935
|
+
stable: bool = True,
|
|
936
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
937
|
+
name: str = None,
|
|
938
|
+
enforce: bool = True,
|
|
939
|
+
):
|
|
940
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
941
|
+
|
|
942
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
943
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
947
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
948
|
+
tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
|
|
949
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
950
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
951
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
952
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
953
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
954
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
955
|
+
"""
|
|
956
|
+
# Compose constraint name
|
|
957
|
+
if name is None:
|
|
958
|
+
name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
|
|
959
|
+
|
|
960
|
+
# Init parent class
|
|
961
|
+
super().__init__(
|
|
962
|
+
tag_prediction=tag_prediction,
|
|
963
|
+
tag_reference=tag_reference,
|
|
964
|
+
rescale_factor_lower=rescale_factor_lower,
|
|
965
|
+
rescale_factor_upper=rescale_factor_upper,
|
|
966
|
+
stable=stable,
|
|
967
|
+
direction=direction,
|
|
968
|
+
name=name,
|
|
969
|
+
enforce=enforce,
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
# Init variables
|
|
973
|
+
self.tag_prediction = tag_prediction
|
|
974
|
+
self.tag_reference = tag_reference
|
|
975
|
+
self.tag_group_identifier = tag_group_identifier
|
|
976
|
+
|
|
977
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
978
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
979
|
+
# Select group identifiers and convert to unique list
|
|
980
|
+
group_identifiers = self.descriptor.select(self.tag_group_identifier, data)
|
|
981
|
+
unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
|
|
982
|
+
|
|
983
|
+
# Initialize checks and directions
|
|
984
|
+
checks = zeros_like(group_identifiers, device=self.device)
|
|
985
|
+
self.directions = zeros_like(group_identifiers, device=self.device)
|
|
986
|
+
|
|
987
|
+
# Get prediction and target keys
|
|
988
|
+
preds_key, _ = self.descriptor.location(self.tag_prediction)
|
|
989
|
+
targets_key, _ = self.descriptor.location(self.tag_reference)
|
|
990
|
+
|
|
991
|
+
for group_identifier in unique_group_identifiers:
|
|
992
|
+
# Create mask for the samples in this group
|
|
993
|
+
group_mask = (group_identifiers == group_identifier).squeeze(1)
|
|
994
|
+
|
|
995
|
+
# Create mini-batch for the group
|
|
996
|
+
group_data = {
|
|
997
|
+
preds_key: data[preds_key][group_mask],
|
|
998
|
+
targets_key: data[targets_key][group_mask],
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
# Call super on the mini-batch
|
|
1002
|
+
checks[group_mask], _ = super().check_constraint(group_data)
|
|
1003
|
+
self.directions[group_mask] = self.compared_rankings
|
|
1004
|
+
|
|
1005
|
+
return checks, ones_like(checks)
|
|
1006
|
+
|
|
1007
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
1008
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
1009
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
1010
|
+
return {layer: self.directions}
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
class ANDConstraint(Constraint):
|
|
1014
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1015
|
+
|
|
1016
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1017
|
+
|
|
1018
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1019
|
+
are satisfied (elementwise logical AND).
|
|
1020
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1021
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1022
|
+
"""
|
|
1023
|
+
|
|
1024
|
+
def __init__(
|
|
1025
|
+
self,
|
|
1026
|
+
*constraints: Constraint,
|
|
826
1027
|
name: str = None,
|
|
827
1028
|
monitor_only: bool = False,
|
|
828
1029
|
rescale_factor: Number = 1.5,
|
|
829
1030
|
) -> None:
|
|
1031
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1032
|
+
|
|
1033
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1034
|
+
|
|
1035
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1036
|
+
are satisfied (elementwise logical AND).
|
|
1037
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1038
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1039
|
+
|
|
1040
|
+
Args:
|
|
1041
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1042
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1043
|
+
the name will be composed from the sub-constraint names joined with
|
|
1044
|
+
" AND ".
|
|
1045
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1046
|
+
but not enforced. Defaults to False.
|
|
1047
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1048
|
+
corrections. Defaults to 1.5.
|
|
1049
|
+
|
|
1050
|
+
Attributes:
|
|
1051
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1052
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1053
|
+
name (str): The name of the constraint (composed or custom).
|
|
830
1054
|
"""
|
|
831
|
-
|
|
1055
|
+
# Type checking
|
|
1056
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
1057
|
+
|
|
1058
|
+
# Compose constraint name
|
|
1059
|
+
if not name:
|
|
1060
|
+
name = " AND ".join([constraint.name for constraint in constraints])
|
|
1061
|
+
|
|
1062
|
+
# Init parent class
|
|
1063
|
+
super().__init__(
|
|
1064
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
1065
|
+
name,
|
|
1066
|
+
monitor_only,
|
|
1067
|
+
rescale_factor,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
# Init variables
|
|
1071
|
+
self.constraints = constraints
|
|
1072
|
+
|
|
1073
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1074
|
+
"""Evaluate whether all sub-constraints are satisfied.
|
|
1075
|
+
|
|
1076
|
+
Args:
|
|
1077
|
+
data: Model predictions and associated batch/context information.
|
|
1078
|
+
|
|
1079
|
+
Returns:
|
|
1080
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1081
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1082
|
+
elementwise whether all constraints are satisfied
|
|
1083
|
+
(logical AND).
|
|
1084
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1085
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1086
|
+
in downstream processing.
|
|
832
1087
|
"""
|
|
1088
|
+
total_satisfaction: Tensor = None
|
|
1089
|
+
total_mask: Tensor = None
|
|
1090
|
+
|
|
1091
|
+
# TODO vectorize this loop
|
|
1092
|
+
for constraint in self.constraints:
|
|
1093
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1094
|
+
if total_satisfaction is None:
|
|
1095
|
+
total_satisfaction = satisfaction
|
|
1096
|
+
total_mask = mask
|
|
1097
|
+
else:
|
|
1098
|
+
total_satisfaction = logical_and(total_satisfaction, satisfaction)
|
|
1099
|
+
total_mask = logical_or(total_mask, mask)
|
|
833
1100
|
|
|
834
|
-
|
|
835
|
-
validate_type("a", a, (str, Transformation))
|
|
836
|
-
validate_type("b", b, (str, Transformation))
|
|
837
|
-
validate_type("rtol", rtol, float)
|
|
838
|
-
validate_type("atol", atol, float)
|
|
839
|
-
|
|
840
|
-
# If transformation is provided, get neuron name,
|
|
841
|
-
# else use IdentityTransformation
|
|
842
|
-
if isinstance(a, Transformation):
|
|
843
|
-
neuron_name_a = a.neuron_name
|
|
844
|
-
transformation_a = a
|
|
845
|
-
else:
|
|
846
|
-
neuron_name_a = a
|
|
847
|
-
transformation_a = IdentityTransformation(neuron_name_a)
|
|
1101
|
+
return total_satisfaction.float(), total_mask.float()
|
|
848
1102
|
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
1103
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1104
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
1105
|
+
|
|
1106
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1107
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1108
|
+
for each affected layer.
|
|
1109
|
+
|
|
1110
|
+
Args:
|
|
1111
|
+
data: Model predictions and associated batch/context information.
|
|
1112
|
+
|
|
1113
|
+
Returns:
|
|
1114
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1115
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1116
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1117
|
+
sub-constraint directions.
|
|
1118
|
+
"""
|
|
1119
|
+
total_direction: dict[str, Tensor] = {}
|
|
1120
|
+
|
|
1121
|
+
# TODO vectorize this loop
|
|
1122
|
+
for constraint in self.constraints:
|
|
1123
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1124
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1125
|
+
direction = constraint.calculate_direction(data)
|
|
1126
|
+
|
|
1127
|
+
for layer, dir in direction.items():
|
|
1128
|
+
if layer not in total_direction:
|
|
1129
|
+
total_direction[layer] = satisfaction.unsqueeze(1) * dir
|
|
1130
|
+
else:
|
|
1131
|
+
total_direction[layer] += satisfaction.unsqueeze(1) * dir
|
|
1132
|
+
|
|
1133
|
+
return total_direction
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
class ORConstraint(Constraint):
|
|
1137
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1138
|
+
|
|
1139
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1140
|
+
|
|
1141
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1142
|
+
is satisfied (elementwise logical OR).
|
|
1143
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1144
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1145
|
+
"""
|
|
1146
|
+
|
|
1147
|
+
def __init__(
|
|
1148
|
+
self,
|
|
1149
|
+
*constraints: Constraint,
|
|
1150
|
+
name: str = None,
|
|
1151
|
+
monitor_only: bool = False,
|
|
1152
|
+
rescale_factor: Number = 1.5,
|
|
1153
|
+
) -> None:
|
|
1154
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1155
|
+
|
|
1156
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1157
|
+
|
|
1158
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1159
|
+
is satisfied (elementwise logical OR).
|
|
1160
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1161
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1165
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1166
|
+
the name will be composed from the sub-constraint names joined with
|
|
1167
|
+
" OR ".
|
|
1168
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1169
|
+
but not enforced. Defaults to False.
|
|
1170
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1171
|
+
corrections. Defaults to 1.5.
|
|
1172
|
+
|
|
1173
|
+
Attributes:
|
|
1174
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1175
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1176
|
+
name (str): The name of the constraint (composed or custom).
|
|
1177
|
+
"""
|
|
1178
|
+
# Type checking
|
|
1179
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
855
1180
|
|
|
856
1181
|
# Compose constraint name
|
|
857
|
-
|
|
1182
|
+
if not name:
|
|
1183
|
+
name = " OR ".join([constraint.name for constraint in constraints])
|
|
858
1184
|
|
|
859
1185
|
# Init parent class
|
|
860
1186
|
super().__init__(
|
|
861
|
-
|
|
1187
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
862
1188
|
name,
|
|
863
1189
|
monitor_only,
|
|
864
1190
|
rescale_factor,
|
|
865
1191
|
)
|
|
866
1192
|
|
|
867
1193
|
# Init variables
|
|
868
|
-
self.
|
|
869
|
-
self.transformation_b = transformation_b
|
|
870
|
-
self.rtol = rtol
|
|
871
|
-
self.atol = atol
|
|
1194
|
+
self.constraints = constraints
|
|
872
1195
|
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
self.layer_b = self.descriptor.neuron_to_layer[neuron_name_b]
|
|
876
|
-
self.index_a = self.descriptor.neuron_to_index[neuron_name_a]
|
|
877
|
-
self.index_b = self.descriptor.neuron_to_index[neuron_name_b]
|
|
1196
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1197
|
+
"""Evaluate whether any sub-constraints are satisfied.
|
|
878
1198
|
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
) -> tuple[Tensor, int]:
|
|
1199
|
+
Args:
|
|
1200
|
+
data: Model predictions and associated batch/context information.
|
|
882
1201
|
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
1202
|
+
Returns:
|
|
1203
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1204
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1205
|
+
elementwise whether any constraints are satisfied
|
|
1206
|
+
(logical OR).
|
|
1207
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1208
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1209
|
+
in downstream processing.
|
|
1210
|
+
"""
|
|
1211
|
+
total_satisfaction: Tensor = None
|
|
1212
|
+
total_mask: Tensor = None
|
|
1213
|
+
|
|
1214
|
+
# TODO vectorize this loop
|
|
1215
|
+
for constraint in self.constraints:
|
|
1216
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1217
|
+
if total_satisfaction is None:
|
|
1218
|
+
total_satisfaction = satisfaction
|
|
1219
|
+
total_mask = mask
|
|
1220
|
+
else:
|
|
1221
|
+
total_satisfaction = logical_or(total_satisfaction, satisfaction)
|
|
1222
|
+
total_mask = logical_or(total_mask, mask)
|
|
886
1223
|
|
|
887
|
-
|
|
888
|
-
selection_a = self.transformation_a(selection_a)
|
|
889
|
-
selection_b = self.transformation_b(selection_b)
|
|
890
|
-
|
|
891
|
-
# Calculate result
|
|
892
|
-
result = isclose(
|
|
893
|
-
square(selection_a) + square(selection_b),
|
|
894
|
-
ones_like(selection_a, device=self.device),
|
|
895
|
-
rtol=self.rtol,
|
|
896
|
-
atol=self.atol,
|
|
897
|
-
).float()
|
|
898
|
-
|
|
899
|
-
return result, numel(result)
|
|
900
|
-
|
|
901
|
-
def calculate_direction(
|
|
902
|
-
self, prediction: dict[str, Tensor]
|
|
903
|
-
) -> Dict[str, Tensor]:
|
|
904
|
-
# NOTE currently only works for dense layers due
|
|
905
|
-
# to neuron to index translation
|
|
1224
|
+
return total_satisfaction.float(), total_mask.float()
|
|
906
1225
|
|
|
907
|
-
|
|
1226
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1227
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
908
1228
|
|
|
909
|
-
|
|
910
|
-
|
|
1229
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1230
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1231
|
+
for each affected layer.
|
|
1232
|
+
|
|
1233
|
+
Args:
|
|
1234
|
+
data: Model predictions and associated batch/context information.
|
|
911
1235
|
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
1236
|
+
Returns:
|
|
1237
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1238
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1239
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1240
|
+
sub-constraint directions.
|
|
1241
|
+
"""
|
|
1242
|
+
total_direction: dict[str, Tensor] = {}
|
|
915
1243
|
|
|
916
|
-
|
|
917
|
-
|
|
1244
|
+
# TODO vectorize this loop
|
|
1245
|
+
for constraint in self.constraints:
|
|
1246
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1247
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1248
|
+
direction = constraint.calculate_direction(data)
|
|
918
1249
|
|
|
919
|
-
|
|
1250
|
+
for layer, dir in direction.items():
|
|
1251
|
+
if layer not in total_direction:
|
|
1252
|
+
total_direction[layer] = satisfaction.unsqueeze(1) * dir
|
|
1253
|
+
else:
|
|
1254
|
+
total_direction[layer] += satisfaction.unsqueeze(1) * dir
|
|
1255
|
+
|
|
1256
|
+
return total_direction
|