dsipts 1.1.12__py3-none-any.whl → 1.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

dsipts/__init__.py CHANGED
@@ -26,6 +26,9 @@ from .models.TTM import TTM
26
26
  from .models.Samformer import Samformer
27
27
  from .models.Duet import Duet
28
28
  from .models.Simple import Simple
29
+ from .models.TimesNet import TimesNet
30
+ from .models.TimeKAN import TimeKAN
31
+ from .version import __version__
29
32
  try:
30
33
  import lightning.pytorch as pl
31
34
  from .models.base_v2 import Base
@@ -44,5 +47,5 @@ __all__ = [
44
47
  "RNN", "LinearTS", "Persistent", "D3VAE", "DilatedConv", "TFT",
45
48
  "Informer", "VVA", "VQVAEA", "CrossFormer", "Autoformer", "PatchTST",
46
49
  "Diffusion", "DilatedConvED", "TIDE", "ITransformer", "TimeXER",
47
- "TTM", "Samformer", "Duet", "Base", "Simple"
50
+ "TTM", "Samformer", "Duet", "Base", "Simple","TimesNet","TimeKAN"
48
51
  ]
@@ -6,7 +6,7 @@ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
6
6
  from sklearn.preprocessing import *
7
7
  from torch.utils.data import DataLoader
8
8
  from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
9
-
9
+ from torch.utils.data.sampler import WeightedRandomSampler
10
10
  try:
11
11
 
12
12
  #new version of lightning
@@ -249,7 +249,8 @@ class TimeSeries():
249
249
  check_past:bool=True,
250
250
  group:Union[None,str]=None,
251
251
  check_holes_and_duplicates:bool=True,
252
- silly_model:bool=False)->None:
252
+ silly_model:bool=False,
253
+ sampler_weights:Union[None,str]=None)->None:
253
254
  """ This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
254
255
  There are some checks:
255
256
  1- the duplicates will tbe removed taking the first instance
@@ -270,6 +271,7 @@ class TimeSeries():
270
271
  group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
271
272
  check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
272
273
  silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
274
+ sampler_weights group (str or None, optional): if it is a column name it will be used as weight for the sampler. Careful that the weight of the sample is the weight value of the fist target value (index)
273
275
  """
274
276
 
275
277
 
@@ -322,7 +324,7 @@ class TimeSeries():
322
324
  if group is not None:
323
325
  if group not in cat_past_var:
324
326
  beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
325
- self.cat_var.append(group)
327
+ self.cat_past_var.append(group)
326
328
  if group not in cat_fut_var:
327
329
  beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
328
330
  self.cat_fut_var.append(group)
@@ -350,7 +352,7 @@ class TimeSeries():
350
352
  if silly_model:
351
353
  beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
352
354
  self.future_variables+=self.target_variables
353
-
355
+ self.sampler_weights = sampler_weights
354
356
  def plot(self):
