congrads 1.0.7__tar.gz → 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {congrads-1.0.7 → congrads-1.1.0}/PKG-INFO +48 -41
- {congrads-1.0.7 → congrads-1.1.0}/README.md +11 -5
- congrads-1.1.0/pyproject.toml +122 -0
- {congrads-1.0.7 → congrads-1.1.0/src}/congrads/__init__.py +2 -3
- congrads-1.1.0/src/congrads/checkpoints.py +178 -0
- congrads-1.1.0/src/congrads/constraints.py +1256 -0
- congrads-1.1.0/src/congrads/core.py +773 -0
- congrads-1.1.0/src/congrads/datasets.py +799 -0
- congrads-1.1.0/src/congrads/descriptor.py +166 -0
- congrads-1.1.0/src/congrads/metrics.py +139 -0
- congrads-1.1.0/src/congrads/networks.py +68 -0
- congrads-1.1.0/src/congrads/py.typed +0 -0
- congrads-1.1.0/src/congrads/transformations.py +116 -0
- {congrads-1.0.7 → congrads-1.1.0/src}/congrads/utils.py +499 -131
- congrads-1.0.7/.gitignore +0 -249
- congrads-1.0.7/.gitlab-ci.yml +0 -117
- congrads-1.0.7/.pylintrc +0 -399
- congrads-1.0.7/.readthedocs.yaml +0 -26
- congrads-1.0.7/.vscode/extensions.json +0 -20
- congrads-1.0.7/.vscode/settings.json +0 -68
- congrads-1.0.7/congrads/checkpoints.py +0 -232
- congrads-1.0.7/congrads/constraints.py +0 -906
- congrads-1.0.7/congrads/core.py +0 -597
- congrads-1.0.7/congrads/datasets.py +0 -499
- congrads-1.0.7/congrads/descriptor.py +0 -130
- congrads-1.0.7/congrads/metrics.py +0 -211
- congrads-1.0.7/congrads/networks.py +0 -114
- congrads-1.0.7/congrads/transformations.py +0 -139
- congrads-1.0.7/congrads.egg-info/PKG-INFO +0 -222
- congrads-1.0.7/congrads.egg-info/SOURCES.txt +0 -47
- congrads-1.0.7/congrads.egg-info/dependency_links.txt +0 -1
- congrads-1.0.7/congrads.egg-info/requires.txt +0 -6
- congrads-1.0.7/congrads.egg-info/top_level.txt +0 -1
- congrads-1.0.7/docs/Makefile +0 -20
- congrads-1.0.7/docs/_static/VanBaelen2023.bib +0 -13
- congrads-1.0.7/docs/_static/congrads_export.png +0 -0
- congrads-1.0.7/docs/_static/congrads_export.svg +0 -52
- congrads-1.0.7/docs/_static/congrads_favicon.png +0 -0
- congrads-1.0.7/docs/_static/congrads_favicon.svg +0 -80
- congrads-1.0.7/docs/_static/convergence_illustration.gif +0 -0
- congrads-1.0.7/docs/api.rst +0 -76
- congrads-1.0.7/docs/concepts.rst +0 -224
- congrads-1.0.7/docs/conf.py +0 -77
- congrads-1.0.7/docs/index.rst +0 -116
- congrads-1.0.7/docs/make.bat +0 -35
- congrads-1.0.7/docs/requirements.in +0 -2
- congrads-1.0.7/docs/requirements.txt +0 -57
- congrads-1.0.7/docs/start.rst +0 -126
- congrads-1.0.7/examples/BiasCorrection.py +0 -149
- congrads-1.0.7/examples/FamilyIncome.py +0 -210
- congrads-1.0.7/examples/NoisySines.py +0 -252
- congrads-1.0.7/notebooks/BiasCorrection.ipynb +0 -232
- congrads-1.0.7/notebooks/FamilyIncome.ipynb +0 -293
- congrads-1.0.7/pyproject.toml +0 -29
- congrads-1.0.7/setup.cfg +0 -4
- congrads-1.0.7/tests/congrads/test_utils.py +0 -80
- congrads-1.0.7/tests/examples/test_BiasCorrection.py +0 -150
- congrads-1.0.7/tests/examples/test_FamilyIncome.py +0 -211
- {congrads-1.0.7 → congrads-1.1.0}/LICENSE +0 -0
|
@@ -1,44 +1,45 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: congrads
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: A toolbox for using Constraint Guided Gradient Descent when training neural networks.
|
|
5
|
+
Author: Wout Rombouts, Quinten Van Baelen, Peter Karsmakers
|
|
5
6
|
Author-email: Wout Rombouts <wout.rombouts@kuleuven.be>, Quinten Van Baelen <quinten.vanbaelen@kuleuven.be>, Peter Karsmakers <peter.karsmakers@kuleuven.be>
|
|
6
7
|
License: Copyright 2024 DTAI - KU Leuven
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
Requires-
|
|
8
|
+
|
|
9
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
10
|
+
are permitted provided that the following conditions are met:
|
|
11
|
+
|
|
12
|
+
1. Redistributions of source code must retain the above copyright notice,
|
|
13
|
+
this list of conditions and the following disclaimer.
|
|
14
|
+
|
|
15
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
16
|
+
this list of conditions and the following disclaimer in the documentation
|
|
17
|
+
and/or other materials provided with the distribution.
|
|
18
|
+
|
|
19
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
20
|
+
contributors may be used to endorse or promote products derived from
|
|
21
|
+
this software without specific prior written permission.
|
|
22
|
+
|
|
23
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS”
|
|
24
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
25
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
26
|
+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
|
27
|
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
28
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
29
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
30
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
31
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
32
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
33
|
+
Requires-Dist: numpy>=1.24.0
|
|
34
|
+
Requires-Dist: pandas>=1.5.0
|
|
35
|
+
Requires-Dist: torch>=2.0.0
|
|
36
|
+
Requires-Dist: torchvision>=0.15.1
|
|
37
|
+
Requires-Dist: tqdm>=4.65.0
|
|
38
|
+
Requires-Dist: matplotlib>=3.7.0 ; extra == 'examples'
|
|
39
|
+
Requires-Dist: tensorboard>=2.12.0 ; extra == 'examples'
|
|
40
|
+
Requires-Python: >=3.11
|
|
41
|
+
Provides-Extra: examples
|
|
34
42
|
Description-Content-Type: text/markdown
|
|
35
|
-
License-File: LICENSE
|
|
36
|
-
Requires-Dist: numpy>=1.26.4
|
|
37
|
-
Requires-Dist: pandas>=2.2.2
|
|
38
|
-
Requires-Dist: torch>=2.5.0
|
|
39
|
-
Requires-Dist: torchvision>=0.20.0
|
|
40
|
-
Requires-Dist: tensorboard>=2.18.0
|
|
41
|
-
Requires-Dist: tqdm>=4.66.5
|
|
42
43
|
|
|
43
44
|
<div align="center">
|
|
44
45
|
<img src="https://github.com/ML-KULeuven/congrads/blob/main/docs/_static/congrads_export.png?raw=true" height="200">
|
|
@@ -49,7 +50,7 @@ Requires-Dist: tqdm>=4.66.5
|
|
|
49
50
|
|
|
50
51
|
[](https://pypi.org/project/congrads)
|
|
51
52
|
[](https://congrads.readthedocs.io)
|
|
52
|
-
[](https://pypi.org/project/congrads)
|
|
53
54
|
[](https://pypistats.org/packages/congrads)
|
|
54
55
|
[](https://opensource.org/licenses/BSD-3-Clause)
|
|
55
56
|
|
|
@@ -80,15 +81,21 @@ Next, install the Congrads toolbox. The recommended way to install it is to use
|
|
|
80
81
|
pip install congrads
|
|
81
82
|
```
|
|
82
83
|
|
|
84
|
+
You can also install Congrads together with extra packages required to run the examples:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
pip install congrads[examples]
|
|
88
|
+
```
|
|
89
|
+
|
|
83
90
|
This should automatically install all required dependencies for you. If you would like to install dependencies manually, Congrads depends on the following:
|
|
84
91
|
|
|
85
|
-
- Python 3.
|
|
92
|
+
- Python 3.11 - 3.13
|
|
86
93
|
- **PyTorch** (install with CUDA support for GPU training, refer to [PyTorch's getting started guide](https://pytorch.org/get-started/locally/))
|
|
87
94
|
- **NumPy** (install with `pip install numpy`, or refer to [NumPy's install guide](https://numpy.org/install/).)
|
|
88
95
|
- **Pandas** (install with `pip install pandas`, or refer to [Panda's install guide](https://pandas.pydata.org/docs/getting_started/install.html).)
|
|
89
96
|
- **Tqdm** (install with `pip install tqdm`)
|
|
90
97
|
- **Torchvision** (install with `pip install torchvision`)
|
|
91
|
-
- **Tensorboard** (install with `pip install tensorboard`)
|
|
98
|
+
- Optional: **Tensorboard** (install with `pip install tensorboard`)
|
|
92
99
|
|
|
93
100
|
### 2. **Core concepts**
|
|
94
101
|
|
|
@@ -182,11 +189,11 @@ core.fit(max_epochs=50)
|
|
|
182
189
|
- **Improve Training Process**: Inject domain knowledge in the training stage, increasing learning efficiency.
|
|
183
190
|
- **Physics-Informed Neural Networks (PINNs)**: Coming soon, Enforce physical laws as constraints in your models.
|
|
184
191
|
|
|
185
|
-
## Roadmap
|
|
192
|
+
## Planned changes / Roadmap
|
|
186
193
|
|
|
187
194
|
- [ ] Add ODE/PDE constraints to support PINNs
|
|
195
|
+
- [ ] Rework callback system
|
|
188
196
|
- [ ] Add support for constraint parser that can interpret equations
|
|
189
|
-
- [ ] Determine if it is feasible to add unit and or functional tests
|
|
190
197
|
|
|
191
198
|
## Research
|
|
192
199
|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
[](https://pypi.org/project/congrads)
|
|
9
9
|
[](https://congrads.readthedocs.io)
|
|
10
|
-
[](https://pypi.org/project/congrads)
|
|
11
11
|
[](https://pypistats.org/packages/congrads)
|
|
12
12
|
[](https://opensource.org/licenses/BSD-3-Clause)
|
|
13
13
|
|
|
@@ -38,15 +38,21 @@ Next, install the Congrads toolbox. The recommended way to install it is to use
|
|
|
38
38
|
pip install congrads
|
|
39
39
|
```
|
|
40
40
|
|
|
41
|
+
You can also install Congrads together with extra packages required to run the examples:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install congrads[examples]
|
|
45
|
+
```
|
|
46
|
+
|
|
41
47
|
This should automatically install all required dependencies for you. If you would like to install dependencies manually, Congrads depends on the following:
|
|
42
48
|
|
|
43
|
-
- Python 3.
|
|
49
|
+
- Python 3.11 - 3.13
|
|
44
50
|
- **PyTorch** (install with CUDA support for GPU training, refer to [PyTorch's getting started guide](https://pytorch.org/get-started/locally/))
|
|
45
51
|
- **NumPy** (install with `pip install numpy`, or refer to [NumPy's install guide](https://numpy.org/install/).)
|
|
46
52
|
- **Pandas** (install with `pip install pandas`, or refer to [Panda's install guide](https://pandas.pydata.org/docs/getting_started/install.html).)
|
|
47
53
|
- **Tqdm** (install with `pip install tqdm`)
|
|
48
54
|
- **Torchvision** (install with `pip install torchvision`)
|
|
49
|
-
- **Tensorboard** (install with `pip install tensorboard`)
|
|
55
|
+
- Optional: **Tensorboard** (install with `pip install tensorboard`)
|
|
50
56
|
|
|
51
57
|
### 2. **Core concepts**
|
|
52
58
|
|
|
@@ -140,11 +146,11 @@ core.fit(max_epochs=50)
|
|
|
140
146
|
- **Improve Training Process**: Inject domain knowledge in the training stage, increasing learning efficiency.
|
|
141
147
|
- **Physics-Informed Neural Networks (PINNs)**: Coming soon, Enforce physical laws as constraints in your models.
|
|
142
148
|
|
|
143
|
-
## Roadmap
|
|
149
|
+
## Planned changes / Roadmap
|
|
144
150
|
|
|
145
151
|
- [ ] Add ODE/PDE constraints to support PINNs
|
|
152
|
+
- [ ] Rework callback system
|
|
146
153
|
- [ ] Add support for constraint parser that can interpret equations
|
|
147
|
-
- [ ] Determine if it is feasible to add unit and or functional tests
|
|
148
154
|
|
|
149
155
|
## Research
|
|
150
156
|
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "congrads"
|
|
3
|
+
version = "1.1.0"
|
|
4
|
+
description = "A toolbox for using Constraint Guided Gradient Descent when training neural networks."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = { file = "LICENSE" }
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Wout Rombouts", email = "wout.rombouts@kuleuven.be" },
|
|
9
|
+
{ name = "Quinten Van Baelen", email = "quinten.vanbaelen@kuleuven.be" },
|
|
10
|
+
{ name = "Peter Karsmakers", email = "peter.karsmakers@kuleuven.be" }
|
|
11
|
+
]
|
|
12
|
+
requires-python = ">=3.11"
|
|
13
|
+
dependencies = [
|
|
14
|
+
"numpy>=1.24.0",
|
|
15
|
+
"pandas>=1.5.0",
|
|
16
|
+
"torch>=2.0.0",
|
|
17
|
+
"torchvision>=0.15.1",
|
|
18
|
+
"tqdm>=4.65.0",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[project.optional-dependencies]
|
|
22
|
+
examples = [
|
|
23
|
+
"matplotlib>=3.7.0",
|
|
24
|
+
"tensorboard>=2.12.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[build-system]
|
|
28
|
+
requires = ["uv_build>=0.8.13,<0.9.0"]
|
|
29
|
+
build-backend = "uv_build"
|
|
30
|
+
|
|
31
|
+
[dependency-groups]
|
|
32
|
+
dev = [
|
|
33
|
+
"pytest>=8.4.2",
|
|
34
|
+
"pytest-cov>=7.0.0",
|
|
35
|
+
"ruff>=0.13.0",
|
|
36
|
+
"sphinx>=8.2.3",
|
|
37
|
+
"sphinx-rtd-theme>=3.0.2",
|
|
38
|
+
"tox>=4.30.2",
|
|
39
|
+
"twine>=6.2.0",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[tool.ruff]
|
|
43
|
+
# Match Black-style defaults but with a longer line length
|
|
44
|
+
line-length = 100
|
|
45
|
+
indent-width = 4
|
|
46
|
+
target-version = "py311" # or "py312" if needed
|
|
47
|
+
|
|
48
|
+
# Where Ruff runs (exclude cache, venvs, build dirs, etc.)
|
|
49
|
+
exclude = [
|
|
50
|
+
".git",
|
|
51
|
+
".mypy_cache",
|
|
52
|
+
".ruff_cache",
|
|
53
|
+
".venv",
|
|
54
|
+
"venv",
|
|
55
|
+
"build",
|
|
56
|
+
"dist",
|
|
57
|
+
"__pypackages__",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
# Which files to include
|
|
61
|
+
include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
|
|
62
|
+
|
|
63
|
+
[tool.ruff.lint]
|
|
64
|
+
# Essentials + additional readability rules
|
|
65
|
+
select = [
|
|
66
|
+
"E", # pycodestyle errors
|
|
67
|
+
"F", # pyflakes
|
|
68
|
+
"I", # import sorting
|
|
69
|
+
"B", # bugbear
|
|
70
|
+
"UP", # pyupgrade
|
|
71
|
+
"C4", # comprehensions
|
|
72
|
+
"ISC", # implicit string concatenation
|
|
73
|
+
"D", # docstring styling
|
|
74
|
+
"N", # consistent naming
|
|
75
|
+
]
|
|
76
|
+
ignore = [
|
|
77
|
+
"E501", # line length (handled by formatter)
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
# Enable autofix for safe rules
|
|
81
|
+
fixable = ["ALL"]
|
|
82
|
+
|
|
83
|
+
[tool.ruff.lint.per-file-ignores]
|
|
84
|
+
"tests/*" = ["D", "N"]
|
|
85
|
+
|
|
86
|
+
[tool.ruff.format]
|
|
87
|
+
quote-style = "double" # enforce double quotes
|
|
88
|
+
indent-style = "space" # 4 spaces
|
|
89
|
+
skip-magic-trailing-comma = false # Black-compatible
|
|
90
|
+
docstring-code-format = true
|
|
91
|
+
docstring-code-line-length = "dynamic"
|
|
92
|
+
|
|
93
|
+
[tool.ruff.lint.pydocstyle]
|
|
94
|
+
convention = "google"
|
|
95
|
+
|
|
96
|
+
[tool.tox]
|
|
97
|
+
envlist = ["py{311,312,313}-{latest,minimum}"]
|
|
98
|
+
|
|
99
|
+
[tool.tox.env_run_base]
|
|
100
|
+
description = "Run tests with latest dependencies"
|
|
101
|
+
deps = [
|
|
102
|
+
"pytest",
|
|
103
|
+
"pytest-cov"
|
|
104
|
+
]
|
|
105
|
+
commands = [
|
|
106
|
+
["pytest", "--cov=src/congrads", "tests/"]
|
|
107
|
+
]
|
|
108
|
+
extras = [ "examples" ]
|
|
109
|
+
|
|
110
|
+
[tool.tox.env.latest]
|
|
111
|
+
description = "Run tests with latest dependencies"
|
|
112
|
+
|
|
113
|
+
[tool.tox.env.minimum]
|
|
114
|
+
description = "Run tests with minimum dependencies"
|
|
115
|
+
deps = [
|
|
116
|
+
"-c constraints-min.txt",
|
|
117
|
+
"pytest",
|
|
118
|
+
"pytest-cov"
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Optional: ensure coverage is reported for all envs
|
|
122
|
+
skip_missing_interpreters = true
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
3
|
-
try:
|
|
1
|
+
try: # noqa: D104
|
|
4
2
|
from importlib.metadata import version as get_version # Python 3.8+
|
|
5
3
|
except ImportError:
|
|
6
4
|
from pkg_resources import (
|
|
@@ -25,5 +23,6 @@ __all__ = [
|
|
|
25
23
|
"descriptor",
|
|
26
24
|
"metrics",
|
|
27
25
|
"networks",
|
|
26
|
+
"transformations",
|
|
28
27
|
"utils",
|
|
29
28
|
]
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Module for managing PyTorch model checkpoints.
|
|
2
|
+
|
|
3
|
+
Provides the `CheckpointManager` class to save and load model and optimizer
|
|
4
|
+
states during training, track the best metric values, and optionally report
|
|
5
|
+
checkpoint events.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from torch import Tensor, load, save
|
|
13
|
+
from torch.nn import Module
|
|
14
|
+
from torch.optim import Optimizer
|
|
15
|
+
|
|
16
|
+
from .metrics import MetricManager
|
|
17
|
+
from .utils import validate_callable, validate_type
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CheckpointManager:
|
|
21
|
+
"""Manage saving and loading checkpoints for PyTorch models and optimizers.
|
|
22
|
+
|
|
23
|
+
Handles checkpointing based on a criteria function, restores metric
|
|
24
|
+
states, and optionally reports when a checkpoint is saved.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool],
|
|
30
|
+
network: Module,
|
|
31
|
+
optimizer: Optimizer,
|
|
32
|
+
metric_manager: MetricManager,
|
|
33
|
+
save_dir: str = "checkpoints",
|
|
34
|
+
create_dir: bool = False,
|
|
35
|
+
report_save: bool = False,
|
|
36
|
+
):
|
|
37
|
+
"""Initialize the CheckpointManager.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
criteria_function (Callable[[dict[str, Tensor], dict[str, Tensor]], bool]):
|
|
41
|
+
Function that determines if the current checkpoint should be
|
|
42
|
+
saved based on the current and best metric values.
|
|
43
|
+
network (torch.nn.Module): The model to save/load.
|
|
44
|
+
optimizer (torch.optim.Optimizer): The optimizer to save/load.
|
|
45
|
+
metric_manager (MetricManager): Manages metric states for checkpointing.
|
|
46
|
+
save_dir (str, optional): Directory to save checkpoints. Defaults to 'checkpoints'.
|
|
47
|
+
create_dir (bool, optional): Whether to create `save_dir` if it does not exist.
|
|
48
|
+
Defaults to False.
|
|
49
|
+
report_save (bool, optional): Whether to report when a checkpoint is saved.
|
|
50
|
+
Defaults to False.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
TypeError: If any provided attribute has an incompatible type.
|
|
54
|
+
FileNotFoundError: If `save_dir` does not exist and `create_dir` is False.
|
|
55
|
+
"""
|
|
56
|
+
# Type checking
|
|
57
|
+
validate_callable("criteria_function", criteria_function)
|
|
58
|
+
validate_type("network", network, Module)
|
|
59
|
+
validate_type("optimizer", optimizer, Optimizer)
|
|
60
|
+
validate_type("metric_manager", metric_manager, MetricManager)
|
|
61
|
+
validate_type("create_dir", create_dir, bool)
|
|
62
|
+
validate_type("report_save", report_save, bool)
|
|
63
|
+
|
|
64
|
+
# Create path or raise error if create_dir is not found
|
|
65
|
+
if not os.path.exists(save_dir):
|
|
66
|
+
if not create_dir:
|
|
67
|
+
raise FileNotFoundError(
|
|
68
|
+
f"Save directory '{save_dir}' configured in checkpoint manager is not found."
|
|
69
|
+
)
|
|
70
|
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
|
71
|
+
|
|
72
|
+
# Initialize objects variables
|
|
73
|
+
self.criteria_function = criteria_function
|
|
74
|
+
self.network = network
|
|
75
|
+
self.optimizer = optimizer
|
|
76
|
+
self.metric_manager = metric_manager
|
|
77
|
+
self.save_dir = save_dir
|
|
78
|
+
self.report_save = report_save
|
|
79
|
+
|
|
80
|
+
self.best_metric_values: dict[str, Tensor] = {}
|
|
81
|
+
|
|
82
|
+
def evaluate_criteria(self, epoch: int, metric_group: str = "during_training"):
|
|
83
|
+
"""Evaluate the criteria function to determine if a better model is found.
|
|
84
|
+
|
|
85
|
+
Aggregates the current metric values during training and applies the
|
|
86
|
+
criteria function. If the criteria function indicates improvement, the
|
|
87
|
+
best metric values are updated, a checkpoint is saved, and a message is
|
|
88
|
+
optionally printed.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
epoch (int): The current epoch number.
|
|
92
|
+
metric_group (str, optional): The metric group to evaluate. Defaults to 'during_training'.
|
|
93
|
+
"""
|
|
94
|
+
current_metric_values = self.metric_manager.aggregate(metric_group)
|
|
95
|
+
if self.criteria_function is not None and self.criteria_function(
|
|
96
|
+
current_metric_values, self.best_metric_values
|
|
97
|
+
):
|
|
98
|
+
# Print message if a new checkpoint is saved
|
|
99
|
+
if self.report_save:
|
|
100
|
+
print(f"New checkpoint saved at epoch {epoch}.")
|
|
101
|
+
|
|
102
|
+
# Update current best metric values
|
|
103
|
+
for metric_name, metric_value in current_metric_values.items():
|
|
104
|
+
self.best_metric_values[metric_name] = metric_value
|
|
105
|
+
|
|
106
|
+
# Save the current state
|
|
107
|
+
self.save(epoch)
|
|
108
|
+
|
|
109
|
+
def resume(self, filename: str = "checkpoint.pth", ignore_missing: bool = False) -> int:
|
|
110
|
+
"""Resumes training from a saved checkpoint file.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
filename (str): The name of the checkpoint file to load.
|
|
114
|
+
Defaults to "checkpoint.pth".
|
|
115
|
+
ignore_missing (bool): If True, does not raise an error if the
|
|
116
|
+
checkpoint file is missing and continues without loading,
|
|
117
|
+
starting from epoch 0. Defaults to False.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
int: The epoch number from the loaded checkpoint, or 0 if
|
|
121
|
+
ignore_missing is True and no checkpoint was found.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
TypeError: If a provided attribute has an incompatible type.
|
|
125
|
+
FileNotFoundError: If the specified checkpoint file does not exist.
|
|
126
|
+
"""
|
|
127
|
+
# Type checking
|
|
128
|
+
validate_type("filename", filename, str)
|
|
129
|
+
validate_type("ignore_missing", ignore_missing, bool)
|
|
130
|
+
|
|
131
|
+
# Return starting epoch, either from checkpoint file or default
|
|
132
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
133
|
+
if os.path.exists(filepath):
|
|
134
|
+
checkpoint = self.load(filename)
|
|
135
|
+
return checkpoint["epoch"]
|
|
136
|
+
elif ignore_missing:
|
|
137
|
+
return 0
|
|
138
|
+
else:
|
|
139
|
+
raise FileNotFoundError(f"A checkpoint was not found at {filepath} to resume training.")
|
|
140
|
+
|
|
141
|
+
def save(self, epoch: int, filename: str = "checkpoint.pth"):
|
|
142
|
+
"""Save a checkpoint.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
epoch (int): Current epoch number.
|
|
146
|
+
filename (str): Name of the checkpoint file. Defaults to
|
|
147
|
+
'checkpoint.pth'.
|
|
148
|
+
"""
|
|
149
|
+
state = {
|
|
150
|
+
"epoch": epoch,
|
|
151
|
+
"network_state": self.network.state_dict(),
|
|
152
|
+
"optimizer_state": self.optimizer.state_dict(),
|
|
153
|
+
"best_metrics": self.best_metric_values,
|
|
154
|
+
}
|
|
155
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
156
|
+
save(state, filepath)
|
|
157
|
+
|
|
158
|
+
def load(self, filename: str):
|
|
159
|
+
"""Load a checkpoint and restore the training state.
|
|
160
|
+
|
|
161
|
+
Loads the checkpoint from the specified file and restores the network
|
|
162
|
+
weights, optimizer state, and best metric values.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
filename (str): Name of the checkpoint file.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
dict: A dictionary containing the loaded checkpoint information,
|
|
169
|
+
including epoch, loss, and other relevant training state.
|
|
170
|
+
"""
|
|
171
|
+
filepath = os.path.join(self.save_dir, filename)
|
|
172
|
+
|
|
173
|
+
checkpoint = load(filepath, weights_only=True)
|
|
174
|
+
self.network.load_state_dict(checkpoint["network_state"])
|
|
175
|
+
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
176
|
+
self.best_metric_values = checkpoint["best_metrics"]
|
|
177
|
+
|
|
178
|
+
return checkpoint
|