congrads 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,6 +45,9 @@ from abc import ABC, abstractmethod
45
45
  from collections.abc import Iterable
46
46
  from typing import Any, Literal, Self
47
47
 
48
+ __all__ = ["Callback", "CallbackManager", "Operation"]
49
+
50
+
48
51
  Stage = Literal[
49
52
  "on_train_start",
50
53
  "on_train_end",
@@ -7,23 +7,33 @@ collect all callback implementations in one place for easy reference
7
7
  and import, and can be extended as new callbacks are added.
8
8
  """
9
9
 
10
+ from torch import Tensor
10
11
  from torch.utils.tensorboard import SummaryWriter
11
12
 
12
13
  from ..callbacks.base import Callback
13
14
  from ..metrics import MetricManager
14
15
  from ..utils.utility import CSVLogger
15
16
 
17
+ __all__ = ["LoggerCallback"]
18
+
16
19
 
17
20
  class LoggerCallback(Callback):
18
- """Callback to log metrics to TensorBoard and CSV after each epoch or test.
21
+ """Callback to periodically aggregate and store metrics during training and testing.
22
+
23
+ This callback works in conjunction with a MetricManager that accumulates metrics
24
+ internally (e.g. per batch). Metrics are:
19
25
 
20
- This callback queries a MetricManager for aggregated metrics, writes them
21
- to TensorBoard using SummaryWriter, and logs them to a CSV file via CSVLogger.
22
- It also flushes loggers and resets metrics after logging.
26
+ - Aggregated at a configurable epoch interval (`aggregate_interval`)
27
+ - Cached in memory (GPU-resident tensors)
28
+ - Written to TensorBoard and CSV at a separate interval (`store_interval`)
29
+
30
+ Aggregation and storage are decoupled to avoid unnecessary GPU-to-CPU
31
+ synchronization. Any remaining cached metrics are flushed at the end of training.
23
32
 
24
33
  Methods implemented:
25
- - on_epoch_end: Logs metrics at the end of a training epoch.
26
- - on_test_end: Logs metrics at the end of testing.
34
+ - on_epoch_end: Periodically aggregates and stores training metrics.
35
+ - on_train_end: Flushes any remaining cached training metrics.
36
+ - on_test_end: Aggregates and stores test metrics immediately.
27
37
  """
28
38
 
29
39
  def __init__(
@@ -31,6 +41,9 @@ class LoggerCallback(Callback):
31
41
  metric_manager: MetricManager,
32
42
  tensorboard_logger: SummaryWriter,
33
43
  csv_logger: CSVLogger,
44
+ *,
45
+ aggregate_interval: int = 1,
46
+ store_interval: int = 1,
34
47
  ):
35
48
  """Initialize the LoggerCallback.
36
49
 
@@ -38,69 +51,115 @@ class LoggerCallback(Callback):
38
51
  metric_manager: Instance of MetricManager used to collect metrics.
39
52
  tensorboard_logger: TensorBoard SummaryWriter instance for logging scalars.
40
53
  csv_logger: CSVLogger instance for logging metrics to CSV files.
54
+ aggregate_interval: Number of epochs between metric aggregation.
55
+ store_interval: Number of epochs between metric storage.
41
56
  """
42
57
  super().__init__()
58
+
59
+ # Input validation
60
+ if aggregate_interval <= 0 or store_interval <= 0:
61
+ raise ValueError("Intervals must be positive integers")
62
+
63
+ if store_interval % aggregate_interval != 0:
64
+ raise ValueError("store_interval must be a multiple of aggregate_interval")
65
+
66
+ # Store references
43
67
  self.metric_manager = metric_manager
44
68
  self.tensorboard_logger = tensorboard_logger
45
69
  self.csv_logger = csv_logger
70
+ self.aggregate_interval = aggregate_interval
71
+ self.store_interval = store_interval
72
+
73
+ # Cached metrics on GPU by epoch
74
+ self._accumulated_metrics: dict[int, dict[str, Tensor]] = {}
46
75
 
47
76
  def on_epoch_end(self, data: dict[str, any], ctx: dict[str, any]):
48
- """Log training metrics at the end of an epoch.
77
+ """Handle end-of-epoch training logic.
78
+
79
+ At the end of each epoch, this method may:
80
+ - Aggregate training metrics from the MetricManager (every `aggregate_interval` epochs)
81
+ - Cache aggregated metrics keyed by epoch
82
+ - Store cached metrics to disk (every `store_interval` epochs)
49
83
 
50
- Aggregates metrics from the MetricManager under the 'during_training' category,
51
- writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics
52
- for the next epoch.
84
+ Metric aggregation resets the MetricManager accumulation state.
85
+ Metric storage triggers GPU-to-CPU synchronization and writes to loggers.
53
86
 
54
87
  Args:
55
- data: Dictionary containing batch/epoch context (must include 'epoch').
56
- ctx: Additional context dictionary (unused in this implementation).
88
+ data: Dictionary containing epoch context (must include 'epoch').
89
+ ctx: Additional context dictionary (unused).
57
90
 
58
91
  Returns:
59
92
  data: The same input dictionary, unmodified.
60
93
  """
61
94
  epoch = data["epoch"]
62
95
 
63
- # Log training metrics
64
- metrics = self.metric_manager.aggregate("during_training")
65
- for name, value in metrics.items():
66
- self.tensorboard_logger.add_scalar(name, value.item(), epoch)
67
- self.csv_logger.add_value(name, value.item(), epoch)
96
+ # Cache training metrics
97
+ if epoch % self.aggregate_interval == 0:
98
+ metrics = self.metric_manager.aggregate("during_training")
99
+ self._accumulated_metrics[epoch] = metrics
100
+ self.metric_manager.reset("during_training")
68
101
 
69
- # Flush/save
70
- self.tensorboard_logger.flush()
71
- self.csv_logger.save()
102
+ # Store metrics to disk
103
+ if epoch % self.store_interval == 0:
104
+ self._save(self._accumulated_metrics)
105
+ self._accumulated_metrics.clear()
106
+
107
+ return data
108
+
109
+ def on_train_end(self, data, ctx):
110
+ """Flush any remaining cached training metrics at the end of training.
72
111
 
73
- # Reset metric manager for training
74
- self.metric_manager.reset("during_training")
112
+ This ensures that aggregated metrics that were not yet written due to
113
+ `store_interval` alignment are persisted before training terminates.
114
+
115
+ Args:
116
+ data: Dictionary containing training context (unused).
117
+ ctx: Additional context dictionary (unused).
118
+
119
+ Returns:
120
+ data: The same input dictionary, unmodified.
121
+ """
122
+ if self._accumulated_metrics:
123
+ self._save(self._accumulated_metrics)
124
+ self._accumulated_metrics.clear()
75
125
 
76
126
  return data
77
127
 
78
128
  def on_test_end(self, data: dict[str, any], ctx: dict[str, any]):
79
- """Log test metrics at the end of testing.
129
+ """Aggregate and store test metrics at the end of testing.
80
130
 
81
- Aggregates metrics from the MetricManager under the 'after_training' category,
82
- writes them to TensorBoard and CSV, flushes the loggers, and resets the metrics.
131
+ Test metrics are aggregated once and written immediately to disk.
132
+ Interval-based aggregation and caching are not applied to testing.
83
133
 
84
134
  Args:
85
135
  data: Dictionary containing test context (must include 'epoch').
86
- ctx: Additional context dictionary (unused in this implementation).
136
+ ctx: Additional context dictionary (unused).
87
137
 
88
138
  Returns:
89
139
  data: The same input dictionary, unmodified.
90
140
  """
91
141
  epoch = data["epoch"]
92
142
 
93
- # Log test metrics
143
+ # Save test metrics
94
144
  metrics = self.metric_manager.aggregate("after_training")
95
- for name, value in metrics.items():
96
- self.tensorboard_logger.add_scalar(name, value.item(), epoch)
97
- self.csv_logger.add_value(name, value.item(), epoch)
145
+ self._save({epoch: metrics})
146
+ self.metric_manager.reset("after_training")
147
+
148
+ return data
149
+
150
+ def _save(self, metrics: dict[int, dict[str, Tensor]]):
151
+ """Write aggregated metrics to TensorBoard and CSV loggers.
152
+
153
+ Args:
154
+ metrics: Mapping from epoch to a dictionary of metric name to scalar tensor.
155
+ Tensors are expected to be detached and graph-free.
156
+ """
157
+ for epoch, metrics_by_name in metrics.items():
158
+ for name, value in metrics_by_name.items():
159
+ cpu_value = value.item()
160
+ self.tensorboard_logger.add_scalar(name, cpu_value, epoch)
161
+ self.csv_logger.add_value(name, cpu_value, epoch)
98
162
 
99
163
  # Flush/save
100
164
  self.tensorboard_logger.flush()
101
165
  self.csv_logger.save()
102
-
103
- # Reset metric manager for test
104
- self.metric_manager.reset("after_training")
105
-
106
- return data
congrads/checkpoints.py CHANGED
@@ -16,6 +16,8 @@ from torch.optim import Optimizer
16
16
  from .metrics import MetricManager
17
17
  from .utils.validation import validate_callable, validate_type
18
18
 
19
+ __all__ = ["CheckpointManager"]
20
+
19
21
 
20
22
  class CheckpointManager:
21
23
  """Manage saving and loading checkpoints for PyTorch models and optimizers.
@@ -27,6 +27,8 @@ from torch import Tensor
27
27
  from congrads.descriptor import Descriptor
28
28
  from congrads.utils.validation import validate_iterable, validate_type
29
29
 
30
+ __all__ = ["Constraint", "MonotonicityConstraint"]
31
+
30
32
 
31
33
  class Constraint(ABC):
32
34
  """Abstract base class for defining constraints applied to neural networks.
@@ -54,6 +54,20 @@ from ..transformations.registry import IdentityTransformation
54
54
  from ..utils.validation import validate_comparator, validate_iterable, validate_type
55
55
  from .base import Constraint, MonotonicityConstraint
56
56
 
57
+ __all__ = [
58
+ "ImplicationConstraint",
59
+ "ScalarConstraint",
60
+ "BinaryConstraint",
61
+ "SumConstraint",
62
+ "RankedMonotonicityConstraint",
63
+ "PairwiseMonotonicityConstraint",
64
+ "PerGroupMonotonicityConstraint",
65
+ "EncodedGroupedMonotonicityConstraint",
66
+ "ANDConstraint",
67
+ "ORConstraint",
68
+ ]
69
+
70
+
57
71
  COMPARATOR_MAP: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
58
72
  ">": torch.gt,
59
73
  ">=": torch.ge,
@@ -121,6 +135,7 @@ class ImplicationConstraint(Constraint):
121
135
  constraint (1 if satisfied, 0 otherwise).
122
136
  - head_satisfaction: Tensor indicating satisfaction of the
123
137
  head constraint alone.
138
+
124
139
  """
125
140
  # Check satisfaction of head and body constraints
126
141
  head_satisfaction, _ = self.head.check_constraint(data)
@@ -420,6 +435,7 @@ class BinaryConstraint(Constraint):
420
435
  (1 for satisfied, 0 for violated) for each sample.
421
436
  - mask (Tensor): Tensor of ones with the same shape as `result`,
422
437
  used for constraint aggregation.
438
+
423
439
  """
424
440
  # Select relevant columns
425
441
  selection_left = self.descriptor.select(self.tag_left, data)
@@ -601,6 +617,7 @@ class SumConstraint(Constraint):
601
617
  - result (Tensor): Binary tensor indicating whether the constraint
602
618
  is satisfied (1) or violated (0) for each sample.
603
619
  - mask (Tensor): Tensor of ones, used for constraint aggregation.
620
+
604
621
  """
605
622
 
606
623
  def compute_weighted_sum(
@@ -825,7 +842,7 @@ class PairwiseMonotonicityConstraint(MonotonicityConstraint):
825
842
 
826
843
  # Consider only upper triangle to avoid duplicate comparisons
827
844
  batch_size = preds.shape[0]
828
- mask = triu(ones(batch_size, batch_size, dtype=torch.bool), diagonal=1)
845
+ mask = triu(ones(batch_size, batch_size, dtype=torch.bool, device=preds.device), diagonal=1)
829
846
 
830
847
  # Pairwise violations
831
848
  violations = (preds_diff * targets_diff < 0) & mask
@@ -864,8 +881,7 @@ class PerGroupMonotonicityConstraint(Constraint):
864
881
 
865
882
  Each group is treated as an independent mini-batch:
866
883
  - The base constraint is applied to the group's subset of data.
867
- - Violations and directions are computed per group and then reassembled
868
- into the original batch order.
884
+ - Violations and directions are computed per group and then reassembled into the original batch order.
869
885
 
870
886
  This is an explicit alternative to :class:`EncodedGroupedMonotonicityConstraint`,
871
887
  which enforces the same logic using vectorized interval encoding.
@@ -986,13 +1002,16 @@ class EncodedGroupedMonotonicityConstraint(Constraint):
986
1002
  """Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval."""
987
1003
  # Get data and keys
988
1004
  ids = self.descriptor.select(self.tag_group, data)
989
- preds = self.descriptor.select(self.base.tag_prediction, data)
990
- targets = self.descriptor.select(self.base.tag_reference, data)
991
1005
  preds_key, _ = self.descriptor.location(self.base.tag_prediction)
992
1006
  targets_key, _ = self.descriptor.location(self.base.tag_reference)
993
1007
 
994
- new_preds = preds + ids * (preds.max() - preds.min() + 1)
995
- new_targets = targets + self.negation * ids * (targets.max() - targets.min() + 1)
1008
+ preds = data[preds_key]
1009
+ targets = data[targets_key]
1010
+
1011
+ new_preds = preds + ids * (preds.amax(dim=0) - preds.amin(dim=0) + 1)
1012
+ new_targets = targets + self.negation * ids * (
1013
+ targets.amax(dim=0) - targets.amin(dim=0) + 1
1014
+ )
996
1015
 
997
1016
  # Create new batch for child constraint
998
1017
  new_data = {preds_key: new_preds, targets_key: new_targets}
@@ -1018,13 +1037,14 @@ class ANDConstraint(Constraint):
1018
1037
  are satisfied (elementwise logical AND).
1019
1038
  * The corrective direction is computed by weighting each sub-constraint's
1020
1039
  direction with its satisfaction mask and summing across all sub-constraints.
1040
+
1021
1041
  """
1022
1042
 
1023
1043
  def __init__(
1024
1044
  self,
1025
1045
  *constraints: Constraint,
1026
1046
  name: str = None,
1027
- monitor_only: bool = False,
1047
+ enforce: bool = False,
1028
1048
  rescale_factor: Number = 1.5,
1029
1049
  ) -> None:
1030
1050
  """A composite constraint that enforces the logical AND of multiple constraints.
@@ -1041,7 +1061,7 @@ class ANDConstraint(Constraint):
1041
1061
  name (str, optional): A custom name for this constraint. If not provided,
1042
1062
  the name will be composed from the sub-constraint names joined with
1043
1063
  " AND ".
1044
- monitor_only (bool, optional): If True, the constraint will be monitored
1064
+ enforce (bool, optional): If True, the constraint will be monitored
1045
1065
  but not enforced. Defaults to False.
1046
1066
  rescale_factor (Number, optional): A scaling factor applied when rescaling
1047
1067
  corrections. Defaults to 1.5.
@@ -1062,7 +1082,7 @@ class ANDConstraint(Constraint):
1062
1082
  super().__init__(
1063
1083
  set().union(*(constraint.tags for constraint in constraints)),
1064
1084
  name,
1065
- monitor_only,
1085
+ enforce,
1066
1086
  rescale_factor,
1067
1087
  )
1068
1088
 
@@ -1083,6 +1103,7 @@ class ANDConstraint(Constraint):
1083
1103
  * `mask`: A tensor of ones with the same shape as
1084
1104
  `total_satisfaction`. Typically used as a weighting mask
1085
1105
  in downstream processing.
1106
+
1086
1107
  """
1087
1108
  total_satisfaction: Tensor = None
1088
1109
  total_mask: Tensor = None
@@ -1141,13 +1162,14 @@ class ORConstraint(Constraint):
1141
1162
  is satisfied (elementwise logical OR).
1142
1163
  * The corrective direction is computed by weighting each sub-constraint's
1143
1164
  direction with its satisfaction mask and summing across all sub-constraints.
1165
+
1144
1166
  """
1145
1167
 
1146
1168
  def __init__(
1147
1169
  self,
1148
1170
  *constraints: Constraint,
1149
1171
  name: str = None,
1150
- monitor_only: bool = False,
1172
+ enforce: bool = False,
1151
1173
  rescale_factor: Number = 1.5,
1152
1174
  ) -> None:
1153
1175
  """A composite constraint that enforces the logical OR of multiple constraints.
@@ -1164,7 +1186,7 @@ class ORConstraint(Constraint):
1164
1186
  name (str, optional): A custom name for this constraint. If not provided,
1165
1187
  the name will be composed from the sub-constraint names joined with
1166
1188
  " OR ".
1167
- monitor_only (bool, optional): If True, the constraint will be monitored
1189
+ enforce (bool, optional): If True, the constraint will be monitored
1168
1190
  but not enforced. Defaults to False.
1169
1191
  rescale_factor (Number, optional): A scaling factor applied when rescaling
1170
1192
  corrections. Defaults to 1.5.
@@ -1185,7 +1207,7 @@ class ORConstraint(Constraint):
1185
1207
  super().__init__(
1186
1208
  set().union(*(constraint.tags for constraint in constraints)),
1187
1209
  name,
1188
- monitor_only,
1210
+ enforce,
1189
1211
  rescale_factor,
1190
1212
  )
1191
1213
 
@@ -1206,6 +1228,7 @@ class ORConstraint(Constraint):
1206
1228
  * `mask`: A tensor of ones with the same shape as
1207
1229
  `total_satisfaction`. Typically used as a weighting mask
1208
1230
  in downstream processing.
1231
+
1209
1232
  """
1210
1233
  total_satisfaction: Tensor = None
1211
1234
  total_mask: Tensor = None
@@ -18,6 +18,8 @@ from ..callbacks.base import CallbackManager
18
18
  from ..core.constraint_engine import ConstraintEngine
19
19
  from ..metrics import MetricManager
20
20
 
21
+ __all__ = ["BatchRunner"]
22
+
21
23
 
22
24
  class BatchRunner:
23
25
  """Executes a single batch for training, validation, or testing.
@@ -41,6 +41,8 @@ from ..core.epoch_runner import EpochRunner
41
41
  from ..descriptor import Descriptor
42
42
  from ..metrics import MetricManager
43
43
 
44
+ __all__ = ["CongradsCore"]
45
+
44
46
 
45
47
  class CongradsCore:
46
48
  """The CongradsCore class is the central training engine for constraint-guided optimization.
@@ -14,6 +14,8 @@ from ..constraints.base import Constraint
14
14
  from ..descriptor import Descriptor
15
15
  from ..metrics import MetricManager
16
16
 
17
+ __all__ = ["ConstraintEngine"]
18
+
17
19
 
18
20
  class ConstraintEngine:
19
21
  """Manages constraint evaluation and enforcement for a neural network.
@@ -17,6 +17,8 @@ from tqdm import tqdm
17
17
 
18
18
  from ..core.batch_runner import BatchRunner
19
19
 
20
+ __all__ = ["EpochRunner"]
21
+
20
22
 
21
23
  class EpochRunner:
22
24
  """Runs full epochs over DataLoaders.
@@ -5,9 +5,11 @@ downloading, loading, and transforming specific datasets where applicable.
5
5
 
6
6
  Classes:
7
7
 
8
- - SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.
9
8
  - BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.
10
9
  - FamilyIncome: A dataset class for the Family Income and Expenditure dataset.
10
+ - SectionedGaussians: A synthetic dataset generating smoothly varying Gaussian signals across multiple sections.
11
+ - SyntheticMonotonicity: A synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.
12
+ - SyntheticClusters: A dataset class for generating synthetic clustered 2D data with labels.
11
13
 
12
14
  Each dataset class provides methods for downloading the data
13
15
  (if not already available or synthetic), checking the integrity of the dataset, loading
@@ -31,6 +33,14 @@ from torchvision.datasets.utils import (
31
33
  download_and_extract_archive,
32
34
  )
33
35
 
36
+ __all__ = [
37
+ "BiasCorrection",
38
+ "FamilyIncome",
39
+ "SectionedGaussians",
40
+ "SyntheticMonotonicity",
41
+ "SyntheticClusters",
42
+ ]
43
+
34
44
 
35
45
  class BiasCorrection(Dataset):
36
46
  """A dataset class for accessing the Bias Correction dataset.
congrads/descriptor.py CHANGED
@@ -14,6 +14,8 @@ from torch import Tensor
14
14
 
15
15
  from .utils.validation import validate_type
16
16
 
17
+ __all__ = ["Descriptor"]
18
+
17
19
 
18
20
  class Descriptor:
19
21
  """A class to manage the mapping between tags.
congrads/metrics.py CHANGED
@@ -11,6 +11,8 @@ from torch import Tensor, cat, nanmean, tensor
11
11
 
12
12
  from .utils.validation import validate_callable, validate_type
13
13
 
14
+ __all__ = ["Metric", "MetricManager"]
15
+
14
16
 
15
17
  class Metric:
16
18
  """Represents a single metric to be accumulated and aggregated.
@@ -3,6 +3,8 @@
3
3
  from torch import Tensor
4
4
  from torch.nn import Linear, Module, ReLU, Sequential
5
5
 
6
+ __all__ = ["MLPNetwork"]
7
+
6
8
 
7
9
  class MLPNetwork(Module):
8
10
  """A multi-layer perceptron (MLP) neural network with configurable hidden layers."""
@@ -6,6 +6,8 @@ from torch import Tensor
6
6
 
7
7
  from ..utils.validation import validate_type
8
8
 
9
+ __all__ = ["Transformation"]
10
+
9
11
 
10
12
  class Transformation(ABC):
11
13
  """Abstract base class for tag data transformations."""
@@ -7,6 +7,8 @@ from torch import Tensor
7
7
  from ..utils.validation import validate_callable, validate_type
8
8
  from .base import Transformation
9
9
 
10
+ __all__ = ["IdentityTransformation", "DenormalizeMinMax", "ApplyOperator"]
11
+
10
12
 
11
13
  class IdentityTransformation(Transformation):
12
14
  """A transformation that returns the input unchanged."""
@@ -12,6 +12,8 @@ normalization, feature engineering, constraint filtering, and sampling.
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
 
15
+ __all__ = ["preprocess_BiasCorrection", "preprocess_FamilyIncome", "preprocess_AdultCensusIncome"]
16
+
15
17
 
16
18
  def preprocess_BiasCorrection(df: pd.DataFrame) -> pd.DataFrame: # noqa: N802
17
19
  """Preprocesses the given dataframe for bias correction by performing a series of transformations.
congrads/utils/utility.py CHANGED
@@ -13,6 +13,16 @@ from torch import Generator, Tensor, argsort, cat, int32, unique
13
13
  from torch.nn.modules.loss import _Loss
14
14
  from torch.utils.data import DataLoader, Dataset, random_split
15
15
 
16
+ __all__ = [
17
+ "CSVLogger",
18
+ "split_data_loaders",
19
+ "ZeroLoss",
20
+ "LossWrapper",
21
+ "process_data_monotonicity_constraint",
22
+ "DictDatasetWrapper",
23
+ "Seeder",
24
+ ]
25
+
16
26
 
17
27
  class CSVLogger:
18
28
  """A utility class for logging key-value pairs to a CSV file, organized by epochs.
@@ -466,8 +476,7 @@ class Seeder:
466
476
  """Initialize the Seeder with a base seed.
467
477
 
468
478
  Args:
469
- base_seed (int): The initial seed from which all subsequent
470
- pseudo-random seeds are deterministically derived.
479
+ base_seed (int): The initial seed from which all subsequent seudo-random seeds are deterministically derived.
471
480
  """
472
481
  self._rng = random.Random(base_seed)
473
482
 
@@ -486,8 +495,7 @@ class Seeder:
486
495
  def set_reproducible(self) -> None:
487
496
  """Configure global random states for reproducibility.
488
497
 
489
- Seeds the following libraries with deterministically generated
490
- seeds based on the base seed:
498
+ Seeds the following libraries with deterministically generated seeds based on the base seed:
491
499
  - Python's built-in `random`
492
500
  - NumPy's random number generator
493
501
  - PyTorch (CPU and GPU)
@@ -496,6 +504,7 @@ class Seeder:
496
504
  - Seeding all CUDA devices
497
505
  - Disabling CuDNN benchmarking
498
506
  - Enabling CuDNN deterministic mode
507
+
499
508
  """
500
509
  random.seed(self.roll_seed())
501
510
  np.random.seed(self.roll_seed())
@@ -7,6 +7,15 @@ validation functions.
7
7
 
8
8
  from torch.utils.data import DataLoader
9
9
 
10
+ __all__ = [
11
+ "validate_type",
12
+ "validate_iterable",
13
+ "validate_comparator",
14
+ "validate_callable",
15
+ "validate_callable_iterable",
16
+ "validate_loaders",
17
+ ]
18
+
10
19
 
11
20
  def validate_type(name, value, expected_types, allow_none=False):
12
21
  """Validate that a value is of the specified type(s).
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: congrads
3
- Version: 0.3.0
3
+ Version: 0.3.1.post1
4
4
  Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
5
5
  Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
6
6
  Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
@@ -0,0 +1,23 @@
1
+ congrads/__init__.py,sha256=XJKWRteSvmTYawgS1Pon8kWhhd3haKo4RGAsyjEGS8Q,383
2
+ congrads/callbacks/base.py,sha256=F1eanEwEOvnBGZ_HKNdfN0QWuK9QFhQ9RbjrvjVJZ00,13771
3
+ congrads/callbacks/registry.py,sha256=UtQuoG48GR0zhOKNIhiUxVvD7Z46ClSiVitx8UAI244,6268
4
+ congrads/checkpoints.py,sha256=sWOL_v9ox0zajmNshNSKHP6l8f-p1RzwqyJIdPQnvjg,7261
5
+ congrads/constraints/base.py,sha256=j-DfdY7C4TJynglUn7JpH_YsREGXI-K9x1XCMEXSqZg,10157
6
+ congrads/constraints/registry.py,sha256=Ki6waNtnbcAA5F7V3nkSOKEnpyGS3g1UKTJwyR3CgC8,52577
7
+ congrads/core/batch_runner.py,sha256=jo7KP-64US3BxTO-zie2pD0zxdVPxNtQDx0TdRpFwFE,6994
8
+ congrads/core/congradscore.py,sha256=S8_ph3vs5LNl53u2y6-l-vKetOXYW5Xku2AeIl7mct8,12098
9
+ congrads/core/constraint_engine.py,sha256=y0I7cxm2oslovXMtT3V5hw3_TZm0NsP1pAgI5uqTyJw,9026
10
+ congrads/core/epoch_runner.py,sha256=8mCsfuAnux3Ws_dohD8HQUsCnSo_TxWlOcHwX6cq4ys,4120
11
+ congrads/datasets/registry.py,sha256=ZYLXnHxCRk0tgVcwsntNfezGdJNLVUFEKWKGMsz1Dsw,31477
12
+ congrads/descriptor.py,sha256=uzAuR_qGM1zzDgeLnSEwbUvvvT1eDL_2hFwwq-I9WoI,6940
13
+ congrads/metrics.py,sha256=Bb0g7KEulytCPTwnT-Q4hfIXwurZuyjPsydx2YmoaqY,4664
14
+ congrads/networks/registry.py,sha256=xxSpJDmkMlIG_1f-aEzmjUTNekVrR9uWDJVv7xgWofQ,2278
15
+ congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ congrads/transformations/base.py,sha256=ZjPCqa5GoqWuaWcP4afrGm0eSEXQoZ4RylGPRKwd48w,930
17
+ congrads/transformations/registry.py,sha256=ZF1wuPsG5A5jse8jZcgx7pyQ5htJ4VVX1_PzVLvADLs,2683
18
+ congrads/utils/preprocessors.py,sha256=8AvY-6kX47SPsQTIaxLl7yBcROlv3pjLCmsG7SLDTqM,18290
19
+ congrads/utils/utility.py,sha256=CWiLZ7tv3dmHA3h8qeWJ8abUKhVckzFZWOyuAweQzhw,19046
20
+ congrads/utils/validation.py,sha256=dleerBb0sAqnn3AjhcUxIRGFnqImWsDbfNNvBVmopsI,6379
21
+ congrads-0.3.1.post1.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
22
+ congrads-0.3.1.post1.dist-info/METADATA,sha256=E8GcJ6cKT0QX4P8MNrKRgxmLX_-ggcqyb010gsUZ3fQ,10754
23
+ congrads-0.3.1.post1.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- congrads/__init__.py,sha256=XJKWRteSvmTYawgS1Pon8kWhhd3haKo4RGAsyjEGS8Q,383
2
- congrads/callbacks/base.py,sha256=OChXls-tndgJQOXNfqavnPywHZHn3N87yLKD4kbHDHk,13714
3
- congrads/callbacks/registry.py,sha256=KkzjDqMS3CkE__PpGrmAEYwRngqGSdQNE8NVWl7ogeA,3898
4
- congrads/checkpoints.py,sha256=V79n3mqjB48nbNkBELqKDg9iou0b1vc5eRrlcu8aIA4,7228
5
- congrads/constraints/base.py,sha256=k9OyPS2A4bP3fSEAEANGuw7zofiWlGIxqb5ows1LQWs,10105
6
- congrads/constraints/registry.py,sha256=k__RfcXle-qDL9OJ-nfwgL9zeM6-ISwgQyIzmx-lsgc,52302
7
- congrads/core/batch_runner.py,sha256=emc7smJLDHq0J8_J9t9X0RtqrXaYwOP9mmhlX_M78e4,6967
8
- congrads/core/congradscore.py,sha256=9ZKUVMB9RbmudtuC-MQcNColBYUjb6XLqb0eISzfrGk,12070
9
- congrads/core/constraint_engine.py,sha256=UEt-tmtJeJX0Wu3ol17Z0A9hacL0F8oouJUwHgIIoDE,8994
10
- congrads/core/epoch_runner.py,sha256=l0x3uLXQ5I5o1C63wXgL4_QkhFmXxW-jeejNJK6sf18,4093
11
- congrads/datasets/registry.py,sha256=RfffRiA7Qijc69cJTBJhItTZ8x9B-p1kXMjvcfEC_nA,31102
12
- congrads/descriptor.py,sha256=tUHF4vvyNzJP5vpq1xn0uhKnOlAkElwG2R9gG4glHvQ,6914
13
- congrads/metrics.py,sha256=e52QC8yNKsxAndjC3U4WMUnQ_0GmiSlExKtxRRShHao,4625
14
- congrads/networks/registry.py,sha256=UPzPDU0wI2zoOEvi697QBSDOtaa3Rc0rgCb-tCxbjak,2252
15
- congrads/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- congrads/transformations/base.py,sha256=KZQkloaDGcqAp8EhlUnHL8VfZqSq8OCqp_Iy_a2Nfns,900
17
- congrads/transformations/registry.py,sha256=p2cLnt3X1bspEPfR7IVd31qPXQimVe_bRu2VhUOIZj0,2607
18
- congrads/utils/preprocessors.py,sha256=oqW3hV_yoUd-6I-NSoE61e_JDNEPnBJvvvdsuKd9Ekg,18190
19
- congrads/utils/utility.py,sha256=zvOAVjQjtmsvyuJm0rF0cy_jApR6qluQsFxk9ItalzE,18893
20
- congrads/utils/validation.py,sha256=Jj8ZJGJrrH9B02cIaScsQpne3zjyarkPldDdT1pejVA,6208
21
- congrads-0.3.0.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
22
- congrads-0.3.0.dist-info/METADATA,sha256=UCutFzaiD6CaeSr9BqtfiyMS-LsDs6Iwpw8QcXSS2jc,10748
23
- congrads-0.3.0.dist-info/RECORD,,