RRAEsTorch 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,1562 +0,0 @@
1
- from collections.abc import Callable
2
- from typing import Optional
3
- from torch.nn import Linear
4
- from torch.nn import Conv2d, ConvTranspose2d, Conv1d, ConvTranspose1d, Conv3d, ConvTranspose3d
5
- import numpy.random as random
6
- import numpy as np
7
- from operator import itemgetter
8
- import numpy as np
9
- from tqdm import tqdm
10
- import itertools
11
- from torchvision.ops import MLP
12
- import torch
13
- from torch.func import vmap
14
-
15
- ##################### ML Training/Evaluation functions ##########################
16
-
17
- def eval_with_batches(
18
- x,
19
- batch_size,
20
- call_func,
21
- end_type="concat_and_resort",
22
- str=None,
23
- *args,
24
- **kwargs,
25
- ):
26
- """Function to evaluate a model on a dataset using batches.
27
-
28
- Parameters
29
- ----------
30
- x : jnp.array
31
- Data to which the model will be applied.
32
- batch_size : int
33
- The batch size
34
- call_func : callable
35
- The function that calls the model, usually model.__call__
36
- end_type : str
37
- What to do with the output, can use first, sum, mean, concat,
38
- stack, and concat_and_resort. Note that the data is shuffled
39
- when batched. The last option (and defaut value) resorts the
40
- outputs to give back the same order as the input.
41
- key_idx : int
42
- Seed for random key to shuffle the data before batching.
43
-
44
- Returns
45
- -------
46
- final_pred : jnp.array
47
- The predictions over the batches.
48
- """
49
- idxs = []
50
- all_preds = []
51
-
52
- if str is not None:
53
- print(str)
54
- fn = lambda x, *args, **kwargs: tqdm(x, *args, **kwargs)
55
- else:
56
- fn = lambda x, *args, **kwargs: x
57
-
58
- if not (isinstance(x, tuple) or isinstance(x, list)):
59
- x = [x]
60
- x = [el.T for el in x]
61
-
62
- for _, inputs in fn(
63
- zip(
64
- itertools.count(start=0),
65
- dataloader(
66
- [*x, np.arange(0, x[0].shape[0], 1)],
67
- batch_size,
68
- once=True,
69
- ),
70
- ),
71
- total=int(x[0].shape[-1] / batch_size),
72
- ):
73
- input_b = inputs[:-1]
74
- idx = inputs[-1]
75
-
76
- input_b = [el.T for el in input_b]
77
-
78
- pred = call_func(*input_b, *args, **kwargs)
79
- idxs.append(idx)
80
- all_preds.append(pred)
81
- if end_type == "first":
82
- break
83
- idxs = np.concatenate(idxs)
84
- match end_type:
85
- case "concat_and_resort":
86
- final_pred = np.concatenate(all_preds, -1)[..., np.argsort(idxs)]
87
- case "concat":
88
- final_pred = np.concatenate(all_preds, -1)
89
- case "stack":
90
- final_pred = np.stack(all_preds, -1)
91
- case "mean":
92
- final_pred = sum(all_preds) / len(all_preds)
93
- case "sum":
94
- final_pred = sum(all_preds)
95
- case "first":
96
- final_pred = all_preds[0]
97
- case _:
98
- final_pred = all_preds
99
- return final_pred
100
-
101
-
102
- def loss_generator(which=None, norm_loss_=None):
103
- """ Allows users to use different loss functions by only providing strings. """
104
- if norm_loss_ is None:
105
- norm_loss_ = lambda x1, x2: torch.linalg.norm(x1 - x2) / torch.linalg.norm(x2) * 100
106
-
107
- if (which == "default") or (which == "Vanilla"):
108
- def loss_fun(model, input, out, **kwargs):
109
- pred = model(input, keep_normalized=True)
110
- aux = {"loss": norm_loss_(pred, out)}
111
- return norm_loss_(pred, out), (aux, {})
112
-
113
- elif which == "RRAE":
114
- def loss_fun(model, input, out, *, k_max, **kwargs):
115
- pred = model(input, k_max=k_max, keep_normalized=True)
116
- aux = {"loss": norm_loss_(pred, out), "k_max": k_max}
117
- return norm_loss_(pred, out), (aux, {})
118
-
119
- elif which == "Sparse":
120
- def loss_fun(
121
- model, input, out, *, sparsity=0.05, beta=1.0, **kwargs
122
- ):
123
- pred = model(input, keep_normalized=True)
124
- lat = model.latent(input)
125
- sparse_term = sparsity * torch.log(sparsity / (torch.mean(lat) + 1e-8)) + (
126
- 1 - sparsity
127
- ) * torch.log((1 - sparsity) / (1 - torch.mean(lat) + 1e-8))
128
-
129
- aux = {"loss rec": norm_loss_(pred, out), "loss sparse": sparse_term}
130
- return norm_loss_(pred, out) + beta * sparse_term, (aux, {})
131
-
132
- elif which == "nuc":
133
- def loss_fun(
134
- model,
135
- input,
136
- out,
137
- *,
138
- lambda_nuc=0.001,
139
- norm_loss=None,
140
- find_layer=None,
141
- **kwargs,
142
- ):
143
- if norm_loss is None:
144
- norm_loss = norm_loss_
145
- pred = model(input)
146
-
147
- if find_layer is None:
148
- raise ValueError(
149
- "To use LoRAE, you should specify how to find the layer for "
150
- "which we add the nuclear norm in the loss. To do so, give the path "
151
- "to the layer as loss kwargs to the trainor: "
152
- 'e.g.: \n"loss_kwargs": {"find_layer": lambda model: model.encode.layers[-2].layers_l[-1].weight} (for predefined CNN AE) \n'
153
- '"loss_kwargs": {"find_layer": lambda model: model.encode.layers_l[-1].weight} (for predefined MLP AE).'
154
- )
155
- else:
156
- weight = find_layer(model)
157
-
158
- aux = {"loss rec": norm_loss_(pred, out), "loss nuc": torch.linalg.norm(weight, "nuc")}
159
- return norm_loss(pred, out) + lambda_nuc * torch.linalg.norm(weight, "nuc"), (aux, {})
160
-
161
- elif which == "VRRAE":
162
- norm_loss_ = lambda pr, out: torch.linalg.norm(pr-out)/torch.linalg.norm(out)*100
163
-
164
- def loss_fun(model, input, out, idx, epsilon, k_max, beta=None, **kwargs):
165
- lat, means, logvars = model.latent(input, epsilon=epsilon, k_max=k_max, return_lat_dist=True)
166
- pred = model.decode(lat)
167
- kl_loss = torch.sum(
168
- -0.5 * (1 + logvars - torch.square(means) - torch.exp(logvars))
169
- )
170
- loss_rec = norm_loss_(pred, out)
171
- aux = {
172
- "loss rec": loss_rec,
173
- "loss kl": kl_loss,
174
- }
175
- if beta is None:
176
- beta = lambda_fn(loss_rec, kl_loss)
177
- aux["beta"] = beta
178
- return loss_rec + beta*kl_loss, (aux, {})
179
-
180
- elif which == "VAE":
181
- norm_loss_ = lambda pr, out: torch.linalg.norm(pr-out)/torch.linalg.norm(out)*100
182
-
183
- def lambda_fn(loss, loss_c):
184
- return loss_c*torch.exp(-0.1382*loss)
185
-
186
- def loss_fun(model, input, out, idx, epsilon, beta=None, **kwargs):
187
- lat, means, logvars = model.latent(input, epsilon=epsilon, length=input.shape[3], return_lat_dist=True)
188
- pred = model.decode(lat)
189
- kl_loss = torch.sum(
190
- -0.5 * (1 + logvars - torch.square(means) - torch.exp(logvars))
191
- )
192
- loss_rec = norm_loss_(pred, out)
193
- aux = {
194
- "loss rec": loss_rec,
195
- "loss kl": kl_loss,
196
- }
197
- if beta is None:
198
- beta = lambda_fn(loss_rec, kl_loss)
199
- aux["beta"] = beta
200
- return loss_rec + beta*kl_loss, (aux, {})
201
-
202
- elif "Contractive":
203
- def loss_fun(model, input, out, *, beta=1.0, find_weight=None, **kwargs):
204
- assert find_weight is not None
205
- lat = model.latent(input)
206
- pred = model(input, keep_normalized=True)
207
- W = find_weight(model)
208
- W = find_weight(model)
209
- dh = lat * (1 - lat)
210
- dh = dh.T
211
- loss_contr = torch.sum(torch.matmul(dh**2, torch.square(W)))
212
- aux = {"loss": norm_loss_(pred, out), "cont": loss_contr}
213
- aux = {"loss": norm_loss_(pred, out), "cont": loss_contr}
214
- return norm_loss_(pred, out) + beta * loss_contr, (aux, {})
215
- else:
216
- raise ValueError(f"{which} is an Unknown loss type")
217
- return loss_fun
218
-
219
- def dataloader(arrays, batch_size, p_vals=None, once=False):
220
- """ JAX copatible dataloader to batch data randomly and differently
221
- between epochs. """
222
- dataset_size = arrays[0].shape[0]
223
- arrays = [array if array is not None else [None] * dataset_size for array in arrays]
224
- indices = np.arange(dataset_size)
225
- kk = 0
226
-
227
- while True:
228
- perm = random.permutation(indices)
229
- start = 0
230
- end = batch_size
231
- while end <= dataset_size:
232
- batch_perm = perm[start:end]
233
- arrs = tuple(
234
- itemgetter(*batch_perm)(array) for array in arrays
235
- ) # Works for lists and arrays
236
- if batch_size != 1:
237
- yield [np.array(arr) for arr in arrs]
238
- else:
239
- yield [
240
- [arr] if arr is None else np.expand_dims(np.array(arr), axis=0)
241
- for arr in arrs
242
- ]
243
- start = end
244
- end = start + batch_size
245
- if once:
246
- if dataset_size % batch_size != 0:
247
- batch_perm = perm[-(dataset_size % batch_size) :]
248
- arrs = tuple(
249
- itemgetter(*batch_perm)(array) for array in arrays
250
- ) # Works for lists and arrays
251
- if dataset_size % batch_size == 1:
252
- yield [
253
- [arr] if arr is None else np.expand_dims(np.array(arr), 0)
254
- for arr in arrs
255
- ]
256
- else:
257
- yield [[arr] if arr is None else np.array(arr) for arr in arrs]
258
- break
259
- kk += 1
260
-
261
- ##################### ML synthetic data generation ########################
262
-
263
- def np_vmap(func, to_array=True):
264
- """ Similar to JAX's vmap but in numpy.
265
- Slow but useful if we want to use the same syntax."""
266
- def map_func(*arrays, args=None, kwargs=None):
267
- sols = []
268
- for elems in zip(*arrays):
269
- if (args is None) and (kwargs is None):
270
- sols.append(func(*elems))
271
- elif (args is not None) and (kwargs is not None):
272
- sols.append(func(*elems, *args, **kwargs))
273
- elif args is not None:
274
- sols.append(func(*elems, *args))
275
- else:
276
- sols.append(func(*elems, **kwargs))
277
- try:
278
- if isinstance(sols[0], list) or isinstance(sols[0], tuple):
279
- final_sols = []
280
- for i in range(len(sols[0])):
281
- final_sols.append(np.array([sol[i] for sol in sols]))
282
- return final_sols
283
- return np.array([np.squeeze(np.stack(sol, axis=0)) for sol in sols])
284
- except:
285
- if to_array:
286
- return np.array(sols)
287
- else:
288
- return sols
289
-
290
- return map_func
291
-
292
- def divide_return(
293
- inp_all,
294
- p_all=None,
295
- output=None,
296
- prop_train=0.8,
297
- test_end=0,
298
- eps=1,
299
- pre_func_in=lambda x: x,
300
- pre_func_out=lambda x: x,
301
- args=(),
302
- ):
303
- """p_all of shape (P x N) and y_all of shape (T x N).
304
- The function divides into train/test according to the parameters
305
- to allow the test set to be interpolated linearly from the training set
306
- (if possible). If test_end is specified this is overwridden to only take
307
- the lest test_end values for testing.
308
-
309
- NOTE: pre_func_in and pre_func_out are functions you want to apply over
310
- the input and output but you can not do it on all the data since it is
311
- too big (e.g. conversion to float). These functions will be applied on
312
- batches during training/evaluation of the Network."""
313
-
314
- if test_end == 0:
315
- idx_test = random.permutation(inp_all.shape[-1])[
316
- : int(inp_all.shape[-1] * (1 - prop_train))
317
- ]
318
-
319
- x_test = inp_all[..., idx_test]
320
- x_train = random.permutation(
321
- np.delete(inp_all, idx_test, -1), -1
322
- )
323
- else:
324
- if p_all is not None:
325
- p_test = p_all[-test_end:]
326
- p_train = p_all[: len(p_all) - test_end]
327
- x_test = inp_all[..., -test_end:]
328
- x_train = inp_all[..., : inp_all.shape[-1] - test_end]
329
-
330
- if output is None:
331
- output_train = x_train
332
- output_test = x_test
333
- else:
334
- output_test = output[idx_test]
335
- output_train = np.delete(output, idx_test, 0)
336
-
337
- if p_all is not None:
338
- p_train = np.expand_dims(p_train, -1) if len(p_train.shape) == 1 else p_train
339
- p_test = np.expand_dims(p_test, -1) if len(p_test.shape) == 1 else p_test
340
- else:
341
- p_train = None
342
- p_test = None
343
-
344
-
345
- return (
346
- torch.tensor(x_train, dtype=torch.float32),
347
- torch.tensor(x_test, dtype=torch.float32),
348
- torch.tensor(p_train, dtype=torch.float32) if p_train is not None else None,
349
- torch.tensor(p_test, dtype=torch.float32) if p_test is not None else None,
350
- torch.tensor(output_train, dtype=torch.float32),
351
- torch.tensor(output_test, dtype=torch.float32),
352
- pre_func_in,
353
- pre_func_out,
354
- args,
355
- )
356
-
357
-
358
- def get_data(problem, folder=None, train_size=1000, test_size=10000, **kwargs):
359
- """Function that generates the examples presented in the paper."""
360
- import numpy as np
361
- match problem:
362
- case "2d_gaussian_shift_scale":
363
- D = 64 # Dimension of the domain
364
- Ntr = train_size # Number of training samples
365
- Nte = test_size # Number of test samples
366
- sigma = 0.2
367
-
368
- def gaussian_2d(x, y, x0, y0, sigma):
369
- return np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
370
-
371
- x = np.linspace(-1, 1, D)
372
- y = np.linspace(-1, 1, D)
373
- X, Y = np.meshgrid(x, y)
374
- # Create training data
375
- train_data = []
376
- x0_vals = np.linspace(-0.5, 0.5, int(np.sqrt(Ntr))+1)
377
- y0_vals = np.linspace(-0.5, 0.5, int(np.sqrt(Ntr))+1)
378
- x0_mesh, y0_mesh = np.meshgrid(x0_vals, y0_vals)
379
- x0_mesh = x0_mesh.flatten()
380
- y0_mesh = y0_mesh.flatten()
381
-
382
- for i in range(Ntr):
383
- train_data.append(gaussian_2d(X, Y, x0_mesh[i], y0_mesh[i], sigma))
384
- train_data = np.stack(train_data, axis=-1)
385
-
386
- # Create test data
387
- x0_vals_test = np.random.uniform(-0.5, 0.5, Nte)
388
- y0_vals_test = np.random.uniform(-0.5, 0.5, Nte)
389
- x0_mesh_test = x0_vals_test
390
- y0_mesh_test = y0_vals_test
391
-
392
- test_data = []
393
- for i in range(Nte):
394
- test_data.append(gaussian_2d(X, Y, x0_mesh_test[i], y0_mesh_test[i], sigma))
395
- test_data = np.stack(test_data, axis=-1)
396
-
397
- # Normalize the data
398
- train_data = (train_data - np.mean(train_data)) / np.std(train_data)
399
- test_data = (test_data - np.mean(test_data)) / np.std(test_data)
400
- # Split the data into training and test sets
401
- x_train = torch.tensor(np.expand_dims(train_data, 0), dtype=torch.float32)
402
- x_test = torch.tensor(np.expand_dims(test_data, 0), dtype=torch.float32)
403
- y_train = torch.tensor(np.expand_dims(train_data, 0), dtype=torch.float32)
404
- y_test = torch.tensor(np.expand_dims(test_data, 0), dtype=torch.float32)
405
- p_train = torch.tensor(np.stack([x0_mesh, y0_mesh], axis=-1), dtype=torch.float32)
406
- p_test = torch.tensor(np.stack([x0_mesh_test, y0_mesh_test], axis=-1), dtype=torch.float32)
407
- return x_train, x_test, p_train, p_test, y_train, y_test, lambda x: x, lambda x: x, ()
408
-
409
-
410
- case "CIFAR-10":
411
- import pickle
412
- import os
413
-
414
- def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
415
- with open(os.path.join(cifar10_dataset_folder_path, 'data_batch_' + str(batch_id)), mode='rb') as file:
416
- batch = pickle.load(file, encoding='latin1')
417
- features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
418
- labels = batch['labels']
419
- return features, labels
420
-
421
- def load_cifar10(cifar10_dataset_folder_path):
422
- x_train = []
423
- y_train = []
424
- for batch_id in range(1, 6):
425
- features, labels = load_cifar10_batch(cifar10_dataset_folder_path, batch_id)
426
- x_train.extend(features)
427
- y_train.extend(labels)
428
- x_train = np.array(x_train)
429
- y_train = np.array(y_train)
430
- with open(os.path.join(cifar10_dataset_folder_path, 'test_batch'), mode='rb') as file:
431
- batch = pickle.load(file, encoding='latin1')
432
- x_test = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
433
- y_test = np.array(batch['labels'])
434
- return x_train, x_test, y_train, y_test
435
-
436
- cifar10_dataset_folder_path = folder
437
- x_train, x_test, y_train, y_test = load_cifar10(cifar10_dataset_folder_path)
438
- pre_func_in = lambda x: np.array(x, dtype=np.float32) / 255.0
439
- pre_func_out = lambda x: np.array(x, dtype=np.float32) / 255.0
440
- x_train = np.swapaxes(x_train, 0, -1)
441
- x_test = np.swapaxes(x_test, 0, -1)
442
- return x_train, x_test, None, None, x_train, x_test, pre_func_in, pre_func_out, ()
443
-
444
- case "CelebA":
445
- data_res = 160
446
- import os
447
- from PIL import Image
448
- import numpy as np
449
- from skimage.transform import resize
450
-
451
- if os.path.exists(f"{folder}/celeba_data_{data_res}.npy"):
452
- print("Loading data from file")
453
- data = np.load(f"{folder}/celeba_data_{data_res}.npy")
454
- else:
455
- print("Loading data and processing...")
456
- data = np.load(f"{folder}/celeba_data.npy")
457
- celeb_transform = lambda im: np.astype(
458
- resize(im, (data_res, data_res, 3), order=1, anti_aliasing=True)
459
- * 255.0,
460
- np.uint8,
461
- )
462
- all_data = []
463
- for i in tqdm(range(data.shape[0])):
464
- all_data.append(celeb_transform(data[i]))
465
-
466
- data = np.stack(all_data, axis=0)
467
- data = np.swapaxes(data, 0, 3)
468
- np.save(f"{folder}/celeba_data_{data_res}.npy", data)
469
-
470
- print("Data shape: ", data.shape)
471
- x_train = data[..., :162770]
472
- x_test = data[..., 182638:]
473
- y_train = x_train
474
- y_test = x_test
475
- pre_func_in = lambda x: np.astype(x, np.float32) / 255.0
476
- pre_func_out = lambda x: np.astype(x, np.float32) / 255.0
477
- return (
478
- x_train,
479
- x_test,
480
- None,
481
- None,
482
- y_train,
483
- y_test,
484
- pre_func_in,
485
- pre_func_out,
486
- (),
487
- )
488
-
489
-
490
- case "shift":
491
- ts = np.linspace(0, 2 * np.pi, 100)
492
-
493
- def sf_func(s):
494
- return np.sin(ts - s * np.pi)
495
-
496
- p_vals = np.linspace(0, 1.8, 200)[:-1] # 18
497
- y_shift = np_vmap(sf_func)(p_vals).T
498
- p_test = np.linspace(0, np.max(p_vals), 500)[1:-1]
499
- y_test = np_vmap(sf_func)(p_test).T
500
- y_all = np.concatenate([y_shift, y_test], axis=-1)
501
- p_all = np.concatenate([p_vals, p_test], axis=0)
502
- return divide_return(y_all, p_all, test_end=y_test.shape[-1])
503
-
504
- case "gaussian_shift":
505
- ts = np.linspace(0, 2 * np.pi, 200)
506
- def gauss_shift(s):
507
- return np.exp(-((ts - s) ** 2) / 0.1) # Smaller width
508
- p_vals = np.linspace(1, 2 * np.pi +1, 20)
509
- ts = np.linspace(0, 2 * np.pi + 2, 500)
510
- y_shift = np_vmap(gauss_shift)(p_vals).T
511
- p_test = np.linspace(np.min(p_vals), np.max(p_vals), 500)[1:-1]
512
- y_test = np_vmap(gauss_shift)(p_test).T
513
- y_all = np.concatenate([y_shift, y_test], axis=-1)
514
- p_all = np.concatenate([p_vals, p_test], axis=0)
515
- return divide_return(y_all, p_all, test_end=y_test.shape[-1])
516
-
517
- case "stairs":
518
- Tend = 3.5 # [s]
519
- NT = 500
520
- nt = NT + 1
521
- times = np.linspace(0, Tend, nt)
522
- freq = 1 # [Hz] # 3
523
- wrad = 2 * np.pi * freq
524
- nAmp = 100 # 60
525
- yo = 2.3
526
- Amp = np.arange(1, 5, 0.1)
527
- phases = np.linspace(1 / 4 * Tend, 3 / 4 * Tend, nAmp)
528
- p_vals = Amp
529
-
530
- def find_ph(amp):
531
- return phases[0] + (amp - Amp[0]) / (Amp[1] - Amp[0]) * (
532
- phases[1] - phases[0]
533
- )
534
-
535
- def create_escal(amp):
536
- return np.cumsum(
537
- (
538
- (
539
- np.abs(
540
- (
541
- amp
542
- * np.sqrt(times)
543
- * np.sin(wrad * (times - find_ph(amp)))
544
- )
545
- - yo
546
- )
547
- + (
548
- (
549
- amp
550
- * np.sqrt(times)
551
- * np.sin(wrad * (times - find_ph(amp)))
552
- )
553
- - yo
554
- )
555
- )
556
- / 2
557
- )
558
- ** 5
559
- )
560
-
561
- y_shift_old = np_vmap(create_escal)(p_vals).T
562
- y_shift = np_vmap(
563
- lambda y: (y - np.mean(y_shift_old)) / np.std(y_shift_old)
564
- )(y_shift_old.T).T
565
- y_shift = y_shift[:, ~np.isnan(y_shift).any(axis=0)]
566
-
567
- p_test = random.uniform(
568
- np.min(p_vals) * 1.00001,
569
- np.max(p_vals) * 0.99999,
570
- (300,),
571
- )
572
- y_test = np_vmap(
573
- lambda y: (y - np.mean(y_shift_old)) / np.std(y_shift_old)
574
- )(np_vmap(create_escal)(p_test)).T
575
- y_all = np.concatenate([y_shift, y_test], axis=-1)
576
- p_all = np.concatenate([p_vals, p_test], axis=0)
577
- ts = np.arange(0, y_shift.shape[0], 1)
578
- return divide_return(y_all, p_all, test_end=y_test.shape[-1])
579
-
580
- case "mult_freqs":
581
- p_vals_0 = np.repeat(np.linspace(0.5 * np.pi, np.pi, 15), 15)
582
- p_vals_1 = np.tile(np.linspace(0.3 * np.pi, 0.8 * np.pi, 15), 15)
583
- p_vals = np.stack([p_vals_0, p_vals_1], axis=-1)
584
- ts = np.arange(0, 5 * np.pi, 0.01)
585
- y_shift = np_vmap(lambda p: np.sin(p[0] * ts) + np.sin(p[1] * ts))(
586
- p_vals
587
- ).T
588
-
589
- p_vals_0 = random.uniform(
590
- p_vals_0[0] * 1.001,
591
- p_vals_0[-1] * 0.999,
592
- (1000,),
593
- )
594
- p_vals_1 = random.uniform(
595
- p_vals_1[0] * 1.001,
596
- p_vals_1[-1] * 0.999,
597
- (1000,),
598
- )
599
- p_test = np.stack([p_vals_0, p_vals_1], axis=-1)
600
- y_test = np_vmap(lambda p: np.sin(p[0] * ts) + np.sin(p[1] * ts))(
601
- p_test
602
- ).T
603
- y_all = np.concatenate([y_shift, y_test], axis=-1)
604
- p_all = np.concatenate([p_vals, p_test], axis=0)
605
- return divide_return(y_all, p_all, test_end=y_test.shape[-1])
606
-
607
- case "mult_gausses":
608
-
609
- p_vals_0 = np.repeat(np.linspace(1, 3, 10), 100)
610
- p_vals_1 = np.tile(np.linspace(4, 6, 10), 100)
611
- p_vals = np.stack([p_vals_0, p_vals_1], axis=-1)
612
- p_test_0 = random.uniform(
613
- p_vals_0[0] * 1.001,
614
- p_vals_0[-1] * 0.999,
615
- (1000,),
616
- )
617
- p_test_1 = random.uniform(
618
- p_vals_1[0] * 1.001,
619
- p_vals_1[-1] * 0.999,
620
- (1000,),
621
- )
622
- p_test = np.stack([p_test_0, p_test_1], axis=-1)
623
-
624
- ts = np.arange(0, 6, 0.005)
625
-
626
- def gauss(a, b, c, t):
627
- return a * np.exp(-((t - b) ** 2) / (2 * c**2))
628
-
629
- a = 1.3
630
- c = 0.2
631
- y_shift = np_vmap(
632
- lambda p: gauss(a, p[0], c, ts) + gauss(-a, p[1], c, ts)
633
- )(p_vals).T
634
- y_test = np_vmap(
635
- lambda p: gauss(a, p[0], c, ts) + gauss(-a, p[1], c, ts)
636
- )(p_test).T
637
- y_all = np.concatenate([y_shift, y_test], axis=-1)
638
- p_all = np.concatenate([p_vals, p_test], axis=0)
639
- return divide_return(y_all, p_all, test_end=y_test.shape[-1])
640
-
641
- case "fashion_mnist":
642
- import pandas
643
- x_train = pandas.read_csv("fashin_mnist/fashion-mnist_train.csv").to_numpy().T[1:]
644
- x_test = pandas.read_csv("fashin_mnist/fashion-mnist_test.csv").to_numpy().T[1:]
645
- y_all = np.concatenate([x_train, x_test], axis=-1)
646
- y_all = np.reshape(y_all, (1, 28, 28, -1))
647
- pre_func_in = lambda x: np.astype(x, np.float32) / 255
648
- return divide_return(y_all, None, test_end=x_test.shape[-1], pre_func_in=pre_func_in, pre_func_out=pre_func_in)
649
-
650
- case "mnist":
651
- import os
652
- import gzip
653
- import numpy as np
654
- import pickle as pkl
655
-
656
- if os.path.exists(f"{folder}/mnist_data.npy"):
657
- print("Loading data from file")
658
- with open(f"{folder}/mnist_data.npy", "rb") as f:
659
- train_images, train_labels, test_images, test_labels = pkl.load(f)
660
- else:
661
- print("Loading data and processing...")
662
-
663
- def load_mnist_images(filename):
664
- with gzip.open(filename, 'rb') as f:
665
- data = np.frombuffer(f.read(), np.uint8, offset=16)
666
- data = data.reshape(-1, 28, 28)
667
- return data
668
-
669
- def load_mnist_labels(filename):
670
- with gzip.open(filename, 'rb') as f:
671
- data = np.frombuffer(f.read(), np.uint8, offset=8)
672
- return data
673
-
674
- def load_mnist(path):
675
- train_images = load_mnist_images(os.path.join(path, 'train-images-idx3-ubyte.gz'))
676
- train_labels = load_mnist_labels(os.path.join(path, 'train-labels-idx1-ubyte.gz'))
677
- test_images = load_mnist_images(os.path.join(path, 't10k-images-idx3-ubyte.gz'))
678
- test_labels = load_mnist_labels(os.path.join(path, 't10k-labels-idx1-ubyte.gz'))
679
- return (train_images, train_labels), (test_images, test_labels)
680
-
681
- def preprocess_mnist(images):
682
- images = images.astype(np.float32) / 255.0
683
- images = np.expand_dims(images, axis=1) # Add channel dimension
684
- return images
685
-
686
- def get_mnist_data(path):
687
- (train_images, train_labels), (test_images, test_labels) = load_mnist(path)
688
- train_images = preprocess_mnist(train_images)
689
- test_images = preprocess_mnist(test_images)
690
- return train_images, train_labels, test_images, test_labels
691
-
692
- train_images, train_labels, test_images, test_labels = get_mnist_data(folder)
693
- train_images = np.swapaxes(np.moveaxis(train_images, 1, -1), 0, -1)
694
- test_images = np.swapaxes(np.moveaxis(test_images, 1, -1), 0, -1)
695
- with open(f"{folder}/mnist_data.npy", "wb") as f:
696
- pkl.dump((train_images, train_labels, test_images, test_labels), f)
697
-
698
- return (
699
- train_images,
700
- test_images,
701
- None,
702
- None,
703
- train_images,
704
- test_images,
705
- lambda x: x,
706
- lambda x: x,
707
- (),
708
- )
709
-
710
- case _:
711
- raise ValueError(f"Problem {problem} not recognized")
712
-
713
- ########################### SVD functions #################################
714
-
715
- def stable_SVD(A):
716
- """ Computes a numerically stable SVD with optional noise injection.
717
-
718
- A: input matrix (..., m, n)
719
- noise_std: std of noise to inject into singular values
720
- eps: numerical stability constant
721
- """
722
- return StableSVD.apply(A)
723
-
724
- class StableSVD(torch.autograd.Function):
725
- @staticmethod
726
- def forward(ctx, A):
727
- """
728
- A: input matrix (..., m, n)
729
- noise_std: std of noise to inject into singular values
730
- eps: numerical stability constant
731
- """
732
- U, S, Vh = torch.linalg.svd(A, full_matrices=False)
733
-
734
- ctx.save_for_backward(U, S, Vh)
735
-
736
- return U, S, Vh
737
-
738
- @staticmethod
739
- def backward(ctx, dU, dS, dVh):
740
- U, S, Vh = ctx.saved_tensors
741
-
742
- scale = S[..., :1]
743
- mask = (S / scale * 100 >= 1e-6)
744
- S_stable = S * mask
745
-
746
- k = S.shape[-1]
747
-
748
- V = Vh.transpose(-2, -1)
749
- dV = dVh.transpose(-2, -1)
750
-
751
- S_i = S_stable.unsqueeze(-1)
752
- S_j = S_stable.unsqueeze(-2)
753
-
754
- diff = S_i**2 - S_j**2
755
- zero = diff == 0
756
- F = torch.zeros_like(diff)
757
- F[~zero] = 1.0 / diff[~zero]
758
-
759
- Ut_dU = U.transpose(-2, -1) @ dU
760
- Vt_dV = V.transpose(-2, -1) @ dV
761
-
762
- K_U = F * (Ut_dU - Ut_dU.transpose(-2, -1))
763
- K_V = F * (Vt_dV - Vt_dV.transpose(-2, -1))
764
-
765
- K = torch.diag_embed(dS)
766
- K = K + K_U @ torch.diag_embed(S)
767
- K = K + torch.diag_embed(S) @ K_V
768
-
769
- dA = U @ K @ Vh
770
-
771
- return dA, None, None
772
-
773
- ######################## Typical Neural Network Architecturs ##############################
774
-
775
- class MLP_with_linear(torch.nn.Module):
776
- layers_l: tuple[Linear, ...]
777
- """ Similar to an Eauinox MLP but with additional linear multiplications by matrices.
778
-
779
- The only new_paraeter is linear_l and it is the number of matrices to use after the MLP.
780
- e.g. if linear_l = 2, the model wil be constituted of an MLP and two matrix multiplications.
781
-
782
- This is especially useful for IRMAE and LoRAE that use matrix multiplications in their latent
783
- space.
784
- """
785
- def __init__(
786
- self,
787
- *,
788
- in_size,
789
- out_size,
790
- width_size,
791
- depth,
792
- linear_l=0,
793
- **kwargs,
794
- ):
795
- super().__init__()
796
-
797
- if depth == 0:
798
- self.mlp = MLP(
799
- in_channels=in_size,
800
- hidden_channels=[out_size],
801
- **kwargs
802
- )
803
- else:
804
- if isinstance(width_size, int):
805
- width_size = [width_size] * depth
806
-
807
- hidden_dims = width_size + [out_size]
808
-
809
- self.mlp = MLP(
810
- in_channels=in_size,
811
- hidden_channels=hidden_dims,
812
- **kwargs
813
- )
814
-
815
- layers_l = []
816
- if linear_l != 0:
817
- for _ in range(linear_l):
818
- layers_l.append(
819
- Linear(out_size, out_size, bias=False)
820
- )
821
-
822
- self.layers_l = torch.nn.ModuleList(layers_l)
823
-
824
- def forward(self, x, *args, **kwargs):
825
- x = self.mlp(x, *args, **kwargs)
826
- if len(self.layers_l) != 0:
827
- for layer in self.layers_l[:-1]:
828
- x = layer(x)
829
- return x
830
-
831
- class Conv2d_(Conv2d):
832
- """ For compatibility, accepting output_padding as kwarg"""
833
- def __init__(self, *args, **kwargs):
834
- if "output_padding" in kwargs:
835
- kwargs.pop("output_padding")
836
- super().__init__(*args, **kwargs)
837
-
838
- class Conv1d_(Conv1d):
839
- """ For compatibility, accepting output_padding as kwarg"""
840
- def __init__(self, *args, **kwargs):
841
- if "output_padding" in kwargs:
842
- kwargs.pop("output_padding")
843
- super().__init__(*args, **kwargs)
844
-
845
-
846
- class MLCNN(torch.nn.Module):
847
- layers: tuple
848
- activation: Callable
849
- final_activation: Callable
850
- """ A Multilayer CNN model. It consists of multiple conv layers.
851
-
852
- """
853
- def __init__(
854
- self,
855
- start_dim,
856
- out_dim,
857
- stride,
858
- padding,
859
- kernel_conv,
860
- dilation,
861
- CNN_widths,
862
- activation=torch.nn.functional.relu,
863
- final_activation=lambda x: x,
864
- transpose=False,
865
- output_padding=None,
866
- dimension=2,
867
- *,
868
- kwargs_cnn={},
869
- **kwargs,
870
- ):
871
- """Note: if provided as lists, activations should be one less than widths.
872
- The last activation is specified by "final activation".
873
- """
874
- super().__init__()
875
-
876
- if isinstance(CNN_widths, list):
877
- CNNs_num = len(CNN_widths)
878
- else:
879
- CNN_widths = [CNN_widths] * (CNNs_num - 1) + [out_dim]
880
-
881
- CNN_widths_b = [start_dim] + CNN_widths[:-1]
882
- layers = []
883
- if dimension == 2:
884
- fn = Conv2d_ if not transpose else ConvTranspose2d
885
- elif dimension == 1:
886
- fn = Conv1d_ if not transpose else ConvTranspose1d
887
-
888
- for i in range(len(CNN_widths)):
889
- layers.append(
890
- fn(
891
- CNN_widths_b[i],
892
- CNN_widths[i],
893
- kernel_size=kernel_conv,
894
- stride=stride,
895
- padding=padding,
896
- dilation=dilation,
897
- output_padding=output_padding,
898
- **kwargs_cnn,
899
- )
900
- )
901
-
902
- self.layers = torch.nn.ModuleList(layers)
903
-
904
- self.activation = [lambda x: activation(x) for _ in CNN_widths]
905
-
906
- self.final_activation = final_activation
907
-
908
- def forward(self, x):
909
- for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation[:-1])):
910
- x = layer(x)
911
- x = act(x)
912
- x = self.final_activation(self.layers[-1](x))
913
- return x
914
-
915
-
916
- class CNNs_with_MLP(torch.nn.Module):
917
- """Class mainly for creating encoders with CNNs.
918
- The encoder is composed of multiple CNNs followed by an MLP.
919
- """
920
-
921
- layers: tuple[MLCNN, MLP]
922
-
923
- def __init__(
924
- self,
925
- width,
926
- height,
927
- out,
928
- channels=1,
929
- width_CNNs=[32, 64, 128, 256],
930
- kernel_conv=3,
931
- stride=2,
932
- padding=1,
933
- dilation=1,
934
- mlp_width=None,
935
- mlp_depth=0,
936
- dimension=2,
937
- final_activation=lambda x: x,
938
- *,
939
- kwargs_cnn={},
940
- kwargs_mlp={},
941
- ):
942
- super().__init__()
943
-
944
- if mlp_depth != 0:
945
- if mlp_width is not None:
946
- assert (
947
- mlp_width >= out
948
- ), "Choose a bigger (or equal) MLP width than the latent space in the encoder."
949
- else:
950
- mlp_width = out
951
-
952
- try:
953
- last_width = width_CNNs[-1]
954
- except:
955
- last_width = width_CNNs
956
-
957
- mlcnn = MLCNN(
958
- channels,
959
- last_width,
960
- stride,
961
- padding,
962
- kernel_conv,
963
- dilation,
964
- width_CNNs,
965
- dimension=dimension,
966
- final_activation=torch.nn.functional.relu,
967
- **kwargs_cnn,
968
- )
969
-
970
- if dimension == 2:
971
- final_Ds = mlcnn(torch.zeros((channels, height, width))).shape[-2:]
972
- fD = final_Ds[0] * final_Ds[1]
973
- elif dimension == 1:
974
- fD = mlcnn(torch.zeros((channels, height))).shape[-1]
975
-
976
- mlp = MLP_with_linear(
977
- in_size=fD * last_width,
978
- out_size=out,
979
- width_size=mlp_width,
980
- depth=mlp_depth,
981
- **kwargs_mlp,
982
- )
983
-
984
- act = lambda x: final_activation(x)
985
- self.layers = torch.nn.ModuleList([mlcnn, mlp])
986
- self.final_act = act
987
-
988
- def forward(self, x, *args, **kwargs):
989
- x = self.layers[0](x)
990
- x = torch.unsqueeze(torch.flatten(x), -1)
991
- x = self.layers[1](torch.squeeze(x))
992
- x = self.final_act(x)
993
- return x
994
-
995
- def prev_D_CNN_trans(D0, D1, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[]):
996
- pad = int_to_lst(pad, 2)
997
- ker = int_to_lst(ker, 2)
998
- st = int_to_lst(st, 2)
999
- dil = int_to_lst(dil, 2)
1000
- outpad = int_to_lst(outpad, 2)
1001
-
1002
- if num == 0:
1003
- return all_D0s, all_D1s
1004
-
1005
- all_D0s.append(int(np.ceil(D0)))
1006
- all_D1s.append(int(np.ceil(D1)))
1007
-
1008
- return prev_D_CNN_trans(
1009
- (D0 + 2 * pad[0] - dil[0] * (ker[0] - 1) - 1 - outpad[0]) / st[0] + 1,
1010
- (D1 + 2 * pad[1] - dil[1] * (ker[1] - 1) - 1 - outpad[1]) / st[1] + 1,
1011
- pad,
1012
- ker,
1013
- st,
1014
- dil,
1015
- outpad,
1016
- num - 1,
1017
- )
1018
-
1019
-
1020
- def find_padding_convT(D, data_dim0, ker, st, dil, outpad):
1021
-
1022
- return D
1023
-
1024
-
1025
- def next_CNN_trans(O0, O1, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[]):
1026
- pad = int_to_lst(pad, 2)
1027
- ker = int_to_lst(ker, 2)
1028
- st = int_to_lst(st, 2)
1029
- dil = int_to_lst(dil, 2)
1030
- outpad = int_to_lst(outpad, 2)
1031
-
1032
- if num == 0:
1033
- return all_D0s, all_D1s
1034
-
1035
- all_D0s.append(int(O0))
1036
- all_D1s.append(int(O1))
1037
-
1038
- return next_CNN_trans(
1039
- (O0 - 1) * st[0] + dil[0] * (ker[0] - 1) - 2 * pad[0] + 1 + outpad[0],
1040
- (O1 - 1) * st[1] + dil[1] * (ker[1] - 1) - 2 * pad[1] + 1 + outpad[1],
1041
- pad,
1042
- ker,
1043
- st,
1044
- dil,
1045
- outpad,
1046
- num - 1,
1047
- )
1048
-
1049
-
1050
- def is_convT_valid(D0, D1, data_dim0, data_dim1, pad, ker, st, dil, outpad, nums):
1051
- all_D0s, all_D1s = next_CNN_trans(D0, D1, pad, ker, st, dil, outpad, nums, all_D0s=[], all_D1s=[])
1052
- final_D0 = all_D0s[-1]
1053
- final_D1 = all_D1s[-1]
1054
- return final_D0 == data_dim0, final_D1 == data_dim1, final_D0, final_D1
1055
-
1056
- class MLP_with_CNNs_trans(torch.nn.Module):
1057
- """Class mainly for creating encoders with CNNs.
1058
- The encoder is composed of multiple CNNs followed by an MLP.
1059
- """
1060
-
1061
- layers: tuple[MLCNN, Linear]
1062
- start_dim: int
1063
- d_shape: int
1064
- out_after_mlp: int
1065
- final_act: Callable
1066
-
1067
- def __init__(
1068
- self,
1069
- width,
1070
- height,
1071
- inp,
1072
- channels,
1073
- out_after_mlp=32,
1074
- width_CNNs=[64, 32],
1075
- kernel_conv=3,
1076
- stride=2,
1077
- padding=1,
1078
- dilation=1,
1079
- output_padding=1,
1080
- final_activation=lambda x: x,
1081
- mlp_width=None,
1082
- mlp_depth=0,
1083
- dimension=2,
1084
- *,
1085
- kwargs_cnn={},
1086
- kwargs_mlp={},
1087
- ):
1088
- super().__init__()
1089
-
1090
- width = 10 if width is None else width # a default value to avoid errors
1091
-
1092
- D0s, D1s = prev_D_CNN_trans(
1093
- height,
1094
- width,
1095
- padding,
1096
- kernel_conv,
1097
- stride,
1098
- dilation,
1099
- output_padding,
1100
- len(width_CNNs) + 1,
1101
- all_D0s=[],
1102
- all_D1s=[],
1103
- )
1104
-
1105
- first_D0 = D0s[-1]
1106
- first_D1 = D1s[-1]
1107
-
1108
- _, _, final_D0, final_D1 = is_convT_valid(
1109
- first_D0,
1110
- first_D1,
1111
- height,
1112
- width,
1113
- padding,
1114
- kernel_conv,
1115
- stride,
1116
- dilation,
1117
- output_padding,
1118
- len(width_CNNs) + 1,
1119
- )
1120
-
1121
- if dimension == 2:
1122
- self.d_shape = [first_D0, first_D1]
1123
- elif dimension == 1:
1124
- self.d_shape = [first_D0]
1125
-
1126
- mlcnn = MLCNN(
1127
- out_after_mlp,
1128
- width_CNNs[-1],
1129
- stride,
1130
- padding,
1131
- kernel_conv,
1132
- dilation,
1133
- width_CNNs,
1134
- dimension=dimension,
1135
- transpose=True,
1136
- output_padding=output_padding,
1137
- final_activation=torch.nn.functional.relu,
1138
- **kwargs_cnn,
1139
- )
1140
-
1141
- if mlp_depth != 0:
1142
- if mlp_width is not None:
1143
- assert (
1144
- mlp_width >= inp
1145
- ), "Choose a bigger (or equal) MLP width than the latent space in decoder."
1146
- else:
1147
- mlp_width = inp
1148
-
1149
- mlp = MLP_with_linear(
1150
- in_size=inp,
1151
- out_size=out_after_mlp * int(np.prod(self.d_shape)),
1152
- width_size=mlp_width,
1153
- depth=mlp_depth,
1154
- **kwargs_mlp,
1155
- )
1156
-
1157
- if dimension == 2:
1158
- final_conv = Conv2d(
1159
- width_CNNs[-1],
1160
- channels,
1161
- kernel_size=(1 + (final_D0 - height), 1 + (final_D1 - width)),
1162
- stride=1,
1163
- padding=0,
1164
- dilation=1,
1165
- )
1166
- elif dimension == 1:
1167
- final_conv = Conv1d(
1168
- width_CNNs[-1],
1169
- channels,
1170
- kernel_size=(1 + (final_D0 - height)),
1171
- stride=1,
1172
- padding=0,
1173
- dilation=1,
1174
- )
1175
-
1176
- self.start_dim = inp
1177
-
1178
- self.final_act = final_activation
1179
- self.layers = torch.nn.ModuleList([mlp, mlcnn, final_conv])
1180
- self.out_after_mlp = out_after_mlp
1181
-
1182
- def forward(self, x, *args, **kwargs):
1183
- x = self.layers[0](x)
1184
- x = torch.reshape(x, (self.out_after_mlp, *self.d_shape))
1185
- x = self.layers[1](x)
1186
- x = self.layers[2](x)
1187
- x = self.final_act(x)
1188
- return x
1189
-
1190
- class Conv3d_(Conv3d):
1191
- def __init__(self, *args, **kwargs):
1192
- if "output_padding" in kwargs:
1193
- kwargs.pop("output_padding")
1194
- super().__init__(*args, **kwargs)
1195
-
1196
-
1197
- class MLCNN3D(torch.nn.Module):
1198
- layers: tuple
1199
- activation: Callable
1200
- final_activation: Callable
1201
-
1202
- def __init__(
1203
- self,
1204
- start_dim,
1205
- out_dim,
1206
- stride,
1207
- padding,
1208
- kernel_conv,
1209
- dilation,
1210
- CNN_widths,
1211
- activation=torch.nn.functional.relu,
1212
- final_activation=lambda x: x,
1213
- transpose=False,
1214
- output_padding=None,
1215
- *,
1216
- kwargs_cnn={},
1217
- **kwargs,
1218
- ):
1219
- """Note: if provided as lists, activations should be one less than widths.
1220
- The last activation is specified by "final activation"."""
1221
- super().__init__()
1222
-
1223
- if isinstance(CNN_widths, list):
1224
- CNNs_num = len(CNN_widths)
1225
- else:
1226
- CNN_widths = [CNN_widths] * (CNNs_num - 1) + [out_dim]
1227
-
1228
- CNN_widths_b = [start_dim] + CNN_widths[:-1]
1229
- layers = []
1230
- fn = Conv3d_ if not transpose else ConvTranspose3d
1231
- for i in range(len(CNN_widths)):
1232
- layers.append(
1233
- fn(
1234
- CNN_widths_b[i],
1235
- CNN_widths[i],
1236
- kernel_size=kernel_conv,
1237
- stride=stride,
1238
- padding=padding,
1239
- dilation=dilation,
1240
- output_padding=output_padding,
1241
- **kwargs_cnn,
1242
- )
1243
- )
1244
-
1245
- self.layers = torch.nn.ModuleList(layers)
1246
-
1247
- self.activation = [lambda x: activation(x) for _ in CNN_widths]
1248
-
1249
- self.final_activation = final_activation
1250
-
1251
-
1252
- def forward(self, x):
1253
- for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation[:-1])):
1254
- x = layer(x)
1255
- x = act(x)
1256
- x = self.final_activation(self.layers[-1](x))
1257
- return x
1258
-
1259
-
1260
- class CNN3D_with_MLP(torch.nn.Module):
1261
- """Class mainly for creating encoders with CNNs.
1262
- The encoder is composed of multiple CNNs followed by an MLP.
1263
- """
1264
-
1265
- layers: tuple[MLCNN3D, Linear]
1266
-
1267
- def __init__(
1268
- self,
1269
- depth,
1270
- height,
1271
- width,
1272
- out,
1273
- channels=1,
1274
- width_CNNs=[32, 64, 128, 256],
1275
- kernel_conv=3,
1276
- stride=2,
1277
- padding=1,
1278
- dilation=1,
1279
- mlp_width=None,
1280
- mlp_depth=0,
1281
- final_activation=lambda x: x,
1282
- *,
1283
- kwargs_cnn={},
1284
- kwargs_mlp={},
1285
- ):
1286
- super().__init__()
1287
-
1288
- if mlp_depth != 0:
1289
- if mlp_width is not None:
1290
- assert (
1291
- mlp_width >= out
1292
- ), "Choose a bigger (or equal) MLP width than the latent space in the encoder."
1293
- else:
1294
- mlp_width = out
1295
-
1296
- try:
1297
- last_width = width_CNNs[-1]
1298
- except:
1299
- last_width = width_CNNs
1300
-
1301
- mlcnn3d = MLCNN3D(
1302
- channels,
1303
- last_width,
1304
- stride,
1305
- padding,
1306
- kernel_conv,
1307
- dilation,
1308
- width_CNNs,
1309
- final_activation=torch.nn.functional.relu,
1310
- **kwargs_cnn,
1311
- )
1312
- final_Ds = mlcnn3d(torch.zeros((channels, depth, height, width))).shape[-3:]
1313
- mlp = MLP_with_linear(
1314
- in_size=final_Ds[0] * final_Ds[1] * final_Ds[2] *last_width,
1315
- out_size=out,
1316
- width_size=mlp_width,
1317
- depth=mlp_depth,
1318
- **kwargs_mlp,
1319
- )
1320
- act = lambda x: final_activation(x)
1321
- self.layers = torch.nn.ModuleList([mlcnn3d, mlp])
1322
- self.final_act = act
1323
-
1324
- def forward(self, x, *args, **kwargs):
1325
- x = self.layers[0](x)
1326
- x = torch.unsqueeze(torch.flatten(x), -1)
1327
- x = self.layers[1](torch.squeeze(x))
1328
- x = self.final_act(x)
1329
- return x
1330
-
1331
- def prev_D_CNN3D_trans(D0, D1, D2, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[], all_D2s=[]):
1332
- pad = int_to_lst(pad, 3)
1333
- ker = int_to_lst(ker, 3)
1334
- st = int_to_lst(st, 3)
1335
- dil = int_to_lst(dil, 3)
1336
- outpad = int_to_lst(outpad, 3)
1337
-
1338
- if num == 0:
1339
- return all_D0s, all_D1s , all_D2s
1340
-
1341
- all_D0s.append(int(np.ceil(D0)))
1342
- all_D1s.append(int(np.ceil(D1)))
1343
- all_D2s.append(int(np.ceil(D2)))
1344
-
1345
- return prev_D_CNN3D_trans(
1346
- (D0 + 2 * pad[0] - dil[0] * (ker[0] - 1) - 1 - outpad[0]) / st[0] + 1,
1347
- (D1 + 2 * pad[1] - dil[1] * (ker[1] - 1) - 1 - outpad[1]) / st[1] + 1,
1348
- (D2 + 2 * pad[2] - dil[2] * (ker[2] - 1) - 1 - outpad[2]) / st[2] + 1,
1349
- pad,
1350
- ker,
1351
- st,
1352
- dil,
1353
- outpad,
1354
- num - 1,
1355
- )
1356
-
1357
- def find_padding_conv3dT(D, data_dim0, ker, st, dil, outpad):
1358
- return D
1359
-
1360
- def next_CNN3D_trans(O0, O1, O2, pad, ker, st, dil, outpad, num, all_D0s=[], all_D1s=[], all_D2s=[]):
1361
- pad = int_to_lst(pad, 3)
1362
- ker = int_to_lst(ker, 3)
1363
- st = int_to_lst(st, 3)
1364
- dil = int_to_lst(dil, 3)
1365
- outpad = int_to_lst(outpad, 3)
1366
-
1367
- if num == 0:
1368
- return all_D0s, all_D1s, all_D2s
1369
-
1370
- all_D0s.append(int(O0))
1371
- all_D1s.append(int(O1))
1372
- all_D2s.append(int(O2))
1373
-
1374
- return next_CNN3D_trans(
1375
- (O0 - 1) * st[0] + dil[0] * (ker[0] - 1) - 2 * pad[0] + 1 + outpad[0],
1376
- (O1 - 1) * st[1] + dil[1] * (ker[1] - 1) - 2 * pad[1] + 1 + outpad[1],
1377
- (O2 - 1) * st[2] + dil[2] * (ker[2] - 1) - 2 * pad[2] + 1 + outpad[2],
1378
- pad,
1379
- ker,
1380
- st,
1381
- dil,
1382
- outpad,
1383
- num - 1,
1384
- )
1385
-
1386
- def is_conv3dT_valid(D0, D1, D2, data_dim0, data_dim1, data_dim2, pad, ker, st, dil, outpad, nums):
1387
- all_D0s, all_D1s, all_D2s = next_CNN3D_trans(D0, D1, D2, pad, ker, st, dil, outpad, nums, all_D0s=[], all_D1s=[], all_D2s=[])
1388
- final_D0 = all_D0s[-1]
1389
- final_D1 = all_D1s[-1]
1390
- final_D2 = all_D2s[-1]
1391
- return final_D0 == data_dim0, final_D1 == data_dim1, final_D2 ==data_dim2, final_D0, final_D1, final_D2
1392
-
1393
- class MLP_with_CNN3D_trans(torch.nn.Module):
1394
- """Class mainly for creating encoders with CNNs.
1395
- The encoder is composed of multiple CNNs followed by an MLP.
1396
- """
1397
-
1398
- layers: tuple[MLCNN3D, Linear]
1399
- start_dim: int
1400
- first_D0: int
1401
- first_D1: int
1402
- first_D2: int
1403
- out_after_mlp: int
1404
- final_act: Callable
1405
-
1406
- def __init__(
1407
- self,
1408
- depth,
1409
- height,
1410
- width,
1411
- inp,
1412
- channels,
1413
- out_after_mlp=32,
1414
- width_CNNs=[64, 32],
1415
- kernel_conv=3,
1416
- stride=2,
1417
- padding=1,
1418
- dilation=1,
1419
- output_padding=1,
1420
- final_activation=lambda x: x,
1421
- mlp_width=None,
1422
- mlp_depth=0,
1423
- *,
1424
- kwargs_cnn={},
1425
- kwargs_mlp={},
1426
- ):
1427
-
1428
- super().__init__()
1429
-
1430
- D0s, D1s, D2s = prev_D_CNN3D_trans(
1431
- depth,
1432
- height,
1433
- width,
1434
- padding,
1435
- kernel_conv,
1436
- stride,
1437
- dilation,
1438
- output_padding,
1439
- len(width_CNNs) + 1,
1440
- all_D0s=[],
1441
- all_D1s=[],
1442
- all_D2s=[],
1443
- )
1444
-
1445
- first_D0 = D0s[-1]
1446
- first_D1 = D1s[-1]
1447
- first_D2 = D2s[-1]
1448
-
1449
- _, _, _, final_D0, final_D1, final_D2 = is_conv3dT_valid(
1450
- first_D0,
1451
- first_D1,
1452
- first_D2,
1453
- depth,
1454
- height,
1455
- width,
1456
- padding,
1457
- kernel_conv,
1458
- stride,
1459
- dilation,
1460
- output_padding,
1461
- len(width_CNNs) + 1,
1462
- )
1463
-
1464
-
1465
- mlcnn3d = MLCNN3D(
1466
- out_after_mlp,
1467
- width_CNNs[-1],
1468
- stride,
1469
- padding,
1470
- kernel_conv,
1471
- dilation,
1472
- width_CNNs,
1473
- transpose=True,
1474
- output_padding=output_padding,
1475
- final_activation=torch.nn.functional.relu,
1476
- **kwargs_cnn,
1477
- )
1478
-
1479
- if mlp_depth != 0:
1480
- if mlp_width is not None:
1481
- assert (
1482
- mlp_width >= inp
1483
- ), "Choose a bigger (or equal) MLP width than the latent space in decoder."
1484
- else:
1485
- mlp_width = inp
1486
-
1487
- mlp = MLP_with_linear(
1488
- in_size=inp,
1489
- out_size=out_after_mlp * first_D0 * first_D1 * first_D2,
1490
- width_size=mlp_width,
1491
- depth=mlp_depth,
1492
- **kwargs_mlp,
1493
- )
1494
-
1495
- final_conv = Conv3d(
1496
- width_CNNs[-1],
1497
- channels,
1498
- kernel_size=(1 + (final_D0 - depth), 1 + (final_D1 - height), 1 + (final_D2 - width)),
1499
- stride=1,
1500
- padding=0,
1501
- dilation=1,
1502
- )
1503
-
1504
- self.start_dim = inp
1505
- self.first_D0 = first_D0
1506
- self.first_D1 = first_D1
1507
- self.first_D2 = first_D2
1508
- self.final_act = final_activation
1509
- self.layers = torch.nn.ModuleList([mlp, mlcnn3d, final_conv])
1510
- self.out_after_mlp = out_after_mlp
1511
-
1512
- def forward(self, x, *args, **kwargs):
1513
- print(x.shape)
1514
- x = self.layers[0](x)
1515
- x = torch.reshape(x, (self.out_after_mlp, self.first_D0, self.first_D1, self.first_D2))
1516
- x = self.layers[1](x)
1517
- x = self.layers[2](x)
1518
- x = self.final_act(x)
1519
- return x
1520
-
1521
- ################# Other Useful functions ################################
1522
-
1523
- def int_to_lst(x, len=1):
1524
- """ Integer to list """
1525
- if isinstance(x, int):
1526
- return [x]*len
1527
- return x
1528
-
1529
- def remove_keys_from_dict(d, keys):
1530
- return {k: v for k, v in d.items() if k not in keys}
1531
-
1532
- def merge_dicts(d1, d2):
1533
- return {**d1, **d2}
1534
-
1535
- def v_print(s, v, f=False):
1536
- if v:
1537
- print(s, flush=f)
1538
-
1539
- def countList(lst1, lst2):
1540
- return [sub[item] for item in range(len(lst2)) for sub in [lst1, lst2]]
1541
-
1542
-
1543
- class Sample(torch.nn.Module):
1544
- sample_dim: int
1545
-
1546
- """ Class used to allow random sampling using JAX.
1547
-
1548
- Funny trick to allow seed to change since otherwise we end up
1549
- with the same noise epsilon for every forward pass.
1550
- """
1551
- def __init__(self, sample_dim):
1552
- super().__init__()
1553
- self.sample_dim = sample_dim
1554
-
1555
- def forward(self, mean, logvar, epsilon=None, ret=False, *args, **kwargs):
1556
- epsilon = 0 if epsilon is None else epsilon
1557
- if ret:
1558
- return mean + torch.exp(0.5 * logvar) * epsilon, mean, logvar
1559
- return mean + torch.exp(0.5 * logvar) * epsilon
1560
-
1561
- def create_epsilon(self, seed, shape):
1562
- return random.normal(size=shape)