dsipts 1.1.10__tar.gz → 1.1.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.10 → dsipts-1.1.12}/PKG-INFO +1 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/pyproject.toml +1 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/data_structure.py +57 -20
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Autoformer.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/CrossFormer.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/D3VAE.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Diffusion.py +3 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/DilatedConv.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/DilatedConvED.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Duet.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ITransformer.py +5 -8
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Informer.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/LinearTS.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/PatchTST.py +3 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/RNN.py +2 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Samformer.py +3 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Simple.py +3 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TFT.py +4 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TIDE.py +4 -1
- dsipts-1.1.12/src/dsipts/models/TTM.py +158 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TimeXER.py +3 -1
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/base.py +47 -35
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/base_v2.py +53 -38
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/layers.py +6 -2
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/PKG-INFO +1 -1
- dsipts-1.1.10/src/dsipts/models/TTM.py +0 -252
- {dsipts-1.1.10 → dsipts-1.1.12}/README.md +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/setup.cfg +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/utils.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -35,7 +35,18 @@ from .modifiers import *
|
|
|
35
35
|
from aim.pytorch_lightning import AimLogger
|
|
36
36
|
import time
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
class DummyScaler():
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
def fit(self,x):
|
|
42
|
+
pass
|
|
43
|
+
def transform(self,x):
|
|
44
|
+
return x
|
|
45
|
+
def inverse_transform(self,x):
|
|
46
|
+
return x
|
|
47
|
+
def fit_transform(self,x):
|
|
48
|
+
return x
|
|
49
|
+
|
|
39
50
|
|
|
40
51
|
pd.options.mode.chained_assignment = None
|
|
41
52
|
log = logging.getLogger(__name__)
|
|
@@ -210,20 +221,23 @@ class TimeSeries():
|
|
|
210
221
|
self.future_variables = []
|
|
211
222
|
self.target_variables = ['signal']
|
|
212
223
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
213
|
-
|
|
224
|
+
self.num_var = list(np.sort(self.num_var))
|
|
214
225
|
|
|
215
226
|
def enrich(self,dataset,columns):
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
+
try:
|
|
228
|
+
if columns =='hour':
|
|
229
|
+
dataset[columns] = dataset.time.dt.hour
|
|
230
|
+
elif columns=='dow':
|
|
231
|
+
dataset[columns] = dataset.time.dt.weekday
|
|
232
|
+
elif columns=='month':
|
|
233
|
+
dataset[columns] = dataset.time.dt.month
|
|
234
|
+
elif columns=='minute':
|
|
235
|
+
dataset[columns] = dataset.time.dt.minute
|
|
236
|
+
else:
|
|
237
|
+
if columns not in dataset.columns:
|
|
238
|
+
beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
|
|
239
|
+
except:
|
|
240
|
+
beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
|
|
227
241
|
|
|
228
242
|
def load_signal(self,data:pd.DataFrame,
|
|
229
243
|
enrich_cat:List[str] = [],
|
|
@@ -300,7 +314,7 @@ class TimeSeries():
|
|
|
300
314
|
if check_past:
|
|
301
315
|
beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
|
|
302
316
|
past_variables = list(set(past_variables).union(set(target_variables)))
|
|
303
|
-
|
|
317
|
+
past_variables = list(np.sort(past_variables))
|
|
304
318
|
self.cat_past_var = cat_past_var
|
|
305
319
|
self.cat_fut_var = cat_fut_var
|
|
306
320
|
|
|
@@ -321,14 +335,18 @@ class TimeSeries():
|
|
|
321
335
|
beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
|
|
322
336
|
else:
|
|
323
337
|
self.enrich(dataset,c)
|
|
338
|
+
self.cat_past_var = list(np.sort(self.cat_past_var))
|
|
339
|
+
self.cat_fut_var = list(np.sort(self.cat_fut_var))
|
|
340
|
+
|
|
324
341
|
self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
|
|
325
|
-
|
|
342
|
+
self.cat_var = list(np.sort(self.cat_var))
|
|
326
343
|
self.dataset = dataset
|
|
327
344
|
self.past_variables = past_variables
|
|
328
345
|
self.future_variables = future_variables
|
|
329
346
|
self.target_variables = target_variables
|
|
330
347
|
self.out_vars = len(target_variables)
|
|
331
348
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
349
|
+
self.num_var = list(np.sort(self.num_var))
|
|
332
350
|
if silly_model:
|
|
333
351
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
334
352
|
self.future_variables+=self.target_variables
|
|
@@ -665,7 +683,8 @@ class TimeSeries():
|
|
|
665
683
|
#self.model.apply(weight_init_zeros)
|
|
666
684
|
|
|
667
685
|
self.config = config
|
|
668
|
-
|
|
686
|
+
|
|
687
|
+
|
|
669
688
|
beauty_string('Setting the model','block',self.verbose)
|
|
670
689
|
beauty_string(model,'',self.verbose)
|
|
671
690
|
|
|
@@ -790,8 +809,17 @@ class TimeSeries():
|
|
|
790
809
|
weight_exists = False
|
|
791
810
|
beauty_string('I can not load a previous model','section',self.verbose)
|
|
792
811
|
|
|
812
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
813
|
+
if self.model.can_be_compiled():
|
|
814
|
+
try:
|
|
815
|
+
self.model = torch.compile(self.model)
|
|
816
|
+
beauty_string('Model COMPILED','block',self.verbose)
|
|
817
|
+
|
|
818
|
+
except:
|
|
819
|
+
beauty_string('Can not compile the model','block',self.verbose)
|
|
820
|
+
else:
|
|
821
|
+
beauty_string('Model can not still be compiled, be patient','block',self.verbose)
|
|
793
822
|
|
|
794
|
-
|
|
795
823
|
|
|
796
824
|
if OLD_PL:
|
|
797
825
|
trainer = pl.Trainer(default_root_dir=dirpath,
|
|
@@ -873,10 +901,19 @@ class TimeSeries():
|
|
|
873
901
|
self.losses = pd.DataFrame()
|
|
874
902
|
|
|
875
903
|
try:
|
|
904
|
+
|
|
876
905
|
if OLD_PL:
|
|
877
|
-
self.model
|
|
906
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
907
|
+
self.model = self.model._orig_mod
|
|
908
|
+
self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
909
|
+
else:
|
|
910
|
+
self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
|
|
878
911
|
else:
|
|
879
|
-
self.model
|
|
912
|
+
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
|
|
913
|
+
mm = self.model._orig_mod
|
|
914
|
+
self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
915
|
+
else:
|
|
916
|
+
self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
|
|
880
917
|
|
|
881
918
|
except Exception as _:
|
|
882
919
|
beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
|
|
@@ -1164,6 +1201,6 @@ class TimeSeries():
|
|
|
1164
1201
|
self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1165
1202
|
else:
|
|
1166
1203
|
self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
|
|
1167
|
-
|
|
1204
|
+
self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
|
|
1168
1205
|
except Exception as e:
|
|
1169
1206
|
beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
|
|
@@ -148,7 +148,8 @@ class Autoformer(Base):
|
|
|
148
148
|
projection=nn.Linear(d_model, self.out_channels*self.mul, bias=True)
|
|
149
149
|
)
|
|
150
150
|
self.projection = nn.Linear(self.past_channels,self.out_channels*self.mul )
|
|
151
|
-
|
|
151
|
+
def can_be_compiled(self):
|
|
152
|
+
return True
|
|
152
153
|
def forward(self, batch):
|
|
153
154
|
|
|
154
155
|
|
|
@@ -425,6 +425,9 @@ class Diffusion(Base):
|
|
|
425
425
|
loss = self.compute_loss(batch,out)
|
|
426
426
|
return loss
|
|
427
427
|
|
|
428
|
+
def can_be_compiled(self):
|
|
429
|
+
return False
|
|
430
|
+
|
|
428
431
|
# function to concat embedded categorical variables
|
|
429
432
|
def cat_categorical_vars(self, batch:dict):
|
|
430
433
|
"""Extracting categorical context about past and future
|
|
@@ -228,7 +228,8 @@ class DilatedConvED(Base):
|
|
|
228
228
|
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
229
229
|
Permute() if use_bn else nn.Identity() ,
|
|
230
230
|
nn.Linear(hidden_RNN ,self.mul))
|
|
231
|
-
|
|
231
|
+
def can_be_compiled(self):
|
|
232
|
+
return True
|
|
232
233
|
|
|
233
234
|
|
|
234
235
|
def forward(self, batch):
|
|
@@ -136,7 +136,8 @@ class Duet(Base):
|
|
|
136
136
|
activation(),
|
|
137
137
|
nn.Linear(dim*2,self.out_channels*self.mul ))
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
def can_be_compiled(self):
|
|
140
|
+
return False
|
|
140
141
|
def forward(self, batch:dict)-> float:
|
|
141
142
|
# x: [Batch, Input length, Channel]
|
|
142
143
|
x_enc = batch['x_num_past'].to(self.device)
|
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
|
8
8
|
from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
|
|
9
9
|
from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
|
|
10
10
|
from .itransformer.Embed import DataEmbedding_inverted
|
|
11
|
+
from ..data_structure.utils import beauty_string
|
|
12
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
13
|
|
|
12
14
|
try:
|
|
13
15
|
import lightning.pytorch as pl
|
|
@@ -17,12 +19,6 @@ except:
|
|
|
17
19
|
import pytorch_lightning as pl
|
|
18
20
|
OLD_PL = True
|
|
19
21
|
from .base import Base
|
|
20
|
-
from .utils import QuantileLossMO,Permute, get_activation
|
|
21
|
-
|
|
22
|
-
from typing import List, Union
|
|
23
|
-
from ..data_structure.utils import beauty_string
|
|
24
|
-
from .utils import get_scope
|
|
25
|
-
from .utils import Embedding_cat_variables
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
|
|
@@ -34,8 +30,6 @@ class ITransformer(Base):
|
|
|
34
30
|
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
35
31
|
|
|
36
32
|
def __init__(self,
|
|
37
|
-
|
|
38
|
-
|
|
39
33
|
# specific params
|
|
40
34
|
hidden_size:int,
|
|
41
35
|
d_model: int,
|
|
@@ -107,6 +101,9 @@ class ITransformer(Base):
|
|
|
107
101
|
)
|
|
108
102
|
self.projector = nn.Linear(d_model, self.future_steps*self.mul, bias=True)
|
|
109
103
|
|
|
104
|
+
def can_be_compiled(self):
|
|
105
|
+
return True
|
|
106
|
+
|
|
110
107
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
|
111
108
|
if self.use_norm:
|
|
112
109
|
# Normalization from Non-stationary Transformer
|
|
@@ -143,7 +143,8 @@ class LinearTS(Base):
|
|
|
143
143
|
activation(),
|
|
144
144
|
nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
145
145
|
nn.Linear(hidden_size//8,self.future_steps*self.mul)))
|
|
146
|
-
|
|
146
|
+
def can_be_compiled(self):
|
|
147
|
+
return True
|
|
147
148
|
def forward(self, batch):
|
|
148
149
|
|
|
149
150
|
x = batch['x_num_past'].to(self.device)
|
|
@@ -133,6 +133,9 @@ class PatchTST(Base):
|
|
|
133
133
|
|
|
134
134
|
#self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
|
|
135
135
|
|
|
136
|
+
def can_be_compiled(self):
|
|
137
|
+
return True
|
|
138
|
+
|
|
136
139
|
def forward(self, batch): # x: [Batch, Input length, Channel]
|
|
137
140
|
|
|
138
141
|
|
|
@@ -67,7 +67,9 @@ class Simple(Base):
|
|
|
67
67
|
self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
68
68
|
activation(),nn.Dropout(dropout_rate),
|
|
69
69
|
nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
|
|
70
|
-
|
|
70
|
+
def can_be_compiled(self):
|
|
71
|
+
return True
|
|
72
|
+
|
|
71
73
|
def forward(self, batch):
|
|
72
74
|
|
|
73
75
|
x = batch['x_num_past'].to(self.device)
|
|
@@ -106,7 +106,10 @@ class TIDE(Base):
|
|
|
106
106
|
|
|
107
107
|
# linear for Y lookback
|
|
108
108
|
self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
|
|
109
|
-
|
|
109
|
+
|
|
110
|
+
def can_be_compiled(self):
|
|
111
|
+
return False
|
|
112
|
+
|
|
110
113
|
|
|
111
114
|
def forward(self, batch:dict)-> float:
|
|
112
115
|
"""training process of the diffusion network
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters, DEFAULT_FREQUENCY_MAPPING
|
|
16
|
+
from ..data_structure.utils import beauty_string
|
|
17
|
+
from .utils import get_scope
|
|
18
|
+
|
|
19
|
+
class TTM(Base):
|
|
20
|
+
handle_multivariate = True
|
|
21
|
+
handle_future_covariates = True
|
|
22
|
+
handle_categorical_variables = True
|
|
23
|
+
handle_quantile_loss = True
|
|
24
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
model_path:str,
|
|
28
|
+
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
29
|
+
prefer_longer_context:bool,
|
|
30
|
+
prediction_channel_indices,
|
|
31
|
+
exogenous_channel_indices_cont,
|
|
32
|
+
exogenous_channel_indices_cat,
|
|
33
|
+
decoder_mode,
|
|
34
|
+
freq,
|
|
35
|
+
freq_prefix_tuning,
|
|
36
|
+
fcm_context_length,
|
|
37
|
+
fcm_use_mixer,
|
|
38
|
+
fcm_mix_layers,
|
|
39
|
+
fcm_prepend_past,
|
|
40
|
+
enable_forecast_channel_mixing,
|
|
41
|
+
**kwargs)->None:
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self.save_hyperparameters(logger=False)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
self.index_fut = list(exogenous_channel_indices_cont)
|
|
49
|
+
|
|
50
|
+
if len(exogenous_channel_indices_cat)>0:
|
|
51
|
+
self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
|
|
52
|
+
else:
|
|
53
|
+
self.index_fut_cat = []
|
|
54
|
+
self.freq = freq
|
|
55
|
+
|
|
56
|
+
base_freq_token = get_frequency_token(self.freq) # e.g., shape [n_token] or scalar
|
|
57
|
+
# ensure it's a tensor of integer type
|
|
58
|
+
if not torch.is_tensor(base_freq_token):
|
|
59
|
+
base_freq_token = torch.tensor(base_freq_token)
|
|
60
|
+
base_freq_token = base_freq_token.long()
|
|
61
|
+
self.register_buffer("token", base_freq_token, persistent=True)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
self.model = get_model(
|
|
65
|
+
model_path=model_path,
|
|
66
|
+
context_length=self.past_steps,
|
|
67
|
+
prediction_length=self.future_steps,
|
|
68
|
+
prefer_l1_loss=prefer_l1_loss,
|
|
69
|
+
prefer_longer_context=prefer_longer_context,
|
|
70
|
+
num_input_channels=self.past_channels+len(self.embs_past), #giusto
|
|
71
|
+
decoder_mode=decoder_mode,
|
|
72
|
+
prediction_channel_indices=list(prediction_channel_indices),
|
|
73
|
+
exogenous_channel_indices=self.index_fut + self.index_fut_cat,
|
|
74
|
+
fcm_context_length=fcm_context_length,
|
|
75
|
+
fcm_use_mixer=fcm_use_mixer,
|
|
76
|
+
fcm_mix_layers=fcm_mix_layers,
|
|
77
|
+
freq=freq,
|
|
78
|
+
freq_prefix_tuning=freq_prefix_tuning,
|
|
79
|
+
fcm_prepend_past=fcm_prepend_past,
|
|
80
|
+
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
81
|
+
|
|
82
|
+
)
|
|
83
|
+
hidden_size = self.model.config.hidden_size
|
|
84
|
+
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
85
|
+
self._freeze_backbone()
|
|
86
|
+
|
|
87
|
+
def _freeze_backbone(self):
|
|
88
|
+
"""
|
|
89
|
+
Freeze the backbone of the model.
|
|
90
|
+
This is useful when you want to fine-tune only the head of the model.
|
|
91
|
+
"""
|
|
92
|
+
beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
|
|
93
|
+
|
|
94
|
+
# Freeze the backbone of the model
|
|
95
|
+
for param in self.model.backbone.parameters():
|
|
96
|
+
param.requires_grad = False
|
|
97
|
+
# Count params
|
|
98
|
+
beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _scaler_past(self, input):
|
|
102
|
+
for i, e in enumerate(self.embs_past):
|
|
103
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
104
|
+
return input
|
|
105
|
+
def _scaler_fut(self, input):
|
|
106
|
+
for i, e in enumerate(self.embs_fut):
|
|
107
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
108
|
+
return input
|
|
109
|
+
|
|
110
|
+
def can_be_compiled(self):
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
def forward(self, batch):
|
|
114
|
+
x_enc = batch['x_num_past'].to(self.device)
|
|
115
|
+
original_indexes = batch['idx_target'][0].tolist()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
if 'x_cat_past' in batch.keys():
|
|
119
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
120
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
121
|
+
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
122
|
+
else:
|
|
123
|
+
past_values = x_enc
|
|
124
|
+
|
|
125
|
+
future_values = torch.zeros_like(past_values).to(self.device)
|
|
126
|
+
future_values = future_values[:,:self.future_steps,:]
|
|
127
|
+
|
|
128
|
+
if 'x_num_future' in batch.keys():
|
|
129
|
+
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
130
|
+
if 'x_cat_future' in batch.keys():
|
|
131
|
+
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
132
|
+
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
133
|
+
future_values[:,:,self.index_cat_fut] = x_mark_dec
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
#investigating!! problem with dynamo!
|
|
137
|
+
#freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
138
|
+
|
|
139
|
+
batch_size = past_values.shape[0]
|
|
140
|
+
freq_token = self.token.repeat(batch_size).long().to(self.device)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
res = self.model(
|
|
144
|
+
past_values= past_values,
|
|
145
|
+
future_values= future_values,# future_values if future_values.shape[0]>0 else None,
|
|
146
|
+
past_observed_mask = None,
|
|
147
|
+
future_observed_mask = None,
|
|
148
|
+
output_hidden_states = False,
|
|
149
|
+
return_dict = False,
|
|
150
|
+
freq_token= freq_token,#[0:past_values.shape[0]], ##investigating
|
|
151
|
+
static_categorical_values = None
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
BS = res.shape[0]
|
|
156
|
+
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
157
|
+
|
|
158
|
+
|
|
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
|
|
|
111
111
|
self.train_loss_epoch = -100.0
|
|
112
112
|
self.verbose = verbose
|
|
113
113
|
self.name = self.__class__.__name__
|
|
114
|
-
self.train_epoch_metrics
|
|
115
|
-
self.validation_epoch_metrics
|
|
114
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
115
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
116
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
117
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
|
+
|
|
116
119
|
|
|
117
120
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
118
121
|
self.quantiles = quantiles
|
|
@@ -295,7 +298,8 @@ class Base(pl.LightningModule):
|
|
|
295
298
|
y_hat = self(batch)
|
|
296
299
|
loss = self.compute_loss(batch,y_hat)
|
|
297
300
|
|
|
298
|
-
self.train_epoch_metrics
|
|
301
|
+
self.train_epoch_metrics+=loss.detach()
|
|
302
|
+
self.train_epoch_count +=1
|
|
299
303
|
return loss
|
|
300
304
|
|
|
301
305
|
|
|
@@ -311,27 +315,20 @@ class Base(pl.LightningModule):
|
|
|
311
315
|
y_hat = self(batch)
|
|
312
316
|
score = 0
|
|
313
317
|
if batch_idx==0:
|
|
314
|
-
|
|
315
|
-
idx = 1
|
|
316
|
-
else:
|
|
317
|
-
idx = 0
|
|
318
|
-
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
319
|
-
|
|
318
|
+
|
|
320
319
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
return self.compute_loss(batch,y_hat)+score
|
|
334
|
-
|
|
320
|
+
self._val_outputs.append({
|
|
321
|
+
"y": batch['y'].detach().cpu(),
|
|
322
|
+
"y_hat": y_hat.detach().cpu()
|
|
323
|
+
})
|
|
324
|
+
self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
|
|
325
|
+
self.validation_epoch_count+=1
|
|
326
|
+
|
|
327
|
+
return None #self.compute_loss(batch,y_hat)+score
|
|
328
|
+
|
|
329
|
+
def on_validation_start(self):
|
|
330
|
+
# reset buffer each epoch
|
|
331
|
+
self._val_outputs = []
|
|
335
332
|
|
|
336
333
|
def validation_epoch_end(self, outs):
|
|
337
334
|
"""
|
|
@@ -339,14 +336,30 @@ class Base(pl.LightningModule):
|
|
|
339
336
|
|
|
340
337
|
:meta private:
|
|
341
338
|
"""
|
|
342
|
-
if len(
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
339
|
+
if len(self._val_outputs)>0:
|
|
340
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
341
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
342
|
+
if self.use_quantiles:
|
|
343
|
+
idx = 1
|
|
344
|
+
else:
|
|
345
|
+
idx = 0
|
|
346
|
+
for i in range(ys.shape[2]):
|
|
347
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
348
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
349
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
350
|
+
ax.plot(real,'o-',label='real')
|
|
351
|
+
ax.plot(pred,'o-',label='pred')
|
|
352
|
+
ax.legend()
|
|
353
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
354
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
355
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
356
|
+
plt.close(fig)
|
|
357
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
358
|
+
|
|
359
|
+
self.validation_epoch_metrics.zero_()
|
|
360
|
+
self.validation_epoch_count.zero_()
|
|
361
|
+
self.log("val_loss", avg,sync_dist=True)
|
|
362
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
350
363
|
|
|
351
364
|
def training_epoch_end(self, outs):
|
|
352
365
|
"""
|
|
@@ -355,12 +368,11 @@ class Base(pl.LightningModule):
|
|
|
355
368
|
:meta private:
|
|
356
369
|
"""
|
|
357
370
|
|
|
358
|
-
loss =
|
|
359
|
-
self.log("train_loss", loss
|
|
371
|
+
loss = self.train_epoch_metrics/self.global_step
|
|
372
|
+
self.log("train_loss", loss,sync_dist=True)
|
|
360
373
|
self.count_epoch+=1
|
|
361
374
|
|
|
362
|
-
self.train_loss_epoch = loss
|
|
363
|
-
|
|
375
|
+
self.train_loss_epoch = loss
|
|
364
376
|
def compute_loss(self,batch,y_hat):
|
|
365
377
|
"""
|
|
366
378
|
custom loss calculation
|