suq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
suq/SUQ_MLP.py ADDED
@@ -0,0 +1,10 @@
1
+ from .diag_suq_mlp import SUQ_MLP_Diag
2
+
3
+ def streamline_mlp(model, posterior, covariance_structure, likelihood, scale_init = 1.0):
4
+ if covariance_structure == 'diag':
5
+ return SUQ_MLP_Diag(org_model = model,
6
+ posterior_variance = posterior,
7
+ likelihood = likelihood,
8
+ scale_init = scale_init)
9
+ else:
10
+ raise NotImplementedError(f"Covariance structure '{covariance_structure}' is not implemented.")
suq/SUQ_ViT.py ADDED
@@ -0,0 +1,14 @@
1
+ from .diag_suq_transformer import SUQ_ViT_Diag
2
+
3
+ def streamline_vit(model, posterior, covariance_structure, likelihood, MLP_deterministic, Attn_deterministic, attention_diag_cov, num_det_blocks, scale_init = 1.0):
4
+ if covariance_structure == 'diag':
5
+ return SUQ_ViT_Diag(ViT = model,
6
+ posterior_variance = posterior,
7
+ MLP_determinstic = MLP_deterministic,
8
+ Attn_determinstic = Attn_deterministic,
9
+ likelihood = likelihood,
10
+ attention_diag_cov = attention_diag_cov,
11
+ num_det_blocks = num_det_blocks,
12
+ scale_init = scale_init)
13
+ else:
14
+ raise NotImplementedError(f"Covariance structure '{covariance_structure}' is not implemented.")
suq/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .SUQ_MLP import streamline_mlp
2
+ from .SUQ_ViT import streamline_vit
3
+
4
+ __all__ = ["streamline_mlp", "streamline_vit"]
suq/base_suq.py ADDED
@@ -0,0 +1,181 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from torch.distributions import Categorical
6
+ from torch.distributions.normal import Normal
7
+ from torch.utils.data import DataLoader
8
+
9
+ from suq.utils.utils import torch_dataset
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ class SUQ_Base(nn.Module):
14
+ """
15
+ Base class for SUQ models.
16
+
17
+ Provides core functionality for:
18
+ - Managing likelihood type (regression or classification)
19
+ - Probit-based approximation for classification
20
+ - NLPD-based fitting of the scale factor
21
+
22
+ Inputs:
23
+ likelihood (str): Either 'classification' or 'regression'
24
+ scale_init (float): Initial value for the scale factor parameter
25
+ """
26
+
27
+ def __init__(self, likelihood, scale_init):
28
+ super().__init__()
29
+
30
+ if likelihood not in ['classification', 'regression']:
31
+ raise ValueError(f"Invalid likelihood type {likelihood}")
32
+
33
+ self.likelihood = likelihood
34
+ self.scale_factor = nn.Parameter(torch.Tensor([scale_init]).to(device))
35
+
36
+ def probit_approximation(self, out_mean, out_var):
37
+ """
38
+ Applies a probit approximation to compute class probabilities from the latent Gaussian distribution.
39
+
40
+ Inputs:
41
+ out_mean (Tensor): Latent function mean, shape [B, C]
42
+ out_var (Tensor): Latent function variance, shape [B, C] or [B, C, C]
43
+
44
+ Outputs:
45
+ posterior_predict_mean (Tensor): Predicted class probabilities, shape [B, C]
46
+ """
47
+
48
+ if out_var.dim() == 3:
49
+ kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var.diagonal(dim1=1, dim2=2))
50
+ else:
51
+ kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
52
+
53
+ posterior_predict_mean = torch.softmax(kappa * out_mean, dim=-1)
54
+ return posterior_predict_mean
55
+
56
+ def fit_scale_factor(self, data_loader, n_epoches, lr, speedup = True, verbose = False):
57
+ """
58
+ Fits the scale factor for predictive variance using negative log predictive density (NLPD).
59
+
60
+ Inputs:
61
+ data_loader (DataLoader): Dataloader containing (input, target) pairs
62
+ n_epoches (int): Number of epochs for optimization
63
+ lr (float): Learning rate for scale optimizer
64
+ speedup (bool): If True (classification only), caches forward pass outputs to accelerate fitting
65
+ verbose (bool): If True, prints NLPD at each epoch
66
+
67
+ Outputs:
68
+ total_train_nlpd (List[float]): Average NLPD per epoch over training data
69
+ """
70
+ print("fit scale factor")
71
+ optimizer = torch.optim.Adam([self.scale_factor], lr)
72
+ total_train_nlpd = []
73
+
74
+ # store intermediate result and pack it into a data loader, so we only need to do one forward pass
75
+ if speedup:
76
+
77
+ if self.likelihood == 'regression':
78
+ raise ValueError(f"Speed up not supported for regression atm")
79
+
80
+ if self.likelihood == 'classification':
81
+
82
+ f_mean = []
83
+ f_var = []
84
+ labels = []
85
+
86
+ for (X, y) in tqdm(data_loader, desc= "packing f_mean f_var into a dataloader"):
87
+ out_mean, out_var = self.forward_latent(X.to(device))
88
+ f_mean.append(out_mean.detach().cpu().numpy())
89
+ f_var.append(out_var.detach().cpu().numpy())
90
+ if y.dim() == 2:
91
+ labels.append(y.numpy().argmax(1).reshape(-1, 1))
92
+ if y.dim() == 1:
93
+ labels.append(y.numpy().reshape(-1, 1))
94
+
95
+ f_mean = np.vstack(f_mean)
96
+ f_var = np.vstack(f_var)
97
+ labels = np.vstack(labels)
98
+
99
+ scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
100
+ scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)
101
+
102
+ for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
103
+ running_nlpd = 0
104
+ for data_pair in scale_fit_dataloader:
105
+ x_mean, x_var_label = data_pair
106
+ num_class = x_mean.shape[1]
107
+ x_mean = x_mean.to(device)
108
+ x_var, label = x_var_label.split(num_class, dim=1)
109
+ x_var = x_var.to(device)
110
+ label = label.to(device)
111
+
112
+ optimizer.zero_grad()
113
+ # make prediction
114
+ x_var = x_var / self.scale_factor.to(device)
115
+ posterior_predict_mean = self.probit_approximation(x_mean, x_var)
116
+ # construct log posterior predictive distribution
117
+ posterior_predictive_dist = Categorical(posterior_predict_mean)
118
+ # calculate nlpd and update
119
+ nlpd = -posterior_predictive_dist.log_prob(label).mean()
120
+ nlpd.backward()
121
+ optimizer.step()
122
+ # log nlpd
123
+ running_nlpd += nlpd.item()
124
+ total_train_nlpd.append(running_nlpd / len(scale_fit_dataloader))
125
+ if verbose:
126
+ print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
127
+
128
+ del scale_fit_dataloader
129
+ del scale_fit_dataset
130
+
131
+ else:
132
+
133
+ if self.likelihood == 'classification':
134
+ for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
135
+ running_nlpd = 0
136
+ for (data, label) in data_loader:
137
+
138
+ data = data.to(device)
139
+ label = label.to(device)
140
+
141
+ optimizer.zero_grad()
142
+ # make prediction
143
+ posterior_predict_mean = self.forward(data)
144
+ # construct log posterior predictive distribution
145
+ posterior_predictive_dist = Categorical(posterior_predict_mean)
146
+ # calculate nlpd and update
147
+ nlpd = -posterior_predictive_dist.log_prob(label).mean()
148
+ nlpd.backward()
149
+ optimizer.step()
150
+ # log nlpd
151
+ running_nlpd += nlpd.item()
152
+ total_train_nlpd.append(running_nlpd / len(data_loader))
153
+ if verbose:
154
+ print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
155
+
156
+
157
+ if self.likelihood == 'regression':
158
+ for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
159
+ running_nlpd = 0
160
+ for (data, label) in data_loader:
161
+ data = data.to(device)
162
+ label = label.to(device)
163
+
164
+ optimizer.zero_grad()
165
+ # make prediction
166
+ posterior_predict_mean, posterior_predict_var = self.forward(data)
167
+ # construct log posterior predictive distribution
168
+ posterior_predictive_dist = Normal(posterior_predict_mean, posterior_predict_var.sqrt())
169
+ # calculate nlpd and update
170
+ nlpd = -posterior_predictive_dist.log_prob(label).mean()
171
+ nlpd.backward()
172
+ optimizer.step()
173
+ # log nlpd
174
+ running_nlpd += nlpd.item()
175
+
176
+ total_train_nlpd.append(running_nlpd / len(data_loader))
177
+
178
+ if verbose:
179
+ print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
180
+
181
+ return total_train_nlpd
suq/diag_suq_mlp.py ADDED
@@ -0,0 +1,308 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from torch.nn.utils import parameters_to_vector
5
+ import numpy as np
6
+ import copy
7
+
8
+ from suq.base_suq import SUQ_Base
9
+
10
+ def forward_aW_diag(a_mean, a_var, weight, bias, w_var, b_var):
11
+ """
12
+ compute mean and covariance of h = a @ W^T + b when posterior has diag covariance
13
+
14
+ ----- Input -----
15
+ a_mean: [N, D_in] mean(a)
16
+ a_var: [N, D_in] a_var[i] = var(a_i)
17
+ weight: [D_out, D_in] W
18
+ bias: [D_out, ] b
19
+ b_var: [D_out, ] b_var[k]: var(b_k)
20
+ w_var: [D_out, D_in] w_cov[k][i]: var(w_ki)
21
+ ----- Output -----
22
+ h_mean: [N, D_out]
23
+ h_var: [N, D_out] h_var[k] = var(h_k)
24
+ """
25
+
26
+ # calculate mean(h)
27
+ h_mean = F.linear(a_mean, weight, bias)
28
+
29
+ # calculate var(h)
30
+ weight_mean2_var_sum = weight ** 2 + w_var # [D_out, D_in]
31
+ h_var = a_mean **2 @ w_var.T + a_var @ weight_mean2_var_sum.T + b_var
32
+
33
+ return h_mean, h_var
34
+
35
+
36
+ def forward_activation_implicit_diag(activation_func, h_mean, h_var):
37
+ """
38
+ given h ~ N(h_mean, h_cov), g(·), where h_cov is a diagonal matrix,
39
+ approximate the distribution of a = g(h) as
40
+ a ~ N(g(h_mean), g'(h_mean)^T h_var g'(h_mean))
41
+
42
+ input
43
+ activation_func: g(·)
44
+ h_mean: [N, D]
45
+ h_var: [N, D], h_var[i] = var(h_i)
46
+
47
+ output
48
+ a_mean: [N, D]
49
+ a_var: [N, D]
50
+ """
51
+
52
+ h_mean_grad = h_mean.detach().clone().requires_grad_()
53
+
54
+ a_mean = activation_func(h_mean_grad)
55
+ a_mean.retain_grad()
56
+ a_mean.backward(torch.ones_like(a_mean)) #[N, D]
57
+
58
+ nabla = h_mean_grad.grad #[N, D]
59
+ a_var = nabla ** 2 * h_var
60
+
61
+ return a_mean.detach(), a_var
62
+
63
+ def forward_batch_norm_diag(h_var, bn_weight, bn_running_var, bn_eps):
64
+ """
65
+ Pass a distribution with diagonal covariance through BatchNorm layer
66
+
67
+ Input
68
+ h_mean: mean of input distribution [B, D]
69
+ h_var: variance of input distribution [B, D]
70
+ bn_weight: batch norm scale factor [D, ]
71
+ bn_running_var: batch norm running variance [D, ]
72
+ bn_eps: batch norm eps
73
+
74
+ Output
75
+ output_var [B, T, D]
76
+ """
77
+
78
+ scale_factor = (1 / (bn_running_var.reshape(1, -1) + bn_eps)) * bn_weight.reshape(1, -1) **2 # [B, D]
79
+ output_var = scale_factor * h_var # [B, D]
80
+
81
+ return output_var
82
+
83
+ class SUQ_Linear_Diag(nn.Module):
84
+ """
85
+ Linear layer with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
86
+
87
+ Wraps a standard `nn.Linear` layer and applies closed-form mean and variance propagation. See the SUQ paper for theoretical background and assumptions.
88
+
89
+ Inputs:
90
+ org_linear (nn.Linear): The original linear layer to wrap
91
+ w_var (Tensor): Weight variances, shape [D_out, D_in]
92
+ b_var (Tensor): Bias variances, shape [D_out]
93
+ """
94
+ def __init__(self, org_linear, w_var, b_var):
95
+ super().__init__()
96
+
97
+ self.weight = org_linear.weight.data
98
+ self.bias = org_linear.bias.data
99
+ self.w_var = w_var
100
+ self.b_var = b_var
101
+
102
+ def forward(self, a_mean, a_var):
103
+ """
104
+ Inputs:
105
+ a_mean (Tensor): Input mean, shape [N, D_in]
106
+ a_var (Tensor): Input variance, shape [N, D_in]
107
+
108
+ Outputs:
109
+ h_mean (Tensor): Output mean, shape [N, D_out]
110
+ h_var (Tensor): Output variance, shape [N, D_out]
111
+ """
112
+
113
+ if a_var == None:
114
+ a_var = torch.zeros_like(a_mean).to(a_mean.device)
115
+
116
+ h_mean, h_var = forward_aW_diag(a_mean, a_var, self.weight, self.bias, self.w_var, self.b_var)
117
+
118
+ return h_mean, h_var
119
+
120
+ class SUQ_Activation_Diag(nn.Module):
121
+ """
122
+ Activation layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
123
+
124
+ Wraps a standard activation function and applies a first-order approximation to propagate input variance through the nonlinearity. See the SUQ paper for theoretical background and assumptions.
125
+
126
+ Inputs:
127
+ afun (Callable): A PyTorch activation function (e.g. nn.ReLU())
128
+ """
129
+
130
+ def __init__(self, afun):
131
+ super().__init__()
132
+ self.afun = afun
133
+
134
+ def forward(self, h_mean, h_var):
135
+ """
136
+ Inputs:
137
+ h_mean (Tensor): Input mean before activation, shape [N, D]
138
+ h_var (Tensor): Input variance before activation, shape [N, D]
139
+
140
+ Outputs:
141
+ a_mean (Tensor): Activated output mean, shape [N, D]
142
+ a_var (Tensor): Approximated output variance, shape [N, D]
143
+ """
144
+ a_mean, a_var = forward_activation_implicit_diag(self.afun, h_mean, h_var)
145
+ return a_mean, a_var
146
+
147
+ class SUQ_BatchNorm_Diag(nn.Module):
148
+ """
149
+ BatchNorm layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
150
+
151
+ Wraps `nn.BatchNorm1d` and adjusts input variance using batch normalization statistics and scale parameters. See the SUQ paper for theoretical background and assumptions.
152
+
153
+ Inputs:
154
+ BatchNorm (nn.BatchNorm1d): The original batch norm layer
155
+ """
156
+
157
+ def __init__(self, BatchNorm):
158
+ super().__init__()
159
+
160
+ self.BatchNorm = BatchNorm
161
+
162
+ def forward(self, x_mean, x_var):
163
+ """
164
+ Inputs:
165
+ x_mean (Tensor): Input mean, shape [B, D]
166
+ x_var (Tensor): Input variance, shape [B, D]
167
+
168
+ Outputs:
169
+ out_mean (Tensor): Output mean after batch normalization, shape [B, D]
170
+ out_var (Tensor): Output variance after batch normalization, shape [B, D]
171
+ """
172
+
173
+ with torch.no_grad():
174
+
175
+ out_mean = self.BatchNorm.forward(x_mean)
176
+ out_var = forward_batch_norm_diag(x_mean, x_var, self.BatchNorm.weight, 1e-5)
177
+
178
+ return out_mean, out_var
179
+
180
+
181
+ class SUQ_MLP_Diag(SUQ_Base):
182
+ """
183
+ Multilayer perceptron model with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
184
+
185
+ Wraps a standard MLP, converting its layers into SUQ-compatible components.
186
+ Supports both classification and regression via predictive Gaussian approximation.
187
+
188
+ Note:
189
+ The input model should correspond to the latent function only:
190
+ - For regression, this is the full model (including final output layer).
191
+ - For classification, exclude the softmax layer and pass only the logit-producing part.
192
+
193
+ Inputs:
194
+ org_model (nn.Module): The original MLP model to convert
195
+ posterior_variance (Tensor): Flattened posterior variance vector
196
+ likelihood (str): Either 'classification' or 'regression'
197
+ scale_init (float, optional): Initial scale factor
198
+ sigma_noise (float, optional): noise level (for regression)
199
+ """
200
+
201
+ def __init__(self, org_model, posterior_variance, likelihood, scale_init = 1.0, sigma_noise = None):
202
+ super().__init__(likelihood, scale_init)
203
+
204
+ self.sigma_noise = sigma_noise
205
+ self.convert_model(org_model, posterior_variance)
206
+
207
+ def forward_latent(self, data, out_var = None):
208
+ """
209
+ Compute the predictive mean and variance of the latent function before applying the likelihood.
210
+
211
+ Traverses the model layer by layer, propagating mean and variance through each SUQ-wrapped layer.
212
+
213
+ Inputs:
214
+ data (Tensor): Input data, shape [B, D]
215
+ out_var (Tensor or None): Optional input variance, shape [B, D]
216
+
217
+ Outputs:
218
+ out_mean (Tensor): Output mean after final layer, shape [B, D_out]
219
+ out_var (Tensor): Output variance after final layer, shape [B, D_out]
220
+ """
221
+
222
+ out_mean = data
223
+
224
+ if isinstance(self.model, nn.Sequential):
225
+ for layer in self.model:
226
+ out_mean, out_var = layer.forward(out_mean, out_var)
227
+ ##TODO: other type of model
228
+
229
+ out_var = out_var / self.scale_factor
230
+
231
+ return out_mean, out_var
232
+
233
+ def forward(self, data):
234
+ """
235
+ Compute the predictive distribution based on the model's likelihood setting.
236
+
237
+ For classification, use probit-approximation.
238
+ For regression, returns the latent mean and total predictive variance.
239
+
240
+ Inputs:
241
+ data (Tensor): Input data, shape [B, D]
242
+
243
+ Outputs:
244
+ If classification:
245
+ Tensor: Class probabilities, shape [B, num_classes]
246
+ If regression:
247
+ Tuple[Tensor, Tensor]: Output mean and total variance, shape [B, D_out]
248
+ """
249
+
250
+ out_mean, out_var = self.forward_latent(data)
251
+
252
+ if self.likelihood == 'classification':
253
+ kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
254
+ return torch.softmax(kappa * out_mean, dim=-1)
255
+
256
+ if self.likelihood == 'regression':
257
+ return out_mean, out_var + self.sigma_noise ** 2
258
+
259
+ def convert_model(self, org_model, posterior_variance):
260
+ """
261
+ Converts a deterministic MLP into a SUQ-compatible model with diagonal posterior.
262
+
263
+ Each layer is replaced with its corresponding SUQ module (e.g. linear, activation, batchnorm), using the provided flattened posterior variance vector.
264
+
265
+ Inputs:
266
+ org_model (nn.Module): The original model to convert (latent function only)
267
+ posterior_variance (Tensor): Flattened posterior variance for Bayesian parameters
268
+ """
269
+
270
+ p_model = copy.deepcopy(org_model)
271
+
272
+ loc = 0
273
+ for n, layer in p_model.named_modules():
274
+ if isinstance(layer, nn.Linear):
275
+
276
+ D_out, D_in = layer.weight.data.shape
277
+ num_param = torch.numel(parameters_to_vector(layer.parameters()))
278
+ num_weight_param = D_out * D_in
279
+
280
+ covariance_block = posterior_variance[loc : loc + num_param]
281
+
282
+ """
283
+ w_var: [D_out, D_in], w_var[k][i] = var(w_ki)
284
+ b_var: [D_out, ] b_var[k]: var(b_k)
285
+ """
286
+
287
+ b_var = torch.zeros_like(layer.bias.data).to(layer.bias.data.device)
288
+ w_var = torch.zeros_like(layer.weight.data).to(layer.bias.data.device)
289
+
290
+ for k in range(D_out):
291
+ b_var[k] = covariance_block[num_weight_param + k]
292
+ for i in range(D_in):
293
+ w_var[k][i] = covariance_block[k * D_in + i]
294
+
295
+ new_layer = SUQ_Linear_Diag(layer, w_var, b_var)
296
+
297
+ loc += num_param
298
+ setattr(p_model, n, new_layer)
299
+
300
+ if isinstance(layer, nn.BatchNorm1d):
301
+ new_layer = SUQ_BatchNorm_Diag(layer)
302
+ setattr(p_model, n, new_layer)
303
+
304
+ if type(layer).__name__ in torch.nn.modules.activation.__all__:
305
+ new_layer = SUQ_Activation_Diag(layer)
306
+ setattr(p_model, n, new_layer)
307
+
308
+ self.model = p_model
@@ -0,0 +1,627 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import numpy as np
5
+ from suq.diag_suq_mlp import forward_aW_diag
6
+ from suq.base_suq import SUQ_Base
7
+
8
+ def forward_linear_diag_Bayesian_weight(e_mean, e_var, w_mean, w_var, bias = None):
9
+ """
10
+ Pass a distribution with diagonal covariance through a Bayesian linear layer with diagonal covariance
11
+
12
+ Given e ~ N(e_mean, e_cov), W ~ N(W_mean, W_cov), calculate the mean and variance h = eW.T + b.
13
+
14
+ We only make the weight Bayesian, bias is treated determinstically
15
+
16
+ Note that as we always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.
17
+
18
+ Input
19
+ e_mean: [B, T, D_in] embedding mean
20
+ e_var: [B, T, D_in] embedding variance
21
+ w_mean: [D_out, D_in] weight mean
22
+ w_var: [D_out, D_in] weight covariance, w_cov[k][i]: var(w_ki)
23
+ Output
24
+ h_mean: [B, T, D_out]
25
+ h_var: [B, T, D_out] h_var[k] = var(h_k)
26
+ """
27
+
28
+ # calculate mean(h)
29
+ h_mean = F.linear(e_mean, w_mean, bias)
30
+
31
+ # calculate var(h)
32
+ weight_mean2_var_sum = w_mean ** 2 + w_var # [D_out, D_in]
33
+ h_var = e_mean **2 @ w_var.T + e_var @ weight_mean2_var_sum.T
34
+
35
+ return h_mean, h_var
36
+
37
+ def forward_linear_diag_determinstic_weight(e_mean, e_var, weight, bias = None):
38
+ """
39
+ Pass a distribution with diagonal covariance through a linear layer
40
+
41
+ Given e ~ N(e_mean, e_var) and determinstic weight W and bias b, calculate the mean and variance h = eW.T + b.
42
+
43
+ Note that as we always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.
44
+
45
+ Input
46
+ e_mean: [B, T, D_in] embedding mean
47
+ e_var: [B, T, D_in] embedding variance
48
+ w_mean: [D_out, D_in] weight
49
+ Output
50
+ h_mean: [B, T, D_out]
51
+ h_var: [B, T, D_out] h_var[k] = var(h_k)
52
+ """
53
+
54
+ h_mean = F.linear(e_mean, weight, bias)
55
+ h_var = F.linear(e_var, weight ** 2, None)
56
+
57
+ return h_mean, h_var
58
+
59
+ @torch.enable_grad()
60
+ def forward_activation_diag(activation_func, h_mean, h_var):
61
+ """
62
+ Pass a distribution with diagonal covariance through an activation layer.
63
+
64
+ Given h ~ N(h_mean, h_cov), g(·), where h_cov is a diagonal matrix,
65
+ approximate the distribution of a = g(h) as
66
+ a ~ N(g(h_mean), g'(h_mean)^T h_var g'(h_mean))
67
+
68
+ Input
69
+ activation_func: g(·)
70
+ h_mean: [B, T, D] input mean
71
+ h_var: [B, T, D] input variance
72
+
73
+ Output
74
+ a_mean: [B, T, D]
75
+ a_var: [B, T, D]
76
+ """
77
+
78
+ h_mean_grad = h_mean.detach().clone().requires_grad_()
79
+
80
+ a_mean = activation_func(h_mean_grad)
81
+ a_mean.retain_grad()
82
+ a_mean.backward(torch.ones_like(a_mean)) #[B, T, D]
83
+
84
+ nabla = h_mean_grad.grad #[B, T, D]
85
+ a_var = nabla ** 2 * h_var
86
+
87
+ return a_mean.detach(), a_var
88
+
89
+ def forward_layer_norm_diag(e_mean, e_var, ln_weight, ln_eps):
90
+ """
91
+ Pass a distribution with diagonal covariance through LayerNorm layer
92
+
93
+ Input
94
+ e_mean: mean of input distribution [B, T, D]
95
+ e_var: variance of input distribution [B, T, D]
96
+ ln_weight: layer norm scale factor
97
+ ln_eps: layer norm eps
98
+
99
+ Output
100
+ output_var [B, T, D]
101
+ """
102
+
103
+ # calculate the var
104
+ input_mean_var = e_mean.var(dim=-1, keepdim=True, unbiased=False) # [B, T, 1]
105
+ scale_factor = (1 / (input_mean_var + ln_eps)) * ln_weight **2 # [B, T, D]
106
+ output_var = scale_factor * e_var # [B, T, D]
107
+
108
+ return output_var
109
+
110
+ def forward_value_cov_Bayesian_W(W_v, W_v_var, input_mean, input_var, n_h, D_v, diag_cov = False):
111
+ """
112
+ Given value matrix W_v ~ N(mean(W), var(W)) and input E ~ N(mean(E), var(E))
113
+ Compute the covariance of output v = W_v @ E
114
+
115
+ Input:
116
+ n_h: number of attention heads
117
+ D_v: dimension of value, n_h * D_v = D
118
+ W_v: value weight matrix [D, D]
119
+ W_v_var: variance of value matrix, [D, D]
120
+ input_mean: mean of input [B, T, D]
121
+ input_var: variance of input variance [B, T, D]
122
+ diag_cov: whether input only has diag covariance
123
+
124
+ Output:
125
+ v_cov [B, T, n_h, D_v, D_v] or v_var [B, T, n_h, D_v]
126
+ """
127
+
128
+ B, T, D = input_var.size()
129
+
130
+ if not diag_cov:
131
+ ## compute general covariance
132
+ W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D)
133
+ # [D, D] -> [1, 1, n_h, D_v, D]
134
+ input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
135
+ # [B, T, D] -> [B, T, 1, 1, D]
136
+ v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
137
+ # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
138
+ v_cov = torch.matmul(W_v_reshaped, v_cov)
139
+ # [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v] -> [B, T, n_h, D_v, D_v]
140
+
141
+ ## add missing part for variance
142
+ W_v_var_reshaped = W_v_var.reshape(1, 1, n_h, D_v, D)
143
+ #[D, D] -> [1, 1, n_h, D_v, D]
144
+ input_var_plus_mean_square = input_var_reshaped + input_mean.reshape(B, T, 1, 1, D)**2 #[B, T, 1, 1, D]
145
+ extra_var_term = torch.sum(input_var_plus_mean_square * W_v_var_reshaped, dim=[4]) # [B, T, n_h, D_v, D] -> [B, T, n_h, D_v]
146
+ v_cov = v_cov + torch.diag_embed(extra_var_term)
147
+
148
+ return v_cov
149
+
150
+ else:
151
+ weight_mean2_var_sum = W_v **2 + W_v_var # [D, D]
152
+ v_var = input_mean **2 @ W_v_var.T + input_var @ weight_mean2_var_sum.T # [B, T, D]
153
+
154
+ return v_var.reshape(B, T, n_h, D_v)
155
+
156
+ def forward_value_cov_determinstic_W(W_v, input_var, n_h, D_v, diag_cov = False):
157
+ """
158
+ Given determinstic value matrix W_v and input E ~ N(mean(E), var(E))
159
+ Compute the covariance of output v = W_v @ E
160
+
161
+
162
+ Input:
163
+ n_h: number of attention heads
164
+ D_v: dimension of value, n_h * D_v = D
165
+ W_v: value weight matrix [D, D], which can be reshaped into [n_h, D_v, D]
166
+ input_var: variance of input variance [B, T, D]
167
+ diag_cov: whether input only has diag covariance
168
+
169
+ Output:
170
+ v_cov [B, T, n_h, D_v, D_v] or v_var [B, T, n_h, D_v]
171
+ """
172
+
173
+ B, T, D = input_var.size()
174
+
175
+ if not diag_cov:
176
+ W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D)
177
+ #[n_h, D_v, D] -> [1, 1, n_h, D_v, D]
178
+ input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
179
+ # [B, T, D] -> [B, T, 1, 1, D]
180
+ v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
181
+ # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
182
+ v_cov = torch.matmul(W_v_reshaped, v_cov)
183
+ # [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v] -> [B, T, n_h, D_v, D_v]
184
+
185
+ return v_cov
186
+
187
+ else:
188
+ v_var = input_var @ (W_v ** 2).T
189
+
190
+ return v_var.reshape(B, T, n_h, D_v)
191
+
192
+ def forward_QKV_cov(attention_score, v_cov, diag_cov = False):
193
+ """
194
+ given attention score (QK^T) and V ~ N(mean(V), cov(V))
195
+ compute the covariance of output E = (QK^T) V
196
+
197
+ Input:
198
+ attention_score: [B, n_h, T, T] attention_score[t] is token t's attention score for all other tokens
199
+ v_cov: [B, T, n_h, D_v, D_v] or [B, T, n_h, D_v], covariance of value
200
+ diag_cov: whether input only has diag covariance
201
+
202
+ Output:
203
+ QKV_cov: [B, n_h, T, D_v, D_v] or [B, T, n_h, D_v], covariance of output E
204
+ """
205
+ if diag_cov:
206
+ B, T, n_h, D_v = v_cov.size()
207
+ QKV_cov = attention_score **2 @ v_cov.transpose(1, 2) # [B, n_h, T, D_v]
208
+ # v_cov [B, T, n_h, D_v] -> [B, n_h, T, D_v]
209
+ # [B, n_h, T, T] @ [B, n_h, T, D_v] -> [B, n_h, T, D_v]
210
+ else:
211
+
212
+ B, T, n_h, D_v, _ = v_cov.size()
213
+
214
+ QKV_cov = attention_score **2 @ v_cov.permute(0, 2, 1, 3, 4).reshape(B, n_h, T, D_v * D_v) # [B, n_h, T, D_v * D_v]
215
+ # v_cov [B, T, n_h, D_v, D_v] -> [B, n_h, T, D_v * D_v]
216
+ # [B, n_h, T, T] @ [B, n_h, T, D_v * D_v] -> [B, n_h, T, D_v * D_v]
217
+ QKV_cov = QKV_cov.reshape(B, n_h, T, D_v, D_v)
218
+
219
+ return QKV_cov
220
+
221
+ def forward_fuse_multi_head_cov(QKV_cov, project_W, diag_cov = False):
222
+ """
223
+ given concatanated multi-head embedding E ~ N(mean(E), cov(E)) and project weight matrix W
224
+ compute variance of each output dimenison
225
+
226
+ Input:
227
+ QKV_cov: [B, n_h, T, D_v, D_v] or [B, n_h, T, D_v]
228
+ project_W: [D, D] D_out x D_in (n_h * D_v)
229
+ diag_cov: whether input only has diag covariance
230
+
231
+ Output:
232
+ output_var [B, T, D]
233
+ """
234
+ if diag_cov:
235
+ B, n_h, T, D_v = QKV_cov.size()
236
+ output_var = QKV_cov.permute(0, 2, 1, 3).reshape(B, T, n_h * D_v) @ project_W ** 2
237
+ # QKV_cov [B, n_h, T, D_v] -> [B, T, n_h, D_v] -> [B, T, n_h * D_v]
238
+
239
+ return output_var
240
+
241
+ else:
242
+ B, n_h, T, D_v, _ = QKV_cov.size()
243
+ D, _ = project_W.shape
244
+
245
+ project_W_reshaped_1 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, D_v, 1)
246
+ # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, D_v, 1]
247
+ project_W_reshaped_2 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, 1, D_v)
248
+ # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, 1, D_v]
249
+
250
+ project_W_outer = torch.bmm(project_W_reshaped_1, project_W_reshaped_2).reshape(n_h, D, D_v, D_v).permute(1, 0, 2, 3) # [D, n_h, D_v, D_v]
251
+ # [n_h * D, D_v, D_v] @ [n_h * D, 1, D_v] -> [n_h * D, D_v, D_v] -> [D, n_h, D_v, D_v]
252
+
253
+ output_var_einsum = torch.einsum('dhij,bthij->dbt', project_W_outer, QKV_cov.permute(0, 2, 1, 3, 4))
254
+
255
+ return output_var_einsum.permute(1, 2, 0)
256
+
257
+ class SUQ_LayerNorm_Diag(nn.Module):
258
+ """
259
+ LayerNorm module with uncertainty propagation under SUQ.
260
+
261
+ Wraps `nn.LayerNorm` and propagates input variance analytically using running statistics. See the SUQ paper for theoretical background and assumptions.
262
+
263
+ Inputs:
264
+ LayerNorm (nn.LayerNorm): The original layer norm module to wrap
265
+ """
266
+
267
+ def __init__(self, LayerNorm):
268
+ super().__init__()
269
+
270
+ self.LayerNorm = LayerNorm
271
+
272
+ def forward(self, x_mean, x_var):
273
+ """
274
+ Inputs:
275
+ x_mean (Tensor): Input mean, shape [B, T, D]
276
+ x_var (Tensor): Input variance, shape [B, T, D]
277
+
278
+ Outputs:
279
+ out_mean (Tensor): Output mean after layer norm, shape [B, T, D]
280
+ out_var (Tensor): Output variance after layer norm, shape [B, T, D]
281
+ """
282
+ with torch.no_grad():
283
+
284
+ out_mean = self.LayerNorm.forward(x_mean)
285
+ out_var = forward_layer_norm_diag(x_mean, x_var, self.LayerNorm.weight, 1e-5)
286
+
287
+ return out_mean, out_var
288
+
289
+
290
+ class SUQ_Classifier_Diag(nn.Module):
291
+ """
292
+ Classifier head with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
293
+
294
+ Wraps a standard linear classifier and applies closed-form mean and variance propagation.
295
+ See the SUQ paper for theoretical background and assumptions.
296
+
297
+ Inputs:
298
+ classifier (nn.Linear): The final classification head
299
+ w_var (Tensor): Weight variances, shape [D_out, D_in]
300
+ b_var (Tensor): Bias variances, shape [D_out]
301
+ """
302
+
303
+ def __init__(self, classifier, w_var, b_var):
304
+ super().__init__()
305
+
306
+ self.weight = classifier.weight
307
+ self.bias = classifier.bias
308
+ self.w_var = w_var.reshape(self.weight.shape)
309
+ self.b_var = b_var.reshape(self.bias.shape)
310
+
311
+ def forward(self, x_mean, x_var):
312
+ """
313
+ Inputs:
314
+ x_mean (Tensor): Input mean, shape [B, D]
315
+ x_var (Tensor): Input variance, shape [B, D]
316
+
317
+ Outputs:
318
+ h_mean (Tensor): Output mean, shape [B, D_out]
319
+ h_var (Tensor): Output variance, shape [B, D_out]
320
+ """
321
+ with torch.no_grad():
322
+ h_mean, h_var = forward_aW_diag(x_mean, x_var, self.weight.data, self.bias.data, self.w_var, self.b_var)
323
+ return h_mean, h_var
324
+
325
+ class SUQ_TransformerMLP_Diag(nn.Module):
326
+ """
327
+ MLP submodule of a transformer block with uncertainty propagation under SUQ.
328
+
329
+ Supports both deterministic and Bayesian forward modes with closed-form variance propagation.
330
+ Used internally in `SUQ_Transformer_Block_Diag`.
331
+
332
+ Inputs:
333
+ MLP (nn.Module): Original MLP submodule
334
+ determinstic (bool): Whether to treat the MLP weights as deterministic
335
+ w_fc_var (Tensor, optional): Variance of the first linear layer (if Bayesian)
336
+ w_proj_var (Tensor, optional): Variance of the second linear layer (if Bayesian)
337
+ """
338
+
339
+ def __init__(self, MLP, determinstic = True, w_fc_var = None, w_proj_var = None):
340
+ super().__init__()
341
+
342
+ self.MLP = MLP
343
+ self.determinstic = determinstic
344
+ if not determinstic:
345
+ self.w_fc_var = w_fc_var.reshape(self.MLP.c_fc.weight.shape)
346
+ self.w_proj_var = w_proj_var.reshape(self.MLP.c_proj.weight.shape)
347
+
348
+ def forward(self, x_mean, x_var):
349
+ """
350
+ Inputs:
351
+ x_mean (Tensor): Input mean, shape [B, T, D]
352
+ x_var (Tensor): Input variance, shape [B, T, D]
353
+
354
+ Outputs:
355
+ h_mean (Tensor): Output mean, shape [B, T, D]
356
+ h_var (Tensor): Output variance, shape [B, T, D]
357
+ """
358
+
359
+ # first fc layer
360
+ with torch.no_grad():
361
+ if self.determinstic:
362
+ h_mean, h_var = forward_linear_diag_determinstic_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.MLP.c_fc.bias.data)
363
+ else:
364
+ h_mean, h_var = forward_linear_diag_Bayesian_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.w_fc_var, self.MLP.c_fc.bias.data)
365
+ # activation function
366
+ h_mean, h_var = forward_activation_diag(self.MLP.gelu, h_mean, h_var)
367
+ # second fc layer
368
+ with torch.no_grad():
369
+ if self.determinstic:
370
+ h_mean, h_var = forward_linear_diag_determinstic_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.MLP.c_proj.bias.data)
371
+ else:
372
+ h_mean, h_var = forward_linear_diag_Bayesian_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.w_proj_var, self.MLP.c_proj.bias.data)
373
+
374
+ return h_mean, h_var
375
+
376
+ class SUQ_Attention_Diag(nn.Module):
377
+ """
378
+ Self-attention module with uncertainty propagation under SUQ.
379
+
380
+ Supports deterministic and Bayesian value projections, with optional diagonal covariance assumptions. For details see SUQ paper section A.6
381
+ Used internally in `SUQ_Transformer_Block_Diag`.
382
+
383
+ Inputs:
384
+ Attention (nn.Module): The original attention module
385
+ determinstic (bool): Whether to treat value projections as deterministic
386
+ diag_cov (bool): If True, only compute the diagoanl covariance for value
387
+ W_v_var (Tensor, optional): Posterior variance for value matrix (if Bayesian)
388
+ """
389
+
390
+ def __init__(self, Attention, determinstic = True, diag_cov = False, W_v_var = None):
391
+ super().__init__()
392
+
393
+ self.Attention = Attention
394
+ self.determinstic = determinstic
395
+ self.diag_cov = diag_cov
396
+
397
+ if not self.determinstic:
398
+ self.W_v_var = W_v_var # [D * D]
399
+
400
+ def forward(self, x_mean, x_var):
401
+ """
402
+ Inputs:
403
+ x_mean (Tensor): Input mean, shape [B, T, D]
404
+ x_var (Tensor): Input variance, shape [B, T, D]
405
+
406
+ Outputs:
407
+ output_mean (Tensor): Output mean after attention, shape [B, T, D]
408
+ output_var (Tensor): Output variance after attention, shape [B, T, D]
409
+ """
410
+
411
+ with torch.no_grad():
412
+
413
+ output_mean, attention_score = self.Attention.forward(x_mean, True)
414
+
415
+ n_h = self.Attention.n_head
416
+ B, T, D = x_mean.size()
417
+ D_v = D // n_h
418
+
419
+ W_v = self.Attention.c_attn_v.weight.data
420
+ project_W = self.Attention.c_proj.weight.data
421
+
422
+ if self.determinstic:
423
+ v_cov = forward_value_cov_determinstic_W(W_v, x_var, n_h, D_v)
424
+ else:
425
+ v_cov = forward_value_cov_Bayesian_W(W_v, self.W_v_var.reshape(D, D), x_mean, x_var, n_h, D_v, self.diag_cov)
426
+
427
+ QKV_cov = forward_QKV_cov(attention_score, v_cov, self.diag_cov)
428
+ output_var = forward_fuse_multi_head_cov(QKV_cov, project_W, self.diag_cov)
429
+
430
+ return output_mean, output_var
431
+
432
+ class SUQ_Transformer_Block_Diag(nn.Module):
433
+ """
434
+ Single transformer block with uncertainty propagation under SUQ.
435
+
436
+ Wraps LayerNorm, attention, and MLP submodules with uncertainty-aware versions.
437
+ Used in `SUQ_ViT_Diag` to form a full transformer stack.
438
+
439
+ Inputs:
440
+ MLP (nn.Module): Original MLP submodule
441
+ Attention (nn.Module): Original attention submodule
442
+ LN_1 (nn.LayerNorm): Pre-attention layer norm
443
+ LN_2 (nn.LayerNorm): Pre-MLP layer norm
444
+ MLP_determinstic (bool): Whether to treat MLP as deterministic
445
+ Attn_determinstic (bool): Whether to treat attention as deterministic
446
+ diag_cov (bool): If True, only compute the diagoanl covariance for value
447
+ w_fc_var (Tensor or None): Posterior variance of MLP input projection (if Bayesian)
448
+ w_proj_var (Tensor or None): Posterior variance of MLP output projection (if Bayesian)
449
+ W_v_var (Tensor or None): Posterior variance of value matrix (if Bayesian)
450
+ """
451
+
452
+
453
+ def __init__(self, MLP, Attention, LN_1, LN_2, MLP_determinstic, Attn_determinstic, diag_cov = False, w_fc_var = None, w_proj_var = None, W_v_var = None):
454
+ super().__init__()
455
+
456
+ self.ln_1 = SUQ_LayerNorm_Diag(LN_1)
457
+ self.ln_2 = SUQ_LayerNorm_Diag(LN_2)
458
+ self.attn = SUQ_Attention_Diag(Attention, Attn_determinstic, diag_cov, W_v_var)
459
+ self.mlp = SUQ_TransformerMLP_Diag(MLP, MLP_determinstic, w_fc_var, w_proj_var)
460
+
461
+ def forward(self, x_mean, x_var):
462
+ """
463
+ Inputs:
464
+ x_mean (Tensor): Input mean, shape [B, T, D]
465
+ x_var (Tensor): Input variance, shape [B, T, D]
466
+
467
+ Outputs:
468
+ h_mean (Tensor): Output mean after transformer block, shape [B, T, D]
469
+ h_var (Tensor): Output variance after transformer block, shape [B, T, D]
470
+ """
471
+
472
+ h_mean, h_var = self.ln_1(x_mean, x_var)
473
+ h_mean, h_var = self.attn(h_mean, h_var)
474
+ h_mean = h_mean + x_mean
475
+ h_var = h_var + x_var
476
+
477
+ old_h_mean, old_h_var = h_mean, h_var
478
+
479
+ h_mean, h_var = self.ln_2(h_mean, h_var)
480
+ h_mean, h_var = self.mlp(h_mean, h_var)
481
+ h_mean = h_mean + old_h_mean
482
+ h_var = h_var + old_h_var
483
+
484
+ return h_mean, h_var
485
+
486
+
487
+ class SUQ_ViT_Diag(SUQ_Base):
488
+ """
489
+ Vision Transformer model with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
490
+
491
+ Wraps a ViT architecture into a structured uncertainty-aware model by replacing parts
492
+ of the network with SUQ-compatible blocks. Allows selective Bayesian treatment of MLP
493
+ and attention modules within each transformer block.
494
+
495
+ Currently supports classification only. See the SUQ paper for theoretical background and assumptions.
496
+
497
+ Inputs:
498
+ ViT (nn.Module): A Vision Transformer model structured like `examples/vit_model.py`
499
+ posterior_variance (Tensor): Flattened posterior variance vector
500
+ MLP_determinstic (bool): Whether MLP submodules are treated as deterministic
501
+ Attn_determinstic (bool): Whether attention submodules are treated as deterministic
502
+ scale_init (float, optional): Initial value for the scale factor
503
+ attention_diag_cov (bool): If True, only compute the diagoanl covariance for value
504
+ likelihood (str): Currently only support 'Classification'
505
+ num_det_blocks (int): Number of transformer blocks to leave deterministic (from the bottom up)
506
+ """
507
+
508
+ def __init__(self, ViT, posterior_variance, MLP_determinstic, Attn_determinstic, scale_init = 1.0, attention_diag_cov = False, likelihood = 'clasification', num_det_blocks = 10):
509
+ super().__init__(likelihood, scale_init)
510
+
511
+ if likelihood not in ['classification']:
512
+ raise ValueError(f"{likelihood} not supported for ViT")
513
+
514
+
515
+ self.transformer = nn.ModuleDict(dict(
516
+ pte = ViT.transformer.pte,
517
+ h = nn.ModuleList(),
518
+ ln_f = SUQ_LayerNorm_Diag(ViT.transformer.ln_f)
519
+ ))
520
+
521
+ self.scale_factor = nn.Parameter(torch.Tensor([scale_init]))
522
+
523
+ num_param_c_fc = ViT.transformer.h[0].mlp.c_fc.weight.numel()
524
+ num_param_c_proj = ViT.transformer.h[0].mlp.c_proj.weight.numel()
525
+ num_param_value_matrix = ViT.transformer.h[0].attn.c_proj.weight.numel()
526
+
527
+ index = 0
528
+ for block_index in range(len(ViT.transformer.h)):
529
+
530
+ if block_index < num_det_blocks:
531
+ self.transformer.h.append(ViT.transformer.h[block_index])
532
+ else:
533
+ if not MLP_determinstic:
534
+ w_fc_var = posterior_variance[index: index + num_param_c_fc]
535
+ index += num_param_c_fc
536
+ w_proj_var = posterior_variance[index: index + num_param_c_proj]
537
+ index += num_param_c_proj
538
+ self.transformer.h.append(
539
+ SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp,
540
+ ViT.transformer.h[block_index].attn,
541
+ ViT.transformer.h[block_index].ln_1,
542
+ ViT.transformer.h[block_index].ln_2,
543
+ MLP_determinstic,
544
+ Attn_determinstic,
545
+ attention_diag_cov,
546
+ w_fc_var,
547
+ w_proj_var,
548
+ None))
549
+
550
+ if not Attn_determinstic:
551
+ w_v_var = posterior_variance[index : index + num_param_value_matrix]
552
+ index += num_param_value_matrix
553
+ self.transformer.h.append(
554
+ SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp,
555
+ ViT.transformer.h[block_index].attn,
556
+ ViT.transformer.h[block_index].ln_1,
557
+ ViT.transformer.h[block_index].ln_2,
558
+ MLP_determinstic,
559
+ Attn_determinstic,
560
+ attention_diag_cov,
561
+ None,
562
+ None,
563
+ w_v_var))
564
+
565
+ num_param_classifier_weight = ViT.classifier.weight.numel()
566
+ self.classifier = SUQ_Classifier_Diag(ViT.classifier, posterior_variance[index: index + num_param_classifier_weight], posterior_variance[index + num_param_classifier_weight:])
567
+
568
+ def forward_latent(self, pixel_values, interpolate_pos_encoding = None):
569
+
570
+ """
571
+ Compute the predictive mean and variance of the ViT's latent output before applying the final likelihood layer.
572
+
573
+ Traverses the full transformer stack with uncertainty propagation.
574
+
575
+ Inputs:
576
+ pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
577
+ interpolate_pos_encoding (optional): Optional positional embedding interpolation
578
+
579
+ Outputs:
580
+ x_mean (Tensor): Predicted latent mean at the [CLS] token, shape [B, D]
581
+ x_var (Tensor): Predicted latent variance at the [CLS] token, shape [B, D]
582
+ """
583
+
584
+ device = pixel_values.device
585
+
586
+ x_mean = self.transformer.pte(
587
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
588
+ )
589
+
590
+ # pass through model
591
+ x_var = torch.zeros_like(x_mean, device = device)
592
+
593
+ for i, block in enumerate(self.transformer.h):
594
+
595
+ if isinstance(block, SUQ_Transformer_Block_Diag):
596
+
597
+ x_mean, x_var = block(x_mean, x_var)
598
+ else:
599
+ x_mean = block(x_mean)
600
+
601
+ x_mean, x_var = self.transformer.ln_f(x_mean, x_var)
602
+
603
+ x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
604
+ x_var = x_var / self.scale_factor.to(device)
605
+
606
+ return x_mean, x_var
607
+
608
+ def forward(self, pixel_values, interpolate_pos_encoding = None):
609
+ """
610
+ Compute predictive class probabilities using a probit approximation.
611
+
612
+ Performs a full forward pass through the ViT with uncertainty propagation, and
613
+ produces softmax-normalized class probabilities for classification.
614
+
615
+ Inputs:
616
+ pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
617
+ interpolate_pos_encoding (optional): Optional positional embedding interpolation
618
+
619
+ Outputs:
620
+ Tensor: Predicted class probabilities, shape [B, num_classes]
621
+ """
622
+
623
+ x_mean, x_var = self.forward_latent(pixel_values, interpolate_pos_encoding)
624
+ kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)
625
+
626
+ return torch.softmax(kappa * x_mean, dim=-1)
627
+
@@ -0,0 +1,23 @@
1
+ from .diag_suq_mlp import (
2
+ SUQ_Linear_Diag,
3
+ SUQ_Activation_Diag,
4
+ SUQ_BatchNorm_Diag
5
+ )
6
+ from .diag_suq_transformer import (
7
+ SUQ_TransformerMLP_Diag,
8
+ SUQ_Attention_Diag,
9
+ SUQ_LayerNorm_Diag,
10
+ SUQ_Classifier_Diag,
11
+ SUQ_Transformer_Block_Diag
12
+ )
13
+
14
+ __all__ = [
15
+ "SUQ_Linear_Diag",
16
+ "SUQ_Activation_Diag",
17
+ "SUQ_BatchNorm_Diag",
18
+ "SUQ_TransformerMLP_Diag",
19
+ "SUQ_Attention_Diag",
20
+ "SUQ_LayerNorm_Diag",
21
+ "SUQ_Classifier_Diag",
22
+ "SUQ_Transformer_Block_Diag"
23
+ ]
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: suq
3
+ Version: 0.1.0
4
+ Summary: Streamlined Uncertainty Quantification (SUQ)
5
+ Home-page: https://github.com/AaltoML/SUQ
6
+ Author: Rui Li, Marcus Klasson, Arno Solin, Martin Trapp
7
+ License: MIT
8
+ Classifier: Programming Language :: Python :: 3
9
+ License-File: LICENSE
10
+ Requires-Dist: torch>=1.10
11
+ Requires-Dist: numpy>=1.21
12
+ Requires-Dist: tqdm>=4.60
13
+ Dynamic: author
14
+ Dynamic: classifier
15
+ Dynamic: home-page
16
+ Dynamic: license
17
+ Dynamic: license-file
18
+ Dynamic: requires-dist
19
+ Dynamic: summary
@@ -0,0 +1,12 @@
1
+ suq/SUQ_MLP.py,sha256=ciegIOz2Y0vtJ4Uc56dhwHjeaAoLAKqQQnSTgyh8Sqc,497
2
+ suq/SUQ_ViT.py,sha256=6BpHMOLf1qhVRgqM3guAcS60PpYJ1L1CEWdIerIKrv4,841
3
+ suq/__init__.py,sha256=5LGGZQ6wwjEfwsmzx55lSA4nfBx1N1qQpzgFff8zpJM,120
4
+ suq/base_suq.py,sha256=77LggEGn5m3h472kRLVghM1j7WCY4zJih3undVdIoS0,7963
5
+ suq/diag_suq_mlp.py,sha256=yE4p2pQtD97g6GvebkROqNDPFuJUhrwZH9zRkWAiSWk,10985
6
+ suq/diag_suq_transformer.py,sha256=QfGkPCf1DJ1cVIrkDp7tD8bImtCgt3XoUnngtgfwIAc,25576
7
+ suq/streamline_layer.py,sha256=sew0BodioNjElVpSq0c87c0ORRA5yGDbLTy1eeLp8VI,504
8
+ suq-0.1.0.dist-info/licenses/LICENSE,sha256=XbdRUcHQPFVqh9MgS85KpgRmx09FAPBW5gOHyzjJz3c,1064
9
+ suq-0.1.0.dist-info/METADATA,sha256=lcWlRUIWAM-LKK8MqbDehGjJ3QX1_Fl7nDCtZ-Jb1V4,497
10
+ suq-0.1.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
11
+ suq-0.1.0.dist-info/top_level.txt,sha256=_8edEeGJ6W1u-AhLS6x16sRWGGKwumOGPl7EBsW-v5g,4
12
+ suq-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (78.1.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 AaltoML
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ suq