congrads 1.0.6__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {congrads-1.0.6 → congrads-1.1.0}/PKG-INFO +48 -41
  2. {congrads-1.0.6 → congrads-1.1.0}/README.md +11 -5
  3. congrads-1.1.0/pyproject.toml +122 -0
  4. {congrads-1.0.6 → congrads-1.1.0/src}/congrads/__init__.py +2 -3
  5. congrads-1.1.0/src/congrads/checkpoints.py +178 -0
  6. congrads-1.1.0/src/congrads/constraints.py +1256 -0
  7. congrads-1.1.0/src/congrads/core.py +773 -0
  8. congrads-1.1.0/src/congrads/datasets.py +799 -0
  9. congrads-1.1.0/src/congrads/descriptor.py +166 -0
  10. congrads-1.1.0/src/congrads/metrics.py +139 -0
  11. congrads-1.1.0/src/congrads/networks.py +68 -0
  12. congrads-1.1.0/src/congrads/py.typed +0 -0
  13. congrads-1.1.0/src/congrads/transformations.py +116 -0
  14. {congrads-1.0.6 → congrads-1.1.0/src}/congrads/utils.py +499 -131
  15. congrads-1.0.6/.gitignore +0 -249
  16. congrads-1.0.6/.gitlab-ci.yml +0 -117
  17. congrads-1.0.6/.pylintrc +0 -399
  18. congrads-1.0.6/.readthedocs.yaml +0 -26
  19. congrads-1.0.6/.vscode/extensions.json +0 -20
  20. congrads-1.0.6/.vscode/settings.json +0 -68
  21. congrads-1.0.6/congrads/checkpoints.py +0 -232
  22. congrads-1.0.6/congrads/constraints.py +0 -919
  23. congrads-1.0.6/congrads/core.py +0 -597
  24. congrads-1.0.6/congrads/datasets.py +0 -499
  25. congrads-1.0.6/congrads/descriptor.py +0 -130
  26. congrads-1.0.6/congrads/metrics.py +0 -211
  27. congrads-1.0.6/congrads/networks.py +0 -114
  28. congrads-1.0.6/congrads/transformations.py +0 -139
  29. congrads-1.0.6/congrads.egg-info/PKG-INFO +0 -222
  30. congrads-1.0.6/congrads.egg-info/SOURCES.txt +0 -47
  31. congrads-1.0.6/congrads.egg-info/dependency_links.txt +0 -1
  32. congrads-1.0.6/congrads.egg-info/requires.txt +0 -6
  33. congrads-1.0.6/congrads.egg-info/top_level.txt +0 -1
  34. congrads-1.0.6/docs/Makefile +0 -20
  35. congrads-1.0.6/docs/_static/VanBaelen2023.bib +0 -13
  36. congrads-1.0.6/docs/_static/congrads_export.png +0 -0
  37. congrads-1.0.6/docs/_static/congrads_export.svg +0 -52
  38. congrads-1.0.6/docs/_static/congrads_favicon.png +0 -0
  39. congrads-1.0.6/docs/_static/congrads_favicon.svg +0 -80
  40. congrads-1.0.6/docs/_static/convergence_illustration.gif +0 -0
  41. congrads-1.0.6/docs/api.rst +0 -76
  42. congrads-1.0.6/docs/concepts.rst +0 -224
  43. congrads-1.0.6/docs/conf.py +0 -77
  44. congrads-1.0.6/docs/index.rst +0 -116
  45. congrads-1.0.6/docs/make.bat +0 -35
  46. congrads-1.0.6/docs/requirements.in +0 -2
  47. congrads-1.0.6/docs/requirements.txt +0 -57
  48. congrads-1.0.6/docs/start.rst +0 -126
  49. congrads-1.0.6/examples/BiasCorrection.py +0 -149
  50. congrads-1.0.6/examples/FamilyIncome.py +0 -210
  51. congrads-1.0.6/examples/NoisySines.py +0 -252
  52. congrads-1.0.6/notebooks/BiasCorrection.ipynb +0 -232
  53. congrads-1.0.6/notebooks/FamilyIncome.ipynb +0 -293
  54. congrads-1.0.6/pyproject.toml +0 -29
  55. congrads-1.0.6/setup.cfg +0 -4
  56. congrads-1.0.6/tests/congrads/test_utils.py +0 -80
  57. congrads-1.0.6/tests/examples/test_BiasCorrection.py +0 -150
  58. congrads-1.0.6/tests/examples/test_FamilyIncome.py +0 -211
  59. {congrads-1.0.6 → congrads-1.1.0}/LICENSE +0 -0
