helios-ml 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- helios-ml-0.1.0/.github/workflows/publish.yml +46 -0
- helios-ml-0.1.0/.github/workflows/tests.yml +66 -0
- helios-ml-0.1.0/.gitignore +20 -0
- helios-ml-0.1.0/.pre-commit-config.yaml +26 -0
- helios-ml-0.1.0/LICENSE +28 -0
- helios-ml-0.1.0/PKG-INFO +116 -0
- helios-ml-0.1.0/README.md +80 -0
- helios-ml-0.1.0/data/logo/logo-background.png +0 -0
- helios-ml-0.1.0/data/logo/logo-transparent.png +0 -0
- helios-ml-0.1.0/examples/cifar10/cifar10.py +297 -0
- helios-ml-0.1.0/pyproject.toml +119 -0
- helios-ml-0.1.0/requirements/ci.txt +17 -0
- helios-ml-0.1.0/requirements/default.txt +11 -0
- helios-ml-0.1.0/requirements/dev.txt +19 -0
- helios-ml-0.1.0/setup.cfg +4 -0
- helios-ml-0.1.0/src/helios/__init__.py +1 -0
- helios-ml-0.1.0/src/helios/_version.py +5 -0
- helios-ml-0.1.0/src/helios/core/__init__.py +25 -0
- helios-ml-0.1.0/src/helios/core/cuda.py +7 -0
- helios-ml-0.1.0/src/helios/core/distributed.py +179 -0
- helios-ml-0.1.0/src/helios/core/logging.py +533 -0
- helios-ml-0.1.0/src/helios/core/rng.py +151 -0
- helios-ml-0.1.0/src/helios/core/utils.py +417 -0
- helios-ml-0.1.0/src/helios/data/__init__.py +26 -0
- helios-ml-0.1.0/src/helios/data/datamodule.py +363 -0
- helios-ml-0.1.0/src/helios/data/functional.py +186 -0
- helios-ml-0.1.0/src/helios/data/samplers.py +229 -0
- helios-ml-0.1.0/src/helios/data/transforms.py +74 -0
- helios-ml-0.1.0/src/helios/losses/__init__.py +10 -0
- helios-ml-0.1.0/src/helios/losses/utils.py +21 -0
- helios-ml-0.1.0/src/helios/losses/weighted_loss.py +56 -0
- helios-ml-0.1.0/src/helios/metrics/__init__.py +27 -0
- helios-ml-0.1.0/src/helios/metrics/functional.py +447 -0
- helios-ml-0.1.0/src/helios/metrics/metrics.py +252 -0
- helios-ml-0.1.0/src/helios/model/__init__.py +10 -0
- helios-ml-0.1.0/src/helios/model/model.py +369 -0
- helios-ml-0.1.0/src/helios/model/utils.py +53 -0
- helios-ml-0.1.0/src/helios/nn/__init__.py +15 -0
- helios-ml-0.1.0/src/helios/nn/swa_utils.py +61 -0
- helios-ml-0.1.0/src/helios/nn/utils.py +56 -0
- helios-ml-0.1.0/src/helios/onnx.py +57 -0
- helios-ml-0.1.0/src/helios/optim/__init__.py +9 -0
- helios-ml-0.1.0/src/helios/optim/utils.py +37 -0
- helios-ml-0.1.0/src/helios/py.typed +0 -0
- helios-ml-0.1.0/src/helios/scheduler/__init__.py +15 -0
- helios-ml-0.1.0/src/helios/scheduler/schedulers.py +148 -0
- helios-ml-0.1.0/src/helios/scheduler/utils.py +38 -0
- helios-ml-0.1.0/src/helios/trainer.py +998 -0
- helios-ml-0.1.0/src/helios_ml.egg-info/PKG-INFO +116 -0
- helios-ml-0.1.0/src/helios_ml.egg-info/SOURCES.txt +72 -0
- helios-ml-0.1.0/src/helios_ml.egg-info/dependency_links.txt +1 -0
- helios-ml-0.1.0/src/helios_ml.egg-info/requires.txt +27 -0
- helios-ml-0.1.0/src/helios_ml.egg-info/top_level.txt +1 -0
- helios-ml-0.1.0/test/__init__.py +0 -0
- helios-ml-0.1.0/test/conftest.py +90 -0
- helios-ml-0.1.0/test/registry_test/__init__.py +3 -0
- helios-ml-0.1.0/test/registry_test/extra.py +2 -0
- helios-ml-0.1.0/test/registry_test/foo.py +6 -0
- helios-ml-0.1.0/test/registry_test/func_registry.py +7 -0
- helios-ml-0.1.0/test/registry_test/nested/__init__.py +0 -0
- helios-ml-0.1.0/test/registry_test/nested/bar.py +6 -0
- helios-ml-0.1.0/test/registry_test/nested/extra.py +2 -0
- helios-ml-0.1.0/test/registry_test/py.typed +0 -0
- helios-ml-0.1.0/test/test_core.py +158 -0
- helios-ml-0.1.0/test/test_data.py +237 -0
- helios-ml-0.1.0/test/test_losses.py +28 -0
- helios-ml-0.1.0/test/test_metrics.py +20 -0
- helios-ml-0.1.0/test/test_model.py +78 -0
- helios-ml-0.1.0/test/test_nn.py +39 -0
- helios-ml-0.1.0/test/test_onnx.py +34 -0
- helios-ml-0.1.0/test/test_optim.py +9 -0
- helios-ml-0.1.0/test/test_scheduler.py +17 -0
- helios-ml-0.1.0/test/test_trainer.py +350 -0
- helios-ml-0.1.0/tools/generate_requirements.py +49 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
name: "Publish"
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
permissions:
|
|
8
|
+
contents: read
|
|
9
|
+
|
|
10
|
+
jobs:
|
|
11
|
+
publish:
|
|
12
|
+
runs-on: ubuntu-latest
|
|
13
|
+
environment:
|
|
14
|
+
name: testpypi
|
|
15
|
+
url: https://pypi.org/project/helios-ml
|
|
16
|
+
permissions:
|
|
17
|
+
id-token: write
|
|
18
|
+
steps:
|
|
19
|
+
- name: Checkout
|
|
20
|
+
uses: actions/checkout@v3
|
|
21
|
+
|
|
22
|
+
- name: Setup Python
|
|
23
|
+
uses: actions/setup-python@v4
|
|
24
|
+
id: setup_python
|
|
25
|
+
with:
|
|
26
|
+
python-version: "3.11"
|
|
27
|
+
|
|
28
|
+
- name: Cache virtualenv
|
|
29
|
+
uses: actions/cache@v3
|
|
30
|
+
with:
|
|
31
|
+
key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('requirements/ci.txt') }}
|
|
32
|
+
path: venv
|
|
33
|
+
|
|
34
|
+
- name: Install dependencies
|
|
35
|
+
run: |
|
|
36
|
+
python -m venv venv
|
|
37
|
+
source venv/bin/activate
|
|
38
|
+
python -m pip install -r requirements/ci.txt
|
|
39
|
+
echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
|
|
40
|
+
echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
|
|
41
|
+
|
|
42
|
+
- name: Build Helios
|
|
43
|
+
run: |
|
|
44
|
+
python -m build
|
|
45
|
+
- name: Publish Helios
|
|
46
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
name: "Tests"
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [ master ]
|
|
6
|
+
paths-ignore:
|
|
7
|
+
- 'README.md'
|
|
8
|
+
- '.pre-commit-config.yaml'
|
|
9
|
+
- 'data/**'
|
|
10
|
+
- "LICENSE"
|
|
11
|
+
pull_request:
|
|
12
|
+
branches: [ master ]
|
|
13
|
+
paths-ignore:
|
|
14
|
+
- 'README.md'
|
|
15
|
+
- 'LICENSE'
|
|
16
|
+
workflow_dispatch:
|
|
17
|
+
|
|
18
|
+
jobs:
|
|
19
|
+
test:
|
|
20
|
+
runs-on: ubuntu-latest
|
|
21
|
+
strategy:
|
|
22
|
+
fail-fast: false
|
|
23
|
+
steps:
|
|
24
|
+
- name: Checkout
|
|
25
|
+
uses: actions/checkout@v3
|
|
26
|
+
|
|
27
|
+
- name: Setup Python
|
|
28
|
+
uses: actions/setup-python@v4
|
|
29
|
+
id: setup_python
|
|
30
|
+
with:
|
|
31
|
+
python-version: "3.11"
|
|
32
|
+
|
|
33
|
+
- name: Cache virtualenv
|
|
34
|
+
uses: actions/cache@v3
|
|
35
|
+
with:
|
|
36
|
+
key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('requirements/ci.txt') }}
|
|
37
|
+
path: venv
|
|
38
|
+
|
|
39
|
+
- name: Install dependencies
|
|
40
|
+
run: |
|
|
41
|
+
python -m venv venv
|
|
42
|
+
source venv/bin/activate
|
|
43
|
+
python -m pip install -r requirements/ci.txt
|
|
44
|
+
echo "$VIRTUAL_ENV/bin" >> $GITHUB_PATH
|
|
45
|
+
echo "VIRTUAL_ENV=$VIRTUAL_ENV" >> $GITHUB_ENV
|
|
46
|
+
|
|
47
|
+
- name: Install Helios
|
|
48
|
+
shell: bash
|
|
49
|
+
run: |
|
|
50
|
+
pip install .
|
|
51
|
+
|
|
52
|
+
- name: Ruff
|
|
53
|
+
run: |
|
|
54
|
+
ruff check src/helios
|
|
55
|
+
ruff check test
|
|
56
|
+
ruff check examples/cifar10
|
|
57
|
+
|
|
58
|
+
- name: Mypy
|
|
59
|
+
run: |
|
|
60
|
+
mypy src/helios
|
|
61
|
+
mypy test
|
|
62
|
+
mypy examples/cifar10
|
|
63
|
+
|
|
64
|
+
- name: Pytest
|
|
65
|
+
run: |
|
|
66
|
+
python -m pytest
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Virtual envs.
|
|
2
|
+
venv/
|
|
3
|
+
.venv/
|
|
4
|
+
|
|
5
|
+
# Distribution/packaging
|
|
6
|
+
dist/
|
|
7
|
+
build/
|
|
8
|
+
*.egg-info/
|
|
9
|
+
|
|
10
|
+
__pycache__/
|
|
11
|
+
.mypy_cache/
|
|
12
|
+
.pytest_cache/
|
|
13
|
+
.ruff_cache/
|
|
14
|
+
|
|
15
|
+
# Generated folders from examples
|
|
16
|
+
*.pth
|
|
17
|
+
examples/**/runs
|
|
18
|
+
examples/**/chkpt
|
|
19
|
+
examples/**/logs
|
|
20
|
+
examples/**/data
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
3
|
+
rev: c4a0b883114b00d8d76b479c820ce7950211c99b # frozen: v4.5.0
|
|
4
|
+
hooks:
|
|
5
|
+
- id: trailing-whitespace
|
|
6
|
+
- id: debug-statements
|
|
7
|
+
- id: check-ast
|
|
8
|
+
- id: mixed-line-ending
|
|
9
|
+
args: [--fix=lf]
|
|
10
|
+
- id: check-yaml
|
|
11
|
+
args: [--allow-multiple-documents]
|
|
12
|
+
- id: check-json
|
|
13
|
+
- id: check-added-large-files
|
|
14
|
+
|
|
15
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
16
|
+
rev: v0.3.4
|
|
17
|
+
hooks:
|
|
18
|
+
- id: ruff-format
|
|
19
|
+
|
|
20
|
+
- repo: local
|
|
21
|
+
hooks:
|
|
22
|
+
- id: generate_requirements.py
|
|
23
|
+
name: generate_requirements.py
|
|
24
|
+
language: system
|
|
25
|
+
entry: python tools/generate_requirements.py
|
|
26
|
+
files: "pyproject.toml|requirements/.*\\.txt|tools/generate_requirements.py"
|
helios-ml-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024, Mauricio A Rovira Galvez
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
helios-ml-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: helios-ml
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A Torch-based package for training AI networks
|
|
5
|
+
Author: Mauricio A. Rovira Galvez
|
|
6
|
+
Project-URL: Homepage, https://github.com/marovira/pyro-ml
|
|
7
|
+
Project-URL: Issues, https://github.com/marovira/pyro-ml/issues
|
|
8
|
+
Requires-Python: >=3.11
|
|
9
|
+
Description-Content-Type: text/x-rst
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: tqdm>=4.66.2
|
|
12
|
+
Requires-Dist: opencv-python>=4.9.0.80
|
|
13
|
+
Requires-Dist: protobuf!=4.24.0,<5.0.0,>=3.19.6
|
|
14
|
+
Requires-Dist: tensorboard>=2.16.2
|
|
15
|
+
Requires-Dist: torch>=2.2.1
|
|
16
|
+
Requires-Dist: torchvision>=0.17.1
|
|
17
|
+
Requires-Dist: onnx>=1.16.0
|
|
18
|
+
Requires-Dist: onnxruntime>=1.17.1
|
|
19
|
+
Requires-Dist: matplotlib>=3.8.4
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: mypy>=1.8.0; extra == "dev"
|
|
22
|
+
Requires-Dist: ruff>=0.3.4; extra == "dev"
|
|
23
|
+
Requires-Dist: pytest>=8.1.1; extra == "dev"
|
|
24
|
+
Requires-Dist: pre-commit>=3.6.2; extra == "dev"
|
|
25
|
+
Requires-Dist: types-Pillow>=10.2.0.20240311; extra == "dev"
|
|
26
|
+
Requires-Dist: types-tqdm>=4.66.0.20240106; extra == "dev"
|
|
27
|
+
Requires-Dist: build>=1.2.1; extra == "dev"
|
|
28
|
+
Requires-Dist: twine>=5.0.0; extra == "dev"
|
|
29
|
+
Provides-Extra: ci
|
|
30
|
+
Requires-Dist: mypy>=1.8.0; extra == "ci"
|
|
31
|
+
Requires-Dist: ruff>=0.3.4; extra == "ci"
|
|
32
|
+
Requires-Dist: pytest>=8.1.1; extra == "ci"
|
|
33
|
+
Requires-Dist: types-Pillow>=10.2.0.20240311; extra == "ci"
|
|
34
|
+
Requires-Dist: types-tqdm>=4.66.0.20240106; extra == "ci"
|
|
35
|
+
Requires-Dist: build>=1.2.1; extra == "ci"
|
|
36
|
+
|
|
37
|
+
<a id="top"></a>
|
|
38
|
+

