congrads 1.0.7__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +804 -454
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +121 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.7.dist-info → congrads-1.1.1.dist-info}/METADATA +48 -41
- congrads-1.1.1.dist-info/RECORD +14 -0
- congrads-1.1.1.dist-info/WHEEL +4 -0
- congrads-1.0.7.dist-info/LICENSE +0 -26
- congrads-1.0.7.dist-info/RECORD +0 -15
- congrads-1.0.7.dist-info/WHEEL +0 -5
- congrads-1.0.7.dist-info/top_level.txt +0 -1
congrads/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
3
|
-
try:
|
|
1
|
+
try: # noqa: D104
|
|
4
2
|
from importlib.metadata import version as get_version # Python 3.8+
|
|
5
3
|
except ImportError:
|
|
6
4
|
from pkg_resources import (
|
|
@@ -25,5 +23,6 @@ __all__ = [
|
|
|
25
23
|
"descriptor",
|
|
26
24
|
"metrics",
|
|
27
25
|
"networks",
|
|
26
|
+
"transformations",
|
|
28
27
|
"utils",
|
|
29
28
|
]
|
congrads/checkpoints.py
CHANGED
|
@@ -1,130 +1,113 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
- Saving and loading the state of models, optimizers, and metrics.
|
|
8
|
-
- Registering and evaluating performance criteria to determine if a model's
|
|
9
|
-
performance has improved, enabling automated saving of the best-performing
|
|
10
|
-
model checkpoints.
|
|
11
|
-
- Resuming training from a specific checkpoint.
|
|
12
|
-
|
|
13
|
-
Usage:
|
|
14
|
-
1. Initialize the `CheckpointManager` with a PyTorch model, optimizer,
|
|
15
|
-
and metric manager.
|
|
16
|
-
2. Register criteria for tracking and evaluating metrics.
|
|
17
|
-
3. Use the `save` and `load` methods to manage checkpoints during training.
|
|
18
|
-
4. Call `evaluate_criteria` to automatically evaluate and save the
|
|
19
|
-
best-performing checkpoints.
|
|
20
|
-
|
|
21
|
-
Dependencies:
|
|
22
|
-
- PyTorch (`torch`)
|
|
1
|
+
"""Module for managing PyTorch model checkpoints.
|
|
2
|
+
|
|
3
|
+
Provides the `CheckpointManager` class to save and load model and optimizer
|
|
4
|
+
states during training, track the best metric values, and optionally report
|
|
5
|
+
checkpoint events.
|
|
23
6
|
"""
|
|
24
7
|
|
|
25
8
|
import os
|
|
9
|
+
from collections.abc import Callable
|
|
26
10
|
from pathlib import Path
|
|
27
|
-
from typing import Callable
|
|
28
11
|
|
|
29
|
-
from torch import Tensor,
|
|
12
|
+
from torch import Tensor, load, save
|
|
30
13
|
from torch.nn import Module
|
|
31
14
|
from torch.optim import Optimizer
|
|
32
15
|
|
|
33
16
|
from .metrics import MetricManager
|
|
34
|
-
from .utils import
|
|
17
|
+
from .utils import validate_callable, validate_type
|
|
35
18
|
|
|
36
19
|
|
|
37
20
|
class CheckpointManager:
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
network (torch.nn.Module): The network (model) to save/load.
|
|
44
|
-
optimizer (torch.optim.Optimizer): The optimizer to save/load.
|
|
45
|
-
metric_manager (MetricManager): The metric manager to restore saved
|
|
46
|
-
metric states.
|
|
47
|
-
save_dir (str): Directory where checkpoints will be saved. Defaults
|
|
48
|
-
to 'checkpoints'.
|
|
49
|
-
create_dir (bool): Whether to create the save_dir if it does not exist.
|
|
50
|
-
Defaults to False.
|
|
51
|
-
|
|
52
|
-
Raises:
|
|
53
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
54
|
-
FileNotFoundError: If the save directory does not exist and create_dir
|
|
55
|
-
is set to False.
|
|
21
|
+
"""Manage saving and loading checkpoints for PyTorch models and optimizers.
|
|
22
|
+
|
|
23
|
+
Handles checkpointing based on a criteria function, restores metric
|
|
24
|
+
states, and optionally reports when a checkpoint is saved.
|
|
56
25
|
"""
|
|
57
26
|
|
|
58
27
|
def __init__(
|
|
59
28
|
self,
|
|
29
|
+
criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
|
|
60
30
|
network: Module,
|
|
61
31
|
optimizer: Optimizer,
|
|
62
32
|
metric_manager: MetricManager,
|
|
63
33
|
save_dir: str = "checkpoints",
|
|
64
34
|
create_dir: bool = False,
|
|
35
|
+
report_save: bool = False,
|
|
65
36
|
):
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
|
|
37
|
+
"""Initialize the CheckpointManager.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
|
|
41
|
+
Function that determines if the current checkpoint should be
|
|
42
|
+
saved based on the current and best metric values.
|
|
43
|
+
network (torch.nn.Module): The model to save/load.
|
|
44
|
+
optimizer (torch.optim.Optimizer): The optimizer to save/load.
|
|
45
|
+
metric_manager (MetricManager): Manages metric states for checkpointing.
|
|
46
|
+
save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
|
|
47
|
+
create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
|
|
48
|
+
Defaults to False.
|
|
49
|
+
report_save (bool, optional): Whether to report when a checkpoint is saved.
|
|
50
|
+
Defaults to False.
|
|
69
51
|
|
|
52
|
+
Raises:
|
|
53
|
+
TypeError: If any provided attribute has an incompatible type.
|
|
54
|
+
FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
|
|
55
|
+
"""
|
|
70
56
|
# Type checking
|
|
57
|
+
validate_callable("criteria_function", criteria_function)
|
|
71
58
|
validate_type("network", network, Module)
|
|
72
59
|
validate_type("optimizer", optimizer, Optimizer)
|
|
73
60
|
validate_type("metric_manager", metric_manager, MetricManager)
|
|
74
61
|
validate_type("create_dir", create_dir, bool)
|
|
62
|
+
validate_type("report_save", report_save, bool)
|
|
75
63
|
|
|
76
64
|
# Create path or raise error if create_dir is not found
|
|
77
65
|
if not os.path.exists(save_dir):
|
|
78
66
|
if not create_dir:
|
|
79
67
|
raise FileNotFoundError(
|
|
80
|
-
f"Save directory '{
|
|
81
|
-
"checkpoint manager is not found."
|
|
68
|
+
f"Save directory '{save_dir}' configured in checkpoint manager is not found."
|
|
82
69
|
)
|
|
83
70
|
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
84
71
|
|
|
85
72
|
# Initialize objects variables
|
|
73
|
+
self.criteria_function = criteria_function
|
|
86
74
|
self.network = network
|
|
87
75
|
self.optimizer = optimizer
|
|
88
76
|
self.metric_manager = metric_manager
|
|
89
77
|
self.save_dir = save_dir
|
|
78
|
+
self.report_save = report_save
|
|
90
79
|
|
|
91
|
-
self.
|
|
92
|
-
self.best_metrics: dict[str, Tensor] = {}
|
|
80
|
+
self.best_metric_values: dict[str, Tensor] = {}
|
|
93
81
|
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
metric_name: str,
|
|
97
|
-
comparator: Callable[[Tensor, Tensor], Tensor] = gt,
|
|
98
|
-
):
|
|
99
|
-
"""
|
|
100
|
-
Register a criterion for evaluating a performance metric
|
|
101
|
-
during training.
|
|
82
|
+
def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
|
|
83
|
+
"""Evaluate the criteria function to determine if a better model is found.
|
|
102
84
|
|
|
103
|
-
|
|
104
|
-
|
|
85
|
+
Aggregates the current metric values during training and applies the
|
|
86
|
+
criteria function. If the criteria function indicates improvement, the
|
|
87
|
+
best metric values are updated, a checkpoint is saved, and a message is
|
|
88
|
+
optionally printed.
|
|
105
89
|
|
|
106
90
|
Args:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
A function that compares the current metric value against the
|
|
110
|
-
previous best value. Defaults to a greater-than (gt) comparison.
|
|
111
|
-
|
|
112
|
-
Raises:
|
|
113
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
114
|
-
|
|
91
|
+
epoch (int): The current epoch number.
|
|
92
|
+
metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
|
|
115
93
|
"""
|
|
94
|
+
current_metric_values = self.metric_manager.aggregate(metric_group)
|
|
95
|
+
if self.criteria_function is not None and self.criteria_function(
|
|
96
|
+
current_metric_values, self.best_metric_values
|
|
97
|
+
):
|
|
98
|
+
# Print message if a new checkpoint is saved
|
|
99
|
+
if self.report_save:
|
|
100
|
+
print(f"New checkpoint saved at epoch {epoch}.")
|
|
116
101
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
102
|
+
# Update current best metric values
|
|
103
|
+
for metric_name, metric_value in current_metric_values.items():
|
|
104
|
+
self.best_metric_values[metric_name] = metric_value
|
|
120
105
|
|
|
121
|
-
|
|
106
|
+
# Save the current state
|
|
107
|
+
self.save(epoch)
|
|
122
108
|
|
|
123
|
-
def resume(
|
|
124
|
-
|
|
125
|
-
) -> int:
|
|
126
|
-
"""
|
|
127
|
-
Resumes training from a saved checkpoint file.
|
|
109
|
+
def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
|
|
110
|
+
"""Resumes training from a saved checkpoint file.
|
|
128
111
|
|
|
129
112
|
Args:
|
|
130
113
|
filename (str): The name of the checkpoint file to load.
|
|
@@ -141,7 +124,6 @@ class CheckpointManager:
|
|
|
141
124
|
TypeError: If a provided attribute has an incompatible type.
|
|
142
125
|
FileNotFoundError: If the specified checkpoint file does not exist.
|
|
143
126
|
"""
|
|
144
|
-
|
|
145
127
|
# Type checking
|
|
146
128
|
validate_type("filename", filename, str)
|
|
147
129
|
validate_type("ignore_missing", ignore_missing, bool)
|
|
@@ -149,84 +131,48 @@ class CheckpointManager:
|
|
|
149
131
|
# Return starting epoch, either from checkpoint file or default
|
|
150
132
|
filepath = os.path.join(self.save_dir, filename)
|
|
151
133
|
if os.path.exists(filepath):
|
|
152
|
-
checkpoint = self.load(
|
|
134
|
+
checkpoint = self.load(filename)
|
|
153
135
|
return checkpoint["epoch"]
|
|
154
136
|
elif ignore_missing:
|
|
155
137
|
return 0
|
|
156
138
|
else:
|
|
157
|
-
raise FileNotFoundError(
|
|
158
|
-
f"A checkpoint was not found at {filepath} to resume training."
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
def evaluate_criteria(self, epoch: int):
|
|
162
|
-
"""
|
|
163
|
-
Evaluate the defined criteria for model performance metrics
|
|
164
|
-
during training.
|
|
165
|
-
|
|
166
|
-
Args:
|
|
167
|
-
epoch (int): The current epoch number.
|
|
168
|
-
|
|
169
|
-
Compares the current metrics against the previous best metrics using
|
|
170
|
-
predefined comparators. If a criterion is met, saves the model and
|
|
171
|
-
the corresponding best metric values.
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
for metric_name, comparator in self.criteria.items():
|
|
139
|
+
raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
|
|
175
140
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
].aggregate()
|
|
179
|
-
best_metric_value = self.best_metrics.get(metric_name)
|
|
180
|
-
|
|
181
|
-
# TODO improve efficiency by not checking is None each iteration
|
|
182
|
-
if best_metric_value is None or comparator(
|
|
183
|
-
current_metric_value,
|
|
184
|
-
best_metric_value,
|
|
185
|
-
):
|
|
186
|
-
self.save(epoch)
|
|
187
|
-
self.best_metrics[metric_name] = current_metric_value
|
|
188
|
-
|
|
189
|
-
def save(
|
|
190
|
-
self,
|
|
191
|
-
epoch: int,
|
|
192
|
-
filename: str = "checkpoint.pth",
|
|
193
|
-
):
|
|
194
|
-
"""
|
|
195
|
-
Save a checkpoint.
|
|
141
|
+
def save(self, epoch: int, filename: str = "checkpoint.pth"):
|
|
142
|
+
"""Save a checkpoint.
|
|
196
143
|
|
|
197
144
|
Args:
|
|
198
145
|
epoch (int): Current epoch number.
|
|
199
146
|
filename (str): Name of the checkpoint file. Defaults to
|
|
200
147
|
'checkpoint.pth'.
|
|
201
148
|
"""
|
|
202
|
-
|
|
203
149
|
state = {
|
|
204
150
|
"epoch": epoch,
|
|
205
151
|
"network_state": self.network.state_dict(),
|
|
206
152
|
"optimizer_state": self.optimizer.state_dict(),
|
|
207
|
-
"best_metrics": self.
|
|
153
|
+
"best_metrics": self.best_metric_values,
|
|
208
154
|
}
|
|
209
155
|
filepath = os.path.join(self.save_dir, filename)
|
|
210
156
|
save(state, filepath)
|
|
211
157
|
|
|
212
158
|
def load(self, filename: str):
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
and
|
|
159
|
+
"""Load a checkpoint and restore the training state.
|
|
160
|
+
|
|
161
|
+
Loads the checkpoint from the specified file and restores the network
|
|
162
|
+
weights, optimizer state, and best metric values.
|
|
216
163
|
|
|
217
164
|
Args:
|
|
218
165
|
filename (str): Name of the checkpoint file.
|
|
219
166
|
|
|
220
167
|
Returns:
|
|
221
|
-
dict: A dictionary containing the loaded checkpoint
|
|
222
|
-
|
|
168
|
+
dict: A dictionary containing the loaded checkpoint information,
|
|
169
|
+
including epoch, loss, and other relevant training state.
|
|
223
170
|
"""
|
|
224
|
-
|
|
225
171
|
filepath = os.path.join(self.save_dir, filename)
|
|
226
172
|
|
|
227
|
-
checkpoint = load(filepath)
|
|
173
|
+
checkpoint = load(filepath, weights_only=True)
|
|
228
174
|
self.network.load_state_dict(checkpoint["network_state"])
|
|
229
175
|
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
230
|
-
self.
|
|
176
|
+
self.best_metric_values = checkpoint["best_metrics"]
|
|
231
177
|
|
|
232
178
|
return checkpoint
|