congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""Defines the abstract base class `Constraint` for specifying constraints on neural network outputs.
|
|
2
|
+
|
|
3
|
+
A `Constraint` monitors whether the network predictions satisfy certain
|
|
4
|
+
conditions during training, validation, and testing. It can optionally
|
|
5
|
+
adjust the loss to enforce constraints, and logs the relevant metrics.
|
|
6
|
+
|
|
7
|
+
Responsibilities:
|
|
8
|
+
- Track which network layers/tags the constraint applies to
|
|
9
|
+
- Check constraint satisfaction for a batch of predictions
|
|
10
|
+
- Compute adjustment directions to enforce the constraint
|
|
11
|
+
- Provide a rescale factor and enforcement flag to influence loss adjustment
|
|
12
|
+
|
|
13
|
+
Subclasses must implement the abstract methods:
|
|
14
|
+
- `check_constraint(data)`: Evaluate constraint satisfaction for a batch
|
|
15
|
+
- `calculate_direction(data)`: Compute directions to adjust predictions
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import random
|
|
19
|
+
import string
|
|
20
|
+
import warnings
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
from numbers import Number
|
|
23
|
+
from typing import Literal
|
|
24
|
+
|
|
25
|
+
from torch import Tensor
|
|
26
|
+
|
|
27
|
+
from congrads.descriptor import Descriptor
|
|
28
|
+
from congrads.utils.validation import validate_iterable, validate_type
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Constraint(ABC):
|
|
32
|
+
"""Abstract base class for defining constraints applied to neural networks.
|
|
33
|
+
|
|
34
|
+
A `Constraint` specifies conditions that the neural network outputs
|
|
35
|
+
should satisfy. It supports monitoring constraint satisfaction
|
|
36
|
+
during training and can adjust loss to enforce constraints. Subclasses
|
|
37
|
+
must implement the `check_constraint` and `calculate_direction` methods.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
41
|
+
name (str, optional): A unique name for the constraint. If not provided,
|
|
42
|
+
a name is generated based on the class name and a random suffix.
|
|
43
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
44
|
+
without adjusting the loss. Defaults to True.
|
|
45
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
46
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
47
|
+
than 1 to give weight to the constraint.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
51
|
+
ValueError: If any tag in `tags` is not
|
|
52
|
+
defined in the `descriptor`.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
- If `rescale_factor <= 1`, a warning is issued.
|
|
56
|
+
- If `name` is not provided, a name is auto-generated,
|
|
57
|
+
and a warning is logged.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
descriptor: Descriptor = None
|
|
62
|
+
device = None
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Initializes a new Constraint instance.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
71
|
+
name (str, optional): A unique name for the constraint. If not
|
|
72
|
+
provided, a name is generated based on the class name and a
|
|
73
|
+
random suffix.
|
|
74
|
+
enforce (bool, optional): If False, only monitor the constraint
|
|
75
|
+
without adjusting the loss. Defaults to True.
|
|
76
|
+
rescale_factor (Number, optional): Factor to scale the
|
|
77
|
+
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
78
|
+
than 1 to give weight to the constraint.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
82
|
+
ValueError: If any tag in `tags` is not defined in the `descriptor`.
|
|
83
|
+
|
|
84
|
+
Note:
|
|
85
|
+
- If `rescale_factor <= 1`, a warning is issued.
|
|
86
|
+
- If `name` is not provided, a name is auto-generated, and a
|
|
87
|
+
warning is logged.
|
|
88
|
+
"""
|
|
89
|
+
# Init parent class
|
|
90
|
+
super().__init__()
|
|
91
|
+
|
|
92
|
+
# Type checking
|
|
93
|
+
validate_iterable("tags", tags, str)
|
|
94
|
+
validate_type("name", name, str, allow_none=True)
|
|
95
|
+
validate_type("enforce", enforce, bool)
|
|
96
|
+
validate_type("rescale_factor", rescale_factor, Number)
|
|
97
|
+
|
|
98
|
+
# Init object variables
|
|
99
|
+
self.tags = tags
|
|
100
|
+
self.rescale_factor = rescale_factor
|
|
101
|
+
self.initial_rescale_factor = rescale_factor
|
|
102
|
+
self.enforce = enforce
|
|
103
|
+
|
|
104
|
+
# Perform checks
|
|
105
|
+
if rescale_factor <= 1:
|
|
106
|
+
warnings.warn(
|
|
107
|
+
f"Rescale factor for constraint {name} is <= 1. The network "
|
|
108
|
+
"will favor general loss over the constraint-adjusted loss. "
|
|
109
|
+
"Is this intended behavior? Normally, the rescale factor "
|
|
110
|
+
"should always be larger than 1.",
|
|
111
|
+
stacklevel=2,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# If no constraint_name is set, generate one based
|
|
115
|
+
# on the class name and a random suffix
|
|
116
|
+
if name:
|
|
117
|
+
self.name = name
|
|
118
|
+
else:
|
|
119
|
+
random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
|
120
|
+
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
121
|
+
warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
|
|
122
|
+
|
|
123
|
+
# Infer layers from descriptor and tags
|
|
124
|
+
self.layers = set()
|
|
125
|
+
for tag in self.tags:
|
|
126
|
+
if not self.descriptor.exists(tag):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"The tag {tag} used with constraint "
|
|
129
|
+
f"{self.name} is not defined in the descriptor. Please "
|
|
130
|
+
"add it to the correct layer using "
|
|
131
|
+
"descriptor.add('layer', ...)."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
layer, _ = self.descriptor.location(tag)
|
|
135
|
+
self.layers.add(layer)
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
139
|
+
"""Evaluates whether the given model predictions satisfy the constraint.
|
|
140
|
+
|
|
141
|
+
1 IS SATISFIED, 0 IS NOT SATISFIED
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
|
|
148
|
+
indicating whether the constraint is satisfied (with value 1.0
|
|
149
|
+
for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
|
|
150
|
+
mask that indicates the relevance of each sample (`True` for relevant
|
|
151
|
+
samples and `False` for irrelevant ones).
|
|
152
|
+
"""
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
@abstractmethod
|
|
156
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
157
|
+
"""Compute adjustment directions to better satisfy the constraint.
|
|
158
|
+
|
|
159
|
+
Given the model predictions, input batch, and context, this method calculates the direction
|
|
160
|
+
in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
dict[str, Tensor]: Dictionary mapping network layers to tensors that
|
|
167
|
+
specify the adjustment direction for each tag.
|
|
168
|
+
"""
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class MonotonicityConstraint(Constraint, ABC):
|
|
173
|
+
"""Abstract base class for monotonicity constraints.
|
|
174
|
+
|
|
175
|
+
Subclasses must define how monotonicity is evaluated and how corrective
|
|
176
|
+
directions are computed.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
tag_prediction: str,
|
|
182
|
+
tag_reference: str,
|
|
183
|
+
rescale_factor_lower: float = 1.5,
|
|
184
|
+
rescale_factor_upper: float = 1.75,
|
|
185
|
+
stable: bool = True,
|
|
186
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
187
|
+
name: str = None,
|
|
188
|
+
enforce: bool = True,
|
|
189
|
+
):
|
|
190
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
191
|
+
|
|
192
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
193
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
197
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
198
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
199
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
200
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
201
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
202
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
203
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
204
|
+
"""
|
|
205
|
+
# Type checking
|
|
206
|
+
validate_type("rescale_factor_lower", rescale_factor_lower, float)
|
|
207
|
+
validate_type("rescale_factor_upper", rescale_factor_upper, float)
|
|
208
|
+
validate_type("stable", stable, bool)
|
|
209
|
+
validate_type("direction", direction, str)
|
|
210
|
+
|
|
211
|
+
# Compose constraint name
|
|
212
|
+
if name is None:
|
|
213
|
+
name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
|
|
214
|
+
|
|
215
|
+
# Init parent class
|
|
216
|
+
super().__init__({tag_prediction}, name, enforce, 1.0)
|
|
217
|
+
|
|
218
|
+
# Init variables
|
|
219
|
+
self.tag_prediction = tag_prediction
|
|
220
|
+
self.tag_reference = tag_reference
|
|
221
|
+
self.rescale_factor_lower = rescale_factor_lower
|
|
222
|
+
self.rescale_factor_upper = rescale_factor_upper
|
|
223
|
+
self.stable = stable
|
|
224
|
+
self.direction = direction
|
|
225
|
+
self.descending = direction == "descending"
|
|
226
|
+
|
|
227
|
+
# Init member variables
|
|
228
|
+
self.compared_rankings: Tensor = None
|
|
229
|
+
|
|
230
|
+
@abstractmethod
|
|
231
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
232
|
+
"""Evaluate whether the monotonicity constraint is satisfied.
|
|
233
|
+
|
|
234
|
+
Implementations must set `self.compared_rankings` with per-sample
|
|
235
|
+
correction directions.
|
|
236
|
+
"""
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
@abstractmethod
|
|
240
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
241
|
+
"""Return directions for monotonicity enforcement."""
|
|
242
|
+
pass
|