RRAEsTorch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,977 @@
1
+
2
+ from __future__ import print_function
3
+ import numpy.random as random
4
+ import numpy as np
5
+ from RRAEsTorch.utilities import (
6
+ remove_keys_from_dict,
7
+ merge_dicts,
8
+ loss_generator,
9
+ eval_with_batches
10
+ )
11
+ import warnings
12
+ import os
13
+ import time
14
+ import dill
15
+ import shutil
16
+ from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
17
+ from functools import partial
18
+ from RRAEsTorch.trackers import (
19
+ Null_Tracker,
20
+ RRAE_fixed_Tracker,
21
+ RRAE_pars_Tracker,
22
+ RRAE_gen_Tracker,
23
+ )
24
+ import matplotlib.pyplot as plt
25
+ from prettytable import PrettyTable
26
+ import torch
27
+ from torch.utils.data import TensorDataset, DataLoader
28
+ from adabelief_pytorch import AdaBelief
29
+
30
+ class Circular_list:
31
+ """
32
+ Creates a list of fixed size.
33
+ Adds elements in a circular manner
34
+ """
35
+
36
+ def __init__(self, size):
37
+ self.size = size
38
+ self.buffer = [0.0] * size
39
+ self.index = 0
40
+
41
+ def add(self, value):
42
+ self.buffer[self.index] = value
43
+ self.index = (self.index + 1) % self.size
44
+
45
+ def __iter__(self):
46
+ for value in self.buffer:
47
+ yield value
48
+
49
+ class Standard_Print():
50
+ def __init__(self, aux, *args, **kwargs):
51
+ self.aux = aux
52
+
53
+ def __str__(self):
54
+ message = ", ".join([f"{k}: {v}" for k, v in self.aux.items()])
55
+ return message
56
+
57
+ class Pretty_Print(PrettyTable):
58
+ def __init__(self, aux, window_size=5, format_numbers=True, printer_settings={}):
59
+ self.aux = aux
60
+ self.format_numbers = format_numbers
61
+ self.set_title = False
62
+ self.first_print = True
63
+ self.window_size = window_size
64
+ self.index_new = 0
65
+ self.index_old = 0
66
+ super().__init__(**printer_settings)
67
+
68
+ def format_number(self, n):
69
+ if isinstance(n, int):
70
+ return "{:.0f}".format(n)
71
+ else:
72
+ return "{:.3f}".format(n)
73
+
74
+ def __str__(self):
75
+ data = list(self.aux.values())
76
+ if self.format_numbers == True:
77
+ data = list(map(self.format_number, data))
78
+
79
+ if self.first_print == True:
80
+ titles = list(self.aux.keys())
81
+ self.field_names = titles
82
+ self.title = "Results"
83
+ self.set_title = True
84
+ for _ in range(self.window_size):
85
+ self.add_row([" "]*len(titles))
86
+
87
+ self._rows[self.index_new] = data
88
+ print(super().__str__())
89
+ print(f"\033[{self.window_size+1}A", end="")
90
+ self.first_print = False
91
+
92
+
93
+ self._rows[self.index_new] = data
94
+
95
+ # This function does a lot of unnecessary things... Removed the parts that I don't want
96
+ # print( "\n".join(self.get_string(start=self.index_new,
97
+ # end=self.index_new+1,
98
+ # float_format="3.3").splitlines()[-2:]))
99
+
100
+ options = self._get_options({})
101
+
102
+ lines = []
103
+
104
+ # Get the rows we need to print, taking into account slicing, sorting, etc.
105
+ formatted_rows = [self._format_row(row) for row in self._rows[self.index_new : self.index_new+1]]
106
+
107
+ # Compute column widths
108
+ self._compute_widths(formatted_rows, options)
109
+ self._hrule = self._stringify_hrule(options)
110
+
111
+ # Add rows
112
+ if formatted_rows:
113
+ lines.append(
114
+ self._stringify_row(
115
+ formatted_rows[-1],
116
+ options,
117
+ self._stringify_hrule(options, where="bottom_"),
118
+ )
119
+ )
120
+
121
+ # Add bottom of border
122
+ lines.append(self._stringify_hrule(options, where="bottom_"))
123
+
124
+ print("\n".join(lines))
125
+
126
+
127
+ # Update indices
128
+ self.index_old = self.index_new
129
+ self.index_new = (self.index_new + 1) % self.window_size
130
+
131
+ # if we move to another printing cycle, push cursor back
132
+ # Dirty trick but works on ubuntu...
133
+ if (self.index_new - self.index_old) != 1:
134
+ print(f"\033[{self.window_size*2}A", end="")
135
+ # Factor of 2 is due to the lower line in the table
136
+
137
+ return '\033[1A'
138
+
139
+ class Print_Info(PrettyTable):
140
+ def __init__(self, print_type="std", aux={}, *args, **kwargs):
141
+ check = (print_type.lower() == "std")
142
+ if check == True:
143
+ self.print_obj = Standard_Print(aux, *args, **kwargs)
144
+ else:
145
+ self.print_obj = Pretty_Print(aux, *args, **kwargs)
146
+
147
+ def update_aux(self, aux):
148
+ self.print_obj.aux = aux
149
+
150
+ def __str__(self):
151
+ return self.print_obj.__str__()
152
+
153
+
154
+ class Trainor_class:
155
+ def __init__(
156
+ self,
157
+ in_train=None,
158
+ model_cls=None,
159
+ folder="",
160
+ file=None,
161
+ out_train=None,
162
+ norm_in="None",
163
+ norm_out="None",
164
+ methods_map=["__call__"],
165
+ methods_norm_in=["__call__"],
166
+ methods_norm_out=["__call__"],
167
+ call_map_count=1,
168
+ call_map_axis=-1,
169
+ **kwargs,
170
+ ):
171
+ if model_cls is not None:
172
+ orig_model_cls = model_cls
173
+ model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
174
+ model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
175
+ self.model = model_cls(**kwargs)
176
+ params_in = self.model.params_in
177
+ params_out = self.model.params_out
178
+ else:
179
+ orig_model_cls = None
180
+ params_in = None
181
+ params_out = None
182
+
183
+ self.all_kwargs = {
184
+ "kwargs": kwargs,
185
+ "params_in": params_in,
186
+ "params_out": params_out,
187
+ "norm_in": norm_in,
188
+ "norm_out": norm_out,
189
+ "call_map_axis": call_map_axis,
190
+ "call_map_count": call_map_count,
191
+ "orig_model_cls": orig_model_cls,
192
+ "methods_map": methods_map,
193
+ "methods_norm_in": methods_norm_in,
194
+ "methods_norm_out": methods_norm_out,
195
+ }
196
+
197
+ self.folder = folder
198
+ if folder != "":
199
+ if not os.path.exists(folder):
200
+ os.makedirs(folder)
201
+ self.file = file
202
+
203
+ def fit(
204
+ self,
205
+ input,
206
+ output,
207
+ loss_type="default", # should be string to use pre defined functions
208
+ loss=None, # a function loss(pred, true) to differentiate in the model
209
+ step_st=[3000, 3000], # 000, 8000],
210
+ lr_st=[1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9],
211
+ print_every=np.nan,
212
+ save_every=np.nan,
213
+ batch_size_st=[16, 16, 16, 16, 32],
214
+ regression=False,
215
+ verbose=True,
216
+ loss_kwargs={},
217
+ flush=False,
218
+ pre_func_inp=lambda x: x,
219
+ pre_func_out=lambda x: x,
220
+ fix_comp=lambda _: (),
221
+ tracker=Null_Tracker(),
222
+ stagn_window=20,
223
+ eps_fn=lambda lat, bs: None,
224
+ optimizer=AdaBelief,
225
+ verbatim = {
226
+ "print_type": "std",
227
+ "window_size" : 5,
228
+ "printer_settings":{"padding_width": 3}
229
+ },
230
+ save_losses=False,
231
+ input_val=None,
232
+ output_val=None,
233
+ latent_size=0
234
+ ):
235
+ assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
236
+ assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
237
+
238
+ from RRAEsTorch.utilities import v_print
239
+
240
+ if flush:
241
+ v_print = partial(v_print, f=True)
242
+ else:
243
+ v_print = partial(v_print, f=False)
244
+
245
+ training_params = {
246
+ "loss": loss,
247
+ "step_st": step_st,
248
+ "lr_st": lr_st,
249
+ "print_every": print_every,
250
+ "batch_size_st": batch_size_st,
251
+ "regression": regression,
252
+ "verbose": verbose,
253
+ "loss_kwargs": loss_kwargs,
254
+ }
255
+
256
+ self.all_kwargs = merge_dicts(self.all_kwargs, training_params) # Append dicts
257
+
258
+ model = self.model # Create alias for model
259
+
260
+ fn = lambda x: x if fn is None else fn
261
+
262
+ # Process loss function
263
+ if callable(loss_type):
264
+ loss_fun = loss_type
265
+ else:
266
+ loss_fun = loss_generator(loss_type, loss)
267
+
268
+ # Make step funciton
269
+ def make_step(model, input, out, optimizer, idx, epsilon, **loss_kwargs):
270
+
271
+ optimizer.zero_grad(set_to_none=True)
272
+ loss, aux = loss_fun(model, input, out, idx=idx, epsilon=epsilon, **loss_kwargs)
273
+
274
+ loss.backward()
275
+ optimizer.step()
276
+
277
+ return loss, model, optimizer, aux
278
+
279
+ # Create filter for splitting the model into differential and static portions
280
+
281
+ for p in fix_comp(model): #e.g. model._encode.parameters()
282
+ p.requires_grad = False
283
+
284
+
285
+ # Loop variables
286
+ t_all = 0.0 # Total time
287
+ avg_loss = np.inf
288
+ training_num = random.randint(0, 1000, (1,))[0]
289
+
290
+ # Window to store averages
291
+ store_window = min(stagn_window, sum(step_st))
292
+ prev_losses = Circular_list(store_window)
293
+
294
+ # Initialize tracker
295
+ track_params = tracker.init()
296
+ extra_track = {}
297
+
298
+ # Initializer printer object
299
+ print_info = Print_Info(**verbatim)
300
+
301
+ if save_losses:
302
+ all_losses = []
303
+
304
+
305
+
306
+ # Outler Loop
307
+ for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
308
+ try:
309
+ t_t = 0.0 # Zero time
310
+ optimizer = torch.optim.Adam(
311
+ filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
312
+ )
313
+
314
+ if (batch_size > input.shape[-1]) or batch_size == -1:
315
+ print(f"Setting batch size to: {input.shape[-1]}")
316
+ batch_size = input.shape[-1]
317
+
318
+ # Inner loop (batch)
319
+ inputT = input.permute(*range(input.ndim - 1, -1, -1))
320
+ outputT = output.permute(*range(output.ndim - 1, -1, -1))
321
+
322
+ dataset = TensorDataset(inputT, outputT, torch.arange(0, input.shape[-1], 1))
323
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
324
+ data_iter = iter(dataloader)
325
+
326
+ for step in range(steps):
327
+ try:
328
+ input_b, out_b, idx_b = next(data_iter)
329
+ except StopIteration:
330
+ # reached the end, recreate iterator (reshuffle)
331
+ data_iter = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True))
332
+ input_b, out_b, idx_b = next(data_iter)
333
+
334
+ start_time = time.perf_counter() # Start time
335
+ input_b = input_b.permute(*range(input_b.ndim - 1, -1, -1))
336
+ out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
337
+ out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
338
+ input_b = pre_func_inp(input_b) # Pre-process batch input values
339
+ epsilon = eps_fn(latent_size, input_b.shape[-1])
340
+
341
+ step_kwargs = merge_dicts(loss_kwargs, track_params)
342
+
343
+ # Compute loss
344
+ loss, model, optimizer, (aux, extra_track) = make_step(
345
+ model,
346
+ input_b,
347
+ out_b,
348
+ optimizer,
349
+ idx_b,
350
+ epsilon,
351
+ **step_kwargs,
352
+ )
353
+
354
+ if input_val is not None:
355
+
356
+
357
+ idx = np.arange(input_val.shape[-1])
358
+ val_loss, _ = loss_fun(
359
+ model, input_val, output_val, idx=idx, epsilon=None, **step_kwargs
360
+ )
361
+ aux["val_loss"] = val_loss
362
+ else:
363
+ aux["val_loss"] = None
364
+
365
+ if save_losses:
366
+ all_losses.append(aux)
367
+
368
+ prev_losses.add(loss.item())
369
+
370
+ if step > stagn_window:
371
+ avg_loss = sum(prev_losses) / stagn_window
372
+
373
+ track_params = tracker(loss, avg_loss, track_params, **extra_track)
374
+
375
+ if track_params.get("stop_train"):
376
+ break
377
+
378
+ dt = time.perf_counter() - start_time # Execution time
379
+ t_t += dt # Batch execution time
380
+ t_all += dt # Total execution time
381
+
382
+ if (step % print_every) == 0 or step == steps - 1:
383
+ t_t = 0.0 # Reset Batch execution time
384
+
385
+ print_info.update_aux({"Batch": step, **aux, "Time [s]": dt, "Total time [s]": t_all})
386
+
387
+ print(print_info)
388
+
389
+ if track_params.get("load"):
390
+ self.load_model(f"checkpoint_k_{track_params.get('k_max')}")
391
+ self.del_file(f"checkpoint_k_{track_params.get('k_max')}")
392
+ model = self.model
393
+
394
+ optimizer = torch.optim.Adam(
395
+ filter(lambda p: p.requires_grad, model.parameters()), lr=lr
396
+ )
397
+
398
+ if track_params.get("save") or ((step % save_every) == 0) or torch.isnan(loss):
399
+ if torch.isnan(loss):
400
+ raise ValueError("Loss is nan, stopping training...")
401
+
402
+ self.model = model
403
+
404
+ if track_params.get("save"):
405
+ orig = f"checkpoint_k_{track_params.get('k_max')+1}"
406
+ self.del_file(f"checkpoint_k_{track_params.get('k_max')+2}")
407
+ checkpoint_filename = orig
408
+
409
+ else:
410
+ orig = (
411
+ f"checkpoint_{step}"
412
+ if not torch.isnan(loss)
413
+ else "checkpoint_bf_nan"
414
+ )
415
+
416
+ checkpoint_filename = f"{orig}_0.pkl"
417
+
418
+ if os.path.exists(checkpoint_filename):
419
+ i = 1
420
+ new_filename = f"{orig}_{i}.pkl"
421
+ while self.path_exists(new_filename):
422
+ i += 1
423
+ new_filename = f"{orig}_{i}.pkl"
424
+ checkpoint_filename = new_filename
425
+ self.save_model(checkpoint_filename)
426
+
427
+ except KeyboardInterrupt:
428
+ pass
429
+
430
+ if save_losses:
431
+ orig = "all_losses"
432
+ new_filename = f"{orig}_0.pkl"
433
+ i = 0
434
+ while self.path_exists(new_filename):
435
+ i += 1
436
+ new_filename = f"{orig}_{i}.pkl"
437
+
438
+ self.save_object(all_losses, new_filename)
439
+
440
+ model.eval()
441
+ self.model = model
442
+ self.batch_size = batch_size
443
+ self.t_all = t_all
444
+ return model, track_params | extra_track
445
+
446
+ def plot_training_losses(self, idx=0):
447
+ try:
448
+ with open(os.path.join(self.folder, f"all_losses_{idx}.pkl"), "rb") as f:
449
+ res_list = dill.load(f)
450
+ except FileNotFoundError:
451
+ raise ValueError("Losses where not saved during training, did you set save_losses=True in training_kwargs?")
452
+ training_losses = [r["loss"] for r in res_list]
453
+ val_losses = [r["val_loss"] for r in res_list]
454
+ plt.plot(training_losses, label="training loss")
455
+ if val_losses[0] is not None:
456
+ plt.plot(val_losses, label="val loss")
457
+ plt.legend()
458
+ plt.xlabel("Forward pass")
459
+ plt.show()
460
+
461
+ @torch.no_grad()
462
+ def evaluate(
463
+ self,
464
+ x_train_o=None,
465
+ y_train_o=None,
466
+ x_test_o=None,
467
+ y_test_o=None,
468
+ batch_size=None,
469
+ pre_func_inp=lambda x: x,
470
+ pre_func_out=lambda x: x,
471
+ call_func=None,
472
+ **kwargs,
473
+ ):
474
+ """Performs post-processing to find the relative error of the RRAE model.
475
+
476
+ Parameters:
477
+ -----------
478
+ y_test: jnp.array
479
+ The test data to be used for the error calculation.
480
+ x_test: jnp.array
481
+ The test input. If this is provided the error_test will be computed by sipmly giving
482
+ x_test to the model.
483
+ p_train: jnp.array
484
+ The training data to be used for the interpolation. If this is provided along with p_test (next),
485
+ the error_test will be computed by interpolating the latent space of the model and then decoding it.
486
+ p_test: jnp.array
487
+ The test parameters for which to interpolate.
488
+ save: bool
489
+ If anything other than False, the model as well as the results will be saved in f"{save}".pkl
490
+ """
491
+ call_func = (
492
+ (lambda x: self.model(pre_func_inp(x))) if call_func is None else call_func
493
+ )
494
+ if x_train_o is not None:
495
+ y_train_o = pre_func_out(y_train_o)
496
+ assert (
497
+ hasattr(self, "batch_size") or batch_size is not None
498
+ ), "You should either provide a batch_size or fit the model first."
499
+
500
+ x_train_oT = x_train_o.permute(*range(x_train_o.ndim - 1, -1, -1))
501
+ dataset = TensorDataset(x_train_oT)
502
+ batch_size = self.batch_size if batch_size is None else batch_size
503
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
504
+
505
+ pred = []
506
+ for x_b in dataloader:
507
+ x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
508
+ pred_batch = call_func(x_bT)
509
+ pred.append(pred_batch)
510
+
511
+ y_pred_train_o = torch.concatenate(pred, axis=-1)
512
+
513
+ self.error_train_o = (
514
+ torch.linalg.norm(y_pred_train_o - y_train_o)
515
+ / torch.linalg.norm(y_train_o)
516
+ * 100
517
+ )
518
+ print("Train error on original output: ", self.error_train_o)
519
+
520
+ y_pred_train = self.model.norm_out.default(None, y_pred_train_o)
521
+ y_train = self.model.norm_out.default(None, y_train_o)
522
+ self.error_train = (
523
+ torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
524
+ )
525
+ print("Train error on normalized output: ", self.error_train)
526
+
527
+ if x_test_o is not None:
528
+ y_test_o = pre_func_out(y_test_o)
529
+ x_test_oT = x_test_o.permute(*range(x_test_o.ndim - 1, -1, -1))
530
+ dataset = TensorDataset(x_test_oT)
531
+ batch_size = self.batch_size if batch_size is None else batch_size
532
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
533
+ pred = []
534
+ for x_b in dataloader:
535
+ x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
536
+ pred_batch = call_func(x_bT)
537
+ pred.append(pred_batch)
538
+ y_pred_test_o = torch.concatenate(pred, axis=-1)
539
+ self.error_test_o = (
540
+ torch.linalg.norm(y_pred_test_o - y_test_o)
541
+ / torch.linalg.norm(y_test_o)
542
+ * 100
543
+ )
544
+
545
+ print("Test error on original output: ", self.error_test_o)
546
+
547
+ y_test = self.model.norm_out.default(None, y_test_o)
548
+ y_pred_test = self.model.norm_out.default(None, y_pred_test_o)
549
+ self.error_test = (
550
+ torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
551
+ )
552
+ print("Test error on normalized output: ", self.error_test)
553
+
554
+ else:
555
+ self.error_test = None
556
+ self.error_test_o = None
557
+ y_pred_test_o = None
558
+ y_pred_test = None
559
+
560
+ print("Total training time: ", self.t_all)
561
+ return {
562
+ "error_train": self.error_train,
563
+ "error_test": self.error_test,
564
+ "error_train_o": self.error_train_o,
565
+ "error_test_o": self.error_test_o,
566
+ "y_pred_train_o": y_pred_train_o,
567
+ "y_pred_test_o": y_pred_test_o,
568
+ "y_pred_train": y_pred_train,
569
+ "y_pred_test": y_pred_test,
570
+ }
571
+ def path_exists(self, filename):
572
+ return os.path.exists(os.path.join(self.folder, filename))
573
+
574
+ def del_file(self, filename):
575
+ filename = os.path.join(self.folder, filename)
576
+ if os.path.exists(filename):
577
+ os.remove(filename)
578
+
579
+ def save_model(self, filename=None, erase=False, **kwargs):
580
+ """Saves the trainor class."""
581
+ if filename is None:
582
+ if (self.folder is None) or (self.file is None):
583
+ raise ValueError("You should provide a filename to save")
584
+ filename = os.path.join(self.folder, self.file)
585
+ if erase:
586
+ shutil.rmtree(self.folder)
587
+ os.makedirs(self.folder)
588
+ else:
589
+ filename = os.path.join(self.folder, filename)
590
+ if not os.path.exists(filename):
591
+ with open(filename, "a") as temp_file:
592
+ pass
593
+ os.utime(filename, None)
594
+ attr = merge_dicts(
595
+ remove_keys_from_dict(self.__dict__, ("model", "all_kwargs")),
596
+ kwargs,
597
+ )
598
+
599
+ save_dict = {
600
+ "model_state_dict": self.model.state_dict(),
601
+ "all_kwargs": self.all_kwargs,
602
+ "attr": attr
603
+ }
604
+
605
+ with open(filename, "wb") as f:
606
+ dill.dump(save_dict, f)
607
+
608
+ print(f"Model saved in {filename}")
609
+
610
+ def load_object(self, filename):
611
+ filename = os.path.join(self.folder, filename)
612
+ with open(filename, "rb") as f:
613
+ object = dill.load(f)
614
+ return object
615
+
616
+ def save_object(self, obj, filename):
617
+ filename = os.path.join(self.folder, filename)
618
+ with open(filename, "wb") as f:
619
+ dill.dump(obj, f)
620
+ print(f"Object saved in {filename}")
621
+
622
+ def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, **fn_kwargs):
623
+ """NOTE: fn_kwargs defines the functions of the model
624
+ (e.g. final_activation, inner activation), if
625
+ needed to be saved/loaded on different devices/OS.
626
+ """
627
+
628
+ if path == None:
629
+ filename = self.file if filename is None else filename
630
+ filename = os.path.join(self.folder, filename)
631
+ else:
632
+ filename = path
633
+
634
+ with open(filename, "rb") as f:
635
+ save_dict = dill.load(f)
636
+ self.all_kwargs = save_dict["all_kwargs"]
637
+ if orig_model_cls is None:
638
+ orig_model_cls = self.all_kwargs["orig_model_cls"]
639
+ else:
640
+ orig_model_cls = orig_model_cls
641
+ kwargs = self.all_kwargs["kwargs"]
642
+ self.call_map_axis = self.all_kwargs["call_map_axis"]
643
+ self.call_map_count = self.all_kwargs["call_map_count"]
644
+ self.params_in = self.all_kwargs["params_in"]
645
+ self.params_out = self.all_kwargs["params_out"]
646
+ self.norm_in = self.all_kwargs["norm_in"]
647
+ self.norm_out = self.all_kwargs["norm_out"]
648
+ try:
649
+ self.methods_map = self.all_kwargs["methods_map"]
650
+ self.methods_norm_in = self.all_kwargs["methods_norm_in"]
651
+ self.methods_norm_out = self.all_kwargs["methods_norm_out"]
652
+ except:
653
+ self.methods_map = ["encode", "decode"]
654
+ self.methods_norm_in = ["encode"]
655
+ self.methods_norm_out = ["decode"]
656
+
657
+ kwargs.update(fn_kwargs)
658
+
659
+ model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
660
+ model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
661
+
662
+ model = model_cls(**kwargs)
663
+ model.load_state_dict(save_dict["model_state_dict"])
664
+ self.model = model
665
+ attributes = save_dict["attr"]
666
+
667
+ for key in attributes:
668
+ setattr(self, key, attributes[key])
669
+ if erase:
670
+ os.remove(filename)
671
+
672
+
673
+ class AE_Trainor_class(Trainor_class):
674
+ def __init__(self, *args, **kwargs):
675
+ super().__init__(*args, methods_map=["encode", "decode"], methods_norm_in=["encode"], methods_norm_out=["decode"], **kwargs)
676
+
677
+ def fit(self, *args, training_kwargs, **kwargs):
678
+ if "pre_func_inp" not in kwargs:
679
+ self.pre_func_inp = lambda x: x
680
+ else:
681
+ self.pre_func_inp = kwargs["pre_func_inp"]
682
+ training_kwargs = merge_dicts(kwargs, training_kwargs)
683
+ return super().fit(*args, **training_kwargs) # train model
684
+
685
+ # def AE_interpolate(
686
+ # self,
687
+ # p_train,
688
+ # p_test,
689
+ # x_train_o,
690
+ # y_test_o,
691
+ # batch_size=None,
692
+ # latent_func=None,
693
+ # decode_func=None,
694
+ # norm_out_func=None,
695
+ # ):
696
+ # """Interpolates the latent space of the model and then decodes it to find the output."""
697
+ # batch_size = self.batch_size if batch_size is None else batch_size
698
+
699
+ # if latent_func is None:
700
+ # call_func = lambda x: self.model.latent(x)
701
+ # else:
702
+ # call_func = latent_func
703
+
704
+ # latent_train = eval_with_batches(
705
+ # x_train_o,
706
+ # batch_size,
707
+ # call_func=call_func,
708
+ # str="Finding train latent space used for interpolation...",
709
+ # key_idx=0,
710
+ # )
711
+
712
+ # interpolation = Objects_Interpolator_nD()
713
+ # latent_test_interp = interpolation(p_test, p_train, latent_train)
714
+
715
+ # if decode_func is None:
716
+ # call_func = lambda x: self.model.decode(x)
717
+ # else:
718
+ # call_func = decode_func
719
+
720
+ # y_pred_interp_test_o = eval_with_batches(
721
+ # latent_test_interp,
722
+ # batch_size,
723
+ # call_func=call_func,
724
+ # str="Decoding interpolated latent space ...",
725
+ # key_idx=0,
726
+ # )
727
+
728
+ # self.error_interp_test_o = (
729
+ # jnp.linalg.norm(y_pred_interp_test_o - y_test_o)
730
+ # / jnp.linalg.norm(y_test_o)
731
+ # * 100
732
+ # )
733
+ # print(
734
+ # "Test (interpolation) error over original output: ",
735
+ # self.error_interp_test_o,
736
+ # )
737
+
738
+ # if norm_out_func is None:
739
+ # call_func = lambda x: self.model.norm_out(x)
740
+ # else:
741
+ # call_func = norm_out_func
742
+
743
+ # y_pred_interp_test = eval_with_batches(
744
+ # y_pred_interp_test_o,
745
+ # batch_size,
746
+ # call_func=call_func,
747
+ # str="Finding Normalized pred of interpolated latent space ...",
748
+ # key_idx=0,
749
+ # )
750
+
751
+ # y_test = eval_with_batches(
752
+ # y_test_o,
753
+ # batch_size,
754
+ # call_func=call_func,
755
+ # str="Finding Normalized output of interpolated latent space ...",
756
+ # key_idx=0,
757
+ # )
758
+ # self.error_interp_test = (
759
+ # jnp.linalg.norm(y_pred_interp_test - y_test) / jnp.linalg.norm(y_test) * 100
760
+ # )
761
+ # print(
762
+ # "Test (interpolation) error over normalized output: ",
763
+ # self.error_interp_test,
764
+ # )
765
+ # return {
766
+ # "error_interp_test": self.error_interp_test,
767
+ # "error_interp_test_o": self.error_interp_test_o,
768
+ # "y_pred_interp_test_o": y_pred_interp_test_o,
769
+ # "y_pred_interp_test": y_pred_interp_test,
770
+ # }
771
+
772
+
773
+ class RRAE_Trainor_class(AE_Trainor_class):
774
+ def __init__(self, *args, adapt=False, k_max=None, adap_type="None", **kwargs):
775
+ self.k_init = k_max
776
+ self.adap_type = adap_type
777
+ if k_max is not None:
778
+ kwargs["k_max"] = k_max
779
+
780
+ super().__init__(*args, **kwargs)
781
+ self.adapt = adapt
782
+
783
+ def fit(self, *args, **kwargs):
784
+ if self.adap_type == "pars":
785
+ default_tracker = RRAE_pars_Tracker(k_init=self.k_init)
786
+ elif self.adap_type == "gen":
787
+ if self.k_init is None:
788
+ warnings.warn(
789
+ "k_max can not be None when using gen adaptive scheme, choose a big initial k_max to start with."
790
+ )
791
+ default_tracker = RRAE_gen_Tracker(k_init=self.k_init)
792
+ elif self.adap_type == "None":
793
+ if self.k_init is None:
794
+ warnings.warn(
795
+ "k_max can not be None when using fixed scheme, choose a fixed k_max to use."
796
+ )
797
+ default_tracker = RRAE_fixed_Tracker(k_init=self.k_init)
798
+
799
+ print("Training RRAEs...")
800
+
801
+ if "training_kwargs" in kwargs:
802
+ training_kwargs = kwargs["training_kwargs"]
803
+ kwargs.pop("training_kwargs")
804
+ else:
805
+ training_kwargs = {}
806
+
807
+ if "ft_kwargs" in kwargs:
808
+ ft_kwargs = kwargs["ft_kwargs"]
809
+ kwargs.pop("ft_kwargs")
810
+ else:
811
+ ft_kwargs = {}
812
+
813
+ if "pre_func_inp" not in kwargs:
814
+ self.pre_func_inp = lambda x: x
815
+ else:
816
+ self.pre_func_inp = kwargs["pre_func_inp"]
817
+
818
+ if "tracker" not in training_kwargs:
819
+ training_kwargs["tracker"] = default_tracker
820
+
821
+
822
+ training_kwargs = merge_dicts(kwargs, training_kwargs)
823
+
824
+ model, track_params = super().fit(*args, training_kwargs=training_kwargs) # train model
825
+
826
+ self.track_params = track_params # Save track parameters in class?
827
+
828
+ if "batch_size_st" in training_kwargs:
829
+ self.batch_size = training_kwargs["batch_size_st"][-1]
830
+ else:
831
+ self.batch_size = 16 # default value
832
+
833
+ if ft_kwargs:
834
+ if "get_basis" in ft_kwargs:
835
+ get_basis = ft_kwargs["get_basis"]
836
+ ft_kwargs.pop("get_basis")
837
+ else:
838
+ get_basis = True
839
+
840
+ if "ft_end_type" in ft_kwargs:
841
+ ft_end_type = ft_kwargs["ft_end_type"]
842
+ ft_kwargs.pop("ft_end_type")
843
+ else:
844
+ ft_end_type = "concat"
845
+
846
+ if "basis_call_kwargs" in ft_kwargs:
847
+ basis_call_kwargs = ft_kwargs["basis_call_kwargs"]
848
+ ft_kwargs.pop("basis_call_kwargs")
849
+ else:
850
+ ft_end_type = "concat"
851
+ basis_call_kwargs = {}
852
+
853
+ ft_model, ft_track_params = self.fine_tune_basis(
854
+ None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
855
+ ) # fine tune basis
856
+ self.ft_track_params = ft_track_params
857
+ else:
858
+ ft_model = None
859
+ ft_track_params = {}
860
+ return model, track_params, ft_model, ft_track_params
861
+
862
+ def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
863
+
864
+ if "loss" in kwargs:
865
+ norm_loss_ = kwargs["loss"]
866
+ else:
867
+ print("Defaulting to L2 norm")
868
+ norm_loss_ = lambda x1, x2: 100 * (
869
+ torch.linalg.norm(x1 - x2) / torch.linalg.norm(x2)
870
+ )
871
+
872
+ if (basis is None):
873
+ with torch.no_grad():
874
+ if get_basis:
875
+ inp = args[0] if len(args) > 0 else kwargs["input"]
876
+
877
+ if "basis_batch_size" in kwargs:
878
+ basis_batch_size = kwargs["basis_batch_size"]
879
+ kwargs.pop("basis_batch_size")
880
+ else:
881
+ basis_batch_size = self.batch_size
882
+
883
+ basis_kwargs = basis_call_kwargs | self.track_params
884
+
885
+ inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
886
+ dataset = TensorDataset(inpT)
887
+ dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
888
+
889
+ all_bases = []
890
+
891
+ for inp_b in dataloader:
892
+ inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
893
+ all_bases.append(self.model.latent(
894
+ self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
895
+ )[0]
896
+ )
897
+ if end_type == "concat":
898
+ all_bases = torch.concatenate(all_bases, axis=1)
899
+ print(all_bases.shape)
900
+ basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
901
+ self.basis = basis[:, : self.track_params["k_max"]]
902
+ else:
903
+ self.basis = all_bases
904
+ else:
905
+ bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
906
+ self.basis = torch.eye(bas.shape[0])
907
+ else:
908
+ self.basis = basis
909
+
910
+ def loss_fun(model, input, out, idx, epsilon, basis):
911
+ pred = model(input, epsilon=epsilon, apply_basis=basis, keep_normalized=True)
912
+ aux = {"loss": norm_loss_(pred, out)}
913
+ return norm_loss_(pred, out), (aux, {})
914
+
915
+ if "loss_type" in kwargs :
916
+ pass
917
+ else:
918
+ print("Defaulting to standard loss")
919
+ kwargs["loss_type"] = loss_fun
920
+
921
+ kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
922
+
923
+ fix_comp = lambda model: model._encode.parameters()
924
+ print("Fine tuning the basis ...")
925
+ return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
926
+
927
+ def evaluate(
928
+ self,
929
+ x_train_o=None,
930
+ y_train_o=None,
931
+ x_test_o=None,
932
+ y_test_o=None,
933
+ batch_size=None,
934
+ pre_func_inp=lambda x: x,
935
+ pre_func_out=lambda x: x,
936
+ call_func=None,
937
+ ):
938
+
939
+ call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
940
+ res = super().evaluate(
941
+ x_train_o,
942
+ y_train_o,
943
+ x_test_o,
944
+ y_test_o,
945
+ batch_size,
946
+ call_func=call_func,
947
+ pre_func_inp=pre_func_inp,
948
+ pre_func_out=pre_func_out,
949
+ )
950
+ return res
951
+
952
+ def AE_interpolate(
953
+ self,
954
+ p_train,
955
+ p_test,
956
+ x_train_o,
957
+ y_test_o,
958
+ batch_size=None,
959
+ latent_func=None,
960
+ decode_func=None,
961
+ norm_out_func=None,
962
+ ):
963
+ call_func = lambda x: (
964
+ self.model.latent(x, apply_basis=self.basis)
965
+ if latent_func is None
966
+ else latent_func
967
+ )
968
+ return super().AE_interpolate(
969
+ p_train,
970
+ p_test,
971
+ x_train_o,
972
+ y_test_o,
973
+ batch_size,
974
+ latent_func=call_func,
975
+ decode_func=decode_func,
976
+ norm_out_func=norm_out_func,
977
+ )