torch-tk 1.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_tk-1.0.8/LICENSE +13 -0
- torch_tk-1.0.8/PKG-INFO +197 -0
- torch_tk-1.0.8/README.md +162 -0
- torch_tk-1.0.8/pyproject.toml +98 -0
- torch_tk-1.0.8/setup.cfg +4 -0
- torch_tk-1.0.8/src/torch_tk/__init__.py +29 -0
- torch_tk-1.0.8/src/torch_tk/checkpoints/__init__.py +1 -0
- torch_tk-1.0.8/src/torch_tk/checkpoints/checkpoint_manager.py +139 -0
- torch_tk-1.0.8/src/torch_tk/checkpoints/utils.py +33 -0
- torch_tk-1.0.8/src/torch_tk/diagnostics/__init__.py +3 -0
- torch_tk-1.0.8/src/torch_tk/diagnostics/diagnostics.py +281 -0
- torch_tk-1.0.8/src/torch_tk/diagnostics/loss.py +160 -0
- torch_tk-1.0.8/src/torch_tk/diagnostics/plotting.py +148 -0
- torch_tk-1.0.8/src/torch_tk/models/__init__.py +2 -0
- torch_tk-1.0.8/src/torch_tk/models/model.py +102 -0
- torch_tk-1.0.8/src/torch_tk/models/utils.py +18 -0
- torch_tk-1.0.8/src/torch_tk/optimizers/__init__.py +3 -0
- torch_tk-1.0.8/src/torch_tk/optimizers/adam.py +87 -0
- torch_tk-1.0.8/src/torch_tk/optimizers/sgd.py +79 -0
- torch_tk-1.0.8/src/torch_tk/optimizers/sgd_manual.py +104 -0
- torch_tk-1.0.8/src/torch_tk/test.py +5 -0
- torch_tk-1.0.8/src/torch_tk/training/__init__.py +1 -0
- torch_tk-1.0.8/src/torch_tk/training/trainer.py +429 -0
- torch_tk-1.0.8/src/torch_tk.egg-info/PKG-INFO +197 -0
- torch_tk-1.0.8/src/torch_tk.egg-info/SOURCES.txt +27 -0
- torch_tk-1.0.8/src/torch_tk.egg-info/dependency_links.txt +1 -0
- torch_tk-1.0.8/src/torch_tk.egg-info/requires.txt +14 -0
- torch_tk-1.0.8/src/torch_tk.egg-info/top_level.txt +1 -0
- torch_tk-1.0.8/tests/test_test.py +5 -0
torch_tk-1.0.8/LICENSE
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026, Jan Kazil
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
|
8
|
+
|
|
9
|
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
|
10
|
+
|
|
11
|
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
|
12
|
+
|
|
13
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
torch_tk-1.0.8/PKG-INFO
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-tk
|
|
3
|
+
Version: 1.0.8
|
|
4
|
+
Summary: Streamlines training, checkpoint management, and diagnostics of PyTorch models.
|
|
5
|
+
Author-email: Jan Kazil <jan.kazil.dev@gmail.com>
|
|
6
|
+
License-Expression: BSD-3-Clause
|
|
7
|
+
Project-URL: Homepage, https://github.com/jankazil/torch-tk
|
|
8
|
+
Project-URL: Repository, https://github.com/jankazil/torch-tk
|
|
9
|
+
Project-URL: Issues, https://github.com/jankazil/torch-tk/issues
|
|
10
|
+
Keywords: pytorch,training,checkpointing,diagnostics,deep-learning
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering
|
|
17
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
18
|
+
Requires-Python: >=3.12
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: torch
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: xarray
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Requires-Dist: scipy
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: build>=1.2; extra == "dev"
|
|
28
|
+
Requires-Dist: twine>=5.1; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest>=8; extra == "dev"
|
|
30
|
+
Requires-Dist: pytest-cov>=5; extra == "dev"
|
|
31
|
+
Requires-Dist: mypy>=1.11; extra == "dev"
|
|
32
|
+
Requires-Dist: ruff>=0.5; extra == "dev"
|
|
33
|
+
Requires-Dist: pre-commit>=3.7; extra == "dev"
|
|
34
|
+
Dynamic: license-file
|
|
35
|
+
|
|
36
|
+
# Torch ToolKit (torch-tk)
|
|
37
|
+
|
|
38
|
+
**torch-tk** streamlines training, checkpoint management, and diagnostics of PyTorch models.
|
|
39
|
+
|
|
40
|
+
## Overview
|
|
41
|
+
|
|
42
|
+
**torch-tk** adds a small amount of structure around PyTorch models and optimizers to simplify training and automate saving and restoring checkpoints.
|
|
43
|
+
|
|
44
|
+
**torch-tk** provides a model base class, optimizers, a checkpoint manager, a trainer, and diagnostics utilities. The model class inherits from `torch.nn.Module`, and the **torch-tk** optimizers are wrappers around `torch.optim` optimizers such as `torch.optim.SGD` and `torch.optim.Adam`. The **torch-tk** model and optimizer classes thus preserve functionality and interface of PyTorch modules and optimizers.
|
|
45
|
+
|
|
46
|
+
## Key features
|
|
47
|
+
|
|
48
|
+
**torch-tk** models and optimizers are self-describing: A model derived from `torch_tk.models.Model` and the optimizers in `torch_tk.optimizers` provide all information needed to save their state and recreate the same model and optimizer instances later. In practice, this means a model and optimizer can save to file the constructor arguments and state parameters needed to rebuild them, allowing to load them back into fresh instances.
|
|
49
|
+
|
|
50
|
+
**torch-tk** provides a `CheckPointManager`. The `CheckPointManager` manages saving, loading, and reconstruing both the model and the optimizer in the state that created the checkpoint. All that is required is that the class paths are available to import the original model and optimizer classes.
|
|
51
|
+
|
|
52
|
+
**torch-tk** provides a `Trainer` class for running epoch-based training. It supports training either from a `DataLoader` or directly from tensors, and records basic diagnostics such as training loss and epoch wallclock time.
|
|
53
|
+
|
|
54
|
+
**torch-tk** provides a `Diagnostics` class for storing sample-resolved loss information together with training metadata. These diagnostics can be created from tensors or data loaders and can be saved to and restored from netCDF files for later analysis.
|
|
55
|
+
|
|
56
|
+
## Workflow
|
|
57
|
+
The workflow is shown in the **torch-tk** [HowTo](https://github.com/jankazil/torch-tk/blob/main/notebooks/HowTo.ipynb) Jupyter notebook.
|
|
58
|
+
|
|
59
|
+
## Installation
|
|
60
|
+
|
|
61
|
+
### pip
|
|
62
|
+
```bash
|
|
63
|
+
pip install torch-tk
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### conda / mamba
|
|
67
|
+
```bash
|
|
68
|
+
mamba install -c jan.kazil -c conda-forge torch-tk
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## Classes
|
|
72
|
+
|
|
73
|
+
- `Model`
|
|
74
|
+
|
|
75
|
+
- A base class which makes models self-describing and automatically reconstructible by the `CheckPointManager`
|
|
76
|
+
- Automatically rebuilds a model from a saved file
|
|
77
|
+
|
|
78
|
+
- `SGD`, `Adam`, ...
|
|
79
|
+
|
|
80
|
+
Wrapper classes for PyTorch optimizers that make the optimizers self-describing and automatically reconstructible by the `CheckPointManager`
|
|
81
|
+
|
|
82
|
+
- `Trainer`
|
|
83
|
+
|
|
84
|
+
- Trains a model from
|
|
85
|
+
- a `torch.utils.data.DataLoader`
|
|
86
|
+
- or directly from tensors, using an efficient batching mechanism
|
|
87
|
+
- Records training loss and model timing per epoch
|
|
88
|
+
|
|
89
|
+
- `CheckPointManager`
|
|
90
|
+
|
|
91
|
+
- Saves and restores model training states
|
|
92
|
+
- Automatically rebuilds both a model and its optimizer from a saved checkpoint file
|
|
93
|
+
|
|
94
|
+
- `Diagnostics`
|
|
95
|
+
|
|
96
|
+
- Computes, stores, and plots per-sample loss and per-sample loss probability distribution
|
|
97
|
+
- Saves and restores diagnostics in netCDF file format
|
|
98
|
+
- Identifies worst-loss samples
|
|
99
|
+
|
|
100
|
+
## Public API
|
|
101
|
+
|
|
102
|
+
### Modules
|
|
103
|
+
|
|
104
|
+
#### `torch_tk.models.model`
|
|
105
|
+
|
|
106
|
+
Provides the abstract `Model` base class for models that can describe, save, restore, and reconstruct themselves. The `Model` class inherits from `torch.nn.Module`, and thus provides the standard PyTorch `Module` interface.
|
|
107
|
+
|
|
108
|
+
The `Model` class defines and provides the following methods:
|
|
109
|
+
|
|
110
|
+
- `Model.forward(xb)`: Abstract method that computes the forward pass.
|
|
111
|
+
- `Model.constructor_dict()`: Abstract method that returns the constructor arguments needed to reconstruct the model.
|
|
112
|
+
- `Model.save_state_dict_to_file(path)`: Save only the state dictionary.
|
|
113
|
+
- `Model.save_to_file(path)`: Save constructor arguments and state dictionary needed to recreate the model.
|
|
114
|
+
- `Model.load_from_file(path, device=None)`: Recreate a model from a saved file.
|
|
115
|
+
- `Model.clone(constructor_dict, state_dict, device=None)`: Reconstruct a model from constructor arguments and state.
|
|
116
|
+
|
|
117
|
+
#### `torch_tk.optimizers.sgd`
|
|
118
|
+
|
|
119
|
+
Wrapper around `torch.optim.SGD` to make it self-describing and automatically reconstructible.
|
|
120
|
+
|
|
121
|
+
- `SGD(...)`: Subclass of `torch.optim.SGD` that stores its constructor arguments on the instance.
|
|
122
|
+
- `SGD.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
|
|
123
|
+
|
|
124
|
+
#### `torch_tk.optimizers.adam`
|
|
125
|
+
|
|
126
|
+
Wrapper around `torch.optim.Adam` to make it self-describing and automatically reconstructible.
|
|
127
|
+
|
|
128
|
+
- `Adam(...)`: Subclass of `torch.optim.Adam` that stores its constructor arguments on the instance.
|
|
129
|
+
- `Adam.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
|
|
130
|
+
|
|
131
|
+
#### `torch_tk.training.trainer`
|
|
132
|
+
|
|
133
|
+
Provides the `Trainer` class for epoch-based training and simple training diagnostics.
|
|
134
|
+
|
|
135
|
+
- `Trainer(model, optimizer, loss_function, epoch=0)`: Initialize trainer state.
|
|
136
|
+
- `Trainer.train_with_dataloader(data_loader, num_epochs, epoch_diag_step=1, verbose=True)`: Train from a `DataLoader`.
|
|
137
|
+
- `Trainer.train_with_data(x_train, y_train, bs, num_epochs, epoch_diag_step=1, verbose=True, shuffle=False)`: Train from in-memory tensors.
|
|
138
|
+
- `Trainer.plot_loss(...)`: Plot recorded diagnostic loss versus epoch.
|
|
139
|
+
- `Trainer.plot_wallclock_time(...)`: Plot recorded epoch wallclock time versus epoch.
|
|
140
|
+
|
|
141
|
+
#### `torch_tk.checkpoints.checkpoint_manager`
|
|
142
|
+
|
|
143
|
+
Provides checkpoint management for saving and reconstructing a model and optimizer together.
|
|
144
|
+
|
|
145
|
+
- `CheckPointManager(model, optimizer, directory)`: Manage checkpoint saving in a directory.
|
|
146
|
+
- `CheckPointManager.save(epoch)`: Save a checkpoint containing epoch, class paths, constructor dictionaries, and state dictionaries.
|
|
147
|
+
- `CheckPointManager.load_from_file(file_path, device=None)`: Reconstruct and return `checkpoint_manager, model, optimizer, epoch` from a checkpoint file.
|
|
148
|
+
|
|
149
|
+
#### `torch_tk.diagnostics.loss`
|
|
150
|
+
|
|
151
|
+
Provides utilities for computing per-sample loss.
|
|
152
|
+
|
|
153
|
+
- `per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader)`: Compute per-sample losses and their mean from a `DataLoader`.
|
|
154
|
+
- `per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None)`: Compute per-sample losses and their mean from in-memory tensors.
|
|
155
|
+
- `model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None)`: Return the indices and values of the `n` worst losses.
|
|
156
|
+
|
|
157
|
+
#### `torch_tk.diagnostics.diagnostics`
|
|
158
|
+
|
|
159
|
+
Provides the `Diagnostics` container for sample-resolved loss diagnostics and analysis.
|
|
160
|
+
|
|
161
|
+
- `Diagnostics.from_data_loader(...)`: Build diagnostics from a model evaluated on a `DataLoader`.
|
|
162
|
+
- `Diagnostics.from_data(...)`: Build diagnostics from in-memory tensors.
|
|
163
|
+
- `Diagnostics.from_netcdf(path)`: Restore diagnostics from a saved netCDF file.
|
|
164
|
+
- `Diagnostics(...)`: Construct a diagnostics object from metadata, epochs, and per-sample loss data.
|
|
165
|
+
- `Diagnostics.__add__(other)`: Concatenate compatible diagnostics across epochs.
|
|
166
|
+
- `Diagnostics.to_netcdf(directory, verbose=True)`: Save diagnostics to a netCDF file.
|
|
167
|
+
|
|
168
|
+
#### `torch_tk.diagnostics.plotting`
|
|
169
|
+
|
|
170
|
+
Provides utilities for plotting diagnostics.
|
|
171
|
+
|
|
172
|
+
- `plot_diagnostics(diagnostics, plot_file=None, title=None, font_factor=1.5, figsize=(9, 6), xlim=None, ylim=None, loss_name='sqrt(loss)', pdf_bin_n=100, dpdlog10=False, show_plot=True, verbose=True)`: Plot kernel-density estimates of square-root per-sample loss distributions across one or more diagnostics objects and epochs.
|
|
173
|
+
|
|
174
|
+
## Notes and limitations
|
|
175
|
+
|
|
176
|
+
- The checkpoint mechanism assumes that models and optimizers are importable from stable class paths and expose `constructor_dict()`, `state_dict()`, and `load_state_dict()`.
|
|
177
|
+
- The checkpoint design is not suitable for optimizers that require non-serializable constructor inputs or custom parameter-group reconstruction beyond `model.parameters()`.
|
|
178
|
+
- The diagnostic plotting utility requires strictly positive, non-negative loss values because it plots the square root of loss on a logarithmic axis.
|
|
179
|
+
- The recorded epoch loss in `Trainer` is exact only when the supplied loss function returns the mean per-sample loss over each batch, as stated in the trainer docstrings.
|
|
180
|
+
|
|
181
|
+
## Development
|
|
182
|
+
|
|
183
|
+
### Code Quality and Testing Commands
|
|
184
|
+
|
|
185
|
+
- `make fmt` - Runs `ruff format`, which reformats Python files according to the style rules in `pyproject.toml`.
|
|
186
|
+
- `make lint` - Runs `ruff check --fix`, which lints the code and auto-fixes what it can.
|
|
187
|
+
- `make check` - Runs formatting and linting.
|
|
188
|
+
- `make type` - Currently disabled. Intended to run `mypy` using the settings in `pyproject.toml`.
|
|
189
|
+
- `make test` - Runs `pytest` with the test settings configured in `pyproject.toml`.
|
|
190
|
+
|
|
191
|
+
## Author
|
|
192
|
+
|
|
193
|
+
Jan Kazil - jan.kazil.dev@gmail.com - [jankazil.com](https://jankazil.com)
|
|
194
|
+
|
|
195
|
+
## License
|
|
196
|
+
|
|
197
|
+
BSD-3-Clause
|
torch_tk-1.0.8/README.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Torch ToolKit (torch-tk)
|
|
2
|
+
|
|
3
|
+
**torch-tk** streamlines training, checkpoint management, and diagnostics of PyTorch models.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
**torch-tk** adds a small amount of structure around PyTorch models and optimizers to simplify training and automate saving and restoring checkpoints.
|
|
8
|
+
|
|
9
|
+
**torch-tk** provides a model base class, optimizers, a checkpoint manager, a trainer, and diagnostics utilities. The model class inherits from `torch.nn.Module`, and the **torch-tk** optimizers are wrappers around `torch.optim` optimizers such as `torch.optim.SGD` and `torch.optim.Adam`. The **torch-tk** model and optimizer classes thus preserve functionality and interface of PyTorch modules and optimizers.
|
|
10
|
+
|
|
11
|
+
## Key features
|
|
12
|
+
|
|
13
|
+
**torch-tk** models and optimizers are self-describing: A model derived from `torch_tk.models.Model` and the optimizers in `torch_tk.optimizers` provide all information needed to save their state and recreate the same model and optimizer instances later. In practice, this means a model and optimizer can save to file the constructor arguments and state parameters needed to rebuild them, allowing to load them back into fresh instances.
|
|
14
|
+
|
|
15
|
+
**torch-tk** provides a `CheckPointManager`. The `CheckPointManager` manages saving, loading, and reconstruing both the model and the optimizer in the state that created the checkpoint. All that is required is that the class paths are available to import the original model and optimizer classes.
|
|
16
|
+
|
|
17
|
+
**torch-tk** provides a `Trainer` class for running epoch-based training. It supports training either from a `DataLoader` or directly from tensors, and records basic diagnostics such as training loss and epoch wallclock time.
|
|
18
|
+
|
|
19
|
+
**torch-tk** provides a `Diagnostics` class for storing sample-resolved loss information together with training metadata. These diagnostics can be created from tensors or data loaders and can be saved to and restored from netCDF files for later analysis.
|
|
20
|
+
|
|
21
|
+
## Workflow
|
|
22
|
+
The workflow is shown in the **torch-tk** [HowTo](https://github.com/jankazil/torch-tk/blob/main/notebooks/HowTo.ipynb) Jupyter notebook.
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
### pip
|
|
27
|
+
```bash
|
|
28
|
+
pip install torch-tk
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### conda / mamba
|
|
32
|
+
```bash
|
|
33
|
+
mamba install -c jan.kazil -c conda-forge torch-tk
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Classes
|
|
37
|
+
|
|
38
|
+
- `Model`
|
|
39
|
+
|
|
40
|
+
- A base class which makes models self-describing and automatically reconstructible by the `CheckPointManager`
|
|
41
|
+
- Automatically rebuilds a model from a saved file
|
|
42
|
+
|
|
43
|
+
- `SGD`, `Adam`, ...
|
|
44
|
+
|
|
45
|
+
Wrapper classes for PyTorch optimizers that make the optimizers self-describing and automatically reconstructible by the `CheckPointManager`
|
|
46
|
+
|
|
47
|
+
- `Trainer`
|
|
48
|
+
|
|
49
|
+
- Trains a model from
|
|
50
|
+
- a `torch.utils.data.DataLoader`
|
|
51
|
+
- or directly from tensors, using an efficient batching mechanism
|
|
52
|
+
- Records training loss and model timing per epoch
|
|
53
|
+
|
|
54
|
+
- `CheckPointManager`
|
|
55
|
+
|
|
56
|
+
- Saves and restores model training states
|
|
57
|
+
- Automatically rebuilds both a model and its optimizer from a saved checkpoint file
|
|
58
|
+
|
|
59
|
+
- `Diagnostics`
|
|
60
|
+
|
|
61
|
+
- Computes, stores, and plots per-sample loss and per-sample loss probability distribution
|
|
62
|
+
- Saves and restores diagnostics in netCDF file format
|
|
63
|
+
- Identifies worst-loss samples
|
|
64
|
+
|
|
65
|
+
## Public API
|
|
66
|
+
|
|
67
|
+
### Modules
|
|
68
|
+
|
|
69
|
+
#### `torch_tk.models.model`
|
|
70
|
+
|
|
71
|
+
Provides the abstract `Model` base class for models that can describe, save, restore, and reconstruct themselves. The `Model` class inherits from `torch.nn.Module`, and thus provides the standard PyTorch `Module` interface.
|
|
72
|
+
|
|
73
|
+
The `Model` class defines and provides the following methods:
|
|
74
|
+
|
|
75
|
+
- `Model.forward(xb)`: Abstract method that computes the forward pass.
|
|
76
|
+
- `Model.constructor_dict()`: Abstract method that returns the constructor arguments needed to reconstruct the model.
|
|
77
|
+
- `Model.save_state_dict_to_file(path)`: Save only the state dictionary.
|
|
78
|
+
- `Model.save_to_file(path)`: Save constructor arguments and state dictionary needed to recreate the model.
|
|
79
|
+
- `Model.load_from_file(path, device=None)`: Recreate a model from a saved file.
|
|
80
|
+
- `Model.clone(constructor_dict, state_dict, device=None)`: Reconstruct a model from constructor arguments and state.
|
|
81
|
+
|
|
82
|
+
#### `torch_tk.optimizers.sgd`
|
|
83
|
+
|
|
84
|
+
Wrapper around `torch.optim.SGD` to make it self-describing and automatically reconstructible.
|
|
85
|
+
|
|
86
|
+
- `SGD(...)`: Subclass of `torch.optim.SGD` that stores its constructor arguments on the instance.
|
|
87
|
+
- `SGD.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
|
|
88
|
+
|
|
89
|
+
#### `torch_tk.optimizers.adam`
|
|
90
|
+
|
|
91
|
+
Wrapper around `torch.optim.Adam` to make it self-describing and automatically reconstructible.
|
|
92
|
+
|
|
93
|
+
- `Adam(...)`: Subclass of `torch.optim.Adam` that stores its constructor arguments on the instance.
|
|
94
|
+
- `Adam.constructor_dict()`: Return the stored optimizer constructor settings excluding `params`.
|
|
95
|
+
|
|
96
|
+
#### `torch_tk.training.trainer`
|
|
97
|
+
|
|
98
|
+
Provides the `Trainer` class for epoch-based training and simple training diagnostics.
|
|
99
|
+
|
|
100
|
+
- `Trainer(model, optimizer, loss_function, epoch=0)`: Initialize trainer state.
|
|
101
|
+
- `Trainer.train_with_dataloader(data_loader, num_epochs, epoch_diag_step=1, verbose=True)`: Train from a `DataLoader`.
|
|
102
|
+
- `Trainer.train_with_data(x_train, y_train, bs, num_epochs, epoch_diag_step=1, verbose=True, shuffle=False)`: Train from in-memory tensors.
|
|
103
|
+
- `Trainer.plot_loss(...)`: Plot recorded diagnostic loss versus epoch.
|
|
104
|
+
- `Trainer.plot_wallclock_time(...)`: Plot recorded epoch wallclock time versus epoch.
|
|
105
|
+
|
|
106
|
+
#### `torch_tk.checkpoints.checkpoint_manager`
|
|
107
|
+
|
|
108
|
+
Provides checkpoint management for saving and reconstructing a model and optimizer together.
|
|
109
|
+
|
|
110
|
+
- `CheckPointManager(model, optimizer, directory)`: Manage checkpoint saving in a directory.
|
|
111
|
+
- `CheckPointManager.save(epoch)`: Save a checkpoint containing epoch, class paths, constructor dictionaries, and state dictionaries.
|
|
112
|
+
- `CheckPointManager.load_from_file(file_path, device=None)`: Reconstruct and return `checkpoint_manager, model, optimizer, epoch` from a checkpoint file.
|
|
113
|
+
|
|
114
|
+
#### `torch_tk.diagnostics.loss`
|
|
115
|
+
|
|
116
|
+
Provides utilities for computing per-sample loss.
|
|
117
|
+
|
|
118
|
+
- `per_sample_loss_from_data_loader(model, loss_function_sample_resolved, data_loader)`: Compute per-sample losses and their mean from a `DataLoader`.
|
|
119
|
+
- `per_sample_loss_from_data(model, loss_function_sample_resolved, x_data, y_data, chunk_size=None)`: Compute per-sample losses and their mean from in-memory tensors.
|
|
120
|
+
- `model_worst_loss(model, loss_function_sample_resolved, x_data, y_data, n, chunk_size=None)`: Return the indices and values of the `n` worst losses.
|
|
121
|
+
|
|
122
|
+
#### `torch_tk.diagnostics.diagnostics`
|
|
123
|
+
|
|
124
|
+
Provides the `Diagnostics` container for sample-resolved loss diagnostics and analysis.
|
|
125
|
+
|
|
126
|
+
- `Diagnostics.from_data_loader(...)`: Build diagnostics from a model evaluated on a `DataLoader`.
|
|
127
|
+
- `Diagnostics.from_data(...)`: Build diagnostics from in-memory tensors.
|
|
128
|
+
- `Diagnostics.from_netcdf(path)`: Restore diagnostics from a saved netCDF file.
|
|
129
|
+
- `Diagnostics(...)`: Construct a diagnostics object from metadata, epochs, and per-sample loss data.
|
|
130
|
+
- `Diagnostics.__add__(other)`: Concatenate compatible diagnostics across epochs.
|
|
131
|
+
- `Diagnostics.to_netcdf(directory, verbose=True)`: Save diagnostics to a netCDF file.
|
|
132
|
+
|
|
133
|
+
#### `torch_tk.diagnostics.plotting`
|
|
134
|
+
|
|
135
|
+
Provides utilities for plotting diagnostics.
|
|
136
|
+
|
|
137
|
+
- `plot_diagnostics(diagnostics, plot_file=None, title=None, font_factor=1.5, figsize=(9, 6), xlim=None, ylim=None, loss_name='sqrt(loss)', pdf_bin_n=100, dpdlog10=False, show_plot=True, verbose=True)`: Plot kernel-density estimates of square-root per-sample loss distributions across one or more diagnostics objects and epochs.
|
|
138
|
+
|
|
139
|
+
## Notes and limitations
|
|
140
|
+
|
|
141
|
+
- The checkpoint mechanism assumes that models and optimizers are importable from stable class paths and expose `constructor_dict()`, `state_dict()`, and `load_state_dict()`.
|
|
142
|
+
- The checkpoint design is not suitable for optimizers that require non-serializable constructor inputs or custom parameter-group reconstruction beyond `model.parameters()`.
|
|
143
|
+
- The diagnostic plotting utility requires strictly positive, non-negative loss values because it plots the square root of loss on a logarithmic axis.
|
|
144
|
+
- The recorded epoch loss in `Trainer` is exact only when the supplied loss function returns the mean per-sample loss over each batch, as stated in the trainer docstrings.
|
|
145
|
+
|
|
146
|
+
## Development
|
|
147
|
+
|
|
148
|
+
### Code Quality and Testing Commands
|
|
149
|
+
|
|
150
|
+
- `make fmt` - Runs `ruff format`, which reformats Python files according to the style rules in `pyproject.toml`.
|
|
151
|
+
- `make lint` - Runs `ruff check --fix`, which lints the code and auto-fixes what it can.
|
|
152
|
+
- `make check` - Runs formatting and linting.
|
|
153
|
+
- `make type` - Currently disabled. Intended to run `mypy` using the settings in `pyproject.toml`.
|
|
154
|
+
- `make test` - Runs `pytest` with the test settings configured in `pyproject.toml`.
|
|
155
|
+
|
|
156
|
+
## Author
|
|
157
|
+
|
|
158
|
+
Jan Kazil - jan.kazil.dev@gmail.com - [jankazil.com](https://jankazil.com)
|
|
159
|
+
|
|
160
|
+
## License
|
|
161
|
+
|
|
162
|
+
BSD-3-Clause
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=77", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torch-tk"
|
|
7
|
+
version = "1.0.8"
|
|
8
|
+
description = "Streamlines training, checkpoint management, and diagnostics of PyTorch models."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
license = "BSD-3-Clause"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Jan Kazil", email = "jan.kazil.dev@gmail.com" }
|
|
15
|
+
]
|
|
16
|
+
dependencies = [
|
|
17
|
+
"torch",
|
|
18
|
+
"numpy",
|
|
19
|
+
"xarray",
|
|
20
|
+
"matplotlib",
|
|
21
|
+
"scipy",
|
|
22
|
+
]
|
|
23
|
+
classifiers = [
|
|
24
|
+
"Development Status :: 4 - Beta",
|
|
25
|
+
"Intended Audience :: Developers",
|
|
26
|
+
"Intended Audience :: Science/Research",
|
|
27
|
+
"Programming Language :: Python :: 3",
|
|
28
|
+
"Programming Language :: Python :: 3.12",
|
|
29
|
+
"Topic :: Scientific/Engineering",
|
|
30
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
31
|
+
]
|
|
32
|
+
keywords = ["pytorch", "training", "checkpointing", "diagnostics", "deep-learning"]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/jankazil/torch-tk"
|
|
36
|
+
Repository = "https://github.com/jankazil/torch-tk"
|
|
37
|
+
Issues = "https://github.com/jankazil/torch-tk/issues"
|
|
38
|
+
|
|
39
|
+
[project.optional-dependencies]
|
|
40
|
+
dev = [
|
|
41
|
+
"build>=1.2",
|
|
42
|
+
"twine>=5.1",
|
|
43
|
+
"pytest>=8",
|
|
44
|
+
"pytest-cov>=5",
|
|
45
|
+
"mypy>=1.11",
|
|
46
|
+
"ruff>=0.5",
|
|
47
|
+
"pre-commit>=3.7",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
[tool.setuptools]
|
|
51
|
+
package-dir = {"" = "src"}
|
|
52
|
+
|
|
53
|
+
[tool.setuptools.packages.find]
|
|
54
|
+
where = ["src"]
|
|
55
|
+
|
|
56
|
+
[tool.ruff]
|
|
57
|
+
line-length = 128
|
|
58
|
+
target-version = "py312"
|
|
59
|
+
extend-exclude = [
|
|
60
|
+
"__init__.py",
|
|
61
|
+
"dist",
|
|
62
|
+
"build",
|
|
63
|
+
"data",
|
|
64
|
+
"demos",
|
|
65
|
+
"docs",
|
|
66
|
+
"experiments",
|
|
67
|
+
"notebooks",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
[tool.ruff.lint]
|
|
71
|
+
select = ["E", "F", "I", "UP", "B", "SIM"]
|
|
72
|
+
ignore = ["E501", "SIM108"]
|
|
73
|
+
|
|
74
|
+
[tool.ruff.format]
|
|
75
|
+
quote-style = "preserve"
|
|
76
|
+
|
|
77
|
+
#
|
|
78
|
+
# MyPy is disabled because static analysis cannot account for all
|
|
79
|
+
# dynamic runtime behaviors, mypy may report false positives which
|
|
80
|
+
# do no reflect actual runtime issues.
|
|
81
|
+
#
|
|
82
|
+
#[tool.mypy]
|
|
83
|
+
#python_version = "3.12"
|
|
84
|
+
#warn_unused_configs = true
|
|
85
|
+
#disallow_untyped_defs = true
|
|
86
|
+
#disallow_incomplete_defs = true
|
|
87
|
+
#no_implicit_optional = true
|
|
88
|
+
#check_untyped_defs = true
|
|
89
|
+
#strict_optional = true
|
|
90
|
+
#pretty = true
|
|
91
|
+
#namespace_packages = true
|
|
92
|
+
#mypy_path = "src"
|
|
93
|
+
#files = ["src/torch_tk", "scripts"]
|
|
94
|
+
|
|
95
|
+
[tool.pytest.ini_options]
|
|
96
|
+
testpaths = ["tests"]
|
|
97
|
+
python_files = ["test_*.py"]
|
|
98
|
+
python_functions = ["test_*"]
|
torch_tk-1.0.8/setup.cfg
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
2
|
+
|
|
3
|
+
_DIST_NAME = "torch-tk"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version(_DIST_NAME)
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
pkg = __package__ or __name__.split(".", 1)[0]
|
|
9
|
+
try:
|
|
10
|
+
__version__ = version(pkg)
|
|
11
|
+
except PackageNotFoundError:
|
|
12
|
+
__version__ = "0.0.0+local"
|
|
13
|
+
|
|
14
|
+
from .models.model import Model
|
|
15
|
+
from .optimizers.sgd import SGD
|
|
16
|
+
from .optimizers.adam import Adam
|
|
17
|
+
from .training.trainer import Trainer
|
|
18
|
+
from .checkpoints.checkpoint_manager import CheckPointManager
|
|
19
|
+
from .diagnostics.diagnostics import Diagnostics
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"__version__",
|
|
23
|
+
"Model",
|
|
24
|
+
"SGD",
|
|
25
|
+
"Adam",
|
|
26
|
+
"Trainer",
|
|
27
|
+
"CheckPointManager",
|
|
28
|
+
"Diagnostics",
|
|
29
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .checkpoint_manager import CheckPointManager
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Checkpoint utilities for saving and restoring self-describing model and optimizer state.
|
|
3
|
+
|
|
4
|
+
This module provides the CheckPointManager class, which saves checkpoints that
|
|
5
|
+
contain enough information to reconstruct both a model and its optimizer in the
|
|
6
|
+
state that created the checkpoint. Each checkpoint stores the fully qualified
|
|
7
|
+
class path, constructor arguments from constructor_dict(), and state from
|
|
8
|
+
state_dict() for both objects.
|
|
9
|
+
|
|
10
|
+
Models and optimizers must be importable from a stable class path and must expose
|
|
11
|
+
|
|
12
|
+
- constructor_dict()
|
|
13
|
+
- state_dict()
|
|
14
|
+
- load_state_dict()
|
|
15
|
+
|
|
16
|
+
A model must be reconstructible from its class path and constructor_dict().
|
|
17
|
+
An optimizer must be reconstructible from its class path, model.parameters(),
|
|
18
|
+
and constructor_dict().
|
|
19
|
+
|
|
20
|
+
The constructor data must be serializable. In practice, this means it should
|
|
21
|
+
contain only standard serializable Python values such as numbers, strings,
|
|
22
|
+
lists, tuples, and dictionaries.
|
|
23
|
+
|
|
24
|
+
This design is not suitable for optimizers that depend on non-serializable
|
|
25
|
+
constructor inputs, non-standard constructor signatures, custom parameter-group
|
|
26
|
+
reconstruction beyond model.parameters(), or runtime state not captured by
|
|
27
|
+
state_dict().
|
|
28
|
+
'''
|
|
29
|
+
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
from .utils import class_path_of_instance, import_class
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CheckPointManager:
|
|
38
|
+
'''
|
|
39
|
+
Save checkpoints and rebuild a model and optimizer from a checkpoint file.
|
|
40
|
+
|
|
41
|
+
The checkpoint contains the epoch, class paths, constructor arguments, and
|
|
42
|
+
state dictionaries for the model and optimizer.
|
|
43
|
+
'''
|
|
44
|
+
|
|
45
|
+
def __init__(self, model, optimizer, directory):
|
|
46
|
+
'''
|
|
47
|
+
Store the model, optimizer, and checkpoint directory.
|
|
48
|
+
|
|
49
|
+
The directory is converted to a Path and created if needed.
|
|
50
|
+
'''
|
|
51
|
+
if not isinstance(directory, Path):
|
|
52
|
+
directory = Path(directory)
|
|
53
|
+
|
|
54
|
+
self.model = model
|
|
55
|
+
self.optimizer = optimizer
|
|
56
|
+
self.directory = directory
|
|
57
|
+
|
|
58
|
+
def save(self, epoch):
|
|
59
|
+
'''
|
|
60
|
+
Save a checkpoint for the current model and optimizer state.
|
|
61
|
+
|
|
62
|
+
Returns the path to the written checkpoint file.
|
|
63
|
+
'''
|
|
64
|
+
if not hasattr(self.model, 'constructor_dict'):
|
|
65
|
+
raise TypeError('Model must implement constructor_dict().')
|
|
66
|
+
|
|
67
|
+
if not hasattr(self.optimizer, 'constructor_dict'):
|
|
68
|
+
raise TypeError('Optimizer must implement constructor_dict().')
|
|
69
|
+
|
|
70
|
+
checkpoint = {
|
|
71
|
+
'epoch': epoch,
|
|
72
|
+
'model_class_path': class_path_of_instance(self.model),
|
|
73
|
+
'model_constructor_dict': self.model.constructor_dict(),
|
|
74
|
+
'model_state_dict': self.model.state_dict(),
|
|
75
|
+
'optimizer_class_path': class_path_of_instance(self.optimizer),
|
|
76
|
+
'optimizer_constructor_dict': self.optimizer.constructor_dict(),
|
|
77
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
file_name = Path(type(self.model).__name__ + '.' + type(self.optimizer).__name__ + '.epoch=' + str(epoch) + '.pt')
|
|
81
|
+
|
|
82
|
+
self.directory.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
84
|
+
file_path = self.directory / file_name
|
|
85
|
+
torch.save(checkpoint, file_path)
|
|
86
|
+
|
|
87
|
+
return file_path
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def load_from_file(cls, file_path, device=None):
|
|
91
|
+
'''
|
|
92
|
+
Load a checkpoint file and reconstruct the model, optimizer, and manager.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
checkpoint_manager, model, optimizer, epoch
|
|
96
|
+
|
|
97
|
+
If device is given, the checkpoint is loaded onto that device and a
|
|
98
|
+
saved model constructor argument named 'device' is overridden.
|
|
99
|
+
'''
|
|
100
|
+
if not isinstance(file_path, Path):
|
|
101
|
+
file_path = Path(file_path)
|
|
102
|
+
|
|
103
|
+
checkpoint = torch.load(file_path, map_location=device)
|
|
104
|
+
|
|
105
|
+
# Reconstruct model
|
|
106
|
+
|
|
107
|
+
model_class = import_class(checkpoint['model_class_path'])
|
|
108
|
+
|
|
109
|
+
model_constructor_dict = checkpoint['model_constructor_dict']
|
|
110
|
+
model_args = model_constructor_dict.get('args', [])
|
|
111
|
+
model_kwargs = dict(model_constructor_dict.get('kwargs', {}))
|
|
112
|
+
|
|
113
|
+
if 'device' in model_kwargs and device is not None:
|
|
114
|
+
model_kwargs['device'] = device
|
|
115
|
+
|
|
116
|
+
model = model_class(*model_args, **model_kwargs)
|
|
117
|
+
model.load_state_dict(checkpoint['model_state_dict'])
|
|
118
|
+
|
|
119
|
+
if device is not None:
|
|
120
|
+
model = model.to(device)
|
|
121
|
+
|
|
122
|
+
# Reconstruct optimizer
|
|
123
|
+
|
|
124
|
+
optimizer_class = import_class(checkpoint['optimizer_class_path'])
|
|
125
|
+
|
|
126
|
+
optimizer_constructor_dict = checkpoint['optimizer_constructor_dict']
|
|
127
|
+
optimizer_args = optimizer_constructor_dict.get('args', [])
|
|
128
|
+
optimizer_kwargs = dict(optimizer_constructor_dict.get('kwargs', {}))
|
|
129
|
+
|
|
130
|
+
optimizer = optimizer_class(model.parameters(), *optimizer_args, **optimizer_kwargs)
|
|
131
|
+
|
|
132
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
133
|
+
|
|
134
|
+
# Create a checkpoint manager instance
|
|
135
|
+
checkpoint_manager = cls(model, optimizer, file_path.parent)
|
|
136
|
+
|
|
137
|
+
epoch = checkpoint['epoch']
|
|
138
|
+
|
|
139
|
+
return checkpoint_manager, model, optimizer, epoch
|