torch-blue 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_blue-0.9.0/LICENSE +28 -0
- torch_blue-0.9.0/PKG-INFO +243 -0
- torch_blue-0.9.0/README.md +214 -0
- torch_blue-0.9.0/pyproject.toml +155 -0
- torch_blue-0.9.0/setup.cfg +4 -0
- torch_blue-0.9.0/torch_blue/__init__.py +1 -0
- torch_blue-0.9.0/torch_blue/vi/__init__.py +48 -0
- torch_blue-0.9.0/torch_blue/vi/_globals.py +1 -0
- torch_blue-0.9.0/torch_blue/vi/analytical_kl_loss.py +443 -0
- torch_blue-0.9.0/torch_blue/vi/base.py +743 -0
- torch_blue-0.9.0/torch_blue/vi/conv.py +502 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/__init__.py +23 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/base.py +459 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/categorical.py +102 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/non_bayesian.py +172 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/normal.py +221 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/quiet.py +122 -0
- torch_blue-0.9.0/torch_blue/vi/distributions/student_t.py +133 -0
- torch_blue-0.9.0/torch_blue/vi/kl_loss.py +143 -0
- torch_blue-0.9.0/torch_blue/vi/linear.py +124 -0
- torch_blue-0.9.0/torch_blue/vi/sequential.py +119 -0
- torch_blue-0.9.0/torch_blue/vi/transformer.py +820 -0
- torch_blue-0.9.0/torch_blue/vi/utils/__init__.py +12 -0
- torch_blue-0.9.0/torch_blue/vi/utils/common_types.py +56 -0
- torch_blue-0.9.0/torch_blue/vi/utils/errors.py +19 -0
- torch_blue-0.9.0/torch_blue/vi/utils/init.py +27 -0
- torch_blue-0.9.0/torch_blue/vi/utils/post_init_metaclass.py +13 -0
- torch_blue-0.9.0/torch_blue/vi/utils/use_norm_constants.py +17 -0
- torch_blue-0.9.0/torch_blue/vi/utils/vi_return.py +72 -0
- torch_blue-0.9.0/torch_blue.egg-info/PKG-INFO +243 -0
- torch_blue-0.9.0/torch_blue.egg-info/SOURCES.txt +32 -0
- torch_blue-0.9.0/torch_blue.egg-info/dependency_links.txt +1 -0
- torch_blue-0.9.0/torch_blue.egg-info/requires.txt +19 -0
- torch_blue-0.9.0/torch_blue.egg-info/top_level.txt +1 -0
torch_blue-0.9.0/LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025, JRG Robust and Efficient AI
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch_blue
|
|
3
|
+
Version: 0.9.0
|
|
4
|
+
Summary: Library for BNNs
|
|
5
|
+
Author-email: RAI <rai@kit.edu>
|
|
6
|
+
Classifier: Programming Language :: Python
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Requires-Python: >=3.9
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: torch
|
|
13
|
+
Provides-Extra: dev
|
|
14
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
15
|
+
Requires-Dist: ruff; extra == "dev"
|
|
16
|
+
Requires-Dist: mypy; extra == "dev"
|
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
|
18
|
+
Requires-Dist: coverage; extra == "dev"
|
|
19
|
+
Requires-Dist: build; extra == "dev"
|
|
20
|
+
Provides-Extra: docs
|
|
21
|
+
Requires-Dist: sphinx; extra == "docs"
|
|
22
|
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
|
23
|
+
Requires-Dist: sphinx-autoapi; extra == "docs"
|
|
24
|
+
Requires-Dist: myst-parser; extra == "docs"
|
|
25
|
+
Provides-Extra: scripts
|
|
26
|
+
Requires-Dist: matplotlib; extra == "scripts"
|
|
27
|
+
Requires-Dist: torchvision; extra == "scripts"
|
|
28
|
+
Dynamic: license-file
|
|
29
|
+
|
|
30
|
+
# torch_blue - A PyTorch-like framework for Bayesian learning and uncertainty estimation
|
|
31
|
+
|
|
32
|
+
`torch_blue` provides a simple way for non-expert users to implement and train Bayesian
|
|
33
|
+
Neural Networks (BNNs). Currently, it only supports Variational Inference (VI), but will
|
|
34
|
+
hopefully grow and expand in the future. To make the user experience as easy as possible
|
|
35
|
+
most components mirror components from [PyTorch](https://pytorch.org/docs/stable/index.html).
|
|
36
|
+
|
|
37
|
+
- [Installation](#installation)
|
|
38
|
+
- [Documentation](#documentation)
|
|
39
|
+
- [Quickstart](#quickstart)
|
|
40
|
+
- [Level 1](#level-1)
|
|
41
|
+
- [Level 2](#level-2)
|
|
42
|
+
- [Level 3](#level-3)
|
|
43
|
+
- [Level 4](#level-4)
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
We heavily recommend installing `torch_blue` in a dedicated `Python3.9+`
|
|
48
|
+
[virtual environment](https://docs.python.org/3/library/venv.html). You can install
|
|
49
|
+
`torch_blue` from PyPI:
|
|
50
|
+
|
|
51
|
+
```console
|
|
52
|
+
$ pip install torch-blue
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
Alternatively, you can install `torch_blue` locally. To achieve this, there
|
|
56
|
+
are two steps you need to follow:
|
|
57
|
+
|
|
58
|
+
1. Clone the repository
|
|
59
|
+
|
|
60
|
+
```console
|
|
61
|
+
$ git clone https://github.com/RAI-SCC/torch_blue
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
2. Install the code locally
|
|
65
|
+
|
|
66
|
+
```console
|
|
67
|
+
$ pip install -e .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
To get the development dependencies, run:
|
|
71
|
+
|
|
72
|
+
```console
|
|
73
|
+
$ pip install -e .[dev]
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
For additional dependencies required if you want to run scripts from the scripts
|
|
77
|
+
directory, run:
|
|
78
|
+
|
|
79
|
+
```console
|
|
80
|
+
$ pip install -e .[scripts]
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
## Documentation
|
|
85
|
+
|
|
86
|
+
Documentation is available online at [readthedocs](https://torch-blue.readthedocs.io).
|
|
87
|
+
|
|
88
|
+
## Quickstart
|
|
89
|
+
|
|
90
|
+
This Quickstart guide assumes basic familiarity with [PyTorch](https://pytorch.org/docs/stable/index.html)
|
|
91
|
+
and knowledge of how to implement the intended model in it. For a (potentially familiar)
|
|
92
|
+
example see `scripts/pytorch_tutorial.py`, which contains a copy of the PyTorch
|
|
93
|
+
[Quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
|
|
94
|
+
modified to train a BNN with variational inference.
|
|
95
|
+
Three levels are introduced:
|
|
96
|
+
- [Level 1](#level-1): Simple sequential layer stacks
|
|
97
|
+
- [Level 2](#level-2): Customizing Bayesian assumptions and VI kwargs
|
|
98
|
+
- [Level 3](#level-3): Non-sequential models and log probabilities
|
|
99
|
+
- [Level 4](#level-4): Custom modules with weights
|
|
100
|
+
|
|
101
|
+
### Level 1
|
|
102
|
+
|
|
103
|
+
Many parts of a neural network remain completely unchanged when turning it into a BNN.
|
|
104
|
+
Indeed, only `Module`s containing `nn.Parameter`s, need to be changed. Therefore, if a
|
|
105
|
+
PyTorch model fulfills two requirements it can be transferred almost unchanged:
|
|
106
|
+
|
|
107
|
+
1. All PyTorch `Module`s containing parameters have equivalents in this package (table below).
|
|
108
|
+
2. The model can be expressed purely as a sequential application of a list of layers,
|
|
109
|
+
i.e. with `nn.Sequential`.
|
|
110
|
+
|
|
111
|
+
| PyTorch | vi replacement |
|
|
112
|
+
|------------------|-----------------|
|
|
113
|
+
| `nn.Linear` | `VILinear` |
|
|
114
|
+
| `nn.Conv1d` | `VIConv1d` |
|
|
115
|
+
| `nn.Conv2d` | `VIConv2d` |
|
|
116
|
+
| `nn.Conv3d` | `VIConv3d` |
|
|
117
|
+
| `nn.Transformer` | `VITransformer` |
|
|
118
|
+
|
|
119
|
+
Given these two conditions, inherit the module from `vi.VIModule` instead of `nn.Module`
|
|
120
|
+
and use `vi.VISequential` instead of `nn.Sequential`. Then replace all layers
|
|
121
|
+
containing parameters as shown in the table above. For basic usage initialize these
|
|
122
|
+
modules with the same arguments as their PyTorch equivalent. For advanced usage see
|
|
123
|
+
[Quickstart: Level 2](#level-2). Many other layers can be included as-is. In particular
|
|
124
|
+
activation functions, pooling, and padding (even dropout, though they
|
|
125
|
+
should not be necessary since the prior acts as regularization). Currently not supported
|
|
126
|
+
are recurrent and transposed convolution layers. Normalization layers may
|
|
127
|
+
have parameters depending on their setting, but can likely be left non-Bayesian.
|
|
128
|
+
|
|
129
|
+
Additionally, the loss must be replaced. To start out use `vi.KullbackLeiblerLoss`,
|
|
130
|
+
which requires a `Distribution` with `self.is_predictive_distribution=True` and the size
|
|
131
|
+
of the training dataset (this is important for balancing of assumptions and data. Choose
|
|
132
|
+
your `Distribution` from the table below based on the loss you would use in PyTorch.
|
|
133
|
+
|
|
134
|
+
> [!IMPORTANT]
|
|
135
|
+
> `KullbackLeiblerLoss` requires the length of the dataset, not the dataloader, which is
|
|
136
|
+
> just the number of batches.
|
|
137
|
+
|
|
138
|
+
| PyTorch | vi replacement <br/> from `vi.distributions` |
|
|
139
|
+
|-----------------------|----------------------------------------------|
|
|
140
|
+
| `nn.MSELoss` | `MeanFieldNormal` |
|
|
141
|
+
| `nn.CrossEntropyLoss` | `Categorical` |
|
|
142
|
+
|
|
143
|
+
> [!NOTE]
|
|
144
|
+
> Reasons for the requirement to use `VISequential` (and how to overcome it)
|
|
145
|
+
> are described in [Quickstart: Level 3](#level-3). However, adding residual connections
|
|
146
|
+
> from the start to the end of a block of layers can also be achieved using
|
|
147
|
+
> `VIResidualConnection`, which acts the same as `VISequential`, but adds the input to
|
|
148
|
+
> the output.
|
|
149
|
+
|
|
150
|
+
### Level 2
|
|
151
|
+
|
|
152
|
+
While the interface of `VIModule`s is kept intentionally similar to PyTorch, there are
|
|
153
|
+
additional arguments that customize the Bayesian assumptions that all provided layers
|
|
154
|
+
accept and custom modules should generally accept and pass on to submodules:
|
|
155
|
+
- variational_distribution (`Distribution`): defines the weight distribution and
|
|
156
|
+
variational parameters. The default `MeanFieldNormal` assumes normal distributed,
|
|
157
|
+
uncorrelated weights described by a mean and a standard deviation. While there are
|
|
158
|
+
currently no alternatives the initial value of the standard deviation can be customized
|
|
159
|
+
here.
|
|
160
|
+
- prior (`Distribution`): defines the assumptions on the weight distribution and acts as
|
|
161
|
+
regularizer. The default `MeanFieldNormal` assumes normal distributed, uncorrelated
|
|
162
|
+
weights with mean 0 and standard deviation 1 (also known as a standard normal prior).
|
|
163
|
+
Mean and standard deviation can be adapted here. Particularly reducing the standard
|
|
164
|
+
deviation may help convergence at the risk of an overconfident model. Other available
|
|
165
|
+
priors:
|
|
166
|
+
- `BasicQuietPrior`: an experimental prior that correlates mean and standard deviation
|
|
167
|
+
to disincentivize noisy weights
|
|
168
|
+
- rescale_prior (`bool`): Experimental. Scales the prior similar to Kaiming-initialization.
|
|
169
|
+
May help with convergence, but may lead to overconfidence. Current research.
|
|
170
|
+
- prior_initialization (`bool`): Experimental. Initialize parameters from the prior
|
|
171
|
+
instead of according to standard non-Bayesian methods. May lead to much faster
|
|
172
|
+
convergence, but can cause the issues Kaiming-initialization counteracts unless
|
|
173
|
+
rescale_prior is also set to True. Current research.
|
|
174
|
+
- return_log_probs (`bool`): This is the topic of [Quickstart: Level 3](#level-3).
|
|
175
|
+
|
|
176
|
+
### Level 3
|
|
177
|
+
|
|
178
|
+
For more advanced models one feature of Variational Inference (VI) needs to be taken
|
|
179
|
+
into account. Generally, a loss for VI will require the log probability of the actually
|
|
180
|
+
used weights (which are sampled on each forward pass) in the variational and prior
|
|
181
|
+
distribution. Since it is quite inefficient to save the samples these log probabilities
|
|
182
|
+
are evaluated during the forward pass and returned by the model. Since this is only
|
|
183
|
+
necessary for training it can be controlled with the argument `return_log_probs`. Once
|
|
184
|
+
the model is initialized this flag can be changed by setting `VIModule.return_log_probs`,
|
|
185
|
+
which either enables (`True`) or disables (`False`) the returning of the log
|
|
186
|
+
probabilities for all submodules.
|
|
187
|
+
|
|
188
|
+
While `torch_blue` calculates and aggregates log probs internally, this is handled
|
|
189
|
+
by the outermost `VIModule`. This module will not have the expected output signature
|
|
190
|
+
when returning log probs, but instead return a `VIReturn` object. This class is PyTorch
|
|
191
|
+
`Tensor` that also contains log prob information in its additional `log_probs`
|
|
192
|
+
attribute. This is the format `torch_blue` losses expect. Therefore, if you feed the
|
|
193
|
+
output directly into a loss there should be no issues. While all PyTorch tensor
|
|
194
|
+
operations can be performed on `VIReturns` many will delete the log prob information and
|
|
195
|
+
transform the object back into a `Tensor`. This needs to be considered when performing
|
|
196
|
+
further operations on the model output. The simplest way to avoid issues is to wrap all
|
|
197
|
+
operations - except the loss - in a `VIModule` since log prob aggregation is only
|
|
198
|
+
performed by the outermost module. For deployment `return_log_probs` should be set to
|
|
199
|
+
`False`. If multiple `Tensor`s are returned by the model, each will carry all log probs.
|
|
200
|
+
|
|
201
|
+
> [!NOTE]
|
|
202
|
+
> Always make sure your outermost module is a VIModule and keep in mind that the output
|
|
203
|
+
> of that module will be a `VIReturn` object, which behaves like a `Tensor`, carries
|
|
204
|
+
> weight log probabilities, if `return_log_probs == True`. Losses in `torch_blue`
|
|
205
|
+
> expect this format.
|
|
206
|
+
|
|
207
|
+
> [!NOTE]
|
|
208
|
+
> Due to Autosampling all output Tensors, i.e. each `VIReturn`
|
|
209
|
+
> in the model output and the `Tensor` containing the log probs has an additional
|
|
210
|
+
> dimension at the beginning representing the multiple samples necessary to properly
|
|
211
|
+
> evaluate the stochastic forward pass. This is only relevant for VIModules that are not
|
|
212
|
+
> contained within other VIModules. Loss functions are designed to expect and handle
|
|
213
|
+
> this output format, i.e. you can simply feed the model output into the loss and
|
|
214
|
+
> everything will work.
|
|
215
|
+
|
|
216
|
+
### Level 4
|
|
217
|
+
|
|
218
|
+
Creating `VIModule`s with Bayesian weights - which are typically called random
|
|
219
|
+
variables in documentation and code - is arguably simpler than in PyTorch. Since a
|
|
220
|
+
different number of weight matrices needs to be created based on the variational
|
|
221
|
+
distribution, the process is completely automated. For `VIModules` without weights
|
|
222
|
+
`super().__init__` is called without arguments. Modules with random variables
|
|
223
|
+
expect `VIkwargs` (which you should be familiar with from [Level 2](#level-2)), but
|
|
224
|
+
defaults are used if non are passed. More importantly, `VIModules` with weights call
|
|
225
|
+
`super().__init__` with the argument `variable_shapes`. The keys of this dictionary are
|
|
226
|
+
the names of the random variables and the values the shapes of the weight matrices as
|
|
227
|
+
tuple or list. The value may also be set to `None`, which will always be the value
|
|
228
|
+
returned for that variable.
|
|
229
|
+
|
|
230
|
+
The insertion order of this dictionary matters, as it becomes the order of the names
|
|
231
|
+
in the module attribute `random_variables`. `random_variables`, the shapes, and a similar
|
|
232
|
+
attribute of the variational distribution call `distribution_parameters` are used to
|
|
233
|
+
dynamically create the weight matrices. The weight matrices can be accesses as
|
|
234
|
+
attributes of the module, which will cause a sample to be drawn and its log prob to be
|
|
235
|
+
stored if needed.
|
|
236
|
+
|
|
237
|
+
Should you need to access the weight tensors directly you can use `getattr` and derive
|
|
238
|
+
the name using the method `variational_parameter_name`.
|
|
239
|
+
|
|
240
|
+
> [!IMPORTANT]
|
|
241
|
+
> Every access of the weights will yield a new sample and log probability to be stored.
|
|
242
|
+
> Aggregation of multiple log probs is handled internally, but unnecessary calls will
|
|
243
|
+
> distort the result.
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# torch_blue - A PyTorch-like framework for Bayesian learning and uncertainty estimation
|
|
2
|
+
|
|
3
|
+
`torch_blue` provides a simple way for non-expert users to implement and train Bayesian
|
|
4
|
+
Neural Networks (BNNs). Currently, it only supports Variational Inference (VI), but will
|
|
5
|
+
hopefully grow and expand in the future. To make the user experience as easy as possible
|
|
6
|
+
most components mirror components from [PyTorch](https://pytorch.org/docs/stable/index.html).
|
|
7
|
+
|
|
8
|
+
- [Installation](#installation)
|
|
9
|
+
- [Documentation](#documentation)
|
|
10
|
+
- [Quickstart](#quickstart)
|
|
11
|
+
- [Level 1](#level-1)
|
|
12
|
+
- [Level 2](#level-2)
|
|
13
|
+
- [Level 3](#level-3)
|
|
14
|
+
- [Level 4](#level-4)
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
We heavily recommend installing `torch_blue` in a dedicated `Python3.9+`
|
|
19
|
+
[virtual environment](https://docs.python.org/3/library/venv.html). You can install
|
|
20
|
+
`torch_blue` from PyPI:
|
|
21
|
+
|
|
22
|
+
```console
|
|
23
|
+
$ pip install torch-blue
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
Alternatively, you can install `torch_blue` locally. To achieve this, there
|
|
27
|
+
are two steps you need to follow:
|
|
28
|
+
|
|
29
|
+
1. Clone the repository
|
|
30
|
+
|
|
31
|
+
```console
|
|
32
|
+
$ git clone https://github.com/RAI-SCC/torch_blue
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
2. Install the code locally
|
|
36
|
+
|
|
37
|
+
```console
|
|
38
|
+
$ pip install -e .
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
To get the development dependencies, run:
|
|
42
|
+
|
|
43
|
+
```console
|
|
44
|
+
$ pip install -e .[dev]
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
For additional dependencies required if you want to run scripts from the scripts
|
|
48
|
+
directory, run:
|
|
49
|
+
|
|
50
|
+
```console
|
|
51
|
+
$ pip install -e .[scripts]
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
## Documentation
|
|
56
|
+
|
|
57
|
+
Documentation is available online at [readthedocs](https://torch-blue.readthedocs.io).
|
|
58
|
+
|
|
59
|
+
## Quickstart
|
|
60
|
+
|
|
61
|
+
This Quickstart guide assumes basic familiarity with [PyTorch](https://pytorch.org/docs/stable/index.html)
|
|
62
|
+
and knowledge of how to implement the intended model in it. For a (potentially familiar)
|
|
63
|
+
example see `scripts/pytorch_tutorial.py`, which contains a copy of the PyTorch
|
|
64
|
+
[Quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
|
|
65
|
+
modified to train a BNN with variational inference.
|
|
66
|
+
Three levels are introduced:
|
|
67
|
+
- [Level 1](#level-1): Simple sequential layer stacks
|
|
68
|
+
- [Level 2](#level-2): Customizing Bayesian assumptions and VI kwargs
|
|
69
|
+
- [Level 3](#level-3): Non-sequential models and log probabilities
|
|
70
|
+
- [Level 4](#level-4): Custom modules with weights
|
|
71
|
+
|
|
72
|
+
### Level 1
|
|
73
|
+
|
|
74
|
+
Many parts of a neural network remain completely unchanged when turning it into a BNN.
|
|
75
|
+
Indeed, only `Module`s containing `nn.Parameter`s, need to be changed. Therefore, if a
|
|
76
|
+
PyTorch model fulfills two requirements it can be transferred almost unchanged:
|
|
77
|
+
|
|
78
|
+
1. All PyTorch `Module`s containing parameters have equivalents in this package (table below).
|
|
79
|
+
2. The model can be expressed purely as a sequential application of a list of layers,
|
|
80
|
+
i.e. with `nn.Sequential`.
|
|
81
|
+
|
|
82
|
+
| PyTorch | vi replacement |
|
|
83
|
+
|------------------|-----------------|
|
|
84
|
+
| `nn.Linear` | `VILinear` |
|
|
85
|
+
| `nn.Conv1d` | `VIConv1d` |
|
|
86
|
+
| `nn.Conv2d` | `VIConv2d` |
|
|
87
|
+
| `nn.Conv3d` | `VIConv3d` |
|
|
88
|
+
| `nn.Transformer` | `VITransformer` |
|
|
89
|
+
|
|
90
|
+
Given these two conditions, inherit the module from `vi.VIModule` instead of `nn.Module`
|
|
91
|
+
and use `vi.VISequential` instead of `nn.Sequential`. Then replace all layers
|
|
92
|
+
containing parameters as shown in the table above. For basic usage initialize these
|
|
93
|
+
modules with the same arguments as their PyTorch equivalent. For advanced usage see
|
|
94
|
+
[Quickstart: Level 2](#level-2). Many other layers can be included as-is. In particular
|
|
95
|
+
activation functions, pooling, and padding (even dropout, though they
|
|
96
|
+
should not be necessary since the prior acts as regularization). Currently not supported
|
|
97
|
+
are recurrent and transposed convolution layers. Normalization layers may
|
|
98
|
+
have parameters depending on their setting, but can likely be left non-Bayesian.
|
|
99
|
+
|
|
100
|
+
Additionally, the loss must be replaced. To start out use `vi.KullbackLeiblerLoss`,
|
|
101
|
+
which requires a `Distribution` with `self.is_predictive_distribution=True` and the size
|
|
102
|
+
of the training dataset (this is important for balancing of assumptions and data. Choose
|
|
103
|
+
your `Distribution` from the table below based on the loss you would use in PyTorch.
|
|
104
|
+
|
|
105
|
+
> [!IMPORTANT]
|
|
106
|
+
> `KullbackLeiblerLoss` requires the length of the dataset, not the dataloader, which is
|
|
107
|
+
> just the number of batches.
|
|
108
|
+
|
|
109
|
+
| PyTorch | vi replacement <br/> from `vi.distributions` |
|
|
110
|
+
|-----------------------|----------------------------------------------|
|
|
111
|
+
| `nn.MSELoss` | `MeanFieldNormal` |
|
|
112
|
+
| `nn.CrossEntropyLoss` | `Categorical` |
|
|
113
|
+
|
|
114
|
+
> [!NOTE]
|
|
115
|
+
> Reasons for the requirement to use `VISequential` (and how to overcome it)
|
|
116
|
+
> are described in [Quickstart: Level 3](#level-3). However, adding residual connections
|
|
117
|
+
> from the start to the end of a block of layers can also be achieved using
|
|
118
|
+
> `VIResidualConnection`, which acts the same as `VISequential`, but adds the input to
|
|
119
|
+
> the output.
|
|
120
|
+
|
|
121
|
+
### Level 2
|
|
122
|
+
|
|
123
|
+
While the interface of `VIModule`s is kept intentionally similar to PyTorch, there are
|
|
124
|
+
additional arguments that customize the Bayesian assumptions that all provided layers
|
|
125
|
+
accept and custom modules should generally accept and pass on to submodules:
|
|
126
|
+
- variational_distribution (`Distribution`): defines the weight distribution and
|
|
127
|
+
variational parameters. The default `MeanFieldNormal` assumes normal distributed,
|
|
128
|
+
uncorrelated weights described by a mean and a standard deviation. While there are
|
|
129
|
+
currently no alternatives the initial value of the standard deviation can be customized
|
|
130
|
+
here.
|
|
131
|
+
- prior (`Distribution`): defines the assumptions on the weight distribution and acts as
|
|
132
|
+
regularizer. The default `MeanFieldNormal` assumes normal distributed, uncorrelated
|
|
133
|
+
weights with mean 0 and standard deviation 1 (also known as a standard normal prior).
|
|
134
|
+
Mean and standard deviation can be adapted here. Particularly reducing the standard
|
|
135
|
+
deviation may help convergence at the risk of an overconfident model. Other available
|
|
136
|
+
priors:
|
|
137
|
+
- `BasicQuietPrior`: an experimental prior that correlates mean and standard deviation
|
|
138
|
+
to disincentivize noisy weights
|
|
139
|
+
- rescale_prior (`bool`): Experimental. Scales the prior similar to Kaiming-initialization.
|
|
140
|
+
May help with convergence, but may lead to overconfidence. Current research.
|
|
141
|
+
- prior_initialization (`bool`): Experimental. Initialize parameters from the prior
|
|
142
|
+
instead of according to standard non-Bayesian methods. May lead to much faster
|
|
143
|
+
convergence, but can cause the issues Kaiming-initialization counteracts unless
|
|
144
|
+
rescale_prior is also set to True. Current research.
|
|
145
|
+
- return_log_probs (`bool`): This is the topic of [Quickstart: Level 3](#level-3).
|
|
146
|
+
|
|
147
|
+
### Level 3
|
|
148
|
+
|
|
149
|
+
For more advanced models one feature of Variational Inference (VI) needs to be taken
|
|
150
|
+
into account. Generally, a loss for VI will require the log probability of the actually
|
|
151
|
+
used weights (which are sampled on each forward pass) in the variational and prior
|
|
152
|
+
distribution. Since it is quite inefficient to save the samples these log probabilities
|
|
153
|
+
are evaluated during the forward pass and returned by the model. Since this is only
|
|
154
|
+
necessary for training it can be controlled with the argument `return_log_probs`. Once
|
|
155
|
+
the model is initialized this flag can be changed by setting `VIModule.return_log_probs`,
|
|
156
|
+
which either enables (`True`) or disables (`False`) the returning of the log
|
|
157
|
+
probabilities for all submodules.
|
|
158
|
+
|
|
159
|
+
While `torch_blue` calculates and aggregates log probs internally, this is handled
|
|
160
|
+
by the outermost `VIModule`. This module will not have the expected output signature
|
|
161
|
+
when returning log probs, but instead return a `VIReturn` object. This class is PyTorch
|
|
162
|
+
`Tensor` that also contains log prob information in its additional `log_probs`
|
|
163
|
+
attribute. This is the format `torch_blue` losses expect. Therefore, if you feed the
|
|
164
|
+
output directly into a loss there should be no issues. While all PyTorch tensor
|
|
165
|
+
operations can be performed on `VIReturns` many will delete the log prob information and
|
|
166
|
+
transform the object back into a `Tensor`. This needs to be considered when performing
|
|
167
|
+
further operations on the model output. The simplest way to avoid issues is to wrap all
|
|
168
|
+
operations - except the loss - in a `VIModule` since log prob aggregation is only
|
|
169
|
+
performed by the outermost module. For deployment `return_log_probs` should be set to
|
|
170
|
+
`False`. If multiple `Tensor`s are returned by the model, each will carry all log probs.
|
|
171
|
+
|
|
172
|
+
> [!NOTE]
|
|
173
|
+
> Always make sure your outermost module is a VIModule and keep in mind that the output
|
|
174
|
+
> of that module will be a `VIReturn` object, which behaves like a `Tensor`, carries
|
|
175
|
+
> weight log probabilities, if `return_log_probs == True`. Losses in `torch_blue`
|
|
176
|
+
> expect this format.
|
|
177
|
+
|
|
178
|
+
> [!NOTE]
|
|
179
|
+
> Due to Autosampling all output Tensors, i.e. each `VIReturn`
|
|
180
|
+
> in the model output and the `Tensor` containing the log probs has an additional
|
|
181
|
+
> dimension at the beginning representing the multiple samples necessary to properly
|
|
182
|
+
> evaluate the stochastic forward pass. This is only relevant for VIModules that are not
|
|
183
|
+
> contained within other VIModules. Loss functions are designed to expect and handle
|
|
184
|
+
> this output format, i.e. you can simply feed the model output into the loss and
|
|
185
|
+
> everything will work.
|
|
186
|
+
|
|
187
|
+
### Level 4
|
|
188
|
+
|
|
189
|
+
Creating `VIModule`s with Bayesian weights - which are typically called random
|
|
190
|
+
variables in documentation and code - is arguably simpler than in PyTorch. Since a
|
|
191
|
+
different number of weight matrices needs to be created based on the variational
|
|
192
|
+
distribution, the process is completely automated. For `VIModules` without weights
|
|
193
|
+
`super().__init__` is called without arguments. Modules with random variables
|
|
194
|
+
expect `VIkwargs` (which you should be familiar with from [Level 2](#level-2)), but
|
|
195
|
+
defaults are used if non are passed. More importantly, `VIModules` with weights call
|
|
196
|
+
`super().__init__` with the argument `variable_shapes`. The keys of this dictionary are
|
|
197
|
+
the names of the random variables and the values the shapes of the weight matrices as
|
|
198
|
+
tuple or list. The value may also be set to `None`, which will always be the value
|
|
199
|
+
returned for that variable.
|
|
200
|
+
|
|
201
|
+
The insertion order of this dictionary matters, as it becomes the order of the names
|
|
202
|
+
in the module attribute `random_variables`. `random_variables`, the shapes, and a similar
|
|
203
|
+
attribute of the variational distribution call `distribution_parameters` are used to
|
|
204
|
+
dynamically create the weight matrices. The weight matrices can be accesses as
|
|
205
|
+
attributes of the module, which will cause a sample to be drawn and its log prob to be
|
|
206
|
+
stored if needed.
|
|
207
|
+
|
|
208
|
+
Should you need to access the weight tensors directly you can use `getattr` and derive
|
|
209
|
+
the name using the method `variational_parameter_name`.
|
|
210
|
+
|
|
211
|
+
> [!IMPORTANT]
|
|
212
|
+
> Every access of the weights will yield a new sample and log probability to be stored.
|
|
213
|
+
> Aggregation of multiple log probs is handled internally, but unnecessary calls will
|
|
214
|
+
> distort the result.
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools >= 61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[tool.setuptools]
|
|
6
|
+
packages = [
|
|
7
|
+
"torch_blue",
|
|
8
|
+
"torch_blue.vi",
|
|
9
|
+
"torch_blue.vi.distributions",
|
|
10
|
+
"torch_blue.vi.utils",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[project]
|
|
14
|
+
name = "torch_blue"
|
|
15
|
+
version = "0.9.0"
|
|
16
|
+
authors = [
|
|
17
|
+
{ name="RAI", email="rai@kit.edu"},
|
|
18
|
+
]
|
|
19
|
+
description = "Library for BNNs"
|
|
20
|
+
readme = "README.md"
|
|
21
|
+
requires-python = ">=3.9"
|
|
22
|
+
classifiers = [
|
|
23
|
+
"Programming Language :: Python",
|
|
24
|
+
"Programming Language :: Python :: 3",
|
|
25
|
+
"Development Status :: 4 - Beta",
|
|
26
|
+
]
|
|
27
|
+
dependencies = [
|
|
28
|
+
"torch",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.optional-dependencies]
|
|
32
|
+
dev = [
|
|
33
|
+
"pre-commit",
|
|
34
|
+
"ruff",
|
|
35
|
+
"mypy",
|
|
36
|
+
"pytest",
|
|
37
|
+
"coverage",
|
|
38
|
+
"build",
|
|
39
|
+
]
|
|
40
|
+
docs = [
|
|
41
|
+
"sphinx",
|
|
42
|
+
"sphinx-rtd-theme",
|
|
43
|
+
"sphinx-autoapi",
|
|
44
|
+
"myst-parser",
|
|
45
|
+
]
|
|
46
|
+
scripts = [
|
|
47
|
+
"matplotlib",
|
|
48
|
+
"torchvision",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
[tool.mypy]
|
|
52
|
+
python_version = "3.9"
|
|
53
|
+
modules = ["torch_blue"]
|
|
54
|
+
disallow_untyped_defs = true
|
|
55
|
+
disallow_incomplete_defs = true
|
|
56
|
+
files = [
|
|
57
|
+
"scripts/"
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
[tool.ruff]
|
|
62
|
+
# Exclude a variety of commonly ignored directories.
|
|
63
|
+
exclude = [
|
|
64
|
+
".bzr",
|
|
65
|
+
".direnv",
|
|
66
|
+
".eggs",
|
|
67
|
+
".git",
|
|
68
|
+
".git-rewrite",
|
|
69
|
+
".hg",
|
|
70
|
+
".ipynb_checkpoints",
|
|
71
|
+
".mypy_cache",
|
|
72
|
+
".nox",
|
|
73
|
+
".pants.d",
|
|
74
|
+
".pyenv",
|
|
75
|
+
".pytest_cache",
|
|
76
|
+
".pytype",
|
|
77
|
+
".ruff_cache",
|
|
78
|
+
".svn",
|
|
79
|
+
".tox",
|
|
80
|
+
".venv",
|
|
81
|
+
".vscode",
|
|
82
|
+
"__pypackages__",
|
|
83
|
+
"_build",
|
|
84
|
+
"buck-out",
|
|
85
|
+
"build",
|
|
86
|
+
"dist",
|
|
87
|
+
"node_modules",
|
|
88
|
+
"site-packages",
|
|
89
|
+
"venv",
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
# Same as Black.
|
|
93
|
+
line-length = 88
|
|
94
|
+
indent-width = 4
|
|
95
|
+
|
|
96
|
+
# Assume Python 3.8.
|
|
97
|
+
target-version = "py39"
|
|
98
|
+
|
|
99
|
+
[tool.ruff.lint]
|
|
100
|
+
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
|
101
|
+
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
|
102
|
+
# McCabe complexity (`C901`) by default.
|
|
103
|
+
select = ["N", "E4", "E7", "E9", "F", "D"]
|
|
104
|
+
ignore = ["D100", "D404"]
|
|
105
|
+
# Enable import sorting
|
|
106
|
+
extend-select = ["I"]
|
|
107
|
+
|
|
108
|
+
# Allow fix for all enabled rules (when `--fix`) is provided.
|
|
109
|
+
fixable = ["ALL"]
|
|
110
|
+
unfixable = []
|
|
111
|
+
|
|
112
|
+
# Allow unused variables when underscore-prefixed.
|
|
113
|
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
|
114
|
+
|
|
115
|
+
[tool.ruff.format]
|
|
116
|
+
# Like Black, use double quotes for strings.
|
|
117
|
+
quote-style = "double"
|
|
118
|
+
|
|
119
|
+
# Like Black, indent with spaces, rather than tabs.
|
|
120
|
+
indent-style = "space"
|
|
121
|
+
|
|
122
|
+
# Like Black, respect magic trailing commas.
|
|
123
|
+
skip-magic-trailing-comma = false
|
|
124
|
+
|
|
125
|
+
# Like Black, automatically detect the appropriate line ending.
|
|
126
|
+
line-ending = "auto"
|
|
127
|
+
|
|
128
|
+
# Enable auto-formatting of code examples in docstrings. Markdown,
|
|
129
|
+
# reStructuredText code/literal blocks and doctests are all supported.
|
|
130
|
+
#
|
|
131
|
+
# This is currently disabled by default, but it is planned for this
|
|
132
|
+
# to be opt-out in the future.
|
|
133
|
+
docstring-code-format = false
|
|
134
|
+
|
|
135
|
+
# Set the line length limit used when formatting code snippets in
|
|
136
|
+
# docstrings.
|
|
137
|
+
#
|
|
138
|
+
# This only has an effect when the `docstring-code-format` setting is
|
|
139
|
+
# enabled.
|
|
140
|
+
docstring-code-line-length = "dynamic"
|
|
141
|
+
|
|
142
|
+
[tool.ruff.lint.pydocstyle]
|
|
143
|
+
convention = "numpy"
|
|
144
|
+
|
|
145
|
+
[tool.pytest.ini_options]
|
|
146
|
+
testpaths = [
|
|
147
|
+
"tests",
|
|
148
|
+
]
|
|
149
|
+
filterwarnings = [
|
|
150
|
+
'ignore:There is a performance drop because we have not yet implemented the batching rule for aten*:UserWarning'
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
[tool.coverage.run]
|
|
154
|
+
source = ["torch_blue"]
|
|
155
|
+
command_line = "-m pytest"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Provides layers and distributions for Bayesian neural networks."""
|