congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/__init__.py CHANGED
@@ -1,6 +1,4 @@
1
- # pylint: skip-file
2
-
3
- try:
1
+ try: # noqa: D104
4
2
  from importlib.metadata import version as get_version # Python 3.8+
5
3
  except ImportError:
6
4
  from pkg_resources import (
@@ -25,5 +23,6 @@ __all__ = [
25
23
  "descriptor",
26
24
  "metrics",
27
25
  "networks",
26
+ "transformations",
28
27
  "utils",
29
28
  ]
congrads/checkpoints.py CHANGED
@@ -1,130 +1,113 @@
1
- """
2
- This module provides a `CheckpointManager` class for managing the saving and
3
- loading of checkpoints during PyTorch model training.
4
-
5
- The `CheckpointManager` handles:
6
-
7
- - Saving and loading the state of models, optimizers, and metrics.
8
- - Registering and evaluating performance criteria to determine if a model's
9
- performance has improved, enabling automated saving of the best-performing
10
- model checkpoints.
11
- - Resuming training from a specific checkpoint.
12
-
13
- Usage:
14
- 1. Initialize the `CheckpointManager` with a PyTorch model, optimizer,
15
- and metric manager.
16
- 2. Register criteria for tracking and evaluating metrics.
17
- 3. Use the `save` and `load` methods to manage checkpoints during training.
18
- 4. Call `evaluate_criteria` to automatically evaluate and save the
19
- best-performing checkpoints.
20
-
21
- Dependencies:
22
- - PyTorch (`torch`)
1
+ """Module for managing PyTorch model checkpoints.
2
+
3
+ Provides the `CheckpointManager` class to save and load model and optimizer
4
+ states during training, track the best metric values, and optionally report
5
+ checkpoint events.
23
6
  """
24
7
 
25
8
  import os
9
+ from collections.abc import Callable
26
10
  from pathlib import Path
27
- from typing import Callable
28
11
 
29
- from torch import Tensor, gt, load, save
12
+ from torch import Tensor, load, save
30
13
  from torch.nn import Module
31
14
  from torch.optim import Optimizer
32
15
 
33
16
  from .metrics import MetricManager
34
- from .utils import validate_comparator_pytorch, validate_type
17
+ from .utils import validate_callable, validate_type
35
18
 
36
19
 
37
20
  class CheckpointManager:
38
- """
39
- A class to handle saving and loading checkpoints for
40
- PyTorch models and optimizers.
41
-
42
- Args:
43
- network (torch.nn.Module): The network (model) to save/load.
44
- optimizer (torch.optim.Optimizer): The optimizer to save/load.
45
- metric_manager (MetricManager): The metric manager to restore saved
46
- metric states.
47
- save_dir (str): Directory where checkpoints will be saved. Defaults
48
- to 'checkpoints'.
49
- create_dir (bool): Whether to create the save_dir if it does not exist.
50
- Defaults to False.
51
-
52
- Raises:
53
- TypeError: If a provided attribute has an incompatible type.
54
- FileNotFoundError: If the save directory does not exist and create_dir
55
- is set to False.
21
+ """Manage saving and loading checkpoints for PyTorch models and optimizers.
22
+
23
+ Handles checkpointing based on a criteria function, restores metric
24
+ states, and optionally reports when a checkpoint is saved.
56
25
  """
57
26
 
58
27
  def __init__(
59
28
  self,
29
+ criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
60
30
  network: Module,
61
31
  optimizer: Optimizer,
62
32
  metric_manager: MetricManager,
63
33
  save_dir: str = "checkpoints",
64
34
  create_dir: bool = False,
35
+ report_save: bool = False,
65
36
  ):
66
- """
67
- Initialize the checkpoint manager.
68
- """
37
+ """Initialize the CheckpointManager.
38
+
39
+ Args:
40
+ criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
41
+ Function that determines if the current checkpoint should be
42
+ saved based on the current and best metric values.
43
+ network (torch.nn.Module): The model to save/load.
44
+ optimizer (torch.optim.Optimizer): The optimizer to save/load.
45
+ metric_manager (MetricManager): Manages metric states for checkpointing.
46
+ save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
47
+ create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
48
+ Defaults to False.
49
+ report_save (bool, optional): Whether to report when a checkpoint is saved.
50
+ Defaults to False.
69
51
 
52
+ Raises:
53
+ TypeError: If any provided attribute has an incompatible type.
54
+ FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
55
+ """
70
56
  # Type checking
57
+ validate_callable("criteria_function", criteria_function)
71
58
  validate_type("network", network, Module)
72
59
  validate_type("optimizer", optimizer, Optimizer)
73
60
  validate_type("metric_manager", metric_manager, MetricManager)
74
61
  validate_type("create_dir", create_dir, bool)
62
+ validate_type("report_save", report_save, bool)
75
63
 
76
64
  # Create path or raise error if create_dir is not found
77
65
  if not os.path.exists(save_dir):
78
66
  if not create_dir:
79
67
  raise FileNotFoundError(
80
- f"Save directory '{str(save_dir)}' configured in "
81
- "checkpoint manager is not found."
68
+ f"Save directory '{save_dir}' configured in checkpoint manager is not found."
82
69
  )
83
70
  Path(save_dir).mkdir(parents=True, exist_ok=True)
84
71
 
85
72
  # Initialize objects variables
73
+ self.criteria_function = criteria_function
86
74
  self.network = network
87
75
  self.optimizer = optimizer
88
76
  self.metric_manager = metric_manager
89
77
  self.save_dir = save_dir
78
+ self.report_save = report_save
90
79
 
91
- self.criteria: dict[str, Callable[[Tensor, Tensor], Tensor]] = {}
92
- self.best_metrics: dict[str, Tensor] = {}
80
+ self.best_metric_values: dict[str, Tensor] = {}
93
81
 
94
- def register(
95
- self,
96
- metric_name: str,
97
- comparator: Callable[[Tensor, Tensor], Tensor] = gt,
98
- ):
99
- """
100
- Register a criterion for evaluating a performance metric
101
- during training.
82
+ def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
83
+ """Evaluate the criteria function to determine if a better model is found.
102
84
 
103
- Stores the comparator to determine whether the current metric has
104
- improved relative to the previous best metric value.
85
+ Aggregates the current metric values during training and applies the
86
+ criteria function. If the criteria function indicates improvement, the
87
+ best metric values are updated, a checkpoint is saved, and a message is
88
+ optionally printed.
105
89
 
106
90
  Args:
107
- metric_name (str): The name of the metric to evaluate.
108
- comparator (Callable[[Tensor, Tensor], Tensor], optional):
109
- A function that compares the current metric value against the
110
- previous best value. Defaults to a greater-than (gt) comparison.
111
-
112
- Raises:
113
- TypeError: If a provided attribute has an incompatible type.
114
-
91
+ epoch (int): The current epoch number.
92
+ metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
115
93
  """
94
+ current_metric_values = self.metric_manager.aggregate(metric_group)
95
+ if self.criteria_function is not None and self.criteria_function(
96
+ current_metric_values, self.best_metric_values
97
+ ):
98
+ # Print message if a new checkpoint is saved
99
+ if self.report_save:
100
+ print(f"New checkpoint saved at epoch {epoch}.")
116
101
 
117
- validate_type("metric_name", metric_name, str)
118
- validate_comparator_pytorch("comparator", comparator)
119
- validate_comparator_pytorch("comparator", comparator)
102
+ # Update current best metric values
103
+ for metric_name, metric_value in current_metric_values.items():
104
+ self.best_metric_values[metric_name] = metric_value
120
105
 
121
- self.criteria[metric_name] = comparator
106
+ # Save the current state
107
+ self.save(epoch)
122
108
 
123
- def resume(
124
- self, filename: str = "checkpoint.pth", ignore_missing: bool = False
125
- ) -> int:
126
- """
127
- Resumes training from a saved checkpoint file.
109
+ def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
110
+ """Resumes training from a saved checkpoint file.
128
111
 
129
112
  Args:
130
113
  filename (str): The name of the checkpoint file to load.
@@ -141,7 +124,6 @@ class CheckpointManager:
141
124
  TypeError: If a provided attribute has an incompatible type.
142
125
  FileNotFoundError: If the specified checkpoint file does not exist.
143
126
  """
144
-
145
127
  # Type checking
146
128
  validate_type("filename", filename, str)
147
129
  validate_type("ignore_missing", ignore_missing, bool)
@@ -149,84 +131,48 @@ class CheckpointManager:
149
131
  # Return starting epoch, either from checkpoint file or default
150
132
  filepath = os.path.join(self.save_dir, filename)
151
133
  if os.path.exists(filepath):
152
- checkpoint = self.load("checkpoint.pth")
134
+ checkpoint = self.load(filename)
153
135
  return checkpoint["epoch"]
154
136
  elif ignore_missing:
155
137
  return 0
156
138
  else:
157
- raise FileNotFoundError(
158
- f"A checkpoint was not found at {filepath} to resume training."
159
- )
160
-
161
- def evaluate_criteria(self, epoch: int):
162
- """
163
- Evaluate the defined criteria for model performance metrics
164
- during training.
165
-
166
- Args:
167
- epoch (int): The current epoch number.
168
-
169
- Compares the current metrics against the previous best metrics using
170
- predefined comparators. If a criterion is met, saves the model and
171
- the corresponding best metric values.
172
- """
173
-
174
- for metric_name, comparator in self.criteria.items():
139
+ raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
175
140
 
176
- current_metric_value = self.metric_manager.metrics[
177
- metric_name
178
- ].aggregate()
179
- best_metric_value = self.best_metrics.get(metric_name)
180
-
181
- # TODO improve efficiency by not checking is None each iteration
182
- if best_metric_value is None or comparator(
183
- current_metric_value,
184
- best_metric_value,
185
- ):
186
- self.save(epoch)
187
- self.best_metrics[metric_name] = current_metric_value
188
-
189
- def save(
190
- self,
191
- epoch: int,
192
- filename: str = "checkpoint.pth",
193
- ):
194
- """
195
- Save a checkpoint.
141
+ def save(self, epoch: int, filename: str = "checkpoint.pth"):
142
+ """Save a checkpoint.
196
143
 
197
144
  Args:
198
145
  epoch (int): Current epoch number.
199
146
  filename (str): Name of the checkpoint file. Defaults to
200
147
  'checkpoint.pth'.
201
148
  """
202
-
203
149
  state = {
204
150
  "epoch": epoch,
205
151
  "network_state": self.network.state_dict(),
206
152
  "optimizer_state": self.optimizer.state_dict(),
207
- "best_metrics": self.best_metrics,
153
+ "best_metrics": self.best_metric_values,
208
154
  }
209
155
  filepath = os.path.join(self.save_dir, filename)
210
156
  save(state, filepath)
211
157
 
212
158
  def load(self, filename: str):
213
- """
214
- Load a checkpoint and restores the state of the network, optimizer
215
- and best_metrics.
159
+ """Load a checkpoint and restore the training state.
160
+
161
+ Loads the checkpoint from the specified file and restores the network
162
+ weights, optimizer state, and best metric values.
216
163
 
217
164
  Args:
218
165
  filename (str): Name of the checkpoint file.
219
166
 
220
167
  Returns:
221
- dict: A dictionary containing the loaded checkpoint
222
- information (epoch, loss, etc.).
168
+ dict: A dictionary containing the loaded checkpoint information,
169
+ including epoch, loss, and other relevant training state.
223
170
  """
224
-
225
171
  filepath = os.path.join(self.save_dir, filename)
226
172
 
227
- checkpoint = load(filepath)
173
+ checkpoint = load(filepath, weights_only=True)
228
174
  self.network.load_state_dict(checkpoint["network_state"])
229
175
  self.optimizer.load_state_dict(checkpoint["optimizer_state"])
230
- self.best_metrics = checkpoint["best_metrics"]
176
+ self.best_metric_values = checkpoint["best_metrics"]
231
177
 
232
178
  return checkpoint