dsipts 1.1.11__tar.gz → 1.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.11 → dsipts-1.1.13}/PKG-INFO +1 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/pyproject.toml +1 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_structure/data_structure.py +50 -18
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_structure/utils.py +4 -2
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Autoformer.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/CrossFormer.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/D3VAE.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Diffusion.py +3 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/DilatedConv.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/DilatedConvED.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Duet.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ITransformer.py +3 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Informer.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/LinearTS.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/PatchTST.py +3 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/RNN.py +2 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Samformer.py +3 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Simple.py +3 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/TFT.py +4 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/TIDE.py +4 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/TTM.py +47 -15
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/TimeXER.py +3 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/base_v2.py +7 -8
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/duet/layers.py +6 -2
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts.egg-info/PKG-INFO +1 -1
- {dsipts-1.1.11 → dsipts-1.1.13}/README.md +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/setup.cfg +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/base.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/utils.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.11 → dsipts-1.1.13}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
|
|
|
6
6
|
from sklearn.preprocessing import *
|
|
7
7
|
from torch.utils.data import DataLoader
|
|
8
8
|
from .utils import extend_time_df,MetricsCallback, MyDataset, ActionEnum,beauty_string
|
|
9
|
-
|
|
9
|
+
from torch.utils.data.sampler import WeightedRandomSampler
|
|
10
10
|
try:
|
|
11
11
|
|
|
12
12
|
#new version of lightning
|
|
@@ -249,7 +249,8 @@ class TimeSeries():
|
|
|
249
249
|
check_past:bool=True,
|
|
250
250
|
group:Union[None,str]=None,
|
|
251
251
|
check_holes_and_duplicates:bool=True,
|
|
252
|
-
silly_model:bool=False
|
|
252
|
+
silly_model:bool=False,
|
|
253
|
+
sampler_weights:Union[None,str]=None)->None:
|
|
253
254
|
""" This is a crucial point in the data structure. We expect here to have a dataset with time as timestamp.
|
|
254
255
|
There are some checks:
|
|
255
256
|
1- the duplicates will tbe removed taking the first instance
|
|
@@ -270,6 +271,7 @@ class TimeSeries():
|
|
|
270
271
|
group (str or None, optional): if not None the time serie dataset is considered composed by omogeneus timeseries coming from different realization (for example point of sales, cities, locations) and the relative series are not splitted during the sample generation. Defaults to None
|
|
271
272
|
check_holes_and_duplicates (bool, optional): if False duplicates or holes will not checked, the dataloader can not correctly work, disable at your own risk. Defaults True
|
|
272
273
|
silly_model (bool, optional): if True, target variables will be added to the pool of the future variables. This can be useful to see if information passes throught the decoder part of your model (if any)
|
|
274
|
+
sampler_weights group (str or None, optional): if it is a column name it will be used as weight for the sampler. Careful that the weight of the sample is the weight value of the fist target value (index)
|
|
273
275
|
"""
|
|
274
276
|
|
|
275
277
|
|
|
@@ -322,7 +324,7 @@ class TimeSeries():
|
|
|
322
324
|
if group is not None:
|
|
323
325
|
if group not in cat_past_var:
|
|
324
326
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
325
|
-
self.
|
|
327
|
+
self.cat_past_var.append(group)
|
|
326
328
|
if group not in cat_fut_var:
|
|
327
329
|
beauty_string(f'I will add {group} to the categorical past/future variables','info',self.verbose)
|
|
328
330
|
self.cat_fut_var.append(group)
|
|
@@ -350,7 +352,7 @@ class TimeSeries():
|
|
|
350
352
|
if silly_model:
|
|
351
353
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
352
354
|
self.future_variables+=self.target_variables
|
|
353
|
-
|
|
355
|
+
self.sampler_weights = sampler_weights
|
|
354
356
|
def plot(self):
|
|
355
357
|
"""
|
|
356
358
|
Easy way to control the loaded data
|
|
@@ -409,6 +411,7 @@ class TimeSeries():
|
|
|
409
411
|
y_samples = []
|
|
410
412
|
t_samples = []
|
|
411
413
|
g_samples = []
|
|
414
|
+
sampler_weights_samples = []
|
|
412
415
|
|
|
413
416
|
if starting_point is not None:
|
|
414
417
|
kk = list(starting_point.keys())[0]
|
|
@@ -475,7 +478,8 @@ class TimeSeries():
|
|
|
475
478
|
if len(self.cat_fut_var)>0:
|
|
476
479
|
x_fut_cat = tmp[self.cat_fut_var].values
|
|
477
480
|
y_target = tmp[self.target_variables].values
|
|
478
|
-
|
|
481
|
+
if self.sampler_weights is not None:
|
|
482
|
+
sampler_weights = tmp[self.sampler_weights].values.flatten()
|
|
479
483
|
|
|
480
484
|
if starting_point is not None:
|
|
481
485
|
check = tmp[list(starting_point.keys())[0]].values == starting_point[list(starting_point.keys())[0]]
|
|
@@ -512,6 +516,8 @@ class TimeSeries():
|
|
|
512
516
|
x_cat_future_samples.append(x_fut_cat[i-shift+skip_stacked:i+future_steps-shift+skip_stacked])
|
|
513
517
|
|
|
514
518
|
y_samples.append(y_target[i+skip_stacked:i+future_steps+skip_stacked])
|
|
519
|
+
if self.sampler_weights is not None:
|
|
520
|
+
sampler_weights_samples.append(sampler_weights[i+skip_stacked])
|
|
515
521
|
t_samples.append(t[i+skip_stacked:i+future_steps+skip_stacked])
|
|
516
522
|
g_samples.append(groups[i])
|
|
517
523
|
|
|
@@ -524,6 +530,8 @@ class TimeSeries():
|
|
|
524
530
|
beauty_string('WARNING x_num_future_samples is empty and it should not','info',True)
|
|
525
531
|
|
|
526
532
|
y_samples = np.stack(y_samples)
|
|
533
|
+
if self.sampler_weights is not None:
|
|
534
|
+
sampler_weights_samples = np.stack(sampler_weights_samples)
|
|
527
535
|
t_samples = np.stack(t_samples)
|
|
528
536
|
g_samples = np.stack(g_samples)
|
|
529
537
|
|
|
@@ -537,7 +545,6 @@ class TimeSeries():
|
|
|
537
545
|
else:
|
|
538
546
|
mod = 1.0
|
|
539
547
|
dd = {'y':y_samples.astype(np.float32),
|
|
540
|
-
|
|
541
548
|
'x_num_past':(x_num_past_samples*mod).astype(np.float32)}
|
|
542
549
|
if len(self.cat_past_var)>0:
|
|
543
550
|
dd['x_cat_past'] = x_cat_past_samples
|
|
@@ -545,7 +552,10 @@ class TimeSeries():
|
|
|
545
552
|
dd['x_cat_future'] = x_cat_future_samples
|
|
546
553
|
if len(self.future_variables)>0:
|
|
547
554
|
dd['x_num_future'] = x_num_future_samples.astype(np.float32)
|
|
548
|
-
|
|
555
|
+
if self.sampler_weights is not None:
|
|
556
|
+
dd['sampler_weights'] = sampler_weights_samples.astype(np.float32)
|
|
557
|
+
else:
|
|
558
|
+
dd['sampler_weights'] = np.ones(len(y_samples)).astype(np.float32)
|
|
549
559
|
return MyDataset(dd,t_samples,g_samples,idx_target,idx_target_future)
|
|
550
560
|
|
|
551
561
|
|
|
@@ -683,10 +693,7 @@ class TimeSeries():
|
|
|
683
693
|
#self.model.apply(weight_init_zeros)
|
|
684
694
|
|
|
685
695
|
self.config = config
|
|
686
|
-
|
|
687
|
-
self.model = torch.compile(self.model)
|
|
688
|
-
except:
|
|
689
|
-
beauty_string('Can not compile the model','block',self.verbose)
|
|
696
|
+
|
|
690
697
|
|
|
691
698
|
beauty_string('Setting the model','block',self.verbose)
|
|
692
699
|
beauty_string(model,'',self.verbose)
|
|
@@ -756,8 +763,14 @@ class TimeSeries():
|
|
|
756
763
|
else:
|
|
757
764
|
self.modifier = None
|
|
758
765
|
|
|
766
|
+
if self.sampler_weights is not None:
|
|
767
|
+
beauty_string(f'USING SAMPLER IN TRAIN {min(train.sampler_weights)}-{max(train.sampler_weights)}','section',self.verbose)
|
|
759
768
|
|
|
760
|
-
|
|
769
|
+
sampler = WeightedRandomSampler(train.sampler_weights, num_samples= len(train))
|
|
770
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=False,sampler=sampler,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
771
|
+
|
|
772
|
+
else:
|
|
773
|
+
train_dl = DataLoader(train, batch_size = batch_size , shuffle=True,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
761
774
|
valid_dl = DataLoader(validation, batch_size = batch_size , shuffle=False,drop_last=True,num_workers=num_workers,persistent_workers=persistent_workers)
|
|
762
775
|
|
|
763
776
|
checkpoint_callback = ModelCheckpoint(dirpath=dirpath,
|
|
@@ -812,8 +825,17 @@ class TimeSeries():
|
|
|
812
825
|
weight_exists = False
|
|
813
826
|
beauty_string('I can not load a previous model','section',self.verbose)
|
|
814
827
|
|
|
828
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
829
|
+
if self.model.can_be_compiled():
|
|
830
|
+
try:
|
|
831
|
+
self.model = torch.compile(self.model)
|
|
832
|
+
beauty_string('Model COMPILED','block',self.verbose)
|
|
833
|
+
|
|
834
|
+
except:
|
|
835
|
+
beauty_string('Can not compile the model','block',self.verbose)
|
|
836
|
+
else:
|
|
837
|
+
beauty_string('Model can not still be compiled, be patient','block',self.verbose)
|
|
815
838
|
|
|
816
|
-
|
|
817
839
|
|
|
818
840
|
if OLD_PL:
|
|
819
841
|
trainer = pl.Trainer(default_root_dir=dirpath,
|
|
@@ -895,10 +917,19 @@ class TimeSeries():
|
|
|
895
917
|
self.losses = pd.DataFrame()
|
|
896
918
|
|
|
897
919
|
try:
|
|
920
|
+
|
|
898
921
|
if OLD_PL:
|
|
899
|
-
self.model
|
|
922
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
923
|
+
self.model = self.model._orig_mod
|
|
924
|
+
self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
925
|
+
else:
|
|
926
|
+
self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
900
927
|
else:
|
|
901
|
-
self.model
|
|
928
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
929
|
+
mm = self.model._orig_mod
|
|
930
|
+
self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
931
|
+
else:
|
|
932
|
+
self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
902
933
|
|
|
903
934
|
except Exception as _:
|
|
904
935
|
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
|
|
@@ -1011,7 +1042,7 @@ class TimeSeries():
|
|
|
1011
1042
|
|
|
1012
1043
|
if self.group is not None:
|
|
1013
1044
|
time[self.group] = groups
|
|
1014
|
-
time = time.melt(id_vars=[
|
|
1045
|
+
time = time.melt(id_vars=[self.group])
|
|
1015
1046
|
else:
|
|
1016
1047
|
time = time.melt()
|
|
1017
1048
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -1033,7 +1064,8 @@ class TimeSeries():
|
|
|
1033
1064
|
|
|
1034
1065
|
if self.group is not None:
|
|
1035
1066
|
time[self.group] = groups
|
|
1036
|
-
|
|
1067
|
+
|
|
1068
|
+
time = time.melt(id_vars=[self.group])
|
|
1037
1069
|
else:
|
|
1038
1070
|
time = time.melt()
|
|
1039
1071
|
time.rename(columns={'value':'time','variable':'lag'},inplace=True)
|
|
@@ -1186,6 +1218,6 @@ class TimeSeries():
|
|
|
1186
1218
|
self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1187
1219
|
else:
|
|
1188
1220
|
self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1189
|
-
|
|
1221
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
1190
1222
|
except Exception as e:
|
|
1191
1223
|
beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
|
|
@@ -142,12 +142,13 @@ class MyDataset(Dataset):
|
|
|
142
142
|
Returns:
|
|
143
143
|
torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
|
|
144
144
|
"""
|
|
145
|
+
|
|
145
146
|
self.data = data
|
|
146
147
|
self.t = t
|
|
147
148
|
self.groups = groups
|
|
148
149
|
self.idx_target = np.array(idx_target) if idx_target is not None else None
|
|
149
150
|
self.idx_target_future = np.array(idx_target_future) if idx_target_future is not None else None
|
|
150
|
-
|
|
151
|
+
self.sampler_weights = data['sampler_weights']
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
def __len__(self):
|
|
@@ -157,7 +158,8 @@ class MyDataset(Dataset):
|
|
|
157
158
|
def __getitem__(self, idxs):
|
|
158
159
|
sample = {}
|
|
159
160
|
for k in self.data:
|
|
160
|
-
|
|
161
|
+
if k!='sampler_weights':
|
|
162
|
+
sample[k] = self.data[k][idxs]
|
|
161
163
|
if self.idx_target is not None:
|
|
162
164
|
sample['idx_target'] = self.idx_target
|
|
163
165
|
if self.idx_target_future is not None:
|
|
@@ -148,7 +148,8 @@ class Autoformer(Base):
|
|
|
148
148
|
projection=nn.Linear(d_model, self.out_channels*self.mul, bias=True)
|
|
149
149
|
)
|
|
150
150
|
self.projection = nn.Linear(self.past_channels,self.out_channels*self.mul )
|
|
151
|
-
|
|
151
|
+
def can_be_compiled(self):
|
|
152
|
+
return True
|
|
152
153
|
def forward(self, batch):
|
|
153
154
|
|
|
154
155
|
|
|
@@ -425,6 +425,9 @@ class Diffusion(Base):
|
|
|
425
425
|
loss = self.compute_loss(batch,out)
|
|
426
426
|
return loss
|
|
427
427
|
|
|
428
|
+
def can_be_compiled(self):
|
|
429
|
+
return False
|
|
430
|
+
|
|
428
431
|
# function to concat embedded categorical variables
|
|
429
432
|
def cat_categorical_vars(self, batch:dict):
|
|
430
433
|
"""Extracting categorical context about past and future
|
|
@@ -228,7 +228,8 @@ class DilatedConvED(Base):
|
|
|
228
228
|
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
229
229
|
Permute() if use_bn else nn.Identity() ,
|
|
230
230
|
nn.Linear(hidden_RNN ,self.mul))
|
|
231
|
-
|
|
231
|
+
def can_be_compiled(self):
|
|
232
|
+
return True
|
|
232
233
|
|
|
233
234
|
|
|
234
235
|
def forward(self, batch):
|
|
@@ -136,7 +136,8 @@ class Duet(Base):
|
|
|
136
136
|
activation(),
|
|
137
137
|
nn.Linear(dim*2,self.out_channels*self.mul ))
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
def can_be_compiled(self):
|
|
140
|
+
return False
|
|
140
141
|
def forward(self, batch:dict)-> float:
|
|
141
142
|
# x: [Batch, Input length, Channel]
|
|
142
143
|
x_enc = batch['x_num_past'].to(self.device)
|
|
@@ -101,6 +101,9 @@ class ITransformer(Base):
|
|
|
101
101
|
)
|
|
102
102
|
self.projector = nn.Linear(d_model, self.future_steps*self.mul, bias=True)
|
|
103
103
|
|
|
104
|
+
def can_be_compiled(self):
|
|
105
|
+
return True
|
|
106
|
+
|
|
104
107
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
|
105
108
|
if self.use_norm:
|
|
106
109
|
# Normalization from Non-stationary Transformer
|
|
@@ -143,7 +143,8 @@ class LinearTS(Base):
|
|
|
143
143
|
activation(),
|
|
144
144
|
nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
145
145
|
nn.Linear(hidden_size//8,self.future_steps*self.mul)))
|
|
146
|
-
|
|
146
|
+
def can_be_compiled(self):
|
|
147
|
+
return True
|
|
147
148
|
def forward(self, batch):
|
|
148
149
|
|
|
149
150
|
x = batch['x_num_past'].to(self.device)
|
|
@@ -133,6 +133,9 @@ class PatchTST(Base):
|
|
|
133
133
|
|
|
134
134
|
#self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
|
|
135
135
|
|
|
136
|
+
def can_be_compiled(self):
|
|
137
|
+
return True
|
|
138
|
+
|
|
136
139
|
def forward(self, batch): # x: [Batch, Input length, Channel]
|
|
137
140
|
|
|
138
141
|
|
|
@@ -67,7 +67,9 @@ class Simple(Base):
|
|
|
67
67
|
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
68
|
activation(),nn.Dropout(dropout_rate),
|
|
69
69
|
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
-
|
|
70
|
+
def can_be_compiled(self):
|
|
71
|
+
return True
|
|
72
|
+
|
|
71
73
|
def forward(self, batch):
|
|
72
74
|
|
|
73
75
|
x = batch['x_num_past'].to(self.device)
|
|
@@ -106,7 +106,10 @@ class TIDE(Base):
|
|
|
106
106
|
|
|
107
107
|
# linear for Y lookback
|
|
108
108
|
self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
|
|
109
|
-
|
|
109
|
+
|
|
110
|
+
def can_be_compiled(self):
|
|
111
|
+
return False
|
|
112
|
+
|
|
110
113
|
|
|
111
114
|
def forward(self, batch:dict)-> float:
|
|
112
115
|
"""training process of the diffusion network
|
|
@@ -12,7 +12,7 @@ except:
|
|
|
12
12
|
from .base import Base
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
from .ttm.utils import get_model, get_frequency_token, count_parameters
|
|
15
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters, DEFAULT_FREQUENCY_MAPPING
|
|
16
16
|
from ..data_structure.utils import beauty_string
|
|
17
17
|
from .utils import get_scope
|
|
18
18
|
|
|
@@ -38,20 +38,32 @@ class TTM(Base):
|
|
|
38
38
|
fcm_mix_layers,
|
|
39
39
|
fcm_prepend_past,
|
|
40
40
|
enable_forecast_channel_mixing,
|
|
41
|
+
force_return,
|
|
41
42
|
**kwargs)->None:
|
|
42
43
|
|
|
43
44
|
super().__init__(**kwargs)
|
|
44
45
|
self.save_hyperparameters(logger=False)
|
|
45
46
|
|
|
47
|
+
|
|
46
48
|
|
|
47
49
|
self.index_fut = list(exogenous_channel_indices_cont)
|
|
48
50
|
|
|
49
51
|
if len(exogenous_channel_indices_cat)>0:
|
|
50
|
-
|
|
52
|
+
|
|
53
|
+
self.index_fut_cat = [self.past_channels+c for c in list(exogenous_channel_indices_cat)]
|
|
54
|
+
|
|
51
55
|
else:
|
|
52
56
|
self.index_fut_cat = []
|
|
53
57
|
self.freq = freq
|
|
54
58
|
|
|
59
|
+
base_freq_token = get_frequency_token(self.freq) # e.g., shape [n_token] or scalar
|
|
60
|
+
# ensure it's a tensor of integer type
|
|
61
|
+
if not torch.is_tensor(base_freq_token):
|
|
62
|
+
base_freq_token = torch.tensor(base_freq_token)
|
|
63
|
+
base_freq_token = base_freq_token.long()
|
|
64
|
+
self.register_buffer("token", base_freq_token, persistent=True)
|
|
65
|
+
|
|
66
|
+
|
|
55
67
|
self.model = get_model(
|
|
56
68
|
model_path=model_path,
|
|
57
69
|
context_length=self.past_steps,
|
|
@@ -66,6 +78,7 @@ class TTM(Base):
|
|
|
66
78
|
fcm_use_mixer=fcm_use_mixer,
|
|
67
79
|
fcm_mix_layers=fcm_mix_layers,
|
|
68
80
|
freq=freq,
|
|
81
|
+
force_return=force_return,
|
|
69
82
|
freq_prefix_tuning=freq_prefix_tuning,
|
|
70
83
|
fcm_prepend_past=fcm_prepend_past,
|
|
71
84
|
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
@@ -74,7 +87,7 @@ class TTM(Base):
|
|
|
74
87
|
hidden_size = self.model.config.hidden_size
|
|
75
88
|
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
76
89
|
self._freeze_backbone()
|
|
77
|
-
|
|
90
|
+
self.zero_pad = (force_return=='zeropad')
|
|
78
91
|
def _freeze_backbone(self):
|
|
79
92
|
"""
|
|
80
93
|
Freeze the backbone of the model.
|
|
@@ -98,34 +111,53 @@ class TTM(Base):
|
|
|
98
111
|
input[:,:,i] = input[:, :, i] / (e-1)
|
|
99
112
|
return input
|
|
100
113
|
|
|
101
|
-
|
|
114
|
+
def can_be_compiled(self):
|
|
115
|
+
|
|
116
|
+
return True#not self.zero_pad
|
|
117
|
+
|
|
102
118
|
def forward(self, batch):
|
|
103
|
-
x_enc = batch['x_num_past']
|
|
104
|
-
original_indexes = batch['idx_target'][0].tolist()
|
|
105
|
-
|
|
119
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
106
120
|
|
|
121
|
+
|
|
122
|
+
if self.zero_pad:
|
|
123
|
+
B,L,C = batch['x_num_past'].shape
|
|
124
|
+
x_enc = torch.zeros((B,512,C)).to(self.device)
|
|
125
|
+
x_enc[:,-L:,:] = batch['x_num_past'].to(self.device)
|
|
126
|
+
else:
|
|
127
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
128
|
+
original_indexes = batch['idx_target'][0].tolist()
|
|
107
129
|
|
|
108
130
|
|
|
109
131
|
if 'x_cat_past' in batch.keys():
|
|
110
|
-
|
|
111
|
-
|
|
132
|
+
if self.zero_pad:
|
|
133
|
+
B,L,C = batch['x_cat_past'].shape
|
|
134
|
+
x_mark_enc = torch.zeros((B,512,C)).to(self.device)
|
|
135
|
+
x_mark_enc[:,-L:,:] = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
136
|
+
else:
|
|
137
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
138
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
112
139
|
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
113
140
|
else:
|
|
114
141
|
past_values = x_enc
|
|
142
|
+
B,L,C = past_values.shape
|
|
143
|
+
future_values = torch.zeros((B,self.future_steps,C)).to(self.device)
|
|
115
144
|
|
|
116
|
-
future_values = torch.zeros_like(past_values)
|
|
117
|
-
future_values = future_values[:,:self.future_steps,:]
|
|
118
145
|
|
|
146
|
+
|
|
119
147
|
if 'x_num_future' in batch.keys():
|
|
120
148
|
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
121
149
|
if 'x_cat_future' in batch.keys():
|
|
122
150
|
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
123
151
|
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
124
|
-
future_values[:,:,self.
|
|
152
|
+
future_values[:,:,self.index_fut_cat] = x_mark_dec
|
|
125
153
|
|
|
126
154
|
|
|
127
|
-
#investigating!!
|
|
128
|
-
freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
155
|
+
#investigating!! problem with dynamo!
|
|
156
|
+
#freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
157
|
+
|
|
158
|
+
batch_size = past_values.shape[0]
|
|
159
|
+
freq_token = self.token.repeat(batch_size).long().to(self.device)
|
|
160
|
+
|
|
129
161
|
|
|
130
162
|
res = self.model(
|
|
131
163
|
past_values= past_values,
|
|
@@ -134,7 +166,7 @@ class TTM(Base):
|
|
|
134
166
|
future_observed_mask = None,
|
|
135
167
|
output_hidden_states = False,
|
|
136
168
|
return_dict = False,
|
|
137
|
-
freq_token= freq_token, ##investigating
|
|
169
|
+
freq_token= freq_token,#[0:past_values.shape[0]], ##investigating
|
|
138
170
|
static_categorical_values = None
|
|
139
171
|
)
|
|
140
172
|
|
|
@@ -307,7 +307,7 @@ class Base(pl.LightningModule):
|
|
|
307
307
|
self.train_epoch_count +=1
|
|
308
308
|
return loss
|
|
309
309
|
|
|
310
|
-
|
|
310
|
+
|
|
311
311
|
def validation_step(self, batch, batch_idx):
|
|
312
312
|
"""
|
|
313
313
|
pythotrch lightening stuff
|
|
@@ -320,15 +320,14 @@ class Base(pl.LightningModule):
|
|
|
320
320
|
else:
|
|
321
321
|
y_hat = self(batch)
|
|
322
322
|
score = 0
|
|
323
|
-
|
|
323
|
+
#log_this_batch = (batch_idx == 0) and (self.count_epoch % int(max(self.trainer.max_epochs / 100,1)) == 1)
|
|
324
324
|
|
|
325
|
+
#if log_this_batch:
|
|
325
326
|
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
"y_hat": y_hat.detach().cpu()
|
|
331
|
-
})
|
|
327
|
+
self._val_outputs=[{
|
|
328
|
+
"y": batch['y'].detach().cpu(),
|
|
329
|
+
"y_hat": y_hat.detach().cpu()
|
|
330
|
+
}]
|
|
332
331
|
self.validation_epoch_metrics+= (self.compute_loss(batch,y_hat)+score).detach()
|
|
333
332
|
self.validation_epoch_count+=1
|
|
334
333
|
return None
|
|
@@ -219,7 +219,7 @@ class SparseDispatcher(object):
|
|
|
219
219
|
# expand according to batch index so we can just split by _part_sizes
|
|
220
220
|
inp_exp = inp[self._batch_index].squeeze(1)
|
|
221
221
|
return torch.split(inp_exp, self._part_sizes, dim=0)
|
|
222
|
-
|
|
222
|
+
|
|
223
223
|
def combine(self, expert_out, multiply_by_gates=True):
|
|
224
224
|
"""Sum together the expert output, weighted by the gates.
|
|
225
225
|
The slice corresponding to a particular batch element `b` is computed
|
|
@@ -234,7 +234,9 @@ class SparseDispatcher(object):
|
|
|
234
234
|
a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
|
|
235
235
|
"""
|
|
236
236
|
# apply exp to expert outputs, so we are not longer in log space
|
|
237
|
+
|
|
237
238
|
stitched = torch.cat(expert_out, 0)
|
|
239
|
+
|
|
238
240
|
if multiply_by_gates:
|
|
239
241
|
# stitched = stitched.mul(self._nonzero_gates)
|
|
240
242
|
stitched = torch.einsum("i...,ij->i...", stitched, self._nonzero_gates)
|
|
@@ -430,9 +432,11 @@ class Linear_extractor_cluster(nn.Module):
|
|
|
430
432
|
expert_inputs = dispatcher.dispatch(x_norm)
|
|
431
433
|
|
|
432
434
|
gates = dispatcher.expert_to_gates()
|
|
435
|
+
|
|
433
436
|
expert_outputs = [
|
|
434
437
|
self.experts[i](expert_inputs[i]) for i in range(self.num_experts)
|
|
435
438
|
]
|
|
439
|
+
#y = dispatcher.combine([e for e in expert_outputs if len(e)>0])
|
|
440
|
+
#with torch._dynamo.disable():
|
|
436
441
|
y = dispatcher.combine(expert_outputs)
|
|
437
|
-
|
|
438
442
|
return y, loss
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|