@@ -1,44 +1,45 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.3
2
2
  Name: congrads
3
- Version: 1.0.6
3
+ Version: 1.1.0
4
4
  Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
5
+ Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
5
6
  Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
6
7
  License: Copyright 2024 DTAI - KU Leuven
7
-
8
- Redistribution and use in source and binary forms, with or without modification,
9
- are permitted provided that the following conditions are met:
10
-
11
- 1. Redistributions of source code must retain the above copyright notice,
12
- this list of conditions and the following disclaimer.
13
-
14
- 2. Redistributions in binary form must reproduce the above copyright notice,
15
- this list of conditions and the following disclaimer in the documentation
16
- and/or other materials provided with the distribution.
17
-
18
- 3. Neither the name of the copyright holder nor the names of its
19
- contributors may be used to endorse or promote products derived from
20
- this software without specific prior written permission.
21
-
22
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS”
23
- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25
- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
26
- LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
28
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
30
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
-
33
- Requires-Python: >=3.9
8
+
9
+ Redistribution and use in source and binary forms, with or without modification,
10
+ are permitted provided that the following conditions are met:
11
+
12
+ 1. Redistributions of source code must retain the above copyright notice,
13
+ this list of conditions and the following disclaimer.
14
+
15
+ 2. Redistributions in binary form must reproduce the above copyright notice,
16
+ this list of conditions and the following disclaimer in the documentation
17
+ and/or other materials provided with the distribution.
18
+
19
+ 3. Neither the name of the copyright holder nor the names of its
20
+ contributors may be used to endorse or promote products derived from
21
+ this software without specific prior written permission.
22
+
23
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS”
24
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
26
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
27
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
+ Requires-Dist: numpy>=1.24.0
34
+ Requires-Dist: pandas>=1.5.0
35
+ Requires-Dist: torch>=2.0.0
36
+ Requires-Dist: torchvision>=0.15.1
37
+ Requires-Dist: tqdm>=4.65.0
38
+ Requires-Dist: matplotlib>=3.7.0 ; extra == 'examples'
39
+ Requires-Dist: tensorboard>=2.12.0 ; extra == 'examples'
40
+ Requires-Python: >=3.11
41
+ Provides-Extra: examples
34
42
  Description-Content-Type: text/markdown
35
- License-File: LICENSE
36
- Requires-Dist: numpy>=1.26.4
37
- Requires-Dist: pandas>=2.2.2
38
- Requires-Dist: torch>=2.5.0
39
- Requires-Dist: torchvision>=0.20.0
40
- Requires-Dist: tensorboard>=2.18.0
41
- Requires-Dist: tqdm>=4.66.5
42
43
 
43
44
  <div align="center">
44
45
  <img src="https://github.com/ML-KULeuven/congrads/blob/main/docs/_static/congrads_export.png?raw=true" height="200">
@@ -49,7 +50,7 @@ Requires-Dist: tqdm>=4.66.5
49
50
 