355
357
  """
356
358
  Easy way to control the loaded data
@@ -409,6 +411,7 @@ class TimeSeries():
409
411
  y_samples = []
410
412
  t_samples = []
411
413
  g_samples = []
414
+ sampler_weights_samples = []
412
415
 
413
416
  if starting_point is not None:
414
417
  kk = list(starting_point.keys())[0]
@@ -448,7 +451,8 @@ class TimeSeries():
448
451
 
449
452
  idx_target = []
450
453
  for c in self.target_variables:
451
- idx_target.append(self.past_variables.index(c))
454
+ if c in self.past_variables:
455
+ idx_target.append(self.past_variables.index(c))
452
456
 
453
457
  idx_target_future = []
454
458
 
@@ -475,7 +479,8 @@ class TimeSeries():
475
479
  if len(self.cat_fut_var)>0:
476
480
  x_fut_cat = tmp[self.cat_fut_var].values
477
481
  y_target = tmp[self.target_variables].values
478
-
482
+ if self.sampler_weights is not None:
483
+ sampler_weights = tmp[self.sampler_weights].values.flatten()
479
484
 
480
485
  if starting_point is not None:
481
486
  check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
@@ -512,6 +517,8 @@ class TimeSeries():
512
517
  x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
513
518
 
514
519
  y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
520
+ if self.sampler_weights is not None:
521
+ sampler_weights_samples.append(sampler_weights[i+skip_stacked])
515
522
  t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
516
523
  g_samples.append(groups[i])
517
524
 
@@ -524,6 +531,8 @@ class TimeSeries():
524
531
  beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
525
532
 
526
533
  y_samples = np.stack(y_samples)
534
+ if self.sampler_weights is not None:
535
+ sampler_weights_samples = np.stack(sampler_weights_samples)
527
536
  t_samples = np.stack(t_samples)
528
537
  g_samples = np.stack(g_samples)
529
538
 
@@ -537,7 +546,6 @@ class TimeSeries():
537
546
  else:
538
547
  mod = 1.0
539
548
  dd = {'y':y_samples.astype(np.float32),
540
-
541
549
  'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
542
550
  if len(self.cat_past_var)>0:
543
551
  dd['x_cat_past'] = x_cat_past_samples
@@ -545,7 +553,10 @@ class TimeSeries():
545
553
  dd['x_cat_future'] = x_cat_future_samples
546
554
  if len(self.future_variables)>0:
547
555
  dd['x_num_future'] = x_num_future_samples.astype(np.float32)
548
-
556
+ if self.sampler_weights is not None:
557
+ dd['sampler_weights'] = sampler_weights_samples.astype(np.float32)
558
+ else:
559
+ dd['sampler_weights'] = np.ones(len(y_samples)).astype(np.float32)
549
560
  return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
550
561
 
551
562
 
@@ -660,7 +671,7 @@ class TimeSeries():
660
671
  if c!=self.group:
661
672
  self.scaler_cat[f'{c}_{group}'] = OrdinalEncoder(dtype=np.int32,handle_unknown= 'use_encoded_value',unknown_value=train[c].nunique())
662
673
  self.scaler_cat[f'{c}_{group}'].fit(tmp[c].values.reshape(-1,1))
663
-
674
+
664
675
  dl_train = self.create_data_loader(train,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
665
676
  dl_validation = self.create_data_loader(validation,past_steps,future_steps,shift,keep_entire_seq_while_shifting,starting_point,skip_step)
666
677
  if test.shape[0]>0:
@@ -753,10 +764,33 @@ class TimeSeries():
753
764
  else:
754
765
  self.modifier = None
755
766
 
767
+ if self.sampler_weights is not None:
768
+ beauty_string(f'USING SAMPLER IN TRAIN {min(train.sampler_weights)}-{max(train.sampler_weights)}','section',self.verbose)
769
+
770
+ sampler = WeightedRandomSampler(train.sampler_weights, num_samples= len(train))
771
+ train_dl = DataLoader(train, batch_size = batch_size , shuffle=False,sampler=sampler,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
756
772
 
757
- train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
773
+ else:
774
+ train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
758
775
  valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
759
-
776
+ debug_prediction = True
777
+
778
+ if debug_prediction:
779
+ dl = DataLoader(test, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
780
+ res = []
781
+ real = []
782
+ for batch in dl:
783
+
784
+
785
+ res.append(self.model.inference(batch).cpu().detach().numpy())
786
+ real.append(batch['y'].cpu().detach().numpy())
787
+
788
+ res = np.vstack(res)
789
+ real = np.vstack(real)
790
+ with open('/home/agobbi/Projects/ExpTS/tmp_beginning.pkl','wb') as f:
791
+ import pickle
792
+ pickle.dump([res,real],f)
793
+
760
794
  checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
761
795
  monitor='val_loss',
762
796
  save_last = True,
@@ -851,23 +885,24 @@ class TimeSeries():
851
885
 
852
886
 
853
887
 
854
-
888
+
855
889
  if auto_lr_find and (weight_exists is False):
856
- if OLD_PL:
857
- lr_tuner = trainer.tune(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
858
- files = os.listdir(dirpath)
859
- for f in files:
860
- if '.lr_find' in f:
861
- os.remove(os.path.join(dirpath,f))
862
- self.model.optim_config['lr'] = lr_tuner['lr_find'].suggestion()
863
- else:
864
- from lightning.pytorch.tuner import Tuner
865
- tuner = Tuner(trainer)
866
- lr_finder = tuner.lr_find(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
867
- self.model.optim_config['lr'] = lr_finder.suggestion() ## we are using it as optim key
890
+ try:
891
+ if OLD_PL:
892
+ lr_tuner = trainer.tune(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
893
+ files = os.listdir(dirpath)
894
+ for f in files:
895
+ if '.lr_find' in f:
896
+ os.remove(os.path.join(dirpath,f))
897
+ self.model.optim_config['lr'] = lr_tuner['lr_find'].suggestion()
898
+ else:
899
+ from lightning.pytorch.tuner import Tuner
900
+ tuner = Tuner(trainer)
901
+ lr_finder = tuner.lr_find(self.model,train_dataloaders=train_dl,val_dataloaders = valid_dl)
902
+ self.model.optim_config['lr'] = lr_finder.suggestion() ## we are using it as optim key
903
+ except Exception as e:
904
+ beauty_string(f'There is a problem with the finding LR routine {e}','section',self.verbose)
868
905
 
869
-
870
-
871
906
  if OLD_PL:
872
907
  if weight_exists:
873
908
  trainer.fit(self.model, train_dl,valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
@@ -878,17 +913,36 @@ class TimeSeries():
878
913
  trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl,ckpt_path=os.path.join(dirpath,'last.ckpt'))
879
914
  else:
880
915
  trainer.fit(self.model, train_dataloaders = train_dl,val_dataloaders = valid_dl)
916
+
881
917
  self.checkpoint_file_best = checkpoint_callback.best_model_path
882
918
  self.checkpoint_file_last = checkpoint_callback.last_model_path
883
919
  if self.checkpoint_file_last=='':
884
920
  beauty_string('There is a bug on saving last model I will try to fix it','info',self.verbose)
885
- self.checkpoint_file_last = checkpoint_callback.best_model_path.replace('checkpoint','last')
921
+ self.checkpoint_file_last = os.path.join(dirpath, "last.ckpt")
922
+ trainer.save_checkpoint(os.path.join(dirpath, "last.ckpt"))
923
+ if self.checkpoint_file_best=='':
924
+ beauty_string('There is a bug on saving best model I will try to fix it','info',self.verbose)
925
+ self.checkpoint_file_best = os.path.join(dirpath, "checkpoint.ckpt")
926
+ trainer.save_checkpoint(os.path.join(dirpath, "checkpoint.ckpt"))
886
927
 
887
928
  self.dirpath = dirpath
888
929
 
889
930
  self.losses = mc.metrics
890
931
 
891
932
  files = os.listdir(dirpath)
933
+ if debug_prediction:
934
+ res = []
935
+ real = []
936
+ for batch in dl:
937
+ res.append(self.model.inference(batch).cpu().detach().numpy())
938
+ real.append(batch['y'].cpu().detach().numpy())
939
+
940
+ res = np.vstack(res)
941
+ real = np.vstack(real)
942
+ with open('/home/agobbi/Projects/ExpTS/tmp_after_training.pkl','wb') as f:
943
+ import pickle
944
+ pickle.dump([res,real],f)
945
+
892
946
  ##accrocchio per multi gpu
893
947
  for f in files:
894
948
  if '__losses__.csv' in f:
@@ -901,7 +955,6 @@ class TimeSeries():
901
955
  self.losses = pd.DataFrame()
902
956
 
903
957
  try:
904
-
905
958
  if OLD_PL:
906
959
  if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
907
960
  self.model = self.model._orig_mod
@@ -914,7 +967,21 @@ class TimeSeries():
914
967
  self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
915
968
  else:
916
969
  self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
970
+ if debug_prediction:
971
+ res = []
972
+ real = []
973
+ for batch in dl:
974
+
975
+
976
+ res.append(self.model.inference(batch).cpu().detach().numpy())
977
+ real.append(batch['y'].cpu().detach().numpy())
917
978
 
979
+ res = np.vstack(res)
980
+ real = np.vstack(real)
981
+ with open('/home/agobbi/Projects/ExpTS/tmp_after_loading.pkl','wb') as f:
982
+ import pickle
983
+ pickle.dump([res,real],f)
984
+
918
985
  except Exception as _:
919
986
  beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
920
987
 
@@ -992,12 +1059,15 @@ class TimeSeries():
992
1059
  beauty_string(f'Device used: {self.model.device}','info',self.verbose)
993
1060
 
994
1061
  for batch in dl:
1062
+
1063
+
995
1064
  res.append(self.model.inference(batch).cpu().detach().numpy())
996
1065
  real.append(batch['y'].cpu().detach().numpy())
997
-
1066
+
998
1067
  res = np.vstack(res)
999
-
1000
1068
  real = np.vstack(real)
1069
+
1070
+
1001
1071
  time = dl.dataset.t
1002
1072
  groups = dl.dataset.groups
1003
1073
  #import pdb
@@ -1026,7 +1096,7 @@ class TimeSeries():
1026
1096
 
1027
1097
  if self.group is not None:
1028
1098
  time[self.group] = groups
1029
- time = time.melt(id_vars=['region'])
1099
+ time = time.melt(id_vars=[self.group])
1030
1100
  else:
1031
1101
  time = time.melt()
1032
1102
  time.rename(columns={'value':'time','variable':'lag'},inplace=True)
@@ -1048,7 +1118,8 @@ class TimeSeries():
1048
1118
 
1049
1119
  if self.group is not None:
1050
1120
  time[self.group] = groups
1051
- time = time.melt(id_vars=['region'])
1121
+
1122
+ time = time.melt(id_vars=[self.group])
1052
1123
  else:
1053
1124
  time = time.melt()
1054
1125
  time.rename(columns={'value':'time','variable':'lag'},inplace=True)
@@ -1186,15 +1257,21 @@ class TimeSeries():
1186
1257
 
1187
1258
  try:
1188
1259
  tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
1260
+ beauty_string(f"Loading {tmp_path}",'section',self.verbose)
1261
+
1189
1262
  except Exception as _:
1190
1263
  beauty_string('checkpoint_file_last not defined try to load best','section',self.verbose)
1191
1264
  tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
1265
+ beauty_string(f"Loading {tmp_path}",'section',self.verbose)
1192
1266
  else:
1193
1267
  try:
1268
+
1194
1269
  tmp_path = os.path.join(directory,self.checkpoint_file_best.split('/')[-1])
1270
+ beauty_string(f"Loading {tmp_path}",'section',self.verbose)
1195
1271
  except Exception as _:
1196
1272
  beauty_string('checkpoint_file_best not defined try to load best','section',self.verbose)
1197
1273
  tmp_path = os.path.join(directory,self.checkpoint_file_last.split('/')[-1])
1274
+ beauty_string(f"Loading {tmp_path}",'section',self.verbose)
1198
1275
  try:
1199
1276
  #with torch.serialization.add_safe_globals([ListConfig]):
1200
1277
  if OLD_PL:
@@ -1202,5 +1279,6 @@ class TimeSeries():
1202
1279
  else:
1203
1280
  self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
1204
1281
  self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
1282
+
1205
1283
  except Exception as e:
1206
1284
  beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
@@ -142,12 +142,13 @@ class MyDataset(Dataset):
142
142
  Returns:
143
143
  torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
144
144
  """
