helios-ml 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. helios-ml-0.1.0/.github/workflows/publish.yml +46 -0
  2. helios-ml-0.1.0/.github/workflows/tests.yml +66 -0
  3. helios-ml-0.1.0/.gitignore +20 -0
  4. helios-ml-0.1.0/.pre-commit-config.yaml +26 -0
  5. helios-ml-0.1.0/LICENSE +28 -0
  6. helios-ml-0.1.0/PKG-INFO +116 -0
  7. helios-ml-0.1.0/README.md +80 -0
  8. helios-ml-0.1.0/data/logo/logo-background.png +0 -0
  9. helios-ml-0.1.0/data/logo/logo-transparent.png +0 -0
  10. helios-ml-0.1.0/examples/cifar10/cifar10.py +297 -0
  11. helios-ml-0.1.0/pyproject.toml +119 -0
  12. helios-ml-0.1.0/requirements/ci.txt +17 -0
  13. helios-ml-0.1.0/requirements/default.txt +11 -0
  14. helios-ml-0.1.0/requirements/dev.txt +19 -0
  15. helios-ml-0.1.0/setup.cfg +4 -0
  16. helios-ml-0.1.0/src/helios/__init__.py +1 -0
  17. helios-ml-0.1.0/src/helios/_version.py +5 -0
  18. helios-ml-0.1.0/src/helios/core/__init__.py +25 -0
  19. helios-ml-0.1.0/src/helios/core/cuda.py +7 -0
  20. helios-ml-0.1.0/src/helios/core/distributed.py +179 -0
  21. helios-ml-0.1.0/src/helios/core/logging.py +533 -0
  22. helios-ml-0.1.0/src/helios/core/rng.py +151 -0
  23. helios-ml-0.1.0/src/helios/core/utils.py +417 -0
  24. helios-ml-0.1.0/src/helios/data/__init__.py +26 -0
  25. helios-ml-0.1.0/src/helios/data/datamodule.py +363 -0
  26. helios-ml-0.1.0/src/helios/data/functional.py +186 -0
  27. helios-ml-0.1.0/src/helios/data/samplers.py +229 -0
  28. helios-ml-0.1.0/src/helios/data/transforms.py +74 -0
  29. helios-ml-0.1.0/src/helios/losses/__init__.py +10 -0
  30. helios-ml-0.1.0/src/helios/losses/utils.py +21 -0
  31. helios-ml-0.1.0/src/helios/losses/weighted_loss.py +56 -0
  32. helios-ml-0.1.0/src/helios/metrics/__init__.py +27 -0
  33. helios-ml-0.1.0/src/helios/metrics/functional.py +447 -0
  34. helios-ml-0.1.0/src/helios/metrics/metrics.py +252 -0
  35. helios-ml-0.1.0/src/helios/model/__init__.py +10 -0
  36. helios-ml-0.1.0/src/helios/model/model.py +369 -0
  37. helios-ml-0.1.0/src/helios/model/utils.py +53 -0
  38. helios-ml-0.1.0/src/helios/nn/__init__.py +15 -0
  39. helios-ml-0.1.0/src/helios/nn/swa_utils.py +61 -0
  40. helios-ml-0.1.0/src/helios/nn/utils.py +56 -0
  41. helios-ml-0.1.0/src/helios/onnx.py +57 -0
  42. helios-ml-0.1.0/src/helios/optim/__init__.py +9 -0
  43. helios-ml-0.1.0/src/helios/optim/utils.py +37 -0
  44. helios-ml-0.1.0/src/helios/py.typed +0 -0
  45. helios-ml-0.1.0/src/helios/scheduler/__init__.py +15 -0
  46. helios-ml-0.1.0/src/helios/scheduler/schedulers.py +148 -0
  47. helios-ml-0.1.0/src/helios/scheduler/utils.py +38 -0
  48. helios-ml-0.1.0/src/helios/trainer.py +998 -0
  49. helios-ml-0.1.0/src/helios_ml.egg-info/PKG-INFO +116 -0
  50. helios-ml-0.1.0/src/helios_ml.egg-info/SOURCES.txt +72 -0
  51. helios-ml-0.1.0/src/helios_ml.egg-info/dependency_links.txt +1 -0
  52. helios-ml-0.1.0/src/helios_ml.egg-info/requires.txt +27 -0
  53. helios-ml-0.1.0/src/helios_ml.egg-info/top_level.txt +1 -0
  54. helios-ml-0.1.0/test/__init__.py +0 -0
  55. helios-ml-0.1.0/test/conftest.py +90 -0
  56. helios-ml-0.1.0/test/registry_test/__init__.py +3 -0
  57. helios-ml-0.1.0/test/registry_test/extra.py +2 -0
  58. helios-ml-0.1.0/test/registry_test/foo.py +6 -0
  59. helios-ml-0.1.0/test/registry_test/func_registry.py +7 -0
  60. helios-ml-0.1.0/test/registry_test/nested/__init__.py +0 -0
  61. helios-ml-0.1.0/test/registry_test/nested/bar.py +6 -0
  62. helios-ml-0.1.0/test/registry_test/nested/extra.py +2 -0
  63. helios-ml-0.1.0/test/registry_test/py.typed +0 -0
  64. helios-ml-0.1.0/test/test_core.py +158 -0
  65. helios-ml-0.1.0/test/test_data.py +237 -0
  66. helios-ml-0.1.0/test/test_losses.py +28 -0
  67. helios-ml-0.1.0/test/test_metrics.py +20 -0
  68. helios-ml-0.1.0/test/test_model.py +78 -0
  69. helios-ml-0.1.0/test/test_nn.py +39 -0
  70. helios-ml-0.1.0/test/test_onnx.py +34 -0
  71. helios-ml-0.1.0/test/test_optim.py +9 -0
  72. helios-ml-0.1.0/test/test_scheduler.py +17 -0
  73. helios-ml-0.1.0/test/test_trainer.py +350 -0
  74. helios-ml-0.1.0/tools/generate_requirements.py +49 -0