50
51
  [![PyPi](https://img.shields.io/pypi/v/congrads.svg)](https://pypi.org/project/congrads)
51
52
  [![Read the Docs](https://img.shields.io/readthedocs/congrads/latest.svg?label=Read%20the%20Docs)](https://congrads.readthedocs.io)
52
- [![Python Version: 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://pypi.org/project/congrads)
53
+ [![Python Version: 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg)](https://pypi.org/project/congrads)
53
54
  [![Downloads](https://img.shields.io/pypi/dm/congrads.svg)](https://pypistats.org/packages/congrads)
54
55
  [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
55
56
 
@@ -80,15 +81,21 @@ Next, install the Congrads toolbox. The recommended way to install it is to use
80
81
  pip install congrads
81
82
  ```
82
83
 
84
+ You can also install Congrads together with extra packages required to run the examples:
85
+
86
+ ```bash
87
+ pip install congrads[examples]
88
+ ```
89
+
83
90
  This should automatically install all required dependencies for you. If you would like to install dependencies manually, Congrads depends on the following:
84
91
 
85
- - Python 3.9 - 3.12
92
+ - Python 3.11 - 3.13
86
93
  - **PyTorch** (install with CUDA support for GPU training, refer to [PyTorch's getting started guide](https://pytorch.org/get-started/locally/))
87
94
  - **NumPy** (install with `pip install numpy`, or refer to [NumPy's install guide](https://numpy.org/install/).)
88
95
  - **Pandas** (install with `pip install pandas`, or refer to [Panda's install guide](https://pandas.pydata.org/docs/getting_started/install.html).)
89
96
  - **Tqdm** (install with `pip install tqdm`)
90
97
  - **Torchvision** (install with `pip install torchvision`)
91
- - **Tensorboard** (install with `pip install tensorboard`)
98
+ - Optional: **Tensorboard** (install with `pip install tensorboard`)
92
99
 
93
100
  ### 2. **Core concepts**
94
101
 
@@ -182,11 +189,11 @@ core.fit(max_epochs=50)
182
189
  - **Improve Training Process**: Inject domain knowledge in the training stage, increasing learning efficiency.
183
190
  - **Physics-Informed Neural Networks (PINNs)**: Coming soon, Enforce physical laws as constraints in your models.
184
191
 
185
- ## Roadmap
192
+ ## Planned changes / Roadmap
186
193
 
187
194
  - [ ] Add ODE/PDE constraints to support PINNs
195
+ - [ ] Rework callback system
188
196
  - [ ] Add support for constraint parser that can interpret equations
189
- - [ ] Determine if it is feasible to add unit and or functional tests
190
197
 
191
198
  ## Research
192
199
 
@@ -7,7 +7,7 @@
7
7
 
8
8
  [![PyPi](https://img.shields.io/pypi/v/congrads.svg)](https://pypi.org/project/congrads)
9
9
  [![Read the Docs](https://img.shields.io/readthedocs/congrads/latest.svg?label=Read%20the%20Docs)](https://congrads.readthedocs.io)
10
- [![Python Version: 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg)](https://pypi.org/project/congrads)
10
+ [![Python Version: 3.11+](https://img.shields.io/badge/Python-3.11+-blue.svg)](https://pypi.org/project/congrads)
11
11
  [![Downloads](https://img.shields.io/pypi/dm/congrads.svg)](https://pypistats.org/packages/congrads)
12
12
  [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
13
13
 
@@ -38,15 +38,21 @@ Next, install the Congrads toolbox. The recommended way to install it is to use
38
38
  pip install congrads
39
39
  ```
40
40
 
41
+ You can also install Congrads together with extra packages required to run the examples:
42
+
43
+ ```bash
44
+ pip install congrads[examples]
45
+ ```
46
+
41
47
  This should automatically install all required dependencies for you. If you would like to install dependencies manually, Congrads depends on the following:
42
48
 
43
- - Python 3.9 - 3.12
49
+ - Python 3.11 - 3.13
44
50
  - **PyTorch** (install with CUDA support for GPU training, refer to [PyTorch's getting started guide](https://pytorch.org/get-started/locally/))
45
51
  - **NumPy** (install with `pip install numpy`, or refer to [NumPy's install guide](https://numpy.org/install/).)
46
52
  - **Pandas** (install with `pip install pandas`, or refer to [Panda's install guide](https://pandas.pydata.org/docs/getting_started/install.html).)
47
53
  - **Tqdm** (install with `pip install tqdm`)
48
54
  - **Torchvision** (install with `pip install torchvision`)
49
- - **Tensorboard** (install with `pip install tensorboard`)
55
+ - Optional: **Tensorboard** (install with `pip install tensorboard`)
50
56
 
51
57
  ### 2. **Core concepts**
52
58
 
@@ -140,11 +146,11 @@ core.fit(max_epochs=50)
140
146
  - **Improve Training Process**: Inject domain knowledge in the training stage, increasing learning efficiency.
141
147
  - **Physics-Informed Neural Networks (PINNs)**: Coming soon, Enforce physical laws as constraints in your models.
142
148
 
143
- ## Roadmap
149
+ ## Planned changes / Roadmap
144
150
 
145
151
  - [ ] Add ODE/PDE constraints to support PINNs
152
+ - [ ] Rework callback system
146
153
  - [ ] Add support for constraint parser that can interpret equations
147
- - [ ] Determine if it is feasible to add unit and or functional tests
148
154
 
149
155
  ## Research
150
156
 
@@ -0,0 +1,122 @@
1
+ [project]
2
+ name = "congrads"
3
+ version = "1.1.0"
4
+ description = "A toolbox for using Constraint Guided Gradient Descent when training neural networks."
5
+ readme = "README.md"
6
+ license = { file = "LICENSE" }
7
+ authors = [
8
+ { name = "Wout Rombouts", email = "wout.rombouts@kuleuven.be" },
9
+ { name = "Quinten Van Baelen", email = "quinten.vanbaelen@kuleuven.be" },
10
+ { name = "Peter Karsmakers", email = "peter.karsmakers@kuleuven.be" }
11
+ ]
12
+ requires-python = ">=3.11"
13
+ dependencies = [
14
+ "numpy>=1.24.0",
15
+ "pandas>=1.5.0",
16
+ "torch>=2.0.0",
17
+ "torchvision>=0.15.1",
18
+ "tqdm>=4.65.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ examples = [
23
+ "matplotlib>=3.7.0",
24
+ "tensorboard>=2.12.0",
25
+ ]
26
+
27
+ [build-system]
28
+ requires = ["uv_build>=0.8.13,<0.9.0"]
29
+ build-backend = "uv_build"
30
+
31
+ [dependency-groups]
32
+ dev = [
33
+ "pytest>=8.4.2",
34
+ "pytest-cov>=7.0.0",
35
+ "ruff>=0.13.0",
36
+ "sphinx>=8.2.3",
37
+ "sphinx-rtd-theme>=3.0.2",
38
+ "tox>=4.30.2",
39
+ "twine>=6.2.0",
40
+ ]
41
+
42
+ [tool.ruff]
43
+ # Match Black-style defaults but with a longer line length
44
+ line-length = 100
45
+ indent-width = 4
46
+ target-version = "py311" # or "py312" if needed
47
+
48
+ # Where Ruff runs (exclude cache, venvs, build dirs, etc.)
49
+ exclude = [
50
+ ".git",
51
+ ".mypy_cache",
52
+ ".ruff_cache",
53
+ ".venv",
54
+ "venv",
55
+ "build",
56
+ "dist",
57
+ "__pypackages__",
58
+ ]
59
+
60
+ # Which files to include
61
+ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
62
+
63
+ [tool.ruff.lint]
64
+ # Essentials + additional readability rules
65
+ select = [
66
+ "E", # pycodestyle errors
67
+ "F", # pyflakes
68
+ "I", # import sorting
69
+ "B", # bugbear
70
+ "UP", # pyupgrade
71
+ "C4", # comprehensions
72
+ "ISC", # implicit string concatenation
73
+ "D", # docstring styling
74
+ "N", # consistent naming
75
+ ]
76
+ ignore = [
77
+ "E501", # line length (handled by formatter)
78
+ ]
79
+
80
+ # Enable autofix for safe rules
81
+ fixable = ["ALL"]
82
+
83
+ [tool.ruff.lint.per-file-ignores]
84
+ "tests/*" = ["D", "N"]
85
+
86
+ [tool.ruff.format]
87
+ quote-style = "double" # enforce double quotes
88
+ indent-style = "space" # 4 spaces
89
+ skip-magic-trailing-comma = false # Black-compatible
90
+ docstring-code-format = true
91
+ docstring-code-line-length = "dynamic"
92
+
93
+ [tool.ruff.lint.pydocstyle]
94
+ convention = "google"
95
+
96
+ [tool.tox]
97
+ envlist = ["py{311,312,313}-{latest,minimum}"]
98
+
99
+ [tool.tox.env_run_base]
100
+ description = "Run tests with latest dependencies"
101
+ deps = [
102
+ "pytest",
103
+ "pytest-cov"
104
+ ]
105
+ commands = [
106
+ ["pytest", "--cov=src/congrads", "tests/"]
107
+ ]
108
+ extras = [ "examples" ]
109
+
110
+ [tool.tox.env.latest]
111
+ description = "Run tests with latest dependencies"
112
+
113
+ [tool.tox.env.minimum]
114
+ description = "Run tests with minimum dependencies"
115
+ deps = [
116
+ "-c constraints-min.txt",
117
+ "pytest",
118
+ "pytest-cov"
119
+ ]
120
+
121
+ # Optional: ensure coverage is reported for all envs
122
+ skip_missing_interpreters = true
@@ -1,6 +1,4 @@
1
- # pylint: skip-file
2
-
3
- try:
1
+ try: # noqa: D104
4
2
  from importlib.metadata import version as get_version # Python 3.8+
5
3
  except ImportError:
6
4
  from pkg_resources import (
@@ -25,5 +23,6 @@ __all__ = [
25
23
  "descriptor",
26
24
  "metrics",
27
25
  "networks",
26
+ "transformations",
28
27
  "utils",
29
28
  ]
@@ -0,0 +1,178 @@
1
+ """Module for managing PyTorch model checkpoints.
2
+
3
+ Provides the `CheckpointManager` class to save and load model and optimizer
4
+ states during training, track the best metric values, and optionally report
5
+ checkpoint events.
6
+ """
7
+
8
+ import os
9
+ from collections.abc import Callable
10
+ from pathlib import Path
11
+
12
+ from torch import Tensor, load, save
13
+ from torch.nn import Module
14
+ from torch.optim import Optimizer
15
+
16
+ from .metrics import MetricManager
17
+ from .utils import validate_callable, validate_type
18
+
19
+
20
+ class CheckpointManager:
21
+ """Manage saving and loading checkpoints for PyTorch models and optimizers.
22
+
23
+ Handles checkpointing based on a criteria function, restores metric
24
+ states, and optionally reports when a checkpoint is saved.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
30
+ network: Module,
31
+ optimizer: Optimizer,
32
+ metric_manager: MetricManager,
33
+ save_dir: str = "checkpoints",
34
+ create_dir: bool = False,
35
+ report_save: bool = False,
36
+ ):
37
+ """Initialize the CheckpointManager.
38
+
39
+ Args:
40
+ criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
41
+ Function that determines if the current checkpoint should be
42
+ saved based on the current and best metric values.
43
+ network (torch.nn.Module): The model to save/load.
44
+ optimizer (torch.optim.Optimizer): The optimizer to save/load.
45
+ metric_manager (MetricManager): Manages metric states for checkpointing.
46
+ save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
47
+ create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
48
+ Defaults to False.
49
+ report_save (bool, optional): Whether to report when a checkpoint is saved.
50
+ Defaults to False.
51
+
52
+ Raises:
53
+ TypeError: If any provided attribute has an incompatible type.
54
+ FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
55
+ """
56
+ # Type checking
57
+ validate_callable("criteria_function", criteria_function)
58
+ validate_type("network", network, Module)
59
+ validate_type("optimizer", optimizer, Optimizer)
60
+ validate_type("metric_manager", metric_manager, MetricManager)
61
+ validate_type("create_dir", create_dir, bool)
62
+ validate_type("report_save", report_save, bool)
63
+
64
+ # Create path or raise error if create_dir is not found
65
+ if not os.path.exists(save_dir):
66
+ if not create_dir:
67
+ raise FileNotFoundError(
68
+ f"Save directory '{save_dir}' configured in checkpoint manager is not found."
69
+ )
70
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
71
+
72
+ # Initialize objects variables
73
+ self.criteria_function = criteria_function
74
+ self.network = network
75
+ self.optimizer = optimizer
76
+ self.metric_manager = metric_manager
77
+ self.save_dir = save_dir
78
+ self.report_save = report_save
79
+
80
+ self.best_metric_values: dict[str, Tensor] = {}
81
+
82
+ def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
83
+ """Evaluate the criteria function to determine if a better model is found.
84
+
85
+ Aggregates the current metric values during training and applies the
86
+ criteria function. If the criteria function indicates improvement, the
87
+ best metric values are updated, a checkpoint is saved, and a message is
88
+ optionally printed.
89
+
90
+ Args:
91
+ epoch (int): The current epoch number.
92
+ metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
93
+ """
94
+ current_metric_values = self.metric_manager.aggregate(metric_group)
95
+ if self.criteria_function is not None and self.criteria_function(
96
+ current_metric_values, self.best_metric_values
97
+ ):
98
+ # Print message if a new checkpoint is saved
99
+ if self.report_save:
100
+ print(f"New checkpoint saved at epoch {epoch}.")
101
+
102
+ # Update current best metric values
103
+ for metric_name, metric_value in current_metric_values.items():
104
+ self.best_metric_values[metric_name] = metric_value
105
+
106
+ # Save the current state
107
+ self.save(epoch)
108
+
109
+ def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
110
+ """Resumes training from a saved checkpoint file.
111
+
112
+ Args:
113
+ filename (str): The name of the checkpoint file to load.
114
+ Defaults to "checkpoint.pth".
115
+ ignore_missing (bool): If True, does not raise an error if the
116
+ checkpoint file is missing and continues without loading,
117
+ starting from epoch 0. Defaults to False.
118
+
119
+ Returns:
120
+ int: The epoch number from the loaded checkpoint, or 0 if
121
+ ignore_missing is True and no checkpoint was found.
122
+
123
+ Raises:
124
+ TypeError: If a provided attribute has an incompatible type.
125
+ FileNotFoundError: If the specified checkpoint file does not exist.
126
+ """
127
+ # Type checking
128
+ validate_type("filename", filename, str)
129
+ validate_type("ignore_missing", ignore_missing, bool)
130
+
131
+ # Return starting epoch, either from checkpoint file or default
132
+ filepath = os.path.join(self.save_dir, filename)
133
+ if os.path.exists(filepath):
134
+ checkpoint = self.load(filename)
135
+ return checkpoint["epoch"]
136
+ elif ignore_missing:
137
+ return 0
138
+ else:
139
+ raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
140
+
141
+ def save(self, epoch: int, filename: str = "checkpoint.pth"):
142
+ """Save a checkpoint.
143
+
144
+ Args:
145
+ epoch (int): Current epoch number.
146
+ filename (str): Name of the checkpoint file. Defaults to
147
+ 'checkpoint.pth'.
148
+ """
149
+ state = {
150
+ "epoch": epoch,
151
+ "network_state": self.network.state_dict(),
152
+ "optimizer_state": self.optimizer.state_dict(),
153
+ "best_metrics": self.best_metric_values,
154
+ }
155
+ filepath = os.path.join(self.save_dir, filename)
156
+ save(state, filepath)
157
+
158
+ def load(self, filename: str):
159
+ """Load a checkpoint and restore the training state.
160
+
161
+ Loads the checkpoint from the specified file and restores the network
162
+ weights, optimizer state, and best metric values.
163
+
164
+ Args:
165
+ filename (str): Name of the checkpoint file.
166
+
167
+ Returns:
168
+ dict: A dictionary containing the loaded checkpoint information,
169
+ including epoch, loss, and other relevant training state.
170
+ """
171
+ filepath = os.path.join(self.save_dir, filename)
172
+
173
+ checkpoint = load(filepath, weights_only=True)
174
+ self.network.load_state_dict(checkpoint["network_state"])
175
+ self.optimizer.load_state_dict(checkpoint["optimizer_state"])
176
+ self.best_metric_values = checkpoint["best_metrics"]
177
+
178
+ return checkpoint