mc-dropout-pytorch 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,189 @@
1
+ Metadata-Version: 2.4
2
+ Name: mc-dropout-pytorch
3
+ Version: 0.1.0
4
+ Summary: MC Dropout (Gal & Ghahramani, 2016) - Pytorch
5
+ Home-page: https://github.com/lucidrains/mc-dropout-pytorch
6
+ Author: lucidrains
7
+ Author-email: lucidrains@gmail.com
8
+ License: MIT
9
+ Keywords: artificial intelligence,deep learning,bayesian deep learning,uncertainty estimation,monte carlo dropout
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3.6
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: accelerate
18
+ Requires-Dist: einops>=0.7
19
+ Requires-Dist: ema-pytorch>=0.4.2
20
+ Requires-Dist: torch>=2.0
21
+ Requires-Dist: tqdm
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description
26
+ Dynamic: description-content-type
27
+ Dynamic: home-page
28
+ Dynamic: keywords
29
+ Dynamic: license
30
+ Dynamic: license-file
31
+ Dynamic: requires-dist
32
+ Dynamic: summary
33
+
34
+
35
+
36
+ ## MC Dropout, in Pytorch
37
+
38
+ [![PyPI version](https://badge.fury.io/py/mc-dropout-pytorch.svg)](https://badge.fury.io/py/mc-dropout-pytorch)
39
+
40
+ Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
41
+
42
+ Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
43
+
44
+ ## Install
45
+
46
+ ```bash
47
+ $ pip install mc-dropout-pytorch
48
+ ```
49
+
50
+ ## Usage
51
+
52
+ ### Regression with uncertainty
53
+
54
+ ```python
55
+ import torch
56
+ from torch.utils.data import TensorDataset
57
+ from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
58
+
59
+ # build model
60
+ model = BayesianMLP(
61
+ input_dim = 1,
62
+ output_dim = 1,
63
+ hidden_dims = (256, 256),
64
+ dropout_rate = 0.1,
65
+ activation = 'relu',
66
+ )
67
+
68
+ # wrap for MC inference (T=50 stochastic passes)
69
+ mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
70
+
71
+ x = torch.linspace(-3, 3, 100).unsqueeze(-1)
72
+ out = mc(x)
73
+
74
+ out.mean # predictive mean — (100, 1)
75
+ out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
76
+ out.samples # raw samples — (50, 100, 1)
77
+ ```
78
+
79
+ ### Classification with predictive entropy
80
+
81
+ ```python
82
+ import torch
83
+ from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
84
+
85
+ model = BayesianCNN(
86
+ in_channels = 1,
87
+ num_classes = 10,
88
+ base_channels = 32,
89
+ dropout_rate = 0.25,
90
+ fc_dropout_rate = 0.5,
91
+ img_size = 28,
92
+ )
93
+
94
+ mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
95
+
96
+ x = torch.randn(8, 1, 28, 28)
97
+ out = mc(x)
98
+
99
+ out.mean # class probabilities — (8, 10)
100
+ out.variance # per-class variance — (8, 10)
101
+
102
+ # active learning signals (§6)
103
+ H = mc.predictive_entropy(x) # (8,) — total uncertainty
104
+ MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
105
+ ```
106
+
107
+ ### Full training loop with the `Trainer`
108
+
109
+ ```python
110
+ import torch
111
+ from torch.utils.data import TensorDataset
112
+ from mc_dropout_pytorch import BayesianMLP, Trainer
113
+
114
+ # synthetic regression dataset
115
+ X = torch.randn(1000, 4)
116
+ y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
117
+ dataset = TensorDataset(X, y)
118
+
119
+ model = BayesianMLP(
120
+ input_dim = 4,
121
+ output_dim = 1,
122
+ hidden_dims = (128, 128),
123
+ dropout_rate = 0.1,
124
+ )
125
+
126
+ trainer = Trainer(
127
+ model,
128
+ dataset,
129
+ task = 'regression',
130
+ train_lr = 1e-3,
131
+ train_num_steps = 5_000,
132
+ train_batch_size = 64,
133
+ ema_decay = 0.995,
134
+ amp = False,
135
+ weight_decay = 1e-4, # ≡ prior precision in §3
136
+ tau = 1.0, # noise precision
137
+ num_mc_samples = 50,
138
+ )
139
+
140
+ trainer.train()
141
+
142
+ # inference via EMA model
143
+ mc = trainer.inference
144
+ out = mc(X[:10])
145
+ print(out.mean, out.variance)
146
+ ```
147
+
148
+ ### Multi-GPU
149
+
150
+ ```bash
151
+ $ accelerate config
152
+ $ accelerate launch train.py
153
+ ```
154
+
155
+ ## Key ideas from the paper
156
+
157
+ **The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
158
+
159
+ **Test-time dropout (MC Dropout)**:
160
+
161
+ ```
162
+ for t = 1 … T:
163
+ ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
164
+
165
+ E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
166
+ Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
167
+ ```
168
+
169
+ **Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
170
+
171
+ **Weight correspondence** (§3.2):
172
+
173
+ | Dropout training | Bayesian GP posterior |
174
+ |---------------------------|--------------------------|
175
+ | dropout probability `p` | variational parameter |
176
+ | L2 weight decay `λ` | prior precision |
177
+ | noise precision `τ` | `τ = (2N λ) / (1 − p)` |
178
+
179
+ ## Citations
180
+
181
+ ```bibtex
182
+ @article{Gal2016Dropout,
183
+ title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
184
+ author = {Yarin Gal and Zoubin Ghahramani},
185
+ journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
186
+ year = {2016},
187
+ url = {https://arxiv.org/abs/1506.02142}
188
+ }
189
+ ```
@@ -0,0 +1,156 @@
1
+
2
+
3
+ ## MC Dropout, in Pytorch
4
+
5
+ [![PyPI version](https://badge.fury.io/py/mc-dropout-pytorch.svg)](https://badge.fury.io/py/mc-dropout-pytorch)
6
+
7
+ Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
8
+
9
+ Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
10
+
11
+ ## Install
12
+
13
+ ```bash
14
+ $ pip install mc-dropout-pytorch
15
+ ```
16
+
17
+ ## Usage
18
+
19
+ ### Regression with uncertainty
20
+
21
+ ```python
22
+ import torch
23
+ from torch.utils.data import TensorDataset
24
+ from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
25
+
26
+ # build model
27
+ model = BayesianMLP(
28
+ input_dim = 1,
29
+ output_dim = 1,
30
+ hidden_dims = (256, 256),
31
+ dropout_rate = 0.1,
32
+ activation = 'relu',
33
+ )
34
+
35
+ # wrap for MC inference (T=50 stochastic passes)
36
+ mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
37
+
38
+ x = torch.linspace(-3, 3, 100).unsqueeze(-1)
39
+ out = mc(x)
40
+
41
+ out.mean # predictive mean — (100, 1)
42
+ out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
43
+ out.samples # raw samples — (50, 100, 1)
44
+ ```
45
+
46
+ ### Classification with predictive entropy
47
+
48
+ ```python
49
+ import torch
50
+ from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
51
+
52
+ model = BayesianCNN(
53
+ in_channels = 1,
54
+ num_classes = 10,
55
+ base_channels = 32,
56
+ dropout_rate = 0.25,
57
+ fc_dropout_rate = 0.5,
58
+ img_size = 28,
59
+ )
60
+
61
+ mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
62
+
63
+ x = torch.randn(8, 1, 28, 28)
64
+ out = mc(x)
65
+
66
+ out.mean # class probabilities — (8, 10)
67
+ out.variance # per-class variance — (8, 10)
68
+
69
+ # active learning signals (§6)
70
+ H = mc.predictive_entropy(x) # (8,) — total uncertainty
71
+ MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
72
+ ```
73
+
74
+ ### Full training loop with the `Trainer`
75
+
76
+ ```python
77
+ import torch
78
+ from torch.utils.data import TensorDataset
79
+ from mc_dropout_pytorch import BayesianMLP, Trainer
80
+
81
+ # synthetic regression dataset
82
+ X = torch.randn(1000, 4)
83
+ y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
84
+ dataset = TensorDataset(X, y)
85
+
86
+ model = BayesianMLP(
87
+ input_dim = 4,
88
+ output_dim = 1,
89
+ hidden_dims = (128, 128),
90
+ dropout_rate = 0.1,
91
+ )
92
+
93
+ trainer = Trainer(
94
+ model,
95
+ dataset,
96
+ task = 'regression',
97
+ train_lr = 1e-3,
98
+ train_num_steps = 5_000,
99
+ train_batch_size = 64,
100
+ ema_decay = 0.995,
101
+ amp = False,
102
+ weight_decay = 1e-4, # ≡ prior precision in §3
103
+ tau = 1.0, # noise precision
104
+ num_mc_samples = 50,
105
+ )
106
+
107
+ trainer.train()
108
+
109
+ # inference via EMA model
110
+ mc = trainer.inference
111
+ out = mc(X[:10])
112
+ print(out.mean, out.variance)
113
+ ```
114
+
115
+ ### Multi-GPU
116
+
117
+ ```bash
118
+ $ accelerate config
119
+ $ accelerate launch train.py
120
+ ```
121
+
122
+ ## Key ideas from the paper
123
+
124
+ **The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
125
+
126
+ **Test-time dropout (MC Dropout)**:
127
+
128
+ ```
129
+ for t = 1 … T:
130
+ ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
131
+
132
+ E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
133
+ Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
134
+ ```
135
+
136
+ **Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
137
+
138
+ **Weight correspondence** (§3.2):
139
+
140
+ | Dropout training | Bayesian GP posterior |
141
+ |---------------------------|--------------------------|
142
+ | dropout probability `p` | variational parameter |
143
+ | L2 weight decay `λ` | prior precision |
144
+ | noise precision `τ` | `τ = (2N λ) / (1 − p)` |
145
+
146
+ ## Citations
147
+
148
+ ```bibtex
149
+ @article{Gal2016Dropout,
150
+ title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
151
+ author = {Yarin Gal and Zoubin Ghahramani},
152
+ journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
153
+ year = {2016},
154
+ url = {https://arxiv.org/abs/1506.02142}
155
+ }
156
+ ```
@@ -0,0 +1,8 @@
1
+ from mc_dropout_pytorch.mc_dropout_pytorch import (
2
+ MCDropout,
3
+ MCDropout2d,
4
+ BayesianMLP,
5
+ BayesianCNN,
6
+ MCDropoutInference,
7
+ Trainer,
8
+ )
@@ -0,0 +1,461 @@
1
+ import math
2
+ from pathlib import Path
3
+ from functools import partial
4
+ from collections import namedtuple
5
+ from multiprocessing import cpu_count
6
+
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
11
+
12
+ from torch.optim import Adam
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+
15
+ from einops import rearrange, reduce, repeat
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from tqdm.auto import tqdm
19
+ from ema_pytorch import EMA
20
+ from accelerate import Accelerator
21
+
22
+ # ──────────────────────────────────────────────
23
+ # constants
24
+ # ──────────────────────────────────────────────
25
+
26
+ ModelOutput = namedtuple('ModelOutput', ['mean', 'variance', 'samples'])
27
+
28
+ # ──────────────────────────────────────────────
29
+ # helpers
30
+ # ──────────────────────────────────────────────
31
+
32
+ def exists(x):
33
+ return x is not None
34
+
35
+ def default(val, d):
36
+ if exists(val):
37
+ return val
38
+ return d() if callable(d) else d
39
+
40
+ def identity(t, *args, **kwargs):
41
+ return t
42
+
43
+ def cycle(dl):
44
+ while True:
45
+ for data in dl:
46
+ yield data
47
+
48
+ def cast_tuple(t, length = 1):
49
+ if isinstance(t, tuple):
50
+ return t
51
+ return ((t,) * length)
52
+
53
+ def divisible_by(numer, denom):
54
+ return (numer % denom) == 0
55
+
56
+ def num_to_groups(num, divisor):
57
+ groups = num // divisor
58
+ remainder = num % divisor
59
+ arr = [divisor] * groups
60
+ if remainder > 0:
61
+ arr.append(remainder)
62
+ return arr
63
+
64
+ # ──────────────────────────────────────────────
65
+ # MC Dropout core: enable dropout at test time
66
+ # ──────────────────────────────────────────────
67
+
68
+ class MCDropout(nn.Dropout):
69
+ """
70
+ MC Dropout layer (Gal & Ghahramani, 2016).
71
+
72
+ Standard nn.Dropout is disabled at eval time (model.eval()).
73
+ This subclass keeps dropout active regardless of training mode,
74
+ so T stochastic forward passes give a Monte Carlo estimate of
75
+ the posterior predictive distribution.
76
+
77
+ Eq. (8) in the paper: T forward passes → {ŷ_t} (t=1..T)
78
+ predictive mean : E[y*] ≈ (1/T) Σ ŷ_t
79
+ predictive var : Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]²
80
+ """
81
+
82
+ def forward(self, x):
83
+ # keep p active even in eval mode ← key contribution of the paper
84
+ return F.dropout(x, self.p, training = True, inplace = self.inplace)
85
+
86
+
87
+ class MCDropout2d(nn.Dropout2d):
88
+ """Spatial MC Dropout for convolutional feature maps."""
89
+
90
+ def forward(self, x):
91
+ return F.dropout2d(x, self.p, training = True, inplace = self.inplace)
92
+
93
+ # ──────────────────────────────────────────────
94
+ # small helper modules
95
+ # ──────────────────────────────────────────────
96
+
97
+ class RMSNorm(nn.Module):
98
+ def __init__(self, dim):
99
+ super().__init__()
100
+ self.scale = dim ** 0.5
101
+ self.g = nn.Parameter(torch.ones(dim))
102
+
103
+ def forward(self, x):
104
+ return F.normalize(x, dim = -1) * self.g * self.scale
105
+
106
+
107
+ class FeedForward(nn.Module):
108
+ def __init__(self, dim, mult = 4, dropout = 0.0):
109
+ super().__init__()
110
+ inner = int(dim * mult)
111
+ self.net = nn.Sequential(
112
+ RMSNorm(dim),
113
+ nn.Linear(dim, inner),
114
+ nn.GELU(),
115
+ MCDropout(dropout),
116
+ nn.Linear(inner, dim),
117
+ )
118
+
119
+ def forward(self, x):
120
+ return self.net(x)
121
+
122
+ # ──────────────────────────────────────────────
123
+ # Bayesian MLP — regression & classification
124
+ # ──────────────────────────────────────────────
125
+
126
+ class BayesianMLP(nn.Module):
127
+ """
128
+ Dropout-regularised MLP whose test-time stochastic forward passes
129
+ approximate a deep Gaussian process posterior (Gal & Ghahramani, §3).
130
+
131
+ Parameters
132
+ ----------
133
+ input_dim : int — feature dimension
134
+ output_dim : int — number of targets / classes
135
+ hidden_dims : tuple — widths of hidden layers, default (256, 256)
136
+ dropout_rate : float — Bernoulli dropout probability p (typically 0.1-0.5)
137
+ activation : str — 'relu' | 'tanh' | 'gelu' (§5 ablation)
138
+ """
139
+
140
+ activation_map = {
141
+ 'relu' : nn.ReLU,
142
+ 'tanh' : nn.Tanh,
143
+ 'gelu' : nn.GELU,
144
+ }
145
+
146
+ def __init__(
147
+ self,
148
+ input_dim,
149
+ output_dim,
150
+ *,
151
+ hidden_dims = (256, 256),
152
+ dropout_rate = 0.1,
153
+ activation = 'relu',
154
+ ):
155
+ super().__init__()
156
+
157
+ act_cls = self.activation_map.get(activation, nn.ReLU)
158
+ dims = (input_dim, *hidden_dims)
159
+
160
+ layers = []
161
+ for d_in, d_out in zip(dims[:-1], dims[1:]):
162
+ layers += [
163
+ nn.Linear(d_in, d_out),
164
+ act_cls(),
165
+ MCDropout(dropout_rate),
166
+ ]
167
+ layers.append(nn.Linear(hidden_dims[-1], output_dim))
168
+
169
+ self.net = nn.Sequential(*layers)
170
+ self.dropout_rate = dropout_rate
171
+
172
+ def forward(self, x):
173
+ return self.net(x)
174
+
175
+
176
+ # ──────────────────────────────────────────────
177
+ # Bayesian CNN — for image classification (§5)
178
+ # ──────────────────────────────────────────────
179
+
180
+ class BayesianCNN(nn.Module):
181
+ """
182
+ Convolutional network with MC Dropout after every conv block,
183
+ matching the MNIST architecture described in §5 of the paper.
184
+
185
+ Architecture: Conv → ReLU → MCDrop2d → Conv → ReLU → MCDrop2d
186
+ → Flatten → Linear → ReLU → MCDrop → Linear
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ in_channels = 1,
192
+ num_classes = 10,
193
+ *,
194
+ base_channels = 32,
195
+ dropout_rate = 0.25,
196
+ fc_dropout_rate = 0.5,
197
+ img_size = 28,
198
+ ):
199
+ super().__init__()
200
+
201
+ c = base_channels
202
+ conv_out_size = (img_size // 4) ** 2 * (c * 2)
203
+
204
+ self.conv = nn.Sequential(
205
+ nn.Conv2d(in_channels, c, 3, padding = 1),
206
+ nn.ReLU(),
207
+ MCDropout2d(dropout_rate),
208
+ nn.Conv2d(c, c * 2, 3, padding = 1),
209
+ nn.ReLU(),
210
+ nn.MaxPool2d(2),
211
+ MCDropout2d(dropout_rate),
212
+ nn.Conv2d(c * 2, c * 2, 3, padding = 1),
213
+ nn.ReLU(),
214
+ nn.MaxPool2d(2),
215
+ MCDropout2d(dropout_rate),
216
+ )
217
+
218
+ self.head = nn.Sequential(
219
+ Rearrange('b c h w -> b (c h w)'),
220
+ nn.Linear(conv_out_size, 256),
221
+ nn.ReLU(),
222
+ MCDropout(fc_dropout_rate),
223
+ nn.Linear(256, num_classes),
224
+ )
225
+
226
+ self.dropout_rate = dropout_rate
227
+
228
+ def forward(self, x):
229
+ return self.head(self.conv(x))
230
+
231
+
232
+ # ──────────────────────────────────────────────
233
+ # MC Inference — the inference wrapper
234
+ # ──────────────────────────────────────────────
235
+
236
+ class MCDropoutInference(nn.Module):
237
+ """
238
+ Wraps any BayesianMLP / BayesianCNN to produce predictive
239
+ mean, variance and full sample tensor via T stochastic passes.
240
+
241
+ Predictive uncertainty decomposition (§3, Eq. 9):
242
+ τ⁻¹ — noise precision (regression length-scale term)
243
+ Var — model (epistemic) uncertainty from T samples
244
+
245
+ Parameters
246
+ ----------
247
+ model : nn.Module — a BayesianMLP or BayesianCNN
248
+ num_samples : int — T in the paper (default 50)
249
+ task : str — 'regression' | 'classification'
250
+ tau : float — noise precision τ for regression uncertainty
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ model,
256
+ *,
257
+ num_samples = 50,
258
+ task = 'regression',
259
+ tau = 1.0,
260
+ ):
261
+ super().__init__()
262
+ self.model = model
263
+ self.num_samples = num_samples
264
+ self.task = task
265
+ self.tau = tau
266
+
267
+ @torch.no_grad()
268
+ def forward(self, x):
269
+ # T stochastic forward passes — shape (T, B, output_dim)
270
+ samples = torch.stack(
271
+ [self.model(x) for _ in range(self.num_samples)],
272
+ dim = 0,
273
+ )
274
+
275
+ if self.task == 'classification':
276
+ # softmax each sample, then average → predictive probabilities
277
+ probs = samples.softmax(dim = -1) # (T, B, C)
278
+ mean = reduce(probs, 't b c -> b c', 'mean')
279
+ var = reduce(probs ** 2, 't b c -> b c', 'mean') - mean ** 2
280
+ else:
281
+ # regression: Eq. (9) — add noise precision term
282
+ mean = reduce(samples, 't b o -> b o', 'mean')
283
+ var = reduce(samples ** 2, 't b o -> b o', 'mean') \
284
+ - mean ** 2 \
285
+ + (1.0 / self.tau)
286
+
287
+ return ModelOutput(mean = mean, variance = var, samples = samples)
288
+
289
+ def predictive_entropy(self, x):
290
+ """
291
+ H[y | x, X, Y] — used for active learning (§6 of paper).
292
+ High entropy → model is uncertain → good candidate to label.
293
+ """
294
+ out = self.forward(x)
295
+ # clip for numerical stability
296
+ p = out.mean.clamp(min = 1e-8)
297
+ return -(p * p.log()).sum(dim = -1)
298
+
299
+ def mutual_information(self, x):
300
+ """
301
+ I[y, ω | x, X, Y] — epistemic (model) uncertainty (§6).
302
+ MI = H[y|x] − E_ω[H[y|x,ω]]
303
+ """
304
+ out = self.forward(x)
305
+ # H of predictive mean
306
+ p_mean = out.mean.clamp(min = 1e-8)
307
+ h_mean = -(p_mean * p_mean.log()).sum(dim = -1)
308
+
309
+ # expected H over samples
310
+ probs = out.samples.softmax(dim = -1).clamp(min = 1e-8) # (T, B, C)
311
+ h_samples = -(probs * probs.log()).sum(dim = -1) # (T, B)
312
+ exp_h = reduce(h_samples, 't b -> b', 'mean')
313
+
314
+ return h_mean - exp_h
315
+
316
+
317
+ # ──────────────────────────────────────────────
318
+ # Trainer
319
+ # ──────────────────────────────────────────────
320
+
321
+ class Trainer:
322
+ """
323
+ Training wrapper with accelerate + EMA for MC Dropout models.
324
+
325
+ Supports both regression (MSE) and classification (cross-entropy).
326
+
327
+ Parameters
328
+ ----------
329
+ model : BayesianMLP | BayesianCNN
330
+ dataset : Dataset
331
+ task : 'regression' | 'classification'
332
+ train_lr : float — learning rate (default 1e-3)
333
+ train_num_steps : int — total gradient steps
334
+ train_batch_size: int — batch size
335
+ ema_decay : float — EMA decay for weight averaging
336
+ amp : bool — mixed precision
337
+ results_folder : str — where to save checkpoints
338
+ num_mc_samples : int — T for inference object
339
+ weight_decay : float — L2 regularisation (≡ prior precision in §3)
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ model,
345
+ dataset,
346
+ *,
347
+ task = 'regression',
348
+ train_lr = 1e-3,
349
+ train_num_steps = 10_000,
350
+ train_batch_size = 128,
351
+ ema_decay = 0.995,
352
+ amp = False,
353
+ results_folder = './results',
354
+ num_mc_samples = 50,
355
+ weight_decay = 1e-4,
356
+ tau = 1.0,
357
+ save_every = 1000,
358
+ ):
359
+ self.accelerator = Accelerator(mixed_precision = 'fp16' if amp else 'no')
360
+
361
+ self.model = model
362
+ self.task = task
363
+ self.tau = tau
364
+
365
+ self.save_every = save_every
366
+ self.train_num_steps = train_num_steps
367
+
368
+ self.dl = cycle(DataLoader(
369
+ dataset,
370
+ batch_size = train_batch_size,
371
+ shuffle = True,
372
+ num_workers = min(4, cpu_count()),
373
+ pin_memory = True,
374
+ ))
375
+
376
+ self.opt = Adam(model.parameters(), lr = train_lr, weight_decay = weight_decay)
377
+
378
+ self.ema = EMA(model, beta = ema_decay, update_every = 10)
379
+
380
+ self.results_folder = Path(results_folder)
381
+ self.results_folder.mkdir(exist_ok = True)
382
+
383
+ self.model, self.opt, self.ema = self.accelerator.prepare(
384
+ self.model, self.opt, self.ema
385
+ )
386
+
387
+ # expose inference wrapper around EMA model
388
+ self.inference = MCDropoutInference(
389
+ self.ema.ema_model,
390
+ num_samples = num_mc_samples,
391
+ task = task,
392
+ tau = tau,
393
+ )
394
+
395
+ self.step = 0
396
+
397
+ def save(self, milestone):
398
+ data = {
399
+ 'step' : self.step,
400
+ 'model' : self.accelerator.get_state_dict(self.model),
401
+ 'opt' : self.opt.state_dict(),
402
+ 'ema' : self.ema.state_dict(),
403
+ }
404
+ torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
405
+
406
+ def load(self, milestone):
407
+ data = torch.load(
408
+ str(self.results_folder / f'model-{milestone}.pt'),
409
+ map_location = self.accelerator.device,
410
+ )
411
+ model = self.accelerator.unwrap_model(self.model)
412
+ model.load_state_dict(data['model'])
413
+ self.step = data['step']
414
+ self.opt.load_state_dict(data['opt'])
415
+ self.ema.load_state_dict(data['ema'])
416
+
417
+ def train(self):
418
+ accelerator = self.accelerator
419
+
420
+ with tqdm(
421
+ initial = self.step,
422
+ total = self.train_num_steps,
423
+ disable = not accelerator.is_main_process,
424
+ ) as pbar:
425
+
426
+ while self.step < self.train_num_steps:
427
+ batch = next(self.dl)
428
+
429
+ # support (x,) or (x, y) datasets
430
+ if isinstance(batch, (list, tuple)) and len(batch) == 2:
431
+ x, y = batch
432
+ else:
433
+ x = batch[0] if isinstance(batch, (list, tuple)) else batch
434
+ y = None
435
+
436
+ with self.accelerator.autocast():
437
+ logits = self.model(x)
438
+
439
+ if self.task == 'classification' and exists(y):
440
+ loss = F.cross_entropy(logits, y.long())
441
+ elif exists(y):
442
+ # heteroscedastic MSE — Eq. (4) negative log-likelihood
443
+ loss = F.mse_loss(logits.squeeze(-1), y.float())
444
+ else:
445
+ raise ValueError("Dataset must return (x, y) pairs")
446
+
447
+ self.accelerator.backward(loss)
448
+ self.opt.step()
449
+ self.opt.zero_grad()
450
+ self.ema.update()
451
+
452
+ pbar.set_description(f'loss: {loss.item():.4f}')
453
+ self.step += 1
454
+
455
+ if divisible_by(self.step, self.save_every):
456
+ milestone = self.step // self.save_every
457
+ self.save(milestone)
458
+
459
+ pbar.update(1)
460
+
461
+ accelerator.print('training complete')
@@ -0,0 +1 @@
1
+ __version__ = '0.1.0'
@@ -0,0 +1,189 @@
1
+ Metadata-Version: 2.4
2
+ Name: mc-dropout-pytorch
3
+ Version: 0.1.0
4
+ Summary: MC Dropout (Gal & Ghahramani, 2016) - Pytorch
5
+ Home-page: https://github.com/lucidrains/mc-dropout-pytorch
6
+ Author: lucidrains
7
+ Author-email: lucidrains@gmail.com
8
+ License: MIT
9
+ Keywords: artificial intelligence,deep learning,bayesian deep learning,uncertainty estimation,monte carlo dropout
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3.6
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: accelerate
18
+ Requires-Dist: einops>=0.7
19
+ Requires-Dist: ema-pytorch>=0.4.2
20
+ Requires-Dist: torch>=2.0
21
+ Requires-Dist: tqdm
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: description
26
+ Dynamic: description-content-type
27
+ Dynamic: home-page
28
+ Dynamic: keywords
29
+ Dynamic: license
30
+ Dynamic: license-file
31
+ Dynamic: requires-dist
32
+ Dynamic: summary
33
+
34
+
35
+
36
+ ## MC Dropout, in Pytorch
37
+
38
+ [![PyPI version](https://badge.fury.io/py/mc-dropout-pytorch.svg)](https://badge.fury.io/py/mc-dropout-pytorch)
39
+
40
+ Implementation of [Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142) (Gal & Ghahramani, ICML 2016) in Pytorch.
41
+
42
+ Standard dropout NNs cast as approximate Bayesian inference over deep Gaussian processes — giving free, calibrated uncertainty estimates with no architectural changes and zero inference overhead beyond T forward passes.
43
+
44
+ ## Install
45
+
46
+ ```bash
47
+ $ pip install mc-dropout-pytorch
48
+ ```
49
+
50
+ ## Usage
51
+
52
+ ### Regression with uncertainty
53
+
54
+ ```python
55
+ import torch
56
+ from torch.utils.data import TensorDataset
57
+ from mc_dropout_pytorch import BayesianMLP, MCDropoutInference, Trainer
58
+
59
+ # build model
60
+ model = BayesianMLP(
61
+ input_dim = 1,
62
+ output_dim = 1,
63
+ hidden_dims = (256, 256),
64
+ dropout_rate = 0.1,
65
+ activation = 'relu',
66
+ )
67
+
68
+ # wrap for MC inference (T=50 stochastic passes)
69
+ mc = MCDropoutInference(model, num_samples = 50, task = 'regression', tau = 1.0)
70
+
71
+ x = torch.linspace(-3, 3, 100).unsqueeze(-1)
72
+ out = mc(x)
73
+
74
+ out.mean # predictive mean — (100, 1)
75
+ out.variance # predictive variance — (100, 1) includes τ⁻¹ noise term
76
+ out.samples # raw samples — (50, 100, 1)
77
+ ```
78
+
79
+ ### Classification with predictive entropy
80
+
81
+ ```python
82
+ import torch
83
+ from mc_dropout_pytorch import BayesianCNN, MCDropoutInference
84
+
85
+ model = BayesianCNN(
86
+ in_channels = 1,
87
+ num_classes = 10,
88
+ base_channels = 32,
89
+ dropout_rate = 0.25,
90
+ fc_dropout_rate = 0.5,
91
+ img_size = 28,
92
+ )
93
+
94
+ mc = MCDropoutInference(model, num_samples = 50, task = 'classification')
95
+
96
+ x = torch.randn(8, 1, 28, 28)
97
+ out = mc(x)
98
+
99
+ out.mean # class probabilities — (8, 10)
100
+ out.variance # per-class variance — (8, 10)
101
+
102
+ # active learning signals (§6)
103
+ H = mc.predictive_entropy(x) # (8,) — total uncertainty
104
+ MI = mc.mutual_information(x) # (8,) — epistemic uncertainty only
105
+ ```
106
+
107
+ ### Full training loop with the `Trainer`
108
+
109
+ ```python
110
+ import torch
111
+ from torch.utils.data import TensorDataset
112
+ from mc_dropout_pytorch import BayesianMLP, Trainer
113
+
114
+ # synthetic regression dataset
115
+ X = torch.randn(1000, 4)
116
+ y = X[:, 0] * 2 + X[:, 1] - X[:, 2] + torch.randn(1000) * 0.1
117
+ dataset = TensorDataset(X, y)
118
+
119
+ model = BayesianMLP(
120
+ input_dim = 4,
121
+ output_dim = 1,
122
+ hidden_dims = (128, 128),
123
+ dropout_rate = 0.1,
124
+ )
125
+
126
+ trainer = Trainer(
127
+ model,
128
+ dataset,
129
+ task = 'regression',
130
+ train_lr = 1e-3,
131
+ train_num_steps = 5_000,
132
+ train_batch_size = 64,
133
+ ema_decay = 0.995,
134
+ amp = False,
135
+ weight_decay = 1e-4, # ≡ prior precision in §3
136
+ tau = 1.0, # noise precision
137
+ num_mc_samples = 50,
138
+ )
139
+
140
+ trainer.train()
141
+
142
+ # inference via EMA model
143
+ mc = trainer.inference
144
+ out = mc(X[:10])
145
+ print(out.mean, out.variance)
146
+ ```
147
+
148
+ ### Multi-GPU
149
+
150
+ ```bash
151
+ $ accelerate config
152
+ $ accelerate launch train.py
153
+ ```
154
+
155
+ ## Key ideas from the paper
156
+
157
+ **The insight (§3)**: Training a NN with dropout and L2 regularisation minimises a KL divergence to the posterior of a deep Gaussian process — no variational EM, no weight sampling required.
158
+
159
+ **Test-time dropout (MC Dropout)**:
160
+
161
+ ```
162
+ for t = 1 … T:
163
+ ŷ_t = f^ω_t(x) # ω_t ~ q(ω) via Bernoulli dropout
164
+
165
+ E[y*] ≈ (1/T) Σ ŷ_t # predictive mean
166
+ Var[y*] ≈ τ⁻¹ I + (1/T) Σ ŷ_t ŷ_tᵀ − E[y*]² # predictive variance (Eq. 9)
167
+ ```
168
+
169
+ **Active learning** (§6): Use `mc.mutual_information(x)` to identify the most informative unlabelled points — pure epistemic uncertainty, disentangled from aleatoric noise.
170
+
171
+ **Weight correspondence** (§3.2):
172
+
173
+ | Dropout training | Bayesian GP posterior |
174
+ |---------------------------|--------------------------|
175
+ | dropout probability `p` | variational parameter |
176
+ | L2 weight decay `λ` | prior precision |
177
+ | noise precision `τ` | `τ = (2N λ) / (1 − p)` |
178
+
179
+ ## Citations
180
+
181
+ ```bibtex
182
+ @article{Gal2016Dropout,
183
+ title = {Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning},
184
+ author = {Yarin Gal and Zoubin Ghahramani},
185
+ journal = {Proceedings of the 33rd International Conference on Machine Learning (ICML)},
186
+ year = {2016},
187
+ url = {https://arxiv.org/abs/1506.02142}
188
+ }
189
+ ```
@@ -0,0 +1,11 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ mc_dropout_pytorch/__init__.py
5
+ mc_dropout_pytorch/mc_dropout_pytorch.py
6
+ mc_dropout_pytorch/version.py
7
+ mc_dropout_pytorch.egg-info/PKG-INFO
8
+ mc_dropout_pytorch.egg-info/SOURCES.txt
9
+ mc_dropout_pytorch.egg-info/dependency_links.txt
10
+ mc_dropout_pytorch.egg-info/requires.txt
11
+ mc_dropout_pytorch.egg-info/top_level.txt
@@ -0,0 +1,5 @@
1
+ accelerate
2
+ einops>=0.7
3
+ ema-pytorch>=0.4.2
4
+ torch>=2.0
5
+ tqdm
@@ -0,0 +1 @@
1
+ mc_dropout_pytorch
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,37 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ exec(open('mc_dropout_pytorch/version.py').read())
4
+
5
+ setup(
6
+ name = 'mc-dropout-pytorch',
7
+ packages = find_packages(),
8
+ version = __version__,
9
+ license = 'MIT',
10
+ description = 'MC Dropout (Gal & Ghahramani, 2016) - Pytorch',
11
+ long_description = open('README.md').read(),
12
+ long_description_content_type = 'text/markdown',
13
+ author = 'lucidrains',
14
+ author_email = 'lucidrains@gmail.com',
15
+ url = 'https://github.com/lucidrains/mc-dropout-pytorch',
16
+ keywords = [
17
+ 'artificial intelligence',
18
+ 'deep learning',
19
+ 'bayesian deep learning',
20
+ 'uncertainty estimation',
21
+ 'monte carlo dropout',
22
+ ],
23
+ install_requires = [
24
+ 'accelerate',
25
+ 'einops>=0.7',
26
+ 'ema-pytorch>=0.4.2',
27
+ 'torch>=2.0',
28
+ 'tqdm',
29
+ ],
30
+ classifiers = [
31
+ 'Development Status :: 4 - Beta',
32
+ 'Intended Audience :: Developers',
33
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
34
+ 'License :: OSI Approved :: MIT License',
35
+ 'Programming Language :: Python :: 3.6',
36
+ ],
37
+ )