145
+
145
146
  self.data = data
146
147
  self.t = t
147
148
  self.groups = groups
148
149
  self.idx_target = np.array(idx_target) if idx_target is not None else None
149
150
  self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
150
-
151
+ self.sampler_weights = data['sampler_weights']
151
152
 
152
153
 
153
154
  def __len__(self):
@@ -157,7 +158,8 @@ class MyDataset(Dataset):
157
158
  def __getitem__(self, idxs):
158
159
  sample = {}
159
160
  for k in self.data:
160
- sample[k] = self.data[k][idxs]
161
+ if k!='sampler_weights':
162
+ sample[k] = self.data[k][idxs]
161
163
  if self.idx_target is not None:
162
164
  sample['idx_target'] = self.idx_target
163
165
  if self.idx_target_future is not None:
@@ -30,6 +30,8 @@ class Persistent(Base):
30
30
  self.save_hyperparameters(logger=False)
31
31
  self.fake = nn.Linear(1,1)
32
32
  self.use_quantiles = False
33
+ def can_be_compiled(self):
34
+ return False
33
35
 
34
36
  def forward(self, batch):
35
37
 
dsipts/models/TTM.py CHANGED
@@ -38,6 +38,8 @@ class TTM(Base):
38
38
  fcm_mix_layers,
