RRAEsTorch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RRAEsTorch/AE_base/AE_base.py +104 -0
- RRAEsTorch/AE_base/__init__.py +1 -0
- RRAEsTorch/AE_classes/AE_classes.py +636 -0
- RRAEsTorch/AE_classes/__init__.py +1 -0
- RRAEsTorch/__init__.py +1 -0
- RRAEsTorch/config.py +95 -0
- RRAEsTorch/tests/test_AE_classes_CNN.py +76 -0
- RRAEsTorch/tests/test_AE_classes_MLP.py +73 -0
- RRAEsTorch/tests/test_fitting_CNN.py +109 -0
- RRAEsTorch/tests/test_fitting_MLP.py +133 -0
- RRAEsTorch/tests/test_mains.py +34 -0
- RRAEsTorch/tests/test_save.py +62 -0
- RRAEsTorch/tests/test_stable_SVD.py +37 -0
- RRAEsTorch/tests/test_wrappers.py +56 -0
- RRAEsTorch/trackers/__init__.py +1 -0
- RRAEsTorch/trackers/trackers.py +245 -0
- RRAEsTorch/training_classes/__init__.py +5 -0
- RRAEsTorch/training_classes/training_classes.py +977 -0
- RRAEsTorch/utilities/__init__.py +1 -0
- RRAEsTorch/utilities/utilities.py +1562 -0
- RRAEsTorch/wrappers/__init__.py +1 -0
- RRAEsTorch/wrappers/wrappers.py +237 -0
- rraestorch-0.1.0.dist-info/METADATA +90 -0
- rraestorch-0.1.0.dist-info/RECORD +27 -0
- rraestorch-0.1.0.dist-info/WHEEL +4 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE +21 -0
- rraestorch-0.1.0.dist-info/licenses/LICENSE copy +21 -0
|
@@ -0,0 +1,977 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import print_function
|
|
3
|
+
import numpy.random as random
|
|
4
|
+
import numpy as np
|
|
5
|
+
from RRAEsTorch.utilities import (
|
|
6
|
+
remove_keys_from_dict,
|
|
7
|
+
merge_dicts,
|
|
8
|
+
loss_generator,
|
|
9
|
+
eval_with_batches
|
|
10
|
+
)
|
|
11
|
+
import warnings
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import dill
|
|
15
|
+
import shutil
|
|
16
|
+
from RRAEsTorch.wrappers import vmap_wrap, norm_wrap
|
|
17
|
+
from functools import partial
|
|
18
|
+
from RRAEsTorch.trackers import (
|
|
19
|
+
Null_Tracker,
|
|
20
|
+
RRAE_fixed_Tracker,
|
|
21
|
+
RRAE_pars_Tracker,
|
|
22
|
+
RRAE_gen_Tracker,
|
|
23
|
+
)
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
from prettytable import PrettyTable
|
|
26
|
+
import torch
|
|
27
|
+
from torch.utils.data import TensorDataset, DataLoader
|
|
28
|
+
from adabelief_pytorch import AdaBelief
|
|
29
|
+
|
|
30
|
+
class Circular_list:
|
|
31
|
+
"""
|
|
32
|
+
Creates a list of fixed size.
|
|
33
|
+
Adds elements in a circular manner
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, size):
|
|
37
|
+
self.size = size
|
|
38
|
+
self.buffer = [0.0] * size
|
|
39
|
+
self.index = 0
|
|
40
|
+
|
|
41
|
+
def add(self, value):
|
|
42
|
+
self.buffer[self.index] = value
|
|
43
|
+
self.index = (self.index + 1) % self.size
|
|
44
|
+
|
|
45
|
+
def __iter__(self):
|
|
46
|
+
for value in self.buffer:
|
|
47
|
+
yield value
|
|
48
|
+
|
|
49
|
+
class Standard_Print():
|
|
50
|
+
def __init__(self, aux, *args, **kwargs):
|
|
51
|
+
self.aux = aux
|
|
52
|
+
|
|
53
|
+
def __str__(self):
|
|
54
|
+
message = ", ".join([f"{k}: {v}" for k, v in self.aux.items()])
|
|
55
|
+
return message
|
|
56
|
+
|
|
57
|
+
class Pretty_Print(PrettyTable):
|
|
58
|
+
def __init__(self, aux, window_size=5, format_numbers=True, printer_settings={}):
|
|
59
|
+
self.aux = aux
|
|
60
|
+
self.format_numbers = format_numbers
|
|
61
|
+
self.set_title = False
|
|
62
|
+
self.first_print = True
|
|
63
|
+
self.window_size = window_size
|
|
64
|
+
self.index_new = 0
|
|
65
|
+
self.index_old = 0
|
|
66
|
+
super().__init__(**printer_settings)
|
|
67
|
+
|
|
68
|
+
def format_number(self, n):
|
|
69
|
+
if isinstance(n, int):
|
|
70
|
+
return "{:.0f}".format(n)
|
|
71
|
+
else:
|
|
72
|
+
return "{:.3f}".format(n)
|
|
73
|
+
|
|
74
|
+
def __str__(self):
|
|
75
|
+
data = list(self.aux.values())
|
|
76
|
+
if self.format_numbers == True:
|
|
77
|
+
data = list(map(self.format_number, data))
|
|
78
|
+
|
|
79
|
+
if self.first_print == True:
|
|
80
|
+
titles = list(self.aux.keys())
|
|
81
|
+
self.field_names = titles
|
|
82
|
+
self.title = "Results"
|
|
83
|
+
self.set_title = True
|
|
84
|
+
for _ in range(self.window_size):
|
|
85
|
+
self.add_row([" "]*len(titles))
|
|
86
|
+
|
|
87
|
+
self._rows[self.index_new] = data
|
|
88
|
+
print(super().__str__())
|
|
89
|
+
print(f"\033[{self.window_size+1}A", end="")
|
|
90
|
+
self.first_print = False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
self._rows[self.index_new] = data
|
|
94
|
+
|
|
95
|
+
# This function does a lot of unnecessary things... Removed the parts that I don't want
|
|
96
|
+
# print( "\n".join(self.get_string(start=self.index_new,
|
|
97
|
+
# end=self.index_new+1,
|
|
98
|
+
# float_format="3.3").splitlines()[-2:]))
|
|
99
|
+
|
|
100
|
+
options = self._get_options({})
|
|
101
|
+
|
|
102
|
+
lines = []
|
|
103
|
+
|
|
104
|
+
# Get the rows we need to print, taking into account slicing, sorting, etc.
|
|
105
|
+
formatted_rows = [self._format_row(row) for row in self._rows[self.index_new : self.index_new+1]]
|
|
106
|
+
|
|
107
|
+
# Compute column widths
|
|
108
|
+
self._compute_widths(formatted_rows, options)
|
|
109
|
+
self._hrule = self._stringify_hrule(options)
|
|
110
|
+
|
|
111
|
+
# Add rows
|
|
112
|
+
if formatted_rows:
|
|
113
|
+
lines.append(
|
|
114
|
+
self._stringify_row(
|
|
115
|
+
formatted_rows[-1],
|
|
116
|
+
options,
|
|
117
|
+
self._stringify_hrule(options, where="bottom_"),
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Add bottom of border
|
|
122
|
+
lines.append(self._stringify_hrule(options, where="bottom_"))
|
|
123
|
+
|
|
124
|
+
print("\n".join(lines))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
# Update indices
|
|
128
|
+
self.index_old = self.index_new
|
|
129
|
+
self.index_new = (self.index_new + 1) % self.window_size
|
|
130
|
+
|
|
131
|
+
# if we move to another printing cycle, push cursor back
|
|
132
|
+
# Dirty trick but works on ubuntu...
|
|
133
|
+
if (self.index_new - self.index_old) != 1:
|
|
134
|
+
print(f"\033[{self.window_size*2}A", end="")
|
|
135
|
+
# Factor of 2 is due to the lower line in the table
|
|
136
|
+
|
|
137
|
+
return '\033[1A'
|
|
138
|
+
|
|
139
|
+
class Print_Info(PrettyTable):
|
|
140
|
+
def __init__(self, print_type="std", aux={}, *args, **kwargs):
|
|
141
|
+
check = (print_type.lower() == "std")
|
|
142
|
+
if check == True:
|
|
143
|
+
self.print_obj = Standard_Print(aux, *args, **kwargs)
|
|
144
|
+
else:
|
|
145
|
+
self.print_obj = Pretty_Print(aux, *args, **kwargs)
|
|
146
|
+
|
|
147
|
+
def update_aux(self, aux):
|
|
148
|
+
self.print_obj.aux = aux
|
|
149
|
+
|
|
150
|
+
def __str__(self):
|
|
151
|
+
return self.print_obj.__str__()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class Trainor_class:
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
in_train=None,
|
|
158
|
+
model_cls=None,
|
|
159
|
+
folder="",
|
|
160
|
+
file=None,
|
|
161
|
+
out_train=None,
|
|
162
|
+
norm_in="None",
|
|
163
|
+
norm_out="None",
|
|
164
|
+
methods_map=["__call__"],
|
|
165
|
+
methods_norm_in=["__call__"],
|
|
166
|
+
methods_norm_out=["__call__"],
|
|
167
|
+
call_map_count=1,
|
|
168
|
+
call_map_axis=-1,
|
|
169
|
+
**kwargs,
|
|
170
|
+
):
|
|
171
|
+
if model_cls is not None:
|
|
172
|
+
orig_model_cls = model_cls
|
|
173
|
+
model_cls = vmap_wrap(orig_model_cls, call_map_axis, call_map_count, methods_map)
|
|
174
|
+
model_cls = norm_wrap(model_cls, in_train, norm_in, None, out_train, norm_out, None, methods_norm_in, methods_norm_out)
|
|
175
|
+
self.model = model_cls(**kwargs)
|
|
176
|
+
params_in = self.model.params_in
|
|
177
|
+
params_out = self.model.params_out
|
|
178
|
+
else:
|
|
179
|
+
orig_model_cls = None
|
|
180
|
+
params_in = None
|
|
181
|
+
params_out = None
|
|
182
|
+
|
|
183
|
+
self.all_kwargs = {
|
|
184
|
+
"kwargs": kwargs,
|
|
185
|
+
"params_in": params_in,
|
|
186
|
+
"params_out": params_out,
|
|
187
|
+
"norm_in": norm_in,
|
|
188
|
+
"norm_out": norm_out,
|
|
189
|
+
"call_map_axis": call_map_axis,
|
|
190
|
+
"call_map_count": call_map_count,
|
|
191
|
+
"orig_model_cls": orig_model_cls,
|
|
192
|
+
"methods_map": methods_map,
|
|
193
|
+
"methods_norm_in": methods_norm_in,
|
|
194
|
+
"methods_norm_out": methods_norm_out,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
self.folder = folder
|
|
198
|
+
if folder != "":
|
|
199
|
+
if not os.path.exists(folder):
|
|
200
|
+
os.makedirs(folder)
|
|
201
|
+
self.file = file
|
|
202
|
+
|
|
203
|
+
def fit(
|
|
204
|
+
self,
|
|
205
|
+
input,
|
|
206
|
+
output,
|
|
207
|
+
loss_type="default", # should be string to use pre defined functions
|
|
208
|
+
loss=None, # a function loss(pred, true) to differentiate in the model
|
|
209
|
+
step_st=[3000, 3000], # 000, 8000],
|
|
210
|
+
lr_st=[1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9],
|
|
211
|
+
print_every=np.nan,
|
|
212
|
+
save_every=np.nan,
|
|
213
|
+
batch_size_st=[16, 16, 16, 16, 32],
|
|
214
|
+
regression=False,
|
|
215
|
+
verbose=True,
|
|
216
|
+
loss_kwargs={},
|
|
217
|
+
flush=False,
|
|
218
|
+
pre_func_inp=lambda x: x,
|
|
219
|
+
pre_func_out=lambda x: x,
|
|
220
|
+
fix_comp=lambda _: (),
|
|
221
|
+
tracker=Null_Tracker(),
|
|
222
|
+
stagn_window=20,
|
|
223
|
+
eps_fn=lambda lat, bs: None,
|
|
224
|
+
optimizer=AdaBelief,
|
|
225
|
+
verbatim = {
|
|
226
|
+
"print_type": "std",
|
|
227
|
+
"window_size" : 5,
|
|
228
|
+
"printer_settings":{"padding_width": 3}
|
|
229
|
+
},
|
|
230
|
+
save_losses=False,
|
|
231
|
+
input_val=None,
|
|
232
|
+
output_val=None,
|
|
233
|
+
latent_size=0
|
|
234
|
+
):
|
|
235
|
+
assert isinstance(input, torch.Tensor), "Input should be a torch tensor"
|
|
236
|
+
assert isinstance(output, torch.Tensor), "Output should be a torch tensor"
|
|
237
|
+
|
|
238
|
+
from RRAEsTorch.utilities import v_print
|
|
239
|
+
|
|
240
|
+
if flush:
|
|
241
|
+
v_print = partial(v_print, f=True)
|
|
242
|
+
else:
|
|
243
|
+
v_print = partial(v_print, f=False)
|
|
244
|
+
|
|
245
|
+
training_params = {
|
|
246
|
+
"loss": loss,
|
|
247
|
+
"step_st": step_st,
|
|
248
|
+
"lr_st": lr_st,
|
|
249
|
+
"print_every": print_every,
|
|
250
|
+
"batch_size_st": batch_size_st,
|
|
251
|
+
"regression": regression,
|
|
252
|
+
"verbose": verbose,
|
|
253
|
+
"loss_kwargs": loss_kwargs,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
self.all_kwargs = merge_dicts(self.all_kwargs, training_params) # Append dicts
|
|
257
|
+
|
|
258
|
+
model = self.model # Create alias for model
|
|
259
|
+
|
|
260
|
+
fn = lambda x: x if fn is None else fn
|
|
261
|
+
|
|
262
|
+
# Process loss function
|
|
263
|
+
if callable(loss_type):
|
|
264
|
+
loss_fun = loss_type
|
|
265
|
+
else:
|
|
266
|
+
loss_fun = loss_generator(loss_type, loss)
|
|
267
|
+
|
|
268
|
+
# Make step funciton
|
|
269
|
+
def make_step(model, input, out, optimizer, idx, epsilon, **loss_kwargs):
|
|
270
|
+
|
|
271
|
+
optimizer.zero_grad(set_to_none=True)
|
|
272
|
+
loss, aux = loss_fun(model, input, out, idx=idx, epsilon=epsilon, **loss_kwargs)
|
|
273
|
+
|
|
274
|
+
loss.backward()
|
|
275
|
+
optimizer.step()
|
|
276
|
+
|
|
277
|
+
return loss, model, optimizer, aux
|
|
278
|
+
|
|
279
|
+
# Create filter for splitting the model into differential and static portions
|
|
280
|
+
|
|
281
|
+
for p in fix_comp(model): #e.g. model._encode.parameters()
|
|
282
|
+
p.requires_grad = False
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Loop variables
|
|
286
|
+
t_all = 0.0 # Total time
|
|
287
|
+
avg_loss = np.inf
|
|
288
|
+
training_num = random.randint(0, 1000, (1,))[0]
|
|
289
|
+
|
|
290
|
+
# Window to store averages
|
|
291
|
+
store_window = min(stagn_window, sum(step_st))
|
|
292
|
+
prev_losses = Circular_list(store_window)
|
|
293
|
+
|
|
294
|
+
# Initialize tracker
|
|
295
|
+
track_params = tracker.init()
|
|
296
|
+
extra_track = {}
|
|
297
|
+
|
|
298
|
+
# Initializer printer object
|
|
299
|
+
print_info = Print_Info(**verbatim)
|
|
300
|
+
|
|
301
|
+
if save_losses:
|
|
302
|
+
all_losses = []
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# Outler Loop
|
|
307
|
+
for steps, lr, batch_size in zip(step_st, lr_st, batch_size_st):
|
|
308
|
+
try:
|
|
309
|
+
t_t = 0.0 # Zero time
|
|
310
|
+
optimizer = torch.optim.Adam(
|
|
311
|
+
filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if (batch_size > input.shape[-1]) or batch_size == -1:
|
|
315
|
+
print(f"Setting batch size to: {input.shape[-1]}")
|
|
316
|
+
batch_size = input.shape[-1]
|
|
317
|
+
|
|
318
|
+
# Inner loop (batch)
|
|
319
|
+
inputT = input.permute(*range(input.ndim - 1, -1, -1))
|
|
320
|
+
outputT = output.permute(*range(output.ndim - 1, -1, -1))
|
|
321
|
+
|
|
322
|
+
dataset = TensorDataset(inputT, outputT, torch.arange(0, input.shape[-1], 1))
|
|
323
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
324
|
+
data_iter = iter(dataloader)
|
|
325
|
+
|
|
326
|
+
for step in range(steps):
|
|
327
|
+
try:
|
|
328
|
+
input_b, out_b, idx_b = next(data_iter)
|
|
329
|
+
except StopIteration:
|
|
330
|
+
# reached the end, recreate iterator (reshuffle)
|
|
331
|
+
data_iter = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True))
|
|
332
|
+
input_b, out_b, idx_b = next(data_iter)
|
|
333
|
+
|
|
334
|
+
start_time = time.perf_counter() # Start time
|
|
335
|
+
input_b = input_b.permute(*range(input_b.ndim - 1, -1, -1))
|
|
336
|
+
out_b = self.model.norm_out.default(None, pre_func_out(out_b)) # Pre-process batch out values
|
|
337
|
+
out_b = out_b.permute(*range(out_b.ndim - 1, -1, -1))
|
|
338
|
+
input_b = pre_func_inp(input_b) # Pre-process batch input values
|
|
339
|
+
epsilon = eps_fn(latent_size, input_b.shape[-1])
|
|
340
|
+
|
|
341
|
+
step_kwargs = merge_dicts(loss_kwargs, track_params)
|
|
342
|
+
|
|
343
|
+
# Compute loss
|
|
344
|
+
loss, model, optimizer, (aux, extra_track) = make_step(
|
|
345
|
+
model,
|
|
346
|
+
input_b,
|
|
347
|
+
out_b,
|
|
348
|
+
optimizer,
|
|
349
|
+
idx_b,
|
|
350
|
+
epsilon,
|
|
351
|
+
**step_kwargs,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if input_val is not None:
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
idx = np.arange(input_val.shape[-1])
|
|
358
|
+
val_loss, _ = loss_fun(
|
|
359
|
+
model, input_val, output_val, idx=idx, epsilon=None, **step_kwargs
|
|
360
|
+
)
|
|
361
|
+
aux["val_loss"] = val_loss
|
|
362
|
+
else:
|
|
363
|
+
aux["val_loss"] = None
|
|
364
|
+
|
|
365
|
+
if save_losses:
|
|
366
|
+
all_losses.append(aux)
|
|
367
|
+
|
|
368
|
+
prev_losses.add(loss.item())
|
|
369
|
+
|
|
370
|
+
if step > stagn_window:
|
|
371
|
+
avg_loss = sum(prev_losses) / stagn_window
|
|
372
|
+
|
|
373
|
+
track_params = tracker(loss, avg_loss, track_params, **extra_track)
|
|
374
|
+
|
|
375
|
+
if track_params.get("stop_train"):
|
|
376
|
+
break
|
|
377
|
+
|
|
378
|
+
dt = time.perf_counter() - start_time # Execution time
|
|
379
|
+
t_t += dt # Batch execution time
|
|
380
|
+
t_all += dt # Total execution time
|
|
381
|
+
|
|
382
|
+
if (step % print_every) == 0 or step == steps - 1:
|
|
383
|
+
t_t = 0.0 # Reset Batch execution time
|
|
384
|
+
|
|
385
|
+
print_info.update_aux({"Batch": step, **aux, "Time [s]": dt, "Total time [s]": t_all})
|
|
386
|
+
|
|
387
|
+
print(print_info)
|
|
388
|
+
|
|
389
|
+
if track_params.get("load"):
|
|
390
|
+
self.load_model(f"checkpoint_k_{track_params.get('k_max')}")
|
|
391
|
+
self.del_file(f"checkpoint_k_{track_params.get('k_max')}")
|
|
392
|
+
model = self.model
|
|
393
|
+
|
|
394
|
+
optimizer = torch.optim.Adam(
|
|
395
|
+
filter(lambda p: p.requires_grad, model.parameters()), lr=lr
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
if track_params.get("save") or ((step % save_every) == 0) or torch.isnan(loss):
|
|
399
|
+
if torch.isnan(loss):
|
|
400
|
+
raise ValueError("Loss is nan, stopping training...")
|
|
401
|
+
|
|
402
|
+
self.model = model
|
|
403
|
+
|
|
404
|
+
if track_params.get("save"):
|
|
405
|
+
orig = f"checkpoint_k_{track_params.get('k_max')+1}"
|
|
406
|
+
self.del_file(f"checkpoint_k_{track_params.get('k_max')+2}")
|
|
407
|
+
checkpoint_filename = orig
|
|
408
|
+
|
|
409
|
+
else:
|
|
410
|
+
orig = (
|
|
411
|
+
f"checkpoint_{step}"
|
|
412
|
+
if not torch.isnan(loss)
|
|
413
|
+
else "checkpoint_bf_nan"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
checkpoint_filename = f"{orig}_0.pkl"
|
|
417
|
+
|
|
418
|
+
if os.path.exists(checkpoint_filename):
|
|
419
|
+
i = 1
|
|
420
|
+
new_filename = f"{orig}_{i}.pkl"
|
|
421
|
+
while self.path_exists(new_filename):
|
|
422
|
+
i += 1
|
|
423
|
+
new_filename = f"{orig}_{i}.pkl"
|
|
424
|
+
checkpoint_filename = new_filename
|
|
425
|
+
self.save_model(checkpoint_filename)
|
|
426
|
+
|
|
427
|
+
except KeyboardInterrupt:
|
|
428
|
+
pass
|
|
429
|
+
|
|
430
|
+
if save_losses:
|
|
431
|
+
orig = "all_losses"
|
|
432
|
+
new_filename = f"{orig}_0.pkl"
|
|
433
|
+
i = 0
|
|
434
|
+
while self.path_exists(new_filename):
|
|
435
|
+
i += 1
|
|
436
|
+
new_filename = f"{orig}_{i}.pkl"
|
|
437
|
+
|
|
438
|
+
self.save_object(all_losses, new_filename)
|
|
439
|
+
|
|
440
|
+
model.eval()
|
|
441
|
+
self.model = model
|
|
442
|
+
self.batch_size = batch_size
|
|
443
|
+
self.t_all = t_all
|
|
444
|
+
return model, track_params | extra_track
|
|
445
|
+
|
|
446
|
+
def plot_training_losses(self, idx=0):
|
|
447
|
+
try:
|
|
448
|
+
with open(os.path.join(self.folder, f"all_losses_{idx}.pkl"), "rb") as f:
|
|
449
|
+
res_list = dill.load(f)
|
|
450
|
+
except FileNotFoundError:
|
|
451
|
+
raise ValueError("Losses where not saved during training, did you set save_losses=True in training_kwargs?")
|
|
452
|
+
training_losses = [r["loss"] for r in res_list]
|
|
453
|
+
val_losses = [r["val_loss"] for r in res_list]
|
|
454
|
+
plt.plot(training_losses, label="training loss")
|
|
455
|
+
if val_losses[0] is not None:
|
|
456
|
+
plt.plot(val_losses, label="val loss")
|
|
457
|
+
plt.legend()
|
|
458
|
+
plt.xlabel("Forward pass")
|
|
459
|
+
plt.show()
|
|
460
|
+
|
|
461
|
+
@torch.no_grad()
|
|
462
|
+
def evaluate(
|
|
463
|
+
self,
|
|
464
|
+
x_train_o=None,
|
|
465
|
+
y_train_o=None,
|
|
466
|
+
x_test_o=None,
|
|
467
|
+
y_test_o=None,
|
|
468
|
+
batch_size=None,
|
|
469
|
+
pre_func_inp=lambda x: x,
|
|
470
|
+
pre_func_out=lambda x: x,
|
|
471
|
+
call_func=None,
|
|
472
|
+
**kwargs,
|
|
473
|
+
):
|
|
474
|
+
"""Performs post-processing to find the relative error of the RRAE model.
|
|
475
|
+
|
|
476
|
+
Parameters:
|
|
477
|
+
-----------
|
|
478
|
+
y_test: jnp.array
|
|
479
|
+
The test data to be used for the error calculation.
|
|
480
|
+
x_test: jnp.array
|
|
481
|
+
The test input. If this is provided the error_test will be computed by sipmly giving
|
|
482
|
+
x_test to the model.
|
|
483
|
+
p_train: jnp.array
|
|
484
|
+
The training data to be used for the interpolation. If this is provided along with p_test (next),
|
|
485
|
+
the error_test will be computed by interpolating the latent space of the model and then decoding it.
|
|
486
|
+
p_test: jnp.array
|
|
487
|
+
The test parameters for which to interpolate.
|
|
488
|
+
save: bool
|
|
489
|
+
If anything other than False, the model as well as the results will be saved in f"{save}".pkl
|
|
490
|
+
"""
|
|
491
|
+
call_func = (
|
|
492
|
+
(lambda x: self.model(pre_func_inp(x))) if call_func is None else call_func
|
|
493
|
+
)
|
|
494
|
+
if x_train_o is not None:
|
|
495
|
+
y_train_o = pre_func_out(y_train_o)
|
|
496
|
+
assert (
|
|
497
|
+
hasattr(self, "batch_size") or batch_size is not None
|
|
498
|
+
), "You should either provide a batch_size or fit the model first."
|
|
499
|
+
|
|
500
|
+
x_train_oT = x_train_o.permute(*range(x_train_o.ndim - 1, -1, -1))
|
|
501
|
+
dataset = TensorDataset(x_train_oT)
|
|
502
|
+
batch_size = self.batch_size if batch_size is None else batch_size
|
|
503
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
504
|
+
|
|
505
|
+
pred = []
|
|
506
|
+
for x_b in dataloader:
|
|
507
|
+
x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
|
|
508
|
+
pred_batch = call_func(x_bT)
|
|
509
|
+
pred.append(pred_batch)
|
|
510
|
+
|
|
511
|
+
y_pred_train_o = torch.concatenate(pred, axis=-1)
|
|
512
|
+
|
|
513
|
+
self.error_train_o = (
|
|
514
|
+
torch.linalg.norm(y_pred_train_o - y_train_o)
|
|
515
|
+
/ torch.linalg.norm(y_train_o)
|
|
516
|
+
* 100
|
|
517
|
+
)
|
|
518
|
+
print("Train error on original output: ", self.error_train_o)
|
|
519
|
+
|
|
520
|
+
y_pred_train = self.model.norm_out.default(None, y_pred_train_o)
|
|
521
|
+
y_train = self.model.norm_out.default(None, y_train_o)
|
|
522
|
+
self.error_train = (
|
|
523
|
+
torch.linalg.norm(y_pred_train - y_train) / torch.linalg.norm(y_train) * 100
|
|
524
|
+
)
|
|
525
|
+
print("Train error on normalized output: ", self.error_train)
|
|
526
|
+
|
|
527
|
+
if x_test_o is not None:
|
|
528
|
+
y_test_o = pre_func_out(y_test_o)
|
|
529
|
+
x_test_oT = x_test_o.permute(*range(x_test_o.ndim - 1, -1, -1))
|
|
530
|
+
dataset = TensorDataset(x_test_oT)
|
|
531
|
+
batch_size = self.batch_size if batch_size is None else batch_size
|
|
532
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
533
|
+
pred = []
|
|
534
|
+
for x_b in dataloader:
|
|
535
|
+
x_bT = x_b[0].permute(*range(x_b[0].ndim - 1, -1, -1))
|
|
536
|
+
pred_batch = call_func(x_bT)
|
|
537
|
+
pred.append(pred_batch)
|
|
538
|
+
y_pred_test_o = torch.concatenate(pred, axis=-1)
|
|
539
|
+
self.error_test_o = (
|
|
540
|
+
torch.linalg.norm(y_pred_test_o - y_test_o)
|
|
541
|
+
/ torch.linalg.norm(y_test_o)
|
|
542
|
+
* 100
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
print("Test error on original output: ", self.error_test_o)
|
|
546
|
+
|
|
547
|
+
y_test = self.model.norm_out.default(None, y_test_o)
|
|
548
|
+
y_pred_test = self.model.norm_out.default(None, y_pred_test_o)
|
|
549
|
+
self.error_test = (
|
|
550
|
+
torch.linalg.norm(y_pred_test - y_test) / torch.linalg.norm(y_test) * 100
|
|
551
|
+
)
|
|
552
|
+
print("Test error on normalized output: ", self.error_test)
|
|
553
|
+
|
|
554
|
+
else:
|
|
555
|
+
self.error_test = None
|
|
556
|
+
self.error_test_o = None
|
|
557
|
+
y_pred_test_o = None
|
|
558
|
+
y_pred_test = None
|
|
559
|
+
|
|
560
|
+
print("Total training time: ", self.t_all)
|
|
561
|
+
return {
|
|
562
|
+
"error_train": self.error_train,
|
|
563
|
+
"error_test": self.error_test,
|
|
564
|
+
"error_train_o": self.error_train_o,
|
|
565
|
+
"error_test_o": self.error_test_o,
|
|
566
|
+
"y_pred_train_o": y_pred_train_o,
|
|
567
|
+
"y_pred_test_o": y_pred_test_o,
|
|
568
|
+
"y_pred_train": y_pred_train,
|
|
569
|
+
"y_pred_test": y_pred_test,
|
|
570
|
+
}
|
|
571
|
+
def path_exists(self, filename):
|
|
572
|
+
return os.path.exists(os.path.join(self.folder, filename))
|
|
573
|
+
|
|
574
|
+
def del_file(self, filename):
|
|
575
|
+
filename = os.path.join(self.folder, filename)
|
|
576
|
+
if os.path.exists(filename):
|
|
577
|
+
os.remove(filename)
|
|
578
|
+
|
|
579
|
+
def save_model(self, filename=None, erase=False, **kwargs):
|
|
580
|
+
"""Saves the trainor class."""
|
|
581
|
+
if filename is None:
|
|
582
|
+
if (self.folder is None) or (self.file is None):
|
|
583
|
+
raise ValueError("You should provide a filename to save")
|
|
584
|
+
filename = os.path.join(self.folder, self.file)
|
|
585
|
+
if erase:
|
|
586
|
+
shutil.rmtree(self.folder)
|
|
587
|
+
os.makedirs(self.folder)
|
|
588
|
+
else:
|
|
589
|
+
filename = os.path.join(self.folder, filename)
|
|
590
|
+
if not os.path.exists(filename):
|
|
591
|
+
with open(filename, "a") as temp_file:
|
|
592
|
+
pass
|
|
593
|
+
os.utime(filename, None)
|
|
594
|
+
attr = merge_dicts(
|
|
595
|
+
remove_keys_from_dict(self.__dict__, ("model", "all_kwargs")),
|
|
596
|
+
kwargs,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
save_dict = {
|
|
600
|
+
"model_state_dict": self.model.state_dict(),
|
|
601
|
+
"all_kwargs": self.all_kwargs,
|
|
602
|
+
"attr": attr
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
with open(filename, "wb") as f:
|
|
606
|
+
dill.dump(save_dict, f)
|
|
607
|
+
|
|
608
|
+
print(f"Model saved in {filename}")
|
|
609
|
+
|
|
610
|
+
def load_object(self, filename):
|
|
611
|
+
filename = os.path.join(self.folder, filename)
|
|
612
|
+
with open(filename, "rb") as f:
|
|
613
|
+
object = dill.load(f)
|
|
614
|
+
return object
|
|
615
|
+
|
|
616
|
+
def save_object(self, obj, filename):
|
|
617
|
+
filename = os.path.join(self.folder, filename)
|
|
618
|
+
with open(filename, "wb") as f:
|
|
619
|
+
dill.dump(obj, f)
|
|
620
|
+
print(f"Object saved in {filename}")
|
|
621
|
+
|
|
622
|
+
def load_model(self, filename=None, erase=False, path=None, orig_model_cls=None, **fn_kwargs):
|
|
623
|
+
"""NOTE: fn_kwargs defines the functions of the model
|
|
624
|
+
(e.g. final_activation, inner activation), if
|
|
625
|
+
needed to be saved/loaded on different devices/OS.
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
if path == None:
|
|
629
|
+
filename = self.file if filename is None else filename
|
|
630
|
+
filename = os.path.join(self.folder, filename)
|
|
631
|
+
else:
|
|
632
|
+
filename = path
|
|
633
|
+
|
|
634
|
+
with open(filename, "rb") as f:
|
|
635
|
+
save_dict = dill.load(f)
|
|
636
|
+
self.all_kwargs = save_dict["all_kwargs"]
|
|
637
|
+
if orig_model_cls is None:
|
|
638
|
+
orig_model_cls = self.all_kwargs["orig_model_cls"]
|
|
639
|
+
else:
|
|
640
|
+
orig_model_cls = orig_model_cls
|
|
641
|
+
kwargs = self.all_kwargs["kwargs"]
|
|
642
|
+
self.call_map_axis = self.all_kwargs["call_map_axis"]
|
|
643
|
+
self.call_map_count = self.all_kwargs["call_map_count"]
|
|
644
|
+
self.params_in = self.all_kwargs["params_in"]
|
|
645
|
+
self.params_out = self.all_kwargs["params_out"]
|
|
646
|
+
self.norm_in = self.all_kwargs["norm_in"]
|
|
647
|
+
self.norm_out = self.all_kwargs["norm_out"]
|
|
648
|
+
try:
|
|
649
|
+
self.methods_map = self.all_kwargs["methods_map"]
|
|
650
|
+
self.methods_norm_in = self.all_kwargs["methods_norm_in"]
|
|
651
|
+
self.methods_norm_out = self.all_kwargs["methods_norm_out"]
|
|
652
|
+
except:
|
|
653
|
+
self.methods_map = ["encode", "decode"]
|
|
654
|
+
self.methods_norm_in = ["encode"]
|
|
655
|
+
self.methods_norm_out = ["decode"]
|
|
656
|
+
|
|
657
|
+
kwargs.update(fn_kwargs)
|
|
658
|
+
|
|
659
|
+
model_cls = vmap_wrap(orig_model_cls, self.call_map_axis, self.call_map_count, self.methods_map)
|
|
660
|
+
model_cls = norm_wrap(model_cls, None, self.norm_in, self.params_in, None, self.norm_out, self.params_out, self.methods_norm_in, self.methods_norm_out)
|
|
661
|
+
|
|
662
|
+
model = model_cls(**kwargs)
|
|
663
|
+
model.load_state_dict(save_dict["model_state_dict"])
|
|
664
|
+
self.model = model
|
|
665
|
+
attributes = save_dict["attr"]
|
|
666
|
+
|
|
667
|
+
for key in attributes:
|
|
668
|
+
setattr(self, key, attributes[key])
|
|
669
|
+
if erase:
|
|
670
|
+
os.remove(filename)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
class AE_Trainor_class(Trainor_class):
|
|
674
|
+
def __init__(self, *args, **kwargs):
|
|
675
|
+
super().__init__(*args, methods_map=["encode", "decode"], methods_norm_in=["encode"], methods_norm_out=["decode"], **kwargs)
|
|
676
|
+
|
|
677
|
+
def fit(self, *args, training_kwargs, **kwargs):
|
|
678
|
+
if "pre_func_inp" not in kwargs:
|
|
679
|
+
self.pre_func_inp = lambda x: x
|
|
680
|
+
else:
|
|
681
|
+
self.pre_func_inp = kwargs["pre_func_inp"]
|
|
682
|
+
training_kwargs = merge_dicts(kwargs, training_kwargs)
|
|
683
|
+
return super().fit(*args, **training_kwargs) # train model
|
|
684
|
+
|
|
685
|
+
# def AE_interpolate(
|
|
686
|
+
# self,
|
|
687
|
+
# p_train,
|
|
688
|
+
# p_test,
|
|
689
|
+
# x_train_o,
|
|
690
|
+
# y_test_o,
|
|
691
|
+
# batch_size=None,
|
|
692
|
+
# latent_func=None,
|
|
693
|
+
# decode_func=None,
|
|
694
|
+
# norm_out_func=None,
|
|
695
|
+
# ):
|
|
696
|
+
# """Interpolates the latent space of the model and then decodes it to find the output."""
|
|
697
|
+
# batch_size = self.batch_size if batch_size is None else batch_size
|
|
698
|
+
|
|
699
|
+
# if latent_func is None:
|
|
700
|
+
# call_func = lambda x: self.model.latent(x)
|
|
701
|
+
# else:
|
|
702
|
+
# call_func = latent_func
|
|
703
|
+
|
|
704
|
+
# latent_train = eval_with_batches(
|
|
705
|
+
# x_train_o,
|
|
706
|
+
# batch_size,
|
|
707
|
+
# call_func=call_func,
|
|
708
|
+
# str="Finding train latent space used for interpolation...",
|
|
709
|
+
# key_idx=0,
|
|
710
|
+
# )
|
|
711
|
+
|
|
712
|
+
# interpolation = Objects_Interpolator_nD()
|
|
713
|
+
# latent_test_interp = interpolation(p_test, p_train, latent_train)
|
|
714
|
+
|
|
715
|
+
# if decode_func is None:
|
|
716
|
+
# call_func = lambda x: self.model.decode(x)
|
|
717
|
+
# else:
|
|
718
|
+
# call_func = decode_func
|
|
719
|
+
|
|
720
|
+
# y_pred_interp_test_o = eval_with_batches(
|
|
721
|
+
# latent_test_interp,
|
|
722
|
+
# batch_size,
|
|
723
|
+
# call_func=call_func,
|
|
724
|
+
# str="Decoding interpolated latent space ...",
|
|
725
|
+
# key_idx=0,
|
|
726
|
+
# )
|
|
727
|
+
|
|
728
|
+
# self.error_interp_test_o = (
|
|
729
|
+
# jnp.linalg.norm(y_pred_interp_test_o - y_test_o)
|
|
730
|
+
# / jnp.linalg.norm(y_test_o)
|
|
731
|
+
# * 100
|
|
732
|
+
# )
|
|
733
|
+
# print(
|
|
734
|
+
# "Test (interpolation) error over original output: ",
|
|
735
|
+
# self.error_interp_test_o,
|
|
736
|
+
# )
|
|
737
|
+
|
|
738
|
+
# if norm_out_func is None:
|
|
739
|
+
# call_func = lambda x: self.model.norm_out(x)
|
|
740
|
+
# else:
|
|
741
|
+
# call_func = norm_out_func
|
|
742
|
+
|
|
743
|
+
# y_pred_interp_test = eval_with_batches(
|
|
744
|
+
# y_pred_interp_test_o,
|
|
745
|
+
# batch_size,
|
|
746
|
+
# call_func=call_func,
|
|
747
|
+
# str="Finding Normalized pred of interpolated latent space ...",
|
|
748
|
+
# key_idx=0,
|
|
749
|
+
# )
|
|
750
|
+
|
|
751
|
+
# y_test = eval_with_batches(
|
|
752
|
+
# y_test_o,
|
|
753
|
+
# batch_size,
|
|
754
|
+
# call_func=call_func,
|
|
755
|
+
# str="Finding Normalized output of interpolated latent space ...",
|
|
756
|
+
# key_idx=0,
|
|
757
|
+
# )
|
|
758
|
+
# self.error_interp_test = (
|
|
759
|
+
# jnp.linalg.norm(y_pred_interp_test - y_test) / jnp.linalg.norm(y_test) * 100
|
|
760
|
+
# )
|
|
761
|
+
# print(
|
|
762
|
+
# "Test (interpolation) error over normalized output: ",
|
|
763
|
+
# self.error_interp_test,
|
|
764
|
+
# )
|
|
765
|
+
# return {
|
|
766
|
+
# "error_interp_test": self.error_interp_test,
|
|
767
|
+
# "error_interp_test_o": self.error_interp_test_o,
|
|
768
|
+
# "y_pred_interp_test_o": y_pred_interp_test_o,
|
|
769
|
+
# "y_pred_interp_test": y_pred_interp_test,
|
|
770
|
+
# }
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class RRAE_Trainor_class(AE_Trainor_class):
|
|
774
|
+
def __init__(self, *args, adapt=False, k_max=None, adap_type="None", **kwargs):
|
|
775
|
+
self.k_init = k_max
|
|
776
|
+
self.adap_type = adap_type
|
|
777
|
+
if k_max is not None:
|
|
778
|
+
kwargs["k_max"] = k_max
|
|
779
|
+
|
|
780
|
+
super().__init__(*args, **kwargs)
|
|
781
|
+
self.adapt = adapt
|
|
782
|
+
|
|
783
|
+
def fit(self, *args, **kwargs):
|
|
784
|
+
if self.adap_type == "pars":
|
|
785
|
+
default_tracker = RRAE_pars_Tracker(k_init=self.k_init)
|
|
786
|
+
elif self.adap_type == "gen":
|
|
787
|
+
if self.k_init is None:
|
|
788
|
+
warnings.warn(
|
|
789
|
+
"k_max can not be None when using gen adaptive scheme, choose a big initial k_max to start with."
|
|
790
|
+
)
|
|
791
|
+
default_tracker = RRAE_gen_Tracker(k_init=self.k_init)
|
|
792
|
+
elif self.adap_type == "None":
|
|
793
|
+
if self.k_init is None:
|
|
794
|
+
warnings.warn(
|
|
795
|
+
"k_max can not be None when using fixed scheme, choose a fixed k_max to use."
|
|
796
|
+
)
|
|
797
|
+
default_tracker = RRAE_fixed_Tracker(k_init=self.k_init)
|
|
798
|
+
|
|
799
|
+
print("Training RRAEs...")
|
|
800
|
+
|
|
801
|
+
if "training_kwargs" in kwargs:
|
|
802
|
+
training_kwargs = kwargs["training_kwargs"]
|
|
803
|
+
kwargs.pop("training_kwargs")
|
|
804
|
+
else:
|
|
805
|
+
training_kwargs = {}
|
|
806
|
+
|
|
807
|
+
if "ft_kwargs" in kwargs:
|
|
808
|
+
ft_kwargs = kwargs["ft_kwargs"]
|
|
809
|
+
kwargs.pop("ft_kwargs")
|
|
810
|
+
else:
|
|
811
|
+
ft_kwargs = {}
|
|
812
|
+
|
|
813
|
+
if "pre_func_inp" not in kwargs:
|
|
814
|
+
self.pre_func_inp = lambda x: x
|
|
815
|
+
else:
|
|
816
|
+
self.pre_func_inp = kwargs["pre_func_inp"]
|
|
817
|
+
|
|
818
|
+
if "tracker" not in training_kwargs:
|
|
819
|
+
training_kwargs["tracker"] = default_tracker
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
training_kwargs = merge_dicts(kwargs, training_kwargs)
|
|
823
|
+
|
|
824
|
+
model, track_params = super().fit(*args, training_kwargs=training_kwargs) # train model
|
|
825
|
+
|
|
826
|
+
self.track_params = track_params # Save track parameters in class?
|
|
827
|
+
|
|
828
|
+
if "batch_size_st" in training_kwargs:
|
|
829
|
+
self.batch_size = training_kwargs["batch_size_st"][-1]
|
|
830
|
+
else:
|
|
831
|
+
self.batch_size = 16 # default value
|
|
832
|
+
|
|
833
|
+
if ft_kwargs:
|
|
834
|
+
if "get_basis" in ft_kwargs:
|
|
835
|
+
get_basis = ft_kwargs["get_basis"]
|
|
836
|
+
ft_kwargs.pop("get_basis")
|
|
837
|
+
else:
|
|
838
|
+
get_basis = True
|
|
839
|
+
|
|
840
|
+
if "ft_end_type" in ft_kwargs:
|
|
841
|
+
ft_end_type = ft_kwargs["ft_end_type"]
|
|
842
|
+
ft_kwargs.pop("ft_end_type")
|
|
843
|
+
else:
|
|
844
|
+
ft_end_type = "concat"
|
|
845
|
+
|
|
846
|
+
if "basis_call_kwargs" in ft_kwargs:
|
|
847
|
+
basis_call_kwargs = ft_kwargs["basis_call_kwargs"]
|
|
848
|
+
ft_kwargs.pop("basis_call_kwargs")
|
|
849
|
+
else:
|
|
850
|
+
ft_end_type = "concat"
|
|
851
|
+
basis_call_kwargs = {}
|
|
852
|
+
|
|
853
|
+
ft_model, ft_track_params = self.fine_tune_basis(
|
|
854
|
+
None, args=args, kwargs=ft_kwargs, get_basis=get_basis, end_type=ft_end_type, basis_call_kwargs=basis_call_kwargs
|
|
855
|
+
) # fine tune basis
|
|
856
|
+
self.ft_track_params = ft_track_params
|
|
857
|
+
else:
|
|
858
|
+
ft_model = None
|
|
859
|
+
ft_track_params = {}
|
|
860
|
+
return model, track_params, ft_model, ft_track_params
|
|
861
|
+
|
|
862
|
+
def fine_tune_basis(self, basis=None, get_basis=True, end_type="concat", basis_call_kwargs={}, *, args, kwargs):
|
|
863
|
+
|
|
864
|
+
if "loss" in kwargs:
|
|
865
|
+
norm_loss_ = kwargs["loss"]
|
|
866
|
+
else:
|
|
867
|
+
print("Defaulting to L2 norm")
|
|
868
|
+
norm_loss_ = lambda x1, x2: 100 * (
|
|
869
|
+
torch.linalg.norm(x1 - x2) / torch.linalg.norm(x2)
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
if (basis is None):
|
|
873
|
+
with torch.no_grad():
|
|
874
|
+
if get_basis:
|
|
875
|
+
inp = args[0] if len(args) > 0 else kwargs["input"]
|
|
876
|
+
|
|
877
|
+
if "basis_batch_size" in kwargs:
|
|
878
|
+
basis_batch_size = kwargs["basis_batch_size"]
|
|
879
|
+
kwargs.pop("basis_batch_size")
|
|
880
|
+
else:
|
|
881
|
+
basis_batch_size = self.batch_size
|
|
882
|
+
|
|
883
|
+
basis_kwargs = basis_call_kwargs | self.track_params
|
|
884
|
+
|
|
885
|
+
inpT = inp.permute(*range(inp.ndim - 1, -1, -1))
|
|
886
|
+
dataset = TensorDataset(inpT)
|
|
887
|
+
dataloader = DataLoader(dataset, batch_size=basis_batch_size, shuffle=False)
|
|
888
|
+
|
|
889
|
+
all_bases = []
|
|
890
|
+
|
|
891
|
+
for inp_b in dataloader:
|
|
892
|
+
inp_bT = inp_b[0].permute(*range(inp_b[0].ndim - 1, -1, -1))
|
|
893
|
+
all_bases.append(self.model.latent(
|
|
894
|
+
self.pre_func_inp(inp_bT), get_basis_coeffs=True, **basis_kwargs
|
|
895
|
+
)[0]
|
|
896
|
+
)
|
|
897
|
+
if end_type == "concat":
|
|
898
|
+
all_bases = torch.concatenate(all_bases, axis=1)
|
|
899
|
+
print(all_bases.shape)
|
|
900
|
+
basis = torch.linalg.svd(all_bases, full_matrices=False)[0]
|
|
901
|
+
self.basis = basis[:, : self.track_params["k_max"]]
|
|
902
|
+
else:
|
|
903
|
+
self.basis = all_bases
|
|
904
|
+
else:
|
|
905
|
+
bas = self.model.latent(self.pre_func_inp(inp[..., 0:1]), get_basis_coeffs=True, **self.track_params)[0]
|
|
906
|
+
self.basis = torch.eye(bas.shape[0])
|
|
907
|
+
else:
|
|
908
|
+
self.basis = basis
|
|
909
|
+
|
|
910
|
+
def loss_fun(model, input, out, idx, epsilon, basis):
|
|
911
|
+
pred = model(input, epsilon=epsilon, apply_basis=basis, keep_normalized=True)
|
|
912
|
+
aux = {"loss": norm_loss_(pred, out)}
|
|
913
|
+
return norm_loss_(pred, out), (aux, {})
|
|
914
|
+
|
|
915
|
+
if "loss_type" in kwargs :
|
|
916
|
+
pass
|
|
917
|
+
else:
|
|
918
|
+
print("Defaulting to standard loss")
|
|
919
|
+
kwargs["loss_type"] = loss_fun
|
|
920
|
+
|
|
921
|
+
kwargs.setdefault("loss_kwargs", {}).update({"basis": self.basis})
|
|
922
|
+
|
|
923
|
+
fix_comp = lambda model: model._encode.parameters()
|
|
924
|
+
print("Fine tuning the basis ...")
|
|
925
|
+
return super().fit(*args, fix_comp=fix_comp, training_kwargs=kwargs)
|
|
926
|
+
|
|
927
|
+
def evaluate(
|
|
928
|
+
self,
|
|
929
|
+
x_train_o=None,
|
|
930
|
+
y_train_o=None,
|
|
931
|
+
x_test_o=None,
|
|
932
|
+
y_test_o=None,
|
|
933
|
+
batch_size=None,
|
|
934
|
+
pre_func_inp=lambda x: x,
|
|
935
|
+
pre_func_out=lambda x: x,
|
|
936
|
+
call_func=None,
|
|
937
|
+
):
|
|
938
|
+
|
|
939
|
+
call_func = lambda x: self.model(pre_func_inp(x), apply_basis=self.basis, epsilon=None)
|
|
940
|
+
res = super().evaluate(
|
|
941
|
+
x_train_o,
|
|
942
|
+
y_train_o,
|
|
943
|
+
x_test_o,
|
|
944
|
+
y_test_o,
|
|
945
|
+
batch_size,
|
|
946
|
+
call_func=call_func,
|
|
947
|
+
pre_func_inp=pre_func_inp,
|
|
948
|
+
pre_func_out=pre_func_out,
|
|
949
|
+
)
|
|
950
|
+
return res
|
|
951
|
+
|
|
952
|
+
def AE_interpolate(
|
|
953
|
+
self,
|
|
954
|
+
p_train,
|
|
955
|
+
p_test,
|
|
956
|
+
x_train_o,
|
|
957
|
+
y_test_o,
|
|
958
|
+
batch_size=None,
|
|
959
|
+
latent_func=None,
|
|
960
|
+
decode_func=None,
|
|
961
|
+
norm_out_func=None,
|
|
962
|
+
):
|
|
963
|
+
call_func = lambda x: (
|
|
964
|
+
self.model.latent(x, apply_basis=self.basis)
|
|
965
|
+
if latent_func is None
|
|
966
|
+
else latent_func
|
|
967
|
+
)
|
|
968
|
+
return super().AE_interpolate(
|
|
969
|
+
p_train,
|
|
970
|
+
p_test,
|
|
971
|
+
x_train_o,
|
|
972
|
+
y_test_o,
|
|
973
|
+
batch_size,
|
|
974
|
+
latent_func=call_func,
|
|
975
|
+
decode_func=decode_func,
|
|
976
|
+
norm_out_func=norm_out_func,
|
|
977
|
+
)
|