torch-tk 1.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. torch_tk-1.0.8/LICENSE +13 -0
  2. torch_tk-1.0.8/PKG-INFO +197 -0
  3. torch_tk-1.0.8/README.md +162 -0
  4. torch_tk-1.0.8/pyproject.toml +98 -0
  5. torch_tk-1.0.8/setup.cfg +4 -0
  6. torch_tk-1.0.8/src/torch_tk/__init__.py +29 -0
  7. torch_tk-1.0.8/src/torch_tk/checkpoints/__init__.py +1 -0
  8. torch_tk-1.0.8/src/torch_tk/checkpoints/checkpoint_manager.py +139 -0
  9. torch_tk-1.0.8/src/torch_tk/checkpoints/utils.py +33 -0
  10. torch_tk-1.0.8/src/torch_tk/diagnostics/__init__.py +3 -0
  11. torch_tk-1.0.8/src/torch_tk/diagnostics/diagnostics.py +281 -0
  12. torch_tk-1.0.8/src/torch_tk/diagnostics/loss.py +160 -0
  13. torch_tk-1.0.8/src/torch_tk/diagnostics/plotting.py +148 -0
  14. torch_tk-1.0.8/src/torch_tk/models/__init__.py +2 -0
  15. torch_tk-1.0.8/src/torch_tk/models/model.py +102 -0
  16. torch_tk-1.0.8/src/torch_tk/models/utils.py +18 -0
  17. torch_tk-1.0.8/src/torch_tk/optimizers/__init__.py +3 -0
  18. torch_tk-1.0.8/src/torch_tk/optimizers/adam.py +87 -0
  19. torch_tk-1.0.8/src/torch_tk/optimizers/sgd.py +79 -0
  20. torch_tk-1.0.8/src/torch_tk/optimizers/sgd_manual.py +104 -0
  21. torch_tk-1.0.8/src/torch_tk/test.py +5 -0
  22. torch_tk-1.0.8/src/torch_tk/training/__init__.py +1 -0
  23. torch_tk-1.0.8/src/torch_tk/training/trainer.py +429 -0
  24. torch_tk-1.0.8/src/torch_tk.egg-info/PKG-INFO +197 -0
  25. torch_tk-1.0.8/src/torch_tk.egg-info/SOURCES.txt +27 -0
  26. torch_tk-1.0.8/src/torch_tk.egg-info/dependency_links.txt +1 -0
  27. torch_tk-1.0.8/src/torch_tk.egg-info/requires.txt +14 -0
  28. torch_tk-1.0.8/src/torch_tk.egg-info/top_level.txt +1 -0
  29. torch_tk-1.0.8/tests/test_test.py +5 -0