39
39
  fcm_prepend_past,
40
40
  enable_forecast_channel_mixing,
41
+ force_return,
42
+ few_shot = True,
41
43
  **kwargs)->None:
42
44
 
43
45
  super().__init__(**kwargs)
@@ -48,7 +50,9 @@ class TTM(Base):
48
50
  self.index_fut = list(exogenous_channel_indices_cont)
49
51
 
50
52
  if len(exogenous_channel_indices_cat)>0:
51
- self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
53
+
54
+ self.index_fut_cat = [self.past_channels+c for c in list(exogenous_channel_indices_cat)]
55
+
52
56
  else:
53
57
  self.index_fut_cat = []
54
58
  self.freq = freq
@@ -75,6 +79,7 @@ class TTM(Base):
75
79
  fcm_use_mixer=fcm_use_mixer,
76
80
  fcm_mix_layers=fcm_mix_layers,
77
81
  freq=freq,
82
+ force_return=force_return,
78
83
  freq_prefix_tuning=freq_prefix_tuning,
79
84
  fcm_prepend_past=fcm_prepend_past,
80
85
  enable_forecast_channel_mixing=enable_forecast_channel_mixing,
@@ -82,8 +87,9 @@ class TTM(Base):
82
87
  )
83
88
  hidden_size = self.model.config.hidden_size
