suq 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- suq/SUQ_MLP.py +10 -0
- suq/SUQ_ViT.py +14 -0
- suq/__init__.py +4 -0
- suq/base_suq.py +181 -0
- suq/diag_suq_mlp.py +308 -0
- suq/diag_suq_transformer.py +627 -0
- suq/streamline_layer.py +23 -0
- suq-0.1.0.dist-info/METADATA +19 -0
- suq-0.1.0.dist-info/RECORD +12 -0
- suq-0.1.0.dist-info/WHEEL +5 -0
- suq-0.1.0.dist-info/licenses/LICENSE +21 -0
- suq-0.1.0.dist-info/top_level.txt +1 -0
suq/SUQ_MLP.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
from .diag_suq_mlp import SUQ_MLP_Diag
|
2
|
+
|
3
|
+
def streamline_mlp(model, posterior, covariance_structure, likelihood, scale_init = 1.0):
|
4
|
+
if covariance_structure == 'diag':
|
5
|
+
return SUQ_MLP_Diag(org_model = model,
|
6
|
+
posterior_variance = posterior,
|
7
|
+
likelihood = likelihood,
|
8
|
+
scale_init = scale_init)
|
9
|
+
else:
|
10
|
+
raise NotImplementedError(f"Covariance structure '{covariance_structure}' is not implemented.")
|
suq/SUQ_ViT.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
from .diag_suq_transformer import SUQ_ViT_Diag
|
2
|
+
|
3
|
+
def streamline_vit(model, posterior, covariance_structure, likelihood, MLP_deterministic, Attn_deterministic, attention_diag_cov, num_det_blocks, scale_init = 1.0):
|
4
|
+
if covariance_structure == 'diag':
|
5
|
+
return SUQ_ViT_Diag(ViT = model,
|
6
|
+
posterior_variance = posterior,
|
7
|
+
MLP_determinstic = MLP_deterministic,
|
8
|
+
Attn_determinstic = Attn_deterministic,
|
9
|
+
likelihood = likelihood,
|
10
|
+
attention_diag_cov = attention_diag_cov,
|
11
|
+
num_det_blocks = num_det_blocks,
|
12
|
+
scale_init = scale_init)
|
13
|
+
else:
|
14
|
+
raise NotImplementedError(f"Covariance structure '{covariance_structure}' is not implemented.")
|
suq/__init__.py
ADDED
suq/base_suq.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import numpy as np
|
4
|
+
from tqdm import tqdm
|
5
|
+
from torch.distributions import Categorical
|
6
|
+
from torch.distributions.normal import Normal
|
7
|
+
from torch.utils.data import DataLoader
|
8
|
+
|
9
|
+
from suq.utils.utils import torch_dataset
|
10
|
+
|
11
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
12
|
+
|
13
|
+
class SUQ_Base(nn.Module):
|
14
|
+
"""
|
15
|
+
Base class for SUQ models.
|
16
|
+
|
17
|
+
Provides core functionality for:
|
18
|
+
- Managing likelihood type (regression or classification)
|
19
|
+
- Probit-based approximation for classification
|
20
|
+
- NLPD-based fitting of the scale factor
|
21
|
+
|
22
|
+
Inputs:
|
23
|
+
likelihood (str): Either 'classification' or 'regression'
|
24
|
+
scale_init (float): Initial value for the scale factor parameter
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, likelihood, scale_init):
|
28
|
+
super().__init__()
|
29
|
+
|
30
|
+
if likelihood not in ['classification', 'regression']:
|
31
|
+
raise ValueError(f"Invalid likelihood type {likelihood}")
|
32
|
+
|
33
|
+
self.likelihood = likelihood
|
34
|
+
self.scale_factor = nn.Parameter(torch.Tensor([scale_init]).to(device))
|
35
|
+
|
36
|
+
def probit_approximation(self, out_mean, out_var):
|
37
|
+
"""
|
38
|
+
Applies a probit approximation to compute class probabilities from the latent Gaussian distribution.
|
39
|
+
|
40
|
+
Inputs:
|
41
|
+
out_mean (Tensor): Latent function mean, shape [B, C]
|
42
|
+
out_var (Tensor): Latent function variance, shape [B, C] or [B, C, C]
|
43
|
+
|
44
|
+
Outputs:
|
45
|
+
posterior_predict_mean (Tensor): Predicted class probabilities, shape [B, C]
|
46
|
+
"""
|
47
|
+
|
48
|
+
if out_var.dim() == 3:
|
49
|
+
kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var.diagonal(dim1=1, dim2=2))
|
50
|
+
else:
|
51
|
+
kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
|
52
|
+
|
53
|
+
posterior_predict_mean = torch.softmax(kappa * out_mean, dim=-1)
|
54
|
+
return posterior_predict_mean
|
55
|
+
|
56
|
+
def fit_scale_factor(self, data_loader, n_epoches, lr, speedup = True, verbose = False):
|
57
|
+
"""
|
58
|
+
Fits the scale factor for predictive variance using negative log predictive density (NLPD).
|
59
|
+
|
60
|
+
Inputs:
|
61
|
+
data_loader (DataLoader): Dataloader containing (input, target) pairs
|
62
|
+
n_epoches (int): Number of epochs for optimization
|
63
|
+
lr (float): Learning rate for scale optimizer
|
64
|
+
speedup (bool): If True (classification only), caches forward pass outputs to accelerate fitting
|
65
|
+
verbose (bool): If True, prints NLPD at each epoch
|
66
|
+
|
67
|
+
Outputs:
|
68
|
+
total_train_nlpd (List[float]): Average NLPD per epoch over training data
|
69
|
+
"""
|
70
|
+
print("fit scale factor")
|
71
|
+
optimizer = torch.optim.Adam([self.scale_factor], lr)
|
72
|
+
total_train_nlpd = []
|
73
|
+
|
74
|
+
# store intermediate result and pack it into a data loader, so we only need to do one forward pass
|
75
|
+
if speedup:
|
76
|
+
|
77
|
+
if self.likelihood == 'regression':
|
78
|
+
raise ValueError(f"Speed up not supported for regression atm")
|
79
|
+
|
80
|
+
if self.likelihood == 'classification':
|
81
|
+
|
82
|
+
f_mean = []
|
83
|
+
f_var = []
|
84
|
+
labels = []
|
85
|
+
|
86
|
+
for (X, y) in tqdm(data_loader, desc= "packing f_mean f_var into a dataloader"):
|
87
|
+
out_mean, out_var = self.forward_latent(X.to(device))
|
88
|
+
f_mean.append(out_mean.detach().cpu().numpy())
|
89
|
+
f_var.append(out_var.detach().cpu().numpy())
|
90
|
+
if y.dim() == 2:
|
91
|
+
labels.append(y.numpy().argmax(1).reshape(-1, 1))
|
92
|
+
if y.dim() == 1:
|
93
|
+
labels.append(y.numpy().reshape(-1, 1))
|
94
|
+
|
95
|
+
f_mean = np.vstack(f_mean)
|
96
|
+
f_var = np.vstack(f_var)
|
97
|
+
labels = np.vstack(labels)
|
98
|
+
|
99
|
+
scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
|
100
|
+
scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)
|
101
|
+
|
102
|
+
for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
|
103
|
+
running_nlpd = 0
|
104
|
+
for data_pair in scale_fit_dataloader:
|
105
|
+
x_mean, x_var_label = data_pair
|
106
|
+
num_class = x_mean.shape[1]
|
107
|
+
x_mean = x_mean.to(device)
|
108
|
+
x_var, label = x_var_label.split(num_class, dim=1)
|
109
|
+
x_var = x_var.to(device)
|
110
|
+
label = label.to(device)
|
111
|
+
|
112
|
+
optimizer.zero_grad()
|
113
|
+
# make prediction
|
114
|
+
x_var = x_var / self.scale_factor.to(device)
|
115
|
+
posterior_predict_mean = self.probit_approximation(x_mean, x_var)
|
116
|
+
# construct log posterior predictive distribution
|
117
|
+
posterior_predictive_dist = Categorical(posterior_predict_mean)
|
118
|
+
# calculate nlpd and update
|
119
|
+
nlpd = -posterior_predictive_dist.log_prob(label).mean()
|
120
|
+
nlpd.backward()
|
121
|
+
optimizer.step()
|
122
|
+
# log nlpd
|
123
|
+
running_nlpd += nlpd.item()
|
124
|
+
total_train_nlpd.append(running_nlpd / len(scale_fit_dataloader))
|
125
|
+
if verbose:
|
126
|
+
print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
|
127
|
+
|
128
|
+
del scale_fit_dataloader
|
129
|
+
del scale_fit_dataset
|
130
|
+
|
131
|
+
else:
|
132
|
+
|
133
|
+
if self.likelihood == 'classification':
|
134
|
+
for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
|
135
|
+
running_nlpd = 0
|
136
|
+
for (data, label) in data_loader:
|
137
|
+
|
138
|
+
data = data.to(device)
|
139
|
+
label = label.to(device)
|
140
|
+
|
141
|
+
optimizer.zero_grad()
|
142
|
+
# make prediction
|
143
|
+
posterior_predict_mean = self.forward(data)
|
144
|
+
# construct log posterior predictive distribution
|
145
|
+
posterior_predictive_dist = Categorical(posterior_predict_mean)
|
146
|
+
# calculate nlpd and update
|
147
|
+
nlpd = -posterior_predictive_dist.log_prob(label).mean()
|
148
|
+
nlpd.backward()
|
149
|
+
optimizer.step()
|
150
|
+
# log nlpd
|
151
|
+
running_nlpd += nlpd.item()
|
152
|
+
total_train_nlpd.append(running_nlpd / len(data_loader))
|
153
|
+
if verbose:
|
154
|
+
print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
|
155
|
+
|
156
|
+
|
157
|
+
if self.likelihood == 'regression':
|
158
|
+
for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
|
159
|
+
running_nlpd = 0
|
160
|
+
for (data, label) in data_loader:
|
161
|
+
data = data.to(device)
|
162
|
+
label = label.to(device)
|
163
|
+
|
164
|
+
optimizer.zero_grad()
|
165
|
+
# make prediction
|
166
|
+
posterior_predict_mean, posterior_predict_var = self.forward(data)
|
167
|
+
# construct log posterior predictive distribution
|
168
|
+
posterior_predictive_dist = Normal(posterior_predict_mean, posterior_predict_var.sqrt())
|
169
|
+
# calculate nlpd and update
|
170
|
+
nlpd = -posterior_predictive_dist.log_prob(label).mean()
|
171
|
+
nlpd.backward()
|
172
|
+
optimizer.step()
|
173
|
+
# log nlpd
|
174
|
+
running_nlpd += nlpd.item()
|
175
|
+
|
176
|
+
total_train_nlpd.append(running_nlpd / len(data_loader))
|
177
|
+
|
178
|
+
if verbose:
|
179
|
+
print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")
|
180
|
+
|
181
|
+
return total_train_nlpd
|
suq/diag_suq_mlp.py
ADDED
@@ -0,0 +1,308 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import torch.nn.functional as F
|
3
|
+
import torch
|
4
|
+
from torch.nn.utils import parameters_to_vector
|
5
|
+
import numpy as np
|
6
|
+
import copy
|
7
|
+
|
8
|
+
from suq.base_suq import SUQ_Base
|
9
|
+
|
10
|
+
def forward_aW_diag(a_mean, a_var, weight, bias, w_var, b_var):
|
11
|
+
"""
|
12
|
+
compute mean and covariance of h = a @ W^T + b when posterior has diag covariance
|
13
|
+
|
14
|
+
----- Input -----
|
15
|
+
a_mean: [N, D_in] mean(a)
|
16
|
+
a_var: [N, D_in] a_var[i] = var(a_i)
|
17
|
+
weight: [D_out, D_in] W
|
18
|
+
bias: [D_out, ] b
|
19
|
+
b_var: [D_out, ] b_var[k]: var(b_k)
|
20
|
+
w_var: [D_out, D_in] w_cov[k][i]: var(w_ki)
|
21
|
+
----- Output -----
|
22
|
+
h_mean: [N, D_out]
|
23
|
+
h_var: [N, D_out] h_var[k] = var(h_k)
|
24
|
+
"""
|
25
|
+
|
26
|
+
# calculate mean(h)
|
27
|
+
h_mean = F.linear(a_mean, weight, bias)
|
28
|
+
|
29
|
+
# calculate var(h)
|
30
|
+
weight_mean2_var_sum = weight ** 2 + w_var # [D_out, D_in]
|
31
|
+
h_var = a_mean **2 @ w_var.T + a_var @ weight_mean2_var_sum.T + b_var
|
32
|
+
|
33
|
+
return h_mean, h_var
|
34
|
+
|
35
|
+
|
36
|
+
def forward_activation_implicit_diag(activation_func, h_mean, h_var):
|
37
|
+
"""
|
38
|
+
given h ~ N(h_mean, h_cov), g(·), where h_cov is a diagonal matrix,
|
39
|
+
approximate the distribution of a = g(h) as
|
40
|
+
a ~ N(g(h_mean), g'(h_mean)^T h_var g'(h_mean))
|
41
|
+
|
42
|
+
input
|
43
|
+
activation_func: g(·)
|
44
|
+
h_mean: [N, D]
|
45
|
+
h_var: [N, D], h_var[i] = var(h_i)
|
46
|
+
|
47
|
+
output
|
48
|
+
a_mean: [N, D]
|
49
|
+
a_var: [N, D]
|
50
|
+
"""
|
51
|
+
|
52
|
+
h_mean_grad = h_mean.detach().clone().requires_grad_()
|
53
|
+
|
54
|
+
a_mean = activation_func(h_mean_grad)
|
55
|
+
a_mean.retain_grad()
|
56
|
+
a_mean.backward(torch.ones_like(a_mean)) #[N, D]
|
57
|
+
|
58
|
+
nabla = h_mean_grad.grad #[N, D]
|
59
|
+
a_var = nabla ** 2 * h_var
|
60
|
+
|
61
|
+
return a_mean.detach(), a_var
|
62
|
+
|
63
|
+
def forward_batch_norm_diag(h_var, bn_weight, bn_running_var, bn_eps):
|
64
|
+
"""
|
65
|
+
Pass a distribution with diagonal covariance through BatchNorm layer
|
66
|
+
|
67
|
+
Input
|
68
|
+
h_mean: mean of input distribution [B, D]
|
69
|
+
h_var: variance of input distribution [B, D]
|
70
|
+
bn_weight: batch norm scale factor [D, ]
|
71
|
+
bn_running_var: batch norm running variance [D, ]
|
72
|
+
bn_eps: batch norm eps
|
73
|
+
|
74
|
+
Output
|
75
|
+
output_var [B, T, D]
|
76
|
+
"""
|
77
|
+
|
78
|
+
scale_factor = (1 / (bn_running_var.reshape(1, -1) + bn_eps)) * bn_weight.reshape(1, -1) **2 # [B, D]
|
79
|
+
output_var = scale_factor * h_var # [B, D]
|
80
|
+
|
81
|
+
return output_var
|
82
|
+
|
83
|
+
class SUQ_Linear_Diag(nn.Module):
|
84
|
+
"""
|
85
|
+
Linear layer with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
86
|
+
|
87
|
+
Wraps a standard `nn.Linear` layer and applies closed-form mean and variance propagation. See the SUQ paper for theoretical background and assumptions.
|
88
|
+
|
89
|
+
Inputs:
|
90
|
+
org_linear (nn.Linear): The original linear layer to wrap
|
91
|
+
w_var (Tensor): Weight variances, shape [D_out, D_in]
|
92
|
+
b_var (Tensor): Bias variances, shape [D_out]
|
93
|
+
"""
|
94
|
+
def __init__(self, org_linear, w_var, b_var):
|
95
|
+
super().__init__()
|
96
|
+
|
97
|
+
self.weight = org_linear.weight.data
|
98
|
+
self.bias = org_linear.bias.data
|
99
|
+
self.w_var = w_var
|
100
|
+
self.b_var = b_var
|
101
|
+
|
102
|
+
def forward(self, a_mean, a_var):
|
103
|
+
"""
|
104
|
+
Inputs:
|
105
|
+
a_mean (Tensor): Input mean, shape [N, D_in]
|
106
|
+
a_var (Tensor): Input variance, shape [N, D_in]
|
107
|
+
|
108
|
+
Outputs:
|
109
|
+
h_mean (Tensor): Output mean, shape [N, D_out]
|
110
|
+
h_var (Tensor): Output variance, shape [N, D_out]
|
111
|
+
"""
|
112
|
+
|
113
|
+
if a_var == None:
|
114
|
+
a_var = torch.zeros_like(a_mean).to(a_mean.device)
|
115
|
+
|
116
|
+
h_mean, h_var = forward_aW_diag(a_mean, a_var, self.weight, self.bias, self.w_var, self.b_var)
|
117
|
+
|
118
|
+
return h_mean, h_var
|
119
|
+
|
120
|
+
class SUQ_Activation_Diag(nn.Module):
|
121
|
+
"""
|
122
|
+
Activation layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
123
|
+
|
124
|
+
Wraps a standard activation function and applies a first-order approximation to propagate input variance through the nonlinearity. See the SUQ paper for theoretical background and assumptions.
|
125
|
+
|
126
|
+
Inputs:
|
127
|
+
afun (Callable): A PyTorch activation function (e.g. nn.ReLU())
|
128
|
+
"""
|
129
|
+
|
130
|
+
def __init__(self, afun):
|
131
|
+
super().__init__()
|
132
|
+
self.afun = afun
|
133
|
+
|
134
|
+
def forward(self, h_mean, h_var):
|
135
|
+
"""
|
136
|
+
Inputs:
|
137
|
+
h_mean (Tensor): Input mean before activation, shape [N, D]
|
138
|
+
h_var (Tensor): Input variance before activation, shape [N, D]
|
139
|
+
|
140
|
+
Outputs:
|
141
|
+
a_mean (Tensor): Activated output mean, shape [N, D]
|
142
|
+
a_var (Tensor): Approximated output variance, shape [N, D]
|
143
|
+
"""
|
144
|
+
a_mean, a_var = forward_activation_implicit_diag(self.afun, h_mean, h_var)
|
145
|
+
return a_mean, a_var
|
146
|
+
|
147
|
+
class SUQ_BatchNorm_Diag(nn.Module):
|
148
|
+
"""
|
149
|
+
BatchNorm layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
150
|
+
|
151
|
+
Wraps `nn.BatchNorm1d` and adjusts input variance using batch normalization statistics and scale parameters. See the SUQ paper for theoretical background and assumptions.
|
152
|
+
|
153
|
+
Inputs:
|
154
|
+
BatchNorm (nn.BatchNorm1d): The original batch norm layer
|
155
|
+
"""
|
156
|
+
|
157
|
+
def __init__(self, BatchNorm):
|
158
|
+
super().__init__()
|
159
|
+
|
160
|
+
self.BatchNorm = BatchNorm
|
161
|
+
|
162
|
+
def forward(self, x_mean, x_var):
|
163
|
+
"""
|
164
|
+
Inputs:
|
165
|
+
x_mean (Tensor): Input mean, shape [B, D]
|
166
|
+
x_var (Tensor): Input variance, shape [B, D]
|
167
|
+
|
168
|
+
Outputs:
|
169
|
+
out_mean (Tensor): Output mean after batch normalization, shape [B, D]
|
170
|
+
out_var (Tensor): Output variance after batch normalization, shape [B, D]
|
171
|
+
"""
|
172
|
+
|
173
|
+
with torch.no_grad():
|
174
|
+
|
175
|
+
out_mean = self.BatchNorm.forward(x_mean)
|
176
|
+
out_var = forward_batch_norm_diag(x_mean, x_var, self.BatchNorm.weight, 1e-5)
|
177
|
+
|
178
|
+
return out_mean, out_var
|
179
|
+
|
180
|
+
|
181
|
+
class SUQ_MLP_Diag(SUQ_Base):
|
182
|
+
"""
|
183
|
+
Multilayer perceptron model with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
184
|
+
|
185
|
+
Wraps a standard MLP, converting its layers into SUQ-compatible components.
|
186
|
+
Supports both classification and regression via predictive Gaussian approximation.
|
187
|
+
|
188
|
+
Note:
|
189
|
+
The input model should correspond to the latent function only:
|
190
|
+
- For regression, this is the full model (including final output layer).
|
191
|
+
- For classification, exclude the softmax layer and pass only the logit-producing part.
|
192
|
+
|
193
|
+
Inputs:
|
194
|
+
org_model (nn.Module): The original MLP model to convert
|
195
|
+
posterior_variance (Tensor): Flattened posterior variance vector
|
196
|
+
likelihood (str): Either 'classification' or 'regression'
|
197
|
+
scale_init (float, optional): Initial scale factor
|
198
|
+
sigma_noise (float, optional): noise level (for regression)
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(self, org_model, posterior_variance, likelihood, scale_init = 1.0, sigma_noise = None):
|
202
|
+
super().__init__(likelihood, scale_init)
|
203
|
+
|
204
|
+
self.sigma_noise = sigma_noise
|
205
|
+
self.convert_model(org_model, posterior_variance)
|
206
|
+
|
207
|
+
def forward_latent(self, data, out_var = None):
|
208
|
+
"""
|
209
|
+
Compute the predictive mean and variance of the latent function before applying the likelihood.
|
210
|
+
|
211
|
+
Traverses the model layer by layer, propagating mean and variance through each SUQ-wrapped layer.
|
212
|
+
|
213
|
+
Inputs:
|
214
|
+
data (Tensor): Input data, shape [B, D]
|
215
|
+
out_var (Tensor or None): Optional input variance, shape [B, D]
|
216
|
+
|
217
|
+
Outputs:
|
218
|
+
out_mean (Tensor): Output mean after final layer, shape [B, D_out]
|
219
|
+
out_var (Tensor): Output variance after final layer, shape [B, D_out]
|
220
|
+
"""
|
221
|
+
|
222
|
+
out_mean = data
|
223
|
+
|
224
|
+
if isinstance(self.model, nn.Sequential):
|
225
|
+
for layer in self.model:
|
226
|
+
out_mean, out_var = layer.forward(out_mean, out_var)
|
227
|
+
##TODO: other type of model
|
228
|
+
|
229
|
+
out_var = out_var / self.scale_factor
|
230
|
+
|
231
|
+
return out_mean, out_var
|
232
|
+
|
233
|
+
def forward(self, data):
|
234
|
+
"""
|
235
|
+
Compute the predictive distribution based on the model's likelihood setting.
|
236
|
+
|
237
|
+
For classification, use probit-approximation.
|
238
|
+
For regression, returns the latent mean and total predictive variance.
|
239
|
+
|
240
|
+
Inputs:
|
241
|
+
data (Tensor): Input data, shape [B, D]
|
242
|
+
|
243
|
+
Outputs:
|
244
|
+
If classification:
|
245
|
+
Tensor: Class probabilities, shape [B, num_classes]
|
246
|
+
If regression:
|
247
|
+
Tuple[Tensor, Tensor]: Output mean and total variance, shape [B, D_out]
|
248
|
+
"""
|
249
|
+
|
250
|
+
out_mean, out_var = self.forward_latent(data)
|
251
|
+
|
252
|
+
if self.likelihood == 'classification':
|
253
|
+
kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
|
254
|
+
return torch.softmax(kappa * out_mean, dim=-1)
|
255
|
+
|
256
|
+
if self.likelihood == 'regression':
|
257
|
+
return out_mean, out_var + self.sigma_noise ** 2
|
258
|
+
|
259
|
+
def convert_model(self, org_model, posterior_variance):
|
260
|
+
"""
|
261
|
+
Converts a deterministic MLP into a SUQ-compatible model with diagonal posterior.
|
262
|
+
|
263
|
+
Each layer is replaced with its corresponding SUQ module (e.g. linear, activation, batchnorm), using the provided flattened posterior variance vector.
|
264
|
+
|
265
|
+
Inputs:
|
266
|
+
org_model (nn.Module): The original model to convert (latent function only)
|
267
|
+
posterior_variance (Tensor): Flattened posterior variance for Bayesian parameters
|
268
|
+
"""
|
269
|
+
|
270
|
+
p_model = copy.deepcopy(org_model)
|
271
|
+
|
272
|
+
loc = 0
|
273
|
+
for n, layer in p_model.named_modules():
|
274
|
+
if isinstance(layer, nn.Linear):
|
275
|
+
|
276
|
+
D_out, D_in = layer.weight.data.shape
|
277
|
+
num_param = torch.numel(parameters_to_vector(layer.parameters()))
|
278
|
+
num_weight_param = D_out * D_in
|
279
|
+
|
280
|
+
covariance_block = posterior_variance[loc : loc + num_param]
|
281
|
+
|
282
|
+
"""
|
283
|
+
w_var: [D_out, D_in], w_var[k][i] = var(w_ki)
|
284
|
+
b_var: [D_out, ] b_var[k]: var(b_k)
|
285
|
+
"""
|
286
|
+
|
287
|
+
b_var = torch.zeros_like(layer.bias.data).to(layer.bias.data.device)
|
288
|
+
w_var = torch.zeros_like(layer.weight.data).to(layer.bias.data.device)
|
289
|
+
|
290
|
+
for k in range(D_out):
|
291
|
+
b_var[k] = covariance_block[num_weight_param + k]
|
292
|
+
for i in range(D_in):
|
293
|
+
w_var[k][i] = covariance_block[k * D_in + i]
|
294
|
+
|
295
|
+
new_layer = SUQ_Linear_Diag(layer, w_var, b_var)
|
296
|
+
|
297
|
+
loc += num_param
|
298
|
+
setattr(p_model, n, new_layer)
|
299
|
+
|
300
|
+
if isinstance(layer, nn.BatchNorm1d):
|
301
|
+
new_layer = SUQ_BatchNorm_Diag(layer)
|
302
|
+
setattr(p_model, n, new_layer)
|
303
|
+
|
304
|
+
if type(layer).__name__ in torch.nn.modules.activation.__all__:
|
305
|
+
new_layer = SUQ_Activation_Diag(layer)
|
306
|
+
setattr(p_model, n, new_layer)
|
307
|
+
|
308
|
+
self.model = p_model
|
@@ -0,0 +1,627 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import torch.nn.functional as F
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from suq.diag_suq_mlp import forward_aW_diag
|
6
|
+
from suq.base_suq import SUQ_Base
|
7
|
+
|
8
|
+
def forward_linear_diag_Bayesian_weight(e_mean, e_var, w_mean, w_var, bias = None):
|
9
|
+
"""
|
10
|
+
Pass a distribution with diagonal covariance through a Bayesian linear layer with diagonal covariance
|
11
|
+
|
12
|
+
Given e ~ N(e_mean, e_cov), W ~ N(W_mean, W_cov), calculate the mean and variance h = eW.T + b.
|
13
|
+
|
14
|
+
We only make the weight Bayesian, bias is treated determinstically
|
15
|
+
|
16
|
+
Note that as we always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.
|
17
|
+
|
18
|
+
Input
|
19
|
+
e_mean: [B, T, D_in] embedding mean
|
20
|
+
e_var: [B, T, D_in] embedding variance
|
21
|
+
w_mean: [D_out, D_in] weight mean
|
22
|
+
w_var: [D_out, D_in] weight covariance, w_cov[k][i]: var(w_ki)
|
23
|
+
Output
|
24
|
+
h_mean: [B, T, D_out]
|
25
|
+
h_var: [B, T, D_out] h_var[k] = var(h_k)
|
26
|
+
"""
|
27
|
+
|
28
|
+
# calculate mean(h)
|
29
|
+
h_mean = F.linear(e_mean, w_mean, bias)
|
30
|
+
|
31
|
+
# calculate var(h)
|
32
|
+
weight_mean2_var_sum = w_mean ** 2 + w_var # [D_out, D_in]
|
33
|
+
h_var = e_mean **2 @ w_var.T + e_var @ weight_mean2_var_sum.T
|
34
|
+
|
35
|
+
return h_mean, h_var
|
36
|
+
|
37
|
+
def forward_linear_diag_determinstic_weight(e_mean, e_var, weight, bias = None):
|
38
|
+
"""
|
39
|
+
Pass a distribution with diagonal covariance through a linear layer
|
40
|
+
|
41
|
+
Given e ~ N(e_mean, e_var) and determinstic weight W and bias b, calculate the mean and variance h = eW.T + b.
|
42
|
+
|
43
|
+
Note that as we always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.
|
44
|
+
|
45
|
+
Input
|
46
|
+
e_mean: [B, T, D_in] embedding mean
|
47
|
+
e_var: [B, T, D_in] embedding variance
|
48
|
+
w_mean: [D_out, D_in] weight
|
49
|
+
Output
|
50
|
+
h_mean: [B, T, D_out]
|
51
|
+
h_var: [B, T, D_out] h_var[k] = var(h_k)
|
52
|
+
"""
|
53
|
+
|
54
|
+
h_mean = F.linear(e_mean, weight, bias)
|
55
|
+
h_var = F.linear(e_var, weight ** 2, None)
|
56
|
+
|
57
|
+
return h_mean, h_var
|
58
|
+
|
59
|
+
@torch.enable_grad()
|
60
|
+
def forward_activation_diag(activation_func, h_mean, h_var):
|
61
|
+
"""
|
62
|
+
Pass a distribution with diagonal covariance through an activation layer.
|
63
|
+
|
64
|
+
Given h ~ N(h_mean, h_cov), g(·), where h_cov is a diagonal matrix,
|
65
|
+
approximate the distribution of a = g(h) as
|
66
|
+
a ~ N(g(h_mean), g'(h_mean)^T h_var g'(h_mean))
|
67
|
+
|
68
|
+
Input
|
69
|
+
activation_func: g(·)
|
70
|
+
h_mean: [B, T, D] input mean
|
71
|
+
h_var: [B, T, D] input variance
|
72
|
+
|
73
|
+
Output
|
74
|
+
a_mean: [B, T, D]
|
75
|
+
a_var: [B, T, D]
|
76
|
+
"""
|
77
|
+
|
78
|
+
h_mean_grad = h_mean.detach().clone().requires_grad_()
|
79
|
+
|
80
|
+
a_mean = activation_func(h_mean_grad)
|
81
|
+
a_mean.retain_grad()
|
82
|
+
a_mean.backward(torch.ones_like(a_mean)) #[B, T, D]
|
83
|
+
|
84
|
+
nabla = h_mean_grad.grad #[B, T, D]
|
85
|
+
a_var = nabla ** 2 * h_var
|
86
|
+
|
87
|
+
return a_mean.detach(), a_var
|
88
|
+
|
89
|
+
def forward_layer_norm_diag(e_mean, e_var, ln_weight, ln_eps):
|
90
|
+
"""
|
91
|
+
Pass a distribution with diagonal covariance through LayerNorm layer
|
92
|
+
|
93
|
+
Input
|
94
|
+
e_mean: mean of input distribution [B, T, D]
|
95
|
+
e_var: variance of input distribution [B, T, D]
|
96
|
+
ln_weight: layer norm scale factor
|
97
|
+
ln_eps: layer norm eps
|
98
|
+
|
99
|
+
Output
|
100
|
+
output_var [B, T, D]
|
101
|
+
"""
|
102
|
+
|
103
|
+
# calculate the var
|
104
|
+
input_mean_var = e_mean.var(dim=-1, keepdim=True, unbiased=False) # [B, T, 1]
|
105
|
+
scale_factor = (1 / (input_mean_var + ln_eps)) * ln_weight **2 # [B, T, D]
|
106
|
+
output_var = scale_factor * e_var # [B, T, D]
|
107
|
+
|
108
|
+
return output_var
|
109
|
+
|
110
|
+
def forward_value_cov_Bayesian_W(W_v, W_v_var, input_mean, input_var, n_h, D_v, diag_cov = False):
|
111
|
+
"""
|
112
|
+
Given value matrix W_v ~ N(mean(W), var(W)) and input E ~ N(mean(E), var(E))
|
113
|
+
Compute the covariance of output v = W_v @ E
|
114
|
+
|
115
|
+
Input:
|
116
|
+
n_h: number of attention heads
|
117
|
+
D_v: dimension of value, n_h * D_v = D
|
118
|
+
W_v: value weight matrix [D, D]
|
119
|
+
W_v_var: variance of value matrix, [D, D]
|
120
|
+
input_mean: mean of input [B, T, D]
|
121
|
+
input_var: variance of input variance [B, T, D]
|
122
|
+
diag_cov: whether input only has diag covariance
|
123
|
+
|
124
|
+
Output:
|
125
|
+
v_cov [B, T, n_h, D_v, D_v] or v_var [B, T, n_h, D_v]
|
126
|
+
"""
|
127
|
+
|
128
|
+
B, T, D = input_var.size()
|
129
|
+
|
130
|
+
if not diag_cov:
|
131
|
+
## compute general covariance
|
132
|
+
W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D)
|
133
|
+
# [D, D] -> [1, 1, n_h, D_v, D]
|
134
|
+
input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
|
135
|
+
# [B, T, D] -> [B, T, 1, 1, D]
|
136
|
+
v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
|
137
|
+
# [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
|
138
|
+
v_cov = torch.matmul(W_v_reshaped, v_cov)
|
139
|
+
# [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v] -> [B, T, n_h, D_v, D_v]
|
140
|
+
|
141
|
+
## add missing part for variance
|
142
|
+
W_v_var_reshaped = W_v_var.reshape(1, 1, n_h, D_v, D)
|
143
|
+
#[D, D] -> [1, 1, n_h, D_v, D]
|
144
|
+
input_var_plus_mean_square = input_var_reshaped + input_mean.reshape(B, T, 1, 1, D)**2 #[B, T, 1, 1, D]
|
145
|
+
extra_var_term = torch.sum(input_var_plus_mean_square * W_v_var_reshaped, dim=[4]) # [B, T, n_h, D_v, D] -> [B, T, n_h, D_v]
|
146
|
+
v_cov = v_cov + torch.diag_embed(extra_var_term)
|
147
|
+
|
148
|
+
return v_cov
|
149
|
+
|
150
|
+
else:
|
151
|
+
weight_mean2_var_sum = W_v **2 + W_v_var # [D, D]
|
152
|
+
v_var = input_mean **2 @ W_v_var.T + input_var @ weight_mean2_var_sum.T # [B, T, D]
|
153
|
+
|
154
|
+
return v_var.reshape(B, T, n_h, D_v)
|
155
|
+
|
156
|
+
def forward_value_cov_determinstic_W(W_v, input_var, n_h, D_v, diag_cov = False):
|
157
|
+
"""
|
158
|
+
Given determinstic value matrix W_v and input E ~ N(mean(E), var(E))
|
159
|
+
Compute the covariance of output v = W_v @ E
|
160
|
+
|
161
|
+
|
162
|
+
Input:
|
163
|
+
n_h: number of attention heads
|
164
|
+
D_v: dimension of value, n_h * D_v = D
|
165
|
+
W_v: value weight matrix [D, D], which can be reshaped into [n_h, D_v, D]
|
166
|
+
input_var: variance of input variance [B, T, D]
|
167
|
+
diag_cov: whether input only has diag covariance
|
168
|
+
|
169
|
+
Output:
|
170
|
+
v_cov [B, T, n_h, D_v, D_v] or v_var [B, T, n_h, D_v]
|
171
|
+
"""
|
172
|
+
|
173
|
+
B, T, D = input_var.size()
|
174
|
+
|
175
|
+
if not diag_cov:
|
176
|
+
W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D)
|
177
|
+
#[n_h, D_v, D] -> [1, 1, n_h, D_v, D]
|
178
|
+
input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
|
179
|
+
# [B, T, D] -> [B, T, 1, 1, D]
|
180
|
+
v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
|
181
|
+
# [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
|
182
|
+
v_cov = torch.matmul(W_v_reshaped, v_cov)
|
183
|
+
# [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v] -> [B, T, n_h, D_v, D_v]
|
184
|
+
|
185
|
+
return v_cov
|
186
|
+
|
187
|
+
else:
|
188
|
+
v_var = input_var @ (W_v ** 2).T
|
189
|
+
|
190
|
+
return v_var.reshape(B, T, n_h, D_v)
|
191
|
+
|
192
|
+
def forward_QKV_cov(attention_score, v_cov, diag_cov = False):
|
193
|
+
"""
|
194
|
+
given attention score (QK^T) and V ~ N(mean(V), cov(V))
|
195
|
+
compute the covariance of output E = (QK^T) V
|
196
|
+
|
197
|
+
Input:
|
198
|
+
attention_score: [B, n_h, T, T] attention_score[t] is token t's attention score for all other tokens
|
199
|
+
v_cov: [B, T, n_h, D_v, D_v] or [B, T, n_h, D_v], covariance of value
|
200
|
+
diag_cov: whether input only has diag covariance
|
201
|
+
|
202
|
+
Output:
|
203
|
+
QKV_cov: [B, n_h, T, D_v, D_v] or [B, T, n_h, D_v], covariance of output E
|
204
|
+
"""
|
205
|
+
if diag_cov:
|
206
|
+
B, T, n_h, D_v = v_cov.size()
|
207
|
+
QKV_cov = attention_score **2 @ v_cov.transpose(1, 2) # [B, n_h, T, D_v]
|
208
|
+
# v_cov [B, T, n_h, D_v] -> [B, n_h, T, D_v]
|
209
|
+
# [B, n_h, T, T] @ [B, n_h, T, D_v] -> [B, n_h, T, D_v]
|
210
|
+
else:
|
211
|
+
|
212
|
+
B, T, n_h, D_v, _ = v_cov.size()
|
213
|
+
|
214
|
+
QKV_cov = attention_score **2 @ v_cov.permute(0, 2, 1, 3, 4).reshape(B, n_h, T, D_v * D_v) # [B, n_h, T, D_v * D_v]
|
215
|
+
# v_cov [B, T, n_h, D_v, D_v] -> [B, n_h, T, D_v * D_v]
|
216
|
+
# [B, n_h, T, T] @ [B, n_h, T, D_v * D_v] -> [B, n_h, T, D_v * D_v]
|
217
|
+
QKV_cov = QKV_cov.reshape(B, n_h, T, D_v, D_v)
|
218
|
+
|
219
|
+
return QKV_cov
|
220
|
+
|
221
|
+
def forward_fuse_multi_head_cov(QKV_cov, project_W, diag_cov = False):
|
222
|
+
"""
|
223
|
+
given concatanated multi-head embedding E ~ N(mean(E), cov(E)) and project weight matrix W
|
224
|
+
compute variance of each output dimenison
|
225
|
+
|
226
|
+
Input:
|
227
|
+
QKV_cov: [B, n_h, T, D_v, D_v] or [B, n_h, T, D_v]
|
228
|
+
project_W: [D, D] D_out x D_in (n_h * D_v)
|
229
|
+
diag_cov: whether input only has diag covariance
|
230
|
+
|
231
|
+
Output:
|
232
|
+
output_var [B, T, D]
|
233
|
+
"""
|
234
|
+
if diag_cov:
|
235
|
+
B, n_h, T, D_v = QKV_cov.size()
|
236
|
+
output_var = QKV_cov.permute(0, 2, 1, 3).reshape(B, T, n_h * D_v) @ project_W ** 2
|
237
|
+
# QKV_cov [B, n_h, T, D_v] -> [B, T, n_h, D_v] -> [B, T, n_h * D_v]
|
238
|
+
|
239
|
+
return output_var
|
240
|
+
|
241
|
+
else:
|
242
|
+
B, n_h, T, D_v, _ = QKV_cov.size()
|
243
|
+
D, _ = project_W.shape
|
244
|
+
|
245
|
+
project_W_reshaped_1 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, D_v, 1)
|
246
|
+
# [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, D_v, 1]
|
247
|
+
project_W_reshaped_2 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, 1, D_v)
|
248
|
+
# [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, 1, D_v]
|
249
|
+
|
250
|
+
project_W_outer = torch.bmm(project_W_reshaped_1, project_W_reshaped_2).reshape(n_h, D, D_v, D_v).permute(1, 0, 2, 3) # [D, n_h, D_v, D_v]
|
251
|
+
# [n_h * D, D_v, D_v] @ [n_h * D, 1, D_v] -> [n_h * D, D_v, D_v] -> [D, n_h, D_v, D_v]
|
252
|
+
|
253
|
+
output_var_einsum = torch.einsum('dhij,bthij->dbt', project_W_outer, QKV_cov.permute(0, 2, 1, 3, 4))
|
254
|
+
|
255
|
+
return output_var_einsum.permute(1, 2, 0)
|
256
|
+
|
257
|
+
class SUQ_LayerNorm_Diag(nn.Module):
|
258
|
+
"""
|
259
|
+
LayerNorm module with uncertainty propagation under SUQ.
|
260
|
+
|
261
|
+
Wraps `nn.LayerNorm` and propagates input variance analytically using running statistics. See the SUQ paper for theoretical background and assumptions.
|
262
|
+
|
263
|
+
Inputs:
|
264
|
+
LayerNorm (nn.LayerNorm): The original layer norm module to wrap
|
265
|
+
"""
|
266
|
+
|
267
|
+
def __init__(self, LayerNorm):
|
268
|
+
super().__init__()
|
269
|
+
|
270
|
+
self.LayerNorm = LayerNorm
|
271
|
+
|
272
|
+
def forward(self, x_mean, x_var):
|
273
|
+
"""
|
274
|
+
Inputs:
|
275
|
+
x_mean (Tensor): Input mean, shape [B, T, D]
|
276
|
+
x_var (Tensor): Input variance, shape [B, T, D]
|
277
|
+
|
278
|
+
Outputs:
|
279
|
+
out_mean (Tensor): Output mean after layer norm, shape [B, T, D]
|
280
|
+
out_var (Tensor): Output variance after layer norm, shape [B, T, D]
|
281
|
+
"""
|
282
|
+
with torch.no_grad():
|
283
|
+
|
284
|
+
out_mean = self.LayerNorm.forward(x_mean)
|
285
|
+
out_var = forward_layer_norm_diag(x_mean, x_var, self.LayerNorm.weight, 1e-5)
|
286
|
+
|
287
|
+
return out_mean, out_var
|
288
|
+
|
289
|
+
|
290
|
+
class SUQ_Classifier_Diag(nn.Module):
|
291
|
+
"""
|
292
|
+
Classifier head with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
293
|
+
|
294
|
+
Wraps a standard linear classifier and applies closed-form mean and variance propagation.
|
295
|
+
See the SUQ paper for theoretical background and assumptions.
|
296
|
+
|
297
|
+
Inputs:
|
298
|
+
classifier (nn.Linear): The final classification head
|
299
|
+
w_var (Tensor): Weight variances, shape [D_out, D_in]
|
300
|
+
b_var (Tensor): Bias variances, shape [D_out]
|
301
|
+
"""
|
302
|
+
|
303
|
+
def __init__(self, classifier, w_var, b_var):
|
304
|
+
super().__init__()
|
305
|
+
|
306
|
+
self.weight = classifier.weight
|
307
|
+
self.bias = classifier.bias
|
308
|
+
self.w_var = w_var.reshape(self.weight.shape)
|
309
|
+
self.b_var = b_var.reshape(self.bias.shape)
|
310
|
+
|
311
|
+
def forward(self, x_mean, x_var):
|
312
|
+
"""
|
313
|
+
Inputs:
|
314
|
+
x_mean (Tensor): Input mean, shape [B, D]
|
315
|
+
x_var (Tensor): Input variance, shape [B, D]
|
316
|
+
|
317
|
+
Outputs:
|
318
|
+
h_mean (Tensor): Output mean, shape [B, D_out]
|
319
|
+
h_var (Tensor): Output variance, shape [B, D_out]
|
320
|
+
"""
|
321
|
+
with torch.no_grad():
|
322
|
+
h_mean, h_var = forward_aW_diag(x_mean, x_var, self.weight.data, self.bias.data, self.w_var, self.b_var)
|
323
|
+
return h_mean, h_var
|
324
|
+
|
325
|
+
class SUQ_TransformerMLP_Diag(nn.Module):
|
326
|
+
"""
|
327
|
+
MLP submodule of a transformer block with uncertainty propagation under SUQ.
|
328
|
+
|
329
|
+
Supports both deterministic and Bayesian forward modes with closed-form variance propagation.
|
330
|
+
Used internally in `SUQ_Transformer_Block_Diag`.
|
331
|
+
|
332
|
+
Inputs:
|
333
|
+
MLP (nn.Module): Original MLP submodule
|
334
|
+
determinstic (bool): Whether to treat the MLP weights as deterministic
|
335
|
+
w_fc_var (Tensor, optional): Variance of the first linear layer (if Bayesian)
|
336
|
+
w_proj_var (Tensor, optional): Variance of the second linear layer (if Bayesian)
|
337
|
+
"""
|
338
|
+
|
339
|
+
def __init__(self, MLP, determinstic = True, w_fc_var = None, w_proj_var = None):
|
340
|
+
super().__init__()
|
341
|
+
|
342
|
+
self.MLP = MLP
|
343
|
+
self.determinstic = determinstic
|
344
|
+
if not determinstic:
|
345
|
+
self.w_fc_var = w_fc_var.reshape(self.MLP.c_fc.weight.shape)
|
346
|
+
self.w_proj_var = w_proj_var.reshape(self.MLP.c_proj.weight.shape)
|
347
|
+
|
348
|
+
def forward(self, x_mean, x_var):
|
349
|
+
"""
|
350
|
+
Inputs:
|
351
|
+
x_mean (Tensor): Input mean, shape [B, T, D]
|
352
|
+
x_var (Tensor): Input variance, shape [B, T, D]
|
353
|
+
|
354
|
+
Outputs:
|
355
|
+
h_mean (Tensor): Output mean, shape [B, T, D]
|
356
|
+
h_var (Tensor): Output variance, shape [B, T, D]
|
357
|
+
"""
|
358
|
+
|
359
|
+
# first fc layer
|
360
|
+
with torch.no_grad():
|
361
|
+
if self.determinstic:
|
362
|
+
h_mean, h_var = forward_linear_diag_determinstic_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.MLP.c_fc.bias.data)
|
363
|
+
else:
|
364
|
+
h_mean, h_var = forward_linear_diag_Bayesian_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.w_fc_var, self.MLP.c_fc.bias.data)
|
365
|
+
# activation function
|
366
|
+
h_mean, h_var = forward_activation_diag(self.MLP.gelu, h_mean, h_var)
|
367
|
+
# second fc layer
|
368
|
+
with torch.no_grad():
|
369
|
+
if self.determinstic:
|
370
|
+
h_mean, h_var = forward_linear_diag_determinstic_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.MLP.c_proj.bias.data)
|
371
|
+
else:
|
372
|
+
h_mean, h_var = forward_linear_diag_Bayesian_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.w_proj_var, self.MLP.c_proj.bias.data)
|
373
|
+
|
374
|
+
return h_mean, h_var
|
375
|
+
|
376
|
+
class SUQ_Attention_Diag(nn.Module):
|
377
|
+
"""
|
378
|
+
Self-attention module with uncertainty propagation under SUQ.
|
379
|
+
|
380
|
+
Supports deterministic and Bayesian value projections, with optional diagonal covariance assumptions. For details see SUQ paper section A.6
|
381
|
+
Used internally in `SUQ_Transformer_Block_Diag`.
|
382
|
+
|
383
|
+
Inputs:
|
384
|
+
Attention (nn.Module): The original attention module
|
385
|
+
determinstic (bool): Whether to treat value projections as deterministic
|
386
|
+
diag_cov (bool): If True, only compute the diagoanl covariance for value
|
387
|
+
W_v_var (Tensor, optional): Posterior variance for value matrix (if Bayesian)
|
388
|
+
"""
|
389
|
+
|
390
|
+
def __init__(self, Attention, determinstic = True, diag_cov = False, W_v_var = None):
|
391
|
+
super().__init__()
|
392
|
+
|
393
|
+
self.Attention = Attention
|
394
|
+
self.determinstic = determinstic
|
395
|
+
self.diag_cov = diag_cov
|
396
|
+
|
397
|
+
if not self.determinstic:
|
398
|
+
self.W_v_var = W_v_var # [D * D]
|
399
|
+
|
400
|
+
def forward(self, x_mean, x_var):
|
401
|
+
"""
|
402
|
+
Inputs:
|
403
|
+
x_mean (Tensor): Input mean, shape [B, T, D]
|
404
|
+
x_var (Tensor): Input variance, shape [B, T, D]
|
405
|
+
|
406
|
+
Outputs:
|
407
|
+
output_mean (Tensor): Output mean after attention, shape [B, T, D]
|
408
|
+
output_var (Tensor): Output variance after attention, shape [B, T, D]
|
409
|
+
"""
|
410
|
+
|
411
|
+
with torch.no_grad():
|
412
|
+
|
413
|
+
output_mean, attention_score = self.Attention.forward(x_mean, True)
|
414
|
+
|
415
|
+
n_h = self.Attention.n_head
|
416
|
+
B, T, D = x_mean.size()
|
417
|
+
D_v = D // n_h
|
418
|
+
|
419
|
+
W_v = self.Attention.c_attn_v.weight.data
|
420
|
+
project_W = self.Attention.c_proj.weight.data
|
421
|
+
|
422
|
+
if self.determinstic:
|
423
|
+
v_cov = forward_value_cov_determinstic_W(W_v, x_var, n_h, D_v)
|
424
|
+
else:
|
425
|
+
v_cov = forward_value_cov_Bayesian_W(W_v, self.W_v_var.reshape(D, D), x_mean, x_var, n_h, D_v, self.diag_cov)
|
426
|
+
|
427
|
+
QKV_cov = forward_QKV_cov(attention_score, v_cov, self.diag_cov)
|
428
|
+
output_var = forward_fuse_multi_head_cov(QKV_cov, project_W, self.diag_cov)
|
429
|
+
|
430
|
+
return output_mean, output_var
|
431
|
+
|
432
|
+
class SUQ_Transformer_Block_Diag(nn.Module):
|
433
|
+
"""
|
434
|
+
Single transformer block with uncertainty propagation under SUQ.
|
435
|
+
|
436
|
+
Wraps LayerNorm, attention, and MLP submodules with uncertainty-aware versions.
|
437
|
+
Used in `SUQ_ViT_Diag` to form a full transformer stack.
|
438
|
+
|
439
|
+
Inputs:
|
440
|
+
MLP (nn.Module): Original MLP submodule
|
441
|
+
Attention (nn.Module): Original attention submodule
|
442
|
+
LN_1 (nn.LayerNorm): Pre-attention layer norm
|
443
|
+
LN_2 (nn.LayerNorm): Pre-MLP layer norm
|
444
|
+
MLP_determinstic (bool): Whether to treat MLP as deterministic
|
445
|
+
Attn_determinstic (bool): Whether to treat attention as deterministic
|
446
|
+
diag_cov (bool): If True, only compute the diagoanl covariance for value
|
447
|
+
w_fc_var (Tensor or None): Posterior variance of MLP input projection (if Bayesian)
|
448
|
+
w_proj_var (Tensor or None): Posterior variance of MLP output projection (if Bayesian)
|
449
|
+
W_v_var (Tensor or None): Posterior variance of value matrix (if Bayesian)
|
450
|
+
"""
|
451
|
+
|
452
|
+
|
453
|
+
def __init__(self, MLP, Attention, LN_1, LN_2, MLP_determinstic, Attn_determinstic, diag_cov = False, w_fc_var = None, w_proj_var = None, W_v_var = None):
|
454
|
+
super().__init__()
|
455
|
+
|
456
|
+
self.ln_1 = SUQ_LayerNorm_Diag(LN_1)
|
457
|
+
self.ln_2 = SUQ_LayerNorm_Diag(LN_2)
|
458
|
+
self.attn = SUQ_Attention_Diag(Attention, Attn_determinstic, diag_cov, W_v_var)
|
459
|
+
self.mlp = SUQ_TransformerMLP_Diag(MLP, MLP_determinstic, w_fc_var, w_proj_var)
|
460
|
+
|
461
|
+
def forward(self, x_mean, x_var):
|
462
|
+
"""
|
463
|
+
Inputs:
|
464
|
+
x_mean (Tensor): Input mean, shape [B, T, D]
|
465
|
+
x_var (Tensor): Input variance, shape [B, T, D]
|
466
|
+
|
467
|
+
Outputs:
|
468
|
+
h_mean (Tensor): Output mean after transformer block, shape [B, T, D]
|
469
|
+
h_var (Tensor): Output variance after transformer block, shape [B, T, D]
|
470
|
+
"""
|
471
|
+
|
472
|
+
h_mean, h_var = self.ln_1(x_mean, x_var)
|
473
|
+
h_mean, h_var = self.attn(h_mean, h_var)
|
474
|
+
h_mean = h_mean + x_mean
|
475
|
+
h_var = h_var + x_var
|
476
|
+
|
477
|
+
old_h_mean, old_h_var = h_mean, h_var
|
478
|
+
|
479
|
+
h_mean, h_var = self.ln_2(h_mean, h_var)
|
480
|
+
h_mean, h_var = self.mlp(h_mean, h_var)
|
481
|
+
h_mean = h_mean + old_h_mean
|
482
|
+
h_var = h_var + old_h_var
|
483
|
+
|
484
|
+
return h_mean, h_var
|
485
|
+
|
486
|
+
|
487
|
+
class SUQ_ViT_Diag(SUQ_Base):
|
488
|
+
"""
|
489
|
+
Vision Transformer model with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.
|
490
|
+
|
491
|
+
Wraps a ViT architecture into a structured uncertainty-aware model by replacing parts
|
492
|
+
of the network with SUQ-compatible blocks. Allows selective Bayesian treatment of MLP
|
493
|
+
and attention modules within each transformer block.
|
494
|
+
|
495
|
+
Currently supports classification only. See the SUQ paper for theoretical background and assumptions.
|
496
|
+
|
497
|
+
Inputs:
|
498
|
+
ViT (nn.Module): A Vision Transformer model structured like `examples/vit_model.py`
|
499
|
+
posterior_variance (Tensor): Flattened posterior variance vector
|
500
|
+
MLP_determinstic (bool): Whether MLP submodules are treated as deterministic
|
501
|
+
Attn_determinstic (bool): Whether attention submodules are treated as deterministic
|
502
|
+
scale_init (float, optional): Initial value for the scale factor
|
503
|
+
attention_diag_cov (bool): If True, only compute the diagoanl covariance for value
|
504
|
+
likelihood (str): Currently only support 'Classification'
|
505
|
+
num_det_blocks (int): Number of transformer blocks to leave deterministic (from the bottom up)
|
506
|
+
"""
|
507
|
+
|
508
|
+
def __init__(self, ViT, posterior_variance, MLP_determinstic, Attn_determinstic, scale_init = 1.0, attention_diag_cov = False, likelihood = 'clasification', num_det_blocks = 10):
|
509
|
+
super().__init__(likelihood, scale_init)
|
510
|
+
|
511
|
+
if likelihood not in ['classification']:
|
512
|
+
raise ValueError(f"{likelihood} not supported for ViT")
|
513
|
+
|
514
|
+
|
515
|
+
self.transformer = nn.ModuleDict(dict(
|
516
|
+
pte = ViT.transformer.pte,
|
517
|
+
h = nn.ModuleList(),
|
518
|
+
ln_f = SUQ_LayerNorm_Diag(ViT.transformer.ln_f)
|
519
|
+
))
|
520
|
+
|
521
|
+
self.scale_factor = nn.Parameter(torch.Tensor([scale_init]))
|
522
|
+
|
523
|
+
num_param_c_fc = ViT.transformer.h[0].mlp.c_fc.weight.numel()
|
524
|
+
num_param_c_proj = ViT.transformer.h[0].mlp.c_proj.weight.numel()
|
525
|
+
num_param_value_matrix = ViT.transformer.h[0].attn.c_proj.weight.numel()
|
526
|
+
|
527
|
+
index = 0
|
528
|
+
for block_index in range(len(ViT.transformer.h)):
|
529
|
+
|
530
|
+
if block_index < num_det_blocks:
|
531
|
+
self.transformer.h.append(ViT.transformer.h[block_index])
|
532
|
+
else:
|
533
|
+
if not MLP_determinstic:
|
534
|
+
w_fc_var = posterior_variance[index: index + num_param_c_fc]
|
535
|
+
index += num_param_c_fc
|
536
|
+
w_proj_var = posterior_variance[index: index + num_param_c_proj]
|
537
|
+
index += num_param_c_proj
|
538
|
+
self.transformer.h.append(
|
539
|
+
SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp,
|
540
|
+
ViT.transformer.h[block_index].attn,
|
541
|
+
ViT.transformer.h[block_index].ln_1,
|
542
|
+
ViT.transformer.h[block_index].ln_2,
|
543
|
+
MLP_determinstic,
|
544
|
+
Attn_determinstic,
|
545
|
+
attention_diag_cov,
|
546
|
+
w_fc_var,
|
547
|
+
w_proj_var,
|
548
|
+
None))
|
549
|
+
|
550
|
+
if not Attn_determinstic:
|
551
|
+
w_v_var = posterior_variance[index : index + num_param_value_matrix]
|
552
|
+
index += num_param_value_matrix
|
553
|
+
self.transformer.h.append(
|
554
|
+
SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp,
|
555
|
+
ViT.transformer.h[block_index].attn,
|
556
|
+
ViT.transformer.h[block_index].ln_1,
|
557
|
+
ViT.transformer.h[block_index].ln_2,
|
558
|
+
MLP_determinstic,
|
559
|
+
Attn_determinstic,
|
560
|
+
attention_diag_cov,
|
561
|
+
None,
|
562
|
+
None,
|
563
|
+
w_v_var))
|
564
|
+
|
565
|
+
num_param_classifier_weight = ViT.classifier.weight.numel()
|
566
|
+
self.classifier = SUQ_Classifier_Diag(ViT.classifier, posterior_variance[index: index + num_param_classifier_weight], posterior_variance[index + num_param_classifier_weight:])
|
567
|
+
|
568
|
+
def forward_latent(self, pixel_values, interpolate_pos_encoding = None):
|
569
|
+
|
570
|
+
"""
|
571
|
+
Compute the predictive mean and variance of the ViT's latent output before applying the final likelihood layer.
|
572
|
+
|
573
|
+
Traverses the full transformer stack with uncertainty propagation.
|
574
|
+
|
575
|
+
Inputs:
|
576
|
+
pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
|
577
|
+
interpolate_pos_encoding (optional): Optional positional embedding interpolation
|
578
|
+
|
579
|
+
Outputs:
|
580
|
+
x_mean (Tensor): Predicted latent mean at the [CLS] token, shape [B, D]
|
581
|
+
x_var (Tensor): Predicted latent variance at the [CLS] token, shape [B, D]
|
582
|
+
"""
|
583
|
+
|
584
|
+
device = pixel_values.device
|
585
|
+
|
586
|
+
x_mean = self.transformer.pte(
|
587
|
+
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
|
588
|
+
)
|
589
|
+
|
590
|
+
# pass through model
|
591
|
+
x_var = torch.zeros_like(x_mean, device = device)
|
592
|
+
|
593
|
+
for i, block in enumerate(self.transformer.h):
|
594
|
+
|
595
|
+
if isinstance(block, SUQ_Transformer_Block_Diag):
|
596
|
+
|
597
|
+
x_mean, x_var = block(x_mean, x_var)
|
598
|
+
else:
|
599
|
+
x_mean = block(x_mean)
|
600
|
+
|
601
|
+
x_mean, x_var = self.transformer.ln_f(x_mean, x_var)
|
602
|
+
|
603
|
+
x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
|
604
|
+
x_var = x_var / self.scale_factor.to(device)
|
605
|
+
|
606
|
+
return x_mean, x_var
|
607
|
+
|
608
|
+
def forward(self, pixel_values, interpolate_pos_encoding = None):
|
609
|
+
"""
|
610
|
+
Compute predictive class probabilities using a probit approximation.
|
611
|
+
|
612
|
+
Performs a full forward pass through the ViT with uncertainty propagation, and
|
613
|
+
produces softmax-normalized class probabilities for classification.
|
614
|
+
|
615
|
+
Inputs:
|
616
|
+
pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
|
617
|
+
interpolate_pos_encoding (optional): Optional positional embedding interpolation
|
618
|
+
|
619
|
+
Outputs:
|
620
|
+
Tensor: Predicted class probabilities, shape [B, num_classes]
|
621
|
+
"""
|
622
|
+
|
623
|
+
x_mean, x_var = self.forward_latent(pixel_values, interpolate_pos_encoding)
|
624
|
+
kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)
|
625
|
+
|
626
|
+
return torch.softmax(kappa * x_mean, dim=-1)
|
627
|
+
|
suq/streamline_layer.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from .diag_suq_mlp import (
|
2
|
+
SUQ_Linear_Diag,
|
3
|
+
SUQ_Activation_Diag,
|
4
|
+
SUQ_BatchNorm_Diag
|
5
|
+
)
|
6
|
+
from .diag_suq_transformer import (
|
7
|
+
SUQ_TransformerMLP_Diag,
|
8
|
+
SUQ_Attention_Diag,
|
9
|
+
SUQ_LayerNorm_Diag,
|
10
|
+
SUQ_Classifier_Diag,
|
11
|
+
SUQ_Transformer_Block_Diag
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"SUQ_Linear_Diag",
|
16
|
+
"SUQ_Activation_Diag",
|
17
|
+
"SUQ_BatchNorm_Diag",
|
18
|
+
"SUQ_TransformerMLP_Diag",
|
19
|
+
"SUQ_Attention_Diag",
|
20
|
+
"SUQ_LayerNorm_Diag",
|
21
|
+
"SUQ_Classifier_Diag",
|
22
|
+
"SUQ_Transformer_Block_Diag"
|
23
|
+
]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: suq
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: Streamlined Uncertainty Quantification (SUQ)
|
5
|
+
Home-page: https://github.com/AaltoML/SUQ
|
6
|
+
Author: Rui Li, Marcus Klasson, Arno Solin, Martin Trapp
|
7
|
+
License: MIT
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: torch>=1.10
|
11
|
+
Requires-Dist: numpy>=1.21
|
12
|
+
Requires-Dist: tqdm>=4.60
|
13
|
+
Dynamic: author
|
14
|
+
Dynamic: classifier
|
15
|
+
Dynamic: home-page
|
16
|
+
Dynamic: license
|
17
|
+
Dynamic: license-file
|
18
|
+
Dynamic: requires-dist
|
19
|
+
Dynamic: summary
|
@@ -0,0 +1,12 @@
|
|
1
|
+
suq/SUQ_MLP.py,sha256=ciegIOz2Y0vtJ4Uc56dhwHjeaAoLAKqQQnSTgyh8Sqc,497
|
2
|
+
suq/SUQ_ViT.py,sha256=6BpHMOLf1qhVRgqM3guAcS60PpYJ1L1CEWdIerIKrv4,841
|
3
|
+
suq/__init__.py,sha256=5LGGZQ6wwjEfwsmzx55lSA4nfBx1N1qQpzgFff8zpJM,120
|
4
|
+
suq/base_suq.py,sha256=77LggEGn5m3h472kRLVghM1j7WCY4zJih3undVdIoS0,7963
|
5
|
+
suq/diag_suq_mlp.py,sha256=yE4p2pQtD97g6GvebkROqNDPFuJUhrwZH9zRkWAiSWk,10985
|
6
|
+
suq/diag_suq_transformer.py,sha256=QfGkPCf1DJ1cVIrkDp7tD8bImtCgt3XoUnngtgfwIAc,25576
|
7
|
+
suq/streamline_layer.py,sha256=sew0BodioNjElVpSq0c87c0ORRA5yGDbLTy1eeLp8VI,504
|
8
|
+
suq-0.1.0.dist-info/licenses/LICENSE,sha256=XbdRUcHQPFVqh9MgS85KpgRmx09FAPBW5gOHyzjJz3c,1064
|
9
|
+
suq-0.1.0.dist-info/METADATA,sha256=lcWlRUIWAM-LKK8MqbDehGjJ3QX1_Fl7nDCtZ-Jb1V4,497
|
10
|
+
suq-0.1.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
11
|
+
suq-0.1.0.dist-info/top_level.txt,sha256=_8edEeGJ6W1u-AhLS6x16sRWGGKwumOGPl7EBsW-v5g,4
|
12
|
+
suq-0.1.0.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 AaltoML
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1 @@
|
|
1
|
+
suq
|