congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1255 @@
1
+ """Module providing constraint classes for guiding neural network training.
2
+
3
+ This module defines constraints that enforce specific conditions on network outputs
4
+ to steer learning. Available constraint types include:
5
+
6
+ - `Constraint`: Base class for all constraint types, defining the interface and core
7
+ behavior.
8
+ - `ImplicationConstraint`: Enforces one condition only if another condition is met,
9
+ useful for modeling implications between outputs.
10
+ - `ScalarConstraint`: Enforces scalar-based comparisons on a network's output.
11
+ - `BinaryConstraint`: Enforces a binary comparison between two tags using a
12
+ comparison function (e.g., less than, greater than).
13
+ - `SumConstraint`: Ensures the sum of selected tags' outputs equals a specified
14
+ value, controlling total output.
15
+
16
+ These constraints can steer the learning process by applying logical implications
17
+ or numerical bounds.
18
+
19
+ Usage:
20
+ 1. Define a custom constraint class by inheriting from `Constraint`.
21
+ 2. Apply the constraint to your neural network during training.
22
+ 3. Use helper classes like `IdentityTransformation` for transformations and
23
+ comparisons in constraints.
24
+
25
+ """
26
+
27
+ from collections.abc import Callable
28
+ from numbers import Number
29
+ from typing import Literal
30
+
31
+ import torch
32
+ from torch import (
33
+ Tensor,
34
+ argsort,
35
+ eq,
36
+ logical_and,
37
+ logical_not,
38
+ logical_or,
39
+ ones,
40
+ ones_like,
41
+ reshape,
42
+ sign,
43
+ sort,
44
+ stack,
45
+ tensor,
46
+ triu,
47
+ unique,
48
+ zeros_like,
49
+ )
50
+ from torch.nn.functional import normalize
51
+
52
+ from ..transformations.base import Transformation
53
+ from ..transformations.registry import IdentityTransformation
54
+ from ..utils.validation import validate_comparator, validate_iterable, validate_type
55
+ from .base import Constraint, MonotonicityConstraint
56
+
57
+ COMPARATOR_MAP: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
58
+ ">": torch.gt,
59
+ ">=": torch.ge,
60
+ "<": torch.lt,
61
+ "<=": torch.le,
62
+ }
63
+
64
+
65
+ class ImplicationConstraint(Constraint):
66
+ """Represents an implication constraint between two constraints (head and body).
67
+
68
+ The implication constraint ensures that the `body` constraint only applies
69
+ when the `head` constraint is satisfied. If the `head` constraint is not
70
+ satisfied, the `body` constraint does not apply.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ head: Constraint,
76
+ body: Constraint,
77
+ name: str = None,
78
+ ):
79
+ """Initializes an ImplicationConstraint instance.
80
+
81
+ Uses `enforce` and `rescale_factor` from the body constraint.
82
+
83
+ Args:
84
+ head (Constraint): Constraint defining the head of the implication.
85
+ body (Constraint): Constraint defining the body of the implication.
86
+ name (str, optional): A unique name for the constraint. If not
87
+ provided, a name is generated based on the class name and a
88
+ random suffix.
89
+
90
+ Raises:
91
+ TypeError: If a provided attribute has an incompatible type.
92
+
93
+ """
94
+ # Type checking
95
+ validate_type("head", head, Constraint)
96
+ validate_type("body", body, Constraint)
97
+
98
+ # Compose constraint name
99
+ name = f"{body.name} if {head.name}"
100
+
101
+ # Init parent class
102
+ super().__init__(head.tags | body.tags, name, body.enforce, body.rescale_factor)
103
+
104
+ self.head = head
105
+ self.body = body
106
+
107
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
108
+ """Check whether the implication constraint is satisfied.
109
+
110
+ Evaluates the `head` and `body` constraints. The `body` constraint
111
+ is enforced only if the `head` constraint is satisfied. If the
112
+ `head` constraint is not satisfied, the `body` constraint does not
113
+ affect the result.
114
+
115
+ Args:
116
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
117
+
118
+ Returns:
119
+ tuple[Tensor, Tensor]:
120
+ - result: Tensor indicating satisfaction of the implication
121
+ constraint (1 if satisfied, 0 otherwise).
122
+ - head_satisfaction: Tensor indicating satisfaction of the
123
+ head constraint alone.
124
+ """
125
+ # Check satisfaction of head and body constraints
126
+ head_satisfaction, _ = self.head.check_constraint(data)
127
+ body_satisfaction, _ = self.body.check_constraint(data)
128
+
129
+ # If head constraint is satisfied (returning 1),
130
+ # the body constraint matters (and should return 0/1 based on body)
131
+ # If head constraint is not satisfied (returning 0),
132
+ # the body constraint does not apply (and should return 1)
133
+ result = logical_or(logical_not(head_satisfaction), body_satisfaction).float()
134
+
135
+ return result, head_satisfaction
136
+
137
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
138
+ """Compute adjustment directions for tags to satisfy the constraint.
139
+
140
+ Uses the `body` constraint directions as the update vector. Only
141
+ applies updates if the `head` constraint is satisfied. Currently,
142
+ this method only works for dense layers due to tag-to-index
143
+ translation limitations.
144
+
145
+ Args:
146
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
147
+
148
+ Returns:
149
+ dict[str, Tensor]: Dictionary mapping tags to tensors
150
+ specifying the adjustment direction for each tag.
151
+ """
152
+ # NOTE currently only works for dense layers
153
+ # due to tag to index translation
154
+
155
+ # Use directions of constraint body as update vector
156
+ return self.body.calculate_direction(data)
157
+
158
+
159
+ class ScalarConstraint(Constraint):
160
+ """A constraint that enforces scalar-based comparisons on a specific tag.
161
+
162
+ This class ensures that the output of a specified tag satisfies a scalar
163
+ comparison operation (e.g., less than, greater than, etc.). It uses a
164
+ comparator function to validate the condition and calculates adjustment
165
+ directions accordingly.
166
+
167
+ Args:
168
+ operand (Union[str, Transformation]): Name of the tag or a
169
+ transformation to apply.
170
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
171
+ scalar (Number): The scalar value to compare against.
172
+ name (str, optional): A unique name for the constraint. If not
173
+ provided, a name is auto-generated in the format
174
+ "<tag> <comparator> <scalar>".
175
+ enforce (bool, optional): If False, only monitor the constraint
176
+ without adjusting the loss. Defaults to True.
177
+ rescale_factor (Number, optional): Factor to scale the
178
+ constraint-adjusted loss. Defaults to 1.5.
179
+
180
+ Raises:
181
+ TypeError: If a provided attribute has an incompatible type.
182
+
183
+ Notes:
184
+ - The `tag` must be defined in the `descriptor` mapping.
185
+ - The constraint name is composed using the tag, comparator, and scalar value.
186
+
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ operand: str | Transformation,
192
+ comparator: Literal[">", "<", ">=", "<="],
193
+ scalar: Number,
194
+ name: str = None,
195
+ enforce: bool = True,
196
+ rescale_factor: Number = 1.5,
197
+ ) -> None:
198
+ """Initializes a ScalarConstraint instance.
199
+
200
+ Args:
201
+ operand (Union[str, Transformation]): Function that needs to be
202
+ performed on the network variables before applying the
203
+ constraint.
204
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
205
+ scalar (Number): Constant to compare the variable to.
206
+ name (str, optional): A unique name for the constraint. If not
207
+ provided, a name is generated based on the class name and a
208
+ random suffix.
209
+ enforce (bool, optional): If False, only monitor the constraint
210
+ without adjusting the loss. Defaults to True.
211
+ rescale_factor (Number, optional): Factor to scale the
212
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
213
+ than 1 to give weight to the constraint.
214
+
215
+ Raises:
216
+ TypeError: If a provided attribute has an incompatible type.
217
+
218
+ Notes:
219
+ - The `tag` must be defined in the `descriptor` mapping.
220
+ - The constraint name is composed using the tag, comparator, and scalar value.
221
+ """
222
+ # Type checking
223
+ validate_type("operand", operand, (str, Transformation))
224
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
225
+ validate_type("scalar", scalar, Number)
226
+
227
+ # If transformation is provided, get tag name, else use IdentityTransformation
228
+ if isinstance(operand, Transformation):
229
+ tag = operand.tag
230
+ transformation = operand
231
+ else:
232
+ tag = operand
233
+ transformation = IdentityTransformation(tag)
234
+
235
+ # Compose constraint name
236
+ name = f"{tag} {comparator} {str(scalar)}"
237
+
238
+ # Init parent class
239
+ super().__init__({tag}, name, enforce, rescale_factor)
240
+
241
+ # Init variables
242
+ self.tag = tag
243
+ self.scalar = scalar
244
+ self.transformation = transformation
245
+
246
+ # Calculate directions based on constraint operator
247
+ if comparator in ["<", "<="]:
248
+ self.direction = 1
249
+ elif comparator in [">", ">="]:
250
+ self.direction = -1
251
+ self.comparator = COMPARATOR_MAP[comparator]
252
+
253
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
254
+ """Check if the scalar constraint is satisfied for a given tag.
255
+
256
+ Args:
257
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
258
+
259
+ Returns:
260
+ tuple[Tensor, Tensor]:
261
+ - result: Tensor indicating whether the tag satisfies the constraint.
262
+ - ones_like(result): Tensor of ones with same shape as `result`.
263
+ """
264
+ # Select relevant columns
265
+ selection = self.descriptor.select(self.tag, data)
266
+
267
+ # Apply transformation
268
+ selection = self.transformation(selection)
269
+
270
+ # Calculate current constraint result
271
+ result = self.comparator(selection, self.scalar).float()
272
+ return result, ones_like(result)
273
+
274
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
275
+ """Compute adjustment directions to satisfy the scalar constraint.
276
+
277
+ Only works for dense layers due to tag-to-index translation.
278
+
279
+ Args:
280
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
281
+
282
+ Returns:
283
+ dict[str, Tensor]: Dictionary mapping layers to tensors specifying
284
+ the adjustment direction for each tag.
285
+ """
286
+ # NOTE currently only works for dense layers due
287
+ # to tag to index translation
288
+
289
+ output = {}
290
+
291
+ for layer in self.layers:
292
+ output[layer] = zeros_like(data[layer][0], device=self.device)
293
+
294
+ layer, index = self.descriptor.location(self.tag)
295
+ output[layer][index] = self.direction
296
+
297
+ for layer in self.layers:
298
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
299
+
300
+ return output
301
+
302
+
303
+ class BinaryConstraint(Constraint):
304
+ """A constraint that enforces a binary comparison between two tags.
305
+
306
+ This class ensures that the output of one tag satisfies a comparison
307
+ operation with the output of another tag (e.g., less than, greater than, etc.).
308
+ It uses a comparator function to validate the condition and calculates adjustment directions accordingly.
309
+
310
+ Args:
311
+ operand_left (Union[str, Transformation]): Name of the left
312
+ tag or a transformation to apply.
313
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
314
+ operand_right (Union[str, Transformation]): Name of the right
315
+ tag or a transformation to apply.
316
+ name (str, optional): A unique name for the constraint. If not
317
+ provided, a name is auto-generated in the format
318
+ "<operand_left> <comparator> <operand_right>".
319
+ enforce (bool, optional): If False, only monitor the constraint
320
+ without adjusting the loss. Defaults to True.
321
+ rescale_factor (Number, optional): Factor to scale the
322
+ constraint-adjusted loss. Defaults to 1.5.
323
+
324
+ Raises:
325
+ TypeError: If a provided attribute has an incompatible type.
326
+
327
+ Notes:
328
+ - The tags must be defined in the `descriptor` mapping.
329
+ - The constraint name is composed using the left tag, comparator, and right tag.
330
+
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ operand_left: str | Transformation,
336
+ comparator: Literal[">", "<", ">=", "<="],
337
+ operand_right: str | Transformation,
338
+ name: str = None,
339
+ enforce: bool = True,
340
+ rescale_factor: Number = 1.5,
341
+ ) -> None:
342
+ """Initializes a BinaryConstraint instance.
343
+
344
+ Args:
345
+ operand_left (Union[str, Transformation]): Name of the left
346
+ tag or a transformation to apply.
347
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
348
+ operand_right (Union[str, Transformation]): Name of the right
349
+ tag or a transformation to apply.
350
+ name (str, optional): A unique name for the constraint. If not
351
+ provided, a name is auto-generated in the format
352
+ "<operand_left> <comparator> <operand_right>".
353
+ enforce (bool, optional): If False, only monitor the constraint
354
+ without adjusting the loss. Defaults to True.
355
+ rescale_factor (Number, optional): Factor to scale the
356
+ constraint-adjusted loss. Defaults to 1.5.
357
+
358
+ Raises:
359
+ TypeError: If a provided attribute has an incompatible type.
360
+
361
+ Notes:
362
+ - The tags must be defined in the `descriptor` mapping.
363
+ - The constraint name is composed using the left tag,
364
+ comparator, and right tag.
365
+ """
366
+ # Type checking
367
+ validate_type("operand_left", operand_left, (str, Transformation))
368
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
369
+ validate_type("operand_right", operand_right, (str, Transformation))
370
+
371
+ # If transformation is provided, get tag name, else use IdentityTransformation
372
+ if isinstance(operand_left, Transformation):
373
+ tag_left = operand_left.tag
374
+ transformation_left = operand_left
375
+ else:
376
+ tag_left = operand_left
377
+ transformation_left = IdentityTransformation(tag_left)
378
+
379
+ if isinstance(operand_right, Transformation):
380
+ tag_right = operand_right.tag
381
+ transformation_right = operand_right
382
+ else:
383
+ tag_right = operand_right
384
+ transformation_right = IdentityTransformation(tag_right)
385
+
386
+ # Compose constraint name
387
+ name = f"{tag_left} {comparator} {tag_right}"
388
+
389
+ # Init parent class
390
+ super().__init__({tag_left, tag_right}, name, enforce, rescale_factor)
391
+
392
+ # Init variables
393
+ self.tag_left = tag_left
394
+ self.tag_right = tag_right
395
+ self.transformation_left = transformation_left
396
+ self.transformation_right = transformation_right
397
+
398
+ # Calculate directions based on constraint operator
399
+ if comparator in ["<", "<="]:
400
+ self.direction_left = 1
401
+ self.direction_right = -1
402
+ elif comparator in [">", ">="]:
403
+ self.direction_left = -1
404
+ self.direction_right = 1
405
+ self.comparator = COMPARATOR_MAP[comparator]
406
+
407
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
408
+ """Evaluate whether the binary constraint is satisfied for the current predictions.
409
+
410
+ The constraint compares the outputs of two tags using the specified
411
+ comparator function. A result of `1` indicates the constraint is satisfied
412
+ for a sample, and `0` indicates it is violated.
413
+
414
+ Args:
415
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
416
+
417
+ Returns:
418
+ tuple[Tensor, Tensor]:
419
+ - result (Tensor): Binary tensor indicating constraint satisfaction
420
+ (1 for satisfied, 0 for violated) for each sample.
421
+ - mask (Tensor): Tensor of ones with the same shape as `result`,
422
+ used for constraint aggregation.
423
+ """
424
+ # Select relevant columns
425
+ selection_left = self.descriptor.select(self.tag_left, data)
426
+ selection_right = self.descriptor.select(self.tag_right, data)
427
+
428
+ # Apply transformations
429
+ selection_left = self.transformation_left(selection_left)
430
+ selection_right = self.transformation_right(selection_right)
431
+
432
+ result = self.comparator(selection_left, selection_right).float()
433
+
434
+ return result, ones_like(result)
435
+
436
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
437
+ """Compute adjustment directions for the tags involved in the binary constraint.
438
+
439
+ The returned directions indicate how to adjust each tag's output to
440
+ satisfy the constraint. Only currently supported for dense layers.
441
+
442
+ Args:
443
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
444
+
445
+ Returns:
446
+ dict[str, Tensor]: A mapping from layer names to tensors specifying
447
+ the normalized adjustment directions for each tag involved in the
448
+ constraint.
449
+ """
450
+ # NOTE currently only works for dense layers due
451
+ # to tag to index translation
452
+
453
+ output = {}
454
+
455
+ for layer in self.layers:
456
+ output[layer] = zeros_like(data[layer][0], device=self.device)
457
+
458
+ layer_left, index_left = self.descriptor.location(self.tag_left)
459
+ layer_right, index_right = self.descriptor.location(self.tag_right)
460
+ output[layer_left][index_left] = self.direction_left
461
+ output[layer_right][index_right] = self.direction_right
462
+
463
+ for layer in self.layers:
464
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
465
+
466
+ return output
467
+
468
+
469
+ class SumConstraint(Constraint):
470
+ """A constraint that enforces a weighted summation comparison between two groups of tags.
471
+
472
+ This class evaluates whether the weighted sum of outputs from one set of
473
+ tags satisfies a comparison operation with the weighted sum of
474
+ outputs from another set of tags.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ operands_left: list[str | Transformation],
480
+ comparator: Literal[">", "<", ">=", "<="],
481
+ operands_right: list[str | Transformation],
482
+ weights_left: list[Number] = None,
483
+ weights_right: list[Number] = None,
484
+ name: str = None,
485
+ enforce: bool = True,
486
+ rescale_factor: Number = 1.5,
487
+ ) -> None:
488
+ """Initializes the SumConstraint.
489
+
490
+ Args:
491
+ operands_left (list[Union[str, Transformation]]): List of tags
492
+ or transformations on the left side.
493
+ comparator (Literal[">", "<", ">=", "<="]): Comparison operator used in the constraint.
494
+ operands_right (list[Union[str, Transformation]]): List of tags
495
+ or transformations on the right side.
496
+ weights_left (list[Number], optional): Weights for the left
497
+ tags. Defaults to None.
498
+ weights_right (list[Number], optional): Weights for the right
499
+ tags. Defaults to None.
500
+ name (str, optional): Unique name for the constraint.
501
+ If None, it's auto-generated. Defaults to None.
502
+ enforce (bool, optional): If False, only monitor the constraint
503
+ without adjusting the loss. Defaults to True.
504
+ rescale_factor (Number, optional): Factor to scale the
505
+ constraint-adjusted loss. Defaults to 1.5.
506
+
507
+ Raises:
508
+ TypeError: If a provided attribute has an incompatible type.
509
+ ValueError: If the dimensions of tags and weights mismatch.
510
+ """
511
+ # Type checking
512
+ validate_iterable("operands_left", operands_left, (str, Transformation))
513
+ validate_comparator("comparator", comparator, COMPARATOR_MAP)
514
+ validate_iterable("operands_right", operands_right, (str, Transformation))
515
+ validate_iterable("weights_left", weights_left, Number, allow_none=True)
516
+ validate_iterable("weights_right", weights_right, Number, allow_none=True)
517
+
518
+ # If transformation is provided, get tag, else use IdentityTransformation
519
+ tags_left: list[str] = []
520
+ transformations_left: list[Transformation] = []
521
+ for operand_left in operands_left:
522
+ if isinstance(operand_left, Transformation):
523
+ tag_left = operand_left.tag
524
+ tags_left.append(tag_left)
525
+ transformations_left.append(operand_left)
526
+ else:
527
+ tag_left = operand_left
528
+ tags_left.append(tag_left)
529
+ transformations_left.append(IdentityTransformation(tag_left))
530
+
531
+ tags_right: list[str] = []
532
+ transformations_right: list[Transformation] = []
533
+ for operand_right in operands_right:
534
+ if isinstance(operand_right, Transformation):
535
+ tag_right = operand_right.tag
536
+ tags_right.append(tag_right)
537
+ transformations_right.append(operand_right)
538
+ else:
539
+ tag_right = operand_right
540
+ tags_right.append(tag_right)
541
+ transformations_right.append(IdentityTransformation(tag_right))
542
+
543
+ # Compose constraint name
544
+ w_left = weights_left or [""] * len(tags_left)
545
+ w_right = weights_right or [""] * len(tags_right)
546
+ left_expr = " + ".join(f"{w}{n}" for w, n in zip(w_left, tags_left, strict=False))
547
+ right_expr = " + ".join(f"{w}{n}" for w, n in zip(w_right, tags_right, strict=False))
548
+ name = f"{left_expr} {comparator} {right_expr}"
549
+
550
+ # Init parent class
551
+ tags = set(tags_left) | set(tags_right)
552
+ super().__init__(tags, name, enforce, rescale_factor)
553
+
554
+ # Init variables
555
+ self.tags_left = tags_left
556
+ self.tags_right = tags_right
557
+ self.transformations_left = transformations_left
558
+ self.transformations_right = transformations_right
559
+
560
+ # If feature list dimensions don't match weight list dimensions, raise error
561
+ if weights_left and (len(tags_left) != len(weights_left)):
562
+ raise ValueError(
563
+ "The dimensions of tags_left don't match with the dimensions of weights_left."
564
+ )
565
+ if weights_right and (len(tags_right) != len(weights_right)):
566
+ raise ValueError(
567
+ "The dimensions of tags_right don't match with the dimensions of weights_right."
568
+ )
569
+
570
+ # If weights are provided for summation, transform them to Tensors
571
+ if weights_left:
572
+ self.weights_left = tensor(weights_left, device=self.device)
573
+ else:
574
+ self.weights_left = ones(len(tags_left), device=self.device)
575
+ if weights_right:
576
+ self.weights_right = tensor(weights_right, device=self.device)
577
+ else:
578
+ self.weights_right = ones(len(tags_right), device=self.device)
579
+
580
+ # Calculate directions based on constraint operator
581
+ if comparator in ["<", "<="]:
582
+ self.direction_left = -1
583
+ self.direction_right = 1
584
+ elif comparator in [">", ">="]:
585
+ self.direction_left = 1
586
+ self.direction_right = -1
587
+ self.comparator = COMPARATOR_MAP[comparator]
588
+
589
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
590
+ """Evaluate whether the weighted sum constraint is satisfied.
591
+
592
+ Computes the weighted sum of outputs from the left and right tags,
593
+ applies the specified comparator function, and returns a binary result for
594
+ each sample.
595
+
596
+ Args:
597
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
598
+
599
+ Returns:
600
+ tuple[Tensor, Tensor]:
601
+ - result (Tensor): Binary tensor indicating whether the constraint
602
+ is satisfied (1) or violated (0) for each sample.
603
+ - mask (Tensor): Tensor of ones, used for constraint aggregation.
604
+ """
605
+
606
+ def compute_weighted_sum(
607
+ tags: list[str],
608
+ transformations: list[Transformation],
609
+ weights: Tensor,
610
+ ) -> Tensor:
611
+ # Select relevant columns
612
+ selections = [self.descriptor.select(tag, data) for tag in tags]
613
+
614
+ # Apply transformations
615
+ results = []
616
+ for transformation, selection in zip(transformations, selections, strict=False):
617
+ results.append(transformation(selection))
618
+
619
+ # Extract predictions for all tags and apply weights in bulk
620
+ predictions = stack(results)
621
+
622
+ # Calculate weighted sum
623
+ return (predictions * weights.view(-1, 1, 1)).sum(dim=0)
624
+
625
+ # Compute weighted sums
626
+ weighted_sum_left = compute_weighted_sum(
627
+ self.tags_left, self.transformations_left, self.weights_left
628
+ )
629
+ weighted_sum_right = compute_weighted_sum(
630
+ self.tags_right, self.transformations_right, self.weights_right
631
+ )
632
+
633
+ # Apply the comparator and calculate the result
634
+ result = self.comparator(weighted_sum_left, weighted_sum_right).float()
635
+
636
+ return result, ones_like(result)
637
+
638
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
639
+ """Compute adjustment directions for tags involved in the weighted sum constraint.
640
+
641
+ The directions indicate how to adjust each tag's output to satisfy the
642
+ constraint. Only dense layers are currently supported.
643
+
644
+ Args:
645
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
646
+
647
+ Returns:
648
+ dict[str, Tensor]: Mapping from layer names to normalized tensors
649
+ specifying adjustment directions for each tag involved in the constraint.
650
+ """
651
+ # NOTE currently only works for dense layers
652
+ # due to tag to index translation
653
+
654
+ output = {}
655
+
656
+ for layer in self.layers:
657
+ output[layer] = zeros_like(data[layer][0], device=self.device)
658
+
659
+ for tag_left in self.tags_left:
660
+ layer, index = self.descriptor.location(tag_left)
661
+ output[layer][index] = self.direction_left
662
+
663
+ for tag_right in self.tags_right:
664
+ layer, index = self.descriptor.location(tag_right)
665
+ output[layer][index] = self.direction_right
666
+
667
+ for layer in self.layers:
668
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
669
+
670
+ return output
671
+
672
+
673
+ class RankedMonotonicityConstraint(MonotonicityConstraint):
674
+ """Constraint that enforces a monotonic relationship between two tags.
675
+
676
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
677
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
678
+ """
679
+
680
+ def __init__(
681
+ self,
682
+ tag_prediction: str,
683
+ tag_reference: str,
684
+ rescale_factor_lower: float = 1.5,
685
+ rescale_factor_upper: float = 1.75,
686
+ stable: bool = True,
687
+ direction: Literal["ascending", "descending"] = "ascending",
688
+ name: str = None,
689
+ enforce: bool = True,
690
+ ):
691
+ """Constraint that enforces monotonicity on a predicted output.
692
+
693
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
694
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
695
+
696
+ Args:
697
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
698
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
699
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
700
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
701
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
702
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
703
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
704
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
705
+ """
706
+ # Compose constraint name
707
+ if name is None:
708
+ name = f"{tag_prediction} monotonically (ranked) {direction} by {tag_reference}"
709
+
710
+ # Init parent class
711
+ super().__init__(
712
+ tag_prediction=tag_prediction,
713
+ tag_reference=tag_reference,
714
+ rescale_factor_lower=rescale_factor_lower,
715
+ rescale_factor_upper=rescale_factor_upper,
716
+ stable=stable,
717
+ direction=direction,
718
+ name=name,
719
+ enforce=enforce,
720
+ )
721
+
722
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
723
+ """Evaluate whether the monotonicity constraint is satisfied."""
724
+ # Select relevant columns
725
+ preds = self.descriptor.select(self.tag_prediction, data)
726
+ targets = self.descriptor.select(self.tag_reference, data)
727
+
728
+ # Utility: convert values -> ranks (0 ... num_features-1)
729
+ def compute_ranks(x: Tensor, descending: bool) -> Tensor:
730
+ return argsort(
731
+ argsort(x, descending=descending, stable=self.stable, dim=0),
732
+ descending=False,
733
+ stable=self.stable,
734
+ dim=0,
735
+ )
736
+
737
+ # Compute predicted and target ranks
738
+ pred_ranks = compute_ranks(preds, descending=self.descending)
739
+ target_ranks = compute_ranks(targets, descending=False)
740
+
741
+ # Rank difference
742
+ rank_diff = pred_ranks - target_ranks
743
+
744
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
745
+ batch_size = preds.shape[0]
746
+ invert_direction = -1 if self.descending else 1
747
+ self.compared_rankings = (
748
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
749
+ + self.rescale_factor_lower * sign(rank_diff)
750
+ ) * invert_direction
751
+
752
+ # Calculate satisfaction
753
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
754
+
755
+ return incorrect_rankings, ones_like(incorrect_rankings)
756
+
757
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
758
+ """Calculates ranking adjustments for monotonicity enforcement."""
759
+ layer, _ = self.descriptor.location(self.tag_prediction)
760
+ return {layer: self.compared_rankings}
761
+
762
+
763
+ class PairwiseMonotonicityConstraint(MonotonicityConstraint):
764
+ """Constraint that enforces a monotonic relationship between two tags.
765
+
766
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
767
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
768
+ """
769
+
770
+ def __init__(
771
+ self,
772
+ tag_prediction: str,
773
+ tag_reference: str,
774
+ rescale_factor_lower: float = 1.5,
775
+ rescale_factor_upper: float = 1.75,
776
+ stable: bool = True,
777
+ direction: Literal["ascending", "descending"] = "ascending",
778
+ name: str = None,
779
+ enforce: bool = True,
780
+ ):
781
+ """Constraint that enforces monotonicity on a predicted output.
782
+
783
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
784
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
785
+
786
+ Args:
787
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
788
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
789
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
790
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
791
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
792
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
793
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
794
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
795
+ """
796
+ # Compose constraint name
797
+ if name is None:
798
+ name = f"{tag_prediction} monotonically (pairwise) {direction} by {tag_reference}"
799
+
800
+ # Init parent class
801
+ super().__init__(
802
+ tag_prediction=tag_prediction,
803
+ tag_reference=tag_reference,
804
+ rescale_factor_lower=rescale_factor_lower,
805
+ rescale_factor_upper=rescale_factor_upper,
806
+ stable=stable,
807
+ direction=direction,
808
+ name=name,
809
+ enforce=enforce,
810
+ )
811
+
812
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
813
+ """Evaluate whether the monotonicity constraint is satisfied."""
814
+ # Select relevant columns
815
+ preds = self.descriptor.select(self.tag_prediction, data).squeeze(1)
816
+ targets = self.descriptor.select(self.tag_reference, data).squeeze(1)
817
+
818
+ # Sort targets back to original order
819
+ sorted_targets, idx = sort(targets, stable=self.stable, dim=0, descending=self.descending)
820
+ sorted_preds = preds[idx]
821
+
822
+ # Pairwise differences
823
+ preds_diff = sorted_preds.unsqueeze(1) - sorted_preds.unsqueeze(0)
824
+ targets_diff = sorted_targets.unsqueeze(1) - sorted_targets.unsqueeze(0)
825
+
826
+ # Consider only upper triangle to avoid duplicate comparisons
827
+ batch_size = preds.shape[0]
828
+ mask = triu(ones(batch_size, batch_size, dtype=torch.bool), diagonal=1)
829
+
830
+ # Pairwise violations
831
+ violations = (preds_diff * targets_diff < 0) & mask
832
+
833
+ # Rank difference
834
+ sorted_rank_diff = violations.float().sum(dim=1) - violations.float().sum(dim=0)
835
+
836
+ # Undo sorting to match original order
837
+ rank_diff = zeros_like(sorted_rank_diff)
838
+ rank_diff[idx] = sorted_rank_diff
839
+ rank_diff = rank_diff.unsqueeze(1)
840
+
841
+ # Rescale differences into [rescale_factor_lower, rescale_factor_upper]
842
+ invert_direction = -1 if self.descending else 1
843
+ self.compared_rankings = (
844
+ (rank_diff / batch_size) * (self.rescale_factor_upper - self.rescale_factor_lower)
845
+ + self.rescale_factor_lower * sign(rank_diff)
846
+ ) * invert_direction
847
+
848
+ # Calculate satisfaction
849
+ incorrect_rankings = eq(self.compared_rankings, 0).float()
850
+
851
+ return incorrect_rankings, ones_like(incorrect_rankings)
852
+
853
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
854
+ """Calculates ranking adjustments for monotonicity enforcement."""
855
+ layer, _ = self.descriptor.location(self.tag_prediction)
856
+ return {layer: self.compared_rankings}
857
+
858
+
859
+ class PerGroupMonotonicityConstraint(Constraint):
860
+ """Applies a monotonicity constraint independently per group of samples.
861
+
862
+ This class wraps an existing `MonotonicityConstraint` instance (`base`) and
863
+ enforces it **separately** for each unique group defined by `tag_group`.
864
+
865
+ Each group is treated as an independent mini-batch:
866
+ - The base constraint is applied to the group's subset of data.
867
+ - Violations and directions are computed per group and then reassembled
868
+ into the original batch order.
869
+
870
+ This is an explicit alternative to :class:`EncodedGroupedMonotonicityConstraint`,
871
+ which enforces the same logic using vectorized interval encoding.
872
+ """
873
+
874
+ def __init__(
875
+ self,
876
+ base: MonotonicityConstraint,
877
+ tag_group: str,
878
+ name: str | None = None,
879
+ ):
880
+ """Initializes the per-group monotonicity constraint.
881
+
882
+ Args:
883
+ base (MonotonicityConstraint): A pre-configured monotonicity constraint to apply per group.
884
+ tag_group (str): Column/tag name that identifies groups.
885
+ name (str | None, optional): Custom name for the constraint. Defaults to
886
+ f"{base.tag_prediction} for each {tag_group} ...".
887
+ """
888
+ # Compose constraint name
889
+ if name is None:
890
+ name = base.name.replace(
891
+ base.tag_prediction, f"{base.tag_prediction} for each {tag_group}", 1
892
+ )
893
+
894
+ # Init parent class
895
+ super().__init__(base.tags, name, base.enforce, base.rescale_factor)
896
+
897
+ # Init variables
898
+ self.base = base
899
+ self.tag_group = tag_group
900
+
901
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
902
+ """Evaluate whether the monotonicity constraint is satisfied per group."""
903
+ # Select group identifiers and convert to unique list
904
+ group_identifiers = self.descriptor.select(self.tag_group, data)
905
+ unique_group_identifiers = unique(group_identifiers, sorted=False).tolist()
906
+
907
+ # Initialize checks and directions
908
+ checks = zeros_like(group_identifiers, device=self.device)
909
+ self.directions = zeros_like(group_identifiers, device=self.device)
910
+
911
+ # Get prediction and target keys
912
+ preds_key, _ = self.descriptor.location(self.base.tag_prediction)
913
+ targets_key, _ = self.descriptor.location(self.base.tag_reference)
914
+
915
+ for group_identifier in unique_group_identifiers:
916
+ # Create mask for the samples in this group
917
+ group_mask = (group_identifiers == group_identifier).squeeze(1)
918
+
919
+ # Create mini-batch for the group
920
+ group_data = {
921
+ preds_key: data[preds_key][group_mask],
922
+ targets_key: data[targets_key][group_mask],
923
+ }
924
+
925
+ # Call super on the mini-batch
926
+ checks[group_mask], _ = self.base.check_constraint(group_data)
927
+ self.directions[group_mask] = self.base.compared_rankings
928
+
929
+ return checks, ones_like(checks)
930
+
931
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
932
+ """Calculates ranking adjustments for monotonicity enforcement."""
933
+ layer, _ = self.descriptor.location(self.base.tag_prediction)
934
+ return {layer: self.directions}
935
+
936
+
937
+ class EncodedGroupedMonotonicityConstraint(Constraint):
938
+ """Applies a base monotonicity constraint to groups via interval encoding.
939
+
940
+ This constraint enforces a monotonic relationship between a prediction tag
941
+ (`base.tag_prediction`) and a reference tag (`base.tag_reference`) for
942
+ each group defined by `tag_group`.
943
+
944
+ Instead of looping over groups, each group's predictions and targets are
945
+ shifted by a large offset to place them in **non-overlapping intervals**.
946
+ This allows the base constraint to be applied in a single, vectorized
947
+ operation.
948
+
949
+ This is a vectorized alternative to :class:`PerGroupMonotonicityConstraint`,
950
+ which enforces the same logic via explicit per-group iteration.
951
+ """
952
+
953
+ def __init__(
954
+ self,
955
+ base: MonotonicityConstraint,
956
+ tag_group: str,
957
+ name: str | None = None,
958
+ ):
959
+ """Initializes the encoded grouped monotonicity constraint.
960
+
961
+ Args:
962
+ base (MonotonicityConstraint): The base constraint to apply to each group.
963
+ tag_group (str): Column/tag name identifying groups in the batch.
964
+ name (str | None, optional): Optional custom name for this constraint.
965
+ If None, generates a descriptive name based on `base.name`.
966
+ """
967
+ # Compose constraint name
968
+ if name is None:
969
+ name = base.name.replace(
970
+ base.tag_prediction,
971
+ f"{base.tag_prediction} for each {tag_group} (encoded)",
972
+ 1,
973
+ )
974
+
975
+ # Init parent class
976
+ super().__init__(base.tags, name, base.enforce, base.rescale_factor)
977
+
978
+ # Init variables
979
+ self.base = base
980
+ self.tag_group = tag_group
981
+
982
+ # Initialize negation factor based on direction
983
+ self.negation = 1 if self.base.direction == "ascending" else -1
984
+
985
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
986
+ """Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval."""
987
+ # Get data and keys
988
+ ids = self.descriptor.select(self.tag_group, data)
989
+ preds = self.descriptor.select(self.base.tag_prediction, data)
990
+ targets = self.descriptor.select(self.base.tag_reference, data)
991
+ preds_key, _ = self.descriptor.location(self.base.tag_prediction)
992
+ targets_key, _ = self.descriptor.location(self.base.tag_reference)
993
+
994
+ new_preds = preds + ids * (preds.max() - preds.min() + 1)
995
+ new_targets = targets + self.negation * ids * (targets.max() - targets.min() + 1)
996
+
997
+ # Create new batch for child constraint
998
+ new_data = {preds_key: new_preds, targets_key: new_targets}
999
+
1000
+ # Call super on the adjusted batch
1001
+ checks, _ = self.base.check_constraint(new_data)
1002
+ self.directions = self.base.compared_rankings
1003
+
1004
+ return checks, ones_like(checks)
1005
+
1006
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
1007
+ """Calculates ranking adjustments for monotonicity enforcement."""
1008
+ layer, _ = self.descriptor.location(self.base.tag_prediction)
1009
+ return {layer: self.directions}
1010
+
1011
+
1012
+ class ANDConstraint(Constraint):
1013
+ """A composite constraint that enforces the logical AND of multiple constraints.
1014
+
1015
+ This class combines multiple sub-constraints and evaluates them jointly:
1016
+
1017
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1018
+ are satisfied (elementwise logical AND).
1019
+ * The corrective direction is computed by weighting each sub-constraint's
1020
+ direction with its satisfaction mask and summing across all sub-constraints.
1021
+ """
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ *constraints: Constraint,
1026
+ name: str = None,
1027
+ monitor_only: bool = False,
1028
+ rescale_factor: Number = 1.5,
1029
+ ) -> None:
1030
+ """A composite constraint that enforces the logical AND of multiple constraints.
1031
+
1032
+ This class combines multiple sub-constraints and evaluates them jointly:
1033
+
1034
+ * The satisfaction of the AND constraint is `True` only if all sub-constraints
1035
+ are satisfied (elementwise logical AND).
1036
+ * The corrective direction is computed by weighting each sub-constraint's
1037
+ direction with its satisfaction mask and summing across all sub-constraints.
1038
+
1039
+ Args:
1040
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1041
+ name (str, optional): A custom name for this constraint. If not provided,
1042
+ the name will be composed from the sub-constraint names joined with
1043
+ " AND ".
1044
+ monitor_only (bool, optional): If True, the constraint will be monitored
1045
+ but not enforced. Defaults to False.
1046
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1047
+ corrections. Defaults to 1.5.
1048
+
1049
+ Attributes:
1050
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1051
+ neurons (set): The union of neurons referenced by the sub-constraints.
1052
+ name (str): The name of the constraint (composed or custom).
1053
+ """
1054
+ # Type checking
1055
+ validate_iterable("constraints", constraints, Constraint)
1056
+
1057
+ # Compose constraint name
1058
+ if not name:
1059
+ name = " AND ".join([constraint.name for constraint in constraints])
1060
+
1061
+ # Init parent class
1062
+ super().__init__(
1063
+ set().union(*(constraint.tags for constraint in constraints)),
1064
+ name,
1065
+ monitor_only,
1066
+ rescale_factor,
1067
+ )
1068
+
1069
+ # Init variables
1070
+ self.constraints = constraints
1071
+
1072
+ def check_constraint(self, data: dict[str, Tensor]):
1073
+ """Evaluate whether all sub-constraints are satisfied.
1074
+
1075
+ Args:
1076
+ data: Model predictions and associated batch/context information.
1077
+
1078
+ Returns:
1079
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1080
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1081
+ elementwise whether all constraints are satisfied
1082
+ (logical AND).
1083
+ * `mask`: A tensor of ones with the same shape as
1084
+ `total_satisfaction`. Typically used as a weighting mask
1085
+ in downstream processing.
1086
+ """
1087
+ total_satisfaction: Tensor = None
1088
+ total_mask: Tensor = None
1089
+
1090
+ # TODO vectorize this loop
1091
+ for constraint in self.constraints:
1092
+ satisfaction, mask = constraint.check_constraint(data)
1093
+ if total_satisfaction is None:
1094
+ total_satisfaction = satisfaction
1095
+ total_mask = mask
1096
+ else:
1097
+ total_satisfaction = logical_and(total_satisfaction, satisfaction)
1098
+ total_mask = logical_or(total_mask, mask)
1099
+
1100
+ return total_satisfaction.float(), total_mask.float()
1101
+
1102
+ def calculate_direction(self, data: dict[str, Tensor]):
1103
+ """Compute the corrective direction by aggregating sub-constraint directions.
1104
+
1105
+ Each sub-constraint contributes its corrective direction, weighted
1106
+ by its satisfaction mask. The directions are summed across constraints
1107
+ for each affected layer.
1108
+
1109
+ Args:
1110
+ data: Model predictions and associated batch/context information.
1111
+
1112
+ Returns:
1113
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1114
+ tensors. Each entry represents the aggregated correction to apply
1115
+ to that layer, based on the satisfaction-weighted sum of
1116
+ sub-constraint directions.
1117
+ """
1118
+ total_direction: dict[str, Tensor] = {}
1119
+
1120
+ # TODO vectorize this loop
1121
+ for constraint in self.constraints:
1122
+ # TODO improve efficiency by avoiding double computation?
1123
+ satisfaction, _ = constraint.check_constraint(data)
1124
+ direction = constraint.calculate_direction(data)
1125
+
1126
+ for layer, dir in direction.items():
1127
+ if layer not in total_direction:
1128
+ total_direction[layer] = satisfaction * dir
1129
+ else:
1130
+ total_direction[layer] += satisfaction * dir
1131
+
1132
+ return total_direction
1133
+
1134
+
1135
+ class ORConstraint(Constraint):
1136
+ """A composite constraint that enforces the logical OR of multiple constraints.
1137
+
1138
+ This class combines multiple sub-constraints and evaluates them jointly:
1139
+
1140
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1141
+ is satisfied (elementwise logical OR).
1142
+ * The corrective direction is computed by weighting each sub-constraint's
1143
+ direction with its satisfaction mask and summing across all sub-constraints.
1144
+ """
1145
+
1146
+ def __init__(
1147
+ self,
1148
+ *constraints: Constraint,
1149
+ name: str = None,
1150
+ monitor_only: bool = False,
1151
+ rescale_factor: Number = 1.5,
1152
+ ) -> None:
1153
+ """A composite constraint that enforces the logical OR of multiple constraints.
1154
+
1155
+ This class combines multiple sub-constraints and evaluates them jointly:
1156
+
1157
+ * The satisfaction of the OR constraint is `True` if at least one sub-constraint
1158
+ is satisfied (elementwise logical OR).
1159
+ * The corrective direction is computed by weighting each sub-constraint's
1160
+ direction with its satisfaction mask and summing across all sub-constraints.
1161
+
1162
+ Args:
1163
+ *constraints (Constraint): One or more `Constraint` instances to be combined.
1164
+ name (str, optional): A custom name for this constraint. If not provided,
1165
+ the name will be composed from the sub-constraint names joined with
1166
+ " OR ".
1167
+ monitor_only (bool, optional): If True, the constraint will be monitored
1168
+ but not enforced. Defaults to False.
1169
+ rescale_factor (Number, optional): A scaling factor applied when rescaling
1170
+ corrections. Defaults to 1.5.
1171
+
1172
+ Attributes:
1173
+ constraints (tuple[Constraint, ...]): The sub-constraints being combined.
1174
+ neurons (set): The union of neurons referenced by the sub-constraints.
1175
+ name (str): The name of the constraint (composed or custom).
1176
+ """
1177
+ # Type checking
1178
+ validate_iterable("constraints", constraints, Constraint)
1179
+
1180
+ # Compose constraint name
1181
+ if not name:
1182
+ name = " OR ".join([constraint.name for constraint in constraints])
1183
+
1184
+ # Init parent class
1185
+ super().__init__(
1186
+ set().union(*(constraint.tags for constraint in constraints)),
1187
+ name,
1188
+ monitor_only,
1189
+ rescale_factor,
1190
+ )
1191
+
1192
+ # Init variables
1193
+ self.constraints = constraints
1194
+
1195
+ def check_constraint(self, data: dict[str, Tensor]):
1196
+ """Evaluate whether any sub-constraints are satisfied.
1197
+
1198
+ Args:
1199
+ data: Model predictions and associated batch/context information.
1200
+
1201
+ Returns:
1202
+ tuple[Tensor, Tensor]: A tuple `(total_satisfaction, mask)` where:
1203
+ * `total_satisfaction`: A boolean or numeric tensor indicating
1204
+ elementwise whether any constraints are satisfied
1205
+ (logical OR).
1206
+ * `mask`: A tensor of ones with the same shape as
1207
+ `total_satisfaction`. Typically used as a weighting mask
1208
+ in downstream processing.
1209
+ """
1210
+ total_satisfaction: Tensor = None
1211
+ total_mask: Tensor = None
1212
+
1213
+ # TODO vectorize this loop
1214
+ for constraint in self.constraints:
1215
+ satisfaction, mask = constraint.check_constraint(data)
1216
+ if total_satisfaction is None:
1217
+ total_satisfaction = satisfaction
1218
+ total_mask = mask
1219
+ else:
1220
+ total_satisfaction = logical_or(total_satisfaction, satisfaction)
1221
+ total_mask = logical_or(total_mask, mask)
1222
+
1223
+ return total_satisfaction.float(), total_mask.float()
1224
+
1225
+ def calculate_direction(self, data: dict[str, Tensor]):
1226
+ """Compute the corrective direction by aggregating sub-constraint directions.
1227
+
1228
+ Each sub-constraint contributes its corrective direction, weighted
1229
+ by its satisfaction mask. The directions are summed across constraints
1230
+ for each affected layer.
1231
+
1232
+ Args:
1233
+ data: Model predictions and associated batch/context information.
1234
+
1235
+ Returns:
1236
+ dict[str, Tensor]: A mapping from layer identifiers to correction
1237
+ tensors. Each entry represents the aggregated correction to apply
1238
+ to that layer, based on the satisfaction-weighted sum of
1239
+ sub-constraint directions.
1240
+ """
1241
+ total_direction: dict[str, Tensor] = {}
1242
+
1243
+ # TODO vectorize this loop
1244
+ for constraint in self.constraints:
1245
+ # TODO improve efficiency by avoiding double computation?
1246
+ satisfaction, _ = constraint.check_constraint(data)
1247
+ direction = constraint.calculate_direction(data)
1248
+
1249
+ for layer, dir in direction.items():
1250
+ if layer not in total_direction:
1251
+ total_direction[layer] = satisfaction * dir
1252
+ else:
1253
+ total_direction[layer] += satisfaction * dir
1254
+
1255
+ return total_direction