congrads 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1278 @@
1
+ """Module providing constraint classes for guiding neural network training.
2
+
3
+ This module defines constraints that enforce specific conditions on network outputs
4
+ to steer learning. Available constraint types include:
5
+
6
+ - `Constraint`: Base class for all constraint types, defining the interface and core
7
+ behavior.
8
+ - `ImplicationConstraint`: Enforces one condition only if another condition is met,
9
+ useful for modeling implications between outputs.
10
+ - `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
11
+ - `BinaryConstraint`: Enforces a binary comparison between two tags using a
12
+ comparison function (e.g., less than, greater than).
13
+ - `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
14
+ value, controlling total output.
15
+
16
+ These constraints can steer the learning process by applying logical implications
17
+ or numerical bounds.
18
+
19
+ Usage:
20
+ 1. Define a custom constraint class by inheriting from `Constraint`.
21
+ 2. Apply the constraint to your neural network during training.
22
+ 3. Use helper classes like `IdentityTransformation` for transformations and
23
+ comparisons in constraints.
24
+
25
+ """
26
+
27
+ from collections.abc import Callable
28
+ from numbers import Number
29
+ from typing import Literal
30
+
31
+ import torch
32
+ from torch import (
33
+ Tensor,
34
+ argsort,
35
+ eq,
36
+ logical_and,
37
+ logical_not,
38
+ logical_or,
39
+ ones,
40
+ ones_like,
41
+ reshape,
42
+ sign,
43
+ sort,
44
+ stack,
45
+ tensor,
46
+ triu,
47
+ unique,
48
+ zeros_like,
49
+ )
50
+ from torch.nn.functional import normalize
51
+
52
+ from ..transformations.base import Transformation
53
+ from ..transformations.registry import IdentityTransformation
54
+ from ..utils.validation import validate_comparator, validate_iterable, validate_type
55
+ from .base import Constraint, MonotonicityConstraint
56
+
57
+ __all__ = [
58
+ "ImplicationConstraint",
59
+ "ScalarConstraint",
60
+ "BinaryConstraint",
61
+ "SumConstraint",
62
+ "RankedMonotonicityConstraint",
63
+ "PairwiseMonotonicityConstraint",
64
+ "PerGroupMonotonicityConstraint",
65
+ "EncodedGroupedMonotonicityConstraint",
66
+ "ANDConstraint",
67
+ "ORConstraint",
68
+ ]
69
+
70
+
71
+ COMPARATOR_MAP: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
72
+ ">": torch.gt,
73
+ ">=": torch.ge,
74
+ "<": torch.lt,
75
+ "<=": torch.le,
76
+ }
77
+
78
+
79
+ class ImplicationConstraint(Constraint):
80
+ """Represents an implication constraint between two constraints (head and body).
81
+
82
+ The implication constraint ensures that the `body` constraint only applies
83
+ when the `head` constraint is satisfied. If the `head` constraint is not
84
+ satisfied, the `body` constraint does not apply.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ head: Constraint,
90
+ body: Constraint,
91
+ name: str = None,
92
+ ):
93
+ """Initializes an ImplicationConstraint instance.
94
+
95
+ Uses `enforce` and `rescale_factor` from the body constraint.
96
+
97
+ Args:
98
+ head (Constraint): Constraint defining the head of the implication.
99
+ body (Constraint): Constraint defining the body of the implication.
100
+ name (str, optional): A unique name for the constraint. If not
101
+ provided, a name is generated based on the class name and a
102
+ random suffix.
103
+
104
+ Raises:
105
+ TypeError: If a provided attribute has an incompatible type.
106
+
107
+ """
108
+ # Type checking
109
+ validate_type("head", head, Constraint)
110
+ validate_type("body", body, Constraint)
111
+
112
+ # Compose constraint name
113
+ name = f"{body.name} if {head.name}"
114
+
115
+ # Init parent class
116
+ super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
117
+
118
+ self.head = head
119
+ self.body = body
120
+
121
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
122
+ """Check whether the implication constraint is satisfied.
123
+
124
+ Evaluates the `head` and `body` constraints. The `body` constraint
125
+ is enforced only if the `head` constraint is satisfied. If the
126
+ `head` constraint is not satisfied, the `body` constraint does not
127
+ affect the result.
128
+
129
+ Args:
130
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
131
+
132
+ Returns:
133
+ tuple[Tensor, Tensor]:
134
+ - result: Tensor indicating satisfaction of the implication
135
+ constraint (1 if satisfied, 0 otherwise).
136
+ - head_satisfaction: Tensor indicating satisfaction of the
137
+ head constraint alone.
138
+
139
+ """
140
+ # Check satisfaction of head and body constraints
141
+ head_satisfaction, _ = self.head.check_constraint(data)
142
+ body_satisfaction, _ = self.body.check_constraint(data)
143
+
144
+ # If head constraint is satisfied (returning 1),
145
+ # the body constraint matters (and should return 0/1 based on body)
146
+ # If head constraint is not satisfied (returning 0),
147
+ # the body constraint does not apply (and should return 1)
148
+ result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
149
+
150
+ return result, head_satisfaction
151
+
152
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
153
+ """Compute adjustment directions for tags to satisfy the constraint.
154
+
155
+ Uses the `body` constraint directions as the update vector. Only
156
+ applies updates if the `head` constraint is satisfied. Currently,
157
+ this method only works for dense layers due to tag-to-index
158
+ translation limitations.
159
+
160
+ Args:
161
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
162
+
163
+ Returns:
164
+ dict[str, Tensor]: Dictionary mapping tags to tensors
165
+ specifying the adjustment direction for each tag.
166
+ """
167
+ # NOTE currently only works for dense layers
168
+ # due to tag to index translation
169
+
170
+ # Use directions of constraint body as update vector
171
+ return self.body.calculate_direction(data)
172
+
173
+
174
+ class ScalarConstraint(Constraint):
175
+ """A constraint that enforces scalar-based comparisons on a specific tag.
176
+
177
+ This class ensures that the output of a specified tag satisfies a scalar
178
+ comparison operation (e.g., less than, greater than, etc.). It uses a
179
+ comparator function to validate the condition and calculates adjustment
180
+ directions accordingly.
181
+
182
+ Args:
183
+ operand (Union[str, Transformation]): Name of the tag or a
184
+ transformation to apply.
185
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
186
+ scalar (Number): The scalar value to compare against.
187
+ name (str, optional): A unique name for the constraint. If not
188
+ provided, a name is auto-generated in the format
189
+ "<tag> <comparator> <scalar>".
190
+ enforce (bool, optional): If False, only monitor the constraint
191
+ without adjusting the loss. Defaults to True.
192
+ rescale_factor (Number, optional): Factor to scale the
193
+ constraint-adjusted loss. Defaults to 1.5.
194
+
195
+ Raises:
196
+ TypeError: If a provided attribute has an incompatible type.
197
+
198
+ Notes:
199
+ - The `tag` must be defined in the `descriptor` mapping.
200
+ - The constraint name is composed using the tag, comparator, and scalar value.
201
+
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ operand: str | Transformation,
207
+ comparator: Literal[">", "<", ">=", "<="],
208
+ scalar: Number,
209
+ name: str = None,
210
+ enforce: bool = True,
211
+ rescale_factor: Number = 1.5,
212
+ ) -> None:
213
+ """Initializes a ScalarConstraint instance.
214
+
215
+ Args:
216
+ operand (Union[str, Transformation]): Function that needs to be
217
+ performed on the network variables before applying the
218
+ constraint.
219
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
220
+ scalar (Number): Constant to compare the variable to.
221
+ name (str, optional): A unique name for the constraint. If not
222
+ provided, a name is generated based on the class name and a
223
+ random suffix.
224
+ enforce (bool, optional): If False, only monitor the constraint
225
+ without adjusting the loss. Defaults to True.
226
+ rescale_factor (Number, optional): Factor to scale the
227
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
228
+ than 1 to give weight to the constraint.
229
+
230
+ Raises:
231
+ TypeError: If a provided attribute has an incompatible type.
232
+
233
+ Notes:
234
+ - The `tag` must be defined in the `descriptor` mapping.
235
+ - The constraint name is composed using the tag, comparator, and scalar value.
236
+ """
237
+ # Type checking
238
+ validate_type("operand", operand, (str, Transformation))
239
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
240
+ validate_type("scalar", scalar, Number)
241
+
242
+ # If transformation is provided, get tag name, else use IdentityTransformation
243
+ if isinstance(operand, Transformation):
244
+ tag = operand.tag
245
+ transformation = operand
246
+ else:
247
+ tag = operand
248
+ transformation = IdentityTransformation(tag)
249
+
250
+ # Compose constraint name
251
+ name = f"{tag} {comparator} {str(scalar)}"
252
+
253
+ # Init parent class
254
+ super().__init__({tag}, name, enforce, rescale_factor)
255
+
256
+ # Init variables
257
+ self.tag = tag
258
+ self.scalar = scalar
259
+ self.transformation = transformation
260
+
261
+ # Calculate directions based on constraint operator
262
+ if comparator in ["<", "<="]:
263
+ self.direction = 1
264
+ elif comparator in [">", ">="]:
265
+ self.direction = -1
266
+ self.comparator = COMPARATOR_MAP[comparator]
267
+
268
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
269
+ """Check if the scalar constraint is satisfied for a given tag.
270
+
271
+ Args:
272
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
273
+
274
+ Returns:
275
+ tuple[Tensor, Tensor]:
276
+ - result: Tensor indicating whether the tag satisfies the constraint.
277
+ - ones_like(result): Tensor of ones with same shape as `result`.
278
+ """
279
+ # Select relevant columns
280
+ selection = self.descriptor.select(self.tag, data)
281
+
282
+ # Apply transformation
283
+ selection = self.transformation(selection)
284
+
285
+ # Calculate current constraint result
286
+ result = self.comparator(selection, self.scalar).float()
287
+ return result, ones_like(result)
288
+
289
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
290
+ """Compute adjustment directions to satisfy the scalar constraint.
291
+
292
+ Only works for dense layers due to tag-to-index translation.
293
+
294
+ Args:
295
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
296
+
297
+ Returns:
298
+ dict[str, Tensor]: Dictionary mapping layers to tensors specifying
299
+ the adjustment direction for each tag.
300
+ """
301
+ # NOTE currently only works for dense layers due
302
+ # to tag to index translation
303
+
304
+ output = {}
305
+
306
+ for layer in self.layers:
307
+ output[layer] = zeros_like(data[layer][0], device=self.device)
308
+
309
+ layer, index = self.descriptor.location(self.tag)
310
+ output[layer][index] = self.direction
311
+
312
+ for layer in self.layers:
313
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
314
+
315
+ return output
316
+
317
+
318
+ class BinaryConstraint(Constraint):
319
+ """A constraint that enforces a binary comparison between two tags.
320
+
321
+ This class ensures that the output of one tag satisfies a comparison
322
+ operation with the output of another tag (e.g., less than, greater than, etc.).
323
+ It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
324
+
325
+ Args:
326
+ operand_left (Union[str, Transformation]): Name of the left
327
+ tag or a transformation to apply.
328
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
329
+ operand_right (Union[str, Transformation]): Name of the right
330
+ tag or a transformation to apply.
331
+ name (str, optional): A unique name for the constraint. If not
332
+ provided, a name is auto-generated in the format
333
+ "<operand_left> <comparator> <operand_right>".
334
+ enforce (bool, optional): If False, only monitor the constraint
335
+ without adjusting the loss. Defaults to True.
336
+ rescale_factor (Number, optional): Factor to scale the
337
+ constraint-adjusted loss. Defaults to 1.5.
338
+
339
+ Raises:
340
+ TypeError: If a provided attribute has an incompatible type.
341
+
342
+ Notes:
343
+ - The tags must be defined in the `descriptor` mapping.
344
+ - The constraint name is composed using the left tag, comparator, and right tag.
345
+
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ operand_left: str | Transformation,
351
+ comparator: Literal[">", "<", ">=", "<="],
352
+ operand_right: str | Transformation,
353
+ name: str = None,
354
+ enforce: bool = True,
355
+ rescale_factor: Number = 1.5,
356
+ ) -> None:
357
+ """Initializes a BinaryConstraint instance.
358
+
359
+ Args:
360
+ operand_left (Union[str, Transformation]): Name of the left
361
+ tag or a transformation to apply.
362
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
363
+ operand_right (Union[str, Transformation]): Name of the right
364
+ tag or a transformation to apply.
365
+ name (str, optional): A unique name for the constraint. If not
366
+ provided, a name is auto-generated in the format
367
+ "<operand_left> <comparator> <operand_right>".
368
+ enforce (bool, optional): If False, only monitor the constraint
369
+ without adjusting the loss. Defaults to True.
370
+ rescale_factor (Number, optional): Factor to scale the
371
+ constraint-adjusted loss. Defaults to 1.5.
372
+
373
+ Raises:
374
+ TypeError: If a provided attribute has an incompatible type.
375
+
376
+ Notes:
377
+ - The tags must be defined in the `descriptor` mapping.
378
+ - The constraint name is composed using the left tag,
379
+ comparator, and right tag.
380
+ """
381
+ # Type checking
382
+ validate_type("operand_left", operand_left, (str, Transformation))
383
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
384
+ validate_type("operand_right", operand_right, (str, Transformation))
385
+
386
+ # If transformation is provided, get tag name, else use IdentityTransformation
387
+ if isinstance(operand_left, Transformation):
388
+ tag_left = operand_left.tag
389
+ transformation_left = operand_left
390
+ else:
391
+ tag_left = operand_left
392
+ transformation_left = IdentityTransformation(tag_left)
393
+
394
+ if isinstance(operand_right, Transformation):
395
+ tag_right = operand_right.tag
396
+ transformation_right = operand_right
397
+ else:
398
+ tag_right = operand_right
399
+ transformation_right = IdentityTransformation(tag_right)
400
+
401
+ # Compose constraint name
402
+ name = f"{tag_left} {comparator} {tag_right}"
403
+
404
+ # Init parent class
405
+ super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
406
+
407
+ # Init variables
408
+ self.tag_left = tag_left
409
+ self.tag_right = tag_right
410
+ self.transformation_left = transformation_left
411
+ self.transformation_right = transformation_right
412
+
413
+ # Calculate directions based on constraint operator
414
+ if comparator in ["<", "<="]:
415
+ self.direction_left = 1
416
+ self.direction_right = -1
417
+ elif comparator in [">", ">="]:
418
+ self.direction_left = -1
419
+ self.direction_right = 1
420
+ self.comparator = COMPARATOR_MAP[comparator]
421
+
422
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
423
+ """Evaluate whether the binary constraint is satisfied for the current predictions.
424
+
425
+ The constraint compares the outputs of two tags using the specified
426
+ comparator function. A result of `1` indicates the constraint is satisfied
427
+ for a sample, and `0` indicates it is violated.
428
+
429
+ Args:
430
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
431
+
432
+ Returns:
433
+ tuple[Tensor, Tensor]:
434
+ - result (Tensor): Binary tensor indicating constraint satisfaction
435
+ (1 for satisfied, 0 for violated) for each sample.
436
+ - mask (Tensor): Tensor of ones with the same shape as `result`,
437
+ used for constraint aggregation.
438
+
439
+ """
440
+ # Select relevant columns
441
+ selection_left = self.descriptor.select(self.tag_left, data)
442
+ selection_right = self.descriptor.select(self.tag_right, data)
443
+
444
+ # Apply transformations
445
+ selection_left = self.transformation_left(selection_left)
446
+ selection_right = self.transformation_right(selection_right)
447
+
448
+ result = self.comparator(selection_left, selection_right).float()
449
+
450
+ return result, ones_like(result)
451
+
452
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
453
+ """Compute adjustment directions for the tags involved in the binary constraint.
454
+
455
+ The returned directions indicate how to adjust each tag's output to
456
+ satisfy the constraint. Only currently supported for dense layers.
457
+
458
+ Args:
459
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
460
+
461
+ Returns:
462
+ dict[str, Tensor]: A mapping from layer names to tensors specifying
463
+ the normalized adjustment directions for each tag involved in the
464
+ constraint.
465
+ """
466
+ # NOTE currently only works for dense layers due
467
+ # to tag to index translation
468
+
469
+ output = {}
470
+
471
+ for layer in self.layers:
472
+ output[layer] = zeros_like(data[layer][0], device=self.device)
473
+
474
+ layer_left, index_left = self.descriptor.location(self.tag_left)
475
+ layer_right, index_right = self.descriptor.location(self.tag_right)
476
+ output[layer_left][index_left] = self.direction_left
477
+ output[layer_right][index_right] = self.direction_right
478
+
479
+ for layer in self.layers:
480
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
481
+
482
+ return output
483
+
484
+
485
+ class SumConstraint(Constraint):
486
+ """A constraint that enforces a weighted summation comparison between two groups of tags.
487
+
488
+ This class evaluates whether the weighted sum of outputs from one set of
489
+ tags satisfies a comparison operation with the weighted sum of
490
+ outputs from another set of tags.
491
+ """
492
+
493
+ def __init__(
494
+ self,
495
+ operands_left: list[str | Transformation],
496
+ comparator: Literal[">", "<", ">=", "<="],
497
+ operands_right: list[str | Transformation],
498
+ weights_left: list[Number] = None,
499
+ weights_right: list[Number] = None,
500
+ name: str = None,
501
+ enforce: bool = True,
502
+ rescale_factor: Number = 1.5,
503
+ ) -> None:
504
+ """Initializes the SumConstraint.
505
+
506
+ Args:
507
+ operands_left (list[Union[str, Transformation]]): List of tags
508
+ or transformations on the left side.
509
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
510
+ operands_right (list[Union[str, Transformation]]): List of tags
511
+ or transformations on the right side.
512
+ weights_left (list[Number], optional): Weights for the left
513
+ tags. Defaults to None.
514
+ weights_right (list[Number], optional): Weights for the right
515
+ tags. Defaults to None.
516
+ name (str, optional): Unique name for the constraint.
517
+ If None, it's auto-generated. Defaults to None.
518
+ enforce (bool, optional): If False, only monitor the constraint
519
+ without adjusting the loss. Defaults to True.
520
+ rescale_factor (Number, optional): Factor to scale the
521
+ constraint-adjusted loss. Defaults to 1.5.
522
+
523
+ Raises:
524
+ TypeError: If a provided attribute has an incompatible type.
525
+ ValueError: If the dimensions of tags and weights mismatch.
526
+ """
527
+ # Type checking
528
+ validate_iterable("operands_left", operands_left, (str, Transformation))
529
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
530
+ validate_iterable("operands_right", operands_right, (str, Transformation))
531
+ validate_iterable("weights_left", weights_left, Number, allow_none=True)
532
+ validate_iterable("weights_right", weights_right, Number, allow_none=True)
533
+
534
+ # If transformation is provided, get tag, else use IdentityTransformation
535
+ tags_left: list[str] = []
536
+ transformations_left: list[Transformation] = []
537
+ for operand_left in operands_left:
538
+ if isinstance(operand_left, Transformation):
539
+ tag_left = operand_left.tag
540
+ tags_left.append(tag_left)
541
+ transformations_left.append(operand_left)
542
+ else:
543
+ tag_left = operand_left
544
+ tags_left.append(tag_left)
545
+ transformations_left.append(IdentityTransformation(tag_left))
546
+
547
+ tags_right: list[str] = []
548
+ transformations_right: list[Transformation] = []
549
+ for operand_right in operands_right:
550
+ if isinstance(operand_right, Transformation):
551
+ tag_right = operand_right.tag
552
+ tags_right.append(tag_right)
553
+ transformations_right.append(operand_right)
554
+ else:
555
+ tag_right = operand_right
556
+ tags_right.append(tag_right)
557
+ transformations_right.append(IdentityTransformation(tag_right))
558
+
559
+ # Compose constraint name
560
+ w_left = weights_left or [""] * len(tags_left)
561
+ w_right = weights_right or [""] * len(tags_right)
562
+ left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
563
+ right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
564
+ name = f"{left_expr} {comparator} {right_expr}"
565
+
566
+ # Init parent class
567
+ tags = set(tags_left) | set(tags_right)
568
+ super().__init__(tags, name, enforce, rescale_factor)
569
+
570
+ # Init variables
571
+ self.tags_left = tags_left
572
+ self.tags_right = tags_right
573
+ self.transformations_left = transformations_left
574
+ self.transformations_right = transformations_right
575
+
576
+ # If feature list dimensions don't match weight list dimensions, raise error
577
+ if weights_left and (len(tags_left) != len(weights_left)):
578
+ raise ValueError(
579
+ "The dimensions of tags_left don't match with the dimensions of weights_left."
580
+ )
581
+ if weights_right and (len(tags_right) != len(weights_right)):
582
+ raise ValueError(
583
+ "The dimensions of tags_right don't match with the dimensions of weights_right."
584
+ )
585
+
586
+ # If weights are provided for summation, transform them to Tensors
587
+ if weights_left:
588
+ self.weights_left = tensor(weights_left, device=self.device)
589
+ else:
590
+ self.weights_left = ones(len(tags_left), device=self.device)
591
+ if weights_right:
592
+ self.weights_right = tensor(weights_right, device=self.device)
593
+ else:
594
+ self.weights_right = ones(len(tags_right), device=self.device)
595
+
596
+ # Calculate directions based on constraint operator
597
+ if comparator in ["<", "<="]:
598
+ self.direction_left = -1
599
+ self.direction_right = 1
600
+ elif comparator in [">", ">="]:
601
+ self.direction_left = 1
602
+ self.direction_right = -1
603
+ self.comparator = COMPARATOR_MAP[comparator]
604
+
605
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
606
+ """Evaluate whether the weighted sum constraint is satisfied.
607
+
608
+ Computes the weighted sum of outputs from the left and right tags,
609
+ applies the specified comparator function, and returns a binary result for
610
+ each sample.
611
+
612
+ Args:
613
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
614
+
615
+ Returns:
616
+ tuple[Tensor, Tensor]:
617
+ - result (Tensor): Binary tensor indicating whether the constraint
618
+ is satisfied (1) or violated (0) for each sample.
619
+ - mask (Tensor): Tensor of ones, used for constraint aggregation.
620
+
621
+ """
622
+
623
+ def compute_weighted_sum(
624
+ tags: list[str],
625
+ transformations: list[Transformation],
626
+ weights: Tensor,
627
+ ) -> Tensor:
628
+ # Select relevant columns
629
+ selections = [self.descriptor.select(tag, data) for tag in tags]
630
+
631
+ # Apply transformations
632
+ results = []
633
+ for transformation, selection in zip(transformations, selections, strict=False):
634
+ results.append(transformation(selection))
635
+
636
+ # Extract predictions for all tags and apply weights in bulk
637
+ predictions = stack(results)
638
+
639
+ # Calculate weighted sum
640
+ return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
641
+
642
+ # Compute weighted sums
643
+ weighted_sum_left = compute_weighted_sum(
644
+ self.tags_left, self.transformations_left, self.weights_left
645
+ )
646
+ weighted_sum_right = compute_weighted_sum(
647
+ self.tags_right, self.transformations_right, self.weights_right
648
+ )
649
+
650
+ # Apply the comparator and calculate the result
651
+ result = self.comparator(weighted_sum_left, weighted_sum_right).float()
652
+
653
+ return result, ones_like(result)
654
+
655
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
656
+ """Compute adjustment directions for tags involved in the weighted sum constraint.
657
+
658
+ The directions indicate how to adjust each tag's output to satisfy the
659
+ constraint. Only dense layers are currently supported.
660
+
661
+ Args:
662
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
663
+
664
+ Returns:
665
+ dict[str, Tensor]: Mapping from layer names to normalized tensors
666
+ specifying adjustment directions for each tag involved in the constraint.
667
+ """
668
+ # NOTE currently only works for dense layers
669
+ # due to tag to index translation
670
+
671
+ output = {}
672
+
673
+ for layer in self.layers:
674
+ output[layer] = zeros_like(data[layer][0], device=self.device)
675
+
676
+ for tag_left in self.tags_left:
677
+ layer, index = self.descriptor.location(tag_left)
678
+ output[layer][index] = self.direction_left
679
+
680
+ for tag_right in self.tags_right:
681
+ layer, index = self.descriptor.location(tag_right)
682
+ output[layer][index] = self.direction_right
683
+
684
+ for layer in self.layers:
685
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
686
+
687
+ return output
688
+
689
+
690
+ class RankedMonotonicityConstraint(MonotonicityConstraint):
691
+ """Constraint that enforces a monotonic relationship between two tags.
692
+
693
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
694
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
695
+ """
696
+
697
+ def __init__(
698
+ self,
699
+ tag_prediction: str,
700
+ tag_reference: str,
701
+ rescale_factor_lower: float = 1.5,
702
+ rescale_factor_upper: float = 1.75,
703
+ stable: bool = True,
704
+ direction: Literal["ascending", "descending"] = "ascending",
705
+ name: str = None,
706
+ enforce: bool = True,
707
+ ):
708
+ """Constraint that enforces monotonicity on a predicted output.
709
+
710
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
711
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
712
+
713
+ Args:
714
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
715
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
716
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
717
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
718
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
719
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
720
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
721
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
722
+ """
723
+ # Compose constraint name
724
+ if name is None:
725
+ name = f"{tag_prediction} monotonically (ranked) {direction} by {tag_reference}"
726
+
727
+ # Init parent class
728
+ super().__init__(
729
+ tag_prediction=tag_prediction,
730
+ tag_reference=tag_reference,
731
+ rescale_factor_lower=rescale_factor_lower,
732
+ rescale_factor_upper=rescale_factor_upper,
733
+ stable=stable,
734
+ direction=direction,
735
+ name=name,
736
+ enforce=enforce,
737
+ )
738
+
739
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
740
+ """Evaluate whether the monotonicity constraint is satisfied."""
741
+ # Select relevant columns
742
+ preds = self.descriptor.select(self.tag_prediction, data)
743
+ targets = self.descriptor.select(self.tag_reference, data)
744
+
745
+ # Utility: convert values -> ranks (0 ... num_features-1)
746
+ def compute_ranks(x: Tensor, descending: bool) -> Tensor:
747
+ return argsort(
748
+ argsort(x, descending=descending, stable=self.stable, dim=0),
749
+ descending=False,
750
+ stable=self.stable,
751
+ dim=0,
752
+ )
753
+
754
+ # Compute predicted and target ranks
755
+ pred_ranks = compute_ranks(preds, descending=self.descending)
756
+ target_ranks = compute_ranks(targets, descending=False)
757
+
758
+ # Rank difference
759
+ rank_diff = pred_ranks - target_ranks
760
+
761
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
762
+ batch_size = preds.shape[0]
763
+ invert_direction = -1 if self.descending else 1
764
+ self.compared_rankings = (
765
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
766
+ + self.rescale_factor_lower * sign(rank_diff)
767
+ ) * invert_direction
768
+
769
+ # Calculate satisfaction
770
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
771
+
772
+ return incorrect_rankings, ones_like(incorrect_rankings)
773
+
774
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
775
+ """Calculates ranking adjustments for monotonicity enforcement."""
776
+ layer, _ = self.descriptor.location(self.tag_prediction)
777
+ return {layer: self.compared_rankings}
778
+
779
+
780
+ class PairwiseMonotonicityConstraint(MonotonicityConstraint):
781
+ """Constraint that enforces a monotonic relationship between two tags.
782
+
783
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
784
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
785
+ """
786
+
787
+ def __init__(
788
+ self,
789
+ tag_prediction: str,
790
+ tag_reference: str,
791
+ rescale_factor_lower: float = 1.5,
792
+ rescale_factor_upper: float = 1.75,
793
+ stable: bool = True,
794
+ direction: Literal["ascending", "descending"] = "ascending",
795
+ name: str = None,
796
+ enforce: bool = True,
797
+ ):
798
+ """Constraint that enforces monotonicity on a predicted output.
799
+
800
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
801
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
802
+
803
+ Args:
804
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
805
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
806
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
807
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
808
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
809
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
810
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
811
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
812
+ """
813
+ # Compose constraint name
814
+ if name is None:
815
+ name = f"{tag_prediction} monotonically (pairwise) {direction} by {tag_reference}"
816
+
817
+ # Init parent class
818
+ super().__init__(
819
+ tag_prediction=tag_prediction,
820
+ tag_reference=tag_reference,
821
+ rescale_factor_lower=rescale_factor_lower,
822
+ rescale_factor_upper=rescale_factor_upper,
823
+ stable=stable,
824
+ direction=direction,
825
+ name=name,
826
+ enforce=enforce,
827
+ )
828
+
829
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
830
+ """Evaluate whether the monotonicity constraint is satisfied."""
831
+ # Select relevant columns
832
+ preds = self.descriptor.select(self.tag_prediction, data).squeeze(1)
833
+ targets = self.descriptor.select(self.tag_reference, data).squeeze(1)
834
+
835
+ # Sort targets back to original order
836
+ sorted_targets, idx = sort(targets, stable=self.stable, dim=0, descending=self.descending)
837
+ sorted_preds = preds[idx]
838
+
839
+ # Pairwise differences
840
+ preds_diff = sorted_preds.unsqueeze(1) - sorted_preds.unsqueeze(0)
841
+ targets_diff = sorted_targets.unsqueeze(1) - sorted_targets.unsqueeze(0)
842
+
843
+ # Consider only upper triangle to avoid duplicate comparisons
844
+ batch_size = preds.shape[0]
845
+ mask = triu(ones(batch_size, batch_size, dtype=torch.bool, device=preds.device), diagonal=1)
846
+
847
+ # Pairwise violations
848
+ violations = (preds_diff * targets_diff < 0) & mask
849
+
850
+ # Rank difference
851
+ sorted_rank_diff = violations.float().sum(dim=1) - violations.float().sum(dim=0)
852
+
853
+ # Undo sorting to match original order
854
+ rank_diff = zeros_like(sorted_rank_diff)
855
+ rank_diff[idx] = sorted_rank_diff
856
+ rank_diff = rank_diff.unsqueeze(1)
857
+
858
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
859
+ invert_direction = -1 if self.descending else 1
860
+ self.compared_rankings = (
861
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
862
+ + self.rescale_factor_lower * sign(rank_diff)
863
+ ) * invert_direction
864
+
865
+ # Calculate satisfaction
866
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
867
+
868
+ return incorrect_rankings, ones_like(incorrect_rankings)
869
+
870
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
871
+ """Calculates ranking adjustments for monotonicity enforcement."""
872
+ layer, _ = self.descriptor.location(self.tag_prediction)
873
+ return {layer: self.compared_rankings}
874
+
875
+
876
+ class PerGroupMonotonicityConstraint(Constraint):
877
+ """Applies a monotonicity constraint independently per group of samples.
878
+
879
+ This class wraps an existing `MonotonicityConstraint` instance (`base`) and
880
+ enforces it **separately** for each unique group defined by `tag_group`.
881
+
882
+ Each group is treated as an independent mini-batch:
883
+ - The base constraint is applied to the group's subset of data.
884
+ - Violations and directions are computed per group and then reassembled into the original batch order.
885
+
886
+ This is an explicit alternative to :class:`EncodedGroupedMonotonicityConstraint`,
887
+ which enforces the same logic using vectorized interval encoding.
888
+ """
889
+
890
+ def __init__(
891
+ self,
892
+ base: MonotonicityConstraint,
893
+ tag_group: str,
894
+ name: str | None = None,
895
+ ):
896
+ """Initializes the per-group monotonicity constraint.
897
+
898
+ Args:
899
+ base (MonotonicityConstraint): A pre-configured monotonicity constraint to apply per group.
900
+ tag_group (str): Column/tag name that identifies groups.
901
+ name (str | None, optional): Custom name for the constraint. Defaults to
902
+ f"{base.tag_prediction} for each {tag_group} ...".
903
+ """
904
+ # Compose constraint name
905
+ if name is None:
906
+ name = base.name.replace(
907
+ base.tag_prediction, f"{base.tag_prediction} for each {tag_group}", 1
908
+ )
909
+
910
+ # Init parent class
911
+ super().__init__(base.tags, name, base.enforce, base.rescale_factor)
912
+
913
+ # Init variables
914
+ self.base = base
915
+ self.tag_group = tag_group
916
+
917
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
918
+ """Evaluate whether the monotonicity constraint is satisfied per group."""
919
+ # Select group identifiers and convert to unique list
920
+ group_identifiers = self.descriptor.select(self.tag_group, data)
921
+ unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
922
+
923
+ # Initialize checks and directions
924
+ checks = zeros_like(group_identifiers, device=self.device)
925
+ self.directions = zeros_like(group_identifiers, device=self.device)
926
+
927
+ # Get prediction and target keys
928
+ preds_key, _ = self.descriptor.location(self.base.tag_prediction)
929
+ targets_key, _ = self.descriptor.location(self.base.tag_reference)
930
+
931
+ for group_identifier in unique_group_identifiers:
932
+ # Create mask for the samples in this group
933
+ group_mask = (group_identifiers == group_identifier).squeeze(1)
934
+
935
+ # Create mini-batch for the group
936
+ group_data = {
937
+ preds_key: data[preds_key][group_mask],
938
+ targets_key: data[targets_key][group_mask],
939
+ }
940
+
941
+ # Call super on the mini-batch
942
+ checks[group_mask], _ = self.base.check_constraint(group_data)
943
+ self.directions[group_mask] = self.base.compared_rankings
944
+
945
+ return checks, ones_like(checks)
946
+
947
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
948
+ """Calculates ranking adjustments for monotonicity enforcement."""
949
+ layer, _ = self.descriptor.location(self.base.tag_prediction)
950
+ return {layer: self.directions}
951
+
952
+
953
+ class EncodedGroupedMonotonicityConstraint(Constraint):
954
+ """Applies a base monotonicity constraint to groups via interval encoding.
955
+
956
+ This constraint enforces a monotonic relationship between a prediction tag
957
+ (`base.tag_prediction`) and a reference tag (`base.tag_reference`) for
958
+ each group defined by `tag_group`.
959
+
960
+ Instead of looping over groups, each group's predictions and targets are
961
+ shifted by a large offset to place them in **non-overlapping intervals**.
962
+ This allows the base constraint to be applied in a single, vectorized
963
+ operation.
964
+
965
+ This is a vectorized alternative to :class:`PerGroupMonotonicityConstraint`,
966
+ which enforces the same logic via explicit per-group iteration.
967
+ """
968
+
969
+ def __init__(
970
+ self,
971
+ base: MonotonicityConstraint,
972
+ tag_group: str,
973
+ name: str | None = None,
974
+ ):
975
+ """Initializes the encoded grouped monotonicity constraint.
976
+
977
+ Args:
978
+ base (MonotonicityConstraint): The base constraint to apply to each group.
979
+ tag_group (str): Column/tag name identifying groups in the batch.
980
+ name (str | None, optional): Optional custom name for this constraint.
981
+ If None, generates a descriptive name based on `base.name`.
982
+ """
983
+ # Compose constraint name
984
+ if name is None:
985
+ name = base.name.replace(
986
+ base.tag_prediction,
987
+ f"{base.tag_prediction} for each {tag_group} (encoded)",
988
+ 1,
989
+ )
990
+
991
+ # Init parent class
992
+ super().__init__(base.tags, name, base.enforce, base.rescale_factor)
993
+
994
+ # Init variables
995
+ self.base = base
996
+ self.tag_group = tag_group
997
+
998
+ # Initialize negation factor based on direction
999
+ self.negation = 1 if self.base.direction == "ascending" else -1
1000
+
1001
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
1002
+ """Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval."""
1003
+ # Get data and keys
1004
+ ids = self.descriptor.select(self.tag_group, data)
1005
+ preds_key, _ = self.descriptor.location(self.base.tag_prediction)
1006
+ targets_key, _ = self.descriptor.location(self.base.tag_reference)
1007
+
1008
+ preds = data[preds_key]
1009
+ targets = data[targets_key]
1010
+
1011
+ new_preds = preds + ids * (preds.amax(dim=0) - preds.amin(dim=0) + 1)
1012
+ new_targets = targets + self.negation * ids * (
1013
+ targets.amax(dim=0) - targets.amin(dim=0) + 1
1014
+ )
1015
+
1016
+ # Create new batch for child constraint
1017
+ new_data = {preds_key: new_preds, targets_key: new_targets}
1018
+
1019
+ # Call super on the adjusted batch
1020
+ checks, _ = self.base.check_constraint(new_data)
1021
+ self.directions = self.base.compared_rankings
1022
+
1023
+ return checks, ones_like(checks)
1024
+
1025
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
1026
+ """Calculates ranking adjustments for monotonicity enforcement."""
1027
+ layer, _ = self.descriptor.location(self.base.tag_prediction)
1028
+ return {layer: self.directions}
1029
+
1030
+
1031
+ class ANDConstraint(Constraint):
1032
+ """A composite constraint that enforces the logical AND of multiple constraints.
1033
+
1034
+ This class combines multiple sub-constraints and evaluates them jointly:
1035
+
1036
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1037
+ are satisfied (elementwise logical AND).
1038
+ * The corrective direction is computed by weighting each sub-constraint's
1039
+ direction with its satisfaction mask and summing across all sub-constraints.
1040
+
1041
+ """
1042
+
1043
+ def __init__(
1044
+ self,
1045
+ *constraints: Constraint,
1046
+ name: str = None,
1047
+ enforce: bool = False,
1048
+ rescale_factor: Number = 1.5,
1049
+ ) -> None:
1050
+ """A composite constraint that enforces the logical AND of multiple constraints.
1051
+
1052
+ This class combines multiple sub-constraints and evaluates them jointly:
1053
+
1054
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1055
+ are satisfied (elementwise logical AND).
1056
+ * The corrective direction is computed by weighting each sub-constraint's
1057
+ direction with its satisfaction mask and summing across all sub-constraints.
1058
+
1059
+ Args:
1060
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1061
+ name (str, optional): A custom name for this constraint. If not provided,
1062
+ the name will be composed from the sub-constraint names joined with
1063
+ " AND ".
1064
+ enforce (bool, optional): If True, the constraint will be monitored
1065
+ but not enforced. Defaults to False.
1066
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1067
+ corrections. Defaults to 1.5.
1068
+
1069
+ Attributes:
1070
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1071
+ neurons (set): The union of neurons referenced by the sub-constraints.
1072
+ name (str): The name of the constraint (composed or custom).
1073
+ """
1074
+ # Type checking
1075
+ validate_iterable("constraints", constraints, Constraint)
1076
+
1077
+ # Compose constraint name
1078
+ if not name:
1079
+ name = " AND ".join([constraint.name for constraint in constraints])
1080
+
1081
+ # Init parent class
1082
+ super().__init__(
1083
+ set().union(*(constraint.tags for constraint in constraints)),
1084
+ name,
1085
+ enforce,
1086
+ rescale_factor,
1087
+ )
1088
+
1089
+ # Init variables
1090
+ self.constraints = constraints
1091
+
1092
+ def check_constraint(self, data: dict[str, Tensor]):
1093
+ """Evaluate whether all sub-constraints are satisfied.
1094
+
1095
+ Args:
1096
+ data: Model predictions and associated batch/context information.
1097
+
1098
+ Returns:
1099
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1100
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1101
+ elementwise whether all constraints are satisfied
1102
+ (logical AND).
1103
+ * `mask`: A tensor of ones with the same shape as
1104
+ `total_satisfaction`. Typically used as a weighting mask
1105
+ in downstream processing.
1106
+
1107
+ """
1108
+ total_satisfaction: Tensor = None
1109
+ total_mask: Tensor = None
1110
+
1111
+ # TODO vectorize this loop
1112
+ for constraint in self.constraints:
1113
+ satisfaction, mask = constraint.check_constraint(data)
1114
+ if total_satisfaction is None:
1115
+ total_satisfaction = satisfaction
1116
+ total_mask = mask
1117
+ else:
1118
+ total_satisfaction = logical_and(total_satisfaction, satisfaction)
1119
+ total_mask = logical_or(total_mask, mask)
1120
+
1121
+ return total_satisfaction.float(), total_mask.float()
1122
+
1123
+ def calculate_direction(self, data: dict[str, Tensor]):
1124
+ """Compute the corrective direction by aggregating sub-constraint directions.
1125
+
1126
+ Each sub-constraint contributes its corrective direction, weighted
1127
+ by its satisfaction mask. The directions are summed across constraints
1128
+ for each affected layer.
1129
+
1130
+ Args:
1131
+ data: Model predictions and associated batch/context information.
1132
+
1133
+ Returns:
1134
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1135
+ tensors. Each entry represents the aggregated correction to apply
1136
+ to that layer, based on the satisfaction-weighted sum of
1137
+ sub-constraint directions.
1138
+ """
1139
+ total_direction: dict[str, Tensor] = {}
1140
+
1141
+ # TODO vectorize this loop
1142
+ for constraint in self.constraints:
1143
+ # TODO improve efficiency by avoiding double computation?
1144
+ satisfaction, _ = constraint.check_constraint(data)
1145
+ direction = constraint.calculate_direction(data)
1146
+
1147
+ for layer, dir in direction.items():
1148
+ if layer not in total_direction:
1149
+ total_direction[layer] = satisfaction * dir
1150
+ else:
1151
+ total_direction[layer] += satisfaction * dir
1152
+
1153
+ return total_direction
1154
+
1155
+
1156
+ class ORConstraint(Constraint):
1157
+ """A composite constraint that enforces the logical OR of multiple constraints.
1158
+
1159
+ This class combines multiple sub-constraints and evaluates them jointly:
1160
+
1161
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1162
+ is satisfied (elementwise logical OR).
1163
+ * The corrective direction is computed by weighting each sub-constraint's
1164
+ direction with its satisfaction mask and summing across all sub-constraints.
1165
+
1166
+ """
1167
+
1168
+ def __init__(
1169
+ self,
1170
+ *constraints: Constraint,
1171
+ name: str = None,
1172
+ enforce: bool = False,
1173
+ rescale_factor: Number = 1.5,
1174
+ ) -> None:
1175
+ """A composite constraint that enforces the logical OR of multiple constraints.
1176
+
1177
+ This class combines multiple sub-constraints and evaluates them jointly:
1178
+
1179
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1180
+ is satisfied (elementwise logical OR).
1181
+ * The corrective direction is computed by weighting each sub-constraint's
1182
+ direction with its satisfaction mask and summing across all sub-constraints.
1183
+
1184
+ Args:
1185
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1186
+ name (str, optional): A custom name for this constraint. If not provided,
1187
+ the name will be composed from the sub-constraint names joined with
1188
+ " OR ".
1189
+ enforce (bool, optional): If True, the constraint will be monitored
1190
+ but not enforced. Defaults to False.
1191
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1192
+ corrections. Defaults to 1.5.
1193
+
1194
+ Attributes:
1195
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1196
+ neurons (set): The union of neurons referenced by the sub-constraints.
1197
+ name (str): The name of the constraint (composed or custom).
1198
+ """
1199
+ # Type checking
1200
+ validate_iterable("constraints", constraints, Constraint)
1201
+
1202
+ # Compose constraint name
1203
+ if not name:
1204
+ name = " OR ".join([constraint.name for constraint in constraints])
1205
+
1206
+ # Init parent class
1207
+ super().__init__(
1208
+ set().union(*(constraint.tags for constraint in constraints)),
1209
+ name,
1210
+ enforce,
1211
+ rescale_factor,
1212
+ )
1213
+
1214
+ # Init variables
1215
+ self.constraints = constraints
1216
+
1217
+ def check_constraint(self, data: dict[str, Tensor]):
1218
+ """Evaluate whether any sub-constraints are satisfied.
1219
+
1220
+ Args:
1221
+ data: Model predictions and associated batch/context information.
1222
+
1223
+ Returns:
1224
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1225
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1226
+ elementwise whether any constraints are satisfied
1227
+ (logical OR).
1228
+ * `mask`: A tensor of ones with the same shape as
1229
+ `total_satisfaction`. Typically used as a weighting mask
1230
+ in downstream processing.
1231
+
1232
+ """
1233
+ total_satisfaction: Tensor = None
1234
+ total_mask: Tensor = None
1235
+
1236
+ # TODO vectorize this loop
1237
+ for constraint in self.constraints:
1238
+ satisfaction, mask = constraint.check_constraint(data)
1239
+ if total_satisfaction is None:
1240
+ total_satisfaction = satisfaction
1241
+ total_mask = mask
1242
+ else:
1243
+ total_satisfaction = logical_or(total_satisfaction, satisfaction)
1244
+ total_mask = logical_or(total_mask, mask)
1245
+
1246
+ return total_satisfaction.float(), total_mask.float()
1247
+
1248
+ def calculate_direction(self, data: dict[str, Tensor]):
1249
+ """Compute the corrective direction by aggregating sub-constraint directions.
1250
+
1251
+ Each sub-constraint contributes its corrective direction, weighted
1252
+ by its satisfaction mask. The directions are summed across constraints
1253
+ for each affected layer.
1254
+
1255
+ Args:
1256
+ data: Model predictions and associated batch/context information.
1257
+
1258
+ Returns:
1259
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1260
+ tensors. Each entry represents the aggregated correction to apply
1261
+ to that layer, based on the satisfaction-weighted sum of
1262
+ sub-constraint directions.
1263
+ """
1264
+ total_direction: dict[str, Tensor] = {}
1265
+
1266
+ # TODO vectorize this loop
1267
+ for constraint in self.constraints:
1268
+ # TODO improve efficiency by avoiding double computation?
1269
+ satisfaction, _ = constraint.check_constraint(data)
1270
+ direction = constraint.calculate_direction(data)
1271
+
1272
+ for layer, dir in direction.items():
1273
+ if layer not in total_direction:
1274
+ total_direction[layer] = satisfaction * dir
1275
+ else:
1276
+ total_direction[layer] += satisfaction * dir
1277
+
1278
+ return total_direction