mc-dropout-pytorch 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mc_dropout_pytorch-0.1.0/LICENSE +21 -0
- mc_dropout_pytorch-0.1.0/PKG-INFO +189 -0
- mc_dropout_pytorch-0.1.0/README.md +156 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch/__init__.py +8 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch/mc_dropout_pytorch.py +461 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch/version.py +1 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch.egg-info/PKG-INFO +189 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch.egg-info/SOURCES.txt +11 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch.egg-info/dependency_links.txt +1 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch.egg-info/requires.txt +5 -0
- mc_dropout_pytorch-0.1.0/mc_dropout_pytorch.egg-info/top_level.txt +1 -0
- mc_dropout_pytorch-0.1.0/setup.cfg +4 -0
- mc_dropout_pytorch-0.1.0/setup.py +37 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Phil Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mc-dropout-pytorch
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: MC Dropout (Gal & Ghahramani, 2016) - Pytorch
|
|
5
|
+
Home-page: https://github.com/lucidrains/mc-dropout-pytorch
|
|
6
|
+
Author: lucidrains
|
|
7
|
+
Author-email: lucidrains@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: artificial intelligence,deep learning,bayesian deep learning,uncertainty estimation,monte carlo dropout
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.6
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: accelerate
|
|
18
|
+
Requires-Dist: einops>=0.7
|
|
19
|
+
Requires-Dist: ema-pytorch>=0.4.2
|
|
20
|
+
Requires-Dist: torch>=2.0
|
|
21
|
+
Requires-Dist: tqdm
|
|
22
|
+
Dynamic: author
|
|
23
|
+
Dynamic: author-email
|
|
24
|
+
Dynamic: classifier
|
|
25
|
+
Dynamic: description
|
|
26
|
+
Dynamic: description-content-type
|
|
27
|
+
Dynamic: home-page
|
|
28
|
+
Dynamic: keywords
|
|
29
|
+
Dynamic: license
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: summary
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
## MC Dropout, in Pytorch
|
|
37
|
+
|
|
38
|
+
[](https://badge.fury.io/py/mc-dropout-pytorch)
|
|
39
|
+
|
|
40
|
+
Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
|
|
41
|
+
|
|
42
|
+
Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
|
|
43
|
+
|
|
44
|
+
## Install
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
$ pip install mc-dropout-pytorch
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Usage
|
|
51
|
+
|
|
52
|
+
### Regression with uncertainty
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import torch
|
|
56
|
+
from torch.utils.data import TensorDataset
|
|
57
|
+
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
|
|
58
|
+
|
|
59
|
+
# build model
|
|
60
|
+
model = BayesianMLP(
|
|
61
|
+
input_dim = 1,
|
|
62
|
+
output_dim = 1,
|
|
63
|
+
hidden_dims = (256, 256),
|
|
64
|
+
dropout_rate = 0.1,
|
|
65
|
+
activation = 'relu',
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# wrap for MC inference (T=50 stochastic passes)
|
|
69
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
|
|
70
|
+
|
|
71
|
+
x = torch.linspace(-3, 3, 100).unsqueeze(-1)
|
|
72
|
+
out = mc(x)
|
|
73
|
+
|
|
74
|
+
out.mean # predictive mean — (100, 1)
|
|
75
|
+
out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
|
|
76
|
+
out.samples # raw samples — (50, 100, 1)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
### Classification with predictive entropy
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
import torch
|
|
83
|
+
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
|
|
84
|
+
|
|
85
|
+
model = BayesianCNN(
|
|
86
|
+
in_channels = 1,
|
|
87
|
+
num_classes = 10,
|
|
88
|
+
base_channels = 32,
|
|
89
|
+
dropout_rate = 0.25,
|
|
90
|
+
fc_dropout_rate = 0.5,
|
|
91
|
+
img_size = 28,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
|
|
95
|
+
|
|
96
|
+
x = torch.randn(8, 1, 28, 28)
|
|
97
|
+
out = mc(x)
|
|
98
|
+
|
|
99
|
+
out.mean # class probabilities — (8, 10)
|
|
100
|
+
out.variance # per-class variance — (8, 10)
|
|
101
|
+
|
|
102
|
+
# active learning signals (§6)
|
|
103
|
+
H = mc.predictive_entropy(x) # (8,) — total uncertainty
|
|
104
|
+
MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
### Full training loop with the `Trainer`
|
|
108
|
+
|
|
109
|
+
```python
|
|
110
|
+
import torch
|
|
111
|
+
from torch.utils.data import TensorDataset
|
|
112
|
+
from mc_dropout_pytorch import BayesianMLP, Trainer
|
|
113
|
+
|
|
114
|
+
# synthetic regression dataset
|
|
115
|
+
X = torch.randn(1000, 4)
|
|
116
|
+
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
|
|
117
|
+
dataset = TensorDataset(X, y)
|
|
118
|
+
|
|
119
|
+
model = BayesianMLP(
|
|
120
|
+
input_dim = 4,
|
|
121
|
+
output_dim = 1,
|
|
122
|
+
hidden_dims = (128, 128),
|
|
123
|
+
dropout_rate = 0.1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
trainer = Trainer(
|
|
127
|
+
model,
|
|
128
|
+
dataset,
|
|
129
|
+
task = 'regression',
|
|
130
|
+
train_lr = 1e-3,
|
|
131
|
+
train_num_steps = 5_000,
|
|
132
|
+
train_batch_size = 64,
|
|
133
|
+
ema_decay = 0.995,
|
|
134
|
+
amp = False,
|
|
135
|
+
weight_decay = 1e-4, # ≡ prior precision in §3
|
|
136
|
+
tau = 1.0, # noise precision
|
|
137
|
+
num_mc_samples = 50,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
trainer.train()
|
|
141
|
+
|
|
142
|
+
# inference via EMA model
|
|
143
|
+
mc = trainer.inference
|
|
144
|
+
out = mc(X[:10])
|
|
145
|
+
print(out.mean, out.variance)
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
### Multi-GPU
|
|
149
|
+
|
|
150
|
+
```bash
|
|
151
|
+
$ accelerate config
|
|
152
|
+
$ accelerate launch train.py
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## Key ideas from the paper
|
|
156
|
+
|
|
157
|
+
**The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
|
|
158
|
+
|
|
159
|
+
**Test-time dropout (MC Dropout)**:
|
|
160
|
+
|
|
161
|
+
```
|
|
162
|
+
for t = 1 … T:
|
|
163
|
+
ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
|
|
164
|
+
|
|
165
|
+
E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
|
|
166
|
+
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
**Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
|
|
170
|
+
|
|
171
|
+
**Weight correspondence** (§3.2):
|
|
172
|
+
|
|
173
|
+
| Dropout training | Bayesian GP posterior |
|
|
174
|
+
|---------------------------|--------------------------|
|
|
175
|
+
| dropout probability `p` | variational parameter |
|
|
176
|
+
| L2 weight decay `λ` | prior precision |
|
|
177
|
+
| noise precision `τ` | `τ = (2N λ) / (1 − p)` |
|
|
178
|
+
|
|
179
|
+
## Citations
|
|
180
|
+
|
|
181
|
+
```bibtex
|
|
182
|
+
@article{Gal2016Dropout,
|
|
183
|
+
title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
|
|
184
|
+
author = {Yarin Gal and Zoubin Ghahramani},
|
|
185
|
+
journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
|
|
186
|
+
year = {2016},
|
|
187
|
+
url = {https://arxiv.org/abs/1506.02142}
|
|
188
|
+
}
|
|
189
|
+
```
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
## MC Dropout, in Pytorch
|
|
4
|
+
|
|
5
|
+
[](https://badge.fury.io/py/mc-dropout-pytorch)
|
|
6
|
+
|
|
7
|
+
Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
|
|
8
|
+
|
|
9
|
+
Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
|
|
10
|
+
|
|
11
|
+
## Install
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
$ pip install mc-dropout-pytorch
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Usage
|
|
18
|
+
|
|
19
|
+
### Regression with uncertainty
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
import torch
|
|
23
|
+
from torch.utils.data import TensorDataset
|
|
24
|
+
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
|
|
25
|
+
|
|
26
|
+
# build model
|
|
27
|
+
model = BayesianMLP(
|
|
28
|
+
input_dim = 1,
|
|
29
|
+
output_dim = 1,
|
|
30
|
+
hidden_dims = (256, 256),
|
|
31
|
+
dropout_rate = 0.1,
|
|
32
|
+
activation = 'relu',
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# wrap for MC inference (T=50 stochastic passes)
|
|
36
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
|
|
37
|
+
|
|
38
|
+
x = torch.linspace(-3, 3, 100).unsqueeze(-1)
|
|
39
|
+
out = mc(x)
|
|
40
|
+
|
|
41
|
+
out.mean # predictive mean — (100, 1)
|
|
42
|
+
out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
|
|
43
|
+
out.samples # raw samples — (50, 100, 1)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### Classification with predictive entropy
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
import torch
|
|
50
|
+
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
|
|
51
|
+
|
|
52
|
+
model = BayesianCNN(
|
|
53
|
+
in_channels = 1,
|
|
54
|
+
num_classes = 10,
|
|
55
|
+
base_channels = 32,
|
|
56
|
+
dropout_rate = 0.25,
|
|
57
|
+
fc_dropout_rate = 0.5,
|
|
58
|
+
img_size = 28,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
|
|
62
|
+
|
|
63
|
+
x = torch.randn(8, 1, 28, 28)
|
|
64
|
+
out = mc(x)
|
|
65
|
+
|
|
66
|
+
out.mean # class probabilities — (8, 10)
|
|
67
|
+
out.variance # per-class variance — (8, 10)
|
|
68
|
+
|
|
69
|
+
# active learning signals (§6)
|
|
70
|
+
H = mc.predictive_entropy(x) # (8,) — total uncertainty
|
|
71
|
+
MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
### Full training loop with the `Trainer`
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
import torch
|
|
78
|
+
from torch.utils.data import TensorDataset
|
|
79
|
+
from mc_dropout_pytorch import BayesianMLP, Trainer
|
|
80
|
+
|
|
81
|
+
# synthetic regression dataset
|
|
82
|
+
X = torch.randn(1000, 4)
|
|
83
|
+
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
|
|
84
|
+
dataset = TensorDataset(X, y)
|
|
85
|
+
|
|
86
|
+
model = BayesianMLP(
|
|
87
|
+
input_dim = 4,
|
|
88
|
+
output_dim = 1,
|
|
89
|
+
hidden_dims = (128, 128),
|
|
90
|
+
dropout_rate = 0.1,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
trainer = Trainer(
|
|
94
|
+
model,
|
|
95
|
+
dataset,
|
|
96
|
+
task = 'regression',
|
|
97
|
+
train_lr = 1e-3,
|
|
98
|
+
train_num_steps = 5_000,
|
|
99
|
+
train_batch_size = 64,
|
|
100
|
+
ema_decay = 0.995,
|
|
101
|
+
amp = False,
|
|
102
|
+
weight_decay = 1e-4, # ≡ prior precision in §3
|
|
103
|
+
tau = 1.0, # noise precision
|
|
104
|
+
num_mc_samples = 50,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
trainer.train()
|
|
108
|
+
|
|
109
|
+
# inference via EMA model
|
|
110
|
+
mc = trainer.inference
|
|
111
|
+
out = mc(X[:10])
|
|
112
|
+
print(out.mean, out.variance)
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### Multi-GPU
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
$ accelerate config
|
|
119
|
+
$ accelerate launch train.py
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
## Key ideas from the paper
|
|
123
|
+
|
|
124
|
+
**The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
|
|
125
|
+
|
|
126
|
+
**Test-time dropout (MC Dropout)**:
|
|
127
|
+
|
|
128
|
+
```
|
|
129
|
+
for t = 1 … T:
|
|
130
|
+
ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
|
|
131
|
+
|
|
132
|
+
E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
|
|
133
|
+
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
**Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
|
|
137
|
+
|
|
138
|
+
**Weight correspondence** (§3.2):
|
|
139
|
+
|
|
140
|
+
| Dropout training | Bayesian GP posterior |
|
|
141
|
+
|---------------------------|--------------------------|
|
|
142
|
+
| dropout probability `p` | variational parameter |
|
|
143
|
+
| L2 weight decay `λ` | prior precision |
|
|
144
|
+
| noise precision `τ` | `τ = (2N λ) / (1 − p)` |
|
|
145
|
+
|
|
146
|
+
## Citations
|
|
147
|
+
|
|
148
|
+
```bibtex
|
|
149
|
+
@article{Gal2016Dropout,
|
|
150
|
+
title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
|
|
151
|
+
author = {Yarin Gal and Zoubin Ghahramani},
|
|
152
|
+
journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
|
|
153
|
+
year = {2016},
|
|
154
|
+
url = {https://arxiv.org/abs/1506.02142}
|
|
155
|
+
}
|
|
156
|
+
```
|
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from functools import partial
|
|
4
|
+
from collections import namedtuple
|
|
5
|
+
from multiprocessing import cpu_count
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn, einsum
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
11
|
+
|
|
12
|
+
from torch.optim import Adam
|
|
13
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
14
|
+
|
|
15
|
+
from einops import rearrange, reduce, repeat
|
|
16
|
+
from einops.layers.torch import Rearrange
|
|
17
|
+
|
|
18
|
+
from tqdm.auto import tqdm
|
|
19
|
+
from ema_pytorch import EMA
|
|
20
|
+
from accelerate import Accelerator
|
|
21
|
+
|
|
22
|
+
# ──────────────────────────────────────────────
|
|
23
|
+
# constants
|
|
24
|
+
# ──────────────────────────────────────────────
|
|
25
|
+
|
|
26
|
+
ModelOutput = namedtuple('ModelOutput', ['mean', 'variance', 'samples'])
|
|
27
|
+
|
|
28
|
+
# ──────────────────────────────────────────────
|
|
29
|
+
# helpers
|
|
30
|
+
# ──────────────────────────────────────────────
|
|
31
|
+
|
|
32
|
+
def exists(x):
|
|
33
|
+
return x is not None
|
|
34
|
+
|
|
35
|
+
def default(val, d):
|
|
36
|
+
if exists(val):
|
|
37
|
+
return val
|
|
38
|
+
return d() if callable(d) else d
|
|
39
|
+
|
|
40
|
+
def identity(t, *args, **kwargs):
|
|
41
|
+
return t
|
|
42
|
+
|
|
43
|
+
def cycle(dl):
|
|
44
|
+
while True:
|
|
45
|
+
for data in dl:
|
|
46
|
+
yield data
|
|
47
|
+
|
|
48
|
+
def cast_tuple(t, length = 1):
|
|
49
|
+
if isinstance(t, tuple):
|
|
50
|
+
return t
|
|
51
|
+
return ((t,) * length)
|
|
52
|
+
|
|
53
|
+
def divisible_by(numer, denom):
|
|
54
|
+
return (numer % denom) == 0
|
|
55
|
+
|
|
56
|
+
def num_to_groups(num, divisor):
|
|
57
|
+
groups = num // divisor
|
|
58
|
+
remainder = num % divisor
|
|
59
|
+
arr = [divisor] * groups
|
|
60
|
+
if remainder > 0:
|
|
61
|
+
arr.append(remainder)
|
|
62
|
+
return arr
|
|
63
|
+
|
|
64
|
+
# ──────────────────────────────────────────────
|
|
65
|
+
# MC Dropout core: enable dropout at test time
|
|
66
|
+
# ──────────────────────────────────────────────
|
|
67
|
+
|
|
68
|
+
class MCDropout(nn.Dropout):
|
|
69
|
+
"""
|
|
70
|
+
MC Dropout layer (Gal & Ghahramani, 2016).
|
|
71
|
+
|
|
72
|
+
Standard nn.Dropout is disabled at eval time (model.eval()).
|
|
73
|
+
This subclass keeps dropout active regardless of training mode,
|
|
74
|
+
so T stochastic forward passes give a Monte Carlo estimate of
|
|
75
|
+
the posterior predictive distribution.
|
|
76
|
+
|
|
77
|
+
Eq. (8) in the paper: T forward passes → {ŷ_t} (t=1..T)
|
|
78
|
+
predictive mean : E[y*] ≈ (1/T) Σ ŷ_t
|
|
79
|
+
predictive var : Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]²
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
# keep p active even in eval mode ← key contribution of the paper
|
|
84
|
+
return F.dropout(x, self.p, training = True, inplace = self.inplace)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class MCDropout2d(nn.Dropout2d):
|
|
88
|
+
"""Spatial MC Dropout for convolutional feature maps."""
|
|
89
|
+
|
|
90
|
+
def forward(self, x):
|
|
91
|
+
return F.dropout2d(x, self.p, training = True, inplace = self.inplace)
|
|
92
|
+
|
|
93
|
+
# ──────────────────────────────────────────────
|
|
94
|
+
# small helper modules
|
|
95
|
+
# ──────────────────────────────────────────────
|
|
96
|
+
|
|
97
|
+
class RMSNorm(nn.Module):
|
|
98
|
+
def __init__(self, dim):
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.scale = dim ** 0.5
|
|
101
|
+
self.g = nn.Parameter(torch.ones(dim))
|
|
102
|
+
|
|
103
|
+
def forward(self, x):
|
|
104
|
+
return F.normalize(x, dim = -1) * self.g * self.scale
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class FeedForward(nn.Module):
|
|
108
|
+
def __init__(self, dim, mult = 4, dropout = 0.0):
|
|
109
|
+
super().__init__()
|
|
110
|
+
inner = int(dim * mult)
|
|
111
|
+
self.net = nn.Sequential(
|
|
112
|
+
RMSNorm(dim),
|
|
113
|
+
nn.Linear(dim, inner),
|
|
114
|
+
nn.GELU(),
|
|
115
|
+
MCDropout(dropout),
|
|
116
|
+
nn.Linear(inner, dim),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def forward(self, x):
|
|
120
|
+
return self.net(x)
|
|
121
|
+
|
|
122
|
+
# ──────────────────────────────────────────────
|
|
123
|
+
# Bayesian MLP — regression & classification
|
|
124
|
+
# ──────────────────────────────────────────────
|
|
125
|
+
|
|
126
|
+
class BayesianMLP(nn.Module):
|
|
127
|
+
"""
|
|
128
|
+
Dropout-regularised MLP whose test-time stochastic forward passes
|
|
129
|
+
approximate a deep Gaussian process posterior (Gal & Ghahramani, §3).
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
input_dim : int — feature dimension
|
|
134
|
+
output_dim : int — number of targets / classes
|
|
135
|
+
hidden_dims : tuple — widths of hidden layers, default (256, 256)
|
|
136
|
+
dropout_rate : float — Bernoulli dropout probability p (typically 0.1-0.5)
|
|
137
|
+
activation : str — 'relu' | 'tanh' | 'gelu' (§5 ablation)
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
activation_map = {
|
|
141
|
+
'relu' : nn.ReLU,
|
|
142
|
+
'tanh' : nn.Tanh,
|
|
143
|
+
'gelu' : nn.GELU,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
input_dim,
|
|
149
|
+
output_dim,
|
|
150
|
+
*,
|
|
151
|
+
hidden_dims = (256, 256),
|
|
152
|
+
dropout_rate = 0.1,
|
|
153
|
+
activation = 'relu',
|
|
154
|
+
):
|
|
155
|
+
super().__init__()
|
|
156
|
+
|
|
157
|
+
act_cls = self.activation_map.get(activation, nn.ReLU)
|
|
158
|
+
dims = (input_dim, *hidden_dims)
|
|
159
|
+
|
|
160
|
+
layers = []
|
|
161
|
+
for d_in, d_out in zip(dims[:-1], dims[1:]):
|
|
162
|
+
layers += [
|
|
163
|
+
nn.Linear(d_in, d_out),
|
|
164
|
+
act_cls(),
|
|
165
|
+
MCDropout(dropout_rate),
|
|
166
|
+
]
|
|
167
|
+
layers.append(nn.Linear(hidden_dims[-1], output_dim))
|
|
168
|
+
|
|
169
|
+
self.net = nn.Sequential(*layers)
|
|
170
|
+
self.dropout_rate = dropout_rate
|
|
171
|
+
|
|
172
|
+
def forward(self, x):
|
|
173
|
+
return self.net(x)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# ──────────────────────────────────────────────
|
|
177
|
+
# Bayesian CNN — for image classification (§5)
|
|
178
|
+
# ──────────────────────────────────────────────
|
|
179
|
+
|
|
180
|
+
class BayesianCNN(nn.Module):
|
|
181
|
+
"""
|
|
182
|
+
Convolutional network with MC Dropout after every conv block,
|
|
183
|
+
matching the MNIST architecture described in §5 of the paper.
|
|
184
|
+
|
|
185
|
+
Architecture: Conv → ReLU → MCDrop2d → Conv → ReLU → MCDrop2d
|
|
186
|
+
→ Flatten → Linear → ReLU → MCDrop → Linear
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
in_channels = 1,
|
|
192
|
+
num_classes = 10,
|
|
193
|
+
*,
|
|
194
|
+
base_channels = 32,
|
|
195
|
+
dropout_rate = 0.25,
|
|
196
|
+
fc_dropout_rate = 0.5,
|
|
197
|
+
img_size = 28,
|
|
198
|
+
):
|
|
199
|
+
super().__init__()
|
|
200
|
+
|
|
201
|
+
c = base_channels
|
|
202
|
+
conv_out_size = (img_size // 4) ** 2 * (c * 2)
|
|
203
|
+
|
|
204
|
+
self.conv = nn.Sequential(
|
|
205
|
+
nn.Conv2d(in_channels, c, 3, padding = 1),
|
|
206
|
+
nn.ReLU(),
|
|
207
|
+
MCDropout2d(dropout_rate),
|
|
208
|
+
nn.Conv2d(c, c * 2, 3, padding = 1),
|
|
209
|
+
nn.ReLU(),
|
|
210
|
+
nn.MaxPool2d(2),
|
|
211
|
+
MCDropout2d(dropout_rate),
|
|
212
|
+
nn.Conv2d(c * 2, c * 2, 3, padding = 1),
|
|
213
|
+
nn.ReLU(),
|
|
214
|
+
nn.MaxPool2d(2),
|
|
215
|
+
MCDropout2d(dropout_rate),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
self.head = nn.Sequential(
|
|
219
|
+
Rearrange('b c h w -> b (c h w)'),
|
|
220
|
+
nn.Linear(conv_out_size, 256),
|
|
221
|
+
nn.ReLU(),
|
|
222
|
+
MCDropout(fc_dropout_rate),
|
|
223
|
+
nn.Linear(256, num_classes),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.dropout_rate = dropout_rate
|
|
227
|
+
|
|
228
|
+
def forward(self, x):
|
|
229
|
+
return self.head(self.conv(x))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# ──────────────────────────────────────────────
|
|
233
|
+
# MC Inference — the inference wrapper
|
|
234
|
+
# ──────────────────────────────────────────────
|
|
235
|
+
|
|
236
|
+
class MCDropoutInference(nn.Module):
|
|
237
|
+
"""
|
|
238
|
+
Wraps any BayesianMLP / BayesianCNN to produce predictive
|
|
239
|
+
mean, variance and full sample tensor via T stochastic passes.
|
|
240
|
+
|
|
241
|
+
Predictive uncertainty decomposition (§3, Eq. 9):
|
|
242
|
+
τ⁻¹ — noise precision (regression length-scale term)
|
|
243
|
+
Var — model (epistemic) uncertainty from T samples
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
model : nn.Module — a BayesianMLP or BayesianCNN
|
|
248
|
+
num_samples : int — T in the paper (default 50)
|
|
249
|
+
task : str — 'regression' | 'classification'
|
|
250
|
+
tau : float — noise precision τ for regression uncertainty
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
model,
|
|
256
|
+
*,
|
|
257
|
+
num_samples = 50,
|
|
258
|
+
task = 'regression',
|
|
259
|
+
tau = 1.0,
|
|
260
|
+
):
|
|
261
|
+
super().__init__()
|
|
262
|
+
self.model = model
|
|
263
|
+
self.num_samples = num_samples
|
|
264
|
+
self.task = task
|
|
265
|
+
self.tau = tau
|
|
266
|
+
|
|
267
|
+
@torch.no_grad()
|
|
268
|
+
def forward(self, x):
|
|
269
|
+
# T stochastic forward passes — shape (T, B, output_dim)
|
|
270
|
+
samples = torch.stack(
|
|
271
|
+
[self.model(x) for _ in range(self.num_samples)],
|
|
272
|
+
dim = 0,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if self.task == 'classification':
|
|
276
|
+
# softmax each sample, then average → predictive probabilities
|
|
277
|
+
probs = samples.softmax(dim = -1) # (T, B, C)
|
|
278
|
+
mean = reduce(probs, 't b c -> b c', 'mean')
|
|
279
|
+
var = reduce(probs ** 2, 't b c -> b c', 'mean') - mean ** 2
|
|
280
|
+
else:
|
|
281
|
+
# regression: Eq. (9) — add noise precision term
|
|
282
|
+
mean = reduce(samples, 't b o -> b o', 'mean')
|
|
283
|
+
var = reduce(samples ** 2, 't b o -> b o', 'mean') \
|
|
284
|
+
- mean ** 2 \
|
|
285
|
+
+ (1.0 / self.tau)
|
|
286
|
+
|
|
287
|
+
return ModelOutput(mean = mean, variance = var, samples = samples)
|
|
288
|
+
|
|
289
|
+
def predictive_entropy(self, x):
|
|
290
|
+
"""
|
|
291
|
+
H[y | x, X, Y] — used for active learning (§6 of paper).
|
|
292
|
+
High entropy → model is uncertain → good candidate to label.
|
|
293
|
+
"""
|
|
294
|
+
out = self.forward(x)
|
|
295
|
+
# clip for numerical stability
|
|
296
|
+
p = out.mean.clamp(min = 1e-8)
|
|
297
|
+
return -(p * p.log()).sum(dim = -1)
|
|
298
|
+
|
|
299
|
+
def mutual_information(self, x):
|
|
300
|
+
"""
|
|
301
|
+
I[y, ω | x, X, Y] — epistemic (model) uncertainty (§6).
|
|
302
|
+
MI = H[y|x] − E_ω[H[y|x,ω]]
|
|
303
|
+
"""
|
|
304
|
+
out = self.forward(x)
|
|
305
|
+
# H of predictive mean
|
|
306
|
+
p_mean = out.mean.clamp(min = 1e-8)
|
|
307
|
+
h_mean = -(p_mean * p_mean.log()).sum(dim = -1)
|
|
308
|
+
|
|
309
|
+
# expected H over samples
|
|
310
|
+
probs = out.samples.softmax(dim = -1).clamp(min = 1e-8) # (T, B, C)
|
|
311
|
+
h_samples = -(probs * probs.log()).sum(dim = -1) # (T, B)
|
|
312
|
+
exp_h = reduce(h_samples, 't b -> b', 'mean')
|
|
313
|
+
|
|
314
|
+
return h_mean - exp_h
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
# ──────────────────────────────────────────────
|
|
318
|
+
# Trainer
|
|
319
|
+
# ──────────────────────────────────────────────
|
|
320
|
+
|
|
321
|
+
class Trainer:
|
|
322
|
+
"""
|
|
323
|
+
Training wrapper with accelerate + EMA for MC Dropout models.
|
|
324
|
+
|
|
325
|
+
Supports both regression (MSE) and classification (cross-entropy).
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
model : BayesianMLP | BayesianCNN
|
|
330
|
+
dataset : Dataset
|
|
331
|
+
task : 'regression' | 'classification'
|
|
332
|
+
train_lr : float — learning rate (default 1e-3)
|
|
333
|
+
train_num_steps : int — total gradient steps
|
|
334
|
+
train_batch_size: int — batch size
|
|
335
|
+
ema_decay : float — EMA decay for weight averaging
|
|
336
|
+
amp : bool — mixed precision
|
|
337
|
+
results_folder : str — where to save checkpoints
|
|
338
|
+
num_mc_samples : int — T for inference object
|
|
339
|
+
weight_decay : float — L2 regularisation (≡ prior precision in §3)
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
def __init__(
|
|
343
|
+
self,
|
|
344
|
+
model,
|
|
345
|
+
dataset,
|
|
346
|
+
*,
|
|
347
|
+
task = 'regression',
|
|
348
|
+
train_lr = 1e-3,
|
|
349
|
+
train_num_steps = 10_000,
|
|
350
|
+
train_batch_size = 128,
|
|
351
|
+
ema_decay = 0.995,
|
|
352
|
+
amp = False,
|
|
353
|
+
results_folder = './results',
|
|
354
|
+
num_mc_samples = 50,
|
|
355
|
+
weight_decay = 1e-4,
|
|
356
|
+
tau = 1.0,
|
|
357
|
+
save_every = 1000,
|
|
358
|
+
):
|
|
359
|
+
self.accelerator = Accelerator(mixed_precision = 'fp16' if amp else 'no')
|
|
360
|
+
|
|
361
|
+
self.model = model
|
|
362
|
+
self.task = task
|
|
363
|
+
self.tau = tau
|
|
364
|
+
|
|
365
|
+
self.save_every = save_every
|
|
366
|
+
self.train_num_steps = train_num_steps
|
|
367
|
+
|
|
368
|
+
self.dl = cycle(DataLoader(
|
|
369
|
+
dataset,
|
|
370
|
+
batch_size = train_batch_size,
|
|
371
|
+
shuffle = True,
|
|
372
|
+
num_workers = min(4, cpu_count()),
|
|
373
|
+
pin_memory = True,
|
|
374
|
+
))
|
|
375
|
+
|
|
376
|
+
self.opt = Adam(model.parameters(), lr = train_lr, weight_decay = weight_decay)
|
|
377
|
+
|
|
378
|
+
self.ema = EMA(model, beta = ema_decay, update_every = 10)
|
|
379
|
+
|
|
380
|
+
self.results_folder = Path(results_folder)
|
|
381
|
+
self.results_folder.mkdir(exist_ok = True)
|
|
382
|
+
|
|
383
|
+
self.model, self.opt, self.ema = self.accelerator.prepare(
|
|
384
|
+
self.model, self.opt, self.ema
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# expose inference wrapper around EMA model
|
|
388
|
+
self.inference = MCDropoutInference(
|
|
389
|
+
self.ema.ema_model,
|
|
390
|
+
num_samples = num_mc_samples,
|
|
391
|
+
task = task,
|
|
392
|
+
tau = tau,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
self.step = 0
|
|
396
|
+
|
|
397
|
+
def save(self, milestone):
|
|
398
|
+
data = {
|
|
399
|
+
'step' : self.step,
|
|
400
|
+
'model' : self.accelerator.get_state_dict(self.model),
|
|
401
|
+
'opt' : self.opt.state_dict(),
|
|
402
|
+
'ema' : self.ema.state_dict(),
|
|
403
|
+
}
|
|
404
|
+
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
|
|
405
|
+
|
|
406
|
+
def load(self, milestone):
|
|
407
|
+
data = torch.load(
|
|
408
|
+
str(self.results_folder / f'model-{milestone}.pt'),
|
|
409
|
+
map_location = self.accelerator.device,
|
|
410
|
+
)
|
|
411
|
+
model = self.accelerator.unwrap_model(self.model)
|
|
412
|
+
model.load_state_dict(data['model'])
|
|
413
|
+
self.step = data['step']
|
|
414
|
+
self.opt.load_state_dict(data['opt'])
|
|
415
|
+
self.ema.load_state_dict(data['ema'])
|
|
416
|
+
|
|
417
|
+
def train(self):
|
|
418
|
+
accelerator = self.accelerator
|
|
419
|
+
|
|
420
|
+
with tqdm(
|
|
421
|
+
initial = self.step,
|
|
422
|
+
total = self.train_num_steps,
|
|
423
|
+
disable = not accelerator.is_main_process,
|
|
424
|
+
) as pbar:
|
|
425
|
+
|
|
426
|
+
while self.step < self.train_num_steps:
|
|
427
|
+
batch = next(self.dl)
|
|
428
|
+
|
|
429
|
+
# support (x,) or (x, y) datasets
|
|
430
|
+
if isinstance(batch, (list, tuple)) and len(batch) == 2:
|
|
431
|
+
x, y = batch
|
|
432
|
+
else:
|
|
433
|
+
x = batch[0] if isinstance(batch, (list, tuple)) else batch
|
|
434
|
+
y = None
|
|
435
|
+
|
|
436
|
+
with self.accelerator.autocast():
|
|
437
|
+
logits = self.model(x)
|
|
438
|
+
|
|
439
|
+
if self.task == 'classification' and exists(y):
|
|
440
|
+
loss = F.cross_entropy(logits, y.long())
|
|
441
|
+
elif exists(y):
|
|
442
|
+
# heteroscedastic MSE — Eq. (4) negative log-likelihood
|
|
443
|
+
loss = F.mse_loss(logits.squeeze(-1), y.float())
|
|
444
|
+
else:
|
|
445
|
+
raise ValueError("Dataset must return (x, y) pairs")
|
|
446
|
+
|
|
447
|
+
self.accelerator.backward(loss)
|
|
448
|
+
self.opt.step()
|
|
449
|
+
self.opt.zero_grad()
|
|
450
|
+
self.ema.update()
|
|
451
|
+
|
|
452
|
+
pbar.set_description(f'loss: {loss.item():.4f}')
|
|
453
|
+
self.step += 1
|
|
454
|
+
|
|
455
|
+
if divisible_by(self.step, self.save_every):
|
|
456
|
+
milestone = self.step // self.save_every
|
|
457
|
+
self.save(milestone)
|
|
458
|
+
|
|
459
|
+
pbar.update(1)
|
|
460
|
+
|
|
461
|
+
accelerator.print('training complete')
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = '0.1.0'
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mc-dropout-pytorch
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: MC Dropout (Gal & Ghahramani, 2016) - Pytorch
|
|
5
|
+
Home-page: https://github.com/lucidrains/mc-dropout-pytorch
|
|
6
|
+
Author: lucidrains
|
|
7
|
+
Author-email: lucidrains@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: artificial intelligence,deep learning,bayesian deep learning,uncertainty estimation,monte carlo dropout
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.6
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: accelerate
|
|
18
|
+
Requires-Dist: einops>=0.7
|
|
19
|
+
Requires-Dist: ema-pytorch>=0.4.2
|
|
20
|
+
Requires-Dist: torch>=2.0
|
|
21
|
+
Requires-Dist: tqdm
|
|
22
|
+
Dynamic: author
|
|
23
|
+
Dynamic: author-email
|
|
24
|
+
Dynamic: classifier
|
|
25
|
+
Dynamic: description
|
|
26
|
+
Dynamic: description-content-type
|
|
27
|
+
Dynamic: home-page
|
|
28
|
+
Dynamic: keywords
|
|
29
|
+
Dynamic: license
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
Dynamic: requires-dist
|
|
32
|
+
Dynamic: summary
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
## MC Dropout, in Pytorch
|
|
37
|
+
|
|
38
|
+
[](https://badge.fury.io/py/mc-dropout-pytorch)
|
|
39
|
+
|
|
40
|
+
Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
|
|
41
|
+
|
|
42
|
+
Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
|
|
43
|
+
|
|
44
|
+
## Install
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
$ pip install mc-dropout-pytorch
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Usage
|
|
51
|
+
|
|
52
|
+
### Regression with uncertainty
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import torch
|
|
56
|
+
from torch.utils.data import TensorDataset
|
|
57
|
+
from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
|
|
58
|
+
|
|
59
|
+
# build model
|
|
60
|
+
model = BayesianMLP(
|
|
61
|
+
input_dim = 1,
|
|
62
|
+
output_dim = 1,
|
|
63
|
+
hidden_dims = (256, 256),
|
|
64
|
+
dropout_rate = 0.1,
|
|
65
|
+
activation = 'relu',
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# wrap for MC inference (T=50 stochastic passes)
|
|
69
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
|
|
70
|
+
|
|
71
|
+
x = torch.linspace(-3, 3, 100).unsqueeze(-1)
|
|
72
|
+
out = mc(x)
|
|
73
|
+
|
|
74
|
+
out.mean # predictive mean — (100, 1)
|
|
75
|
+
out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
|
|
76
|
+
out.samples # raw samples — (50, 100, 1)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
### Classification with predictive entropy
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
import torch
|
|
83
|
+
from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
|
|
84
|
+
|
|
85
|
+
model = BayesianCNN(
|
|
86
|
+
in_channels = 1,
|
|
87
|
+
num_classes = 10,
|
|
88
|
+
base_channels = 32,
|
|
89
|
+
dropout_rate = 0.25,
|
|
90
|
+
fc_dropout_rate = 0.5,
|
|
91
|
+
img_size = 28,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
|
|
95
|
+
|
|
96
|
+
x = torch.randn(8, 1, 28, 28)
|
|
97
|
+
out = mc(x)
|
|
98
|
+
|
|
99
|
+
out.mean # class probabilities — (8, 10)
|
|
100
|
+
out.variance # per-class variance — (8, 10)
|
|
101
|
+
|
|
102
|
+
# active learning signals (§6)
|
|
103
|
+
H = mc.predictive_entropy(x) # (8,) — total uncertainty
|
|
104
|
+
MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
### Full training loop with the `Trainer`
|
|
108
|
+
|
|
109
|
+
```python
|
|
110
|
+
import torch
|
|
111
|
+
from torch.utils.data import TensorDataset
|
|
112
|
+
from mc_dropout_pytorch import BayesianMLP, Trainer
|
|
113
|
+
|
|
114
|
+
# synthetic regression dataset
|
|
115
|
+
X = torch.randn(1000, 4)
|
|
116
|
+
y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
|
|
117
|
+
dataset = TensorDataset(X, y)
|
|
118
|
+
|
|
119
|
+
model = BayesianMLP(
|
|
120
|
+
input_dim = 4,
|
|
121
|
+
output_dim = 1,
|
|
122
|
+
hidden_dims = (128, 128),
|
|
123
|
+
dropout_rate = 0.1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
trainer = Trainer(
|
|
127
|
+
model,
|
|
128
|
+
dataset,
|
|
129
|
+
task = 'regression',
|
|
130
|
+
train_lr = 1e-3,
|
|
131
|
+
train_num_steps = 5_000,
|
|
132
|
+
train_batch_size = 64,
|
|
133
|
+
ema_decay = 0.995,
|
|
134
|
+
amp = False,
|
|
135
|
+
weight_decay = 1e-4, # ≡ prior precision in §3
|
|
136
|
+
tau = 1.0, # noise precision
|
|
137
|
+
num_mc_samples = 50,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
trainer.train()
|
|
141
|
+
|
|
142
|
+
# inference via EMA model
|
|
143
|
+
mc = trainer.inference
|
|
144
|
+
out = mc(X[:10])
|
|
145
|
+
print(out.mean, out.variance)
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
### Multi-GPU
|
|
149
|
+
|
|
150
|
+
```bash
|
|
151
|
+
$ accelerate config
|
|
152
|
+
$ accelerate launch train.py
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## Key ideas from the paper
|
|
156
|
+
|
|
157
|
+
**The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
|
|
158
|
+
|
|
159
|
+
**Test-time dropout (MC Dropout)**:
|
|
160
|
+
|
|
161
|
+
```
|
|
162
|
+
for t = 1 … T:
|
|
163
|
+
ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
|
|
164
|
+
|
|
165
|
+
E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
|
|
166
|
+
Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
**Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
|
|
170
|
+
|
|
171
|
+
**Weight correspondence** (§3.2):
|
|
172
|
+
|
|
173
|
+
| Dropout training | Bayesian GP posterior |
|
|
174
|
+
|---------------------------|--------------------------|
|
|
175
|
+
| dropout probability `p` | variational parameter |
|
|
176
|
+
| L2 weight decay `λ` | prior precision |
|
|
177
|
+
| noise precision `τ` | `τ = (2N λ) / (1 − p)` |
|
|
178
|
+
|
|
179
|
+
## Citations
|
|
180
|
+
|
|
181
|
+
```bibtex
|
|
182
|
+
@article{Gal2016Dropout,
|
|
183
|
+
title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
|
|
184
|
+
author = {Yarin Gal and Zoubin Ghahramani},
|
|
185
|
+
journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
|
|
186
|
+
year = {2016},
|
|
187
|
+
url = {https://arxiv.org/abs/1506.02142}
|
|
188
|
+
}
|
|
189
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
mc_dropout_pytorch/__init__.py
|
|
5
|
+
mc_dropout_pytorch/mc_dropout_pytorch.py
|
|
6
|
+
mc_dropout_pytorch/version.py
|
|
7
|
+
mc_dropout_pytorch.egg-info/PKG-INFO
|
|
8
|
+
mc_dropout_pytorch.egg-info/SOURCES.txt
|
|
9
|
+
mc_dropout_pytorch.egg-info/dependency_links.txt
|
|
10
|
+
mc_dropout_pytorch.egg-info/requires.txt
|
|
11
|
+
mc_dropout_pytorch.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mc_dropout_pytorch
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
exec(open('mc_dropout_pytorch/version.py').read())
|
|
4
|
+
|
|
5
|
+
setup(
|
|
6
|
+
name = 'mc-dropout-pytorch',
|
|
7
|
+
packages = find_packages(),
|
|
8
|
+
version = __version__,
|
|
9
|
+
license = 'MIT',
|
|
10
|
+
description = 'MC Dropout (Gal & Ghahramani, 2016) - Pytorch',
|
|
11
|
+
long_description = open('README.md').read(),
|
|
12
|
+
long_description_content_type = 'text/markdown',
|
|
13
|
+
author = 'lucidrains',
|
|
14
|
+
author_email = 'lucidrains@gmail.com',
|
|
15
|
+
url = 'https://github.com/lucidrains/mc-dropout-pytorch',
|
|
16
|
+
keywords = [
|
|
17
|
+
'artificial intelligence',
|
|
18
|
+
'deep learning',
|
|
19
|
+
'bayesian deep learning',
|
|
20
|
+
'uncertainty estimation',
|
|
21
|
+
'monte carlo dropout',
|
|
22
|
+
],
|
|
23
|
+
install_requires = [
|
|
24
|
+
'accelerate',
|
|
25
|
+
'einops>=0.7',
|
|
26
|
+
'ema-pytorch>=0.4.2',
|
|
27
|
+
'torch>=2.0',
|
|
28
|
+
'tqdm',
|
|
29
|
+
],
|
|
30
|
+
classifiers = [
|
|
31
|
+
'Development Status :: 4 - Beta',
|
|
32
|
+
'Intended Audience :: Developers',
|
|
33
|
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
|
34
|
+
'License :: OSI Approved :: MIT License',
|
|
35
|
+
'Programming Language :: Python :: 3.6',
|
|
36
|
+
],
|
|
37
|
+
)
|