torch_tk-1.0.8/LICENSE ADDED
@@ -0,0 +1,13 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2026, Jan Kazil
4
+
5
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6
+
7
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
8
+
9
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10
+
11
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
12
+
13
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,197 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-tk
3
+ Version: 1.0.8
4
+ Summary: Streamlines training, checkpoint management, and diagnostics of PyTorch models.
5
+ Author-email: Jan Kazil <jan.kazil.dev@gmail.com>
6
+ License-Expression: BSD-3-Clause
7
+ Project-URL: Homepage, https://github.com/jankazil/torch-tk
8
+ Project-URL: Repository, https://github.com/jankazil/torch-tk
9
+ Project-URL: Issues, https://github.com/jankazil/torch-tk/issues
10
+ Keywords: pytorch,training,checkpointing,diagnostics,deep-learning
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering
17
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
18
+ Requires-Python: >=3.12
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: torch
22
+ Requires-Dist: numpy
23
+ Requires-Dist: xarray
24
+ Requires-Dist: matplotlib
25
+ Requires-Dist: scipy
26
+ Provides-Extra: dev
27
+ Requires-Dist: build>=1.2; extra == "dev"
28
+ Requires-Dist: twine>=5.1; extra == "dev"
29
+ Requires-Dist: pytest>=8; extra == "dev"
30
+ Requires-Dist: pytest-cov>=5; extra == "dev"
31
+ Requires-Dist: mypy>=1.11; extra == "dev"
32
+ Requires-Dist: ruff>=0.5; extra == "dev"
33
+ Requires-Dist: pre-commit>=3.7; extra == "dev"
34
+ Dynamic: license-file
35
+
36
+ # Torch ToolKit (torch-tk)
37
+
38
+ **torch-tk** streamlines training, checkpoint management, and diagnostics of PyTorch models.
39
+
40
+ ## Overview
41
+
42
+ **torch-tk** adds a small amount of structure around PyTorch models and optimizers to simplify training and automate saving and restoring checkpoints.
43
+
44
+ **torch-tk** provides a model base class, optimizers, a checkpoint manager, a trainer, and diagnostics utilities. The model class inherits from `torch.nn.Module`, and the **torch-tk** optimizers are wrappers around `torch.optim` optimizers such as `torch.optim.SGD` and `torch.optim.Adam`. The **torch-tk** model and optimizer classes thus preserve functionality and interface of PyTorch modules and optimizers.
45
+
46
+ ## Key features
47
+
48
+ **torch-tk** models and optimizers are self-describing: A model derived from `torch_tk.models.Model` and the optimizers in `torch_tk.optimizers` provide all information needed to save their state and recreate the same model and optimizer instances later. In practice, this means a model and optimizer can save to file the constructor arguments and state parameters needed to rebuild them, allowing to load them back into fresh instances.
49
+
50
+ **torch-tk** provides a `CheckPointManager`. The `CheckPointManager` manages saving, loading, and reconstruing both the model and the optimizer in the state that created the checkpoint. All that is required is that the class paths are available to import the original model and optimizer classes.
51
+
52
+ **torch-tk** provides a `Trainer` class for running epoch-based training. It supports training either from a `DataLoader` or directly from tensors, and records basic diagnostics such as training loss and epoch wallclock time.
53
+
54
+ **torch-tk** provides a `Diagnostics` class for storing sample-resolved loss information together with training metadata. These diagnostics can be created from tensors or data loaders and can be saved to and restored from netCDF files for later analysis.
55
+
56
+ ## Workflow
57
+ The workflow is shown in the **torch-tk** [HowTo](https://github.com/jankazil/torch-tk/blob/main/notebooks/HowTo.ipynb) Jupyter notebook.
58
+
59
+ ## Installation
60
+
61
+ ### pip
62
+ ```bash
63
+ pip install torch-tk
64
+ ```
65
+
66
+ ### conda / mamba
67
+ ```bash
68
+ mamba install -c jan.kazil -c conda-forge torch-tk
69
+ ```
70
+
71
+ ## Classes
72
+
73
+ - `Model`
74
+
75
+ - A base class which makes models self-describing and automatically reconstructible by the `CheckPointManager`
76
+ - Automatically rebuilds a model from a saved file
77
+
78
+ - `SGD`, `Adam`, ...
79
+
80
+ Wrapper classes for PyTorch optimizers that make the optimizers self-describing and automatically reconstructible by the `CheckPointManager`
81
+
82
+ - `Trainer`
83
+
84
+ - Trains a model from
85
+ - a `torch.utils.data.DataLoader`
86
+ - or directly from tensors, using an efficient batching mechanism
87
+ - Records training loss and model timing per epoch
88
+
89
+ - `CheckPointManager`
90
+
91
+ - Saves and restores model training states
92
+ - Automatically rebuilds both a model and its optimizer from a saved checkpoint file
93
+
94
+ - `Diagnostics`
95
+
96
+ - Computes, stores, and plots per-sample loss and per-sample loss probability distribution
97
+ - Saves and restores diagnostics in netCDF file format
98
+ - Identifies worst-loss samples
99
+
100
+ ## Public API
101
+
102
+ ### Modules
103
+
104
+ #### `torch_tk.models.model`
105
+
106
+ Provides the abstract `Model` base class for models that can describe, save, restore, and reconstruct themselves. The `Model` class inherits from `torch.nn.Module`, and thus provides the standard PyTorch `Module` interface.
107
+
108
+ The `Model` class defines and provides the following methods:
109
+
110
+ - `Model.forward(xb)`: Abstract method that computes the forward pass.
111
+ - `Model.constructor_dict()`: Abstract method that returns the constructor arguments needed to reconstruct the model.
112
+ - `Model.save_state_dict_to_file(path)`: Save only the state dictionary.
113
+ - `Model.save_to_file(path)`: Save constructor arguments and state dictionary needed to recreate the model.
114
+ - `Model.load_from_file(path, device=None)`: Recreate a model from a saved file.
115
+ - `Model.clone(constructor_dict, state_dict, device=None)`: Reconstruct a model from constructor arguments and state.
116
+
117
+ #### `torch_tk.optimizers.sgd`
118
+
119
+ Wrapper around `torch.optim.SGD` to make it self-describing and automatically reconstructible.
120
+
121
+ - `SGD(...)`: Subclass of `torch.optim.SGD` that stores its constructor arguments on the instance.
122
+ - `SGD.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
123
+
124
+ #### `torch_tk.optimizers.adam`
125
+
126
+ Wrapper around `torch.optim.Adam` to make it self-describing and automatically reconstructible.
127
+
128
+ - `Adam(...)`: Subclass of `torch.optim.Adam` that stores its constructor arguments on the instance.
129
+ - `Adam.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
130
+
131
+ #### `torch_tk.training.trainer`
132
+
133
+ Provides the `Trainer` class for epoch-based training and simple training diagnostics.
134
+
135
+ - `Trainer(model, optimizer, loss_function, epoch=0)`: Initialize trainer state.
136
+ - `Trainer.train_with_dataloader(data_loader, num_epochs, epoch_diag_step=1, verbose=True)`: Train from a `DataLoader`.
137
+ - `Trainer.train_with_data(x_train, y_train, bs, num_epochs, epoch_diag_step=1, verbose=True, shuffle=False)`: Train from in-memory tensors.
138
+ - `Trainer.plot_loss(...)`: Plot recorded diagnostic loss versus epoch.
139
+ - `Trainer.plot_wallclock_time(...)`: Plot recorded epoch wallclock time versus epoch.
140
+
141
+ #### `torch_tk.checkpoints.checkpoint_manager`
142
+
143
+ Provides checkpoint management for saving and reconstructing a model and optimizer together.
144
+
145
+ - `CheckPointManager(model, optimizer, directory)`: Manage checkpoint saving in a directory.
146
+ - `CheckPointManager.save(epoch)`: Save a checkpoint containing epoch, class paths, constructor dictionaries, and state dictionaries.
147
+ - `CheckPointManager.load_from_file(file_path, device=None)`: Reconstruct and return `checkpoint_manager, model, optimizer, epoch` from a checkpoint file.
148
+
149
+ #### `torch_tk.diagnostics.loss`
150
+
151
+ Provides utilities for computing per-sample loss.
152
+
153
+ - `per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader)`: Compute per-sample losses and their mean from a `DataLoader`.
154
+ - `per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None)`: Compute per-sample losses and their mean from in-memory tensors.
155
+ - `model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None)`: Return the indices and values of the `n` worst losses.
156
+
157
+ #### `torch_tk.diagnostics.diagnostics`
158
+
159
+ Provides the `Diagnostics` container for sample-resolved loss diagnostics and analysis.
160
+
161
+ - `Diagnostics.from_data_loader(...)`: Build diagnostics from a model evaluated on a `DataLoader`.
162
+ - `Diagnostics.from_data(...)`: Build diagnostics from in-memory tensors.
163
+ - `Diagnostics.from_netcdf(path)`: Restore diagnostics from a saved netCDF file.
164
+ - `Diagnostics(...)`: Construct a diagnostics object from metadata, epochs, and per-sample loss data.
165
+ - `Diagnostics.__add__(other)`: Concatenate compatible diagnostics across epochs.
166
+ - `Diagnostics.to_netcdf(directory, verbose=True)`: Save diagnostics to a netCDF file.
167
+
168
+ #### `torch_tk.diagnostics.plotting`
169
+
170
+ Provides utilities for plotting diagnostics.
171
+
172
+ - `plot_diagnostics(diagnostics, plot_file=None, title=None, font_factor=1.5, figsize=(9, 6), xlim=None, ylim=None, loss_name='sqrt(loss)', pdf_bin_n=100, dpdlog10=False, show_plot=True, verbose=True)`: Plot kernel-density estimates of square-root per-sample loss distributions across one or more diagnostics objects and epochs.
173
+
174
+ ## Notes and limitations
175
+
176
+ - The checkpoint mechanism assumes that models and optimizers are importable from stable class paths and expose `constructor_dict()`, `state_dict()`, and `load_state_dict()`.
177
+ - The checkpoint design is not suitable for optimizers that require non-serializable constructor inputs or custom parameter-group reconstruction beyond `model.parameters()`.
178
+ - The diagnostic plotting utility requires strictly positive, non-negative loss values because it plots the square root of loss on a logarithmic axis.
179
+ - The recorded epoch loss in `Trainer` is exact only when the supplied loss function returns the mean per-sample loss over each batch, as stated in the trainer docstrings.
180
+
181
+ ## Development
182
+
183
+ ### Code Quality and Testing Commands
184
+
185
+ - `make fmt` - Runs `ruff format`, which reformats Python files according to the style rules in `pyproject.toml`.
186
+ - `make lint` - Runs `ruff check --fix`, which lints the code and auto-fixes what it can.
187
+ - `make check` - Runs formatting and linting.
188
+ - `make type` - Currently disabled. Intended to run `mypy` using the settings in `pyproject.toml`.
189
+ - `make test` - Runs `pytest` with the test settings configured in `pyproject.toml`.
190
+
191
+ ## Author
192
+
193
+ Jan Kazil - jan.kazil.dev@gmail.com - [jankazil.com](https://jankazil.com)
194
+
195
+ ## License
196
+
197
+ BSD-3-Clause
@@ -0,0 +1,162 @@
1
+ # Torch ToolKit (torch-tk)
2
+
3
+ **torch-tk** streamlines training, checkpoint management, and diagnostics of PyTorch models.
4
+
5
+ ## Overview
6
+
7
+ **torch-tk** adds a small amount of structure around PyTorch models and optimizers to simplify training and automate saving and restoring checkpoints.
8
+
9
+ **torch-tk** provides a model base class, optimizers, a checkpoint manager, a trainer, and diagnostics utilities. The model class inherits from `torch.nn.Module`, and the **torch-tk** optimizers are wrappers around `torch.optim` optimizers such as `torch.optim.SGD` and `torch.optim.Adam`. The **torch-tk** model and optimizer classes thus preserve functionality and interface of PyTorch modules and optimizers.
10
+
11
+ ## Key features
12
+
13
+ **torch-tk** models and optimizers are self-describing: A model derived from `torch_tk.models.Model` and the optimizers in `torch_tk.optimizers` provide all information needed to save their state and recreate the same model and optimizer instances later. In practice, this means a model and optimizer can save to file the constructor arguments and state parameters needed to rebuild them, allowing to load them back into fresh instances.
14
+
15
+ **torch-tk** provides a `CheckPointManager`. The `CheckPointManager` manages saving, loading, and reconstruing both the model and the optimizer in the state that created the checkpoint. All that is required is that the class paths are available to import the original model and optimizer classes.
16
+
17
+ **torch-tk** provides a `Trainer` class for running epoch-based training. It supports training either from a `DataLoader` or directly from tensors, and records basic diagnostics such as training loss and epoch wallclock time.
18
+
19
+ **torch-tk** provides a `Diagnostics` class for storing sample-resolved loss information together with training metadata. These diagnostics can be created from tensors or data loaders and can be saved to and restored from netCDF files for later analysis.
20
+
21
+ ## Workflow
22
+ The workflow is shown in the **torch-tk** [HowTo](https://github.com/jankazil/torch-tk/blob/main/notebooks/HowTo.ipynb) Jupyter notebook.
23
+
24
+ ## Installation
25
+
26
+ ### pip
27
+ ```bash
28
+ pip install torch-tk
29
+ ```
30
+
31
+ ### conda / mamba
32
+ ```bash
33
+ mamba install -c jan.kazil -c conda-forge torch-tk
34
+ ```
35
+
36
+ ## Classes
37
+
38
+ - `Model`
39
+
40
+ - A base class which makes models self-describing and automatically reconstructible by the `CheckPointManager`
41
+ - Automatically rebuilds a model from a saved file
42
+
43
+ - `SGD`, `Adam`, ...
44
+
45
+ Wrapper classes for PyTorch optimizers that make the optimizers self-describing and automatically reconstructible by the `CheckPointManager`
46
+
47
+ - `Trainer`
48
+
49
+ - Trains a model from
50
+ - a `torch.utils.data.DataLoader`
51
+ - or directly from tensors, using an efficient batching mechanism
52
+ - Records training loss and model timing per epoch
53
+
54
+ - `CheckPointManager`
55
+
56
+ - Saves and restores model training states
57
+ - Automatically rebuilds both a model and its optimizer from a saved checkpoint file
58
+
59
+ - `Diagnostics`
60
+
61
+ - Computes, stores, and plots per-sample loss and per-sample loss probability distribution
62
+ - Saves and restores diagnostics in netCDF file format
63
+ - Identifies worst-loss samples
64
+
65
+ ## Public API
66
+
67
+ ### Modules
68
+
69
+ #### `torch_tk.models.model`
70
+
71
+ Provides the abstract `Model` base class for models that can describe, save, restore, and reconstruct themselves. The `Model` class inherits from `torch.nn.Module`, and thus provides the standard PyTorch `Module` interface.
72
+
73
+ The `Model` class defines and provides the following methods:
74
+
75
+ - `Model.forward(xb)`: Abstract method that computes the forward pass.
76
+ - `Model.constructor_dict()`: Abstract method that returns the constructor arguments needed to reconstruct the model.
77
+ - `Model.save_state_dict_to_file(path)`: Save only the state dictionary.
78
+ - `Model.save_to_file(path)`: Save constructor arguments and state dictionary needed to recreate the model.
79
+ - `Model.load_from_file(path, device=None)`: Recreate a model from a saved file.
80
+ - `Model.clone(constructor_dict, state_dict, device=None)`: Reconstruct a model from constructor arguments and state.
81
+
82
+ #### `torch_tk.optimizers.sgd`
83
+
84
+ Wrapper around `torch.optim.SGD` to make it self-describing and automatically reconstructible.
85
+
86
+ - `SGD(...)`: Subclass of `torch.optim.SGD` that stores its constructor arguments on the instance.
87
+ - `SGD.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
88
+
89
+ #### `torch_tk.optimizers.adam`
90
+
91
+ Wrapper around `torch.optim.Adam` to make it self-describing and automatically reconstructible.
92
+
93
+ - `Adam(...)`: Subclass of `torch.optim.Adam` that stores its constructor arguments on the instance.
94
+ - `Adam.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
95
+
96
+ #### `torch_tk.training.trainer`
97
+
98
+ Provides the `Trainer` class for epoch-based training and simple training diagnostics.
99
+
100
+ - `Trainer(model, optimizer, loss_function, epoch=0)`: Initialize trainer state.
101
+ - `Trainer.train_with_dataloader(data_loader, num_epochs, epoch_diag_step=1, verbose=True)`: Train from a `DataLoader`.
102
+ - `Trainer.train_with_data(x_train, y_train, bs, num_epochs, epoch_diag_step=1, verbose=True, shuffle=False)`: Train from in-memory tensors.
103
+ - `Trainer.plot_loss(...)`: Plot recorded diagnostic loss versus epoch.
104
+ - `Trainer.plot_wallclock_time(...)`: Plot recorded epoch wallclock time versus epoch.
105
+
106
+ #### `torch_tk.checkpoints.checkpoint_manager`
107
+
108
+ Provides checkpoint management for saving and reconstructing a model and optimizer together.
109
+
110
+ - `CheckPointManager(model, optimizer, directory)`: Manage checkpoint saving in a directory.
111
+ - `CheckPointManager.save(epoch)`: Save a checkpoint containing epoch, class paths, constructor dictionaries, and state dictionaries.
112
+ - `CheckPointManager.load_from_file(file_path, device=None)`: Reconstruct and return `checkpoint_manager, model, optimizer, epoch` from a checkpoint file.
113
+
114
+ #### `torch_tk.diagnostics.loss`
115
+
116
+ Provides utilities for computing per-sample loss.
117
+
118
+ - `per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader)`: Compute per-sample losses and their mean from a `DataLoader`.
119
+ - `per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None)`: Compute per-sample losses and their mean from in-memory tensors.
120
+ - `model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None)`: Return the indices and values of the `n` worst losses.
121
+
122
+ #### `torch_tk.diagnostics.diagnostics`
123
+
124
+ Provides the `Diagnostics` container for sample-resolved loss diagnostics and analysis.
125
+
126
+ - `Diagnostics.from_data_loader(...)`: Build diagnostics from a model evaluated on a `DataLoader`.
127
+ - `Diagnostics.from_data(...)`: Build diagnostics from in-memory tensors.
128
+ - `Diagnostics.from_netcdf(path)`: Restore diagnostics from a saved netCDF file.
129
+ - `Diagnostics(...)`: Construct a diagnostics object from metadata, epochs, and per-sample loss data.
130
+ - `Diagnostics.__add__(other)`: Concatenate compatible diagnostics across epochs.
131
+ - `Diagnostics.to_netcdf(directory, verbose=True)`: Save diagnostics to a netCDF file.
132
+
133
+ #### `torch_tk.diagnostics.plotting`
134
+
135
+ Provides utilities for plotting diagnostics.
136
+
137
+ - `plot_diagnostics(diagnostics, plot_file=None, title=None, font_factor=1.5, figsize=(9, 6), xlim=None, ylim=None, loss_name='sqrt(loss)', pdf_bin_n=100, dpdlog10=False, show_plot=True, verbose=True)`: Plot kernel-density estimates of square-root per-sample loss distributions across one or more diagnostics objects and epochs.
138
+
139
+ ## Notes and limitations
140
+
141
+ - The checkpoint mechanism assumes that models and optimizers are importable from stable class paths and expose `constructor_dict()`, `state_dict()`, and `load_state_dict()`.
142
+ - The checkpoint design is not suitable for optimizers that require non-serializable constructor inputs or custom parameter-group reconstruction beyond `model.parameters()`.
143
+ - The diagnostic plotting utility requires strictly positive, non-negative loss values because it plots the square root of loss on a logarithmic axis.
144
+ - The recorded epoch loss in `Trainer` is exact only when the supplied loss function returns the mean per-sample loss over each batch, as stated in the trainer docstrings.
145
+
146
+ ## Development
147
+
148
+ ### Code Quality and Testing Commands
149
+
150
+ - `make fmt` - Runs `ruff format`, which reformats Python files according to the style rules in `pyproject.toml`.
151
+ - `make lint` - Runs `ruff check --fix`, which lints the code and auto-fixes what it can.
152
+ - `make check` - Runs formatting and linting.
153
+ - `make type` - Currently disabled. Intended to run `mypy` using the settings in `pyproject.toml`.
154
+ - `make test` - Runs `pytest` with the test settings configured in `pyproject.toml`.
155
+
156
+ ## Author
157
+
158
+ Jan Kazil - jan.kazil.dev@gmail.com - [jankazil.com](https://jankazil.com)
159
+
160
+ ## License
161
+
162
+ BSD-3-Clause
@@ -0,0 +1,98 @@
1
+ [build-system]
2
+ requires = ["setuptools>=77", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torch-tk"
7
+ version = "1.0.8"
8
+ description = "Streamlines training, checkpoint management, and diagnostics of PyTorch models."
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ license = "BSD-3-Clause"
12
+ license-files = ["LICENSE"]
13
+ authors = [
14
+ { name = "Jan Kazil", email = "jan.kazil.dev@gmail.com" }
15
+ ]
16
+ dependencies = [
17
+ "torch",
18
+ "numpy",
19
+ "xarray",
20
+ "matplotlib",
21
+ "scipy",
22
+ ]
23
+ classifiers = [
24
+ "Development Status :: 4 - Beta",
25
+ "Intended Audience :: Developers",
26
+ "Intended Audience :: Science/Research",
27
+ "Programming Language :: Python :: 3",
28
+ "Programming Language :: Python :: 3.12",
29
+ "Topic :: Scientific/Engineering",
30
+ "Topic :: Software Development :: Libraries :: Python Modules",
31
+ ]
32
+ keywords = ["pytorch", "training", "checkpointing", "diagnostics", "deep-learning"]
33
+
34
+ [project.urls]
35
+ Homepage = "https://github.com/jankazil/torch-tk"
36
+ Repository = "https://github.com/jankazil/torch-tk"
37
+ Issues = "https://github.com/jankazil/torch-tk/issues"
38
+
39
+ [project.optional-dependencies]
40
+ dev = [
41
+ "build>=1.2",
42
+ "twine>=5.1",
43
+ "pytest>=8",
44
+ "pytest-cov>=5",
45
+ "mypy>=1.11",
46
+ "ruff>=0.5",
47
+ "pre-commit>=3.7",
48
+ ]
49
+
50
+ [tool.setuptools]
51
+ package-dir = {"" = "src"}
52
+
53
+ [tool.setuptools.packages.find]
54
+ where = ["src"]
55
+
56
+ [tool.ruff]
57
+ line-length = 128
58
+ target-version = "py312"
59
+ extend-exclude = [
60
+ "__init__.py",
61
+ "dist",
62
+ "build",
63
+ "data",
64
+ "demos",
65
+ "docs",
66
+ "experiments",
67
+ "notebooks",
68
+ ]
69
+
70
+ [tool.ruff.lint]
71
+ select = ["E", "F", "I", "UP", "B", "SIM"]
72
+ ignore = ["E501", "SIM108"]
73
+
74
+ [tool.ruff.format]
75
+ quote-style = "preserve"
76
+
77
+ #
78
+ # MyPy is disabled because static analysis cannot account for all
79
+ # dynamic runtime behaviors, mypy may report false positives which
80
+ # do no reflect actual runtime issues.
81
+ #
82
+ #[tool.mypy]
83
+ #python_version = "3.12"
84
+ #warn_unused_configs = true
85
+ #disallow_untyped_defs = true
86
+ #disallow_incomplete_defs = true
87
+ #no_implicit_optional = true
88
+ #check_untyped_defs = true
89
+ #strict_optional = true
90
+ #pretty = true
91
+ #namespace_packages = true
92
+ #mypy_path = "src"
93
+ #files = ["src/torch_tk", "scripts"]
94
+
95
+ [tool.pytest.ini_options]
96
+ testpaths = ["tests"]
97
+ python_files = ["test_*.py"]
98
+ python_functions = ["test_*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,29 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ _DIST_NAME = "torch-tk"
4
+
5
+ try:
6
+ __version__ = version(_DIST_NAME)
7
+ except PackageNotFoundError:
8
+ pkg = __package__ or __name__.split(".", 1)[0]
9
+ try:
10
+ __version__ = version(pkg)
11
+ except PackageNotFoundError:
12
+ __version__ = "0.0.0+local"
13
+
14
+ from .models.model import Model
15
+ from .optimizers.sgd import SGD
16
+ from .optimizers.adam import Adam
17
+ from .training.trainer import Trainer
18
+ from .checkpoints.checkpoint_manager import CheckPointManager
19
+ from .diagnostics.diagnostics import Diagnostics
20
+
21
+ __all__ = [
22
+ "__version__",
23
+ "Model",
24
+ "SGD",
25
+ "Adam",
26
+ "Trainer",
27
+ "CheckPointManager",
28
+ "Diagnostics",
29
+ ]
@@ -0,0 +1 @@
1
+ from .checkpoint_manager import CheckPointManager
@@ -0,0 +1,139 @@
1
+ '''
2
+ Checkpoint utilities for saving and restoring self-describing model and optimizer state.
3
+
4
+ This module provides the CheckPointManager class, which saves checkpoints that
5
+ contain enough information to reconstruct both a model and its optimizer in the
6
+ state that created the checkpoint. Each checkpoint stores the fully qualified
7
+ class path, constructor arguments from constructor_dict(), and state from
8
+ state_dict() for both objects.
9
+
10
+ Models and optimizers must be importable from a stable class path and must expose
11
+
12
+ - constructor_dict()
13
+ - state_dict()
14
+ - load_state_dict()
15
+
16
+ A model must be reconstructible from its class path and constructor_dict().
17
+ An optimizer must be reconstructible from its class path, model.parameters(),
18
+ and constructor_dict().
19
+
20
+ The constructor data must be serializable. In practice, this means it should
21
+ contain only standard serializable Python values such as numbers, strings,
22
+ lists, tuples, and dictionaries.
23
+
24
+ This design is not suitable for optimizers that depend on non-serializable
25
+ constructor inputs, non-standard constructor signatures, custom parameter-group
26
+ reconstruction beyond model.parameters(), or runtime state not captured by
27
+ state_dict().
28
+ '''
29
+
30
+ from pathlib import Path
31
+
32
+ import torch
33
+
34
+ from .utils import class_path_of_instance, import_class
35
+
36
+
37
+ class CheckPointManager:
38
+ '''
39
+ Save checkpoints and rebuild a model and optimizer from a checkpoint file.
40
+
41
+ The checkpoint contains the epoch, class paths, constructor arguments, and
42
+ state dictionaries for the model and optimizer.
43
+ '''
44
+
45
+ def __init__(self, model, optimizer, directory):
46
+ '''
47
+ Store the model, optimizer, and checkpoint directory.
48
+
49
+ The directory is converted to a Path and created if needed.
50
+ '''
51
+ if not isinstance(directory, Path):
52
+ directory = Path(directory)
53
+
54
+ self.model = model
55
+ self.optimizer = optimizer
56
+ self.directory = directory
57
+
58
+ def save(self, epoch):
59
+ '''
60
+ Save a checkpoint for the current model and optimizer state.
61
+
62
+ Returns the path to the written checkpoint file.
63
+ '''
64
+ if not hasattr(self.model, 'constructor_dict'):
65
+ raise TypeError('Model must implement constructor_dict().')
66
+
67
+ if not hasattr(self.optimizer, 'constructor_dict'):
68
+ raise TypeError('Optimizer must implement constructor_dict().')
69
+
70
+ checkpoint = {
71
+ 'epoch': epoch,
72
+ 'model_class_path': class_path_of_instance(self.model),
73
+ 'model_constructor_dict': self.model.constructor_dict(),
74
+ 'model_state_dict': self.model.state_dict(),
75
+ 'optimizer_class_path': class_path_of_instance(self.optimizer),
76
+ 'optimizer_constructor_dict': self.optimizer.constructor_dict(),
77
+ 'optimizer_state_dict': self.optimizer.state_dict(),
78
+ }
79
+
80
+ file_name = Path(type(self.model).__name__ + '.' + type(self.optimizer).__name__ + '.epoch=' + str(epoch) + '.pt')
81
+
82
+ self.directory.mkdir(parents=True, exist_ok=True)
83
+
84
+ file_path = self.directory / file_name
85
+ torch.save(checkpoint, file_path)
86
+
87
+ return file_path
88
+
89
+ @classmethod
90
+ def load_from_file(cls, file_path, device=None):
91
+ '''
92
+ Load a checkpoint file and reconstruct the model, optimizer, and manager.
93
+
94
+ Returns:
95
+ checkpoint_manager, model, optimizer, epoch
96
+
97
+ If device is given, the checkpoint is loaded onto that device and a
98
+ saved model constructor argument named 'device' is overridden.
99
+ '''
100
+ if not isinstance(file_path, Path):
101
+ file_path = Path(file_path)
102
+
103
+ checkpoint = torch.load(file_path, map_location=device)
104
+
105
+ # Reconstruct model
106
+
107
+ model_class = import_class(checkpoint['model_class_path'])
108
+
109
+ model_constructor_dict = checkpoint['model_constructor_dict']
110
+ model_args = model_constructor_dict.get('args', [])
111
+ model_kwargs = dict(model_constructor_dict.get('kwargs', {}))
112
+
113
+ if 'device' in model_kwargs and device is not None:
114
+ model_kwargs['device'] = device
115
+
116
+ model = model_class(*model_args, **model_kwargs)
117
+ model.load_state_dict(checkpoint['model_state_dict'])
118
+
119
+ if device is not None:
120
+ model = model.to(device)
121
+
122
+ # Reconstruct optimizer
123
+
124
+ optimizer_class = import_class(checkpoint['optimizer_class_path'])
125
+
126
+ optimizer_constructor_dict = checkpoint['optimizer_constructor_dict']
127
+ optimizer_args = optimizer_constructor_dict.get('args', [])
128
+ optimizer_kwargs = dict(optimizer_constructor_dict.get('kwargs', {}))
129
+
130
+ optimizer = optimizer_class(model.parameters(), *optimizer_args, **optimizer_kwargs)
131
+
132
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
133
+
134
+ # Create a checkpoint manager instance
135
+ checkpoint_manager = cls(model, optimizer, file_path.parent)
136
+
137
+ epoch = checkpoint['epoch']
138
+
139
+ return checkpoint_manager, model, optimizer, epoch