RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_classes/AE_classes.py +18 -14
- RRAEsTorch/tests/test_AE_classes_CNN.py +20 -26
- RRAEsTorch/tests/test_AE_classes_MLP.py +20 -28
- RRAEsTorch/tests/test_fitting_CNN.py +14 -14
- RRAEsTorch/tests/test_fitting_MLP.py +11 -13
- RRAEsTorch/tests/test_save.py +11 -11
- RRAEsTorch/training_classes/training_classes.py +78 -121
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/METADATA +1 -2
- rraestorch-0.1.7.dist-info/RECORD +22 -0
- RRAEsTorch/tests/test_wrappers.py +0 -56
- RRAEsTorch/utilities/utilities.py +0 -1562
- RRAEsTorch/wrappers/__init__.py +0 -1
- RRAEsTorch/wrappers/wrappers.py +0 -237
- rraestorch-0.1.5.dist-info/RECORD +0 -27
- rraestorch-0.1.5.dist-info/licenses/LICENSE copy +0 -21
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/WHEEL +0 -0
- {rraestorch-0.1.5.dist-info → rraestorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,1562 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import Optional
|
|
3
|
-
from torch.nn import Linear
|
|
4
|
-
from torch.nn import Conv2d, ConvTranspose2d, Conv1d, ConvTranspose1d, Conv3d, ConvTranspose3d
|
|
5
|
-
import numpy.random as random
|
|
6
|
-
import numpy as np
|
|
7
|
-
from operator import itemgetter
|
|
8
|
-
import numpy as np
|
|
9
|
-
from tqdm import tqdm
|
|
10
|
-
import itertools
|
|
11
|
-
from torchvision.ops import MLP
|
|
12
|
-
import torch
|
|
13
|
-
from torch.func import vmap
|
|
14
|
-
|
|
15
|
-
##################### ML Training/Evaluation functions ##########################
|
|
16
|
-
|
|
17
|
-
def eval_with_batches(
|
|
18
|
-
x,
|
|
19
|
-
batch_size,
|
|
20
|
-
call_func,
|
|
21
|
-
end_type="concat_and_resort",
|
|
22
|
-
str=None,
|
|
23
|
-
*args,
|
|
24
|
-
**kwargs,
|
|
25
|
-
):
|
|
26
|
-
"""Function to evaluate a model on a dataset using batches.
|
|
27
|
-
|
|
28
|
-
Parameters
|
|
29
|
-
----------
|
|
30
|
-
x : jnp.array
|
|
31
|
-
Data to which the model will be applied.
|
|
32
|
-
batch_size : int
|
|
33
|
-
The batch size
|
|
34
|
-
call_func : callable
|
|
35
|
-
The function that calls the model, usually model.__call__
|
|
36
|
-
end_type : str
|
|
37
|
-
What to do with the output, can use first, sum, mean, concat,
|
|
38
|
-
stack, and concat_and_resort. Note that the data is shuffled
|
|
39
|
-
when batched. The last option (and defaut value) resorts the
|
|
40
|
-
outputs to give back the same order as the input.
|
|
41
|
-
key_idx : int
|
|
42
|
-
Seed for random key to shuffle the data before batching.
|
|
43
|
-
|
|
44
|
-
Returns
|
|
45
|
-
-------
|
|
46
|
-
final_pred : jnp.array
|
|
47
|
-
The predictions over the batches.
|
|
48
|
-
"""
|
|
49
|
-
idxs = []
|
|
50
|
-
all_preds = []
|
|
51
|
-
|
|
52
|
-
if str is not None:
|
|
53
|
-
print(str)
|
|
54
|
-
fn = lambda x, *args, **kwargs: tqdm(x, *args, **kwargs)
|
|
55
|
-
else:
|
|
56
|
-
fn = lambda x, *args, **kwargs: x
|
|
57
|
-
|
|
58
|
-
if not (isinstance(x, tuple) or isinstance(x, list)):
|
|
59
|
-
x = [x]
|
|
60
|
-
x = [el.T for el in x]
|
|
61
|
-
|
|
62
|
-
for _, inputs in fn(
|
|
63
|
-
zip(
|
|
64
|
-
itertools.count(start=0),
|
|
65
|
-
dataloader(
|
|
66
|
-
[*x, np.arange(0, x[0].shape[0], 1)],
|
|
67
|
-
batch_size,
|
|
68
|
-
once=True,
|
|
69
|
-
),
|
|
70
|
-
),
|
|
71
|
-
total=int(x[0].shape[-1] / batch_size),
|
|
72
|
-
):
|
|
73
|
-
input_b = inputs[:-1]
|
|
74
|
-
idx = inputs[-1]
|
|
75
|
-
|
|
76
|
-
input_b = [el.T for el in input_b]
|
|
77
|
-
|
|
78
|
-
pred = call_func(*input_b, *args, **kwargs)
|
|
79
|
-
idxs.append(idx)
|
|
80
|
-
all_preds.append(pred)
|
|
81
|
-
if end_type == "first":
|
|
82
|
-
break
|
|
83
|
-
idxs = np.concatenate(idxs)
|
|
84
|
-
match end_type:
|
|
85
|
-
case "concat_and_resort":
|
|
86
|
-
final_pred = np.concatenate(all_preds, -1)[..., np.argsort(idxs)]
|
|
87
|
-
case "concat":
|
|
88
|
-
final_pred = np.concatenate(all_preds, -1)
|
|
89
|
-
case "stack":
|
|
90
|
-
final_pred = np.stack(all_preds, -1)
|
|
91
|
-
case "mean":
|
|
92
|
-
final_pred = sum(all_preds) / len(all_preds)
|
|
93
|
-
case "sum":
|
|
94
|
-
final_pred = sum(all_preds)
|
|
95
|
-
case "first":
|
|
96
|
-
final_pred = all_preds[0]
|
|
97
|
-
case _:
|
|
98
|
-
final_pred = all_preds
|
|
99
|
-
return final_pred
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def loss_generator(which=None, norm_loss_=None):
|
|
103
|
-
""" Allows users to use different loss functions by only providing strings. """
|
|
104
|
-
if norm_loss_ is None:
|
|
105
|
-
norm_loss_ = lambda x1, x2: torch.linalg.norm(x1 - x2) / torch.linalg.norm(x2) * 100
|
|
106
|
-
|
|
107
|
-
if (which == "default") or (which == "Vanilla"):
|
|
108
|
-
def loss_fun(model, input, out, **kwargs):
|
|
109
|
-
pred = model(input, keep_normalized=True)
|
|
110
|
-
aux = {"loss": norm_loss_(pred, out)}
|
|
111
|
-
return norm_loss_(pred, out), (aux, {})
|
|
112
|
-
|
|
113
|
-
elif which == "RRAE":
|
|
114
|
-
def loss_fun(model, input, out, *, k_max, **kwargs):
|
|
115
|
-
pred = model(input, k_max=k_max, keep_normalized=True)
|
|
116
|
-
aux = {"loss": norm_loss_(pred, out), "k_max": k_max}
|
|
117
|
-
return norm_loss_(pred, out), (aux, {})
|
|
118
|
-
|
|
119
|
-
elif which == "Sparse":
|
|
120
|
-
def loss_fun(
|
|
121
|
-
model, input, out, *, sparsity=0.05, beta=1.0, **kwargs
|
|
122
|
-
):
|
|
123
|
-
pred = model(input, keep_normalized=True)
|
|
124
|
-
lat = model.latent(input)
|
|
125
|
-
sparse_term = sparsity * torch.log(sparsity / (torch.mean(lat) + 1e-8)) + (
|
|
126
|
-
1 - sparsity
|
|
127
|
-
) * torch.log((1 - sparsity) / (1 - torch.mean(lat) + 1e-8))
|
|
128
|
-
|
|
129
|
-
aux = {"loss rec": norm_loss_(pred, out), "loss sparse": sparse_term}
|
|
130
|
-
return norm_loss_(pred, out) + beta * sparse_term, (aux, {})
|
|
131
|
-
|
|
132
|
-
elif which == "nuc":
|
|
133
|
-
def loss_fun(
|
|
134
|
-
model,
|
|
135
|
-
input,
|
|
136
|
-
out,
|
|
137
|
-
*,
|
|
138
|
-
lambda_nuc=0.001,
|
|
139
|
-
norm_loss=None,
|
|
140
|
-
find_layer=None,
|
|
141
|
-
**kwargs,
|
|
142
|
-
):
|
|
143
|
-
if norm_loss is None:
|
|
144
|
-
norm_loss = norm_loss_
|
|
145
|
-
pred = model(input)
|
|
146
|
-
|
|
147
|
-
if find_layer is None:
|
|
148
|
-
raise ValueError(
|
|
149
|
-
"To use LoRAE, you should specify how to find the layer for "
|
|
150
|
-
"which we add the nuclear norm in the loss. To do so, give the path "
|
|
151
|
-
"to the layer as loss kwargs to the trainor: "
|
|
152
|
-
'e.g.: \n"loss_kwargs": {"find_layer": lambda model: model.encode.layers[-2].layers_l[-1].weight} (for predefined CNN AE) \n'
|
|
153
|
-
'"loss_kwargs": {"find_layer": lambda model: model.encode.layers_l[-1].weight} (for predefined MLP AE).'
|
|
154
|
-
)
|
|
155
|
-
else:
|
|
156
|
-
weight = find_layer(model)
|
|
157
|
-
|
|
158
|
-
aux = {"loss rec": norm_loss_(pred, out), "loss nuc": torch.linalg.norm(weight, "nuc")}
|
|
159
|
-
return norm_loss(pred, out) + lambda_nuc * torch.linalg.norm(weight, "nuc"), (aux, {})
|
|
160
|
-
|
|
161
|
-
elif which == "VRRAE":
|
|
162
|
-
norm_loss_ = lambda pr, out: torch.linalg.norm(pr-out)/torch.linalg.norm(out)*100
|
|
163
|
-
|
|
164
|
-
def loss_fun(model, input, out, idx, epsilon, k_max, beta=None, **kwargs):
|
|
165
|
-
lat, means, logvars = model.latent(input, epsilon=epsilon, k_max=k_max, return_lat_dist=True)
|
|
166
|
-
pred = model.decode(lat)
|
|
167
|
-
kl_loss = torch.sum(
|
|
168
|
-
-0.5 * (1 + logvars - torch.square(means) - torch.exp(logvars))
|
|
169
|
-
)
|
|
170
|
-
loss_rec = norm_loss_(pred, out)
|
|
171
|
-
aux = {
|
|
172
|
-
"loss rec": loss_rec,
|
|
173
|
-
"loss kl": kl_loss,
|
|
174
|
-
}
|
|
175
|
-
if beta is None:
|
|
176
|
-
beta = lambda_fn(loss_rec, kl_loss)
|
|
177
|
-
aux["beta"] = beta
|
|
178
|
-
return loss_rec + beta*kl_loss, (aux, {})
|
|
179
|
-
|
|
180
|
-
elif which == "VAE":
|
|
181
|
-
norm_loss_ = lambda pr, out: torch.linalg.norm(pr-out)/torch.linalg.norm(out)*100
|
|
182
|
-
|
|
183
|
-
def lambda_fn(loss, loss_c):
|
|
184
|
-
return loss_c*torch.exp(-0.1382*loss)
|
|
185
|
-
|
|
186
|
-
def loss_fun(model, input, out, idx, epsilon, beta=None, **kwargs):
|
|
187
|
-
lat, means, logvars = model.latent(input, epsilon=epsilon, length=input.shape[3], return_lat_dist=True)
|
|
188
|
-
pred = model.decode(lat)
|
|
189
|
-
kl_loss = torch.sum(
|
|
190
|
-
-0.5 * (1 + logvars - torch.square(means) - torch.exp(logvars))
|
|
191
|
-
)
|
|
192
|
-
loss_rec = norm_loss_(pred, out)
|
|
193
|
-
aux = {
|
|
194
|
-
"loss rec": loss_rec,
|
|
195
|
-
"loss kl": kl_loss,
|
|
196
|
-
}
|
|
197
|
-
if beta is None:
|
|
198
|
-
beta = lambda_fn(loss_rec, kl_loss)
|
|
199
|
-
aux["beta"] = beta
|
|
200
|
-
return loss_rec + beta*kl_loss, (aux, {})
|
|
201
|
-
|
|
202
|
-
elif "Contractive":
|
|
203
|
-
def loss_fun(model, input, out, *, beta=1.0, find_weight=None, **kwargs):
|
|
204
|
-
assert find_weight is not None
|
|
205
|
-
lat = model.latent(input)
|
|
206
|
-
pred = model(input, keep_normalized=True)
|
|
207
|
-
W = find_weight(model)
|
|
208
|
-
W = find_weight(model)
|
|
209
|
-
dh = lat * (1 - lat)
|
|
210
|
-
dh = dh.T
|
|
211
|
-
loss_contr = torch.sum(torch.matmul(dh**2, torch.square(W)))
|
|
212
|
-
aux = {"loss": norm_loss_(pred, out), "cont": loss_contr}
|
|
213
|
-
aux = {"loss": norm_loss_(pred, out), "cont": loss_contr}
|
|
214
|
-
return norm_loss_(pred, out) + beta * loss_contr, (aux, {})
|
|
215
|
-
else:
|
|
216
|
-
raise ValueError(f"{which} is an Unknown loss type")
|
|
217
|
-
return loss_fun
|
|
218
|
-
|
|
219
|
-
def dataloader(arrays, batch_size, p_vals=None, once=False):
|
|
220
|
-
""" JAX copatible dataloader to batch data randomly and differently
|
|
221
|
-
between epochs. """
|
|
222
|
-
dataset_size = arrays[0].shape[0]
|
|
223
|
-
arrays = [array if array is not None else [None] * dataset_size for array in arrays]
|
|
224
|
-
indices = np.arange(dataset_size)
|
|
225
|
-
kk = 0
|
|
226
|
-
|
|
227
|
-
while True:
|
|
228
|
-
perm = random.permutation(indices)
|
|
229
|
-
start = 0
|
|
230
|
-
end = batch_size
|
|
231
|
-
while end <= dataset_size:
|
|
232
|
-
batch_perm = perm[start:end]
|
|
233
|
-
arrs = tuple(
|
|
234
|
-
itemgetter(*batch_perm)(array) for array in arrays
|
|
235
|
-
) # Works for lists and arrays
|
|
236
|
-
if batch_size != 1:
|
|
237
|
-
yield [np.array(arr) for arr in arrs]
|
|
238
|
-
else:
|
|
239
|
-
yield [
|
|
240
|
-
[arr] if arr is None else np.expand_dims(np.array(arr), axis=0)
|
|
241
|
-
for arr in arrs
|
|
242
|
-
]
|
|
243
|
-
start = end
|
|
244
|
-
end = start + batch_size
|
|
245
|
-
if once:
|
|
246
|
-
if dataset_size % batch_size != 0:
|
|
247
|
-
batch_perm = perm[-(dataset_size % batch_size) :]
|
|
248
|
-
arrs = tuple(
|
|
249
|
-
itemgetter(*batch_perm)(array) for array in arrays
|
|
250
|
-
) # Works for lists and arrays
|
|
251
|
-
if dataset_size % batch_size == 1:
|
|
252
|
-
yield [
|
|
253
|
-
[arr] if arr is None else np.expand_dims(np.array(arr), 0)
|
|
254
|
-
for arr in arrs
|
|
255
|
-
]
|
|
256
|
-
else:
|
|
257
|
-
yield [[arr] if arr is None else np.array(arr) for arr in arrs]
|
|
258
|
-
break
|
|
259
|
-
kk += 1
|
|
260
|
-
|
|
261
|
-
##################### ML synthetic data generation ########################
|
|
262
|
-
|
|
263
|
-
def np_vmap(func, to_array=True):
|
|
264
|
-
""" Similar to JAX's vmap but in numpy.
|
|
265
|
-
Slow but useful if we want to use the same syntax."""
|
|
266
|
-
def map_func(*arrays, args=None, kwargs=None):
|
|
267
|
-
sols = []
|
|
268
|
-
for elems in zip(*arrays):
|
|
269
|
-
if (args is None) and (kwargs is None):
|
|
270
|
-
sols.append(func(*elems))
|
|
271
|
-
elif (args is not None) and (kwargs is not None):
|
|
272
|
-
sols.append(func(*elems, *args, **kwargs))
|
|
273
|
-
elif args is not None:
|
|
274
|
-
sols.append(func(*elems, *args))
|
|
275
|
-
else:
|
|
276
|
-
sols.append(func(*elems, **kwargs))
|
|
277
|
-
try:
|
|
278
|
-
if isinstance(sols[0], list) or isinstance(sols[0], tuple):
|
|
279
|
-
final_sols = []
|
|
280
|
-
for i in range(len(sols[0])):
|
|
281
|
-
final_sols.append(np.array([sol[i] for sol in sols]))
|
|
282
|
-
return final_sols
|
|
283
|
-
return np.array([np.squeeze(np.stack(sol, axis=0)) for sol in sols])
|
|
284
|
-
except:
|
|
285
|
-
if to_array:
|
|
286
|
-
return np.array(sols)
|
|
287
|
-
else:
|
|
288
|
-
return sols
|
|
289
|
-
|
|
290
|
-
return map_func
|
|
291
|
-
|
|
292
|
-
def divide_return(
|
|
293
|
-
inp_all,
|
|
294
|
-
p_all=None,
|
|
295
|
-
output=None,
|
|
296
|
-
prop_train=0.8,
|
|
297
|
-
test_end=0,
|
|
298
|
-
eps=1,
|
|
299
|
-
pre_func_in=lambda x: x,
|
|
300
|
-
pre_func_out=lambda x: x,
|
|
301
|
-
args=(),
|
|
302
|
-
):
|
|
303
|
-
"""p_all of shape (P x N) and y_all of shape (T x N).
|
|
304
|
-
The function divides into train/test according to the parameters
|
|
305
|
-
to allow the test set to be interpolated linearly from the training set
|
|
306
|
-
(if possible). If test_end is specified this is overwridden to only take
|
|
307
|
-
the lest test_end values for testing.
|
|
308
|
-
|
|
309
|
-
NOTE: pre_func_in and pre_func_out are functions you want to apply over
|
|
310
|
-
the input and output but you can not do it on all the data since it is
|
|
311
|
-
too big (e.g. conversion to float). These functions will be applied on
|
|
312
|
-
batches during training/evaluation of the Network."""
|
|
313
|
-
|
|
314
|
-
if test_end == 0:
|
|
315
|
-
idx_test = random.permutation(inp_all.shape[-1])[
|
|
316
|
-
: int(inp_all.shape[-1] * (1 - prop_train))
|
|
317
|
-
]
|
|
318
|
-
|
|
319
|
-
x_test = inp_all[..., idx_test]
|
|
320
|
-
x_train = random.permutation(
|
|
321
|
-
np.delete(inp_all, idx_test, -1), -1
|
|
322
|
-
)
|
|
323
|
-
else:
|
|
324
|
-
if p_all is not None:
|
|
325
|
-
p_test = p_all[-test_end:]
|
|
326
|
-
p_train = p_all[: len(p_all) - test_end]
|
|
327
|
-
x_test = inp_all[..., -test_end:]
|
|
328
|
-
x_train = inp_all[..., : inp_all.shape[-1] - test_end]
|
|
329
|
-
|
|
330
|
-
if output is None:
|
|
331
|
-
output_train = x_train
|
|
332
|
-
output_test = x_test
|
|
333
|
-
else:
|
|
334
|
-
output_test = output[idx_test]
|
|
335
|
-
output_train = np.delete(output, idx_test, 0)
|
|
336
|
-
|
|
337
|
-
if p_all is not None:
|
|
338
|
-
p_train = np.expand_dims(p_train, -1) if len(p_train.shape) == 1 else p_train
|
|
339
|
-
p_test = np.expand_dims(p_test, -1) if len(p_test.shape) == 1 else p_test
|
|
340
|
-
else:
|
|
341
|
-
p_train = None
|
|
342
|
-
p_test = None
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
return (
|
|
346
|
-
torch.tensor(x_train, dtype=torch.float32),
|
|
347
|
-
torch.tensor(x_test, dtype=torch.float32),
|
|
348
|
-
torch.tensor(p_train, dtype=torch.float32) if p_train is not None else None,
|
|
349
|
-
torch.tensor(p_test, dtype=torch.float32) if p_test is not None else None,
|
|
350
|
-
torch.tensor(output_train, dtype=torch.float32),
|
|
351
|
-
torch.tensor(output_test, dtype=torch.float32),
|
|
352
|
-
pre_func_in,
|
|
353
|
-
pre_func_out,
|
|
354
|
-
args,
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
def get_data(problem, folder=None, train_size=1000, test_size=10000, **kwargs):
|
|
359
|
-
"""Function that generates the examples presented in the paper."""
|
|
360
|
-
import numpy as np
|
|
361
|
-
match problem:
|
|
362
|
-
case "2d_gaussian_shift_scale":
|
|
363
|
-
D = 64 # Dimension of the domain
|
|
364
|
-
Ntr = train_size # Number of training samples
|
|
365
|
-
Nte = test_size # Number of test samples
|
|
366
|
-
sigma = 0.2
|
|
367
|
-
|
|
368
|
-
def gaussian_2d(x, y, x0, y0, sigma):
|
|
369
|
-
return np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
|
|
370
|
-
|
|
371
|
-
x = np.linspace(-1, 1, D)
|
|
372
|
-
y = np.linspace(-1, 1, D)
|
|
373
|
-
X, Y = np.meshgrid(x, y)
|
|
374
|
-
# Create training data
|
|
375
|
-
train_data = []
|
|
376
|
-
x0_vals = np.linspace(-0.5, 0.5, int(np.sqrt(Ntr))+1)
|
|
377
|
-
y0_vals = np.linspace(-0.5, 0.5, int(np.sqrt(Ntr))+1)
|
|
378
|
-
x0_mesh, y0_mesh = np.meshgrid(x0_vals, y0_vals)
|
|
379
|
-
x0_mesh = x0_mesh.flatten()
|
|
380
|
-
y0_mesh = y0_mesh.flatten()
|
|
381
|
-
|
|
382
|
-
for i in range(Ntr):
|
|
383
|
-
train_data.append(gaussian_2d(X, Y, x0_mesh[i], y0_mesh[i], sigma))
|
|
384
|
-
train_data = np.stack(train_data, axis=-1)
|
|
385
|
-
|
|
386
|
-
# Create test data
|
|
387
|
-
x0_vals_test = np.random.uniform(-0.5, 0.5, Nte)
|
|
388
|
-
y0_vals_test = np.random.uniform(-0.5, 0.5, Nte)
|
|
389
|
-
x0_mesh_test = x0_vals_test
|
|
390
|
-
y0_mesh_test = y0_vals_test
|
|
391
|
-
|
|
392
|
-
test_data = []
|
|
393
|
-
for i in range(Nte):
|
|
394
|
-
test_data.append(gaussian_2d(X, Y, x0_mesh_test[i], y0_mesh_test[i], sigma))
|
|
395
|
-
test_data = np.stack(test_data, axis=-1)
|
|
396
|
-
|
|
397
|
-
# Normalize the data
|
|
398
|
-
train_data = (train_data - np.mean(train_data)) / np.std(train_data)
|
|
399
|
-
test_data = (test_data - np.mean(test_data)) / np.std(test_data)
|
|
400
|
-
# Split the data into training and test sets
|
|
401
|
-
x_train = torch.tensor(np.expand_dims(train_data, 0), dtype=torch.float32)
|
|
402
|
-
x_test = torch.tensor(np.expand_dims(test_data, 0), dtype=torch.float32)
|
|
403
|
-
y_train = torch.tensor(np.expand_dims(train_data, 0), dtype=torch.float32)
|
|
404
|
-
y_test = torch.tensor(np.expand_dims(test_data, 0), dtype=torch.float32)
|
|
405
|
-
p_train = torch.tensor(np.stack([x0_mesh, y0_mesh], axis=-1), dtype=torch.float32)
|
|
406
|
-
p_test = torch.tensor(np.stack([x0_mesh_test, y0_mesh_test], axis=-1), dtype=torch.float32)
|
|
407
|
-
return x_train, x_test, p_train, p_test, y_train, y_test, lambda x: x, lambda x: x, ()
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
case "CIFAR-10":
|
|
411
|
-
import pickle
|
|
412
|
-
import os
|
|
413
|
-
|
|
414
|
-
def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
|
|
415
|
-
with open(os.path.join(cifar10_dataset_folder_path, 'data_batch_' + str(batch_id)), mode='rb') as file:
|
|
416
|
-
batch = pickle.load(file, encoding='latin1')
|
|
417
|
-
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
|
|
418
|
-
labels = batch['labels']
|
|
419
|
-
return features, labels
|
|
420
|
-
|
|
421
|
-
def load_cifar10(cifar10_dataset_folder_path):
|
|
422
|
-
x_train = []
|
|
423
|
-
y_train = []
|
|
424
|
-
for batch_id in range(1, 6):
|
|
425
|
-
features, labels = load_cifar10_batch(cifar10_dataset_folder_path, batch_id)
|
|
426
|
-
x_train.extend(features)
|
|
427
|
-
y_train.extend(labels)
|
|
428
|
-
x_train = np.array(x_train)
|
|
429
|
-
y_train = np.array(y_train)
|
|
430
|
-
with open(os.path.join(cifar10_dataset_folder_path, 'test_batch'), mode='rb') as file:
|
|
431
|
-
batch = pickle.load(file, encoding='latin1')
|
|
432
|
-
x_test = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
|
|
433
|
-
y_test = np.array(batch['labels'])
|
|
434
|
-
return x_train, x_test, y_train, y_test
|
|
435
|
-
|
|
436
|
-
cifar10_dataset_folder_path = folder
|
|
437
|
-
x_train, x_test, y_train, y_test = load_cifar10(cifar10_dataset_folder_path)
|
|
438
|
-
pre_func_in = lambda x: np.array(x, dtype=np.float32) / 255.0
|
|
439
|
-
pre_func_out = lambda x: np.array(x, dtype=np.float32) / 255.0
|
|
440
|
-
x_train = np.swapaxes(x_train, 0, -1)
|
|
441
|
-
x_test = np.swapaxes(x_test, 0, -1)
|
|
442
|
-
return x_train, x_test, None, None, x_train, x_test, pre_func_in, pre_func_out, ()
|
|
443
|
-
|
|
444
|
-
case "CelebA":
|
|
445
|
-
data_res = 160
|
|
446
|
-
import os
|
|
447
|
-
from PIL import Image
|
|
448
|
-
import numpy as np
|
|
449
|
-
from skimage.transform import resize
|
|
450
|
-
|
|
451
|
-
if os.path.exists(f"{folder}/celeba_data_{data_res}.npy"):
|
|
452
|
-
print("Loading data from file")
|
|
453
|
-
data = np.load(f"{folder}/celeba_data_{data_res}.npy")
|
|
454
|
-
else:
|
|
455
|
-
print("Loading data and processing...")
|
|
456
|
-
data = np.load(f"{folder}/celeba_data.npy")
|
|
457
|
-
celeb_transform = lambda im: np.astype(
|
|
458
|
-
resize(im, (data_res, data_res, 3), order=1, anti_aliasing=True)
|
|
459
|
-
* 255.0,
|
|
460
|
-
np.uint8,
|
|
461
|
-
)
|
|
462
|
-
all_data = []
|
|
463
|
-
for i in tqdm(range(data.shape[0])):
|
|
464
|
-
all_data.append(celeb_transform(data[i]))
|
|
465
|
-
|
|
466
|
-
data = np.stack(all_data, axis=0)
|
|
467
|
-
data = np.swapaxes(data, 0, 3)
|
|
468
|
-
np.save(f"{folder}/celeba_data_{data_res}.npy", data)
|
|
469
|
-
|
|
470
|
-
print("Data shape: ", data.shape)
|
|
471
|
-
x_train = data[..., :162770]
|
|
472
|
-
x_test = data[..., 182638:]
|
|
473
|
-
y_train = x_train
|
|
474
|
-
y_test = x_test
|
|
475
|
-
pre_func_in = lambda x: np.astype(x, np.float32) / 255.0
|
|
476
|
-
pre_func_out = lambda x: np.astype(x, np.float32) / 255.0
|
|
477
|
-
return (
|
|
478
|
-
x_train,
|
|
479
|
-
x_test,
|
|
480
|
-
None,
|
|
481
|
-
None,
|
|
482
|
-
y_train,
|
|
483
|
-
y_test,
|
|
484
|
-
pre_func_in,
|
|
485
|
-
pre_func_out,
|
|
486
|
-
(),
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
case "shift":
|
|
491
|
-
ts = np.linspace(0, 2 * np.pi, 100)
|
|
492
|
-
|
|
493
|
-
def sf_func(s):
|
|
494
|
-
return np.sin(ts - s * np.pi)
|
|
495
|
-
|
|
496
|
-
p_vals = np.linspace(0, 1.8, 200)[:-1] # 18
|
|
497
|
-
y_shift = np_vmap(sf_func)(p_vals).T
|
|
498
|
-
p_test = np.linspace(0, np.max(p_vals), 500)[1:-1]
|
|
499
|
-
y_test = np_vmap(sf_func)(p_test).T
|
|
500
|
-
y_all = np.concatenate([y_shift, y_test], axis=-1)
|
|
501
|
-
p_all = np.concatenate([p_vals, p_test], axis=0)
|
|
502
|
-
return divide_return(y_all, p_all, test_end=y_test.shape[-1])
|
|
503
|
-
|
|
504
|
-
case "gaussian_shift":
|
|
505
|
-
ts = np.linspace(0, 2 * np.pi, 200)
|
|
506
|
-
def gauss_shift(s):
|
|
507
|
-
return np.exp(-((ts - s) ** 2) / 0.1) # Smaller width
|
|
508
|
-
p_vals = np.linspace(1, 2 * np.pi +1, 20)
|
|
509
|
-
ts = np.linspace(0, 2 * np.pi + 2, 500)
|
|
510
|
-
y_shift = np_vmap(gauss_shift)(p_vals).T
|
|
511
|
-
p_test = np.linspace(np.min(p_vals), np.max(p_vals), 500)[1:-1]
|
|
512
|
-
y_test = np_vmap(gauss_shift)(p_test).T
|
|
513
|
-
y_all = np.concatenate([y_shift, y_test], axis=-1)
|
|
514
|
-
p_all = np.concatenate([p_vals, p_test], axis=0)
|
|
515
|
-
return divide_return(y_all, p_all, test_end=y_test.shape[-1])
|
|
516
|
-
|
|
517
|
-
case "stairs":
|
|
518
|
-
Tend = 3.5 # [s]
|
|
519
|
-
NT = 500
|
|
520
|
-
nt = NT + 1
|
|
521
|
-
times = np.linspace(0, Tend, nt)
|
|
522
|
-
freq = 1 # [Hz] # 3
|
|
523
|
-
wrad = 2 * np.pi * freq
|
|
524
|
-
nAmp = 100 # 60
|
|
525
|
-
yo = 2.3
|
|
526
|
-
Amp = np.arange(1, 5, 0.1)
|
|
527
|
-
phases = np.linspace(1 / 4 * Tend, 3 / 4 * Tend, nAmp)
|
|
528
|
-
p_vals = Amp
|
|
529
|
-
|
|
530
|
-
def find_ph(amp):
|
|
531
|
-
return phases[0] + (amp - Amp[0]) / (Amp[1] - Amp[0]) * (
|
|
532
|
-
phases[1] - phases[0]
|
|
533
|
-
)
|
|
534
|
-
|
|
535
|
-
def create_escal(amp):
|
|
536
|
-
return np.cumsum(
|
|
537
|
-
(
|
|
538
|
-
(
|
|
539
|
-
np.abs(
|
|
540
|
-
(
|
|
541
|
-
amp
|
|
542
|
-
* np.sqrt(times)
|
|
543
|
-
* np.sin(wrad * (times - find_ph(amp)))
|
|
544
|
-
)
|
|
545
|
-
- yo
|
|
546
|
-
)
|
|
547
|
-
+ (
|
|
548
|
-
(
|
|
549
|
-
amp
|
|
550
|
-
* np.sqrt(times)
|
|
551
|
-
* np.sin(wrad * (times - find_ph(amp)))
|
|
552
|
-
)
|
|
553
|
-
- yo
|
|
554
|
-
)
|
|
555
|
-
)
|
|
556
|
-
/ 2
|
|
557
|
-
)
|
|
558
|
-
** 5
|
|
559
|
-
)
|
|
560
|
-
|
|
561
|
-
y_shift_old = np_vmap(create_escal)(p_vals).T
|
|
562
|
-
y_shift = np_vmap(
|
|
563
|
-
lambda y: (y - np.mean(y_shift_old)) / np.std(y_shift_old)
|
|
564
|
-
)(y_shift_old.T).T
|
|
565
|
-
y_shift = y_shift[:, ~np.isnan(y_shift).any(axis=0)]
|
|
566
|
-
|
|
567
|
-
p_test = random.uniform(
|
|
568
|
-
np.min(p_vals) * 1.00001,
|
|
569
|
-
np.max(p_vals) * 0.99999,
|
|
570
|
-
(300,),
|
|
571
|
-
)
|
|
572
|
-
y_test = np_vmap(
|
|
573
|
-
lambda y: (y - np.mean(y_shift_old)) / np.std(y_shift_old)
|
|
574
|
-
)(np_vmap(create_escal)(p_test)).T
|
|
575
|
-
y_all = np.concatenate([y_shift, y_test], axis=-1)
|
|
576
|
-
p_all = np.concatenate([p_vals, p_test], axis=0)
|
|
577
|
-
ts = np.arange(0, y_shift.shape[0], 1)
|
|
578
|
-
return divide_return(y_all, p_all, test_end=y_test.shape[-1])
|
|
579
|
-
|
|
580
|
-
case "mult_freqs":
|
|
581
|
-
p_vals_0 = np.repeat(np.linspace(0.5 * np.pi, np.pi, 15), 15)
|
|
582
|
-
p_vals_1 = np.tile(np.linspace(0.3 * np.pi, 0.8 * np.pi, 15), 15)
|
|
583
|
-
p_vals = np.stack([p_vals_0, p_vals_1], axis=-1)
|
|
584
|
-
ts = np.arange(0, 5 * np.pi, 0.01)
|
|
585
|
-
y_shift = np_vmap(lambda p: np.sin(p[0] * ts) + np.sin(p[1] * ts))(
|
|
586
|
-
p_vals
|
|
587
|
-
).T
|
|
588
|
-
|
|
589
|
-
p_vals_0 = random.uniform(
|
|
590
|
-
p_vals_0[0] * 1.001,
|
|
591
|
-
p_vals_0[-1] * 0.999,
|
|
592
|
-
(1000,),
|
|
593
|
-
)
|
|
594
|
-
p_vals_1 = random.uniform(
|
|
595
|
-
p_vals_1[0] * 1.001,
|
|
596
|
-
p_vals_1[-1] * 0.999,
|
|
597
|
-
(1000,),
|
|
598
|
-
)
|
|
599
|
-
p_test = np.stack([p_vals_0, p_vals_1], axis=-1)
|
|
600
|
-
y_test = np_vmap(lambda p: np.sin(p[0] * ts) + np.sin(p[1] * ts))(
|
|
601
|
-
p_test
|
|
602
|
-
).T
|
|
603
|
-
y_all = np.concatenate([y_shift, y_test], axis=-1)
|
|
604
|
-
p_all = np.concatenate([p_vals, p_test], axis=0)
|
|
605
|
-
return divide_return(y_all, p_all, test_end=y_test.shape[-1])
|
|
606
|
-
|
|
607
|
-
case "mult_gausses":
|
|
608
|
-
|
|
609
|
-
p_vals_0 = np.repeat(np.linspace(1, 3, 10), 100)
|
|
610
|
-
p_vals_1 = np.tile(np.linspace(4, 6, 10), 100)
|
|
611
|
-
p_vals = np.stack([p_vals_0, p_vals_1], axis=-1)
|
|
612
|
-
p_test_0 = random.uniform(
|
|
613
|
-
p_vals_0[0] * 1.001,
|
|
614
|
-
p_vals_0[-1] * 0.999,
|
|
615
|
-
(1000,),
|
|
616
|
-
)
|
|
617
|
-
p_test_1 = random.uniform(
|
|
618
|
-
p_vals_1[0] * 1.001,
|
|
619
|
-
p_vals_1[-1] * 0.999,
|
|
620
|
-
(1000,),
|
|
621
|
-
)
|
|
622
|
-
p_test = np.stack([p_test_0, p_test_1], axis=-1)
|
|
623
|
-
|
|
624
|
-
ts = np.arange(0, 6, 0.005)
|
|
625
|
-
|
|
626
|
-
def gauss(a, b, c, t):
|
|
627
|
-
return a * np.exp(-((t - b) ** 2) / (2 * c**2))
|
|
628
|
-
|
|
629
|
-
a = 1.3
|
|
630
|
-
c = 0.2
|
|
631
|
-
y_shift = np_vmap(
|
|
632
|
-
lambda p: gauss(a, p[0], c, ts) + gauss(-a, p[1], c, ts)
|
|
633
|
-
)(p_vals).T
|
|
634
|
-
y_test = np_vmap(
|
|
635
|
-
lambda p: gauss(a, p[0], c, ts) + gauss(-a, p[1], c, ts)
|
|
636
|
-
)(p_test).T
|
|
637
|
-
y_all = np.concatenate([y_shift, y_test], axis=-1)
|
|
638
|
-
p_all = np.concatenate([p_vals, p_test], axis=0)
|
|
639
|
-
return divide_return(y_all, p_all, test_end=y_test.shape[-1])
|
|
640
|
-
|
|
641
|
-
case "fashion_mnist":
|
|
642
|
-
import pandas
|
|
643
|
-
x_train = pandas.read_csv("fashin_mnist/fashion-mnist_train.csv").to_numpy().T[1:]
|
|
644
|
-
x_test = pandas.read_csv("fashin_mnist/fashion-mnist_test.csv").to_numpy().T[1:]
|
|
645
|
-
y_all = np.concatenate([x_train, x_test], axis=-1)
|
|
646
|
-
y_all = np.reshape(y_all, (1, 28, 28, -1))
|
|
647
|
-
pre_func_in = lambda x: np.astype(x, np.float32) / 255
|
|
648
|
-
return divide_return(y_all, None, test_end=x_test.shape[-1], pre_func_in=pre_func_in, pre_func_out=pre_func_in)
|
|
649
|
-
|
|
650
|
-
case "mnist":
|
|
651
|
-
import os
|
|
652
|
-
import gzip
|
|
653
|
-
import numpy as np
|
|
654
|
-
import pickle as pkl
|
|
655
|
-
|
|
656
|
-
if os.path.exists(f"{folder}/mnist_data.npy"):
|
|
657
|
-
print("Loading data from file")
|
|
658
|
-
with open(f"{folder}/mnist_data.npy", "rb") as f:
|
|
659
|
-
train_images, train_labels, test_images, test_labels = pkl.load(f)
|
|
660
|
-
else:
|
|
661
|
-
print("Loading data and processing...")
|
|
662
|
-
|
|
663
|
-
def load_mnist_images(filename):
|
|
664
|
-
with gzip.open(filename, 'rb') as f:
|
|
665
|
-
data = np.frombuffer(f.read(), np.uint8, offset=16)
|
|
666
|
-
data = data.reshape(-1, 28, 28)
|
|
667
|
-
return data
|
|
668
|
-
|
|
669
|
-
def load_mnist_labels(filename):
|
|
670
|
-
with gzip.open(filename, 'rb') as f:
|
|
671
|
-
data = np.frombuffer(f.read(), np.uint8, offset=8)
|
|
672
|
-
return data
|
|
673
|
-
|
|
674
|
-
def load_mnist(path):
|
|
675
|
-
train_images = load_mnist_images(os.path.join(path, 'train-images-idx3-ubyte.gz'))
|
|
676
|
-
train_labels = load_mnist_labels(os.path.join(path, 'train-labels-idx1-ubyte.gz'))
|
|
677
|
-
test_images = load_mnist_images(os.path.join(path, 't10k-images-idx3-ubyte.gz'))
|
|
678
|
-
test_labels = load_mnist_labels(os.path.join(path, 't10k-labels-idx1-ubyte.gz'))
|
|
679
|
-
return (train_images, train_labels), (test_images, test_labels)
|
|
680
|
-
|
|
681
|
-
def preprocess_mnist(images):
|
|
682
|
-
images = images.astype(np.float32) / 255.0
|
|
683
|
-
images = np.expand_dims(images, axis=1) # Add channel dimension
|
|
684
|
-
return images
|
|
685
|
-
|
|
686
|
-
def get_mnist_data(path):
|
|
687
|
-
(train_images, train_labels), (test_images, test_labels) = load_mnist(path)
|
|
688
|
-
train_images = preprocess_mnist(train_images)
|
|
689
|
-
test_images = preprocess_mnist(test_images)
|
|
690
|
-
return train_images, train_labels, test_images, test_labels
|
|
691
|
-
|
|
692
|
-
train_images, train_labels, test_images, test_labels = get_mnist_data(folder)
|
|
693
|
-
train_images = np.swapaxes(np.moveaxis(train_images, 1, -1), 0, -1)
|
|
694
|
-
test_images = np.swapaxes(np.moveaxis(test_images, 1, -1), 0, -1)
|
|
695
|
-
with open(f"{folder}/mnist_data.npy", "wb") as f:
|
|
696
|
-
pkl.dump((train_images, train_labels, test_images, test_labels), f)
|
|
697
|
-
|
|
698
|
-
return (
|
|
699
|
-
train_images,
|
|
700
|
-
test_images,
|
|
701
|
-
None,
|
|
702
|
-
None,
|
|
703
|
-
train_images,
|
|
704
|
-
test_images,
|
|
705
|
-
lambda x: x,
|
|
706
|
-
lambda x: x,
|
|
707
|
-
(),
|
|
708
|
-
)
|
|
709
|
-
|
|
710
|
-
case _:
|
|
711
|
-
raise ValueError(f"Problem {problem} not recognized")
|
|
712
|
-
|
|
713
|
-
########################### SVD functions #################################
|
|
714
|
-
|
|
715
|
-
def stable_SVD(A):
|
|
716
|
-
""" Computes a numerically stable SVD with optional noise injection.
|
|
717
|
-
|
|
718
|
-
A: input matrix (..., m, n)
|
|
719
|
-
noise_std: std of noise to inject into singular values
|
|
720
|
-
eps: numerical stability constant
|
|
721
|
-
"""
|
|
722
|
-
return StableSVD.apply(A)
|
|
723
|
-
|
|
724
|
-
class StableSVD(torch.autograd.Function):
|
|
725
|
-
@staticmethod
|
|
726
|
-
def forward(ctx, A):
|
|
727
|
-
"""
|
|
728
|
-
A: input matrix (..., m, n)
|
|
729
|
-
noise_std: std of noise to inject into singular values
|
|
730
|
-
eps: numerical stability constant
|
|
731
|
-
"""
|
|
732
|
-
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
|
|
733
|
-
|
|
734
|
-
ctx.save_for_backward(U, S, Vh)
|
|
735
|
-
|
|
736
|
-
return U, S, Vh
|
|
737
|
-
|
|
738
|
-
@staticmethod
|
|
739
|
-
def backward(ctx, dU, dS, dVh):
|
|
740
|
-
U, S, Vh = ctx.saved_tensors
|
|
741
|
-
|
|
742
|
-
scale = S[..., :1]
|
|
743
|
-
mask = (S / scale * 100 >= 1e-6)
|
|
744
|
-
S_stable = S * mask
|
|
745
|
-
|
|
746
|
-
k = S.shape[-1]
|
|
747
|
-
|
|
748
|
-
V = Vh.transpose(-2, -1)
|
|
749
|
-
dV = dVh.transpose(-2, -1)
|
|
750
|
-
|
|
751
|
-
S_i = S_stable.unsqueeze(-1)
|
|
752
|
-
S_j = S_stable.unsqueeze(-2)
|
|
753
|
-
|
|
754
|
-
diff = S_i**2 - S_j**2
|
|
755
|
-
zero = diff == 0
|
|
756
|
-
F = torch.zeros_like(diff)
|
|
757
|
-
F[~zero] = 1.0 / diff[~zero]
|
|
758
|
-
|
|
759
|
-
Ut_dU = U.transpose(-2, -1) @ dU
|
|
760
|
-
Vt_dV = V.transpose(-2, -1) @ dV
|
|
761
|
-
|
|
762
|
-
K_U = F * (Ut_dU - Ut_dU.transpose(-2, -1))
|
|
763
|
-
K_V = F * (Vt_dV - Vt_dV.transpose(-2, -1))
|
|
764
|
-
|
|
765
|
-
K = torch.diag_embed(dS)
|
|
766
|
-
K = K + K_U @ torch.diag_embed(S)
|
|
767
|
-
K = K + torch.diag_embed(S) @ K_V
|
|
768
|
-
|
|
769
|
-
dA = U @ K @ Vh
|
|
770
|
-
|
|
771
|
-
return dA, None, None
|
|
772
|
-
|
|
773
|
-
######################## Typical Neural Network Architecturs ##############################
|
|
774
|
-
|
|
775
|
-
class MLP_with_linear(torch.nn.Module):
|
|
776
|
-
layers_l: tuple[Linear, ...]
|
|
777
|
-
""" Similar to an Eauinox MLP but with additional linear multiplications by matrices.
|
|
778
|
-
|
|
779
|
-
The only new_paraeter is linear_l and it is the number of matrices to use after the MLP.
|
|
780
|
-
e.g. if linear_l = 2, the model wil be constituted of an MLP and two matrix multiplications.
|
|
781
|
-
|
|
782
|
-
This is especially useful for IRMAE and LoRAE that use matrix multiplications in their latent
|
|
783
|
-
space.
|
|
784
|
-
"""
|
|
785
|
-
def __init__(
|
|
786
|
-
self,
|
|
787
|
-
*,
|
|
788
|
-
in_size,
|
|
789
|
-
out_size,
|
|
790
|
-
width_size,
|
|
791
|
-
depth,
|
|
792
|
-
linear_l=0,
|
|
793
|
-
**kwargs,
|
|
794
|
-
):
|
|
795
|
-
super().__init__()
|
|
796
|
-
|
|
797
|
-
if depth == 0:
|
|
798
|
-
self.mlp = MLP(
|
|
799
|
-
in_channels=in_size,
|
|
800
|
-
hidden_channels=[out_size],
|
|
801
|
-
**kwargs
|
|
802
|
-
)
|
|
803
|
-
else:
|
|
804
|
-
if isinstance(width_size, int):
|
|
805
|
-
width_size = [width_size] * depth
|
|
806
|
-
|
|
807
|
-
hidden_dims = width_size + [out_size]
|
|
808
|
-
|
|
809
|
-
self.mlp = MLP(
|
|
810
|
-
in_channels=in_size,
|
|
811
|
-
hidden_channels=hidden_dims,
|
|
812
|
-
**kwargs
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
layers_l = []
|
|
816
|
-
if linear_l != 0:
|
|
817
|
-
for _ in range(linear_l):
|
|
818
|
-
layers_l.append(
|
|
819
|
-
Linear(out_size, out_size, bias=False)
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
self.layers_l = torch.nn.ModuleList(layers_l)
|
|
823
|
-
|
|
824
|
-
def forward(self, x, *args, **kwargs):
|
|
825
|
-
x = self.mlp(x, *args, **kwargs)
|
|
826
|
-
if len(self.layers_l) != 0:
|
|
827
|
-
for layer in self.layers_l[:-1]:
|
|
828
|
-
x = layer(x)
|
|
829
|
-
return x
|
|
830
|
-
|
|
831
|
-
class Conv2d_(Conv2d):
|
|
832
|
-
""" For compatibility, accepting output_padding as kwarg"""
|
|
833
|
-
def __init__(self, *args, **kwargs):
|
|
834
|
-
if "output_padding" in kwargs:
|
|
835
|
-
kwargs.pop("output_padding")
|
|
836
|
-
super().__init__(*args, **kwargs)
|
|
837
|
-
|
|
838
|
-
class Conv1d_(Conv1d):
|
|
839
|
-
""" For compatibility, accepting output_padding as kwarg"""
|
|
840
|
-
def __init__(self, *args, **kwargs):
|
|
841
|
-
if "output_padding" in kwargs:
|
|
842
|
-
kwargs.pop("output_padding")
|
|
843
|
-
super().__init__(*args, **kwargs)
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
class MLCNN(torch.nn.Module):
|
|
847
|
-
layers: tuple
|
|
848
|
-
activation: Callable
|
|
849
|
-
final_activation: Callable
|
|
850
|
-
""" A Multilayer CNN model. It consists of multiple conv layers.
|
|
851
|
-
|
|
852
|
-
"""
|
|
853
|
-
def __init__(
|
|
854
|
-
self,
|
|
855
|
-
start_dim,
|
|
856
|
-
out_dim,
|
|
857
|
-
stride,
|
|
858
|
-
padding,
|
|
859
|
-
kernel_conv,
|
|
860
|
-
dilation,
|
|
861
|
-
CNN_widths,
|
|
862
|
-
activation=torch.nn.functional.relu,
|
|
863
|
-
final_activation=lambda x: x,
|
|
864
|
-
transpose=False,
|
|
865
|
-
output_padding=None,
|
|
866
|
-
dimension=2,
|
|
867
|
-
*,
|
|
868
|
-
kwargs_cnn={},
|
|
869
|
-
**kwargs,
|
|
870
|
-
):
|
|
871
|
-
"""Note: if provided as lists, activations should be one less than widths.
|
|
872
|
-
The last activation is specified by "final activation".
|
|
873
|
-
"""
|
|
874
|
-
super().__init__()
|
|
875
|
-
|
|
876
|
-
if isinstance(CNN_widths, list):
|
|
877
|
-
CNNs_num = len(CNN_widths)
|
|
878
|
-
else:
|
|
879
|
-
CNN_widths = [CNN_widths] * (CNNs_num - 1) + [out_dim]
|
|
880
|
-
|
|
881
|
-
CNN_widths_b = [start_dim] + CNN_widths[:-1]
|
|
882
|
-
layers = []
|
|
883
|
-
if dimension == 2:
|
|
884
|
-
fn = Conv2d_ if not transpose else ConvTranspose2d
|
|
885
|
-
elif dimension == 1:
|
|
886
|
-
fn = Conv1d_ if not transpose else ConvTranspose1d
|
|
887
|
-
|
|
888
|
-
for i in range(len(CNN_widths)):
|
|
889
|
-
layers.append(
|
|
890
|
-
fn(
|
|
891
|
-
CNN_widths_b[i],
|
|
892
|
-
CNN_widths[i],
|
|
893
|
-
kernel_size=kernel_conv,
|
|
894
|
-
stride=stride,
|
|
895
|
-
padding=padding,
|
|
896
|
-
dilation=dilation,
|
|
897
|
-
output_padding=output_padding,
|
|
898
|
-
**kwargs_cnn,
|
|
899
|
-
)
|
|
900
|
-
)
|
|
901
|
-
|
|
902
|
-
self.layers = torch.nn.ModuleList(layers)
|
|
903
|
-
|
|
904
|
-
self.activation = [lambda x: activation(x) for _ in CNN_widths]
|
|
905
|
-
|
|
906
|
-
self.final_activation = final_activation
|
|
907
|
-
|
|
908
|
-
def forward(self, x):
|
|
909
|
-
for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation[:-1])):
|
|
910
|
-
x = layer(x)
|
|
911
|
-
x = act(x)
|
|
912
|
-
x = self.final_activation(self.layers[-1](x))
|
|
913
|
-
return x
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
class CNNs_with_MLP(torch.nn.Module):
|
|
917
|
-
"""Class mainly for creating encoders with CNNs.
|
|
918
|
-
The encoder is composed of multiple CNNs followed by an MLP.
|
|
919
|
-
"""
|
|
920
|
-
|
|
921
|
-
layers: tuple[MLCNN, MLP]
|
|
922
|
-
|
|
923
|
-
def __init__(
|
|
924
|
-
self,
|
|
925
|
-
width,
|
|
926
|
-
height,
|
|
927
|
-
out,
|
|
928
|
-
channels=1,
|
|
929
|
-
width_CNNs=[32, 64, 128, 256],
|
|
930
|
-
kernel_conv=3,
|
|
931
|
-
stride=2,
|
|
932
|
-
padding=1,
|
|
933
|
-
dilation=1,
|
|
934
|
-
mlp_width=None,
|
|
935
|
-
mlp_depth=0,
|
|
936
|
-
dimension=2,
|
|
937
|
-
final_activation=lambda x: x,
|
|
938
|
-
*,
|
|
939
|
-
kwargs_cnn={},
|
|
940
|
-
kwargs_mlp={},
|
|
941
|
-
):
|
|
942
|
-
super().__init__()
|
|
943
|
-
|
|
944
|
-
if mlp_depth != 0:
|
|
945
|
-
if mlp_width is not None:
|
|
946
|
-
assert (
|
|
947
|
-
mlp_width >= out
|
|
948
|
-
), "Choose a bigger (or equal) MLP width than the latent space in the encoder."
|
|
949
|
-
else:
|
|
950
|
-
mlp_width = out
|
|
951
|
-
|
|
952
|
-
try:
|
|
953
|
-
last_width = width_CNNs[-1]
|
|
954
|
-
except:
|
|
955
|
-
last_width = width_CNNs
|
|
956
|
-
|
|
957
|
-
mlcnn = MLCNN(
|
|
958
|
-
channels,
|
|
959
|
-
last_width,
|
|
960
|
-
stride,
|
|
961
|
-
padding,
|
|
962
|
-
kernel_conv,
|
|
963
|
-
dilation,
|
|
964
|
-
width_CNNs,
|
|
965
|
-
dimension=dimension,
|
|
966
|
-
final_activation=torch.nn.functional.relu,
|
|
967
|
-
**kwargs_cnn,
|
|
968
|
-
)
|
|
969
|
-
|
|
970
|
-
if dimension == 2:
|
|
971
|
-
final_Ds = mlcnn(torch.zeros((channels, height, width))).shape[-2:]
|
|
972
|
-
fD = final_Ds[0] * final_Ds[1]
|
|
973
|
-
elif dimension == 1:
|
|
974
|
-
fD = mlcnn(torch.zeros((channels, height))).shape[-1]
|
|
975
|
-
|
|
976
|
-
mlp = MLP_with_linear(
|
|
977
|
-
in_size=fD * last_width,
|
|
978
|
-
out_size=out,
|
|
979
|
-
width_size=mlp_width,
|
|
980
|
-
depth=mlp_depth,
|
|
981
|
-
**kwargs_mlp,
|
|
982
|
-
)
|
|
983
|
-
|
|
984
|
-
act = lambda x: final_activation(x)
|
|
985
|
-
self.layers = torch.nn.ModuleList([mlcnn, mlp])
|
|
986
|
-
self.final_act = act
|
|
987
|
-
|
|
988
|
-
def forward(self, x, *args, **kwargs):
|
|
989
|
-
x = self.layers[0](x)
|
|
990
|
-
x = torch.unsqueeze(torch.flatten(x), -1)
|
|
991
|
-
x = self.layers[1](torch.squeeze(x))
|
|
992
|
-
x = self.final_act(x)
|
|
993
|
-
return x
|
|
994
|
-
|
|
995
|
-
def prev_D_CNN_trans(D0, D1, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[]):
|
|
996
|
-
pad = int_to_lst(pad, 2)
|
|
997
|
-
ker = int_to_lst(ker, 2)
|
|
998
|
-
st = int_to_lst(st, 2)
|
|
999
|
-
dil = int_to_lst(dil, 2)
|
|
1000
|
-
outpad = int_to_lst(outpad, 2)
|
|
1001
|
-
|
|
1002
|
-
if num == 0:
|
|
1003
|
-
return all_D0s, all_D1s
|
|
1004
|
-
|
|
1005
|
-
all_D0s.append(int(np.ceil(D0)))
|
|
1006
|
-
all_D1s.append(int(np.ceil(D1)))
|
|
1007
|
-
|
|
1008
|
-
return prev_D_CNN_trans(
|
|
1009
|
-
(D0 + 2 * pad[0] - dil[0] * (ker[0] - 1) - 1 - outpad[0]) / st[0] + 1,
|
|
1010
|
-
(D1 + 2 * pad[1] - dil[1] * (ker[1] - 1) - 1 - outpad[1]) / st[1] + 1,
|
|
1011
|
-
pad,
|
|
1012
|
-
ker,
|
|
1013
|
-
st,
|
|
1014
|
-
dil,
|
|
1015
|
-
outpad,
|
|
1016
|
-
num - 1,
|
|
1017
|
-
)
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
def find_padding_convT(D, data_dim0, ker, st, dil, outpad):
|
|
1021
|
-
|
|
1022
|
-
return D
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
def next_CNN_trans(O0, O1, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[]):
|
|
1026
|
-
pad = int_to_lst(pad, 2)
|
|
1027
|
-
ker = int_to_lst(ker, 2)
|
|
1028
|
-
st = int_to_lst(st, 2)
|
|
1029
|
-
dil = int_to_lst(dil, 2)
|
|
1030
|
-
outpad = int_to_lst(outpad, 2)
|
|
1031
|
-
|
|
1032
|
-
if num == 0:
|
|
1033
|
-
return all_D0s, all_D1s
|
|
1034
|
-
|
|
1035
|
-
all_D0s.append(int(O0))
|
|
1036
|
-
all_D1s.append(int(O1))
|
|
1037
|
-
|
|
1038
|
-
return next_CNN_trans(
|
|
1039
|
-
(O0 - 1) * st[0] + dil[0] * (ker[0] - 1) - 2 * pad[0] + 1 + outpad[0],
|
|
1040
|
-
(O1 - 1) * st[1] + dil[1] * (ker[1] - 1) - 2 * pad[1] + 1 + outpad[1],
|
|
1041
|
-
pad,
|
|
1042
|
-
ker,
|
|
1043
|
-
st,
|
|
1044
|
-
dil,
|
|
1045
|
-
outpad,
|
|
1046
|
-
num - 1,
|
|
1047
|
-
)
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
def is_convT_valid(D0, D1, data_dim0, data_dim1, pad, ker, st, dil, outpad, nums):
|
|
1051
|
-
all_D0s, all_D1s = next_CNN_trans(D0, D1, pad, ker, st, dil, outpad, nums, all_D0s=[], all_D1s=[])
|
|
1052
|
-
final_D0 = all_D0s[-1]
|
|
1053
|
-
final_D1 = all_D1s[-1]
|
|
1054
|
-
return final_D0 == data_dim0, final_D1 == data_dim1, final_D0, final_D1
|
|
1055
|
-
|
|
1056
|
-
class MLP_with_CNNs_trans(torch.nn.Module):
|
|
1057
|
-
"""Class mainly for creating encoders with CNNs.
|
|
1058
|
-
The encoder is composed of multiple CNNs followed by an MLP.
|
|
1059
|
-
"""
|
|
1060
|
-
|
|
1061
|
-
layers: tuple[MLCNN, Linear]
|
|
1062
|
-
start_dim: int
|
|
1063
|
-
d_shape: int
|
|
1064
|
-
out_after_mlp: int
|
|
1065
|
-
final_act: Callable
|
|
1066
|
-
|
|
1067
|
-
def __init__(
|
|
1068
|
-
self,
|
|
1069
|
-
width,
|
|
1070
|
-
height,
|
|
1071
|
-
inp,
|
|
1072
|
-
channels,
|
|
1073
|
-
out_after_mlp=32,
|
|
1074
|
-
width_CNNs=[64, 32],
|
|
1075
|
-
kernel_conv=3,
|
|
1076
|
-
stride=2,
|
|
1077
|
-
padding=1,
|
|
1078
|
-
dilation=1,
|
|
1079
|
-
output_padding=1,
|
|
1080
|
-
final_activation=lambda x: x,
|
|
1081
|
-
mlp_width=None,
|
|
1082
|
-
mlp_depth=0,
|
|
1083
|
-
dimension=2,
|
|
1084
|
-
*,
|
|
1085
|
-
kwargs_cnn={},
|
|
1086
|
-
kwargs_mlp={},
|
|
1087
|
-
):
|
|
1088
|
-
super().__init__()
|
|
1089
|
-
|
|
1090
|
-
width = 10 if width is None else width # a default value to avoid errors
|
|
1091
|
-
|
|
1092
|
-
D0s, D1s = prev_D_CNN_trans(
|
|
1093
|
-
height,
|
|
1094
|
-
width,
|
|
1095
|
-
padding,
|
|
1096
|
-
kernel_conv,
|
|
1097
|
-
stride,
|
|
1098
|
-
dilation,
|
|
1099
|
-
output_padding,
|
|
1100
|
-
len(width_CNNs) + 1,
|
|
1101
|
-
all_D0s=[],
|
|
1102
|
-
all_D1s=[],
|
|
1103
|
-
)
|
|
1104
|
-
|
|
1105
|
-
first_D0 = D0s[-1]
|
|
1106
|
-
first_D1 = D1s[-1]
|
|
1107
|
-
|
|
1108
|
-
_, _, final_D0, final_D1 = is_convT_valid(
|
|
1109
|
-
first_D0,
|
|
1110
|
-
first_D1,
|
|
1111
|
-
height,
|
|
1112
|
-
width,
|
|
1113
|
-
padding,
|
|
1114
|
-
kernel_conv,
|
|
1115
|
-
stride,
|
|
1116
|
-
dilation,
|
|
1117
|
-
output_padding,
|
|
1118
|
-
len(width_CNNs) + 1,
|
|
1119
|
-
)
|
|
1120
|
-
|
|
1121
|
-
if dimension == 2:
|
|
1122
|
-
self.d_shape = [first_D0, first_D1]
|
|
1123
|
-
elif dimension == 1:
|
|
1124
|
-
self.d_shape = [first_D0]
|
|
1125
|
-
|
|
1126
|
-
mlcnn = MLCNN(
|
|
1127
|
-
out_after_mlp,
|
|
1128
|
-
width_CNNs[-1],
|
|
1129
|
-
stride,
|
|
1130
|
-
padding,
|
|
1131
|
-
kernel_conv,
|
|
1132
|
-
dilation,
|
|
1133
|
-
width_CNNs,
|
|
1134
|
-
dimension=dimension,
|
|
1135
|
-
transpose=True,
|
|
1136
|
-
output_padding=output_padding,
|
|
1137
|
-
final_activation=torch.nn.functional.relu,
|
|
1138
|
-
**kwargs_cnn,
|
|
1139
|
-
)
|
|
1140
|
-
|
|
1141
|
-
if mlp_depth != 0:
|
|
1142
|
-
if mlp_width is not None:
|
|
1143
|
-
assert (
|
|
1144
|
-
mlp_width >= inp
|
|
1145
|
-
), "Choose a bigger (or equal) MLP width than the latent space in decoder."
|
|
1146
|
-
else:
|
|
1147
|
-
mlp_width = inp
|
|
1148
|
-
|
|
1149
|
-
mlp = MLP_with_linear(
|
|
1150
|
-
in_size=inp,
|
|
1151
|
-
out_size=out_after_mlp * int(np.prod(self.d_shape)),
|
|
1152
|
-
width_size=mlp_width,
|
|
1153
|
-
depth=mlp_depth,
|
|
1154
|
-
**kwargs_mlp,
|
|
1155
|
-
)
|
|
1156
|
-
|
|
1157
|
-
if dimension == 2:
|
|
1158
|
-
final_conv = Conv2d(
|
|
1159
|
-
width_CNNs[-1],
|
|
1160
|
-
channels,
|
|
1161
|
-
kernel_size=(1 + (final_D0 - height), 1 + (final_D1 - width)),
|
|
1162
|
-
stride=1,
|
|
1163
|
-
padding=0,
|
|
1164
|
-
dilation=1,
|
|
1165
|
-
)
|
|
1166
|
-
elif dimension == 1:
|
|
1167
|
-
final_conv = Conv1d(
|
|
1168
|
-
width_CNNs[-1],
|
|
1169
|
-
channels,
|
|
1170
|
-
kernel_size=(1 + (final_D0 - height)),
|
|
1171
|
-
stride=1,
|
|
1172
|
-
padding=0,
|
|
1173
|
-
dilation=1,
|
|
1174
|
-
)
|
|
1175
|
-
|
|
1176
|
-
self.start_dim = inp
|
|
1177
|
-
|
|
1178
|
-
self.final_act = final_activation
|
|
1179
|
-
self.layers = torch.nn.ModuleList([mlp, mlcnn, final_conv])
|
|
1180
|
-
self.out_after_mlp = out_after_mlp
|
|
1181
|
-
|
|
1182
|
-
def forward(self, x, *args, **kwargs):
|
|
1183
|
-
x = self.layers[0](x)
|
|
1184
|
-
x = torch.reshape(x, (self.out_after_mlp, *self.d_shape))
|
|
1185
|
-
x = self.layers[1](x)
|
|
1186
|
-
x = self.layers[2](x)
|
|
1187
|
-
x = self.final_act(x)
|
|
1188
|
-
return x
|
|
1189
|
-
|
|
1190
|
-
class Conv3d_(Conv3d):
|
|
1191
|
-
def __init__(self, *args, **kwargs):
|
|
1192
|
-
if "output_padding" in kwargs:
|
|
1193
|
-
kwargs.pop("output_padding")
|
|
1194
|
-
super().__init__(*args, **kwargs)
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
class MLCNN3D(torch.nn.Module):
|
|
1198
|
-
layers: tuple
|
|
1199
|
-
activation: Callable
|
|
1200
|
-
final_activation: Callable
|
|
1201
|
-
|
|
1202
|
-
def __init__(
|
|
1203
|
-
self,
|
|
1204
|
-
start_dim,
|
|
1205
|
-
out_dim,
|
|
1206
|
-
stride,
|
|
1207
|
-
padding,
|
|
1208
|
-
kernel_conv,
|
|
1209
|
-
dilation,
|
|
1210
|
-
CNN_widths,
|
|
1211
|
-
activation=torch.nn.functional.relu,
|
|
1212
|
-
final_activation=lambda x: x,
|
|
1213
|
-
transpose=False,
|
|
1214
|
-
output_padding=None,
|
|
1215
|
-
*,
|
|
1216
|
-
kwargs_cnn={},
|
|
1217
|
-
**kwargs,
|
|
1218
|
-
):
|
|
1219
|
-
"""Note: if provided as lists, activations should be one less than widths.
|
|
1220
|
-
The last activation is specified by "final activation"."""
|
|
1221
|
-
super().__init__()
|
|
1222
|
-
|
|
1223
|
-
if isinstance(CNN_widths, list):
|
|
1224
|
-
CNNs_num = len(CNN_widths)
|
|
1225
|
-
else:
|
|
1226
|
-
CNN_widths = [CNN_widths] * (CNNs_num - 1) + [out_dim]
|
|
1227
|
-
|
|
1228
|
-
CNN_widths_b = [start_dim] + CNN_widths[:-1]
|
|
1229
|
-
layers = []
|
|
1230
|
-
fn = Conv3d_ if not transpose else ConvTranspose3d
|
|
1231
|
-
for i in range(len(CNN_widths)):
|
|
1232
|
-
layers.append(
|
|
1233
|
-
fn(
|
|
1234
|
-
CNN_widths_b[i],
|
|
1235
|
-
CNN_widths[i],
|
|
1236
|
-
kernel_size=kernel_conv,
|
|
1237
|
-
stride=stride,
|
|
1238
|
-
padding=padding,
|
|
1239
|
-
dilation=dilation,
|
|
1240
|
-
output_padding=output_padding,
|
|
1241
|
-
**kwargs_cnn,
|
|
1242
|
-
)
|
|
1243
|
-
)
|
|
1244
|
-
|
|
1245
|
-
self.layers = torch.nn.ModuleList(layers)
|
|
1246
|
-
|
|
1247
|
-
self.activation = [lambda x: activation(x) for _ in CNN_widths]
|
|
1248
|
-
|
|
1249
|
-
self.final_activation = final_activation
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
def forward(self, x):
|
|
1253
|
-
for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation[:-1])):
|
|
1254
|
-
x = layer(x)
|
|
1255
|
-
x = act(x)
|
|
1256
|
-
x = self.final_activation(self.layers[-1](x))
|
|
1257
|
-
return x
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
class CNN3D_with_MLP(torch.nn.Module):
|
|
1261
|
-
"""Class mainly for creating encoders with CNNs.
|
|
1262
|
-
The encoder is composed of multiple CNNs followed by an MLP.
|
|
1263
|
-
"""
|
|
1264
|
-
|
|
1265
|
-
layers: tuple[MLCNN3D, Linear]
|
|
1266
|
-
|
|
1267
|
-
def __init__(
|
|
1268
|
-
self,
|
|
1269
|
-
depth,
|
|
1270
|
-
height,
|
|
1271
|
-
width,
|
|
1272
|
-
out,
|
|
1273
|
-
channels=1,
|
|
1274
|
-
width_CNNs=[32, 64, 128, 256],
|
|
1275
|
-
kernel_conv=3,
|
|
1276
|
-
stride=2,
|
|
1277
|
-
padding=1,
|
|
1278
|
-
dilation=1,
|
|
1279
|
-
mlp_width=None,
|
|
1280
|
-
mlp_depth=0,
|
|
1281
|
-
final_activation=lambda x: x,
|
|
1282
|
-
*,
|
|
1283
|
-
kwargs_cnn={},
|
|
1284
|
-
kwargs_mlp={},
|
|
1285
|
-
):
|
|
1286
|
-
super().__init__()
|
|
1287
|
-
|
|
1288
|
-
if mlp_depth != 0:
|
|
1289
|
-
if mlp_width is not None:
|
|
1290
|
-
assert (
|
|
1291
|
-
mlp_width >= out
|
|
1292
|
-
), "Choose a bigger (or equal) MLP width than the latent space in the encoder."
|
|
1293
|
-
else:
|
|
1294
|
-
mlp_width = out
|
|
1295
|
-
|
|
1296
|
-
try:
|
|
1297
|
-
last_width = width_CNNs[-1]
|
|
1298
|
-
except:
|
|
1299
|
-
last_width = width_CNNs
|
|
1300
|
-
|
|
1301
|
-
mlcnn3d = MLCNN3D(
|
|
1302
|
-
channels,
|
|
1303
|
-
last_width,
|
|
1304
|
-
stride,
|
|
1305
|
-
padding,
|
|
1306
|
-
kernel_conv,
|
|
1307
|
-
dilation,
|
|
1308
|
-
width_CNNs,
|
|
1309
|
-
final_activation=torch.nn.functional.relu,
|
|
1310
|
-
**kwargs_cnn,
|
|
1311
|
-
)
|
|
1312
|
-
final_Ds = mlcnn3d(torch.zeros((channels, depth, height, width))).shape[-3:]
|
|
1313
|
-
mlp = MLP_with_linear(
|
|
1314
|
-
in_size=final_Ds[0] * final_Ds[1] * final_Ds[2] *last_width,
|
|
1315
|
-
out_size=out,
|
|
1316
|
-
width_size=mlp_width,
|
|
1317
|
-
depth=mlp_depth,
|
|
1318
|
-
**kwargs_mlp,
|
|
1319
|
-
)
|
|
1320
|
-
act = lambda x: final_activation(x)
|
|
1321
|
-
self.layers = torch.nn.ModuleList([mlcnn3d, mlp])
|
|
1322
|
-
self.final_act = act
|
|
1323
|
-
|
|
1324
|
-
def forward(self, x, *args, **kwargs):
|
|
1325
|
-
x = self.layers[0](x)
|
|
1326
|
-
x = torch.unsqueeze(torch.flatten(x), -1)
|
|
1327
|
-
x = self.layers[1](torch.squeeze(x))
|
|
1328
|
-
x = self.final_act(x)
|
|
1329
|
-
return x
|
|
1330
|
-
|
|
1331
|
-
def prev_D_CNN3D_trans(D0, D1, D2, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[], all_D2s=[]):
|
|
1332
|
-
pad = int_to_lst(pad, 3)
|
|
1333
|
-
ker = int_to_lst(ker, 3)
|
|
1334
|
-
st = int_to_lst(st, 3)
|
|
1335
|
-
dil = int_to_lst(dil, 3)
|
|
1336
|
-
outpad = int_to_lst(outpad, 3)
|
|
1337
|
-
|
|
1338
|
-
if num == 0:
|
|
1339
|
-
return all_D0s, all_D1s , all_D2s
|
|
1340
|
-
|
|
1341
|
-
all_D0s.append(int(np.ceil(D0)))
|
|
1342
|
-
all_D1s.append(int(np.ceil(D1)))
|
|
1343
|
-
all_D2s.append(int(np.ceil(D2)))
|
|
1344
|
-
|
|
1345
|
-
return prev_D_CNN3D_trans(
|
|
1346
|
-
(D0 + 2 * pad[0] - dil[0] * (ker[0] - 1) - 1 - outpad[0]) / st[0] + 1,
|
|
1347
|
-
(D1 + 2 * pad[1] - dil[1] * (ker[1] - 1) - 1 - outpad[1]) / st[1] + 1,
|
|
1348
|
-
(D2 + 2 * pad[2] - dil[2] * (ker[2] - 1) - 1 - outpad[2]) / st[2] + 1,
|
|
1349
|
-
pad,
|
|
1350
|
-
ker,
|
|
1351
|
-
st,
|
|
1352
|
-
dil,
|
|
1353
|
-
outpad,
|
|
1354
|
-
num - 1,
|
|
1355
|
-
)
|
|
1356
|
-
|
|
1357
|
-
def find_padding_conv3dT(D, data_dim0, ker, st, dil, outpad):
|
|
1358
|
-
return D
|
|
1359
|
-
|
|
1360
|
-
def next_CNN3D_trans(O0, O1, O2, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[], all_D2s=[]):
|
|
1361
|
-
pad = int_to_lst(pad, 3)
|
|
1362
|
-
ker = int_to_lst(ker, 3)
|
|
1363
|
-
st = int_to_lst(st, 3)
|
|
1364
|
-
dil = int_to_lst(dil, 3)
|
|
1365
|
-
outpad = int_to_lst(outpad, 3)
|
|
1366
|
-
|
|
1367
|
-
if num == 0:
|
|
1368
|
-
return all_D0s, all_D1s, all_D2s
|
|
1369
|
-
|
|
1370
|
-
all_D0s.append(int(O0))
|
|
1371
|
-
all_D1s.append(int(O1))
|
|
1372
|
-
all_D2s.append(int(O2))
|
|
1373
|
-
|
|
1374
|
-
return next_CNN3D_trans(
|
|
1375
|
-
(O0 - 1) * st[0] + dil[0] * (ker[0] - 1) - 2 * pad[0] + 1 + outpad[0],
|
|
1376
|
-
(O1 - 1) * st[1] + dil[1] * (ker[1] - 1) - 2 * pad[1] + 1 + outpad[1],
|
|
1377
|
-
(O2 - 1) * st[2] + dil[2] * (ker[2] - 1) - 2 * pad[2] + 1 + outpad[2],
|
|
1378
|
-
pad,
|
|
1379
|
-
ker,
|
|
1380
|
-
st,
|
|
1381
|
-
dil,
|
|
1382
|
-
outpad,
|
|
1383
|
-
num - 1,
|
|
1384
|
-
)
|
|
1385
|
-
|
|
1386
|
-
def is_conv3dT_valid(D0, D1, D2, data_dim0, data_dim1, data_dim2, pad, ker, st, dil, outpad, nums):
|
|
1387
|
-
all_D0s, all_D1s, all_D2s = next_CNN3D_trans(D0, D1, D2, pad, ker, st, dil, outpad, nums, all_D0s=[], all_D1s=[], all_D2s=[])
|
|
1388
|
-
final_D0 = all_D0s[-1]
|
|
1389
|
-
final_D1 = all_D1s[-1]
|
|
1390
|
-
final_D2 = all_D2s[-1]
|
|
1391
|
-
return final_D0 == data_dim0, final_D1 == data_dim1, final_D2 ==data_dim2, final_D0, final_D1, final_D2
|
|
1392
|
-
|
|
1393
|
-
class MLP_with_CNN3D_trans(torch.nn.Module):
|
|
1394
|
-
"""Class mainly for creating encoders with CNNs.
|
|
1395
|
-
The encoder is composed of multiple CNNs followed by an MLP.
|
|
1396
|
-
"""
|
|
1397
|
-
|
|
1398
|
-
layers: tuple[MLCNN3D, Linear]
|
|
1399
|
-
start_dim: int
|
|
1400
|
-
first_D0: int
|
|
1401
|
-
first_D1: int
|
|
1402
|
-
first_D2: int
|
|
1403
|
-
out_after_mlp: int
|
|
1404
|
-
final_act: Callable
|
|
1405
|
-
|
|
1406
|
-
def __init__(
|
|
1407
|
-
self,
|
|
1408
|
-
depth,
|
|
1409
|
-
height,
|
|
1410
|
-
width,
|
|
1411
|
-
inp,
|
|
1412
|
-
channels,
|
|
1413
|
-
out_after_mlp=32,
|
|
1414
|
-
width_CNNs=[64, 32],
|
|
1415
|
-
kernel_conv=3,
|
|
1416
|
-
stride=2,
|
|
1417
|
-
padding=1,
|
|
1418
|
-
dilation=1,
|
|
1419
|
-
output_padding=1,
|
|
1420
|
-
final_activation=lambda x: x,
|
|
1421
|
-
mlp_width=None,
|
|
1422
|
-
mlp_depth=0,
|
|
1423
|
-
*,
|
|
1424
|
-
kwargs_cnn={},
|
|
1425
|
-
kwargs_mlp={},
|
|
1426
|
-
):
|
|
1427
|
-
|
|
1428
|
-
super().__init__()
|
|
1429
|
-
|
|
1430
|
-
D0s, D1s, D2s = prev_D_CNN3D_trans(
|
|
1431
|
-
depth,
|
|
1432
|
-
height,
|
|
1433
|
-
width,
|
|
1434
|
-
padding,
|
|
1435
|
-
kernel_conv,
|
|
1436
|
-
stride,
|
|
1437
|
-
dilation,
|
|
1438
|
-
output_padding,
|
|
1439
|
-
len(width_CNNs) + 1,
|
|
1440
|
-
all_D0s=[],
|
|
1441
|
-
all_D1s=[],
|
|
1442
|
-
all_D2s=[],
|
|
1443
|
-
)
|
|
1444
|
-
|
|
1445
|
-
first_D0 = D0s[-1]
|
|
1446
|
-
first_D1 = D1s[-1]
|
|
1447
|
-
first_D2 = D2s[-1]
|
|
1448
|
-
|
|
1449
|
-
_, _, _, final_D0, final_D1, final_D2 = is_conv3dT_valid(
|
|
1450
|
-
first_D0,
|
|
1451
|
-
first_D1,
|
|
1452
|
-
first_D2,
|
|
1453
|
-
depth,
|
|
1454
|
-
height,
|
|
1455
|
-
width,
|
|
1456
|
-
padding,
|
|
1457
|
-
kernel_conv,
|
|
1458
|
-
stride,
|
|
1459
|
-
dilation,
|
|
1460
|
-
output_padding,
|
|
1461
|
-
len(width_CNNs) + 1,
|
|
1462
|
-
)
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
mlcnn3d = MLCNN3D(
|
|
1466
|
-
out_after_mlp,
|
|
1467
|
-
width_CNNs[-1],
|
|
1468
|
-
stride,
|
|
1469
|
-
padding,
|
|
1470
|
-
kernel_conv,
|
|
1471
|
-
dilation,
|
|
1472
|
-
width_CNNs,
|
|
1473
|
-
transpose=True,
|
|
1474
|
-
output_padding=output_padding,
|
|
1475
|
-
final_activation=torch.nn.functional.relu,
|
|
1476
|
-
**kwargs_cnn,
|
|
1477
|
-
)
|
|
1478
|
-
|
|
1479
|
-
if mlp_depth != 0:
|
|
1480
|
-
if mlp_width is not None:
|
|
1481
|
-
assert (
|
|
1482
|
-
mlp_width >= inp
|
|
1483
|
-
), "Choose a bigger (or equal) MLP width than the latent space in decoder."
|
|
1484
|
-
else:
|
|
1485
|
-
mlp_width = inp
|
|
1486
|
-
|
|
1487
|
-
mlp = MLP_with_linear(
|
|
1488
|
-
in_size=inp,
|
|
1489
|
-
out_size=out_after_mlp * first_D0 * first_D1 * first_D2,
|
|
1490
|
-
width_size=mlp_width,
|
|
1491
|
-
depth=mlp_depth,
|
|
1492
|
-
**kwargs_mlp,
|
|
1493
|
-
)
|
|
1494
|
-
|
|
1495
|
-
final_conv = Conv3d(
|
|
1496
|
-
width_CNNs[-1],
|
|
1497
|
-
channels,
|
|
1498
|
-
kernel_size=(1 + (final_D0 - depth), 1 + (final_D1 - height), 1 + (final_D2 - width)),
|
|
1499
|
-
stride=1,
|
|
1500
|
-
padding=0,
|
|
1501
|
-
dilation=1,
|
|
1502
|
-
)
|
|
1503
|
-
|
|
1504
|
-
self.start_dim = inp
|
|
1505
|
-
self.first_D0 = first_D0
|
|
1506
|
-
self.first_D1 = first_D1
|
|
1507
|
-
self.first_D2 = first_D2
|
|
1508
|
-
self.final_act = final_activation
|
|
1509
|
-
self.layers = torch.nn.ModuleList([mlp, mlcnn3d, final_conv])
|
|
1510
|
-
self.out_after_mlp = out_after_mlp
|
|
1511
|
-
|
|
1512
|
-
def forward(self, x, *args, **kwargs):
|
|
1513
|
-
print(x.shape)
|
|
1514
|
-
x = self.layers[0](x)
|
|
1515
|
-
x = torch.reshape(x, (self.out_after_mlp, self.first_D0, self.first_D1, self.first_D2))
|
|
1516
|
-
x = self.layers[1](x)
|
|
1517
|
-
x = self.layers[2](x)
|
|
1518
|
-
x = self.final_act(x)
|
|
1519
|
-
return x
|
|
1520
|
-
|
|
1521
|
-
################# Other Useful functions ################################
|
|
1522
|
-
|
|
1523
|
-
def int_to_lst(x, len=1):
|
|
1524
|
-
""" Integer to list """
|
|
1525
|
-
if isinstance(x, int):
|
|
1526
|
-
return [x]*len
|
|
1527
|
-
return x
|
|
1528
|
-
|
|
1529
|
-
def remove_keys_from_dict(d, keys):
|
|
1530
|
-
return {k: v for k, v in d.items() if k not in keys}
|
|
1531
|
-
|
|
1532
|
-
def merge_dicts(d1, d2):
|
|
1533
|
-
return {**d1, **d2}
|
|
1534
|
-
|
|
1535
|
-
def v_print(s, v, f=False):
|
|
1536
|
-
if v:
|
|
1537
|
-
print(s, flush=f)
|
|
1538
|
-
|
|
1539
|
-
def countList(lst1, lst2):
|
|
1540
|
-
return [sub[item] for item in range(len(lst2)) for sub in [lst1, lst2]]
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
class Sample(torch.nn.Module):
|
|
1544
|
-
sample_dim: int
|
|
1545
|
-
|
|
1546
|
-
""" Class used to allow random sampling using JAX.
|
|
1547
|
-
|
|
1548
|
-
Funny trick to allow seed to change since otherwise we end up
|
|
1549
|
-
with the same noise epsilon for every forward pass.
|
|
1550
|
-
"""
|
|
1551
|
-
def __init__(self, sample_dim):
|
|
1552
|
-
super().__init__()
|
|
1553
|
-
self.sample_dim = sample_dim
|
|
1554
|
-
|
|
1555
|
-
def forward(self, mean, logvar, epsilon=None, ret=False, *args, **kwargs):
|
|
1556
|
-
epsilon = 0 if epsilon is None else epsilon
|
|
1557
|
-
if ret:
|
|
1558
|
-
return mean + torch.exp(0.5 * logvar) * epsilon, mean, logvar
|
|
1559
|
-
return mean + torch.exp(0.5 * logvar) * epsilon
|
|
1560
|
-
|
|
1561
|
-
def create_epsilon(self, seed, shape):
|
|
1562
|
-
return random.normal(size=shape)
|