congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,10 +24,6 @@ Usage:
24
24
 
25
25
  """
26
26
 
27
- import random
28
- import string
29
- import warnings
30
- from abc import ABC, abstractmethod
31
27
  from collections.abc import Callable
32
28
  from numbers import Number
33
29
  from typing import Literal
@@ -54,156 +50,10 @@ from torch import (
54
50
  )
55
51
  from torch.nn.functional import normalize
56
52
 
57
- from .descriptor import Descriptor
58
- from .transformations import IdentityTransformation, Transformation
59
- from .utils import validate_comparator_pytorch, validate_iterable, validate_type
60
-
61
-
62
- class Constraint(ABC):
63
- """Abstract base class for defining constraints applied to neural networks.
64
-
65
- A `Constraint` specifies conditions that the neural network outputs
66
- should satisfy. It supports monitoring constraint satisfaction
67
- during training and can adjust loss to enforce constraints. Subclasses
68
- must implement the `check_constraint` and `calculate_direction` methods.
69
-
70
- Args:
71
- tags (set[str]): Tags referencing parts of the network where this constraint applies to.
72
- name (str, optional): A unique name for the constraint. If not provided,
73
- a name is generated based on the class name and a random suffix.
74
- enforce (bool, optional): If False, only monitor the constraint
75
- without adjusting the loss. Defaults to True.
76
- rescale_factor (Number, optional): Factor to scale the
77
- constraint-adjusted loss. Defaults to 1.5. Should be greater
78
- than 1 to give weight to the constraint.
79
-
80
- Raises:
81
- TypeError: If a provided attribute has an incompatible type.
82
- ValueError: If any tag in `tags` is not
83
- defined in the `descriptor`.
84
-
85
- Note:
86
- - If `rescale_factor <= 1`, a warning is issued.
87
- - If `name` is not provided, a name is auto-generated,
88
- and a warning is logged.
89
-
90
- """
91
-
92
- descriptor: Descriptor = None
93
- device = None
94
-
95
- def __init__(
96
- self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
97
- ) -> None:
98
- """Initializes a new Constraint instance.
99
-
100
- Args:
101
- tags (set[str]): Tags referencing parts of the network where this constraint applies to.
102
- name (str, optional): A unique name for the constraint. If not
103
- provided, a name is generated based on the class name and a
104
- random suffix.
105
- enforce (bool, optional): If False, only monitor the constraint
106
- without adjusting the loss. Defaults to True.
107
- rescale_factor (Number, optional): Factor to scale the
108
- constraint-adjusted loss. Defaults to 1.5. Should be greater
109
- than 1 to give weight to the constraint.
110
-
111
- Raises:
112
- TypeError: If a provided attribute has an incompatible type.
113
- ValueError: If any tag in `tags` is not defined in the `descriptor`.
114
-
115
- Note:
116
- - If `rescale_factor <= 1`, a warning is issued.
117
- - If `name` is not provided, a name is auto-generated, and a
118
- warning is logged.
119
- """
120
- # Init parent class
121
- super().__init__()
122
-
123
- # Type checking
124
- validate_iterable("tags", tags, str)
125
- validate_type("name", name, str, allow_none=True)
126
- validate_type("enforce", enforce, bool)
127
- validate_type("rescale_factor", rescale_factor, Number)
128
-
129
- # Init object variables
130
- self.tags = tags
131
- self.rescale_factor = rescale_factor
132
- self.initial_rescale_factor = rescale_factor
133
- self.enforce = enforce
134
-
135
- # Perform checks
136
- if rescale_factor <= 1:
137
- warnings.warn(
138
- f"Rescale factor for constraint {name} is <= 1. The network "
139
- "will favor general loss over the constraint-adjusted loss. "
140
- "Is this intended behavior? Normally, the rescale factor "
141
- "should always be larger than 1.",
142
- stacklevel=2,
143
- )
144
-
145
- # If no constraint_name is set, generate one based
146
- # on the class name and a random suffix
147
- if name:
148
- self.name = name
149
- else:
150
- random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
151
- self.name = f"{self.__class__.__name__}_{random_suffix}"
152
- warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
153
-
154
- # Infer layers from descriptor and tags
155
- self.layers = set()
156
- for tag in self.tags:
157
- if not self.descriptor.exists(tag):
158
- raise ValueError(
159
- f"The tag {tag} used with constraint "
160
- f"{self.name} is not defined in the descriptor. Please "
161
- "add it to the correct layer using "
162
- "descriptor.add('layer', ...)."
163
- )
164
-
165
- layer, _ = self.descriptor.location(tag)
166
- self.layers.add(layer)
167
-
168
- @abstractmethod
169
- def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
170
- """Evaluates whether the given model predictions satisfy the constraint.
171
-
172
- 1 IS SATISFIED, 0 IS NOT SATISFIED
173
-
174
- Args:
175
- data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
176
-
177
- Returns:
178
- tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
179
- indicating whether the constraint is satisfied (with value 1.0
180
- for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
181
- mask that indicates the relevance of each sample (`True` for relevant
182
- samples and `False` for irrelevant ones).
183
-
184
- Raises:
185
- NotImplementedError: If not implemented in a subclass.
186
- """
187
- raise NotImplementedError
188
-
189
- @abstractmethod
190
- def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
191
- """Compute adjustment directions to better satisfy the constraint.
192
-
193
- Given the model predictions, input batch, and context, this method calculates the direction
194
- in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
195
-
196
- Args:
197
- data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
198
-
199
- Returns:
200
- dict[str, Tensor]: Dictionary mapping network layers to tensors that
201
- specify the adjustment direction for each tag.
202
-
203
- Raises:
204
- NotImplementedError: Must be implemented by subclasses.
205
- """
206
- raise NotImplementedError
53
+ from ..transformations.base import Transformation
54
+ from ..transformations.registry import IdentityTransformation
55
+ from ..utils.validation import validate_comparator_pytorch, validate_iterable, validate_type
56
+ from .base import Constraint
207
57
 
208
58
 
209
59
  class ImplicationConstraint(Constraint):
@@ -918,11 +768,23 @@ class MonotonicityConstraint(Constraint):
918
768
  return {layer: self.compared_rankings}
919
769
 
920
770
 
921
- class GroupedMonotonicityConstraint(MonotonicityConstraint):
922
- """Constraint that enforces a monotonic relationship between two tags.
771
+ class PerGroupMonotonicityConstraint(MonotonicityConstraint):
772
+ """Group-wise monotonicity constraint enforced independently per group.
923
773
 
924
- This constraint ensures that the activations of a prediction tag (`tag_prediction`)
925
- are monotonically ascending or descending with respect to a target tag (`tag_reference`).
774
+ This constraint enforces a monotonic relationship between a prediction tag
775
+ (`tag_prediction`) and a reference tag (`tag_reference`) **within each group**
776
+ identified by `tag_group_identifier`.
777
+
778
+ For each unique group identifier, the base `MonotonicityConstraint` is applied
779
+ independently to the corresponding subset of samples. This makes the behavior
780
+ semantically explicit and easy to reason about, as no interaction or ordering
781
+ is introduced across different groups.
782
+
783
+ Notes:
784
+ - Groups are treated as fully independent constraint instances.
785
+ - This implementation prioritizes correctness and interpretability.
786
+ - It may be less efficient for large numbers of groups due to explicit
787
+ iteration and repeated constraint evaluation.
926
788
  """
927
789
 
928
790
  def __init__(
@@ -1010,6 +872,106 @@ class GroupedMonotonicityConstraint(MonotonicityConstraint):
1010
872
  return {layer: self.directions}
1011
873
 
1012
874
 
875
+ class EncodedGroupedMonotonicityConstraint(MonotonicityConstraint):
876
+ """Group-wise monotonicity constraint enforced via rank encoding.
877
+
878
+ This constraint enforces a monotonic relationship between a prediction tag
879
+ (`tag_prediction`) and a reference tag (`tag_reference`) within each group
880
+ identified by `tag_group_identifier`, using a fully vectorized approach.
881
+
882
+ Group independence is achieved by encoding the group identifiers into the
883
+ prediction and reference values via large offsets, effectively separating
884
+ the rank spaces of different groups. This allows the base
885
+ `MonotonicityConstraint` to be applied once to the entire batch without
886
+ explicit per-group iteration.
887
+
888
+ Notes:
889
+ - Groups are isolated implicitly through rank-space separation.
890
+ - The logic is less explicit than per-group evaluation and relies on
891
+ offset-based rank encoding for correctness.
892
+ - This constraint might cause floating point errors if
893
+ the prediction or target range is very large.
894
+ """
895
+
896
+ def __init__(
897
+ self,
898
+ tag_prediction: str,
899
+ tag_reference: str,
900
+ tag_group_identifier: str,
901
+ rescale_factor_lower: float = 1.5,
902
+ rescale_factor_upper: float = 1.75,
903
+ stable: bool = True,
904
+ direction: Literal["ascending", "descending"] = "ascending",
905
+ name: str = None,
906
+ enforce: bool = True,
907
+ ):
908
+ """Constraint that enforces monotonicity on a predicted output.
909
+
910
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
911
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
912
+
913
+ Args:
914
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
915
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
916
+ tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
917
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
918
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
919
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
920
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
921
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
922
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
923
+ """
924
+ # Compose constraint name
925
+ if name is None:
926
+ name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
927
+
928
+ # Init parent class
929
+ super().__init__(
930
+ tag_prediction=tag_prediction,
931
+ tag_reference=tag_reference,
932
+ rescale_factor_lower=rescale_factor_lower,
933
+ rescale_factor_upper=rescale_factor_upper,
934
+ stable=stable,
935
+ direction=direction,
936
+ name=name,
937
+ enforce=enforce,
938
+ )
939
+
940
+ # Init variables
941
+ self.tag_prediction = tag_prediction
942
+ self.tag_reference = tag_reference
943
+ self.tag_group_identifier = tag_group_identifier
944
+
945
+ # Initialize negation factor based on direction
946
+ self.negation = 1 if direction == "ascending" else -1
947
+
948
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
949
+ """Evaluate whether the monotonicity constraint is satisfied."""
950
+ # Get data and keys
951
+ ids = self.descriptor.select(self.tag_group_identifier, data)
952
+ preds = self.descriptor.select(self.tag_prediction, data)
953
+ targets = self.descriptor.select(self.tag_reference, data)
954
+ preds_key, _ = self.descriptor.location(self.tag_prediction)
955
+ targets_key, _ = self.descriptor.location(self.tag_reference)
956
+
957
+ new_preds = preds + ids * (preds.max() - preds.min() + 1)
958
+ new_targets = targets + self.negation * ids * (targets.max() - targets.min() + 1)
959
+
960
+ # Create new batch for child constraint
961
+ new_data = {preds_key: new_preds, targets_key: new_targets}
962
+
963
+ # Call super on the adjusted batch
964
+ checks, _ = super().check_constraint(new_data)
965
+ self.directions = self.compared_rankings
966
+
967
+ return checks, ones_like(checks)
968
+
969
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
970
+ """Calculates ranking adjustments for monotonicity enforcement."""
971
+ layer, _ = self.descriptor.location(self.tag_prediction)
972
+ return {layer: self.directions}
973
+
974
+
1013
975
  class ANDConstraint(Constraint):
1014
976
  """A composite constraint that enforces the logical AND of multiple constraints.
1015
977
 
@@ -0,0 +1,200 @@
1
+ """Defines the BatchRunner, which executes individual batches for training, validation, and testing.
2
+
3
+ Responsibilities:
4
+ - Move batch data to the appropriate device
5
+ - Run forward passes through the network
6
+ - Compute base and constraint-adjusted losses
7
+ - Perform backpropagation during training
8
+ - Accumulate metrics for loss and other monitored quantities
9
+ - Trigger callbacks at key points in the batch lifecycle
10
+ """
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ from torch.nn import Module
15
+ from torch.optim import Optimizer
16
+
17
+ from ..callbacks.base import CallbackManager
18
+ from ..core.constraint_engine import ConstraintEngine
19
+ from ..metrics import MetricManager
20
+
21
+
22
+ class BatchRunner:
23
+ """Executes a single batch for training, validation, or testing.
24
+
25
+ The BatchRunner handles moving data to the correct device, running the network
26
+ forward, computing base and constraint-adjusted losses, performing backpropagation
27
+ during training, accumulating metrics, and dispatching callbacks at key points
28
+ in the batch lifecycle.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ network: Module,
34
+ criterion,
35
+ optimizer: Optimizer,
36
+ constraint_engine: ConstraintEngine,
37
+ metric_manager: MetricManager | None,
38
+ callback_manager: CallbackManager | None,
39
+ device: torch.device,
40
+ ):
41
+ """Initialize the BatchRunner.
42
+
43
+ Args:
44
+ network: The neural network module to execute.
45
+ criterion: Loss function callable accepting (output, target, data=batch).
46
+ optimizer: Optimizer for updating network parameters.
47
+ constraint_engine: ConstraintEngine instance for evaluating and enforcing constraints.
48
+ metric_manager: Optional MetricManager for logging batch metrics.
49
+ callback_manager: Optional CallbackManager for triggering hooks during batch processing.
50
+ device: Torch device on which to place data and network.
51
+ """
52
+ self.network = network
53
+ self.criterion = criterion
54
+ self.optimizer = optimizer
55
+ self.constraint_engine = constraint_engine
56
+ self.metric_manager = metric_manager
57
+ self.callback_manager = callback_manager
58
+ self.device = device
59
+
60
+ def _to_device(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
61
+ """Move all tensors in the batch to the BatchRunner's device.
62
+
63
+ Args:
64
+ batch: Dictionary of tensors for a single batch.
65
+
66
+ Returns:
67
+ Dictionary of tensors moved to the target device.
68
+ """
69
+ return {k: v.to(self.device) for k, v in batch.items()}
70
+
71
+ def _run_callbacks(self, hook: str, data: dict) -> dict:
72
+ """Run the specified callback hook on the batch data.
73
+
74
+ Args:
75
+ hook: Name of the callback hook to run.
76
+ data: Dictionary containing batch data.
77
+
78
+ Returns:
79
+ Potentially modified batch data after callback execution.
80
+ """
81
+ if self.callback_manager is None:
82
+ return data
83
+ return self.callback_manager.run(hook, data)
84
+
85
+ def train_batch(self, batch: dict[str, Tensor]) -> Tensor:
86
+ """Run a single training batch.
87
+
88
+ Steps performed:
89
+ 1. Move batch to device and run "on_train_batch_start" callback.
90
+ 2. Forward pass through the network.
91
+ 3. Compute base loss using the criterion and accumulate metric.
92
+ 4. Apply constraint-based adjustments to the loss.
93
+ 5. Perform backward pass and optimizer step.
94
+ 6. Run "on_train_batch_end" callback.
95
+
96
+ Args:
97
+ batch: Dictionary of input and target tensors for the batch.
98
+
99
+ Returns:
100
+ Tensor: The base loss computed before constraint adjustments.
101
+ """
102
+ batch = self._to_device(batch)
103
+ batch = self._run_callbacks("on_train_batch_start", batch)
104
+
105
+ # Forward
106
+ batch = self.network(batch)
107
+ batch = self._run_callbacks("after_train_forward", batch)
108
+
109
+ # Base loss
110
+ loss: Tensor = self.criterion(
111
+ batch["output"],
112
+ batch["target"],
113
+ data=batch,
114
+ )
115
+
116
+ if self.metric_manager is not None:
117
+ self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
118
+
119
+ # Constraint-adjusted loss
120
+ combined_loss = self.constraint_engine.train(batch, loss)
121
+
122
+ # Backward
123
+ self.optimizer.zero_grad()
124
+ combined_loss.backward()
125
+ self.optimizer.step()
126
+
127
+ batch = self._run_callbacks("on_train_batch_end", batch)
128
+ return loss
129
+
130
+ def valid_batch(self, batch: dict[str, Tensor]) -> Tensor:
131
+ """Run a single validation batch.
132
+
133
+ Steps performed:
134
+ 1. Move batch to device and run "on_valid_batch_start" callback.
135
+ 2. Forward pass through the network.
136
+ 3. Compute base loss using the criterion and accumulate metric.
137
+ 4. Evaluate constraints via the ConstraintEngine (does not modify loss).
138
+ 5. Run "on_valid_batch_end" callback.
139
+
140
+ Args:
141
+ batch: Dictionary of input and target tensors for the batch.
142
+
143
+ Returns:
144
+ Tensor: The base loss computed for the batch.
145
+ """
146
+ batch = self._to_device(batch)
147
+ batch = self._run_callbacks("on_valid_batch_start", batch)
148
+
149
+ batch = self.network(batch)
150
+ batch = self._run_callbacks("after_valid_forward", batch)
151
+
152
+ loss: Tensor = self.criterion(
153
+ batch["output"],
154
+ batch["target"],
155
+ data=batch,
156
+ )
157
+
158
+ if self.metric_manager is not None:
159
+ self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
160
+
161
+ self.constraint_engine.validate(batch, loss)
162
+
163
+ batch = self._run_callbacks("on_valid_batch_end", batch)
164
+ return loss
165
+
166
+ def test_batch(self, batch: dict[str, Tensor]) -> Tensor:
167
+ """Run a single test batch.
168
+
169
+ Steps performed:
170
+ 1. Move batch to device and run "on_test_batch_start" callback.
171
+ 2. Forward pass through the network.
172
+ 3. Compute base loss using the criterion and accumulate metric.
173
+ 4. Evaluate constraints via the ConstraintEngine (does not modify loss).
174
+ 5. Run "on_test_batch_end" callback.
175
+
176
+ Args:
177
+ batch: Dictionary of input and target tensors for the batch.
178
+
179
+ Returns:
180
+ Tensor: The base loss computed for the batch.
181
+ """
182
+ batch = self._to_device(batch)
183
+ batch = self._run_callbacks("on_test_batch_start", batch)
184
+
185
+ batch = self.network(batch)
186
+ batch = self._run_callbacks("after_test_forward", batch)
187
+
188
+ loss: Tensor = self.criterion(
189
+ batch["output"],
190
+ batch["target"],
191
+ data=batch,
192
+ )
193
+
194
+ if self.metric_manager is not None:
195
+ self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
196
+
197
+ self.constraint_engine.test(batch, loss)
198
+
199
+ batch = self._run_callbacks("on_test_batch_end", batch)
200
+ return loss