congrads 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +17 -10
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +664 -134
- congrads/core.py +482 -110
- congrads/datasets.py +315 -11
- congrads/descriptor.py +100 -20
- congrads/metrics.py +178 -16
- congrads/networks.py +47 -23
- congrads/transformations.py +139 -0
- congrads/utils.py +439 -39
- congrads-1.0.2.dist-info/METADATA +208 -0
- congrads-1.0.2.dist-info/RECORD +15 -0
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/WHEEL +1 -1
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/LICENSE +0 -0
- {congrads-0.2.0.dist-info → congrads-1.0.2.dist-info}/top_level.txt +0 -0
congrads/__init__.py
CHANGED
|
@@ -1,17 +1,24 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
1
|
+
# pylint: skip-file
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from importlib.metadata import version as get_version # Python 3.8+
|
|
5
|
+
except ImportError:
|
|
6
|
+
from pkg_resources import (
|
|
7
|
+
get_distribution as get_version,
|
|
8
|
+
) # Fallback for older versions
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
version = get_version("congrads") # Replace with your package name
|
|
12
|
+
except Exception:
|
|
13
|
+
version = "0.0.0" # Fallback if the package isn't installed
|
|
3
14
|
|
|
4
15
|
# Only expose the submodules, not individual classes
|
|
5
|
-
from . import constraints
|
|
6
|
-
from . import core
|
|
7
|
-
from . import datasets
|
|
8
|
-
from . import descriptor
|
|
9
|
-
from . import metrics
|
|
10
|
-
from . import networks
|
|
11
|
-
from . import utils
|
|
16
|
+
from . import constraints, core, datasets, descriptor, metrics, networks, utils
|
|
12
17
|
|
|
13
|
-
# Define __all__ to specify that the submodules are accessible,
|
|
18
|
+
# Define __all__ to specify that the submodules are accessible,
|
|
19
|
+
# but not classes directly.
|
|
14
20
|
__all__ = [
|
|
21
|
+
"checkpoints",
|
|
15
22
|
"constraints",
|
|
16
23
|
"core",
|
|
17
24
|
"datasets",
|
congrads/checkpoints.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides a `CheckpointManager` class for managing the saving and
|
|
3
|
+
loading of checkpoints during PyTorch model training.
|
|
4
|
+
|
|
5
|
+
The `CheckpointManager` handles:
|
|
6
|
+
|
|
7
|
+
- Saving and loading the state of models, optimizers, and metrics.
|
|
8
|
+
- Registering and evaluating performance criteria to determine if a model's
|
|
9
|
+
performance has improved, enabling automated saving of the best-performing
|
|
10
|
+
model checkpoints.
|
|
11
|
+
- Resuming training from a specific checkpoint.
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
1. Initialize the `CheckpointManager` with a PyTorch model, optimizer,
|
|
15
|
+
and metric manager.
|
|
16
|
+
2. Register criteria for tracking and evaluating metrics.
|
|
17
|
+
3. Use the `save` and `load` methods to manage checkpoints during training.
|
|
18
|
+
4. Call `evaluate_criteria` to automatically evaluate and save the
|
|
19
|
+
best-performing checkpoints.
|
|
20
|
+
|
|
21
|
+
Dependencies:
|
|
22
|
+
- PyTorch (`torch`)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import os
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Callable
|
|
28
|
+
|
|
29
|
+
from torch import Tensor, gt, load, save
|
|
30
|
+
from torch.nn import Module
|
|
31
|
+
from torch.optim import Optimizer
|
|
32
|
+
|
|
33
|
+
from .metrics import MetricManager
|
|
34
|
+
from .utils import validate_comparator_pytorch, validate_type
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CheckpointManager:
|
|
38
|
+
"""
|
|
39
|
+
A class to handle saving and loading checkpoints for
|
|
40
|
+
PyTorch models and optimizers.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
network (torch.nn.Module): The network (model) to save/load.
|
|
44
|
+
optimizer (torch.optim.Optimizer): The optimizer to save/load.
|
|
45
|
+
metric_manager (MetricManager): The metric manager to restore saved
|
|
46
|
+
metric states.
|
|
47
|
+
save_dir (str): Directory where checkpoints will be saved. Defaults
|
|
48
|
+
to 'checkpoints'.
|
|
49
|
+
create_dir (bool): Whether to create the save_dir if it does not exist.
|
|
50
|
+
Defaults to False.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
54
|
+
FileNotFoundError: If the save directory does not exist and create_dir
|
|
55
|
+
is set to False.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
network: Module,
|
|
61
|
+
optimizer: Optimizer,
|
|
62
|
+
metric_manager: MetricManager,
|
|
63
|
+
save_dir: str = "checkpoints",
|
|
64
|
+
create_dir: bool = False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize the checkpoint manager.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# Type checking
|
|
71
|
+
validate_type("network", network, Module)
|
|
72
|
+
validate_type("optimizer", optimizer, Optimizer)
|
|
73
|
+
validate_type("metric_manager", metric_manager, MetricManager)
|
|
74
|
+
validate_type("create_dir", create_dir, bool)
|
|
75
|
+
|
|
76
|
+
# Create path or raise error if create_dir is not found
|
|
77
|
+
if not os.path.exists(save_dir):
|
|
78
|
+
if not create_dir:
|
|
79
|
+
raise FileNotFoundError(
|
|
80
|
+
f"Save directory '{str(save_dir)}' configured in "
|
|
81
|
+
"checkpoint manager is not found."
|
|
82
|
+
)
|
|
83
|
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
84
|
+
|
|
85
|
+
# Initialize objects variables
|
|
86
|
+
self.network = network
|
|
87
|
+
self.optimizer = optimizer
|
|
88
|
+
self.metric_manager = metric_manager
|
|
89
|
+
self.save_dir = save_dir
|
|
90
|
+
|
|
91
|
+
self.criteria: dict[str, Callable[[Tensor, Tensor], Tensor]] = {}
|
|
92
|
+
self.best_metrics: dict[str, Tensor] = {}
|
|
93
|
+
|
|
94
|
+
def register(
|
|
95
|
+
self,
|
|
96
|
+
metric_name: str,
|
|
97
|
+
comparator: Callable[[Tensor, Tensor], Tensor] = gt,
|
|
98
|
+
):
|
|
99
|
+
"""
|
|
100
|
+
Register a criterion for evaluating a performance metric
|
|
101
|
+
during training.
|
|
102
|
+
|
|
103
|
+
Stores the comparator to determine whether the current metric has
|
|
104
|
+
improved relative to the previous best metric value.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
metric_name (str): The name of the metric to evaluate.
|
|
108
|
+
comparator (Callable[[Tensor, Tensor], Tensor], optional):
|
|
109
|
+
A function that compares the current metric value against the
|
|
110
|
+
previous best value. Defaults to a greater-than (gt) comparison.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
validate_type("metric_name", metric_name, str)
|
|
118
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
119
|
+
validate_comparator_pytorch("comparator", comparator)
|
|
120
|
+
|
|
121
|
+
self.criteria[metric_name] = comparator
|
|
122
|
+
|
|
123
|
+
def resume(
|
|
124
|
+
self, filename: str = "checkpoint.pth", ignore_missing: bool = False
|
|
125
|
+
) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Resumes training from a saved checkpoint file.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
filename (str): The name of the checkpoint file to load.
|
|
131
|
+
Defaults to "checkpoint.pth".
|
|
132
|
+
ignore_missing (bool): If True, does not raise an error if the
|
|
133
|
+
checkpoint file is missing and continues without loading,
|
|
134
|
+
starting from epoch 0. Defaults to False.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
int: The epoch number from the loaded checkpoint, or 0 if
|
|
138
|
+
ignore_missing is True and no checkpoint was found.
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
142
|
+
FileNotFoundError: If the specified checkpoint file does not exist.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
# Type checking
|
|
146
|
+
validate_type("filename", filename, str)
|
|
147
|
+
validate_type("ignore_missing", ignore_missing, bool)
|
|
148
|
+
|
|
149
|
+
# Return starting epoch, either from checkpoint file or default
|
|
150
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
151
|
+
if os.path.exists(filepath):
|
|
152
|
+
checkpoint = self.load("checkpoint.pth")
|
|
153
|
+
return checkpoint["epoch"]
|
|
154
|
+
elif ignore_missing:
|
|
155
|
+
return 0
|
|
156
|
+
else:
|
|
157
|
+
raise FileNotFoundError(
|
|
158
|
+
f"A checkpoint was not found at {filepath} to resume training."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def evaluate_criteria(self, epoch: int):
|
|
162
|
+
"""
|
|
163
|
+
Evaluate the defined criteria for model performance metrics
|
|
164
|
+
during training.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
epoch (int): The current epoch number.
|
|
168
|
+
|
|
169
|
+
Compares the current metrics against the previous best metrics using
|
|
170
|
+
predefined comparators. If a criterion is met, saves the model and
|
|
171
|
+
the corresponding best metric values.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
for metric_name, comparator in self.criteria.items():
|
|
175
|
+
|
|
176
|
+
current_metric_value = self.metric_manager.metrics[
|
|
177
|
+
metric_name
|
|
178
|
+
].aggregate()
|
|
179
|
+
best_metric_value = self.best_metrics.get(metric_name)
|
|
180
|
+
|
|
181
|
+
# TODO improve efficiency by not checking is None each iteration
|
|
182
|
+
if best_metric_value is None or comparator(
|
|
183
|
+
current_metric_value,
|
|
184
|
+
best_metric_value,
|
|
185
|
+
):
|
|
186
|
+
self.save(epoch)
|
|
187
|
+
self.best_metrics[metric_name] = current_metric_value
|
|
188
|
+
|
|
189
|
+
def save(
|
|
190
|
+
self,
|
|
191
|
+
epoch: int,
|
|
192
|
+
filename: str = "checkpoint.pth",
|
|
193
|
+
):
|
|
194
|
+
"""
|
|
195
|
+
Save a checkpoint.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
epoch (int): Current epoch number.
|
|
199
|
+
filename (str): Name of the checkpoint file. Defaults to
|
|
200
|
+
'checkpoint.pth'.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
state = {
|
|
204
|
+
"epoch": epoch,
|
|
205
|
+
"network_state": self.network.state_dict(),
|
|
206
|
+
"optimizer_state": self.optimizer.state_dict(),
|
|
207
|
+
"best_metrics": self.best_metrics,
|
|
208
|
+
}
|
|
209
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
210
|
+
save(state, filepath)
|
|
211
|
+
|
|
212
|
+
def load(self, filename: str):
|
|
213
|
+
"""
|
|
214
|
+
Load a checkpoint and restores the state of the network, optimizer
|
|
215
|
+
and best_metrics.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
filename (str): Name of the checkpoint file.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
dict: A dictionary containing the loaded checkpoint
|
|
222
|
+
information (epoch, loss, etc.).
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
226
|
+
|
|
227
|
+
checkpoint = load(filepath)
|
|
228
|
+
self.network.load_state_dict(checkpoint["network_state"])
|
|
229
|
+
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
230
|
+
self.best_metrics = checkpoint["best_metrics"]
|
|
231
|
+
|
|
232
|
+
return checkpoint
|