congrads 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/__init__.py CHANGED
@@ -1,17 +1,24 @@
1
- # __init__.py
2
- version = "0.2.0"
1
+ # pylint: skip-file
2
+
3
+ try:
4
+ from importlib.metadata import version as get_version # Python 3.8+
5
+ except ImportError:
6
+ from pkg_resources import (
7
+ get_distribution as get_version,
8
+ ) # Fallback for older versions
9
+
10
+ try:
11
+ version = get_version("congrads") # Replace with your package name
12
+ except Exception:
13
+ version = "0.0.0" # Fallback if the package isn't installed
3
14
 
4
15
  # Only expose the submodules, not individual classes
5
- from . import constraints
6
- from . import core
7
- from . import datasets
8
- from . import descriptor
9
- from . import metrics
10
- from . import networks
11
- from . import utils
16
+ from . import constraints, core, datasets, descriptor, metrics, networks, utils
12
17
 
13
- # Define __all__ to specify that the submodules are accessible, but not classes directly.
18
+ # Define __all__ to specify that the submodules are accessible,
19
+ # but not classes directly.
14
20
  __all__ = [
21
+ "checkpoints",
15
22
  "constraints",
16
23
  "core",
17
24
  "datasets",
@@ -0,0 +1,232 @@
1
+ """
2
+ This module provides a `CheckpointManager` class for managing the saving and
3
+ loading of checkpoints during PyTorch model training.
4
+
5
+ The `CheckpointManager` handles:
6
+
7
+ - Saving and loading the state of models, optimizers, and metrics.
8
+ - Registering and evaluating performance criteria to determine if a model's
9
+ performance has improved, enabling automated saving of the best-performing
10
+ model checkpoints.
11
+ - Resuming training from a specific checkpoint.
12
+
13
+ Usage:
14
+ 1. Initialize the `CheckpointManager` with a PyTorch model, optimizer,
15
+ and metric manager.
16
+ 2. Register criteria for tracking and evaluating metrics.
17
+ 3. Use the `save` and `load` methods to manage checkpoints during training.
18
+ 4. Call `evaluate_criteria` to automatically evaluate and save the
19
+ best-performing checkpoints.
20
+
21
+ Dependencies:
22
+ - PyTorch (`torch`)
23
+ """
24
+
25
+ import os
26
+ from pathlib import Path
27
+ from typing import Callable
28
+
29
+ from torch import Tensor, gt, load, save
30
+ from torch.nn import Module
31
+ from torch.optim import Optimizer
32
+
33
+ from .metrics import MetricManager
34
+ from .utils import validate_comparator_pytorch, validate_type
35
+
36
+
37
+ class CheckpointManager:
38
+ """
39
+ A class to handle saving and loading checkpoints for
40
+ PyTorch models and optimizers.
41
+
42
+ Args:
43
+ network (torch.nn.Module): The network (model) to save/load.
44
+ optimizer (torch.optim.Optimizer): The optimizer to save/load.
45
+ metric_manager (MetricManager): The metric manager to restore saved
46
+ metric states.
47
+ save_dir (str): Directory where checkpoints will be saved. Defaults
48
+ to 'checkpoints'.
49
+ create_dir (bool): Whether to create the save_dir if it does not exist.
50
+ Defaults to False.
51
+
52
+ Raises:
53
+ TypeError: If a provided attribute has an incompatible type.
54
+ FileNotFoundError: If the save directory does not exist and create_dir
55
+ is set to False.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ network: Module,
61
+ optimizer: Optimizer,
62
+ metric_manager: MetricManager,
63
+ save_dir: str = "checkpoints",
64
+ create_dir: bool = False,
65
+ ):
66
+ """
67
+ Initialize the checkpoint manager.
68
+ """
69
+
70
+ # Type checking
71
+ validate_type("network", network, Module)
72
+ validate_type("optimizer", optimizer, Optimizer)
73
+ validate_type("metric_manager", metric_manager, MetricManager)
74
+ validate_type("create_dir", create_dir, bool)
75
+
76
+ # Create path or raise error if create_dir is not found
77
+ if not os.path.exists(save_dir):
78
+ if not create_dir:
79
+ raise FileNotFoundError(
80
+ f"Save directory '{str(save_dir)}' configured in "
81
+ "checkpoint manager is not found."
82
+ )
83
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
84
+
85
+ # Initialize objects variables
86
+ self.network = network
87
+ self.optimizer = optimizer
88
+ self.metric_manager = metric_manager
89
+ self.save_dir = save_dir
90
+
91
+ self.criteria: dict[str, Callable[[Tensor, Tensor], Tensor]] = {}
92
+ self.best_metrics: dict[str, Tensor] = {}
93
+
94
+ def register(
95
+ self,
96
+ metric_name: str,
97
+ comparator: Callable[[Tensor, Tensor], Tensor] = gt,
98
+ ):
99
+ """
100
+ Register a criterion for evaluating a performance metric
101
+ during training.
102
+
103
+ Stores the comparator to determine whether the current metric has
104
+ improved relative to the previous best metric value.
105
+
106
+ Args:
107
+ metric_name (str): The name of the metric to evaluate.
108
+ comparator (Callable[[Tensor, Tensor], Tensor], optional):
109
+ A function that compares the current metric value against the
110
+ previous best value. Defaults to a greater-than (gt) comparison.
111
+
112
+ Raises:
113
+ TypeError: If a provided attribute has an incompatible type.
114
+
115
+ """
116
+
117
+ validate_type("metric_name", metric_name, str)
118
+ validate_comparator_pytorch("comparator", comparator)
119
+ validate_comparator_pytorch("comparator", comparator)
120
+
121
+ self.criteria[metric_name] = comparator
122
+
123
+ def resume(
124
+ self, filename: str = "checkpoint.pth", ignore_missing: bool = False
125
+ ) -> int:
126
+ """
127
+ Resumes training from a saved checkpoint file.
128
+
129
+ Args:
130
+ filename (str): The name of the checkpoint file to load.
131
+ Defaults to "checkpoint.pth".
132
+ ignore_missing (bool): If True, does not raise an error if the
133
+ checkpoint file is missing and continues without loading,
134
+ starting from epoch 0. Defaults to False.
135
+
136
+ Returns:
137
+ int: The epoch number from the loaded checkpoint, or 0 if
138
+ ignore_missing is True and no checkpoint was found.
139
+
140
+ Raises:
141
+ TypeError: If a provided attribute has an incompatible type.
142
+ FileNotFoundError: If the specified checkpoint file does not exist.
143
+ """
144
+
145
+ # Type checking
146
+ validate_type("filename", filename, str)
147
+ validate_type("ignore_missing", ignore_missing, bool)
148
+
149
+ # Return starting epoch, either from checkpoint file or default
150
+ filepath = os.path.join(self.save_dir, filename)
151
+ if os.path.exists(filepath):
152
+ checkpoint = self.load("checkpoint.pth")
153
+ return checkpoint["epoch"]
154
+ elif ignore_missing:
155
+ return 0
156
+ else:
157
+ raise FileNotFoundError(
158
+ f"A checkpoint was not found at {filepath} to resume training."
159
+ )
160
+
161
+ def evaluate_criteria(self, epoch: int):
162
+ """
163
+ Evaluate the defined criteria for model performance metrics
164
+ during training.
165
+
166
+ Args:
167
+ epoch (int): The current epoch number.
168
+
169
+ Compares the current metrics against the previous best metrics using
170
+ predefined comparators. If a criterion is met, saves the model and
171
+ the corresponding best metric values.
172
+ """
173
+
174
+ for metric_name, comparator in self.criteria.items():
175
+
176
+ current_metric_value = self.metric_manager.metrics[
177
+ metric_name
178
+ ].aggregate()
179
+ best_metric_value = self.best_metrics.get(metric_name)
180
+
181
+ # TODO improve efficiency by not checking is None each iteration
182
+ if best_metric_value is None or comparator(
183
+ current_metric_value,
184
+ best_metric_value,
185
+ ):
186
+ self.save(epoch)
187
+ self.best_metrics[metric_name] = current_metric_value
188
+
189
+ def save(
190
+ self,
191
+ epoch: int,
192
+ filename: str = "checkpoint.pth",
193
+ ):
194
+ """
195
+ Save a checkpoint.
196
+
197
+ Args:
198
+ epoch (int): Current epoch number.
199
+ filename (str): Name of the checkpoint file. Defaults to
200
+ 'checkpoint.pth'.
201
+ """
202
+
203
+ state = {
204
+ "epoch": epoch,
205
+ "network_state": self.network.state_dict(),
206
+ "optimizer_state": self.optimizer.state_dict(),
207
+ "best_metrics": self.best_metrics,
208
+ }
209
+ filepath = os.path.join(self.save_dir, filename)
210
+ save(state, filepath)
211
+
212
+ def load(self, filename: str):
213
+ """
214
+ Load a checkpoint and restores the state of the network, optimizer
215
+ and best_metrics.
216
+
217
+ Args:
218
+ filename (str): Name of the checkpoint file.
219
+
220
+ Returns:
221
+ dict: A dictionary containing the loaded checkpoint
222
+ information (epoch, loss, etc.).
223
+ """
224
+
225
+ filepath = os.path.join(self.save_dir, filename)
226
+
227
+ checkpoint = load(filepath)
228
+ self.network.load_state_dict(checkpoint["network_state"])
229
+ self.optimizer.load_state_dict(checkpoint["optimizer_state"])
230
+ self.best_metrics = checkpoint["best_metrics"]
231
+
232
+ return checkpoint