congrads 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -21
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +148 -29
- congrads/metrics.py +109 -19
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/{utils.py → utils/preprocessors.py} +201 -72
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -389
- congrads/core.py +0 -225
- congrads/datasets.py +0 -195
- congrads/networks.py +0 -90
- congrads-0.2.0.dist-info/LICENSE +0 -26
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- congrads-0.2.0.dist-info/WHEEL +0 -5
- congrads-0.2.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,1255 @@
|
|
|
1
|
+
"""Module providing constraint classes for guiding neural network training.
|
|
2
|
+
|
|
3
|
+
This module defines constraints that enforce specific conditions on network outputs
|
|
4
|
+
to steer learning. Available constraint types include:
|
|
5
|
+
|
|
6
|
+
- `Constraint`: Base class for all constraint types, defining the interface and core
|
|
7
|
+
behavior.
|
|
8
|
+
- `ImplicationConstraint`: Enforces one condition only if another condition is met,
|
|
9
|
+
useful for modeling implications between outputs.
|
|
10
|
+
- `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
|
|
11
|
+
- `BinaryConstraint`: Enforces a binary comparison between two tags using a
|
|
12
|
+
comparison function (e.g., less than, greater than).
|
|
13
|
+
- `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
|
|
14
|
+
value, controlling total output.
|
|
15
|
+
|
|
16
|
+
These constraints can steer the learning process by applying logical implications
|
|
17
|
+
or numerical bounds.
|
|
18
|
+
|
|
19
|
+
Usage:
|
|
20
|
+
1. Define a custom constraint class by inheriting from `Constraint`.
|
|
21
|
+
2. Apply the constraint to your neural network during training.
|
|
22
|
+
3. Use helper classes like `IdentityTransformation` for transformations and
|
|
23
|
+
comparisons in constraints.
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from collections.abc import Callable
|
|
28
|
+
from numbers import Number
|
|
29
|
+
from typing import Literal
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
from torch import (
|
|
33
|
+
Tensor,
|
|
34
|
+
argsort,
|
|
35
|
+
eq,
|
|
36
|
+
logical_and,
|
|
37
|
+
logical_not,
|
|
38
|
+
logical_or,
|
|
39
|
+
ones,
|
|
40
|
+
ones_like,
|
|
41
|
+
reshape,
|
|
42
|
+
sign,
|
|
43
|
+
sort,
|
|
44
|
+
stack,
|
|
45
|
+
tensor,
|
|
46
|
+
triu,
|
|
47
|
+
unique,
|
|
48
|
+
zeros_like,
|
|
49
|
+
)
|
|
50
|
+
from torch.nn.functional import normalize
|
|
51
|
+
|
|
52
|
+
from ..transformations.base import Transformation
|
|
53
|
+
from ..transformations.registry import IdentityTransformation
|
|
54
|
+
from ..utils.validation import validate_comparator, validate_iterable, validate_type
|
|
55
|
+
from .base import Constraint, MonotonicityConstraint
|
|
56
|
+
|
|
57
|
+
COMPARATOR_MAP: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
|
|
58
|
+
">": torch.gt,
|
|
59
|
+
">=": torch.ge,
|
|
60
|
+
"<": torch.lt,
|
|
61
|
+
"<=": torch.le,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ImplicationConstraint(Constraint):
|
|
66
|
+
"""Represents an implication constraint between two constraints (head and body).
|
|
67
|
+
|
|
68
|
+
The implication constraint ensures that the `body` constraint only applies
|
|
69
|
+
when the `head` constraint is satisfied. If the `head` constraint is not
|
|
70
|
+
satisfied, the `body` constraint does not apply.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
head: Constraint,
|
|
76
|
+
body: Constraint,
|
|
77
|
+
name: str = None,
|
|
78
|
+
):
|
|
79
|
+
"""Initializes an ImplicationConstraint instance.
|
|
80
|
+
|
|
81
|
+
Uses `enforce` and `rescale_factor` from the body constraint.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
head (Constraint): Constraint defining the head of the implication.
|
|
85
|
+
body (Constraint): Constraint defining the body of the implication.
|
|
86
|
+
name (str, optional): A unique name for the constraint. If not
|
|
87
|
+
provided, a name is generated based on the class name and a
|
|
88
|
+
random suffix.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
92
|
+
|
|
93
|
+
"""
|
|
94
|
+
# Type checking
|
|
95
|
+
validate_type("head", head, Constraint)
|
|
96
|
+
validate_type("body", body, Constraint)
|
|
97
|
+
|
|
98
|
+
# Compose constraint name
|
|
99
|
+
name = f"{body.name} if {head.name}"
|
|
100
|
+
|
|
101
|
+
# Init parent class
|
|
102
|
+
super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
|
|
103
|
+
|
|
104
|
+
self.head = head
|
|
105
|
+
self.body = body
|
|
106
|
+
|
|
107
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
108
|
+
"""Check whether the implication constraint is satisfied.
|
|
109
|
+
|
|
110
|
+
Evaluates the `head` and `body` constraints. The `body` constraint
|
|
111
|
+
is enforced only if the `head` constraint is satisfied. If the
|
|
112
|
+
`head` constraint is not satisfied, the `body` constraint does not
|
|
113
|
+
affect the result.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
tuple[Tensor, Tensor]:
|
|
120
|
+
- result: Tensor indicating satisfaction of the implication
|
|
121
|
+
constraint (1 if satisfied, 0 otherwise).
|
|
122
|
+
- head_satisfaction: Tensor indicating satisfaction of the
|
|
123
|
+
head constraint alone.
|
|
124
|
+
"""
|
|
125
|
+
# Check satisfaction of head and body constraints
|
|
126
|
+
head_satisfaction, _ = self.head.check_constraint(data)
|
|
127
|
+
body_satisfaction, _ = self.body.check_constraint(data)
|
|
128
|
+
|
|
129
|
+
# If head constraint is satisfied (returning 1),
|
|
130
|
+
# the body constraint matters (and should return 0/1 based on body)
|
|
131
|
+
# If head constraint is not satisfied (returning 0),
|
|
132
|
+
# the body constraint does not apply (and should return 1)
|
|
133
|
+
result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
|
|
134
|
+
|
|
135
|
+
return result, head_satisfaction
|
|
136
|
+
|
|
137
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
138
|
+
"""Compute adjustment directions for tags to satisfy the constraint.
|
|
139
|
+
|
|
140
|
+
Uses the `body` constraint directions as the update vector. Only
|
|
141
|
+
applies updates if the `head` constraint is satisfied. Currently,
|
|
142
|
+
this method only works for dense layers due to tag-to-index
|
|
143
|
+
translation limitations.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
dict[str, Tensor]: Dictionary mapping tags to tensors
|
|
150
|
+
specifying the adjustment direction for each tag.
|
|
151
|
+
"""
|
|
152
|
+
# NOTE currently only works for dense layers
|
|
153
|
+
# due to tag to index translation
|
|
154
|
+
|
|
155
|
+
# Use directions of constraint body as update vector
|
|
156
|
+
return self.body.calculate_direction(data)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ScalarConstraint(Constraint):
|
|
160
|
+
"""A constraint that enforces scalar-based comparisons on a specific tag.
|
|
161
|
+
|
|
162
|
+
This class ensures that the output of a specified tag satisfies a scalar
|
|
163
|
+
comparison operation (e.g., less than, greater than, etc.). It uses a
|
|
164
|
+
comparator function to validate the condition and calculates adjustment
|
|
165
|
+
directions accordingly.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
operand (Union[str, Transformation]): Name of the tag or a
|
|
169
|
+
transformation to apply.
|
|
170
|
+
comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
|
|
171
|
+
scalar (Number): The scalar value to compare against.
|
|
172
|
+
name (str, optional): A unique name for the constraint. If not
|
|
173
|
+
provided, a name is auto-generated in the format
|
|
174
|
+
"<tag> <comparator> <scalar>".
|
|
175
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
176
|
+
without adjusting the loss. Defaults to True.
|
|
177
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
178
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
182
|
+
|
|
183
|
+
Notes:
|
|
184
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
185
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
operand: str | Transformation,
|
|
192
|
+
comparator: Literal[">", "<", ">=", "<="],
|
|
193
|
+
scalar: Number,
|
|
194
|
+
name: str = None,
|
|
195
|
+
enforce: bool = True,
|
|
196
|
+
rescale_factor: Number = 1.5,
|
|
197
|
+
) -> None:
|
|
198
|
+
"""Initializes a ScalarConstraint instance.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
operand (Union[str, Transformation]): Function that needs to be
|
|
202
|
+
performed on the network variables before applying the
|
|
203
|
+
constraint.
|
|
204
|
+
comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
|
|
205
|
+
scalar (Number): Constant to compare the variable to.
|
|
206
|
+
name (str, optional): A unique name for the constraint. If not
|
|
207
|
+
provided, a name is generated based on the class name and a
|
|
208
|
+
random suffix.
|
|
209
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
210
|
+
without adjusting the loss. Defaults to True.
|
|
211
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
212
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
213
|
+
than 1 to give weight to the constraint.
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
217
|
+
|
|
218
|
+
Notes:
|
|
219
|
+
- The `tag` must be defined in the `descriptor` mapping.
|
|
220
|
+
- The constraint name is composed using the tag, comparator, and scalar value.
|
|
221
|
+
"""
|
|
222
|
+
# Type checking
|
|
223
|
+
validate_type("operand", operand, (str, Transformation))
|
|
224
|
+
validate_comparator("comparator", comparator, COMPARATOR_MAP)
|
|
225
|
+
validate_type("scalar", scalar, Number)
|
|
226
|
+
|
|
227
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
228
|
+
if isinstance(operand, Transformation):
|
|
229
|
+
tag = operand.tag
|
|
230
|
+
transformation = operand
|
|
231
|
+
else:
|
|
232
|
+
tag = operand
|
|
233
|
+
transformation = IdentityTransformation(tag)
|
|
234
|
+
|
|
235
|
+
# Compose constraint name
|
|
236
|
+
name = f"{tag} {comparator} {str(scalar)}"
|
|
237
|
+
|
|
238
|
+
# Init parent class
|
|
239
|
+
super().__init__({tag}, name, enforce, rescale_factor)
|
|
240
|
+
|
|
241
|
+
# Init variables
|
|
242
|
+
self.tag = tag
|
|
243
|
+
self.scalar = scalar
|
|
244
|
+
self.transformation = transformation
|
|
245
|
+
|
|
246
|
+
# Calculate directions based on constraint operator
|
|
247
|
+
if comparator in ["<", "<="]:
|
|
248
|
+
self.direction = 1
|
|
249
|
+
elif comparator in [">", ">="]:
|
|
250
|
+
self.direction = -1
|
|
251
|
+
self.comparator = COMPARATOR_MAP[comparator]
|
|
252
|
+
|
|
253
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
254
|
+
"""Check if the scalar constraint is satisfied for a given tag.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
tuple[Tensor, Tensor]:
|
|
261
|
+
- result: Tensor indicating whether the tag satisfies the constraint.
|
|
262
|
+
- ones_like(result): Tensor of ones with same shape as `result`.
|
|
263
|
+
"""
|
|
264
|
+
# Select relevant columns
|
|
265
|
+
selection = self.descriptor.select(self.tag, data)
|
|
266
|
+
|
|
267
|
+
# Apply transformation
|
|
268
|
+
selection = self.transformation(selection)
|
|
269
|
+
|
|
270
|
+
# Calculate current constraint result
|
|
271
|
+
result = self.comparator(selection, self.scalar).float()
|
|
272
|
+
return result, ones_like(result)
|
|
273
|
+
|
|
274
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
275
|
+
"""Compute adjustment directions to satisfy the scalar constraint.
|
|
276
|
+
|
|
277
|
+
Only works for dense layers due to tag-to-index translation.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
dict[str, Tensor]: Dictionary mapping layers to tensors specifying
|
|
284
|
+
the adjustment direction for each tag.
|
|
285
|
+
"""
|
|
286
|
+
# NOTE currently only works for dense layers due
|
|
287
|
+
# to tag to index translation
|
|
288
|
+
|
|
289
|
+
output = {}
|
|
290
|
+
|
|
291
|
+
for layer in self.layers:
|
|
292
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
293
|
+
|
|
294
|
+
layer, index = self.descriptor.location(self.tag)
|
|
295
|
+
output[layer][index] = self.direction
|
|
296
|
+
|
|
297
|
+
for layer in self.layers:
|
|
298
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
299
|
+
|
|
300
|
+
return output
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class BinaryConstraint(Constraint):
|
|
304
|
+
"""A constraint that enforces a binary comparison between two tags.
|
|
305
|
+
|
|
306
|
+
This class ensures that the output of one tag satisfies a comparison
|
|
307
|
+
operation with the output of another tag (e.g., less than, greater than, etc.).
|
|
308
|
+
It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
312
|
+
tag or a transformation to apply.
|
|
313
|
+
comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
|
|
314
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
315
|
+
tag or a transformation to apply.
|
|
316
|
+
name (str, optional): A unique name for the constraint. If not
|
|
317
|
+
provided, a name is auto-generated in the format
|
|
318
|
+
"<operand_left> <comparator> <operand_right>".
|
|
319
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
320
|
+
without adjusting the loss. Defaults to True.
|
|
321
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
322
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
326
|
+
|
|
327
|
+
Notes:
|
|
328
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
329
|
+
- The constraint name is composed using the left tag, comparator, and right tag.
|
|
330
|
+
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
operand_left: str | Transformation,
|
|
336
|
+
comparator: Literal[">", "<", ">=", "<="],
|
|
337
|
+
operand_right: str | Transformation,
|
|
338
|
+
name: str = None,
|
|
339
|
+
enforce: bool = True,
|
|
340
|
+
rescale_factor: Number = 1.5,
|
|
341
|
+
) -> None:
|
|
342
|
+
"""Initializes a BinaryConstraint instance.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
operand_left (Union[str, Transformation]): Name of the left
|
|
346
|
+
tag or a transformation to apply.
|
|
347
|
+
comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
|
|
348
|
+
operand_right (Union[str, Transformation]): Name of the right
|
|
349
|
+
tag or a transformation to apply.
|
|
350
|
+
name (str, optional): A unique name for the constraint. If not
|
|
351
|
+
provided, a name is auto-generated in the format
|
|
352
|
+
"<operand_left> <comparator> <operand_right>".
|
|
353
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
354
|
+
without adjusting the loss. Defaults to True.
|
|
355
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
356
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
357
|
+
|
|
358
|
+
Raises:
|
|
359
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
360
|
+
|
|
361
|
+
Notes:
|
|
362
|
+
- The tags must be defined in the `descriptor` mapping.
|
|
363
|
+
- The constraint name is composed using the left tag,
|
|
364
|
+
comparator, and right tag.
|
|
365
|
+
"""
|
|
366
|
+
# Type checking
|
|
367
|
+
validate_type("operand_left", operand_left, (str, Transformation))
|
|
368
|
+
validate_comparator("comparator", comparator, COMPARATOR_MAP)
|
|
369
|
+
validate_type("operand_right", operand_right, (str, Transformation))
|
|
370
|
+
|
|
371
|
+
# If transformation is provided, get tag name, else use IdentityTransformation
|
|
372
|
+
if isinstance(operand_left, Transformation):
|
|
373
|
+
tag_left = operand_left.tag
|
|
374
|
+
transformation_left = operand_left
|
|
375
|
+
else:
|
|
376
|
+
tag_left = operand_left
|
|
377
|
+
transformation_left = IdentityTransformation(tag_left)
|
|
378
|
+
|
|
379
|
+
if isinstance(operand_right, Transformation):
|
|
380
|
+
tag_right = operand_right.tag
|
|
381
|
+
transformation_right = operand_right
|
|
382
|
+
else:
|
|
383
|
+
tag_right = operand_right
|
|
384
|
+
transformation_right = IdentityTransformation(tag_right)
|
|
385
|
+
|
|
386
|
+
# Compose constraint name
|
|
387
|
+
name = f"{tag_left} {comparator} {tag_right}"
|
|
388
|
+
|
|
389
|
+
# Init parent class
|
|
390
|
+
super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
|
|
391
|
+
|
|
392
|
+
# Init variables
|
|
393
|
+
self.tag_left = tag_left
|
|
394
|
+
self.tag_right = tag_right
|
|
395
|
+
self.transformation_left = transformation_left
|
|
396
|
+
self.transformation_right = transformation_right
|
|
397
|
+
|
|
398
|
+
# Calculate directions based on constraint operator
|
|
399
|
+
if comparator in ["<", "<="]:
|
|
400
|
+
self.direction_left = 1
|
|
401
|
+
self.direction_right = -1
|
|
402
|
+
elif comparator in [">", ">="]:
|
|
403
|
+
self.direction_left = -1
|
|
404
|
+
self.direction_right = 1
|
|
405
|
+
self.comparator = COMPARATOR_MAP[comparator]
|
|
406
|
+
|
|
407
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
408
|
+
"""Evaluate whether the binary constraint is satisfied for the current predictions.
|
|
409
|
+
|
|
410
|
+
The constraint compares the outputs of two tags using the specified
|
|
411
|
+
comparator function. A result of `1` indicates the constraint is satisfied
|
|
412
|
+
for a sample, and `0` indicates it is violated.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
tuple[Tensor, Tensor]:
|
|
419
|
+
- result (Tensor): Binary tensor indicating constraint satisfaction
|
|
420
|
+
(1 for satisfied, 0 for violated) for each sample.
|
|
421
|
+
- mask (Tensor): Tensor of ones with the same shape as `result`,
|
|
422
|
+
used for constraint aggregation.
|
|
423
|
+
"""
|
|
424
|
+
# Select relevant columns
|
|
425
|
+
selection_left = self.descriptor.select(self.tag_left, data)
|
|
426
|
+
selection_right = self.descriptor.select(self.tag_right, data)
|
|
427
|
+
|
|
428
|
+
# Apply transformations
|
|
429
|
+
selection_left = self.transformation_left(selection_left)
|
|
430
|
+
selection_right = self.transformation_right(selection_right)
|
|
431
|
+
|
|
432
|
+
result = self.comparator(selection_left, selection_right).float()
|
|
433
|
+
|
|
434
|
+
return result, ones_like(result)
|
|
435
|
+
|
|
436
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
437
|
+
"""Compute adjustment directions for the tags involved in the binary constraint.
|
|
438
|
+
|
|
439
|
+
The returned directions indicate how to adjust each tag's output to
|
|
440
|
+
satisfy the constraint. Only currently supported for dense layers.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
dict[str, Tensor]: A mapping from layer names to tensors specifying
|
|
447
|
+
the normalized adjustment directions for each tag involved in the
|
|
448
|
+
constraint.
|
|
449
|
+
"""
|
|
450
|
+
# NOTE currently only works for dense layers due
|
|
451
|
+
# to tag to index translation
|
|
452
|
+
|
|
453
|
+
output = {}
|
|
454
|
+
|
|
455
|
+
for layer in self.layers:
|
|
456
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
457
|
+
|
|
458
|
+
layer_left, index_left = self.descriptor.location(self.tag_left)
|
|
459
|
+
layer_right, index_right = self.descriptor.location(self.tag_right)
|
|
460
|
+
output[layer_left][index_left] = self.direction_left
|
|
461
|
+
output[layer_right][index_right] = self.direction_right
|
|
462
|
+
|
|
463
|
+
for layer in self.layers:
|
|
464
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
465
|
+
|
|
466
|
+
return output
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class SumConstraint(Constraint):
|
|
470
|
+
"""A constraint that enforces a weighted summation comparison between two groups of tags.
|
|
471
|
+
|
|
472
|
+
This class evaluates whether the weighted sum of outputs from one set of
|
|
473
|
+
tags satisfies a comparison operation with the weighted sum of
|
|
474
|
+
outputs from another set of tags.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
operands_left: list[str | Transformation],
|
|
480
|
+
comparator: Literal[">", "<", ">=", "<="],
|
|
481
|
+
operands_right: list[str | Transformation],
|
|
482
|
+
weights_left: list[Number] = None,
|
|
483
|
+
weights_right: list[Number] = None,
|
|
484
|
+
name: str = None,
|
|
485
|
+
enforce: bool = True,
|
|
486
|
+
rescale_factor: Number = 1.5,
|
|
487
|
+
) -> None:
|
|
488
|
+
"""Initializes the SumConstraint.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
operands_left (list[Union[str, Transformation]]): List of tags
|
|
492
|
+
or transformations on the left side.
|
|
493
|
+
comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
|
|
494
|
+
operands_right (list[Union[str, Transformation]]): List of tags
|
|
495
|
+
or transformations on the right side.
|
|
496
|
+
weights_left (list[Number], optional): Weights for the left
|
|
497
|
+
tags. Defaults to None.
|
|
498
|
+
weights_right (list[Number], optional): Weights for the right
|
|
499
|
+
tags. Defaults to None.
|
|
500
|
+
name (str, optional): Unique name for the constraint.
|
|
501
|
+
If None, it's auto-generated. Defaults to None.
|
|
502
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
503
|
+
without adjusting the loss. Defaults to True.
|
|
504
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
505
|
+
constraint-adjusted loss. Defaults to 1.5.
|
|
506
|
+
|
|
507
|
+
Raises:
|
|
508
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
509
|
+
ValueError: If the dimensions of tags and weights mismatch.
|
|
510
|
+
"""
|
|
511
|
+
# Type checking
|
|
512
|
+
validate_iterable("operands_left", operands_left, (str, Transformation))
|
|
513
|
+
validate_comparator("comparator", comparator, COMPARATOR_MAP)
|
|
514
|
+
validate_iterable("operands_right", operands_right, (str, Transformation))
|
|
515
|
+
validate_iterable("weights_left", weights_left, Number, allow_none=True)
|
|
516
|
+
validate_iterable("weights_right", weights_right, Number, allow_none=True)
|
|
517
|
+
|
|
518
|
+
# If transformation is provided, get tag, else use IdentityTransformation
|
|
519
|
+
tags_left: list[str] = []
|
|
520
|
+
transformations_left: list[Transformation] = []
|
|
521
|
+
for operand_left in operands_left:
|
|
522
|
+
if isinstance(operand_left, Transformation):
|
|
523
|
+
tag_left = operand_left.tag
|
|
524
|
+
tags_left.append(tag_left)
|
|
525
|
+
transformations_left.append(operand_left)
|
|
526
|
+
else:
|
|
527
|
+
tag_left = operand_left
|
|
528
|
+
tags_left.append(tag_left)
|
|
529
|
+
transformations_left.append(IdentityTransformation(tag_left))
|
|
530
|
+
|
|
531
|
+
tags_right: list[str] = []
|
|
532
|
+
transformations_right: list[Transformation] = []
|
|
533
|
+
for operand_right in operands_right:
|
|
534
|
+
if isinstance(operand_right, Transformation):
|
|
535
|
+
tag_right = operand_right.tag
|
|
536
|
+
tags_right.append(tag_right)
|
|
537
|
+
transformations_right.append(operand_right)
|
|
538
|
+
else:
|
|
539
|
+
tag_right = operand_right
|
|
540
|
+
tags_right.append(tag_right)
|
|
541
|
+
transformations_right.append(IdentityTransformation(tag_right))
|
|
542
|
+
|
|
543
|
+
# Compose constraint name
|
|
544
|
+
w_left = weights_left or [""] * len(tags_left)
|
|
545
|
+
w_right = weights_right or [""] * len(tags_right)
|
|
546
|
+
left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
|
|
547
|
+
right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
|
|
548
|
+
name = f"{left_expr} {comparator} {right_expr}"
|
|
549
|
+
|
|
550
|
+
# Init parent class
|
|
551
|
+
tags = set(tags_left) | set(tags_right)
|
|
552
|
+
super().__init__(tags, name, enforce, rescale_factor)
|
|
553
|
+
|
|
554
|
+
# Init variables
|
|
555
|
+
self.tags_left = tags_left
|
|
556
|
+
self.tags_right = tags_right
|
|
557
|
+
self.transformations_left = transformations_left
|
|
558
|
+
self.transformations_right = transformations_right
|
|
559
|
+
|
|
560
|
+
# If feature list dimensions don't match weight list dimensions, raise error
|
|
561
|
+
if weights_left and (len(tags_left) != len(weights_left)):
|
|
562
|
+
raise ValueError(
|
|
563
|
+
"The dimensions of tags_left don't match with the dimensions of weights_left."
|
|
564
|
+
)
|
|
565
|
+
if weights_right and (len(tags_right) != len(weights_right)):
|
|
566
|
+
raise ValueError(
|
|
567
|
+
"The dimensions of tags_right don't match with the dimensions of weights_right."
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# If weights are provided for summation, transform them to Tensors
|
|
571
|
+
if weights_left:
|
|
572
|
+
self.weights_left = tensor(weights_left, device=self.device)
|
|
573
|
+
else:
|
|
574
|
+
self.weights_left = ones(len(tags_left), device=self.device)
|
|
575
|
+
if weights_right:
|
|
576
|
+
self.weights_right = tensor(weights_right, device=self.device)
|
|
577
|
+
else:
|
|
578
|
+
self.weights_right = ones(len(tags_right), device=self.device)
|
|
579
|
+
|
|
580
|
+
# Calculate directions based on constraint operator
|
|
581
|
+
if comparator in ["<", "<="]:
|
|
582
|
+
self.direction_left = -1
|
|
583
|
+
self.direction_right = 1
|
|
584
|
+
elif comparator in [">", ">="]:
|
|
585
|
+
self.direction_left = 1
|
|
586
|
+
self.direction_right = -1
|
|
587
|
+
self.comparator = COMPARATOR_MAP[comparator]
|
|
588
|
+
|
|
589
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
590
|
+
"""Evaluate whether the weighted sum constraint is satisfied.
|
|
591
|
+
|
|
592
|
+
Computes the weighted sum of outputs from the left and right tags,
|
|
593
|
+
applies the specified comparator function, and returns a binary result for
|
|
594
|
+
each sample.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
tuple[Tensor, Tensor]:
|
|
601
|
+
- result (Tensor): Binary tensor indicating whether the constraint
|
|
602
|
+
is satisfied (1) or violated (0) for each sample.
|
|
603
|
+
- mask (Tensor): Tensor of ones, used for constraint aggregation.
|
|
604
|
+
"""
|
|
605
|
+
|
|
606
|
+
def compute_weighted_sum(
|
|
607
|
+
tags: list[str],
|
|
608
|
+
transformations: list[Transformation],
|
|
609
|
+
weights: Tensor,
|
|
610
|
+
) -> Tensor:
|
|
611
|
+
# Select relevant columns
|
|
612
|
+
selections = [self.descriptor.select(tag, data) for tag in tags]
|
|
613
|
+
|
|
614
|
+
# Apply transformations
|
|
615
|
+
results = []
|
|
616
|
+
for transformation, selection in zip(transformations, selections, strict=False):
|
|
617
|
+
results.append(transformation(selection))
|
|
618
|
+
|
|
619
|
+
# Extract predictions for all tags and apply weights in bulk
|
|
620
|
+
predictions = stack(results)
|
|
621
|
+
|
|
622
|
+
# Calculate weighted sum
|
|
623
|
+
return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
|
|
624
|
+
|
|
625
|
+
# Compute weighted sums
|
|
626
|
+
weighted_sum_left = compute_weighted_sum(
|
|
627
|
+
self.tags_left, self.transformations_left, self.weights_left
|
|
628
|
+
)
|
|
629
|
+
weighted_sum_right = compute_weighted_sum(
|
|
630
|
+
self.tags_right, self.transformations_right, self.weights_right
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Apply the comparator and calculate the result
|
|
634
|
+
result = self.comparator(weighted_sum_left, weighted_sum_right).float()
|
|
635
|
+
|
|
636
|
+
return result, ones_like(result)
|
|
637
|
+
|
|
638
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
639
|
+
"""Compute adjustment directions for tags involved in the weighted sum constraint.
|
|
640
|
+
|
|
641
|
+
The directions indicate how to adjust each tag's output to satisfy the
|
|
642
|
+
constraint. Only dense layers are currently supported.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
dict[str, Tensor]: Mapping from layer names to normalized tensors
|
|
649
|
+
specifying adjustment directions for each tag involved in the constraint.
|
|
650
|
+
"""
|
|
651
|
+
# NOTE currently only works for dense layers
|
|
652
|
+
# due to tag to index translation
|
|
653
|
+
|
|
654
|
+
output = {}
|
|
655
|
+
|
|
656
|
+
for layer in self.layers:
|
|
657
|
+
output[layer] = zeros_like(data[layer][0], device=self.device)
|
|
658
|
+
|
|
659
|
+
for tag_left in self.tags_left:
|
|
660
|
+
layer, index = self.descriptor.location(tag_left)
|
|
661
|
+
output[layer][index] = self.direction_left
|
|
662
|
+
|
|
663
|
+
for tag_right in self.tags_right:
|
|
664
|
+
layer, index = self.descriptor.location(tag_right)
|
|
665
|
+
output[layer][index] = self.direction_right
|
|
666
|
+
|
|
667
|
+
for layer in self.layers:
|
|
668
|
+
output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
|
|
669
|
+
|
|
670
|
+
return output
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
class RankedMonotonicityConstraint(MonotonicityConstraint):
|
|
674
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
675
|
+
|
|
676
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
677
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
def __init__(
|
|
681
|
+
self,
|
|
682
|
+
tag_prediction: str,
|
|
683
|
+
tag_reference: str,
|
|
684
|
+
rescale_factor_lower: float = 1.5,
|
|
685
|
+
rescale_factor_upper: float = 1.75,
|
|
686
|
+
stable: bool = True,
|
|
687
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
688
|
+
name: str = None,
|
|
689
|
+
enforce: bool = True,
|
|
690
|
+
):
|
|
691
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
692
|
+
|
|
693
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
694
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
698
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
699
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
700
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
701
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
702
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
703
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
704
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
705
|
+
"""
|
|
706
|
+
# Compose constraint name
|
|
707
|
+
if name is None:
|
|
708
|
+
name = f"{tag_prediction} monotonically (ranked) {direction} by {tag_reference}"
|
|
709
|
+
|
|
710
|
+
# Init parent class
|
|
711
|
+
super().__init__(
|
|
712
|
+
tag_prediction=tag_prediction,
|
|
713
|
+
tag_reference=tag_reference,
|
|
714
|
+
rescale_factor_lower=rescale_factor_lower,
|
|
715
|
+
rescale_factor_upper=rescale_factor_upper,
|
|
716
|
+
stable=stable,
|
|
717
|
+
direction=direction,
|
|
718
|
+
name=name,
|
|
719
|
+
enforce=enforce,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
723
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
724
|
+
# Select relevant columns
|
|
725
|
+
preds = self.descriptor.select(self.tag_prediction, data)
|
|
726
|
+
targets = self.descriptor.select(self.tag_reference, data)
|
|
727
|
+
|
|
728
|
+
# Utility: convert values -> ranks (0 ... num_features-1)
|
|
729
|
+
def compute_ranks(x: Tensor, descending: bool) -> Tensor:
|
|
730
|
+
return argsort(
|
|
731
|
+
argsort(x, descending=descending, stable=self.stable, dim=0),
|
|
732
|
+
descending=False,
|
|
733
|
+
stable=self.stable,
|
|
734
|
+
dim=0,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Compute predicted and target ranks
|
|
738
|
+
pred_ranks = compute_ranks(preds, descending=self.descending)
|
|
739
|
+
target_ranks = compute_ranks(targets, descending=False)
|
|
740
|
+
|
|
741
|
+
# Rank difference
|
|
742
|
+
rank_diff = pred_ranks - target_ranks
|
|
743
|
+
|
|
744
|
+
# Rescale differences into [rescale_factor_lower, rescale_factor_upper]
|
|
745
|
+
batch_size = preds.shape[0]
|
|
746
|
+
invert_direction = -1 if self.descending else 1
|
|
747
|
+
self.compared_rankings = (
|
|
748
|
+
(rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
|
|
749
|
+
+ self.rescale_factor_lower * sign(rank_diff)
|
|
750
|
+
) * invert_direction
|
|
751
|
+
|
|
752
|
+
# Calculate satisfaction
|
|
753
|
+
incorrect_rankings = eq(self.compared_rankings, 0).float()
|
|
754
|
+
|
|
755
|
+
return incorrect_rankings, ones_like(incorrect_rankings)
|
|
756
|
+
|
|
757
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
758
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
759
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
760
|
+
return {layer: self.compared_rankings}
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
class PairwiseMonotonicityConstraint(MonotonicityConstraint):
|
|
764
|
+
"""Constraint that enforces a monotonic relationship between two tags.
|
|
765
|
+
|
|
766
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
767
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
768
|
+
"""
|
|
769
|
+
|
|
770
|
+
def __init__(
|
|
771
|
+
self,
|
|
772
|
+
tag_prediction: str,
|
|
773
|
+
tag_reference: str,
|
|
774
|
+
rescale_factor_lower: float = 1.5,
|
|
775
|
+
rescale_factor_upper: float = 1.75,
|
|
776
|
+
stable: bool = True,
|
|
777
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
778
|
+
name: str = None,
|
|
779
|
+
enforce: bool = True,
|
|
780
|
+
):
|
|
781
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
782
|
+
|
|
783
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
784
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
788
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
789
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
790
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
791
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
792
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
793
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
794
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
795
|
+
"""
|
|
796
|
+
# Compose constraint name
|
|
797
|
+
if name is None:
|
|
798
|
+
name = f"{tag_prediction} monotonically (pairwise) {direction} by {tag_reference}"
|
|
799
|
+
|
|
800
|
+
# Init parent class
|
|
801
|
+
super().__init__(
|
|
802
|
+
tag_prediction=tag_prediction,
|
|
803
|
+
tag_reference=tag_reference,
|
|
804
|
+
rescale_factor_lower=rescale_factor_lower,
|
|
805
|
+
rescale_factor_upper=rescale_factor_upper,
|
|
806
|
+
stable=stable,
|
|
807
|
+
direction=direction,
|
|
808
|
+
name=name,
|
|
809
|
+
enforce=enforce,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
813
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
814
|
+
# Select relevant columns
|
|
815
|
+
preds = self.descriptor.select(self.tag_prediction, data).squeeze(1)
|
|
816
|
+
targets = self.descriptor.select(self.tag_reference, data).squeeze(1)
|
|
817
|
+
|
|
818
|
+
# Sort targets back to original order
|
|
819
|
+
sorted_targets, idx = sort(targets, stable=self.stable, dim=0, descending=self.descending)
|
|
820
|
+
sorted_preds = preds[idx]
|
|
821
|
+
|
|
822
|
+
# Pairwise differences
|
|
823
|
+
preds_diff = sorted_preds.unsqueeze(1) - sorted_preds.unsqueeze(0)
|
|
824
|
+
targets_diff = sorted_targets.unsqueeze(1) - sorted_targets.unsqueeze(0)
|
|
825
|
+
|
|
826
|
+
# Consider only upper triangle to avoid duplicate comparisons
|
|
827
|
+
batch_size = preds.shape[0]
|
|
828
|
+
mask = triu(ones(batch_size, batch_size, dtype=torch.bool), diagonal=1)
|
|
829
|
+
|
|
830
|
+
# Pairwise violations
|
|
831
|
+
violations = (preds_diff * targets_diff < 0) & mask
|
|
832
|
+
|
|
833
|
+
# Rank difference
|
|
834
|
+
sorted_rank_diff = violations.float().sum(dim=1) - violations.float().sum(dim=0)
|
|
835
|
+
|
|
836
|
+
# Undo sorting to match original order
|
|
837
|
+
rank_diff = zeros_like(sorted_rank_diff)
|
|
838
|
+
rank_diff[idx] = sorted_rank_diff
|
|
839
|
+
rank_diff = rank_diff.unsqueeze(1)
|
|
840
|
+
|
|
841
|
+
# Rescale differences into [rescale_factor_lower, rescale_factor_upper]
|
|
842
|
+
invert_direction = -1 if self.descending else 1
|
|
843
|
+
self.compared_rankings = (
|
|
844
|
+
(rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
|
|
845
|
+
+ self.rescale_factor_lower * sign(rank_diff)
|
|
846
|
+
) * invert_direction
|
|
847
|
+
|
|
848
|
+
# Calculate satisfaction
|
|
849
|
+
incorrect_rankings = eq(self.compared_rankings, 0).float()
|
|
850
|
+
|
|
851
|
+
return incorrect_rankings, ones_like(incorrect_rankings)
|
|
852
|
+
|
|
853
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
854
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
855
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
856
|
+
return {layer: self.compared_rankings}
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class PerGroupMonotonicityConstraint(Constraint):
|
|
860
|
+
"""Applies a monotonicity constraint independently per group of samples.
|
|
861
|
+
|
|
862
|
+
This class wraps an existing `MonotonicityConstraint` instance (`base`) and
|
|
863
|
+
enforces it **separately** for each unique group defined by `tag_group`.
|
|
864
|
+
|
|
865
|
+
Each group is treated as an independent mini-batch:
|
|
866
|
+
- The base constraint is applied to the group's subset of data.
|
|
867
|
+
- Violations and directions are computed per group and then reassembled
|
|
868
|
+
into the original batch order.
|
|
869
|
+
|
|
870
|
+
This is an explicit alternative to :class:`EncodedGroupedMonotonicityConstraint`,
|
|
871
|
+
which enforces the same logic using vectorized interval encoding.
|
|
872
|
+
"""
|
|
873
|
+
|
|
874
|
+
def __init__(
|
|
875
|
+
self,
|
|
876
|
+
base: MonotonicityConstraint,
|
|
877
|
+
tag_group: str,
|
|
878
|
+
name: str | None = None,
|
|
879
|
+
):
|
|
880
|
+
"""Initializes the per-group monotonicity constraint.
|
|
881
|
+
|
|
882
|
+
Args:
|
|
883
|
+
base (MonotonicityConstraint): A pre-configured monotonicity constraint to apply per group.
|
|
884
|
+
tag_group (str): Column/tag name that identifies groups.
|
|
885
|
+
name (str | None, optional): Custom name for the constraint. Defaults to
|
|
886
|
+
f"{base.tag_prediction} for each {tag_group} ...".
|
|
887
|
+
"""
|
|
888
|
+
# Compose constraint name
|
|
889
|
+
if name is None:
|
|
890
|
+
name = base.name.replace(
|
|
891
|
+
base.tag_prediction, f"{base.tag_prediction} for each {tag_group}", 1
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
# Init parent class
|
|
895
|
+
super().__init__(base.tags, name, base.enforce, base.rescale_factor)
|
|
896
|
+
|
|
897
|
+
# Init variables
|
|
898
|
+
self.base = base
|
|
899
|
+
self.tag_group = tag_group
|
|
900
|
+
|
|
901
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
902
|
+
"""Evaluate whether the monotonicity constraint is satisfied per group."""
|
|
903
|
+
# Select group identifiers and convert to unique list
|
|
904
|
+
group_identifiers = self.descriptor.select(self.tag_group, data)
|
|
905
|
+
unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
|
|
906
|
+
|
|
907
|
+
# Initialize checks and directions
|
|
908
|
+
checks = zeros_like(group_identifiers, device=self.device)
|
|
909
|
+
self.directions = zeros_like(group_identifiers, device=self.device)
|
|
910
|
+
|
|
911
|
+
# Get prediction and target keys
|
|
912
|
+
preds_key, _ = self.descriptor.location(self.base.tag_prediction)
|
|
913
|
+
targets_key, _ = self.descriptor.location(self.base.tag_reference)
|
|
914
|
+
|
|
915
|
+
for group_identifier in unique_group_identifiers:
|
|
916
|
+
# Create mask for the samples in this group
|
|
917
|
+
group_mask = (group_identifiers == group_identifier).squeeze(1)
|
|
918
|
+
|
|
919
|
+
# Create mini-batch for the group
|
|
920
|
+
group_data = {
|
|
921
|
+
preds_key: data[preds_key][group_mask],
|
|
922
|
+
targets_key: data[targets_key][group_mask],
|
|
923
|
+
}
|
|
924
|
+
|
|
925
|
+
# Call super on the mini-batch
|
|
926
|
+
checks[group_mask], _ = self.base.check_constraint(group_data)
|
|
927
|
+
self.directions[group_mask] = self.base.compared_rankings
|
|
928
|
+
|
|
929
|
+
return checks, ones_like(checks)
|
|
930
|
+
|
|
931
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
932
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
933
|
+
layer, _ = self.descriptor.location(self.base.tag_prediction)
|
|
934
|
+
return {layer: self.directions}
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
class EncodedGroupedMonotonicityConstraint(Constraint):
|
|
938
|
+
"""Applies a base monotonicity constraint to groups via interval encoding.
|
|
939
|
+
|
|
940
|
+
This constraint enforces a monotonic relationship between a prediction tag
|
|
941
|
+
(`base.tag_prediction`) and a reference tag (`base.tag_reference`) for
|
|
942
|
+
each group defined by `tag_group`.
|
|
943
|
+
|
|
944
|
+
Instead of looping over groups, each group's predictions and targets are
|
|
945
|
+
shifted by a large offset to place them in **non-overlapping intervals**.
|
|
946
|
+
This allows the base constraint to be applied in a single, vectorized
|
|
947
|
+
operation.
|
|
948
|
+
|
|
949
|
+
This is a vectorized alternative to :class:`PerGroupMonotonicityConstraint`,
|
|
950
|
+
which enforces the same logic via explicit per-group iteration.
|
|
951
|
+
"""
|
|
952
|
+
|
|
953
|
+
def __init__(
|
|
954
|
+
self,
|
|
955
|
+
base: MonotonicityConstraint,
|
|
956
|
+
tag_group: str,
|
|
957
|
+
name: str | None = None,
|
|
958
|
+
):
|
|
959
|
+
"""Initializes the encoded grouped monotonicity constraint.
|
|
960
|
+
|
|
961
|
+
Args:
|
|
962
|
+
base (MonotonicityConstraint): The base constraint to apply to each group.
|
|
963
|
+
tag_group (str): Column/tag name identifying groups in the batch.
|
|
964
|
+
name (str | None, optional): Optional custom name for this constraint.
|
|
965
|
+
If None, generates a descriptive name based on `base.name`.
|
|
966
|
+
"""
|
|
967
|
+
# Compose constraint name
|
|
968
|
+
if name is None:
|
|
969
|
+
name = base.name.replace(
|
|
970
|
+
base.tag_prediction,
|
|
971
|
+
f"{base.tag_prediction} for each {tag_group} (encoded)",
|
|
972
|
+
1,
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Init parent class
|
|
976
|
+
super().__init__(base.tags, name, base.enforce, base.rescale_factor)
|
|
977
|
+
|
|
978
|
+
# Init variables
|
|
979
|
+
self.base = base
|
|
980
|
+
self.tag_group = tag_group
|
|
981
|
+
|
|
982
|
+
# Initialize negation factor based on direction
|
|
983
|
+
self.negation = 1 if self.base.direction == "ascending" else -1
|
|
984
|
+
|
|
985
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
986
|
+
"""Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval."""
|
|
987
|
+
# Get data and keys
|
|
988
|
+
ids = self.descriptor.select(self.tag_group, data)
|
|
989
|
+
preds = self.descriptor.select(self.base.tag_prediction, data)
|
|
990
|
+
targets = self.descriptor.select(self.base.tag_reference, data)
|
|
991
|
+
preds_key, _ = self.descriptor.location(self.base.tag_prediction)
|
|
992
|
+
targets_key, _ = self.descriptor.location(self.base.tag_reference)
|
|
993
|
+
|
|
994
|
+
new_preds = preds + ids * (preds.max() - preds.min() + 1)
|
|
995
|
+
new_targets = targets + self.negation * ids * (targets.max() - targets.min() + 1)
|
|
996
|
+
|
|
997
|
+
# Create new batch for child constraint
|
|
998
|
+
new_data = {preds_key: new_preds, targets_key: new_targets}
|
|
999
|
+
|
|
1000
|
+
# Call super on the adjusted batch
|
|
1001
|
+
checks, _ = self.base.check_constraint(new_data)
|
|
1002
|
+
self.directions = self.base.compared_rankings
|
|
1003
|
+
|
|
1004
|
+
return checks, ones_like(checks)
|
|
1005
|
+
|
|
1006
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
1007
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
1008
|
+
layer, _ = self.descriptor.location(self.base.tag_prediction)
|
|
1009
|
+
return {layer: self.directions}
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
class ANDConstraint(Constraint):
|
|
1013
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1014
|
+
|
|
1015
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1016
|
+
|
|
1017
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1018
|
+
are satisfied (elementwise logical AND).
|
|
1019
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1020
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1021
|
+
"""
|
|
1022
|
+
|
|
1023
|
+
def __init__(
|
|
1024
|
+
self,
|
|
1025
|
+
*constraints: Constraint,
|
|
1026
|
+
name: str = None,
|
|
1027
|
+
monitor_only: bool = False,
|
|
1028
|
+
rescale_factor: Number = 1.5,
|
|
1029
|
+
) -> None:
|
|
1030
|
+
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1031
|
+
|
|
1032
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1033
|
+
|
|
1034
|
+
* The satisfaction of the AND constraint is `True` only if all sub-constraints
|
|
1035
|
+
are satisfied (elementwise logical AND).
|
|
1036
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1037
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1038
|
+
|
|
1039
|
+
Args:
|
|
1040
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1041
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1042
|
+
the name will be composed from the sub-constraint names joined with
|
|
1043
|
+
" AND ".
|
|
1044
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1045
|
+
but not enforced. Defaults to False.
|
|
1046
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1047
|
+
corrections. Defaults to 1.5.
|
|
1048
|
+
|
|
1049
|
+
Attributes:
|
|
1050
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1051
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1052
|
+
name (str): The name of the constraint (composed or custom).
|
|
1053
|
+
"""
|
|
1054
|
+
# Type checking
|
|
1055
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
1056
|
+
|
|
1057
|
+
# Compose constraint name
|
|
1058
|
+
if not name:
|
|
1059
|
+
name = " AND ".join([constraint.name for constraint in constraints])
|
|
1060
|
+
|
|
1061
|
+
# Init parent class
|
|
1062
|
+
super().__init__(
|
|
1063
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
1064
|
+
name,
|
|
1065
|
+
monitor_only,
|
|
1066
|
+
rescale_factor,
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# Init variables
|
|
1070
|
+
self.constraints = constraints
|
|
1071
|
+
|
|
1072
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1073
|
+
"""Evaluate whether all sub-constraints are satisfied.
|
|
1074
|
+
|
|
1075
|
+
Args:
|
|
1076
|
+
data: Model predictions and associated batch/context information.
|
|
1077
|
+
|
|
1078
|
+
Returns:
|
|
1079
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1080
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1081
|
+
elementwise whether all constraints are satisfied
|
|
1082
|
+
(logical AND).
|
|
1083
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1084
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1085
|
+
in downstream processing.
|
|
1086
|
+
"""
|
|
1087
|
+
total_satisfaction: Tensor = None
|
|
1088
|
+
total_mask: Tensor = None
|
|
1089
|
+
|
|
1090
|
+
# TODO vectorize this loop
|
|
1091
|
+
for constraint in self.constraints:
|
|
1092
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1093
|
+
if total_satisfaction is None:
|
|
1094
|
+
total_satisfaction = satisfaction
|
|
1095
|
+
total_mask = mask
|
|
1096
|
+
else:
|
|
1097
|
+
total_satisfaction = logical_and(total_satisfaction, satisfaction)
|
|
1098
|
+
total_mask = logical_or(total_mask, mask)
|
|
1099
|
+
|
|
1100
|
+
return total_satisfaction.float(), total_mask.float()
|
|
1101
|
+
|
|
1102
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1103
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
1104
|
+
|
|
1105
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1106
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1107
|
+
for each affected layer.
|
|
1108
|
+
|
|
1109
|
+
Args:
|
|
1110
|
+
data: Model predictions and associated batch/context information.
|
|
1111
|
+
|
|
1112
|
+
Returns:
|
|
1113
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1114
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1115
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1116
|
+
sub-constraint directions.
|
|
1117
|
+
"""
|
|
1118
|
+
total_direction: dict[str, Tensor] = {}
|
|
1119
|
+
|
|
1120
|
+
# TODO vectorize this loop
|
|
1121
|
+
for constraint in self.constraints:
|
|
1122
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1123
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1124
|
+
direction = constraint.calculate_direction(data)
|
|
1125
|
+
|
|
1126
|
+
for layer, dir in direction.items():
|
|
1127
|
+
if layer not in total_direction:
|
|
1128
|
+
total_direction[layer] = satisfaction * dir
|
|
1129
|
+
else:
|
|
1130
|
+
total_direction[layer] += satisfaction * dir
|
|
1131
|
+
|
|
1132
|
+
return total_direction
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
class ORConstraint(Constraint):
|
|
1136
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1137
|
+
|
|
1138
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1139
|
+
|
|
1140
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1141
|
+
is satisfied (elementwise logical OR).
|
|
1142
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1143
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1144
|
+
"""
|
|
1145
|
+
|
|
1146
|
+
def __init__(
|
|
1147
|
+
self,
|
|
1148
|
+
*constraints: Constraint,
|
|
1149
|
+
name: str = None,
|
|
1150
|
+
monitor_only: bool = False,
|
|
1151
|
+
rescale_factor: Number = 1.5,
|
|
1152
|
+
) -> None:
|
|
1153
|
+
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
1154
|
+
|
|
1155
|
+
This class combines multiple sub-constraints and evaluates them jointly:
|
|
1156
|
+
|
|
1157
|
+
* The satisfaction of the OR constraint is `True` if at least one sub-constraint
|
|
1158
|
+
is satisfied (elementwise logical OR).
|
|
1159
|
+
* The corrective direction is computed by weighting each sub-constraint's
|
|
1160
|
+
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
*constraints (Constraint): One or more `Constraint` instances to be combined.
|
|
1164
|
+
name (str, optional): A custom name for this constraint. If not provided,
|
|
1165
|
+
the name will be composed from the sub-constraint names joined with
|
|
1166
|
+
" OR ".
|
|
1167
|
+
monitor_only (bool, optional): If True, the constraint will be monitored
|
|
1168
|
+
but not enforced. Defaults to False.
|
|
1169
|
+
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1170
|
+
corrections. Defaults to 1.5.
|
|
1171
|
+
|
|
1172
|
+
Attributes:
|
|
1173
|
+
constraints (tuple[Constraint, ...]): The sub-constraints being combined.
|
|
1174
|
+
neurons (set): The union of neurons referenced by the sub-constraints.
|
|
1175
|
+
name (str): The name of the constraint (composed or custom).
|
|
1176
|
+
"""
|
|
1177
|
+
# Type checking
|
|
1178
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
1179
|
+
|
|
1180
|
+
# Compose constraint name
|
|
1181
|
+
if not name:
|
|
1182
|
+
name = " OR ".join([constraint.name for constraint in constraints])
|
|
1183
|
+
|
|
1184
|
+
# Init parent class
|
|
1185
|
+
super().__init__(
|
|
1186
|
+
set().union(*(constraint.tags for constraint in constraints)),
|
|
1187
|
+
name,
|
|
1188
|
+
monitor_only,
|
|
1189
|
+
rescale_factor,
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
# Init variables
|
|
1193
|
+
self.constraints = constraints
|
|
1194
|
+
|
|
1195
|
+
def check_constraint(self, data: dict[str, Tensor]):
|
|
1196
|
+
"""Evaluate whether any sub-constraints are satisfied.
|
|
1197
|
+
|
|
1198
|
+
Args:
|
|
1199
|
+
data: Model predictions and associated batch/context information.
|
|
1200
|
+
|
|
1201
|
+
Returns:
|
|
1202
|
+
tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
|
|
1203
|
+
* `total_satisfaction`: A boolean or numeric tensor indicating
|
|
1204
|
+
elementwise whether any constraints are satisfied
|
|
1205
|
+
(logical OR).
|
|
1206
|
+
* `mask`: A tensor of ones with the same shape as
|
|
1207
|
+
`total_satisfaction`. Typically used as a weighting mask
|
|
1208
|
+
in downstream processing.
|
|
1209
|
+
"""
|
|
1210
|
+
total_satisfaction: Tensor = None
|
|
1211
|
+
total_mask: Tensor = None
|
|
1212
|
+
|
|
1213
|
+
# TODO vectorize this loop
|
|
1214
|
+
for constraint in self.constraints:
|
|
1215
|
+
satisfaction, mask = constraint.check_constraint(data)
|
|
1216
|
+
if total_satisfaction is None:
|
|
1217
|
+
total_satisfaction = satisfaction
|
|
1218
|
+
total_mask = mask
|
|
1219
|
+
else:
|
|
1220
|
+
total_satisfaction = logical_or(total_satisfaction, satisfaction)
|
|
1221
|
+
total_mask = logical_or(total_mask, mask)
|
|
1222
|
+
|
|
1223
|
+
return total_satisfaction.float(), total_mask.float()
|
|
1224
|
+
|
|
1225
|
+
def calculate_direction(self, data: dict[str, Tensor]):
|
|
1226
|
+
"""Compute the corrective direction by aggregating sub-constraint directions.
|
|
1227
|
+
|
|
1228
|
+
Each sub-constraint contributes its corrective direction, weighted
|
|
1229
|
+
by its satisfaction mask. The directions are summed across constraints
|
|
1230
|
+
for each affected layer.
|
|
1231
|
+
|
|
1232
|
+
Args:
|
|
1233
|
+
data: Model predictions and associated batch/context information.
|
|
1234
|
+
|
|
1235
|
+
Returns:
|
|
1236
|
+
dict[str, Tensor]: A mapping from layer identifiers to correction
|
|
1237
|
+
tensors. Each entry represents the aggregated correction to apply
|
|
1238
|
+
to that layer, based on the satisfaction-weighted sum of
|
|
1239
|
+
sub-constraint directions.
|
|
1240
|
+
"""
|
|
1241
|
+
total_direction: dict[str, Tensor] = {}
|
|
1242
|
+
|
|
1243
|
+
# TODO vectorize this loop
|
|
1244
|
+
for constraint in self.constraints:
|
|
1245
|
+
# TODO improve efficiency by avoiding double computation?
|
|
1246
|
+
satisfaction, _ = constraint.check_constraint(data)
|
|
1247
|
+
direction = constraint.calculate_direction(data)
|
|
1248
|
+
|
|
1249
|
+
for layer, dir in direction.items():
|
|
1250
|
+
if layer not in total_direction:
|
|
1251
|
+
total_direction[layer] = satisfaction * dir
|
|
1252
|
+
else:
|
|
1253
|
+
total_direction[layer] += satisfaction * dir
|
|
1254
|
+
|
|
1255
|
+
return total_direction
|