congrads 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/callbacks/base.py +3 -0
- congrads/callbacks/registry.py +94 -35
- congrads/checkpoints.py +2 -0
- congrads/constraints/base.py +2 -0
- congrads/constraints/registry.py +36 -13
- congrads/core/batch_runner.py +2 -0
- congrads/core/congradscore.py +2 -0
- congrads/core/constraint_engine.py +2 -0
- congrads/core/epoch_runner.py +2 -0
- congrads/datasets/registry.py +11 -1
- congrads/descriptor.py +2 -0
- congrads/metrics.py +2 -0
- congrads/networks/registry.py +2 -0
- congrads/transformations/base.py +2 -0
- congrads/transformations/registry.py +2 -0
- congrads/utils/preprocessors.py +2 -0
- congrads/utils/utility.py +13 -4
- congrads/utils/validation.py +9 -0
- {congrads-0.3.0.dist-info → congrads-0.3.1.post1.dist-info}/METADATA +1 -1
- congrads-0.3.1.post1.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/RECORD +0 -23
- {congrads-0.3.0.dist-info → congrads-0.3.1.post1.dist-info}/WHEEL +0 -0
congrads/callbacks/base.py
CHANGED
congrads/callbacks/registry.py
CHANGED
|
@@ -7,23 +7,33 @@ collect all callback implementations in one place for easy reference
|
|
|
7
7
|
and import, and can be extended as new callbacks are added.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
+
from torch import Tensor
|
|
10
11
|
from torch.utils.tensorboard import SummaryWriter
|
|
11
12
|
|
|
12
13
|
from ..callbacks.base import Callback
|
|
13
14
|
from ..metrics import MetricManager
|
|
14
15
|
from ..utils.utility import CSVLogger
|
|
15
16
|
|
|
17
|
+
__all__ = ["LoggerCallback"]
|
|
18
|
+
|
|
16
19
|
|
|
17
20
|
class LoggerCallback(Callback):
|
|
18
|
-
"""Callback to
|
|
21
|
+
"""Callback to periodically aggregate and store metrics during training and testing.
|
|
22
|
+
|
|
23
|
+
This callback works in conjunction with a MetricManager that accumulates metrics
|
|
24
|
+
internally (e.g. per batch). Metrics are:
|
|
19
25
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
26
|
+
- Aggregated at a configurable epoch interval (`aggregate_interval`)
|
|
27
|
+
- Cached in memory (GPU-resident tensors)
|
|
28
|
+
- Written to TensorBoard and CSV at a separate interval (`store_interval`)
|
|
29
|
+
|
|
30
|
+
Aggregation and storage are decoupled to avoid unnecessary GPU-to-CPU
|
|
31
|
+
synchronization. Any remaining cached metrics are flushed at the end of training.
|
|
23
32
|
|
|
24
33
|
Methods implemented:
|
|
25
|
-
- on_epoch_end:
|
|
26
|
-
-
|
|
34
|
+
- on_epoch_end: Periodically aggregates and stores training metrics.
|
|
35
|
+
- on_train_end: Flushes any remaining cached training metrics.
|
|
36
|
+
- on_test_end: Aggregates and stores test metrics immediately.
|
|
27
37
|
"""
|
|
28
38
|
|
|
29
39
|
def __init__(
|
|
@@ -31,6 +41,9 @@ class LoggerCallback(Callback):
|
|
|
31
41
|
metric_manager: MetricManager,
|
|
32
42
|
tensorboard_logger: SummaryWriter,
|
|
33
43
|
csv_logger: CSVLogger,
|
|
44
|
+
*,
|
|
45
|
+
aggregate_interval: int = 1,
|
|
46
|
+
store_interval: int = 1,
|
|
34
47
|
):
|
|
35
48
|
"""Initialize the LoggerCallback.
|
|
36
49
|
|
|
@@ -38,69 +51,115 @@ class LoggerCallback(Callback):
|
|
|
38
51
|
metric_manager: Instance of MetricManager used to collect metrics.
|
|
39
52
|
tensorboard_logger: TensorBoard SummaryWriter instance for logging scalars.
|
|
40
53
|
csv_logger: CSVLogger instance for logging metrics to CSV files.
|
|
54
|
+
aggregate_interval: Number of epochs between metric aggregation.
|
|
55
|
+
store_interval: Number of epochs between metric storage.
|
|
41
56
|
"""
|
|
42
57
|
super().__init__()
|
|
58
|
+
|
|
59
|
+
# Input validation
|
|
60
|
+
if aggregate_interval <= 0 or store_interval <= 0:
|
|
61
|
+
raise ValueError("Intervals must be positive integers")
|
|
62
|
+
|
|
63
|
+
if store_interval % aggregate_interval != 0:
|
|
64
|
+
raise ValueError("store_interval must be a multiple of aggregate_interval")
|
|
65
|
+
|
|
66
|
+
# Store references
|
|
43
67
|
self.metric_manager = metric_manager
|
|
44
68
|
self.tensorboard_logger = tensorboard_logger
|
|
45
69
|
self.csv_logger = csv_logger
|
|
70
|
+
self.aggregate_interval = aggregate_interval
|
|
71
|
+
self.store_interval = store_interval
|
|
72
|
+
|
|
73
|
+
# Cached metrics on GPU by epoch
|
|
74
|
+
self._accumulated_metrics: dict[int, dict[str, Tensor]] = {}
|
|
46
75
|
|
|
47
76
|
def on_epoch_end(self, data: dict[str, any], ctx: dict[str, any]):
|
|
48
|
-
"""
|
|
77
|
+
"""Handle end-of-epoch training logic.
|
|
78
|
+
|
|
79
|
+
At the end of each epoch, this method may:
|
|
80
|
+
- Aggregate training metrics from the MetricManager (every `aggregate_interval` epochs)
|
|
81
|
+
- Cache aggregated metrics keyed by epoch
|
|
82
|
+
- Store cached metrics to disk (every `store_interval` epochs)
|
|
49
83
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
for the next epoch.
|
|
84
|
+
Metric aggregation resets the MetricManager accumulation state.
|
|
85
|
+
Metric storage triggers GPU-to-CPU synchronization and writes to loggers.
|
|
53
86
|
|
|
54
87
|
Args:
|
|
55
|
-
data: Dictionary containing
|
|
56
|
-
ctx: Additional context dictionary (unused
|
|
88
|
+
data: Dictionary containing epoch context (must include 'epoch').
|
|
89
|
+
ctx: Additional context dictionary (unused).
|
|
57
90
|
|
|
58
91
|
Returns:
|
|
59
92
|
data: The same input dictionary, unmodified.
|
|
60
93
|
"""
|
|
61
94
|
epoch = data["epoch"]
|
|
62
95
|
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
self.
|
|
67
|
-
self.
|
|
96
|
+
# Cache training metrics
|
|
97
|
+
if epoch % self.aggregate_interval == 0:
|
|
98
|
+
metrics = self.metric_manager.aggregate("during_training")
|
|
99
|
+
self._accumulated_metrics[epoch] = metrics
|
|
100
|
+
self.metric_manager.reset("during_training")
|
|
68
101
|
|
|
69
|
-
#
|
|
70
|
-
self.
|
|
71
|
-
|
|
102
|
+
# Store metrics to disk
|
|
103
|
+
if epoch % self.store_interval == 0:
|
|
104
|
+
self._save(self._accumulated_metrics)
|
|
105
|
+
self._accumulated_metrics.clear()
|
|
106
|
+
|
|
107
|
+
return data
|
|
108
|
+
|
|
109
|
+
def on_train_end(self, data, ctx):
|
|
110
|
+
"""Flush any remaining cached training metrics at the end of training.
|
|
72
111
|
|
|
73
|
-
|
|
74
|
-
|
|
112
|
+
This ensures that aggregated metrics that were not yet written due to
|
|
113
|
+
`store_interval` alignment are persisted before training terminates.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
data: Dictionary containing training context (unused).
|
|
117
|
+
ctx: Additional context dictionary (unused).
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
data: The same input dictionary, unmodified.
|
|
121
|
+
"""
|
|
122
|
+
if self._accumulated_metrics:
|
|
123
|
+
self._save(self._accumulated_metrics)
|
|
124
|
+
self._accumulated_metrics.clear()
|
|
75
125
|
|
|
76
126
|
return data
|
|
77
127
|
|
|
78
128
|
def on_test_end(self, data: dict[str, any], ctx: dict[str, any]):
|
|
79
|
-
"""
|
|
129
|
+
"""Aggregate and store test metrics at the end of testing.
|
|
80
130
|
|
|
81
|
-
|
|
82
|
-
|
|
131
|
+
Test metrics are aggregated once and written immediately to disk.
|
|
132
|
+
Interval-based aggregation and caching are not applied to testing.
|
|
83
133
|
|
|
84
134
|
Args:
|
|
85
135
|
data: Dictionary containing test context (must include 'epoch').
|
|
86
|
-
ctx: Additional context dictionary (unused
|
|
136
|
+
ctx: Additional context dictionary (unused).
|
|
87
137
|
|
|
88
138
|
Returns:
|
|
89
139
|
data: The same input dictionary, unmodified.
|
|
90
140
|
"""
|
|
91
141
|
epoch = data["epoch"]
|
|
92
142
|
|
|
93
|
-
#
|
|
143
|
+
# Save test metrics
|
|
94
144
|
metrics = self.metric_manager.aggregate("after_training")
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
145
|
+
self._save({epoch: metrics})
|
|
146
|
+
self.metric_manager.reset("after_training")
|
|
147
|
+
|
|
148
|
+
return data
|
|
149
|
+
|
|
150
|
+
def _save(self, metrics: dict[int, dict[str, Tensor]]):
|
|
151
|
+
"""Write aggregated metrics to TensorBoard and CSV loggers.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
metrics: Mapping from epoch to a dictionary of metric name to scalar tensor.
|
|
155
|
+
Tensors are expected to be detached and graph-free.
|
|
156
|
+
"""
|
|
157
|
+
for epoch, metrics_by_name in metrics.items():
|
|
158
|
+
for name, value in metrics_by_name.items():
|
|
159
|
+
cpu_value = value.item()
|
|
160
|
+
self.tensorboard_logger.add_scalar(name, cpu_value, epoch)
|
|
161
|
+
self.csv_logger.add_value(name, cpu_value, epoch)
|
|
98
162
|
|
|
99
163
|
# Flush/save
|
|
100
164
|
self.tensorboard_logger.flush()
|
|
101
165
|
self.csv_logger.save()
|
|
102
|
-
|
|
103
|
-
# Reset metric manager for test
|
|
104
|
-
self.metric_manager.reset("after_training")
|
|
105
|
-
|
|
106
|
-
return data
|
congrads/checkpoints.py
CHANGED
|
@@ -16,6 +16,8 @@ from torch.optim import Optimizer
|
|
|
16
16
|
from .metrics import MetricManager
|
|
17
17
|
from .utils.validation import validate_callable, validate_type
|
|
18
18
|
|
|
19
|
+
__all__ = ["CheckpointManager"]
|
|
20
|
+
|
|
19
21
|
|
|
20
22
|
class CheckpointManager:
|
|
21
23
|
"""Manage saving and loading checkpoints for PyTorch models and optimizers.
|
congrads/constraints/base.py
CHANGED
|
@@ -27,6 +27,8 @@ from torch import Tensor
|
|
|
27
27
|
from congrads.descriptor import Descriptor
|
|
28
28
|
from congrads.utils.validation import validate_iterable, validate_type
|
|
29
29
|
|
|
30
|
+
__all__ = ["Constraint", "MonotonicityConstraint"]
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
class Constraint(ABC):
|
|
32
34
|
"""Abstract base class for defining constraints applied to neural networks.
|
congrads/constraints/registry.py
CHANGED
|
@@ -54,6 +54,20 @@ from ..transformations.registry import IdentityTransformation
|
|
|
54
54
|
from ..utils.validation import validate_comparator, validate_iterable, validate_type
|
|
55
55
|
from .base import Constraint, MonotonicityConstraint
|
|
56
56
|
|
|
57
|
+
__all__ = [
|
|
58
|
+
"ImplicationConstraint",
|
|
59
|
+
"ScalarConstraint",
|
|
60
|
+
"BinaryConstraint",
|
|
61
|
+
"SumConstraint",
|
|
62
|
+
"RankedMonotonicityConstraint",
|
|
63
|
+
"PairwiseMonotonicityConstraint",
|
|
64
|
+
"PerGroupMonotonicityConstraint",
|
|
65
|
+
"EncodedGroupedMonotonicityConstraint",
|
|
66
|
+
"ANDConstraint",
|
|
67
|
+
"ORConstraint",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
57
71
|
COMPARATOR_MAP: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
|
|
58
72
|
">": torch.gt,
|
|
59
73
|
">=": torch.ge,
|
|
@@ -121,6 +135,7 @@ class ImplicationConstraint(Constraint):
|
|
|
121
135
|
constraint (1 if satisfied, 0 otherwise).
|
|
122
136
|
- head_satisfaction: Tensor indicating satisfaction of the
|
|
123
137
|
head constraint alone.
|
|
138
|
+
|
|
124
139
|
"""
|
|
125
140
|
# Check satisfaction of head and body constraints
|
|
126
141
|
head_satisfaction, _ = self.head.check_constraint(data)
|
|
@@ -420,6 +435,7 @@ class BinaryConstraint(Constraint):
|
|
|
420
435
|
(1 for satisfied, 0 for violated) for each sample.
|
|
421
436
|
- mask (Tensor): Tensor of ones with the same shape as `result`,
|
|
422
437
|
used for constraint aggregation.
|
|
438
|
+
|
|
423
439
|
"""
|
|
424
440
|
# Select relevant columns
|
|
425
441
|
selection_left = self.descriptor.select(self.tag_left, data)
|
|
@@ -601,6 +617,7 @@ class SumConstraint(Constraint):
|
|
|
601
617
|
- result (Tensor): Binary tensor indicating whether the constraint
|
|
602
618
|
is satisfied (1) or violated (0) for each sample.
|
|
603
619
|
- mask (Tensor): Tensor of ones, used for constraint aggregation.
|
|
620
|
+
|
|
604
621
|
"""
|
|
605
622
|
|
|
606
623
|
def compute_weighted_sum(
|
|
@@ -825,7 +842,7 @@ class PairwiseMonotonicityConstraint(MonotonicityConstraint):
|
|
|
825
842
|
|
|
826
843
|
# Consider only upper triangle to avoid duplicate comparisons
|
|
827
844
|
batch_size = preds.shape[0]
|
|
828
|
-
mask = triu(ones(batch_size, batch_size, dtype=torch.bool), diagonal=1)
|
|
845
|
+
mask = triu(ones(batch_size, batch_size, dtype=torch.bool, device=preds.device), diagonal=1)
|
|
829
846
|
|
|
830
847
|
# Pairwise violations
|
|
831
848
|
violations = (preds_diff * targets_diff < 0) & mask
|
|
@@ -864,8 +881,7 @@ class PerGroupMonotonicityConstraint(Constraint):
|
|
|
864
881
|
|
|
865
882
|
Each group is treated as an independent mini-batch:
|
|
866
883
|
- The base constraint is applied to the group's subset of data.
|
|
867
|
-
- Violations and directions are computed per group and then reassembled
|
|
868
|
-
into the original batch order.
|
|
884
|
+
- Violations and directions are computed per group and then reassembled into the original batch order.
|
|
869
885
|
|
|
870
886
|
This is an explicit alternative to :class:`EncodedGroupedMonotonicityConstraint`,
|
|
871
887
|
which enforces the same logic using vectorized interval encoding.
|
|
@@ -986,13 +1002,16 @@ class EncodedGroupedMonotonicityConstraint(Constraint):
|
|
|
986
1002
|
"""Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval."""
|
|
987
1003
|
# Get data and keys
|
|
988
1004
|
ids = self.descriptor.select(self.tag_group, data)
|
|
989
|
-
preds = self.descriptor.select(self.base.tag_prediction, data)
|
|
990
|
-
targets = self.descriptor.select(self.base.tag_reference, data)
|
|
991
1005
|
preds_key, _ = self.descriptor.location(self.base.tag_prediction)
|
|
992
1006
|
targets_key, _ = self.descriptor.location(self.base.tag_reference)
|
|
993
1007
|
|
|
994
|
-
|
|
995
|
-
|
|
1008
|
+
preds = data[preds_key]
|
|
1009
|
+
targets = data[targets_key]
|
|
1010
|
+
|
|
1011
|
+
new_preds = preds + ids * (preds.amax(dim=0) - preds.amin(dim=0) + 1)
|
|
1012
|
+
new_targets = targets + self.negation * ids * (
|
|
1013
|
+
targets.amax(dim=0) - targets.amin(dim=0) + 1
|
|
1014
|
+
)
|
|
996
1015
|
|
|
997
1016
|
# Create new batch for child constraint
|
|
998
1017
|
new_data = {preds_key: new_preds, targets_key: new_targets}
|
|
@@ -1018,13 +1037,14 @@ class ANDConstraint(Constraint):
|
|
|
1018
1037
|
are satisfied (elementwise logical AND).
|
|
1019
1038
|
* The corrective direction is computed by weighting each sub-constraint's
|
|
1020
1039
|
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1040
|
+
|
|
1021
1041
|
"""
|
|
1022
1042
|
|
|
1023
1043
|
def __init__(
|
|
1024
1044
|
self,
|
|
1025
1045
|
*constraints: Constraint,
|
|
1026
1046
|
name: str = None,
|
|
1027
|
-
|
|
1047
|
+
enforce: bool = False,
|
|
1028
1048
|
rescale_factor: Number = 1.5,
|
|
1029
1049
|
) -> None:
|
|
1030
1050
|
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
@@ -1041,7 +1061,7 @@ class ANDConstraint(Constraint):
|
|
|
1041
1061
|
name (str, optional): A custom name for this constraint. If not provided,
|
|
1042
1062
|
the name will be composed from the sub-constraint names joined with
|
|
1043
1063
|
" AND ".
|
|
1044
|
-
|
|
1064
|
+
enforce (bool, optional): If True, the constraint will be monitored
|
|
1045
1065
|
but not enforced. Defaults to False.
|
|
1046
1066
|
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1047
1067
|
corrections. Defaults to 1.5.
|
|
@@ -1062,7 +1082,7 @@ class ANDConstraint(Constraint):
|
|
|
1062
1082
|
super().__init__(
|
|
1063
1083
|
set().union(*(constraint.tags for constraint in constraints)),
|
|
1064
1084
|
name,
|
|
1065
|
-
|
|
1085
|
+
enforce,
|
|
1066
1086
|
rescale_factor,
|
|
1067
1087
|
)
|
|
1068
1088
|
|
|
@@ -1083,6 +1103,7 @@ class ANDConstraint(Constraint):
|
|
|
1083
1103
|
* `mask`: A tensor of ones with the same shape as
|
|
1084
1104
|
`total_satisfaction`. Typically used as a weighting mask
|
|
1085
1105
|
in downstream processing.
|
|
1106
|
+
|
|
1086
1107
|
"""
|
|
1087
1108
|
total_satisfaction: Tensor = None
|
|
1088
1109
|
total_mask: Tensor = None
|
|
@@ -1141,13 +1162,14 @@ class ORConstraint(Constraint):
|
|
|
1141
1162
|
is satisfied (elementwise logical OR).
|
|
1142
1163
|
* The corrective direction is computed by weighting each sub-constraint's
|
|
1143
1164
|
direction with its satisfaction mask and summing across all sub-constraints.
|
|
1165
|
+
|
|
1144
1166
|
"""
|
|
1145
1167
|
|
|
1146
1168
|
def __init__(
|
|
1147
1169
|
self,
|
|
1148
1170
|
*constraints: Constraint,
|
|
1149
1171
|
name: str = None,
|
|
1150
|
-
|
|
1172
|
+
enforce: bool = False,
|
|
1151
1173
|
rescale_factor: Number = 1.5,
|
|
1152
1174
|
) -> None:
|
|
1153
1175
|
"""A composite constraint that enforces the logical OR of multiple constraints.
|
|
@@ -1164,7 +1186,7 @@ class ORConstraint(Constraint):
|
|
|
1164
1186
|
name (str, optional): A custom name for this constraint. If not provided,
|
|
1165
1187
|
the name will be composed from the sub-constraint names joined with
|
|
1166
1188
|
" OR ".
|
|
1167
|
-
|
|
1189
|
+
enforce (bool, optional): If True, the constraint will be monitored
|
|
1168
1190
|
but not enforced. Defaults to False.
|
|
1169
1191
|
rescale_factor (Number, optional): A scaling factor applied when rescaling
|
|
1170
1192
|
corrections. Defaults to 1.5.
|
|
@@ -1185,7 +1207,7 @@ class ORConstraint(Constraint):
|
|
|
1185
1207
|
super().__init__(
|
|
1186
1208
|
set().union(*(constraint.tags for constraint in constraints)),
|
|
1187
1209
|
name,
|
|
1188
|
-
|
|
1210
|
+
enforce,
|
|
1189
1211
|
rescale_factor,
|
|
1190
1212
|
)
|
|
1191
1213
|
|
|
@@ -1206,6 +1228,7 @@ class ORConstraint(Constraint):
|
|
|
1206
1228
|
* `mask`: A tensor of ones with the same shape as
|
|
1207
1229
|
`total_satisfaction`. Typically used as a weighting mask
|
|
1208
1230
|
in downstream processing.
|
|
1231
|
+
|
|
1209
1232
|
"""
|
|
1210
1233
|
total_satisfaction: Tensor = None
|
|
1211
1234
|
total_mask: Tensor = None
|
congrads/core/batch_runner.py
CHANGED
|
@@ -18,6 +18,8 @@ from ..callbacks.base import CallbackManager
|
|
|
18
18
|
from ..core.constraint_engine import ConstraintEngine
|
|
19
19
|
from ..metrics import MetricManager
|
|
20
20
|
|
|
21
|
+
__all__ = ["BatchRunner"]
|
|
22
|
+
|
|
21
23
|
|
|
22
24
|
class BatchRunner:
|
|
23
25
|
"""Executes a single batch for training, validation, or testing.
|
congrads/core/congradscore.py
CHANGED
|
@@ -41,6 +41,8 @@ from ..core.epoch_runner import EpochRunner
|
|
|
41
41
|
from ..descriptor import Descriptor
|
|
42
42
|
from ..metrics import MetricManager
|
|
43
43
|
|
|
44
|
+
__all__ = ["CongradsCore"]
|
|
45
|
+
|
|
44
46
|
|
|
45
47
|
class CongradsCore:
|
|
46
48
|
"""The CongradsCore class is the central training engine for constraint-guided optimization.
|
|
@@ -14,6 +14,8 @@ from ..constraints.base import Constraint
|
|
|
14
14
|
from ..descriptor import Descriptor
|
|
15
15
|
from ..metrics import MetricManager
|
|
16
16
|
|
|
17
|
+
__all__ = ["ConstraintEngine"]
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
class ConstraintEngine:
|
|
19
21
|
"""Manages constraint evaluation and enforcement for a neural network.
|
congrads/core/epoch_runner.py
CHANGED
congrads/datasets/registry.py
CHANGED
|
@@ -5,9 +5,11 @@ downloading, loading, and transforming specific datasets where applicable.
|
|
|
5
5
|
|
|
6
6
|
Classes:
|
|
7
7
|
|
|
8
|
-
- SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
|
|
9
8
|
- BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
|
|
10
9
|
- FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
|
|
10
|
+
- SectionedGaussians: A synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
|
|
11
|
+
- SyntheticMonotonicity: A synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
|
|
12
|
+
- SyntheticClusters: A dataset class for generating synthetic clustered 2D data with labels.
|
|
11
13
|
|
|
12
14
|
Each dataset class provides methods for downloading the data
|
|
13
15
|
(if not already available or synthetic), checking the integrity of the dataset, loading
|
|
@@ -31,6 +33,14 @@ from torchvision.datasets.utils import (
|
|
|
31
33
|
download_and_extract_archive,
|
|
32
34
|
)
|
|
33
35
|
|
|
36
|
+
__all__ = [
|
|
37
|
+
"BiasCorrection",
|
|
38
|
+
"FamilyIncome",
|
|
39
|
+
"SectionedGaussians",
|
|
40
|
+
"SyntheticMonotonicity",
|
|
41
|
+
"SyntheticClusters",
|
|
42
|
+
]
|
|
43
|
+
|
|
34
44
|
|
|
35
45
|
class BiasCorrection(Dataset):
|
|
36
46
|
"""A dataset class for accessing the Bias Correction dataset.
|
congrads/descriptor.py
CHANGED
congrads/metrics.py
CHANGED
congrads/networks/registry.py
CHANGED
congrads/transformations/base.py
CHANGED
|
@@ -7,6 +7,8 @@ from torch import Tensor
|
|
|
7
7
|
from ..utils.validation import validate_callable, validate_type
|
|
8
8
|
from .base import Transformation
|
|
9
9
|
|
|
10
|
+
__all__ = ["IdentityTransformation", "DenormalizeMinMax", "ApplyOperator"]
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class IdentityTransformation(Transformation):
|
|
12
14
|
"""A transformation that returns the input unchanged."""
|
congrads/utils/preprocessors.py
CHANGED
|
@@ -12,6 +12,8 @@ normalization, feature engineering, constraint filtering, and sampling.
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import pandas as pd
|
|
14
14
|
|
|
15
|
+
__all__ = ["preprocess_BiasCorrection", "preprocess_FamilyIncome", "preprocess_AdultCensusIncome"]
|
|
16
|
+
|
|
15
17
|
|
|
16
18
|
def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
|
|
17
19
|
"""Preprocesses the given dataframe for bias correction by performing a series of transformations.
|
congrads/utils/utility.py
CHANGED
|
@@ -13,6 +13,16 @@ from torch import Generator, Tensor, argsort, cat, int32, unique
|
|
|
13
13
|
from torch.nn.modules.loss import _Loss
|
|
14
14
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
15
15
|
|
|
16
|
+
__all__ = [
|
|
17
|
+
"CSVLogger",
|
|
18
|
+
"split_data_loaders",
|
|
19
|
+
"ZeroLoss",
|
|
20
|
+
"LossWrapper",
|
|
21
|
+
"process_data_monotonicity_constraint",
|
|
22
|
+
"DictDatasetWrapper",
|
|
23
|
+
"Seeder",
|
|
24
|
+
]
|
|
25
|
+
|
|
16
26
|
|
|
17
27
|
class CSVLogger:
|
|
18
28
|
"""A utility class for logging key-value pairs to a CSV file, organized by epochs.
|
|
@@ -466,8 +476,7 @@ class Seeder:
|
|
|
466
476
|
"""Initialize the Seeder with a base seed.
|
|
467
477
|
|
|
468
478
|
Args:
|
|
469
|
-
base_seed (int): The initial seed from which all subsequent
|
|
470
|
-
pseudo-random seeds are deterministically derived.
|
|
479
|
+
base_seed (int): The initial seed from which all subsequent seudo-random seeds are deterministically derived.
|
|
471
480
|
"""
|
|
472
481
|
self._rng = random.Random(base_seed)
|
|
473
482
|
|
|
@@ -486,8 +495,7 @@ class Seeder:
|
|
|
486
495
|
def set_reproducible(self) -> None:
|
|
487
496
|
"""Configure global random states for reproducibility.
|
|
488
497
|
|
|
489
|
-
Seeds the following libraries with deterministically generated
|
|
490
|
-
seeds based on the base seed:
|
|
498
|
+
Seeds the following libraries with deterministically generated seeds based on the base seed:
|
|
491
499
|
- Python's built-in `random`
|
|
492
500
|
- NumPy's random number generator
|
|
493
501
|
- PyTorch (CPU and GPU)
|
|
@@ -496,6 +504,7 @@ class Seeder:
|
|
|
496
504
|
- Seeding all CUDA devices
|
|
497
505
|
- Disabling CuDNN benchmarking
|
|
498
506
|
- Enabling CuDNN deterministic mode
|
|
507
|
+
|
|
499
508
|
"""
|
|
500
509
|
random.seed(self.roll_seed())
|
|
501
510
|
np.random.seed(self.roll_seed())
|
congrads/utils/validation.py
CHANGED
|
@@ -7,6 +7,15 @@ validation functions.
|
|
|
7
7
|
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
9
|
|
|
10
|
+
__all__ = [
|
|
11
|
+
"validate_type",
|
|
12
|
+
"validate_iterable",
|
|
13
|
+
"validate_comparator",
|
|
14
|
+
"validate_callable",
|
|
15
|
+
"validate_callable_iterable",
|
|
16
|
+
"validate_loaders",
|
|
17
|
+
]
|
|
18
|
+
|
|
10
19
|
|
|
11
20
|
def validate_type(name, value, expected_types, allow_none=False):
|
|
12
21
|
"""Validate that a value is of the specified type(s).
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: congrads
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.1.post1
|
|
4
4
|
Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
|
|
5
5
|
Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
|
|
6
6
|
Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
congrads/__init__.py,sha256=XJKWRteSvmTYawgS1Pon8kWhhd3haKo4RGAsyjEGS8Q,383
|
|
2
|
+
congrads/callbacks/base.py,sha256=F1eanEwEOvnBGZ_HKNdfN0QWuK9QFhQ9RbjrvjVJZ00,13771
|
|
3
|
+
congrads/callbacks/registry.py,sha256=UtQuoG48GR0zhOKNIhiUxVvD7Z46ClSiVitx8UAI244,6268
|
|
4
|
+
congrads/checkpoints.py,sha256=sWOL_v9ox0zajmNshNSKHP6l8f-p1RzwqyJIdPQnvjg,7261
|
|
5
|
+
congrads/constraints/base.py,sha256=j-DfdY7C4TJynglUn7JpH_YsREGXI-K9x1XCMEXSqZg,10157
|
|
6
|
+
congrads/constraints/registry.py,sha256=Ki6waNtnbcAA5F7V3nkSOKEnpyGS3g1UKTJwyR3CgC8,52577
|
|
7
|
+
congrads/core/batch_runner.py,sha256=jo7KP-64US3BxTO-zie2pD0zxdVPxNtQDx0TdRpFwFE,6994
|
|
8
|
+
congrads/core/congradscore.py,sha256=S8_ph3vs5LNl53u2y6-l-vKetOXYW5Xku2AeIl7mct8,12098
|
|
9
|
+
congrads/core/constraint_engine.py,sha256=y0I7cxm2oslovXMtT3V5hw3_TZm0NsP1pAgI5uqTyJw,9026
|
|
10
|
+
congrads/core/epoch_runner.py,sha256=8mCsfuAnux3Ws_dohD8HQUsCnSo_TxWlOcHwX6cq4ys,4120
|
|
11
|
+
congrads/datasets/registry.py,sha256=ZYLXnHxCRk0tgVcwsntNfezGdJNLVUFEKWKGMsz1Dsw,31477
|
|
12
|
+
congrads/descriptor.py,sha256=uzAuR_qGM1zzDgeLnSEwbUvvvT1eDL_2hFwwq-I9WoI,6940
|
|
13
|
+
congrads/metrics.py,sha256=Bb0g7KEulytCPTwnT-Q4hfIXwurZuyjPsydx2YmoaqY,4664
|
|
14
|
+
congrads/networks/registry.py,sha256=xxSpJDmkMlIG_1f-aEzmjUTNekVrR9uWDJVv7xgWofQ,2278
|
|
15
|
+
congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
congrads/transformations/base.py,sha256=ZjPCqa5GoqWuaWcP4afrGm0eSEXQoZ4RylGPRKwd48w,930
|
|
17
|
+
congrads/transformations/registry.py,sha256=ZF1wuPsG5A5jse8jZcgx7pyQ5htJ4VVX1_PzVLvADLs,2683
|
|
18
|
+
congrads/utils/preprocessors.py,sha256=8AvY-6kX47SPsQTIaxLl7yBcROlv3pjLCmsG7SLDTqM,18290
|
|
19
|
+
congrads/utils/utility.py,sha256=CWiLZ7tv3dmHA3h8qeWJ8abUKhVckzFZWOyuAweQzhw,19046
|
|
20
|
+
congrads/utils/validation.py,sha256=dleerBb0sAqnn3AjhcUxIRGFnqImWsDbfNNvBVmopsI,6379
|
|
21
|
+
congrads-0.3.1.post1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
22
|
+
congrads-0.3.1.post1.dist-info/METADATA,sha256=E8GcJ6cKT0QX4P8MNrKRgxmLX_-ggcqyb010gsUZ3fQ,10754
|
|
23
|
+
congrads-0.3.1.post1.dist-info/RECORD,,
|
congrads-0.3.0.dist-info/RECORD
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
congrads/__init__.py,sha256=XJKWRteSvmTYawgS1Pon8kWhhd3haKo4RGAsyjEGS8Q,383
|
|
2
|
-
congrads/callbacks/base.py,sha256=OChXls-tndgJQOXNfqavnPywHZHn3N87yLKD4kbHDHk,13714
|
|
3
|
-
congrads/callbacks/registry.py,sha256=KkzjDqMS3CkE__PpGrmAEYwRngqGSdQNE8NVWl7ogeA,3898
|
|
4
|
-
congrads/checkpoints.py,sha256=V79n3mqjB48nbNkBELqKDg9iou0b1vc5eRrlcu8aIA4,7228
|
|
5
|
-
congrads/constraints/base.py,sha256=k9OyPS2A4bP3fSEAEANGuw7zofiWlGIxqb5ows1LQWs,10105
|
|
6
|
-
congrads/constraints/registry.py,sha256=k__RfcXle-qDL9OJ-nfwgL9zeM6-ISwgQyIzmx-lsgc,52302
|
|
7
|
-
congrads/core/batch_runner.py,sha256=emc7smJLDHq0J8_J9t9X0RtqrXaYwOP9mmhlX_M78e4,6967
|
|
8
|
-
congrads/core/congradscore.py,sha256=9ZKUVMB9RbmudtuC-MQcNColBYUjb6XLqb0eISzfrGk,12070
|
|
9
|
-
congrads/core/constraint_engine.py,sha256=UEt-tmtJeJX0Wu3ol17Z0A9hacL0F8oouJUwHgIIoDE,8994
|
|
10
|
-
congrads/core/epoch_runner.py,sha256=l0x3uLXQ5I5o1C63wXgL4_QkhFmXxW-jeejNJK6sf18,4093
|
|
11
|
-
congrads/datasets/registry.py,sha256=RfffRiA7Qijc69cJTBJhItTZ8x9B-p1kXMjvcfEC_nA,31102
|
|
12
|
-
congrads/descriptor.py,sha256=tUHF4vvyNzJP5vpq1xn0uhKnOlAkElwG2R9gG4glHvQ,6914
|
|
13
|
-
congrads/metrics.py,sha256=e52QC8yNKsxAndjC3U4WMUnQ_0GmiSlExKtxRRShHao,4625
|
|
14
|
-
congrads/networks/registry.py,sha256=UPzPDU0wI2zoOEvi697QBSDOtaa3Rc0rgCb-tCxbjak,2252
|
|
15
|
-
congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
congrads/transformations/base.py,sha256=KZQkloaDGcqAp8EhlUnHL8VfZqSq8OCqp_Iy_a2Nfns,900
|
|
17
|
-
congrads/transformations/registry.py,sha256=p2cLnt3X1bspEPfR7IVd31qPXQimVe_bRu2VhUOIZj0,2607
|
|
18
|
-
congrads/utils/preprocessors.py,sha256=oqW3hV_yoUd-6I-NSoE61e_JDNEPnBJvvvdsuKd9Ekg,18190
|
|
19
|
-
congrads/utils/utility.py,sha256=zvOAVjQjtmsvyuJm0rF0cy_jApR6qluQsFxk9ItalzE,18893
|
|
20
|
-
congrads/utils/validation.py,sha256=Jj8ZJGJrrH9B02cIaScsQpne3zjyarkPldDdT1pejVA,6208
|
|
21
|
-
congrads-0.3.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
|
|
22
|
-
congrads-0.3.0.dist-info/METADATA,sha256=UCutFzaiD6CaeSr9BqtfiyMS-LsDs6Iwpw8QcXSS2jc,10748
|
|
23
|
-
congrads-0.3.0.dist-info/RECORD,,
|
|
File without changes
|