congrads 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,242 @@
1
+ """Defines the abstract base class `Constraint` for specifying constraints on neural network outputs.
2
+
3
+ A `Constraint` monitors whether the network predictions satisfy certain
4
+ conditions during training, validation, and testing. It can optionally
5
+ adjust the loss to enforce constraints, and logs the relevant metrics.
6
+
7
+ Responsibilities:
8
+ - Track which network layers/tags the constraint applies to
9
+ - Check constraint satisfaction for a batch of predictions
10
+ - Compute adjustment directions to enforce the constraint
11
+ - Provide a rescale factor and enforcement flag to influence loss adjustment
12
+
13
+ Subclasses must implement the abstract methods:
14
+ - `check_constraint(data)`: Evaluate constraint satisfaction for a batch
15
+ - `calculate_direction(data)`: Compute directions to adjust predictions
16
+ """
17
+
18
+ import random
19
+ import string
20
+ import warnings
21
+ from abc import ABC, abstractmethod
22
+ from numbers import Number
23
+ from typing import Literal
24
+
25
+ from torch import Tensor
26
+
27
+ from congrads.descriptor import Descriptor
28
+ from congrads.utils.validation import validate_iterable, validate_type
29
+
30
+
31
+ class Constraint(ABC):
32
+ """Abstract base class for defining constraints applied to neural networks.
33
+
34
+ A `Constraint` specifies conditions that the neural network outputs
35
+ should satisfy. It supports monitoring constraint satisfaction
36
+ during training and can adjust loss to enforce constraints. Subclasses
37
+ must implement the `check_constraint` and `calculate_direction` methods.
38
+
39
+ Args:
40
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
41
+ name (str, optional): A unique name for the constraint. If not provided,
42
+ a name is generated based on the class name and a random suffix.
43
+ enforce (bool, optional): If False, only monitor the constraint
44
+ without adjusting the loss. Defaults to True.
45
+ rescale_factor (Number, optional): Factor to scale the
46
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
47
+ than 1 to give weight to the constraint.
48
+
49
+ Raises:
50
+ TypeError: If a provided attribute has an incompatible type.
51
+ ValueError: If any tag in `tags` is not
52
+ defined in the `descriptor`.
53
+
54
+ Note:
55
+ - If `rescale_factor <= 1`, a warning is issued.
56
+ - If `name` is not provided, a name is auto-generated,
57
+ and a warning is logged.
58
+
59
+ """
60
+
61
+ descriptor: Descriptor = None
62
+ device = None
63
+
64
+ def __init__(
65
+ self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
66
+ ) -> None:
67
+ """Initializes a new Constraint instance.
68
+
69
+ Args:
70
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
71
+ name (str, optional): A unique name for the constraint. If not
72
+ provided, a name is generated based on the class name and a
73
+ random suffix.
74
+ enforce (bool, optional): If False, only monitor the constraint
75
+ without adjusting the loss. Defaults to True.
76
+ rescale_factor (Number, optional): Factor to scale the
77
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
78
+ than 1 to give weight to the constraint.
79
+
80
+ Raises:
81
+ TypeError: If a provided attribute has an incompatible type.
82
+ ValueError: If any tag in `tags` is not defined in the `descriptor`.
83
+
84
+ Note:
85
+ - If `rescale_factor <= 1`, a warning is issued.
86
+ - If `name` is not provided, a name is auto-generated, and a
87
+ warning is logged.
88
+ """
89
+ # Init parent class
90
+ super().__init__()
91
+
92
+ # Type checking
93
+ validate_iterable("tags", tags, str)
94
+ validate_type("name", name, str, allow_none=True)
95
+ validate_type("enforce", enforce, bool)
96
+ validate_type("rescale_factor", rescale_factor, Number)
97
+
98
+ # Init object variables
99
+ self.tags = tags
100
+ self.rescale_factor = rescale_factor
101
+ self.initial_rescale_factor = rescale_factor
102
+ self.enforce = enforce
103
+
104
+ # Perform checks
105
+ if rescale_factor <= 1:
106
+ warnings.warn(
107
+ f"Rescale factor for constraint {name} is <= 1. The network "
108
+ "will favor general loss over the constraint-adjusted loss. "
109
+ "Is this intended behavior? Normally, the rescale factor "
110
+ "should always be larger than 1.",
111
+ stacklevel=2,
112
+ )
113
+
114
+ # If no constraint_name is set, generate one based
115
+ # on the class name and a random suffix
116
+ if name:
117
+ self.name = name
118
+ else:
119
+ random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
120
+ self.name = f"{self.__class__.__name__}_{random_suffix}"
121
+ warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
122
+
123
+ # Infer layers from descriptor and tags
124
+ self.layers = set()
125
+ for tag in self.tags:
126
+ if not self.descriptor.exists(tag):
127
+ raise ValueError(
128
+ f"The tag {tag} used with constraint "
129
+ f"{self.name} is not defined in the descriptor. Please "
130
+ "add it to the correct layer using "
131
+ "descriptor.add('layer', ...)."
132
+ )
133
+
134
+ layer, _ = self.descriptor.location(tag)
135
+ self.layers.add(layer)
136
+
137
+ @abstractmethod
138
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
139
+ """Evaluates whether the given model predictions satisfy the constraint.
140
+
141
+ 1 IS SATISFIED, 0 IS NOT SATISFIED
142
+
143
+ Args:
144
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
145
+
146
+ Returns:
147
+ tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
148
+ indicating whether the constraint is satisfied (with value 1.0
149
+ for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
150
+ mask that indicates the relevance of each sample (`True` for relevant
151
+ samples and `False` for irrelevant ones).
152
+ """
153
+ pass
154
+
155
+ @abstractmethod
156
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
157
+ """Compute adjustment directions to better satisfy the constraint.
158
+
159
+ Given the model predictions, input batch, and context, this method calculates the direction
160
+ in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
161
+
162
+ Args:
163
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
164
+
165
+ Returns:
166
+ dict[str, Tensor]: Dictionary mapping network layers to tensors that
167
+ specify the adjustment direction for each tag.
168
+ """
169
+ pass
170
+
171
+
172
+ class MonotonicityConstraint(Constraint, ABC):
173
+ """Abstract base class for monotonicity constraints.
174
+
175
+ Subclasses must define how monotonicity is evaluated and how corrective
176
+ directions are computed.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ tag_prediction: str,
182
+ tag_reference: str,
183
+ rescale_factor_lower: float = 1.5,
184
+ rescale_factor_upper: float = 1.75,
185
+ stable: bool = True,
186
+ direction: Literal["ascending", "descending"] = "ascending",
187
+ name: str = None,
188
+ enforce: bool = True,
189
+ ):
190
+ """Constraint that enforces monotonicity on a predicted output.
191
+
192
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
193
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
194
+
195
+ Args:
196
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
197
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
198
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
199
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
200
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
201
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
202
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
203
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
204
+ """
205
+ # Type checking
206
+ validate_type("rescale_factor_lower", rescale_factor_lower, float)
207
+ validate_type("rescale_factor_upper", rescale_factor_upper, float)
208
+ validate_type("stable", stable, bool)
209
+ validate_type("direction", direction, str)
210
+
211
+ # Compose constraint name
212
+ if name is None:
213
+ name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
214
+
215
+ # Init parent class
216
+ super().__init__({tag_prediction}, name, enforce, 1.0)
217
+
218
+ # Init variables
219
+ self.tag_prediction = tag_prediction
220
+ self.tag_reference = tag_reference
221
+ self.rescale_factor_lower = rescale_factor_lower
222
+ self.rescale_factor_upper = rescale_factor_upper
223
+ self.stable = stable
224
+ self.direction = direction
225
+ self.descending = direction == "descending"
226
+
227
+ # Init member variables
228
+ self.compared_rankings: Tensor = None
229
+
230
+ @abstractmethod
231
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
232
+ """Evaluate whether the monotonicity constraint is satisfied.
233
+
234
+ Implementations must set `self.compared_rankings` with per-sample
235
+ correction directions.
236
+ """
237
+ pass
238
+
239
+ @abstractmethod
240
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
241
+ """Return directions for monotonicity enforcement."""
242
+ pass