torch-blue 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torch_blue-0.9.0/LICENSE +28 -0
  2. torch_blue-0.9.0/PKG-INFO +243 -0
  3. torch_blue-0.9.0/README.md +214 -0
  4. torch_blue-0.9.0/pyproject.toml +155 -0
  5. torch_blue-0.9.0/setup.cfg +4 -0
  6. torch_blue-0.9.0/torch_blue/__init__.py +1 -0
  7. torch_blue-0.9.0/torch_blue/vi/__init__.py +48 -0
  8. torch_blue-0.9.0/torch_blue/vi/_globals.py +1 -0
  9. torch_blue-0.9.0/torch_blue/vi/analytical_kl_loss.py +443 -0
  10. torch_blue-0.9.0/torch_blue/vi/base.py +743 -0
  11. torch_blue-0.9.0/torch_blue/vi/conv.py +502 -0
  12. torch_blue-0.9.0/torch_blue/vi/distributions/__init__.py +23 -0
  13. torch_blue-0.9.0/torch_blue/vi/distributions/base.py +459 -0
  14. torch_blue-0.9.0/torch_blue/vi/distributions/categorical.py +102 -0
  15. torch_blue-0.9.0/torch_blue/vi/distributions/non_bayesian.py +172 -0
  16. torch_blue-0.9.0/torch_blue/vi/distributions/normal.py +221 -0
  17. torch_blue-0.9.0/torch_blue/vi/distributions/quiet.py +122 -0
  18. torch_blue-0.9.0/torch_blue/vi/distributions/student_t.py +133 -0
  19. torch_blue-0.9.0/torch_blue/vi/kl_loss.py +143 -0
  20. torch_blue-0.9.0/torch_blue/vi/linear.py +124 -0
  21. torch_blue-0.9.0/torch_blue/vi/sequential.py +119 -0
  22. torch_blue-0.9.0/torch_blue/vi/transformer.py +820 -0
  23. torch_blue-0.9.0/torch_blue/vi/utils/__init__.py +12 -0
  24. torch_blue-0.9.0/torch_blue/vi/utils/common_types.py +56 -0
  25. torch_blue-0.9.0/torch_blue/vi/utils/errors.py +19 -0
  26. torch_blue-0.9.0/torch_blue/vi/utils/init.py +27 -0
  27. torch_blue-0.9.0/torch_blue/vi/utils/post_init_metaclass.py +13 -0
  28. torch_blue-0.9.0/torch_blue/vi/utils/use_norm_constants.py +17 -0
  29. torch_blue-0.9.0/torch_blue/vi/utils/vi_return.py +72 -0
  30. torch_blue-0.9.0/torch_blue.egg-info/PKG-INFO +243 -0
  31. torch_blue-0.9.0/torch_blue.egg-info/SOURCES.txt +32 -0
  32. torch_blue-0.9.0/torch_blue.egg-info/dependency_links.txt +1 -0
  33. torch_blue-0.9.0/torch_blue.egg-info/requires.txt +19 -0
  34. torch_blue-0.9.0/torch_blue.egg-info/top_level.txt +1 -0
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2025, JRG Robust and Efficient AI
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,243 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch_blue
3
+ Version: 0.9.0
4
+ Summary: Library for BNNs
5
+ Author-email: RAI <rai@kit.edu>
6
+ Classifier: Programming Language :: Python
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Development Status :: 4 - Beta
9
+ Requires-Python: >=3.9
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: torch
13
+ Provides-Extra: dev
14
+ Requires-Dist: pre-commit; extra == "dev"
15
+ Requires-Dist: ruff; extra == "dev"
16
+ Requires-Dist: mypy; extra == "dev"
17
+ Requires-Dist: pytest; extra == "dev"
18
+ Requires-Dist: coverage; extra == "dev"
19
+ Requires-Dist: build; extra == "dev"
20
+ Provides-Extra: docs
21
+ Requires-Dist: sphinx; extra == "docs"
22
+ Requires-Dist: sphinx-rtd-theme; extra == "docs"
23
+ Requires-Dist: sphinx-autoapi; extra == "docs"
24
+ Requires-Dist: myst-parser; extra == "docs"
25
+ Provides-Extra: scripts
26
+ Requires-Dist: matplotlib; extra == "scripts"
27
+ Requires-Dist: torchvision; extra == "scripts"
28
+ Dynamic: license-file
29
+
30
+ # torch_blue - A PyTorch-like framework for Bayesian learning and uncertainty estimation
31
+
32
+ `torch_blue` provides a simple way for non-expert users to implement and train Bayesian
33
+ Neural Networks (BNNs). Currently, it only supports Variational Inference (VI), but will
34
+ hopefully grow and expand in the future. To make the user experience as easy as possible
35
+ most components mirror components from [PyTorch](https://pytorch.org/docs/stable/index.html).
36
+
37
+ - [Installation](#installation)
38
+ - [Documentation](#documentation)
39
+ - [Quickstart](#quickstart)
40
+ - [Level 1](#level-1)
41
+ - [Level 2](#level-2)
42
+ - [Level 3](#level-3)
43
+ - [Level 4](#level-4)
44
+
45
+ ## Installation
46
+
47
+ We heavily recommend installing `torch_blue` in a dedicated `Python3.9+`
48
+ [virtual environment](https://docs.python.org/3/library/venv.html). You can install
49
+ `torch_blue` from PyPI:
50
+
51
+ ```console
52
+ $ pip install torch-blue
53
+ ```
54
+
55
+ Alternatively, you can install `torch_blue` locally. To achieve this, there
56
+ are two steps you need to follow:
57
+
58
+ 1. Clone the repository
59
+
60
+ ```console
61
+ $ git clone https://github.com/RAI-SCC/torch_blue
62
+ ```
63
+
64
+ 2. Install the code locally
65
+
66
+ ```console
67
+ $ pip install -e .
68
+ ```
69
+
70
+ To get the development dependencies, run:
71
+
72
+ ```console
73
+ $ pip install -e .[dev]
74
+ ```
75
+
76
+ For additional dependencies required if you want to run scripts from the scripts
77
+ directory, run:
78
+
79
+ ```console
80
+ $ pip install -e .[scripts]
81
+ ```
82
+
83
+
84
+ ## Documentation
85
+
86
+ Documentation is available online at [readthedocs](https://torch-blue.readthedocs.io).
87
+
88
+ ## Quickstart
89
+
90
+ This Quickstart guide assumes basic familiarity with [PyTorch](https://pytorch.org/docs/stable/index.html)
91
+ and knowledge of how to implement the intended model in it. For a (potentially familiar)
92
+ example see `scripts/pytorch_tutorial.py`, which contains a copy of the PyTorch
93
+ [Quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
94
+ modified to train a BNN with variational inference.
95
+ Three levels are introduced:
96
+ - [Level 1](#level-1): Simple sequential layer stacks
97
+ - [Level 2](#level-2): Customizing Bayesian assumptions and VI kwargs
98
+ - [Level 3](#level-3): Non-sequential models and log probabilities
99
+ - [Level 4](#level-4): Custom modules with weights
100
+
101
+ ### Level 1
102
+
103
+ Many parts of a neural network remain completely unchanged when turning it into a BNN.
104
+ Indeed, only `Module`s containing `nn.Parameter`s, need to be changed. Therefore, if a
105
+ PyTorch model fulfills two requirements it can be transferred almost unchanged:
106
+
107
+ 1. All PyTorch `Module`s containing parameters have equivalents in this package (table below).
108
+ 2. The model can be expressed purely as a sequential application of a list of layers,
109
+ i.e. with `nn.Sequential`.
110
+
111
+ | PyTorch | vi replacement |
112
+ |------------------|-----------------|
113
+ | `nn.Linear` | `VILinear` |
114
+ | `nn.Conv1d` | `VIConv1d` |
115
+ | `nn.Conv2d` | `VIConv2d` |
116
+ | `nn.Conv3d` | `VIConv3d` |
117
+ | `nn.Transformer` | `VITransformer` |
118
+
119
+ Given these two conditions, inherit the module from `vi.VIModule` instead of `nn.Module`
120
+ and use `vi.VISequential` instead of `nn.Sequential`. Then replace all layers
121
+ containing parameters as shown in the table above. For basic usage initialize these
122
+ modules with the same arguments as their PyTorch equivalent. For advanced usage see
123
+ [Quickstart: Level 2](#level-2). Many other layers can be included as-is. In particular
124
+ activation functions, pooling, and padding (even dropout, though they
125
+ should not be necessary since the prior acts as regularization). Currently not supported
126
+ are recurrent and transposed convolution layers. Normalization layers may
127
+ have parameters depending on their setting, but can likely be left non-Bayesian.
128
+
129
+ Additionally, the loss must be replaced. To start out use `vi.KullbackLeiblerLoss`,
130
+ which requires a `Distribution` with `self.is_predictive_distribution=True` and the size
131
+ of the training dataset (this is important for balancing of assumptions and data. Choose
132
+ your `Distribution` from the table below based on the loss you would use in PyTorch.
133
+
134
+ > [!IMPORTANT]
135
+ > `KullbackLeiblerLoss` requires the length of the dataset, not the dataloader, which is
136
+ > just the number of batches.
137
+
138
+ | PyTorch | vi replacement <br/> from `vi.distributions` |
139
+ |-----------------------|----------------------------------------------|
140
+ | `nn.MSELoss` | `MeanFieldNormal` |
141
+ | `nn.CrossEntropyLoss` | `Categorical` |
142
+
143
+ > [!NOTE]
144
+ > Reasons for the requirement to use `VISequential` (and how to overcome it)
145
+ > are described in [Quickstart: Level 3](#level-3). However, adding residual connections
146
+ > from the start to the end of a block of layers can also be achieved using
147
+ > `VIResidualConnection`, which acts the same as `VISequential`, but adds the input to
148
+ > the output.
149
+
150
+ ### Level 2
151
+
152
+ While the interface of `VIModule`s is kept intentionally similar to PyTorch, there are
153
+ additional arguments that customize the Bayesian assumptions that all provided layers
154
+ accept and custom modules should generally accept and pass on to submodules:
155
+ - variational_distribution (`Distribution`): defines the weight distribution and
156
+ variational parameters. The default `MeanFieldNormal` assumes normal distributed,
157
+ uncorrelated weights described by a mean and a standard deviation. While there are
158
+ currently no alternatives the initial value of the standard deviation can be customized
159
+ here.
160
+ - prior (`Distribution`): defines the assumptions on the weight distribution and acts as
161
+ regularizer. The default `MeanFieldNormal` assumes normal distributed, uncorrelated
162
+ weights with mean 0 and standard deviation 1 (also known as a standard normal prior).
163
+ Mean and standard deviation can be adapted here. Particularly reducing the standard
164
+ deviation may help convergence at the risk of an overconfident model. Other available
165
+ priors:
166
+ - `BasicQuietPrior`: an experimental prior that correlates mean and standard deviation
167
+ to disincentivize noisy weights
168
+ - rescale_prior (`bool`): Experimental. Scales the prior similar to Kaiming-initialization.
169
+ May help with convergence, but may lead to overconfidence. Current research.
170
+ - prior_initialization (`bool`): Experimental. Initialize parameters from the prior
171
+ instead of according to standard non-Bayesian methods. May lead to much faster
172
+ convergence, but can cause the issues Kaiming-initialization counteracts unless
173
+ rescale_prior is also set to True. Current research.
174
+ - return_log_probs (`bool`): This is the topic of [Quickstart: Level 3](#level-3).
175
+
176
+ ### Level 3
177
+
178
+ For more advanced models one feature of Variational Inference (VI) needs to be taken
179
+ into account. Generally, a loss for VI will require the log probability of the actually
180
+ used weights (which are sampled on each forward pass) in the variational and prior
181
+ distribution. Since it is quite inefficient to save the samples these log probabilities
182
+ are evaluated during the forward pass and returned by the model. Since this is only
183
+ necessary for training it can be controlled with the argument `return_log_probs`. Once
184
+ the model is initialized this flag can be changed by setting `VIModule.return_log_probs`,
185
+ which either enables (`True`) or disables (`False`) the returning of the log
186
+ probabilities for all submodules.
187
+
188
+ While `torch_blue` calculates and aggregates log probs internally, this is handled
189
+ by the outermost `VIModule`. This module will not have the expected output signature
190
+ when returning log probs, but instead return a `VIReturn` object. This class is PyTorch
191
+ `Tensor` that also contains log prob information in its additional `log_probs`
192
+ attribute. This is the format `torch_blue` losses expect. Therefore, if you feed the
193
+ output directly into a loss there should be no issues. While all PyTorch tensor
194
+ operations can be performed on `VIReturns` many will delete the log prob information and
195
+ transform the object back into a `Tensor`. This needs to be considered when performing
196
+ further operations on the model output. The simplest way to avoid issues is to wrap all
197
+ operations - except the loss - in a `VIModule` since log prob aggregation is only
198
+ performed by the outermost module. For deployment `return_log_probs` should be set to
199
+ `False`. If multiple `Tensor`s are returned by the model, each will carry all log probs.
200
+
201
+ > [!NOTE]
202
+ > Always make sure your outermost module is a VIModule and keep in mind that the output
203
+ > of that module will be a `VIReturn` object, which behaves like a `Tensor`, carries
204
+ > weight log probabilities, if `return_log_probs == True`. Losses in `torch_blue`
205
+ > expect this format.
206
+
207
+ > [!NOTE]
208
+ > Due to Autosampling all output Tensors, i.e. each `VIReturn`
209
+ > in the model output and the `Tensor` containing the log probs has an additional
210
+ > dimension at the beginning representing the multiple samples necessary to properly
211
+ > evaluate the stochastic forward pass. This is only relevant for VIModules that are not
212
+ > contained within other VIModules. Loss functions are designed to expect and handle
213
+ > this output format, i.e. you can simply feed the model output into the loss and
214
+ > everything will work.
215
+
216
+ ### Level 4
217
+
218
+ Creating `VIModule`s with Bayesian weights - which are typically called random
219
+ variables in documentation and code - is arguably simpler than in PyTorch. Since a
220
+ different number of weight matrices needs to be created based on the variational
221
+ distribution, the process is completely automated. For `VIModules` without weights
222
+ `super().__init__` is called without arguments. Modules with random variables
223
+ expect `VIkwargs` (which you should be familiar with from [Level 2](#level-2)), but
224
+ defaults are used if non are passed. More importantly, `VIModules` with weights call
225
+ `super().__init__` with the argument `variable_shapes`. The keys of this dictionary are
226
+ the names of the random variables and the values the shapes of the weight matrices as
227
+ tuple or list. The value may also be set to `None`, which will always be the value
228
+ returned for that variable.
229
+
230
+ The insertion order of this dictionary matters, as it becomes the order of the names
231
+ in the module attribute `random_variables`. `random_variables`, the shapes, and a similar
232
+ attribute of the variational distribution call `distribution_parameters` are used to
233
+ dynamically create the weight matrices. The weight matrices can be accesses as
234
+ attributes of the module, which will cause a sample to be drawn and its log prob to be
235
+ stored if needed.
236
+
237
+ Should you need to access the weight tensors directly you can use `getattr` and derive
238
+ the name using the method `variational_parameter_name`.
239
+
240
+ > [!IMPORTANT]
241
+ > Every access of the weights will yield a new sample and log probability to be stored.
242
+ > Aggregation of multiple log probs is handled internally, but unnecessary calls will
243
+ > distort the result.
@@ -0,0 +1,214 @@
1
+ # torch_blue - A PyTorch-like framework for Bayesian learning and uncertainty estimation
2
+
3
+ `torch_blue` provides a simple way for non-expert users to implement and train Bayesian
4
+ Neural Networks (BNNs). Currently, it only supports Variational Inference (VI), but will
5
+ hopefully grow and expand in the future. To make the user experience as easy as possible
6
+ most components mirror components from [PyTorch](https://pytorch.org/docs/stable/index.html).
7
+
8
+ - [Installation](#installation)
9
+ - [Documentation](#documentation)
10
+ - [Quickstart](#quickstart)
11
+ - [Level 1](#level-1)
12
+ - [Level 2](#level-2)
13
+ - [Level 3](#level-3)
14
+ - [Level 4](#level-4)
15
+
16
+ ## Installation
17
+
18
+ We heavily recommend installing `torch_blue` in a dedicated `Python3.9+`
19
+ [virtual environment](https://docs.python.org/3/library/venv.html). You can install
20
+ `torch_blue` from PyPI:
21
+
22
+ ```console
23
+ $ pip install torch-blue
24
+ ```
25
+
26
+ Alternatively, you can install `torch_blue` locally. To achieve this, there
27
+ are two steps you need to follow:
28
+
29
+ 1. Clone the repository
30
+
31
+ ```console
32
+ $ git clone https://github.com/RAI-SCC/torch_blue
33
+ ```
34
+
35
+ 2. Install the code locally
36
+
37
+ ```console
38
+ $ pip install -e .
39
+ ```
40
+
41
+ To get the development dependencies, run:
42
+
43
+ ```console
44
+ $ pip install -e .[dev]
45
+ ```
46
+
47
+ For additional dependencies required if you want to run scripts from the scripts
48
+ directory, run:
49
+
50
+ ```console
51
+ $ pip install -e .[scripts]
52
+ ```
53
+
54
+
55
+ ## Documentation
56
+
57
+ Documentation is available online at [readthedocs](https://torch-blue.readthedocs.io).
58
+
59
+ ## Quickstart
60
+
61
+ This Quickstart guide assumes basic familiarity with [PyTorch](https://pytorch.org/docs/stable/index.html)
62
+ and knowledge of how to implement the intended model in it. For a (potentially familiar)
63
+ example see `scripts/pytorch_tutorial.py`, which contains a copy of the PyTorch
64
+ [Quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
65
+ modified to train a BNN with variational inference.
66
+ Three levels are introduced:
67
+ - [Level 1](#level-1): Simple sequential layer stacks
68
+ - [Level 2](#level-2): Customizing Bayesian assumptions and VI kwargs
69
+ - [Level 3](#level-3): Non-sequential models and log probabilities
70
+ - [Level 4](#level-4): Custom modules with weights
71
+
72
+ ### Level 1
73
+
74
+ Many parts of a neural network remain completely unchanged when turning it into a BNN.
75
+ Indeed, only `Module`s containing `nn.Parameter`s, need to be changed. Therefore, if a
76
+ PyTorch model fulfills two requirements it can be transferred almost unchanged:
77
+
78
+ 1. All PyTorch `Module`s containing parameters have equivalents in this package (table below).
79
+ 2. The model can be expressed purely as a sequential application of a list of layers,
80
+ i.e. with `nn.Sequential`.
81
+
82
+ | PyTorch | vi replacement |
83
+ |------------------|-----------------|
84
+ | `nn.Linear` | `VILinear` |
85
+ | `nn.Conv1d` | `VIConv1d` |
86
+ | `nn.Conv2d` | `VIConv2d` |
87
+ | `nn.Conv3d` | `VIConv3d` |
88
+ | `nn.Transformer` | `VITransformer` |
89
+
90
+ Given these two conditions, inherit the module from `vi.VIModule` instead of `nn.Module`
91
+ and use `vi.VISequential` instead of `nn.Sequential`. Then replace all layers
92
+ containing parameters as shown in the table above. For basic usage initialize these
93
+ modules with the same arguments as their PyTorch equivalent. For advanced usage see
94
+ [Quickstart: Level 2](#level-2). Many other layers can be included as-is. In particular
95
+ activation functions, pooling, and padding (even dropout, though they
96
+ should not be necessary since the prior acts as regularization). Currently not supported
97
+ are recurrent and transposed convolution layers. Normalization layers may
98
+ have parameters depending on their setting, but can likely be left non-Bayesian.
99
+
100
+ Additionally, the loss must be replaced. To start out use `vi.KullbackLeiblerLoss`,
101
+ which requires a `Distribution` with `self.is_predictive_distribution=True` and the size
102
+ of the training dataset (this is important for balancing of assumptions and data. Choose
103
+ your `Distribution` from the table below based on the loss you would use in PyTorch.
104
+
105
+ > [!IMPORTANT]
106
+ > `KullbackLeiblerLoss` requires the length of the dataset, not the dataloader, which is
107
+ > just the number of batches.
108
+
109
+ | PyTorch | vi replacement <br/> from `vi.distributions` |
110
+ |-----------------------|----------------------------------------------|
111
+ | `nn.MSELoss` | `MeanFieldNormal` |
112
+ | `nn.CrossEntropyLoss` | `Categorical` |
113
+
114
+ > [!NOTE]
115
+ > Reasons for the requirement to use `VISequential` (and how to overcome it)
116
+ > are described in [Quickstart: Level 3](#level-3). However, adding residual connections
117
+ > from the start to the end of a block of layers can also be achieved using
118
+ > `VIResidualConnection`, which acts the same as `VISequential`, but adds the input to
119
+ > the output.
120
+
121
+ ### Level 2
122
+
123
+ While the interface of `VIModule`s is kept intentionally similar to PyTorch, there are
124
+ additional arguments that customize the Bayesian assumptions that all provided layers
125
+ accept and custom modules should generally accept and pass on to submodules:
126
+ - variational_distribution (`Distribution`): defines the weight distribution and
127
+ variational parameters. The default `MeanFieldNormal` assumes normal distributed,
128
+ uncorrelated weights described by a mean and a standard deviation. While there are
129
+ currently no alternatives the initial value of the standard deviation can be customized
130
+ here.
131
+ - prior (`Distribution`): defines the assumptions on the weight distribution and acts as
132
+ regularizer. The default `MeanFieldNormal` assumes normal distributed, uncorrelated
133
+ weights with mean 0 and standard deviation 1 (also known as a standard normal prior).
134
+ Mean and standard deviation can be adapted here. Particularly reducing the standard
135
+ deviation may help convergence at the risk of an overconfident model. Other available
136
+ priors:
137
+ - `BasicQuietPrior`: an experimental prior that correlates mean and standard deviation
138
+ to disincentivize noisy weights
139
+ - rescale_prior (`bool`): Experimental. Scales the prior similar to Kaiming-initialization.
140
+ May help with convergence, but may lead to overconfidence. Current research.
141
+ - prior_initialization (`bool`): Experimental. Initialize parameters from the prior
142
+ instead of according to standard non-Bayesian methods. May lead to much faster
143
+ convergence, but can cause the issues Kaiming-initialization counteracts unless
144
+ rescale_prior is also set to True. Current research.
145
+ - return_log_probs (`bool`): This is the topic of [Quickstart: Level 3](#level-3).
146
+
147
+ ### Level 3
148
+
149
+ For more advanced models one feature of Variational Inference (VI) needs to be taken
150
+ into account. Generally, a loss for VI will require the log probability of the actually
151
+ used weights (which are sampled on each forward pass) in the variational and prior
152
+ distribution. Since it is quite inefficient to save the samples these log probabilities
153
+ are evaluated during the forward pass and returned by the model. Since this is only
154
+ necessary for training it can be controlled with the argument `return_log_probs`. Once
155
+ the model is initialized this flag can be changed by setting `VIModule.return_log_probs`,
156
+ which either enables (`True`) or disables (`False`) the returning of the log
157
+ probabilities for all submodules.
158
+
159
+ While `torch_blue` calculates and aggregates log probs internally, this is handled
160
+ by the outermost `VIModule`. This module will not have the expected output signature
161
+ when returning log probs, but instead return a `VIReturn` object. This class is PyTorch
162
+ `Tensor` that also contains log prob information in its additional `log_probs`
163
+ attribute. This is the format `torch_blue` losses expect. Therefore, if you feed the
164
+ output directly into a loss there should be no issues. While all PyTorch tensor
165
+ operations can be performed on `VIReturns` many will delete the log prob information and
166
+ transform the object back into a `Tensor`. This needs to be considered when performing
167
+ further operations on the model output. The simplest way to avoid issues is to wrap all
168
+ operations - except the loss - in a `VIModule` since log prob aggregation is only
169
+ performed by the outermost module. For deployment `return_log_probs` should be set to
170
+ `False`. If multiple `Tensor`s are returned by the model, each will carry all log probs.
171
+
172
+ > [!NOTE]
173
+ > Always make sure your outermost module is a VIModule and keep in mind that the output
174
+ > of that module will be a `VIReturn` object, which behaves like a `Tensor`, carries
175
+ > weight log probabilities, if `return_log_probs == True`. Losses in `torch_blue`
176
+ > expect this format.
177
+
178
+ > [!NOTE]
179
+ > Due to Autosampling all output Tensors, i.e. each `VIReturn`
180
+ > in the model output and the `Tensor` containing the log probs has an additional
181
+ > dimension at the beginning representing the multiple samples necessary to properly
182
+ > evaluate the stochastic forward pass. This is only relevant for VIModules that are not
183
+ > contained within other VIModules. Loss functions are designed to expect and handle
184
+ > this output format, i.e. you can simply feed the model output into the loss and
185
+ > everything will work.
186
+
187
+ ### Level 4
188
+
189
+ Creating `VIModule`s with Bayesian weights - which are typically called random
190
+ variables in documentation and code - is arguably simpler than in PyTorch. Since a
191
+ different number of weight matrices needs to be created based on the variational
192
+ distribution, the process is completely automated. For `VIModules` without weights
193
+ `super().__init__` is called without arguments. Modules with random variables
194
+ expect `VIkwargs` (which you should be familiar with from [Level 2](#level-2)), but
195
+ defaults are used if non are passed. More importantly, `VIModules` with weights call
196
+ `super().__init__` with the argument `variable_shapes`. The keys of this dictionary are
197
+ the names of the random variables and the values the shapes of the weight matrices as
198
+ tuple or list. The value may also be set to `None`, which will always be the value
199
+ returned for that variable.
200
+
201
+ The insertion order of this dictionary matters, as it becomes the order of the names
202
+ in the module attribute `random_variables`. `random_variables`, the shapes, and a similar
203
+ attribute of the variational distribution call `distribution_parameters` are used to
204
+ dynamically create the weight matrices. The weight matrices can be accesses as
205
+ attributes of the module, which will cause a sample to be drawn and its log prob to be
206
+ stored if needed.
207
+
208
+ Should you need to access the weight tensors directly you can use `getattr` and derive
209
+ the name using the method `variational_parameter_name`.
210
+
211
+ > [!IMPORTANT]
212
+ > Every access of the weights will yield a new sample and log probability to be stored.
213
+ > Aggregation of multiple log probs is handled internally, but unnecessary calls will
214
+ > distort the result.
@@ -0,0 +1,155 @@
1
+ [build-system]
2
+ requires = ["setuptools >= 61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools]
6
+ packages = [
7
+ "torch_blue",
8
+ "torch_blue.vi",
9
+ "torch_blue.vi.distributions",
10
+ "torch_blue.vi.utils",
11
+ ]
12
+
13
+ [project]
14
+ name = "torch_blue"
15
+ version = "0.9.0"
16
+ authors = [
17
+ { name="RAI", email="rai@kit.edu"},
18
+ ]
19
+ description = "Library for BNNs"
20
+ readme = "README.md"
21
+ requires-python = ">=3.9"
22
+ classifiers = [
23
+ "Programming Language :: Python",
24
+ "Programming Language :: Python :: 3",
25
+ "Development Status :: 4 - Beta",
26
+ ]
27
+ dependencies = [
28
+ "torch",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pre-commit",
34
+ "ruff",
35
+ "mypy",
36
+ "pytest",
37
+ "coverage",
38
+ "build",
39
+ ]
40
+ docs = [
41
+ "sphinx",
42
+ "sphinx-rtd-theme",
43
+ "sphinx-autoapi",
44
+ "myst-parser",
45
+ ]
46
+ scripts = [
47
+ "matplotlib",
48
+ "torchvision",
49
+ ]
50
+
51
+ [tool.mypy]
52
+ python_version = "3.9"
53
+ modules = ["torch_blue"]
54
+ disallow_untyped_defs = true
55
+ disallow_incomplete_defs = true
56
+ files = [
57
+ "scripts/"
58
+ ]
59
+
60
+
61
+ [tool.ruff]
62
+ # Exclude a variety of commonly ignored directories.
63
+ exclude = [
64
+ ".bzr",
65
+ ".direnv",
66
+ ".eggs",
67
+ ".git",
68
+ ".git-rewrite",
69
+ ".hg",
70
+ ".ipynb_checkpoints",
71
+ ".mypy_cache",
72
+ ".nox",
73
+ ".pants.d",
74
+ ".pyenv",
75
+ ".pytest_cache",
76
+ ".pytype",
77
+ ".ruff_cache",
78
+ ".svn",
79
+ ".tox",
80
+ ".venv",
81
+ ".vscode",
82
+ "__pypackages__",
83
+ "_build",
84
+ "buck-out",
85
+ "build",
86
+ "dist",
87
+ "node_modules",
88
+ "site-packages",
89
+ "venv",
90
+ ]
91
+
92
+ # Same as Black.
93
+ line-length = 88
94
+ indent-width = 4
95
+
96
+ # Assume Python 3.8.
97
+ target-version = "py39"
98
+
99
+ [tool.ruff.lint]
100
+ # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
101
+ # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
102
+ # McCabe complexity (`C901`) by default.
103
+ select = ["N", "E4", "E7", "E9", "F", "D"]
104
+ ignore = ["D100", "D404"]
105
+ # Enable import sorting
106
+ extend-select = ["I"]
107
+
108
+ # Allow fix for all enabled rules (when `--fix`) is provided.
109
+ fixable = ["ALL"]
110
+ unfixable = []
111
+
112
+ # Allow unused variables when underscore-prefixed.
113
+ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
114
+
115
+ [tool.ruff.format]
116
+ # Like Black, use double quotes for strings.
117
+ quote-style = "double"
118
+
119
+ # Like Black, indent with spaces, rather than tabs.
120
+ indent-style = "space"
121
+
122
+ # Like Black, respect magic trailing commas.
123
+ skip-magic-trailing-comma = false
124
+
125
+ # Like Black, automatically detect the appropriate line ending.
126
+ line-ending = "auto"
127
+
128
+ # Enable auto-formatting of code examples in docstrings. Markdown,
129
+ # reStructuredText code/literal blocks and doctests are all supported.
130
+ #
131
+ # This is currently disabled by default, but it is planned for this
132
+ # to be opt-out in the future.
133
+ docstring-code-format = false
134
+
135
+ # Set the line length limit used when formatting code snippets in
136
+ # docstrings.
137
+ #
138
+ # This only has an effect when the `docstring-code-format` setting is
139
+ # enabled.
140
+ docstring-code-line-length = "dynamic"
141
+
142
+ [tool.ruff.lint.pydocstyle]
143
+ convention = "numpy"
144
+
145
+ [tool.pytest.ini_options]
146
+ testpaths = [
147
+ "tests",
148
+ ]
149
+ filterwarnings = [
150
+ 'ignore:There is a performance drop because we have not yet implemented the batching rule for aten*:UserWarning'
151
+ ]
152
+
153
+ [tool.coverage.run]
154
+ source = ["torch_blue"]
155
+ command_line = "-m pytest"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1 @@
1
+ """Provides layers and distributions for Bayesian neural networks."""