congrads 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,180 @@
1
+ """Module for managing PyTorch model checkpoints.
2
+
3
+ Provides the `CheckpointManager` class to save and load model and optimizer
4
+ states during training, track the best metric values, and optionally report
5
+ checkpoint events.
6
+ """
7
+
8
+ import os
9
+ from collections.abc import Callable
10
+ from pathlib import Path
11
+
12
+ from torch import Tensor, load, save
13
+ from torch.nn import Module
14
+ from torch.optim import Optimizer
15
+
16
+ from .metrics import MetricManager
17
+ from .utils.validation import validate_callable, validate_type
18
+
19
+ __all__ = ["CheckpointManager"]
20
+
21
+
22
+ class CheckpointManager:
23
+ """Manage saving and loading checkpoints for PyTorch models and optimizers.
24
+
25
+ Handles checkpointing based on a criteria function, restores metric
26
+ states, and optionally reports when a checkpoint is saved.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
32
+ network: Module,
33
+ optimizer: Optimizer,
34
+ metric_manager: MetricManager,
35
+ save_dir: str = "checkpoints",
36
+ create_dir: bool = False,
37
+ report_save: bool = False,
38
+ ):
39
+ """Initialize the CheckpointManager.
40
+
41
+ Args:
42
+ criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
43
+ Function that determines if the current checkpoint should be
44
+ saved based on the current and best metric values.
45
+ network (torch.nn.Module): The model to save/load.
46
+ optimizer (torch.optim.Optimizer): The optimizer to save/load.
47
+ metric_manager (MetricManager): Manages metric states for checkpointing.
48
+ save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
49
+ create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
50
+ Defaults to False.
51
+ report_save (bool, optional): Whether to report when a checkpoint is saved.
52
+ Defaults to False.
53
+
54
+ Raises:
55
+ TypeError: If any provided attribute has an incompatible type.
56
+ FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
57
+ """
58
+ # Type checking
59
+ validate_callable("criteria_function", criteria_function)
60
+ validate_type("network", network, Module)
61
+ validate_type("optimizer", optimizer, Optimizer)
62
+ validate_type("metric_manager", metric_manager, MetricManager)
63
+ validate_type("create_dir", create_dir, bool)
64
+ validate_type("report_save", report_save, bool)
65
+
66
+ # Create path or raise error if create_dir is not found
67
+ if not os.path.exists(save_dir):
68
+ if not create_dir:
69
+ raise FileNotFoundError(
70
+ f"Save directory '{save_dir}' configured in checkpoint manager is not found."
71
+ )
72
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
73
+
74
+ # Initialize objects variables
75
+ self.criteria_function = criteria_function
76
+ self.network = network
77
+ self.optimizer = optimizer
78
+ self.metric_manager = metric_manager
79
+ self.save_dir = save_dir
80
+ self.report_save = report_save
81
+
82
+ self.best_metric_values: dict[str, Tensor] = {}
83
+
84
+ def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
85
+ """Evaluate the criteria function to determine if a better model is found.
86
+
87
+ Aggregates the current metric values during training and applies the
88
+ criteria function. If the criteria function indicates improvement, the
89
+ best metric values are updated, a checkpoint is saved, and a message is
90
+ optionally printed.
91
+
92
+ Args:
93
+ epoch (int): The current epoch number.
94
+ metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
95
+ """
96
+ current_metric_values = self.metric_manager.aggregate(metric_group)
97
+ if self.criteria_function is not None and self.criteria_function(
98
+ current_metric_values, self.best_metric_values
99
+ ):
100
+ # Print message if a new checkpoint is saved
101
+ if self.report_save:
102
+ print(f"New checkpoint saved at epoch {epoch}.")
103
+
104
+ # Update current best metric values
105
+ for metric_name, metric_value in current_metric_values.items():
106
+ self.best_metric_values[metric_name] = metric_value
107
+
108
+ # Save the current state
109
+ self.save(epoch)
110
+
111
+ def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
112
+ """Resumes training from a saved checkpoint file.
113
+
114
+ Args:
115
+ filename (str): The name of the checkpoint file to load.
116
+ Defaults to "checkpoint.pth".
117
+ ignore_missing (bool): If True, does not raise an error if the
118
+ checkpoint file is missing and continues without loading,
119
+ starting from epoch 0. Defaults to False.
120
+
121
+ Returns:
122
+ int: The epoch number from the loaded checkpoint, or 0 if
123
+ ignore_missing is True and no checkpoint was found.
124
+
125
+ Raises:
126
+ TypeError: If a provided attribute has an incompatible type.
127
+ FileNotFoundError: If the specified checkpoint file does not exist.
128
+ """
129
+ # Type checking
130
+ validate_type("filename", filename, str)
131
+ validate_type("ignore_missing", ignore_missing, bool)
132
+
133
+ # Return starting epoch, either from checkpoint file or default
134
+ filepath = os.path.join(self.save_dir, filename)
135
+ if os.path.exists(filepath):
136
+ checkpoint = self.load(filename)
137
+ return checkpoint["epoch"]
138
+ elif ignore_missing:
139
+ return 0
140
+ else:
141
+ raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
142
+
143
+ def save(self, epoch: int, filename: str = "checkpoint.pth"):
144
+ """Save a checkpoint.
145
+
146
+ Args:
147
+ epoch (int): Current epoch number.
148
+ filename (str): Name of the checkpoint file. Defaults to
149
+ 'checkpoint.pth'.
150
+ """
151
+ state = {
152
+ "epoch": epoch,
153
+ "network_state": self.network.state_dict(),
154
+ "optimizer_state": self.optimizer.state_dict(),
155
+ "best_metrics": self.best_metric_values,
156
+ }
157
+ filepath = os.path.join(self.save_dir, filename)
158
+ save(state, filepath)
159
+
160
+ def load(self, filename: str):
161
+ """Load a checkpoint and restore the training state.
162
+
163
+ Loads the checkpoint from the specified file and restores the network
164
+ weights, optimizer state, and best metric values.
165
+
166
+ Args:
167
+ filename (str): Name of the checkpoint file.
168
+
169
+ Returns:
170
+ dict: A dictionary containing the loaded checkpoint information,
171
+ including epoch, loss, and other relevant training state.
172
+ """
173
+ filepath = os.path.join(self.save_dir, filename)
174
+
175
+ checkpoint = load(filepath, weights_only=True)
176
+ self.network.load_state_dict(checkpoint["network_state"])
177
+ self.optimizer.load_state_dict(checkpoint["optimizer_state"])
178
+ self.best_metric_values = checkpoint["best_metrics"]
179
+
180
+ return checkpoint
@@ -0,0 +1,244 @@
1
+ """Defines the abstract base class `Constraint` for specifying constraints on neural network outputs.
2
+
3
+ A `Constraint` monitors whether the network predictions satisfy certain
4
+ conditions during training, validation, and testing. It can optionally
5
+ adjust the loss to enforce constraints, and logs the relevant metrics.
6
+
7
+ Responsibilities:
8
+ - Track which network layers/tags the constraint applies to
9
+ - Check constraint satisfaction for a batch of predictions
10
+ - Compute adjustment directions to enforce the constraint
11
+ - Provide a rescale factor and enforcement flag to influence loss adjustment
12
+
13
+ Subclasses must implement the abstract methods:
14
+ - `check_constraint(data)`: Evaluate constraint satisfaction for a batch
15
+ - `calculate_direction(data)`: Compute directions to adjust predictions
16
+ """
17
+
18
+ import random
19
+ import string
20
+ import warnings
21
+ from abc import ABC, abstractmethod
22
+ from numbers import Number
23
+ from typing import Literal
24
+
25
+ from torch import Tensor
26
+
27
+ from congrads.descriptor import Descriptor
28
+ from congrads.utils.validation import validate_iterable, validate_type
29
+
30
+ __all__ = ["Constraint", "MonotonicityConstraint"]
31
+
32
+
33
+ class Constraint(ABC):
34
+ """Abstract base class for defining constraints applied to neural networks.
35
+
36
+ A `Constraint` specifies conditions that the neural network outputs
37
+ should satisfy. It supports monitoring constraint satisfaction
38
+ during training and can adjust loss to enforce constraints. Subclasses
39
+ must implement the `check_constraint` and `calculate_direction` methods.
40
+
41
+ Args:
42
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
43
+ name (str, optional): A unique name for the constraint. If not provided,
44
+ a name is generated based on the class name and a random suffix.
45
+ enforce (bool, optional): If False, only monitor the constraint
46
+ without adjusting the loss. Defaults to True.
47
+ rescale_factor (Number, optional): Factor to scale the
48
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
49
+ than 1 to give weight to the constraint.
50
+
51
+ Raises:
52
+ TypeError: If a provided attribute has an incompatible type.
53
+ ValueError: If any tag in `tags` is not
54
+ defined in the `descriptor`.
55
+
56
+ Note:
57
+ - If `rescale_factor <= 1`, a warning is issued.
58
+ - If `name` is not provided, a name is auto-generated,
59
+ and a warning is logged.
60
+
61
+ """
62
+
63
+ descriptor: Descriptor = None
64
+ device = None
65
+
66
+ def __init__(
67
+ self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
68
+ ) -> None:
69
+ """Initializes a new Constraint instance.
70
+
71
+ Args:
72
+ tags (set[str]): Tags referencing parts of the network where this constraint applies to.
73
+ name (str, optional): A unique name for the constraint. If not
74
+ provided, a name is generated based on the class name and a
75
+ random suffix.
76
+ enforce (bool, optional): If False, only monitor the constraint
77
+ without adjusting the loss. Defaults to True.
78
+ rescale_factor (Number, optional): Factor to scale the
79
+ constraint-adjusted loss. Defaults to 1.5. Should be greater
80
+ than 1 to give weight to the constraint.
81
+
82
+ Raises:
83
+ TypeError: If a provided attribute has an incompatible type.
84
+ ValueError: If any tag in `tags` is not defined in the `descriptor`.
85
+
86
+ Note:
87
+ - If `rescale_factor <= 1`, a warning is issued.
88
+ - If `name` is not provided, a name is auto-generated, and a
89
+ warning is logged.
90
+ """
91
+ # Init parent class
92
+ super().__init__()
93
+
94
+ # Type checking
95
+ validate_iterable("tags", tags, str)
96
+ validate_type("name", name, str, allow_none=True)
97
+ validate_type("enforce", enforce, bool)
98
+ validate_type("rescale_factor", rescale_factor, Number)
99
+
100
+ # Init object variables
101
+ self.tags = tags
102
+ self.rescale_factor = rescale_factor
103
+ self.initial_rescale_factor = rescale_factor
104
+ self.enforce = enforce
105
+
106
+ # Perform checks
107
+ if rescale_factor <= 1:
108
+ warnings.warn(
109
+ f"Rescale factor for constraint {name} is <= 1. The network "
110
+ "will favor general loss over the constraint-adjusted loss. "
111
+ "Is this intended behavior? Normally, the rescale factor "
112
+ "should always be larger than 1.",
113
+ stacklevel=2,
114
+ )
115
+
116
+ # If no constraint_name is set, generate one based
117
+ # on the class name and a random suffix
118
+ if name:
119
+ self.name = name
120
+ else:
121
+ random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
122
+ self.name = f"{self.__class__.__name__}_{random_suffix}"
123
+ warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
124
+
125
+ # Infer layers from descriptor and tags
126
+ self.layers = set()
127
+ for tag in self.tags:
128
+ if not self.descriptor.exists(tag):
129
+ raise ValueError(
130
+ f"The tag {tag} used with constraint "
131
+ f"{self.name} is not defined in the descriptor. Please "
132
+ "add it to the correct layer using "
133
+ "descriptor.add('layer', ...)."
134
+ )
135
+
136
+ layer, _ = self.descriptor.location(tag)
137
+ self.layers.add(layer)
138
+
139
+ @abstractmethod
140
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
141
+ """Evaluates whether the given model predictions satisfy the constraint.
142
+
143
+ 1 IS SATISFIED, 0 IS NOT SATISFIED
144
+
145
+ Args:
146
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
147
+
148
+ Returns:
149
+ tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
150
+ indicating whether the constraint is satisfied (with value 1.0
151
+ for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
152
+ mask that indicates the relevance of each sample (`True` for relevant
153
+ samples and `False` for irrelevant ones).
154
+ """
155
+ pass
156
+
157
+ @abstractmethod
158
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
159
+ """Compute adjustment directions to better satisfy the constraint.
160
+
161
+ Given the model predictions, input batch, and context, this method calculates the direction
162
+ in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
163
+
164
+ Args:
165
+ data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
166
+
167
+ Returns:
168
+ dict[str, Tensor]: Dictionary mapping network layers to tensors that
169
+ specify the adjustment direction for each tag.
170
+ """
171
+ pass
172
+
173
+
174
+ class MonotonicityConstraint(Constraint, ABC):
175
+ """Abstract base class for monotonicity constraints.
176
+
177
+ Subclasses must define how monotonicity is evaluated and how corrective
178
+ directions are computed.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ tag_prediction: str,
184
+ tag_reference: str,
185
+ rescale_factor_lower: float = 1.5,
186
+ rescale_factor_upper: float = 1.75,
187
+ stable: bool = True,
188
+ direction: Literal["ascending", "descending"] = "ascending",
189
+ name: str = None,
190
+ enforce: bool = True,
191
+ ):
192
+ """Constraint that enforces monotonicity on a predicted output.
193
+
194
+ This constraint ensures that the activations of a prediction tag (`tag_prediction`)
195
+ are monotonically ascending or descending with respect to a target tag (`tag_reference`).
196
+
197
+ Args:
198
+ tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
199
+ tag_reference (str): Name of the tag that acts as the monotonic reference.
200
+ rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
201
+ rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
202
+ stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
203
+ direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
204
+ name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
205
+ enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
206
+ """
207
+ # Type checking
208
+ validate_type("rescale_factor_lower", rescale_factor_lower, float)
209
+ validate_type("rescale_factor_upper", rescale_factor_upper, float)
210
+ validate_type("stable", stable, bool)
211
+ validate_type("direction", direction, str)
212
+
213
+ # Compose constraint name
214
+ if name is None:
215
+ name = f"{tag_prediction} monotonically {direction} by {tag_reference}"
216
+
217
+ # Init parent class
218
+ super().__init__({tag_prediction}, name, enforce, 1.0)
219
+
220
+ # Init variables
221
+ self.tag_prediction = tag_prediction
222
+ self.tag_reference = tag_reference
223
+ self.rescale_factor_lower = rescale_factor_lower
224
+ self.rescale_factor_upper = rescale_factor_upper
225
+ self.stable = stable
226
+ self.direction = direction
227
+ self.descending = direction == "descending"
228
+
229
+ # Init member variables
230
+ self.compared_rankings: Tensor = None
231
+
232
+ @abstractmethod
233
+ def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
234
+ """Evaluate whether the monotonicity constraint is satisfied.
235
+
236
+ Implementations must set `self.compared_rankings` with per-sample
237
+ correction directions.
238
+ """
239
+ pass
240
+
241
+ @abstractmethod
242
+ def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
243
+ """Return directions for monotonicity enforcement."""
244
+ pass