84
89
  self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
85
- self._freeze_backbone()
86
-
90
+ if few_shot:
91
+ self._freeze_backbone()
92
+ self.zero_pad = (force_return=='zeropad')
87
93
  def _freeze_backbone(self):
88
94
  """
89
95
  Freeze the backbone of the model.
@@ -108,29 +114,44 @@ class TTM(Base):
108
114
  return input
109
115
 
110
116
  def can_be_compiled(self):
111
- return True
117
+
118
+ return True#True#not self.zero_pad
112
119
 
113
120
  def forward(self, batch):
114
121
  x_enc = batch['x_num_past'].to(self.device)
122
+
123
+
124
+ if self.zero_pad:
125
+ B,L,C = batch['x_num_past'].shape
126
+ x_enc = torch.zeros((B,512,C)).to(self.device)
127
+ x_enc[:,-L:,:] = batch['x_num_past'].to(self.device)
128
+ else:
129
+ x_enc = batch['x_num_past'].to(self.device)
115
130
  original_indexes = batch['idx_target'][0].tolist()
116
131
 
117
132
 
118
133
  if 'x_cat_past' in batch.keys():
119
- x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
120
- x_mark_enc = self._scaler_past(x_mark_enc)
134
+ if self.zero_pad:
135
+ B,L,C = batch['x_cat_past'].shape
136
+ x_mark_enc = torch.zeros((B,512,C)).to(self.device)
137
+ x_mark_enc[:,-L:,:] = batch['x_cat_past'].to(torch.float32).to(self.device)
138
+ else:
139
+ x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
140
+ x_mark_enc = self._scaler_past(x_mark_enc)
121
141
  past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
122
142
  else:
123
143
  past_values = x_enc
144
+ B,L,C = past_values.shape
145
+ future_values = torch.zeros((B,self.future_steps,C)).to(self.device)
124
146
 
125
- future_values = torch.zeros_like(past_values).to(self.device)
126
- future_values = future_values[:,:self.future_steps,:]
127
147
 
148
+
128
149
  if 'x_num_future' in batch.keys():
129
150
  future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
130
151
  if 'x_cat_future' in batch.keys():
131
152
  x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
132
153
  x_mark_dec = self._scaler_fut(x_mark_dec)
133
- future_values[:,:,self.index_cat_fut] = x_mark_dec
154
+ future_values[:,:,self.index_fut_cat] = x_mark_dec
134
155
 
135
156
 
136
157
  #investigating!! problem with dynamo!
@@ -153,6 +174,7 @@ class TTM(Base):
153
174
 
154
175
 
155
176
  BS = res.shape[0]
177
+
156
178
  return res.reshape(BS,self.future_steps,-1,self.mul)
157
179
 
158
180
 
@@ -0,0 +1,123 @@
1
+ ## Copyright https://github.com/huangst21/TimeKAN/blob/main/models/TimeKAN.py
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside itransformer folder
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .timekan.Layers import FrequencyDecomp,FrequencyMixing,series_decomp,Normalize
9
+ from ..data_structure.utils import beauty_string
10
+ from .utils import get_scope,get_activation,Embedding_cat_variables
11
+
12
+ try:
13
+ import lightning.pytorch as pl
14
+ from .base_v2 import Base
15
+ OLD_PL = False
16
+ except:
17
+ import pytorch_lightning as pl
18
+ OLD_PL = True
19
+ from .base import Base
20
+
21
+
22
+
23
+ class TimeKAN(Base):
24
+ handle_multivariate = True
25
+ handle_future_covariates = True
26
+ handle_categorical_variables = True
27
+ handle_quantile_loss = True
28
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
29
+
30
+ def __init__(self,
31
+ # specific params
32
+ down_sampling_window:int,
33
+ e_layers:int,
34
+ moving_avg:int,
35
+ down_sampling_layers: int,
36
+ d_model: int,
37
+ begin_order: int,
38
+ use_norm:bool,
39
+ **kwargs)->None:
40
+
41
+
42
+
43
+
44
+ super().__init__(**kwargs)
45
+
46
+ self.save_hyperparameters(logger=False)
47
+ self.e_layers = e_layers
48
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
49
+ #self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
50
+ emb_past_out_channel = self.emb_past.output_channels
51
+
52
+
53
+
54
+ self.res_blocks = nn.ModuleList([FrequencyDecomp( self.past_steps,down_sampling_window,down_sampling_layers) for _ in range(e_layers)])
55
+ self.add_blocks = nn.ModuleList([FrequencyMixing(d_model,self.past_steps,begin_order,down_sampling_window,down_sampling_layers) for _ in range(e_layers)])
56
+
57
+ self.preprocess = series_decomp(moving_avg)
58
+ self.enc_in = self.past_channels + emb_past_out_channel
59
+ self.project = nn.Linear(self.enc_in,d_model)
60
+ self.layer = e_layers
61
+ self.normalize_layers = torch.nn.ModuleList(
62
+ [
63
+ Normalize(self.enc_in, affine=True, non_norm=use_norm)
64
+ for i in range(down_sampling_layers + 1)
65
+ ]
66
+ )
67
+ self.predict_layer =nn. Linear(
68
+ self.past_steps,
69
+ self.future_steps,
70
+ )
71
+ self.final_layer = nn.Linear(d_model, self.mul)
72
+ self.down_sampling_layers = down_sampling_layers
73
+ self.down_sampling_window = down_sampling_window
74
+ def can_be_compiled(self):
75
+ return True#True
76
+
77
+
78
+ def __multi_level_process_inputs(self, x_enc):
79
+ down_pool = torch.nn.AvgPool1d(self.down_sampling_window)
80
+ # B,T,C -> B,C,T
81
+ x_enc = x_enc.permute(0, 2, 1)
82
+ x_enc_ori = x_enc
83
+ x_enc_sampling_list = []
84
+ x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
85
+ for i in range(self.down_sampling_layers):
86
+ x_enc_sampling = down_pool(x_enc_ori)
87
+ x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
88
+ x_enc_ori = x_enc_sampling
89
+ x_enc = x_enc_sampling_list
90
+ return x_enc
91
+
92
+
93
+ def forward(self, batch:dict)-> float:
94
+
95
+ x_enc = batch['x_num_past'].to(self.device)
96
+ BS = x_enc.shape[0]
97
+ if 'x_cat_past' in batch.keys():
98
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
99
+ else:
100
+ emb_past = self.emb_past(BS,None)
101
+
102
+ x_past = torch.cat([x_enc,emb_past],2)
103
+
104
+ x_enc = self.__multi_level_process_inputs(x_past)
105
+
106
+ x_list = []
107
+ for i, x in zip(range(len(x_enc)), x_enc, ):
108
+ B, T, N = x.size()
109
+ x = self.normalize_layers[i](x, 'norm')
110
+ x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
111
+ x_list.append(self.project(x.reshape(B, T, N)))
112
+
113
+
114
+
115
+ for i in range(self.layer):
116
+ x_list = self.res_blocks[i](x_list)
117
+ x_list = self.add_blocks[i](x_list)
118
+
119
+ dec_out = x_list[0]
120
+ dec_out = self.predict_layer(dec_out.permute(0, 2, 1)).permute( 0, 2, 1)
121
+ dec_out = self.final_layer(dec_out)
122
+
123
+ return dec_out.reshape(BS,self.future_steps,self.out_channels,self.mul)