dsipts 1.1.12__py3-none-any.whl → 1.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +4 -1
- dsipts/data_structure/data_structure.py +110 -32
- dsipts/data_structure/utils.py +4 -2
- dsipts/models/Persistent.py +2 -0
- dsipts/models/TTM.py +31 -9
- dsipts/models/TimeKAN.py +123 -0
- dsipts/models/TimesNet.py +96 -0
- dsipts/models/base.py +9 -1
- dsipts/models/base_v2.py +17 -7
- dsipts/models/timekan/Layers.py +284 -0
- dsipts/models/timekan/__init__.py +0 -0
- dsipts/models/timesnet/Layers.py +95 -0
- dsipts/models/timesnet/__init__.py +0 -0
- dsipts/version.py +1 -0
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/METADATA +56 -8
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/RECORD +18 -11
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/WHEEL +0 -0
- {dsipts-1.1.12.dist-info → dsipts-1.1.15.dist-info}/top_level.txt +0 -0
dsipts/__init__.py
CHANGED
|
@@ -26,6 +26,9 @@ from .models.TTM import TTM
|
|
|
26
26
|
from .models.Samformer import Samformer
|
|
27
27
|
from .models.Duet import Duet
|
|
28
28
|
from .models.Simple import Simple
|
|
29
|
+
from .models.TimesNet import TimesNet
|
|
30
|
+
from .models.TimeKAN import TimeKAN
|
|
31
|
+
from .version import __version__
|
|
29
32
|
try:
|
|
30
33
|
import lightning.pytorch as pl
|
|
31
34
|
from .models.base_v2 import Base
|
|
@@ -44,5 +47,5 @@ __all__ = [
|
|
|
44
47
|
"RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
|
|
45
48
|
"Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
|
|
46
49
|
"Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
|
|
47
|
-
"TTM", "Samformer", "Duet", "Base", "Simple"
|
|
50
|
+
"TTM", "Samformer", "Duet", "Base", "Simple","TimesNet","TimeKAN"
|
|
48
51
|
]
|
|
@@ -6,7 +6,7 @@ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
|
|
|
6
6
|
from sklearn.preprocessing import *
|
|
7
7
|
from torch.utils.data import DataLoader
|
|
8
8
|
from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
|
|
9
|
-
|
|
9
|
+
from torch.utils.data.sampler import WeightedRandomSampler
|
|
10
10
|
try:
|
|
11
11
|
|
|
12
12
|
#new version of lightning
|
|
@@ -249,7 +249,8 @@ class TimeSeries():
|
|
|
249
249
|
check_past:bool=True,
|
|
250
250
|
group:Union[None,str]=None,
|
|
251
251
|
check_holes_and_duplicates:bool=True,
|
|
252
|
-
silly_model:bool=False
|
|
252
|
+
silly_model:bool=False,
|
|
253
|
+
sampler_weights:Union[None,str]=None)->None:
|
|
253
254
|
""" This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
|
|
254
255
|
There are some checks:
|
|
255
256
|
1- the duplicates will tbe removed taking the first instance
|
|
@@ -270,6 +271,7 @@ class TimeSeries():
|
|
|
270
271
|
group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
|
|
271
272
|
check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
|
|
272
273
|
silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
|
|
274
|
+
sampler_weights group (str or None, optional): if it is a column name it will be used as weight for the sampler. Careful that the weight of the sample is the weight value of the fist target value (index)
|
|
273
275
|
"""
|
|
274
276
|
|
|
275
277
|
|
|
@@ -322,7 +324,7 @@ class TimeSeries():
|
|
|
322
324
|
if group is not None:
|
|
323
325
|
if group not in cat_past_var:
|
|
324
326
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
325
|
-
self.
|
|
327
|
+
self.cat_past_var.append(group)
|
|
326
328
|
if group not in cat_fut_var:
|
|
327
329
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
328
330
|
self.cat_fut_var.append(group)
|
|
@@ -350,7 +352,7 @@ class TimeSeries():
|
|
|
350
352
|
if silly_model:
|
|
351
353
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
352
354
|
self.future_variables+=self.target_variables
|
|
353
|
-
|
|
355
|
+
self.sampler_weights = sampler_weights
|
|
354
356
|
def plot(self):
|
|
355
357
|
"""
|
|
356
358
|
Easy way to control the loaded data
|
|
@@ -409,6 +411,7 @@ class TimeSeries():
|
|
|
409
411
|
y_samples = []
|
|
410
412
|
t_samples = []
|
|
411
413
|
g_samples = []
|
|
414
|
+
sampler_weights_samples = []
|
|
412
415
|
|
|
413
416
|
if starting_point is not None:
|
|
414
417
|
kk = list(starting_point.keys())[0]
|
|
@@ -448,7 +451,8 @@ class TimeSeries():
|
|
|
448
451
|
|
|
449
452
|
idx_target = []
|
|
450
453
|
for c in self.target_variables:
|
|
451
|
-
|
|
454
|
+
if c in self.past_variables:
|
|
455
|
+
idx_target.append(self.past_variables.index(c))
|
|
452
456
|
|
|
453
457
|
idx_target_future = []
|
|
454
458
|
|
|
@@ -475,7 +479,8 @@ class TimeSeries():
|
|
|
475
479
|
if len(self.cat_fut_var)>0:
|
|
476
480
|
x_fut_cat = tmp[self.cat_fut_var].values
|
|
477
481
|
y_target = tmp[self.target_variables].values
|
|
478
|
-
|
|
482
|
+
if self.sampler_weights is not None:
|
|
483
|
+
sampler_weights = tmp[self.sampler_weights].values.flatten()
|
|
479
484
|
|
|
480
485
|
if starting_point is not None:
|
|
481
486
|
check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
|
|
@@ -512,6 +517,8 @@ class TimeSeries():
|
|
|
512
517
|
x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
|
|
513
518
|
|
|
514
519
|
y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
|
|
520
|
+
if self.sampler_weights is not None:
|
|
521
|
+
sampler_weights_samples.append(sampler_weights[i+skip_stacked])
|
|
515
522
|
t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
|
|
516
523
|
g_samples.append(groups[i])
|
|
517
524
|
|
|
@@ -524,6 +531,8 @@ class TimeSeries():
|
|
|
524
531
|
beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
|
|
525
532
|
|
|
526
533
|
y_samples = np.stack(y_samples)
|
|
534
|
+
if self.sampler_weights is not None:
|
|
535
|
+
sampler_weights_samples = np.stack(sampler_weights_samples)
|
|
527
536
|
t_samples = np.stack(t_samples)
|
|
528
537
|
g_samples = np.stack(g_samples)
|
|
529
538
|
|
|
@@ -537,7 +546,6 @@ class TimeSeries():
|
|
|
537
546
|
else:
|
|
538
547
|
mod = 1.0
|
|
539
548
|
dd = {'y':y_samples.astype(np.float32),
|
|
540
|
-
|
|
541
549
|
'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
|
|
542
550
|
if len(self.cat_past_var)>0:
|
|
543
551
|
dd['x_cat_past'] = x_cat_past_samples
|
|
@@ -545,7 +553,10 @@ class TimeSeries():
|
|
|
545
553
|
dd['x_cat_future'] = x_cat_future_samples
|
|
546
554
|
if len(self.future_variables)>0:
|
|
547
555
|
dd['x_num_future'] = x_num_future_samples.astype(np.float32)
|
|
548
|
-
|
|
556
|
+
if self.sampler_weights is not None:
|
|
557
|
+
dd['sampler_weights'] = sampler_weights_samples.astype(np.float32)
|
|
558
|
+
else:
|
|
559
|
+
dd['sampler_weights'] = np.ones(len(y_samples)).astype(np.float32)
|
|
549
560
|
return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
|
|
550
561
|
|
|
551
562
|
|
|
@@ -660,7 +671,7 @@ class TimeSeries():
|
|
|
660
671
|
if c!=self.group:
|
|
661
672
|
self.scaler_cat[f'{c}_{group}'] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
|
|
662
673
|
self.scaler_cat[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
|
|
663
|
-
|
|
674
|
+
|
|
664
675
|
dl_train = self.create_data_loader(train,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
|
|
665
676
|
dl_validation = self.create_data_loader(validation,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
|
|
666
677
|
if test.shape[0]>0:
|
|
@@ -753,10 +764,33 @@ class TimeSeries():
|
|
|
753
764
|
else:
|
|
754
765
|
self.modifier = None
|
|
755
766
|
|
|
767
|
+
if self.sampler_weights is not None:
|
|
768
|
+
beauty_string(f'USING SAMPLER IN TRAIN {min(train.sampler_weights)}-{max(train.sampler_weights)}','section',self.verbose)
|
|
769
|
+
|
|
770
|
+
sampler = WeightedRandomSampler(train.sampler_weights, num_samples= len(train))
|
|
771
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=False,sampler=sampler,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
756
772
|
|
|
757
|
-
|
|
773
|
+
else:
|
|
774
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
758
775
|
valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
759
|
-
|
|
776
|
+
debug_prediction = True
|
|
777
|
+
|
|
778
|
+
if debug_prediction:
|
|
779
|
+
dl = DataLoader(test, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
780
|
+
res = []
|
|
781
|
+
real = []
|
|
782
|
+
for batch in dl:
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
res.append(self.model.inference(batch).cpu().detach().numpy())
|
|
786
|
+
real.append(batch['y'].cpu().detach().numpy())
|
|
787
|
+
|
|
788
|
+
res = np.vstack(res)
|
|
789
|
+
real = np.vstack(real)
|
|
790
|
+
with open('/home/agobbi/Projects/ExpTS/tmp_beginning.pkl','wb') as f:
|
|
791
|
+
import pickle
|
|
792
|
+
pickle.dump([res,real],f)
|
|
793
|
+
|
|
760
794
|
checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
|
|
761
795
|
monitor='val_loss',
|
|
762
796
|
save_last = True,
|
|
@@ -851,23 +885,24 @@ class TimeSeries():
|
|
|
851
885
|
|
|
852
886
|
|
|
853
887
|
|
|
854
|
-
|
|
888
|
+
|
|
855
889
|
if auto_lr_find and (weight_exists is False):
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
890
|
+
try:
|
|
891
|
+
if OLD_PL:
|
|
892
|
+
lr_tuner = trainer.tune(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
|
|
893
|
+
files = os.listdir(dirpath)
|
|
894
|
+
for f in files:
|
|
895
|
+
if '.lr_find' in f:
|
|
896
|
+
os.remove(os.path.join(dirpath,f))
|
|
897
|
+
self.model.optim_config['lr'] = lr_tuner['lr_find'].suggestion()
|
|
898
|
+
else:
|
|
899
|
+
from lightning.pytorch.tuner import Tuner
|
|
900
|
+
tuner = Tuner(trainer)
|
|
901
|
+
lr_finder = tuner.lr_find(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
|
|
902
|
+
self.model.optim_config['lr'] = lr_finder.suggestion() ## we are using it as optim key
|
|
903
|
+
except Exception as e:
|
|
904
|
+
beauty_string(f'There is a problem with the finding LR routine {e}','section',self.verbose)
|
|
868
905
|
|
|
869
|
-
|
|
870
|
-
|
|
871
906
|
if OLD_PL:
|
|
872
907
|
if weight_exists:
|
|
873
908
|
trainer.fit(self.model, train_dl,valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
|
|
@@ -878,17 +913,36 @@ class TimeSeries():
|
|
|
878
913
|
trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
|
|
879
914
|
else:
|
|
880
915
|
trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl)
|
|
916
|
+
|
|
881
917
|
self.checkpoint_file_best = checkpoint_callback.best_model_path
|
|
882
918
|
self.checkpoint_file_last = checkpoint_callback.last_model_path
|
|
883
919
|
if self.checkpoint_file_last=='':
|
|
884
920
|
beauty_string('There is a bug on saving last model I will try to fix it','info',self.verbose)
|
|
885
|
-
self.checkpoint_file_last =
|
|
921
|
+
self.checkpoint_file_last = os.path.join(dirpath, "last.ckpt")
|
|
922
|
+
trainer.save_checkpoint(os.path.join(dirpath, "last.ckpt"))
|
|
923
|
+
if self.checkpoint_file_best=='':
|
|
924
|
+
beauty_string('There is a bug on saving best model I will try to fix it','info',self.verbose)
|
|
925
|
+
self.checkpoint_file_best = os.path.join(dirpath, "checkpoint.ckpt")
|
|
926
|
+
trainer.save_checkpoint(os.path.join(dirpath, "checkpoint.ckpt"))
|
|
886
927
|
|
|
887
928
|
self.dirpath = dirpath
|
|
888
929
|
|
|
889
930
|
self.losses = mc.metrics
|
|
890
931
|
|
|
891
932
|
files = os.listdir(dirpath)
|
|
933
|
+
if debug_prediction:
|
|
934
|
+
res = []
|
|
935
|
+
real = []
|
|
936
|
+
for batch in dl:
|
|
937
|
+
res.append(self.model.inference(batch).cpu().detach().numpy())
|
|
938
|
+
real.append(batch['y'].cpu().detach().numpy())
|
|
939
|
+
|
|
940
|
+
res = np.vstack(res)
|
|
941
|
+
real = np.vstack(real)
|
|
942
|
+
with open('/home/agobbi/Projects/ExpTS/tmp_after_training.pkl','wb') as f:
|
|
943
|
+
import pickle
|
|
944
|
+
pickle.dump([res,real],f)
|
|
945
|
+
|
|
892
946
|
##accrocchio per multi gpu
|
|
893
947
|
for f in files:
|
|
894
948
|
if '__losses__.csv' in f:
|
|
@@ -901,7 +955,6 @@ class TimeSeries():
|
|
|
901
955
|
self.losses = pd.DataFrame()
|
|
902
956
|
|
|
903
957
|
try:
|
|
904
|
-
|
|
905
958
|
if OLD_PL:
|
|
906
959
|
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
907
960
|
self.model = self.model._orig_mod
|
|
@@ -914,7 +967,21 @@ class TimeSeries():
|
|
|
914
967
|
self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
915
968
|
else:
|
|
916
969
|
self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
970
|
+
if debug_prediction:
|
|
971
|
+
res = []
|
|
972
|
+
real = []
|
|
973
|
+
for batch in dl:
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
res.append(self.model.inference(batch).cpu().detach().numpy())
|
|
977
|
+
real.append(batch['y'].cpu().detach().numpy())
|
|
917
978
|
|
|
979
|
+
res = np.vstack(res)
|
|
980
|
+
real = np.vstack(real)
|
|
981
|
+
with open('/home/agobbi/Projects/ExpTS/tmp_after_loading.pkl','wb') as f:
|
|
982
|
+
import pickle
|
|
983
|
+
pickle.dump([res,real],f)
|
|
984
|
+
|
|
918
985
|
except Exception as _:
|
|
919
986
|
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
|
|
920
987
|
|
|
@@ -992,12 +1059,15 @@ class TimeSeries():
|
|
|
992
1059
|
beauty_string(f'Device used: {self.model.device}','info',self.verbose)
|
|
993
1060
|
|
|
994
1061
|
for batch in dl:
|
|
1062
|
+
|
|
1063
|
+
|
|
995
1064
|
res.append(self.model.inference(batch).cpu().detach().numpy())
|
|
996
1065
|
real.append(batch['y'].cpu().detach().numpy())
|
|
997
|
-
|
|
1066
|
+
|
|
998
1067
|
res = np.vstack(res)
|
|
999
|
-
|
|
1000
1068
|
real = np.vstack(real)
|
|
1069
|
+
|
|
1070
|
+
|
|
1001
1071
|
time = dl.dataset.t
|
|
1002
1072
|
groups = dl.dataset.groups
|
|
1003
1073
|
#import pdb
|
|
@@ -1026,7 +1096,7 @@ class TimeSeries():
|
|
|
1026
1096
|
|
|
1027
1097
|
if self.group is not None:
|
|
1028
1098
|
time[self.group] = groups
|
|
1029
|
-
time = time.melt(id_vars=[
|
|
1099
|
+
time = time.melt(id_vars=[self.group])
|
|
1030
1100
|
else:
|
|
1031
1101
|
time = time.melt()
|
|
1032
1102
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -1048,7 +1118,8 @@ class TimeSeries():
|
|
|
1048
1118
|
|
|
1049
1119
|
if self.group is not None:
|
|
1050
1120
|
time[self.group] = groups
|
|
1051
|
-
|
|
1121
|
+
|
|
1122
|
+
time = time.melt(id_vars=[self.group])
|
|
1052
1123
|
else:
|
|
1053
1124
|
time = time.melt()
|
|
1054
1125
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -1186,15 +1257,21 @@ class TimeSeries():
|
|
|
1186
1257
|
|
|
1187
1258
|
try:
|
|
1188
1259
|
tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
|
|
1260
|
+
beauty_string(f"Loading {tmp_path}",'section',self.verbose)
|
|
1261
|
+
|
|
1189
1262
|
except Exception as _:
|
|
1190
1263
|
beauty_string('checkpoint_file_last not defined try to load best','section',self.verbose)
|
|
1191
1264
|
tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
|
|
1265
|
+
beauty_string(f"Loading {tmp_path}",'section',self.verbose)
|
|
1192
1266
|
else:
|
|
1193
1267
|
try:
|
|
1268
|
+
|
|
1194
1269
|
tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
|
|
1270
|
+
beauty_string(f"Loading {tmp_path}",'section',self.verbose)
|
|
1195
1271
|
except Exception as _:
|
|
1196
1272
|
beauty_string('checkpoint_file_best not defined try to load best','section',self.verbose)
|
|
1197
1273
|
tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
|
|
1274
|
+
beauty_string(f"Loading {tmp_path}",'section',self.verbose)
|
|
1198
1275
|
try:
|
|
1199
1276
|
#with torch.serialization.add_safe_globals([ListConfig]):
|
|
1200
1277
|
if OLD_PL:
|
|
@@ -1202,5 +1279,6 @@ class TimeSeries():
|
|
|
1202
1279
|
else:
|
|
1203
1280
|
self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1204
1281
|
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
1282
|
+
|
|
1205
1283
|
except Exception as e:
|
|
1206
1284
|
beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
|
dsipts/data_structure/utils.py
CHANGED
|
@@ -142,12 +142,13 @@ class MyDataset(Dataset):
|
|
|
142
142
|
Returns:
|
|
143
143
|
torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
|
|
144
144
|
"""
|
|
145
|
+
|
|
145
146
|
self.data = data
|
|
146
147
|
self.t = t
|
|
147
148
|
self.groups = groups
|
|
148
149
|
self.idx_target = np.array(idx_target) if idx_target is not None else None
|
|
149
150
|
self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
|
|
150
|
-
|
|
151
|
+
self.sampler_weights = data['sampler_weights']
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
def __len__(self):
|
|
@@ -157,7 +158,8 @@ class MyDataset(Dataset):
|
|
|
157
158
|
def __getitem__(self, idxs):
|
|
158
159
|
sample = {}
|
|
159
160
|
for k in self.data:
|
|
160
|
-
|
|
161
|
+
if k!='sampler_weights':
|
|
162
|
+
sample[k] = self.data[k][idxs]
|
|
161
163
|
if self.idx_target is not None:
|
|
162
164
|
sample['idx_target'] = self.idx_target
|
|
163
165
|
if self.idx_target_future is not None:
|
dsipts/models/Persistent.py
CHANGED
dsipts/models/TTM.py
CHANGED
|
@@ -38,6 +38,8 @@ class TTM(Base):
|
|
|
38
38
|
fcm_mix_layers,
|
|
39
39
|
fcm_prepend_past,
|
|
40
40
|
enable_forecast_channel_mixing,
|
|
41
|
+
force_return,
|
|
42
|
+
few_shot = True,
|
|
41
43
|
**kwargs)->None:
|
|
42
44
|
|
|
43
45
|
super().__init__(**kwargs)
|
|
@@ -48,7 +50,9 @@ class TTM(Base):
|
|
|
48
50
|
self.index_fut = list(exogenous_channel_indices_cont)
|
|
49
51
|
|
|
50
52
|
if len(exogenous_channel_indices_cat)>0:
|
|
51
|
-
|
|
53
|
+
|
|
54
|
+
self.index_fut_cat = [self.past_channels+c for c in list(exogenous_channel_indices_cat)]
|
|
55
|
+
|
|
52
56
|
else:
|
|
53
57
|
self.index_fut_cat = []
|
|
54
58
|
self.freq = freq
|
|
@@ -75,6 +79,7 @@ class TTM(Base):
|
|
|
75
79
|
fcm_use_mixer=fcm_use_mixer,
|
|
76
80
|
fcm_mix_layers=fcm_mix_layers,
|
|
77
81
|
freq=freq,
|
|
82
|
+
force_return=force_return,
|
|
78
83
|
freq_prefix_tuning=freq_prefix_tuning,
|
|
79
84
|
fcm_prepend_past=fcm_prepend_past,
|
|
80
85
|
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
@@ -82,8 +87,9 @@ class TTM(Base):
|
|
|
82
87
|
)
|
|
83
88
|
hidden_size = self.model.config.hidden_size
|
|
84
89
|
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
85
|
-
|
|
86
|
-
|
|
90
|
+
if few_shot:
|
|
91
|
+
self._freeze_backbone()
|
|
92
|
+
self.zero_pad = (force_return=='zeropad')
|
|
87
93
|
def _freeze_backbone(self):
|
|
88
94
|
"""
|
|
89
95
|
Freeze the backbone of the model.
|
|
@@ -108,29 +114,44 @@ class TTM(Base):
|
|
|
108
114
|
return input
|
|
109
115
|
|
|
110
116
|
def can_be_compiled(self):
|
|
111
|
-
|
|
117
|
+
|
|
118
|
+
return True#True#not self.zero_pad
|
|
112
119
|
|
|
113
120
|
def forward(self, batch):
|
|
114
121
|
x_enc = batch['x_num_past'].to(self.device)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if self.zero_pad:
|
|
125
|
+
B,L,C = batch['x_num_past'].shape
|
|
126
|
+
x_enc = torch.zeros((B,512,C)).to(self.device)
|
|
127
|
+
x_enc[:,-L:,:] = batch['x_num_past'].to(self.device)
|
|
128
|
+
else:
|
|
129
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
115
130
|
original_indexes = batch['idx_target'][0].tolist()
|
|
116
131
|
|
|
117
132
|
|
|
118
133
|
if 'x_cat_past' in batch.keys():
|
|
119
|
-
|
|
120
|
-
|
|
134
|
+
if self.zero_pad:
|
|
135
|
+
B,L,C = batch['x_cat_past'].shape
|
|
136
|
+
x_mark_enc = torch.zeros((B,512,C)).to(self.device)
|
|
137
|
+
x_mark_enc[:,-L:,:] = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
138
|
+
else:
|
|
139
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
140
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
121
141
|
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
122
142
|
else:
|
|
123
143
|
past_values = x_enc
|
|
144
|
+
B,L,C = past_values.shape
|
|
145
|
+
future_values = torch.zeros((B,self.future_steps,C)).to(self.device)
|
|
124
146
|
|
|
125
|
-
future_values = torch.zeros_like(past_values).to(self.device)
|
|
126
|
-
future_values = future_values[:,:self.future_steps,:]
|
|
127
147
|
|
|
148
|
+
|
|
128
149
|
if 'x_num_future' in batch.keys():
|
|
129
150
|
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
130
151
|
if 'x_cat_future' in batch.keys():
|
|
131
152
|
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
132
153
|
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
133
|
-
future_values[:,:,self.
|
|
154
|
+
future_values[:,:,self.index_fut_cat] = x_mark_dec
|
|
134
155
|
|
|
135
156
|
|
|
136
157
|
#investigating!! problem with dynamo!
|
|
@@ -153,6 +174,7 @@ class TTM(Base):
|
|
|
153
174
|
|
|
154
175
|
|
|
155
176
|
BS = res.shape[0]
|
|
177
|
+
|
|
156
178
|
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
157
179
|
|
|
158
180
|
|
dsipts/models/TimeKAN.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
## Copyright https://github.com/huangst21/TimeKAN/blob/main/models/TimeKAN.py
|
|
2
|
+
## Modified for notation alignmenet and batch structure
|
|
3
|
+
## extended to what inside itransformer folder
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
from .timekan.Layers import FrequencyDecomp,FrequencyMixing,series_decomp,Normalize
|
|
9
|
+
from ..data_structure.utils import beauty_string
|
|
10
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import lightning.pytorch as pl
|
|
14
|
+
from .base_v2 import Base
|
|
15
|
+
OLD_PL = False
|
|
16
|
+
except:
|
|
17
|
+
import pytorch_lightning as pl
|
|
18
|
+
OLD_PL = True
|
|
19
|
+
from .base import Base
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TimeKAN(Base):
|
|
24
|
+
handle_multivariate = True
|
|
25
|
+
handle_future_covariates = True
|
|
26
|
+
handle_categorical_variables = True
|
|
27
|
+
handle_quantile_loss = True
|
|
28
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
# specific params
|
|
32
|
+
down_sampling_window:int,
|
|
33
|
+
e_layers:int,
|
|
34
|
+
moving_avg:int,
|
|
35
|
+
down_sampling_layers: int,
|
|
36
|
+
d_model: int,
|
|
37
|
+
begin_order: int,
|
|
38
|
+
use_norm:bool,
|
|
39
|
+
**kwargs)->None:
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
super().__init__(**kwargs)
|
|
45
|
+
|
|
46
|
+
self.save_hyperparameters(logger=False)
|
|
47
|
+
self.e_layers = e_layers
|
|
48
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
49
|
+
#self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
50
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
self.res_blocks = nn.ModuleList([FrequencyDecomp( self.past_steps,down_sampling_window,down_sampling_layers) for _ in range(e_layers)])
|
|
55
|
+
self.add_blocks = nn.ModuleList([FrequencyMixing(d_model,self.past_steps,begin_order,down_sampling_window,down_sampling_layers) for _ in range(e_layers)])
|
|
56
|
+
|
|
57
|
+
self.preprocess = series_decomp(moving_avg)
|
|
58
|
+
self.enc_in = self.past_channels + emb_past_out_channel
|
|
59
|
+
self.project = nn.Linear(self.enc_in,d_model)
|
|
60
|
+
self.layer = e_layers
|
|
61
|
+
self.normalize_layers = torch.nn.ModuleList(
|
|
62
|
+
[
|
|
63
|
+
Normalize(self.enc_in, affine=True, non_norm=use_norm)
|
|
64
|
+
for i in range(down_sampling_layers + 1)
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
self.predict_layer =nn. Linear(
|
|
68
|
+
self.past_steps,
|
|
69
|
+
self.future_steps,
|
|
70
|
+
)
|
|
71
|
+
self.final_layer = nn.Linear(d_model, self.mul)
|
|
72
|
+
self.down_sampling_layers = down_sampling_layers
|
|
73
|
+
self.down_sampling_window = down_sampling_window
|
|
74
|
+
def can_be_compiled(self):
|
|
75
|
+
return True#True
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def __multi_level_process_inputs(self, x_enc):
|
|
79
|
+
down_pool = torch.nn.AvgPool1d(self.down_sampling_window)
|
|
80
|
+
# B,T,C -> B,C,T
|
|
81
|
+
x_enc = x_enc.permute(0, 2, 1)
|
|
82
|
+
x_enc_ori = x_enc
|
|
83
|
+
x_enc_sampling_list = []
|
|
84
|
+
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
|
|
85
|
+
for i in range(self.down_sampling_layers):
|
|
86
|
+
x_enc_sampling = down_pool(x_enc_ori)
|
|
87
|
+
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
|
|
88
|
+
x_enc_ori = x_enc_sampling
|
|
89
|
+
x_enc = x_enc_sampling_list
|
|
90
|
+
return x_enc
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def forward(self, batch:dict)-> float:
|
|
94
|
+
|
|
95
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
96
|
+
BS = x_enc.shape[0]
|
|
97
|
+
if 'x_cat_past' in batch.keys():
|
|
98
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
99
|
+
else:
|
|
100
|
+
emb_past = self.emb_past(BS,None)
|
|
101
|
+
|
|
102
|
+
x_past = torch.cat([x_enc,emb_past],2)
|
|
103
|
+
|
|
104
|
+
x_enc = self.__multi_level_process_inputs(x_past)
|
|
105
|
+
|
|
106
|
+
x_list = []
|
|
107
|
+
for i, x in zip(range(len(x_enc)), x_enc, ):
|
|
108
|
+
B, T, N = x.size()
|
|
109
|
+
x = self.normalize_layers[i](x, 'norm')
|
|
110
|
+
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
|
|
111
|
+
x_list.append(self.project(x.reshape(B, T, N)))
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
for i in range(self.layer):
|
|
116
|
+
x_list = self.res_blocks[i](x_list)
|
|
117
|
+
x_list = self.add_blocks[i](x_list)
|
|
118
|
+
|
|
119
|
+
dec_out = x_list[0]
|
|
120
|
+
dec_out = self.predict_layer(dec_out.permute(0, 2, 1)).permute( 0, 2, 1)
|
|
121
|
+
dec_out = self.final_layer(dec_out)
|
|
122
|
+
|
|
123
|
+
return dec_out.reshape(BS,self.future_steps,self.out_channels,self.mul)
|