@@ -0,0 +1,46 @@
1
+ name: "Publish"
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ permissions:
8
+ contents: read
9
+
10
+ jobs:
11
+ publish:
12
+ runs-on: ubuntu-latest
13
+ environment:
14
+ name: testpypi
15
+ url: https://pypi.org/project/helios-ml
16
+ permissions:
17
+ id-token: write
18
+ steps:
19
+ - name: Checkout
20
+ uses: actions/checkout@v3
21
+
22
+ - name: Setup Python
23
+ uses: actions/setup-python@v4
24
+ id: setup_python
25
+ with:
26
+ python-version: "3.11"
27
+
28
+ - name: Cache virtualenv
29
+ uses: actions/cache@v3
30
+ with:
31
+ key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('requirements/ci.txt') }}
32
+ path: venv
33
+
34
+ - name: Install dependencies
35
+ run: |
36
+ python -m venv venv
37
+ source venv/bin/activate
38
+ python -m pip install -r requirements/ci.txt
39
+ echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
40
+ echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
41
+
42
+ - name: Build Helios
43
+ run: |
44
+ python -m build
45
+ - name: Publish Helios
46
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,66 @@
1
+ name: "Tests"
2
+
3
+ on:
4
+ push:
5
+ branches: [ master ]
6
+ paths-ignore:
7
+ - 'README.md'
8
+ - '.pre-commit-config.yaml'
9
+ - 'data/**'
10
+ - "LICENSE"
11
+ pull_request:
12
+ branches: [ master ]
13
+ paths-ignore:
14
+ - 'README.md'
15
+ - 'LICENSE'
16
+ workflow_dispatch:
17
+
18
+ jobs:
19
+ test:
20
+ runs-on: ubuntu-latest
21
+ strategy:
22
+ fail-fast: false
23
+ steps:
24
+ - name: Checkout
25
+ uses: actions/checkout@v3
26
+
27
+ - name: Setup Python
28
+ uses: actions/setup-python@v4
29
+ id: setup_python
30
+ with:
31
+ python-version: "3.11"
32
+
33
+ - name: Cache virtualenv
34
+ uses: actions/cache@v3
35
+ with:
36
+ key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('requirements/ci.txt') }}
37
+ path: venv
38
+
39
+ - name: Install dependencies
40
+ run: |
41
+ python -m venv venv
42
+ source venv/bin/activate
43
+ python -m pip install -r requirements/ci.txt
44
+ echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
45
+ echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
46
+
47
+ - name: Install Helios
48
+ shell: bash
49
+ run: |
50
+ pip install .
51
+
52
+ - name: Ruff
53
+ run: |
54
+ ruff check src/helios
55
+ ruff check test
56
+ ruff check examples/cifar10
57
+
58
+ - name: Mypy
59
+ run: |
60
+ mypy src/helios
61
+ mypy test
62
+ mypy examples/cifar10
63
+
64
+ - name: Pytest
65
+ run: |
66
+ python -m pytest
@@ -0,0 +1,20 @@
1
+ # Virtual envs.
2
+ venv/
3
+ .venv/
4
+
5
+ # Distribution/packaging
6
+ dist/
7
+ build/
8
+ *.egg-info/
9
+
10
+ __pycache__/
11
+ .mypy_cache/
12
+ .pytest_cache/
13
+ .ruff_cache/
14
+
15
+ # Generated folders from examples
16
+ *.pth
17
+ examples/**/runs
18
+ examples/**/chkpt
19
+ examples/**/logs
20
+ examples/**/data
@@ -0,0 +1,26 @@
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: c4a0b883114b00d8d76b479c820ce7950211c99b # frozen: v4.5.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: debug-statements
7
+ - id: check-ast
8
+ - id: mixed-line-ending
9
+ args: [--fix=lf]
10
+ - id: check-yaml
11
+ args: [--allow-multiple-documents]
12
+ - id: check-json
13
+ - id: check-added-large-files
14
+
15
+ - repo: https://github.com/astral-sh/ruff-pre-commit
16
+ rev: v0.3.4
17
+ hooks:
18
+ - id: ruff-format
19
+
20
+ - repo: local
21
+ hooks:
22
+ - id: generate_requirements.py
23
+ name: generate_requirements.py
24
+ language: system
25
+ entry: python tools/generate_requirements.py
26
+ files: "pyproject.toml|requirements/.*\\.txt|tools/generate_requirements.py"
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, Mauricio A Rovira Galvez
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.1
2
+ Name: helios-ml
3
+ Version: 0.1.0
4
+ Summary: A Torch-based package for training AI networks
5
+ Author: Mauricio A. Rovira Galvez
6
+ Project-URL: Homepage, https://github.com/marovira/pyro-ml
7
+ Project-URL: Issues, https://github.com/marovira/pyro-ml/issues
8
+ Requires-Python: >=3.11
9
+ Description-Content-Type: text/x-rst
10
+ License-File: LICENSE
11
+ Requires-Dist: tqdm>=4.66.2
12
+ Requires-Dist: opencv-python>=4.9.0.80
13
+ Requires-Dist: protobuf!=4.24.0,<5.0.0,>=3.19.6
14
+ Requires-Dist: tensorboard>=2.16.2
15
+ Requires-Dist: torch>=2.2.1
16
+ Requires-Dist: torchvision>=0.17.1
17
+ Requires-Dist: onnx>=1.16.0
18
+ Requires-Dist: onnxruntime>=1.17.1
19
+ Requires-Dist: matplotlib>=3.8.4
20
+ Provides-Extra: dev
21
+ Requires-Dist: mypy>=1.8.0; extra == "dev"
22
+ Requires-Dist: ruff>=0.3.4; extra == "dev"
23
+ Requires-Dist: pytest>=8.1.1; extra == "dev"
24
+ Requires-Dist: pre-commit>=3.6.2; extra == "dev"
25
+ Requires-Dist: types-Pillow>=10.2.0.20240311; extra == "dev"
26
+ Requires-Dist: types-tqdm>=4.66.0.20240106; extra == "dev"
27
+ Requires-Dist: build>=1.2.1; extra == "dev"
28
+ Requires-Dist: twine>=5.0.0; extra == "dev"
29
+ Provides-Extra: ci
30
+ Requires-Dist: mypy>=1.8.0; extra == "ci"
31
+ Requires-Dist: ruff>=0.3.4; extra == "ci"
32
+ Requires-Dist: pytest>=8.1.1; extra == "ci"
33
+ Requires-Dist: types-Pillow>=10.2.0.20240311; extra == "ci"
34
+ Requires-Dist: types-tqdm>=4.66.0.20240106; extra == "ci"
35
+ Requires-Dist: build>=1.2.1; extra == "ci"
36
+
37
+ <a id="top"></a>
38
+ ![HELIOS logo](data/logo/logo-transparent.png)
39
+
40
+ [![Generic badge](https://img.shields.io/badge/License-BSD3-blue)](LICENSE)
41
+ [![Static Badge](https://img.shields.io/badge/Python-3.11%2B-red?logoColor=red)](https://www.python.org/downloads/release/python-3110/)
42
+ [![Tests](https://github.com/marovira/helios-ml/actions/workflows/tests.yml/badge.svg)](https://github.com/marovira/helios-ml/actions/workflows/tests.yml)
43
+
44
+ ## What is Helios?
45
+
46
+ Named after Greek god of the sun, Helios is a light-weight package for training ML
47
+ networks built on top of PyTorch. It is designed to abstract all of the "boiler-plate"
48
+ code involved with training. Specifically, it wraps the following common patterns:
49
+
50
+ * Creation of the dataloaders.
51
+ * Initialization of CUDA, PyTorch, and random number states.
52
+ * Initialization for distributed training.
53
+ * Training, validation, and testing loops.
54
+ * Saving and loading checkpoints.
55
+ * Exporting to ONNX.
56
+
57
+ It is important to note that Helios is **not** a fully fledged training environment similar
58
+ to [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). Instead, Helios
59
+ is focused on providing a simple and straight-forward interface that abstracts most of the
60
+ common code patterns while retaining the ability to be easily overridden to suit the
61
+ individual needs of each training scheme.
62
+
63
+ ## Main Features
64
+
65
+ Helios offers the following functionality out of the box:
66
+
67
+ 1. Resume training: Helios has been built with the ability to resume training if it is
68
+ paused. Specifically, Helios will ensure that the behaviour of the trained model is
69
+ *identical* to the one it would've had if it had been trained without pauses.
70
+ 2. Automatic detection of multi-GPU environments for distributed training. In addition,
71
+ Helios also supports training using `torchrun` and will automatically handle the
72
+ initialisation and clean up of the distributed state. It will also correctly set the
73
+ devices and maps to ensure weights are mapped tot he correct location.
74
+ 3. Registries for creation of arbitrary types. These include: networks, loss functions,
75
+ optimizers, schedulers, etc.
76
+ 4. Correct handling of logging when doing distributed training (even over multiple nodes).
77
+
78
+ ## Installation
79
+
80
+ You can install Helios using Pip as follows:
81
+
82
+ ```sh
83
+ pip install helios-ml
84
+ ```
85
+
86
+ ## Documentation
87
+
88
+ Documentation coming soon!
89
+
90
+ ## Contributing
91
+
92
+ There are three ways in which you can contribute to Helios:
93
+
94
+ * If you find a bug, please open an issue. Similarly, if you have a question
95
+ about how to use it, or if something is unclear, please post an issue so it
96
+ can be addressed.
97
+ * If you have a fix for a bug, or a code enhancement, please open a pull
98
+ request. Before you submit it though, make sure to abide by the rules written
99
+ below.
100
+ * If you have a feature proposal, you can either open an issue or create a pull
101
+ request. If you are submitting a pull request, it must abide by the rules
102
+ written below. Note that any new features need to be approved by me.
103
+
104
+ If you are submitting a pull request, the guidelines are the following:
105
+
106
+ 1. Ensure that your code follows the standards and formatting of Helios. The coding
107
+ standards and formatting are enforced through the Ruff Linter and Formatter. Any
108
+ changes that do not abide by these rules will be rejected. It is your responsibility to
109
+ ensure that both Ruff and Mypy linters pass.
110
+ 2. Ensure that *all* unit tests are working prior to submitting the pull
111
+ request. If you are adding a new feature that has been approved, it is your
112
+ responsibility to provide the corresponding unit tests (if applicable).
113
+
114
+ ## License
115
+
116
+ Helios is published under the BSD-3 license and can be viewed [here](LICENSE).
@@ -0,0 +1,80 @@
1
+ <a id="top"></a>
2
+ ![HELIOS logo](data/logo/logo-transparent.png)
3
+
4
+ [![Generic badge](https://img.shields.io/badge/License-BSD3-blue)](LICENSE)
5
+ [![Static Badge](https://img.shields.io/badge/Python-3.11%2B-red?logoColor=red)](https://www.python.org/downloads/release/python-3110/)
6
+ [![Tests](https://github.com/marovira/helios-ml/actions/workflows/tests.yml/badge.svg)](https://github.com/marovira/helios-ml/actions/workflows/tests.yml)
7
+
8
+ ## What is Helios?
9
+
10
+ Named after Greek god of the sun, Helios is a light-weight package for training ML
11
+ networks built on top of PyTorch. It is designed to abstract all of the "boiler-plate"
12
+ code involved with training. Specifically, it wraps the following common patterns:
13
+
14
+ * Creation of the dataloaders.
15
+ * Initialization of CUDA, PyTorch, and random number states.
16
+ * Initialization for distributed training.
17
+ * Training, validation, and testing loops.
18
+ * Saving and loading checkpoints.
19
+ * Exporting to ONNX.
20
+
21
+ It is important to note that Helios is **not** a fully fledged training environment similar
22
+ to [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). Instead, Helios
23
+ is focused on providing a simple and straight-forward interface that abstracts most of the
24
+ common code patterns while retaining the ability to be easily overridden to suit the
25
+ individual needs of each training scheme.
26
+
27
+ ## Main Features
28
+
29
+ Helios offers the following functionality out of the box:
30
+
31
+ 1. Resume training: Helios has been built with the ability to resume training if it is
32
+ paused. Specifically, Helios will ensure that the behaviour of the trained model is
33
+ *identical* to the one it would've had if it had been trained without pauses.
34
+ 2. Automatic detection of multi-GPU environments for distributed training. In addition,
35
+ Helios also supports training using `torchrun` and will automatically handle the
36
+ initialisation and clean up of the distributed state. It will also correctly set the
37
+ devices and maps to ensure weights are mapped tot he correct location.
38
+ 3. Registries for creation of arbitrary types. These include: networks, loss functions,
39
+ optimizers, schedulers, etc.
40
+ 4. Correct handling of logging when doing distributed training (even over multiple nodes).
41
+
42
+ ## Installation
43
+
44
+ You can install Helios using Pip as follows:
45
+
46
+ ```sh
47
+ pip install helios-ml
48
+ ```
49
+
50
+ ## Documentation
51
+
52
+ Documentation coming soon!
53
+
54
+ ## Contributing
55
+
56
+ There are three ways in which you can contribute to Helios:
57
+
58
+ * If you find a bug, please open an issue. Similarly, if you have a question
59
+ about how to use it, or if something is unclear, please post an issue so it
60
+ can be addressed.
61
+ * If you have a fix for a bug, or a code enhancement, please open a pull
62
+ request. Before you submit it though, make sure to abide by the rules written
63
+ below.
64
+ * If you have a feature proposal, you can either open an issue or create a pull
65
+ request. If you are submitting a pull request, it must abide by the rules
66
+ written below. Note that any new features need to be approved by me.
67
+
68
+ If you are submitting a pull request, the guidelines are the following:
69
+
70
+ 1. Ensure that your code follows the standards and formatting of Helios. The coding
71
+ standards and formatting are enforced through the Ruff Linter and Formatter. Any
72
+ changes that do not abide by these rules will be rejected. It is your responsibility to
73
+ ensure that both Ruff and Mypy linters pass.
74
+ 2. Ensure that *all* unit tests are working prior to submitting the pull
75
+ request. If you are adding a new feature that has been approved, it is your
76
+ responsibility to provide the corresponding unit tests (if applicable).
77
+
78
+ ## License
79
+
80
+ Helios is published under the BSD-3 license and can be viewed [here](LICENSE).
@@ -0,0 +1,297 @@
1
+ import os
2
+ import pathlib
3
+ import typing
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+ import torchvision.transforms.v2 as T
9
+ from torch import nn
10
+
11
+ import helios.core as hlc
12
+ import helios.data as hld
13
+ import helios.model as hlm
14
+ import helios.optim as hlo
15
+ import helios.trainer as hlt
16
+ from helios.core import logging
17
+
18
+
19
+ class CIFARDataModule(hld.DataModule):
20
+ """
21
+ Example datamodule class built with CIFAR10.
22
+
23
+ Here you can see how to setup a datamodule that requires data to be downloaded using
24
+ torchvision's CIFAR10 dataset. The code is adapted from PyTorch's "Training a
25
+ Classifier" tutorial:
26
+ https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
27
+
28
+ Args:
29
+ root (pathlib.Path): the root where the data will be downloaded to.
30
+ """
31
+
32
+ def __init__(self, root: pathlib.Path) -> None:
33
+ """Construct the datamodule."""
34
+ super().__init__()
35
+ self._root = root / "data"
36
+
37
+ def prepare_data(self) -> None:
38
+ """Download all necessary data."""
39
+ torchvision.datasets.CIFAR10(root=self._root, train=True, download=True)
40
+ torchvision.datasets.CIFAR10(root=self._root, train=False, download=True)
41
+
42
+ def setup(self) -> None:
43
+ """Create the datasets."""
44
+ # Use the ToTensor transform from Pyro to automate the conversion from images to
45
+ # tensors.
46
+ transforms = T.Compose(
47
+ [hld.transforms.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
48
+ )
49
+ params = hld.DataLoaderParams()
50
+ params.batch_size = 4
51
+ params.shuffle = True
52
+ params.num_workers = 2
53
+ params.drop_last = True
54
+ self._train_dataset = self._create_dataset(
55
+ torchvision.datasets.CIFAR10(
56
+ root=self._root, train=True, download=False, transform=transforms
57
+ ),
58
+ params,
59
+ )
60
+
61
+ # The dataloader params are copied when the dataset is created, so we can safely
62
+ # change the options for the validation dataset without interfering with the ones
63
+ # for training.
64
+ params.drop_last = False
65
+ params.shuffle = False
66
+ self._valid_dataset = self._create_dataset(
67
+ torchvision.datasets.CIFAR10(
68
+ root=self._root, train=False, download=False, transform=transforms
69
+ ),
70
+ params,
71
+ )
72
+
73
+
74
+ class Net(nn.Module):
75
+ """
76
+ Example image classifier.
77
+
78
+ The code is taken from PyTorch's Training a Classifier tutorial:
79
+ https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
80
+ """
81
+
82
+ def __init__(self):
83
+ """Create the classifier."""
84
+ super().__init__()
85
+ self.conv1 = nn.Conv2d(3, 6, 5)
86
+ self.pool = nn.MaxPool2d(2, 2)
87
+ self.conv2 = nn.Conv2d(6, 16, 5)
88
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
89
+ self.fc2 = nn.Linear(120, 84)
90
+ self.fc3 = nn.Linear(84, 10)
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ """Compute the label for the given image."""
94
+ x = self.pool(F.relu(self.conv1(x)))
95
+ x = self.pool(F.relu(self.conv2(x)))
96
+ x = torch.flatten(x, 1) # flatten all dimensions except batch
97
+ x = F.relu(self.fc1(x))
98
+ x = F.relu(self.fc2(x))
99
+ x = self.fc3(x)
100
+ return x
101
+
102
+
103
+ class ClassifierModel(hlm.Model):
104
+ """
105
+ Example model class for training the classifier.
106
+
107
+ Here you can see how to setup the model class and some of the basic functionality that
108
+ is available. The code is adapted from PyTorch's "Training a Classifier" tutorial:
109
+ https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
110
+ """
111
+
112
+ def __init__(self) -> None:
113
+ """Create the model."""
114
+ super().__init__("classifier")
115
+
116
+ def setup(self, fast_init: bool = False) -> None:
117
+ """Create the network, loss, and optimizer."""
118
+ # Note that when we create the network and loss function, we immediately move them
119
+ # to the current device, which has been set by the trainer.
120
+ self._net = Net().to(self.device)
121
+ self._criterion = nn.CrossEntropyLoss().to(self.device)
122
+
123
+ # Note that SGD is shipped as part of the default optimizers from Pyro, so we can
124
+ # directly request it from create_optimizer instead of building it ourselves.
125
+ self._optimizer = hlo.create_optimizer(
126
+ "SGD", self._net.parameters(), lr=0.001, momentum=0.9
127
+ )
128
+
129
+ def load_state_dict(
130
+ self, state_dict: dict[str, typing.Any], fast_init: bool = False
131
+ ) -> None:
132
+ """Restore the model from a saved checkpoint."""
133
+ # Note that we don't have to re-map the weights ourselves. They have already been
134
+ # re-mapped for us by the trainer when it loaded the checkpoint.
135
+ self._net.load_state_dict(state_dict["net"])
136
+ self._criterion.load_state_dict(state_dict["criterion"])
137
+ self._optimizer.load_state_dict(state_dict["optimizer"])
138
+
139
+ def state_dict(self) -> dict[str, typing.Any]:
140
+ """Return the state dict of the model."""
141
+ return {
142
+ "net": self._net.state_dict(),
143
+ "criterion": self._criterion.state_dict(),
144
+ "optimizer": self._optimizer.state_dict(),
145
+ }
146
+
147
+ def train(self) -> None:
148
+ """Set the model to train."""
149
+ # If we had more networks, we would shift them to training mode here.
150
+ self._net.train()
151
+
152
+ def on_training_start(self) -> None:
153
+ """Perform steps before training starts."""
154
+ tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
155
+
156
+ x = torch.randn((1, 3, 32, 32)).to(self.device)
157
+ tb_logger.add_graph(self._net, x)
158
+
159
+ def train_step(self, batch: typing.Any, state: hlt.TrainingState) -> None:
160
+ """Forward and backward training passes."""
161
+ # Due to the simplicity of the code, we do both the forward and backward passes in
162
+ # the training step, but you could also split it between the train_step and
163
+ # on_training_batch_end if your setup is more complex.
164
+ inputs, labels = batch
165
+ inputs = inputs.to(self.device)
166
+ labels = labels.to(self.device)
167
+
168
+ self._optimizer.zero_grad()
169
+
170
+ outputs = self._net(inputs)
171
+ loss = self._criterion(outputs, labels)
172
+ loss.backward()
173
+ self._optimizer.step()
174
+
175
+ # Note that we save the value of the loss function to the _loss_items dictionary.
176
+ # This allows the model to automatically gather the losses for us if we are
177
+ # training in distributed mode.
178
+ self._loss_items["loss"] = loss
179
+
180
+ def on_training_batch_end(
181
+ self,
182
+ state: hlt.TrainingState,
183
+ should_log: bool = False,
184
+ ) -> None:
185
+ """
186
+ Perform steps after the training batch.
187
+
188
+ Args:
189
+ state (pyt.TrainingState): the current training state.
190
+ should_log (bool): if true, then writing to the log should be performed.
191
+ """
192
+ # If we were training in distributed mode, calling the base Model's function will
193
+ # automatically gather all loss values saved to self._loss_items.
194
+ super().on_training_batch_end(state, should_log)
195
+
196
+ # This flag is set to true whenever the number of iterations is a multiple of the
197
+ # logging frequency we set when creating the trainer.
198
+ if should_log:
199
+ root_logger = logging.get_root_logger()
200
+ tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
201
+
202
+ loss_val = self._loss_items["loss"]
203
+
204
+ root_logger.info(
205
+ f"[{state.global_epoch + 1}, {state.global_iteration:5d}] "
206
+ f"loss: {loss_val:.3f}, "
207
+ f"running loss: {loss_val / state.running_iter:.3f} "
208
+ f"avg time: {state.average_iter_time:.2f}s"
209
+ )
210
+ tb_logger.add_scalar("train/loss", loss_val, state.global_iteration)
211
+ tb_logger.add_scalar(
212
+ "train/running loss",
213
+ loss_val / state.running_iter,
214
+ state.global_iteration,
215
+ )
216
+
217
+ def on_training_end(self) -> None:
218
+ """Perform steps after training ends."""
219
+ # For our example, we're going to save the hyper-params to the tensorboard log so
220
+ # we can compare with other runs.
221
+ # Notice that self._val_scores is active because we validate every epoch. If your
222
+ # validation frequency is different, you may need to alter this code.
223
+ total = self._val_scores["total"]
224
+ correct = self._val_scores["correct"]
225
+ accuracy = 100 * correct // total
226
+ writer = hlc.get_from_optional(logging.get_tensorboard_writer())
227
+ writer.add_hparams(
228
+ {"lr": 0.001, "momentum": 0.9, "epochs": 2},
229
+ {"hparam/accuracy": accuracy, "hparam/loss": self._loss_items["loss"].item()},
230
+ )
231
+
232
+ def eval(self) -> None:
233
+ """Set the model to eval."""
234
+ self._net.eval()
235
+
236
+ def on_validation_start(self, validation_cycle: int) -> None:
237
+ """Perform steps on validation start."""
238
+ # The base function will automatically clear the validation scores for us.
239
+ super().on_validation_start(validation_cycle)
240
+
241
+ # If you need to add further data to the table, you can do so here. In our case,
242
+ # we're going to add the total number of labels seen and how many of those were
243
+ # correct so we can compute the accuracy metric.
244
+ self._val_scores["total"] = 0
245
+ self._val_scores["correct"] = 0
246
+
247
+ def valid_step(self, batch: typing.Any, step: int) -> None:
248
+ """Perform the validation step."""
249
+ images, labels = batch
250
+ images = images.to(self.device)
251
+ labels = labels.to(self.device)
252
+
253
+ outputs = self._net(images)
254
+
255
+ _, predicted = torch.max(outputs.data, 1)
256
+ self._val_scores["total"] += labels.size(0)
257
+ self._val_scores["correct"] += (predicted == labels).sum().item()
258
+
259
+ def on_validation_end(self, validation_cycle: int) -> None:
260
+ """Perform steps after validation ends."""
261
+ root_logger = logging.get_root_logger()
262
+ tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
263
+
264
+ # Grab the validation scores and compute the accuracy metric so we can log it. If
265
+ # we were in distributed mode, you would need to gather the values here.
266
+ total = self._val_scores["total"]
267
+ correct = self._val_scores["correct"]
268
+ accuracy = 100 * correct // total
269
+
270
+ root_logger.info(f"[Validation {validation_cycle}] accuracy: {accuracy}")
271
+ tb_logger.add_scalar("val", accuracy, validation_cycle)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ # Set the CUBLAS workspace setting to allow determinism to be used in CUDA >= 10.2.
276
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
277
+
278
+ datamodule = CIFARDataModule(pathlib.Path.cwd())
279
+ model = ClassifierModel()
280
+
281
+ trainer = hlt.Trainer(
282
+ run_name="cifar10",
283
+ train_unit=hlt.TrainingUnit.EPOCH,
284
+ total_steps=2,
285
+ valid_frequency=1,
286
+ chkpt_frequency=1,
287
+ print_frequency=10,
288
+ enable_tensorboard=True,
289
+ enable_file_logging=True,
290
+ enable_progress_bar=True,
291
+ enable_deterministic=True,
292
+ chkpt_root=pathlib.Path.cwd() / "chkpt",
293
+ log_path=pathlib.Path.cwd() / "logs",
294
+ run_path=pathlib.Path.cwd() / "runs",
295
+ )
296
+
297
+ trainer.fit(model, datamodule)