|
|
39
|
+
|
|
40
|
+
[](LICENSE)
|
|
41
|
+
[](https://www.python.org/downloads/release/python-3110/)
|
|
42
|
+
[](https://github.com/marovira/helios-ml/actions/workflows/tests.yml)
|
|
43
|
+
|
|
44
|
+
## What is Helios?
|
|
45
|
+
|
|
46
|
+
Named after Greek god of the sun, Helios is a light-weight package for training ML
|
|
47
|
+
networks built on top of PyTorch. It is designed to abstract all of the "boiler-plate"
|
|
48
|
+
code involved with training. Specifically, it wraps the following common patterns:
|
|
49
|
+
|
|
50
|
+
* Creation of the dataloaders.
|
|
51
|
+
* Initialization of CUDA, PyTorch, and random number states.
|
|
52
|
+
* Initialization for distributed training.
|
|
53
|
+
* Training, validation, and testing loops.
|
|
54
|
+
* Saving and loading checkpoints.
|
|
55
|
+
* Exporting to ONNX.
|
|
56
|
+
|
|
57
|
+
It is important to note that Helios is **not** a fully fledged training environment similar
|
|
58
|
+
to [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). Instead, Helios
|
|
59
|
+
is focused on providing a simple and straight-forward interface that abstracts most of the
|
|
60
|
+
common code patterns while retaining the ability to be easily overridden to suit the
|
|
61
|
+
individual needs of each training scheme.
|
|
62
|
+
|
|
63
|
+
## Main Features
|
|
64
|
+
|
|
65
|
+
Helios offers the following functionality out of the box:
|
|
66
|
+
|
|
67
|
+
1. Resume training: Helios has been built with the ability to resume training if it is
|
|
68
|
+
paused. Specifically, Helios will ensure that the behaviour of the trained model is
|
|
69
|
+
*identical* to the one it would've had if it had been trained without pauses.
|
|
70
|
+
2. Automatic detection of multi-GPU environments for distributed training. In addition,
|
|
71
|
+
Helios also supports training using `torchrun` and will automatically handle the
|
|
72
|
+
initialisation and clean up of the distributed state. It will also correctly set the
|
|
73
|
+
devices and maps to ensure weights are mapped tot he correct location.
|
|
74
|
+
3. Registries for creation of arbitrary types. These include: networks, loss functions,
|
|
75
|
+
optimizers, schedulers, etc.
|
|
76
|
+
4. Correct handling of logging when doing distributed training (even over multiple nodes).
|
|
77
|
+
|
|
78
|
+
## Installation
|
|
79
|
+
|
|
80
|
+
You can install Helios using Pip as follows:
|
|
81
|
+
|
|
82
|
+
```sh
|
|
83
|
+
pip install helios-ml
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Documentation
|
|
87
|
+
|
|
88
|
+
Documentation coming soon!
|
|
89
|
+
|
|
90
|
+
## Contributing
|
|
91
|
+
|
|
92
|
+
There are three ways in which you can contribute to Helios:
|
|
93
|
+
|
|
94
|
+
* If you find a bug, please open an issue. Similarly, if you have a question
|
|
95
|
+
about how to use it, or if something is unclear, please post an issue so it
|
|
96
|
+
can be addressed.
|
|
97
|
+
* If you have a fix for a bug, or a code enhancement, please open a pull
|
|
98
|
+
request. Before you submit it though, make sure to abide by the rules written
|
|
99
|
+
below.
|
|
100
|
+
* If you have a feature proposal, you can either open an issue or create a pull
|
|
101
|
+
request. If you are submitting a pull request, it must abide by the rules
|
|
102
|
+
written below. Note that any new features need to be approved by me.
|
|
103
|
+
|
|
104
|
+
If you are submitting a pull request, the guidelines are the following:
|
|
105
|
+
|
|
106
|
+
1. Ensure that your code follows the standards and formatting of Helios. The coding
|
|
107
|
+
standards and formatting are enforced through the Ruff Linter and Formatter. Any
|
|
108
|
+
changes that do not abide by these rules will be rejected. It is your responsibility to
|
|
109
|
+
ensure that both Ruff and Mypy linters pass.
|
|
110
|
+
2. Ensure that *all* unit tests are working prior to submitting the pull
|
|
111
|
+
request. If you are adding a new feature that has been approved, it is your
|
|
112
|
+
responsibility to provide the corresponding unit tests (if applicable).
|
|
113
|
+
|
|
114
|
+
## License
|
|
115
|
+
|
|
116
|
+
Helios is published under the BSD-3 license and can be viewed [here](LICENSE).
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
<a id="top"></a>
|
|
2
|
+

|
|
3
|
+
|
|
4
|
+
[](LICENSE)
|
|
5
|
+
[](https://www.python.org/downloads/release/python-3110/)
|
|
6
|
+
[](https://github.com/marovira/helios-ml/actions/workflows/tests.yml)
|
|
7
|
+
|
|
8
|
+
## What is Helios?
|
|
9
|
+
|
|
10
|
+
Named after Greek god of the sun, Helios is a light-weight package for training ML
|
|
11
|
+
networks built on top of PyTorch. It is designed to abstract all of the "boiler-plate"
|
|
12
|
+
code involved with training. Specifically, it wraps the following common patterns:
|
|
13
|
+
|
|
14
|
+
* Creation of the dataloaders.
|
|
15
|
+
* Initialization of CUDA, PyTorch, and random number states.
|
|
16
|
+
* Initialization for distributed training.
|
|
17
|
+
* Training, validation, and testing loops.
|
|
18
|
+
* Saving and loading checkpoints.
|
|
19
|
+
* Exporting to ONNX.
|
|
20
|
+
|
|
21
|
+
It is important to note that Helios is **not** a fully fledged training environment similar
|
|
22
|
+
to [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning). Instead, Helios
|
|
23
|
+
is focused on providing a simple and straight-forward interface that abstracts most of the
|
|
24
|
+
common code patterns while retaining the ability to be easily overridden to suit the
|
|
25
|
+
individual needs of each training scheme.
|
|
26
|
+
|
|
27
|
+
## Main Features
|
|
28
|
+
|
|
29
|
+
Helios offers the following functionality out of the box:
|
|
30
|
+
|
|
31
|
+
1. Resume training: Helios has been built with the ability to resume training if it is
|
|
32
|
+
paused. Specifically, Helios will ensure that the behaviour of the trained model is
|
|
33
|
+
*identical* to the one it would've had if it had been trained without pauses.
|
|
34
|
+
2. Automatic detection of multi-GPU environments for distributed training. In addition,
|
|
35
|
+
Helios also supports training using `torchrun` and will automatically handle the
|
|
36
|
+
initialisation and clean up of the distributed state. It will also correctly set the
|
|
37
|
+
devices and maps to ensure weights are mapped tot he correct location.
|
|
38
|
+
3. Registries for creation of arbitrary types. These include: networks, loss functions,
|
|
39
|
+
optimizers, schedulers, etc.
|
|
40
|
+
4. Correct handling of logging when doing distributed training (even over multiple nodes).
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
You can install Helios using Pip as follows:
|
|
45
|
+
|
|
46
|
+
```sh
|
|
47
|
+
pip install helios-ml
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Documentation
|
|
51
|
+
|
|
52
|
+
Documentation coming soon!
|
|
53
|
+
|
|
54
|
+
## Contributing
|
|
55
|
+
|
|
56
|
+
There are three ways in which you can contribute to Helios:
|
|
57
|
+
|
|
58
|
+
* If you find a bug, please open an issue. Similarly, if you have a question
|
|
59
|
+
about how to use it, or if something is unclear, please post an issue so it
|
|
60
|
+
can be addressed.
|
|
61
|
+
* If you have a fix for a bug, or a code enhancement, please open a pull
|
|
62
|
+
request. Before you submit it though, make sure to abide by the rules written
|
|
63
|
+
below.
|
|
64
|
+
* If you have a feature proposal, you can either open an issue or create a pull
|
|
65
|
+
request. If you are submitting a pull request, it must abide by the rules
|
|
66
|
+
written below. Note that any new features need to be approved by me.
|
|
67
|
+
|
|
68
|
+
If you are submitting a pull request, the guidelines are the following:
|
|
69
|
+
|
|
70
|
+
1. Ensure that your code follows the standards and formatting of Helios. The coding
|
|
71
|
+
standards and formatting are enforced through the Ruff Linter and Formatter. Any
|
|
72
|
+
changes that do not abide by these rules will be rejected. It is your responsibility to
|
|
73
|
+
ensure that both Ruff and Mypy linters pass.
|
|
74
|
+
2. Ensure that *all* unit tests are working prior to submitting the pull
|
|
75
|
+
request. If you are adding a new feature that has been approved, it is your
|
|
76
|
+
responsibility to provide the corresponding unit tests (if applicable).
|
|
77
|
+
|
|
78
|
+
## License
|
|
79
|
+
|
|
80
|
+
Helios is published under the BSD-3 license and can be viewed [here](LICENSE).
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import torchvision
|
|
8
|
+
import torchvision.transforms.v2 as T
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
import helios.core as hlc
|
|
12
|
+
import helios.data as hld
|
|
13
|
+
import helios.model as hlm
|
|
14
|
+
import helios.optim as hlo
|
|
15
|
+
import helios.trainer as hlt
|
|
16
|
+
from helios.core import logging
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CIFARDataModule(hld.DataModule):
|
|
20
|
+
"""
|
|
21
|
+
Example datamodule class built with CIFAR10.
|
|
22
|
+
|
|
23
|
+
Here you can see how to setup a datamodule that requires data to be downloaded using
|
|
24
|
+
torchvision's CIFAR10 dataset. The code is adapted from PyTorch's "Training a
|
|
25
|
+
Classifier" tutorial:
|
|
26
|
+
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
root (pathlib.Path): the root where the data will be downloaded to.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, root: pathlib.Path) -> None:
|
|
33
|
+
"""Construct the datamodule."""
|
|
34
|
+
super().__init__()
|
|
35
|
+
self._root = root / "data"
|
|
36
|
+
|
|
37
|
+
def prepare_data(self) -> None:
|
|
38
|
+
"""Download all necessary data."""
|
|
39
|
+
torchvision.datasets.CIFAR10(root=self._root, train=True, download=True)
|
|
40
|
+
torchvision.datasets.CIFAR10(root=self._root, train=False, download=True)
|
|
41
|
+
|
|
42
|
+
def setup(self) -> None:
|
|
43
|
+
"""Create the datasets."""
|
|
44
|
+
# Use the ToTensor transform from Pyro to automate the conversion from images to
|
|
45
|
+
# tensors.
|
|
46
|
+
transforms = T.Compose(
|
|
47
|
+
[hld.transforms.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
|
48
|
+
)
|
|
49
|
+
params = hld.DataLoaderParams()
|
|
50
|
+
params.batch_size = 4
|
|
51
|
+
params.shuffle = True
|
|
52
|
+
params.num_workers = 2
|
|
53
|
+
params.drop_last = True
|
|
54
|
+
self._train_dataset = self._create_dataset(
|
|
55
|
+
torchvision.datasets.CIFAR10(
|
|
56
|
+
root=self._root, train=True, download=False, transform=transforms
|
|
57
|
+
),
|
|
58
|
+
params,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# The dataloader params are copied when the dataset is created, so we can safely
|
|
62
|
+
# change the options for the validation dataset without interfering with the ones
|
|
63
|
+
# for training.
|
|
64
|
+
params.drop_last = False
|
|
65
|
+
params.shuffle = False
|
|
66
|
+
self._valid_dataset = self._create_dataset(
|
|
67
|
+
torchvision.datasets.CIFAR10(
|
|
68
|
+
root=self._root, train=False, download=False, transform=transforms
|
|
69
|
+
),
|
|
70
|
+
params,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Net(nn.Module):
|
|
75
|
+
"""
|
|
76
|
+
Example image classifier.
|
|
77
|
+
|
|
78
|
+
The code is taken from PyTorch's Training a Classifier tutorial:
|
|
79
|
+
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self):
|
|
83
|
+
"""Create the classifier."""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
86
|
+
self.pool = nn.MaxPool2d(2, 2)
|
|
87
|
+
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
88
|
+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
|
89
|
+
self.fc2 = nn.Linear(120, 84)
|
|
90
|
+
self.fc3 = nn.Linear(84, 10)
|
|
91
|
+
|
|
92
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
"""Compute the label for the given image."""
|
|
94
|
+
x = self.pool(F.relu(self.conv1(x)))
|
|
95
|
+
x = self.pool(F.relu(self.conv2(x)))
|
|
96
|
+
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
|
97
|
+
x = F.relu(self.fc1(x))
|
|
98
|
+
x = F.relu(self.fc2(x))
|
|
99
|
+
x = self.fc3(x)
|
|
100
|
+
return x
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ClassifierModel(hlm.Model):
|
|
104
|
+
"""
|
|
105
|
+
Example model class for training the classifier.
|
|
106
|
+
|
|
107
|
+
Here you can see how to setup the model class and some of the basic functionality that
|
|
108
|
+
is available. The code is adapted from PyTorch's "Training a Classifier" tutorial:
|
|
109
|
+
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(self) -> None:
|
|
113
|
+
"""Create the model."""
|
|
114
|
+
super().__init__("classifier")
|
|
115
|
+
|
|
116
|
+
def setup(self, fast_init: bool = False) -> None:
|
|
117
|
+
"""Create the network, loss, and optimizer."""
|
|
118
|
+
# Note that when we create the network and loss function, we immediately move them
|
|
119
|
+
# to the current device, which has been set by the trainer.
|
|
120
|
+
self._net = Net().to(self.device)
|
|
121
|
+
self._criterion = nn.CrossEntropyLoss().to(self.device)
|
|
122
|
+
|
|
123
|
+
# Note that SGD is shipped as part of the default optimizers from Pyro, so we can
|
|
124
|
+
# directly request it from create_optimizer instead of building it ourselves.
|
|
125
|
+
self._optimizer = hlo.create_optimizer(
|
|
126
|
+
"SGD", self._net.parameters(), lr=0.001, momentum=0.9
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def load_state_dict(
|
|
130
|
+
self, state_dict: dict[str, typing.Any], fast_init: bool = False
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Restore the model from a saved checkpoint."""
|
|
133
|
+
# Note that we don't have to re-map the weights ourselves. They have already been
|
|
134
|
+
# re-mapped for us by the trainer when it loaded the checkpoint.
|
|
135
|
+
self._net.load_state_dict(state_dict["net"])
|
|
136
|
+
self._criterion.load_state_dict(state_dict["criterion"])
|
|
137
|
+
self._optimizer.load_state_dict(state_dict["optimizer"])
|
|
138
|
+
|
|
139
|
+
def state_dict(self) -> dict[str, typing.Any]:
|
|
140
|
+
"""Return the state dict of the model."""
|
|
141
|
+
return {
|
|
142
|
+
"net": self._net.state_dict(),
|
|
143
|
+
"criterion": self._criterion.state_dict(),
|
|
144
|
+
"optimizer": self._optimizer.state_dict(),
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
def train(self) -> None:
|
|
148
|
+
"""Set the model to train."""
|
|
149
|
+
# If we had more networks, we would shift them to training mode here.
|
|
150
|
+
self._net.train()
|
|
151
|
+
|
|
152
|
+
def on_training_start(self) -> None:
|
|
153
|
+
"""Perform steps before training starts."""
|
|
154
|
+
tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
|
|
155
|
+
|
|
156
|
+
x = torch.randn((1, 3, 32, 32)).to(self.device)
|
|
157
|
+
tb_logger.add_graph(self._net, x)
|
|
158
|
+
|
|
159
|
+
def train_step(self, batch: typing.Any, state: hlt.TrainingState) -> None:
|
|
160
|
+
"""Forward and backward training passes."""
|
|
161
|
+
# Due to the simplicity of the code, we do both the forward and backward passes in
|
|
162
|
+
# the training step, but you could also split it between the train_step and
|
|
163
|
+
# on_training_batch_end if your setup is more complex.
|
|
164
|
+
inputs, labels = batch
|
|
165
|
+
inputs = inputs.to(self.device)
|
|
166
|
+
labels = labels.to(self.device)
|
|
167
|
+
|
|
168
|
+
self._optimizer.zero_grad()
|
|
169
|
+
|
|
170
|
+
outputs = self._net(inputs)
|
|
171
|
+
loss = self._criterion(outputs, labels)
|
|
172
|
+
loss.backward()
|
|
173
|
+
self._optimizer.step()
|
|
174
|
+
|
|
175
|
+
# Note that we save the value of the loss function to the _loss_items dictionary.
|
|
176
|
+
# This allows the model to automatically gather the losses for us if we are
|
|
177
|
+
# training in distributed mode.
|
|
178
|
+
self._loss_items["loss"] = loss
|
|
179
|
+
|
|
180
|
+
def on_training_batch_end(
|
|
181
|
+
self,
|
|
182
|
+
state: hlt.TrainingState,
|
|
183
|
+
should_log: bool = False,
|
|
184
|
+
) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Perform steps after the training batch.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
state (pyt.TrainingState): the current training state.
|
|
190
|
+
should_log (bool): if true, then writing to the log should be performed.
|
|
191
|
+
"""
|
|
192
|
+
# If we were training in distributed mode, calling the base Model's function will
|
|
193
|
+
# automatically gather all loss values saved to self._loss_items.
|
|
194
|
+
super().on_training_batch_end(state, should_log)
|
|
195
|
+
|
|
196
|
+
# This flag is set to true whenever the number of iterations is a multiple of the
|
|
197
|
+
# logging frequency we set when creating the trainer.
|
|
198
|
+
if should_log:
|
|
199
|
+
root_logger = logging.get_root_logger()
|
|
200
|
+
tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
|
|
201
|
+
|
|
202
|
+
loss_val = self._loss_items["loss"]
|
|
203
|
+
|
|
204
|
+
root_logger.info(
|
|
205
|
+
f"[{state.global_epoch + 1}, {state.global_iteration:5d}] "
|
|
206
|
+
f"loss: {loss_val:.3f}, "
|
|
207
|
+
f"running loss: {loss_val / state.running_iter:.3f} "
|
|
208
|
+
f"avg time: {state.average_iter_time:.2f}s"
|
|
209
|
+
)
|
|
210
|
+
tb_logger.add_scalar("train/loss", loss_val, state.global_iteration)
|
|
211
|
+
tb_logger.add_scalar(
|
|
212
|
+
"train/running loss",
|
|
213
|
+
loss_val / state.running_iter,
|
|
214
|
+
state.global_iteration,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def on_training_end(self) -> None:
|
|
218
|
+
"""Perform steps after training ends."""
|
|
219
|
+
# For our example, we're going to save the hyper-params to the tensorboard log so
|
|
220
|
+
# we can compare with other runs.
|
|
221
|
+
# Notice that self._val_scores is active because we validate every epoch. If your
|
|
222
|
+
# validation frequency is different, you may need to alter this code.
|
|
223
|
+
total = self._val_scores["total"]
|
|
224
|
+
correct = self._val_scores["correct"]
|
|
225
|
+
accuracy = 100 * correct // total
|
|
226
|
+
writer = hlc.get_from_optional(logging.get_tensorboard_writer())
|
|
227
|
+
writer.add_hparams(
|
|
228
|
+
{"lr": 0.001, "momentum": 0.9, "epochs": 2},
|
|
229
|
+
{"hparam/accuracy": accuracy, "hparam/loss": self._loss_items["loss"].item()},
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def eval(self) -> None:
|
|
233
|
+
"""Set the model to eval."""
|
|
234
|
+
self._net.eval()
|
|
235
|
+
|
|
236
|
+
def on_validation_start(self, validation_cycle: int) -> None:
|
|
237
|
+
"""Perform steps on validation start."""
|
|
238
|
+
# The base function will automatically clear the validation scores for us.
|
|
239
|
+
super().on_validation_start(validation_cycle)
|
|
240
|
+
|
|
241
|
+
# If you need to add further data to the table, you can do so here. In our case,
|
|
242
|
+
# we're going to add the total number of labels seen and how many of those were
|
|
243
|
+
# correct so we can compute the accuracy metric.
|
|
244
|
+
self._val_scores["total"] = 0
|
|
245
|
+
self._val_scores["correct"] = 0
|
|
246
|
+
|
|
247
|
+
def valid_step(self, batch: typing.Any, step: int) -> None:
|
|
248
|
+
"""Perform the validation step."""
|
|
249
|
+
images, labels = batch
|
|
250
|
+
images = images.to(self.device)
|
|
251
|
+
labels = labels.to(self.device)
|
|
252
|
+
|
|
253
|
+
outputs = self._net(images)
|
|
254
|
+
|
|
255
|
+
_, predicted = torch.max(outputs.data, 1)
|
|
256
|
+
self._val_scores["total"] += labels.size(0)
|
|
257
|
+
self._val_scores["correct"] += (predicted == labels).sum().item()
|
|
258
|
+
|
|
259
|
+
def on_validation_end(self, validation_cycle: int) -> None:
|
|
260
|
+
"""Perform steps after validation ends."""
|
|
261
|
+
root_logger = logging.get_root_logger()
|
|
262
|
+
tb_logger = hlc.get_from_optional(logging.get_tensorboard_writer())
|
|
263
|
+
|
|
264
|
+
# Grab the validation scores and compute the accuracy metric so we can log it. If
|
|
265
|
+
# we were in distributed mode, you would need to gather the values here.
|
|
266
|
+
total = self._val_scores["total"]
|
|
267
|
+
correct = self._val_scores["correct"]
|
|
268
|
+
accuracy = 100 * correct // total
|
|
269
|
+
|
|
270
|
+
root_logger.info(f"[Validation {validation_cycle}] accuracy: {accuracy}")
|
|
271
|
+
tb_logger.add_scalar("val", accuracy, validation_cycle)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
if __name__ == "__main__":
|
|
275
|
+
# Set the CUBLAS workspace setting to allow determinism to be used in CUDA >= 10.2.
|
|
276
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
277
|
+
|
|
278
|
+
datamodule = CIFARDataModule(pathlib.Path.cwd())
|
|
279
|
+
model = ClassifierModel()
|
|
280
|
+
|
|
281
|
+
trainer = hlt.Trainer(
|
|
282
|
+
run_name="cifar10",
|
|
283
|
+
train_unit=hlt.TrainingUnit.EPOCH,
|
|
284
|
+
total_steps=2,
|
|
285
|
+
valid_frequency=1,
|
|
286
|
+
chkpt_frequency=1,
|
|
287
|
+
print_frequency=10,
|
|
288
|
+
enable_tensorboard=True,
|
|
289
|
+
enable_file_logging=True,
|
|
290
|
+
enable_progress_bar=True,
|
|
291
|
+
enable_deterministic=True,
|
|
292
|
+
chkpt_root=pathlib.Path.cwd() / "chkpt",
|
|
293
|
+
log_path=pathlib.Path.cwd() / "logs",
|
|
294
|
+
run_path=pathlib.Path.cwd() / "runs",
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
trainer.fit(model, datamodule)
|