dsipts 1.1.10__tar.gz → 1.1.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (88) hide show
  1. {dsipts-1.1.10 → dsipts-1.1.12}/PKG-INFO +1 -1
  2. {dsipts-1.1.10 → dsipts-1.1.12}/pyproject.toml +1 -1
  3. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/data_structure.py +57 -20
  4. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Autoformer.py +2 -1
  5. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/CrossFormer.py +2 -1
  6. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/D3VAE.py +2 -1
  7. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Diffusion.py +3 -0
  8. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/DilatedConv.py +2 -1
  9. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/DilatedConvED.py +2 -1
  10. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Duet.py +2 -1
  11. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ITransformer.py +5 -8
  12. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Informer.py +2 -1
  13. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/LinearTS.py +2 -1
  14. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/PatchTST.py +3 -0
  15. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/RNN.py +2 -1
  16. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Samformer.py +3 -1
  17. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Simple.py +3 -1
  18. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TFT.py +4 -0
  19. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TIDE.py +4 -1
  20. dsipts-1.1.12/src/dsipts/models/TTM.py +158 -0
  21. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/TimeXER.py +3 -1
  22. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/base.py +47 -35
  23. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/base_v2.py +53 -38
  24. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/layers.py +6 -2
  25. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/PKG-INFO +1 -1
  26. dsipts-1.1.10/src/dsipts/models/TTM.py +0 -252
  27. {dsipts-1.1.10 → dsipts-1.1.12}/README.md +0 -0
  28. {dsipts-1.1.10 → dsipts-1.1.12}/setup.cfg +0 -0
  29. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/__init__.py +0 -0
  30. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/__init__.py +0 -0
  31. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/monash.py +0 -0
  32. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_management/public_datasets.py +0 -0
  33. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/__init__.py +0 -0
  34. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/modifiers.py +0 -0
  35. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/data_structure/utils.py +0 -0
  36. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/Persistent.py +0 -0
  37. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/VQVAEA.py +0 -0
  38. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/VVA.py +0 -0
  39. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/__init__.py +0 -0
  40. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/autoformer/__init__.py +0 -0
  41. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/autoformer/layers.py +0 -0
  42. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/__init__.py +0 -0
  43. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/attn.py +0 -0
  44. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
  45. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_embed.py +0 -0
  46. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
  47. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/__init__.py +0 -0
  48. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
  49. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/embedding.py +0 -0
  50. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/encoder.py +0 -0
  51. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/model.py +0 -0
  52. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/neural_operations.py +0 -0
  53. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/resnet.py +0 -0
  54. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/d3vae/utils.py +0 -0
  55. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/__init__.py +0 -0
  56. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/duet/masked.py +0 -0
  57. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/__init__.py +0 -0
  58. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/attn.py +0 -0
  59. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/decoder.py +0 -0
  60. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/embed.py +0 -0
  61. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/informer/encoder.py +0 -0
  62. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/Embed.py +0 -0
  63. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
  64. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
  65. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/itransformer/__init__.py +0 -0
  66. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/patchtst/__init__.py +0 -0
  67. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/patchtst/layers.py +0 -0
  68. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/samformer/__init__.py +0 -0
  69. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/samformer/utils.py +0 -0
  70. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/tft/__init__.py +0 -0
  71. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/tft/sub_nn.py +0 -0
  72. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/timexer/Layers.py +0 -0
  73. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/timexer/__init__.py +0 -0
  74. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/__init__.py +0 -0
  75. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
  76. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/consts.py +0 -0
  77. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
  78. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/ttm/utils.py +0 -0
  79. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/utils.py +0 -0
  80. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/__init__.py +0 -0
  81. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/minigpt.py +0 -0
  82. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/vva/vqvae.py +0 -0
  83. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/xlstm/__init__.py +0 -0
  84. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts/models/xlstm/xLSTM.py +0 -0
  85. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/SOURCES.txt +0 -0
  86. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/dependency_links.txt +0 -0
  87. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/requires.txt +0 -0
  88. {dsipts-1.1.10 → dsipts-1.1.12}/src/dsipts.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dsipts
3
- Version: 1.1.10
3
+ Version: 1.1.12
4
4
  Summary: Unified library for timeseries modelling
5
5
  Author-email: Andrea Gobbi <agobbi@fbk.eu>
6
6
  Project-URL: Homepage, https://github.com/DSIP-FBK/DSIPTS
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dsipts"
3
- version = "1.1.10"
3
+ version = "1.1.12"
4
4
  description = "Unified library for timeseries modelling"
5
5
  readme = "README.md"
6
6
  requires-python = "==3.11.13"
@@ -35,7 +35,18 @@ from .modifiers import *
35
35
  from aim.pytorch_lightning import AimLogger
36
36
  import time
37
37
 
38
-
38
+ class DummyScaler():
39
+ def __init__(self):
40
+ pass
41
+ def fit(self,x):
42
+ pass
43
+ def transform(self,x):
44
+ return x
45
+ def inverse_transform(self,x):
46
+ return x
47
+ def fit_transform(self,x):
48
+ return x
49
+
39
50
 
40
51
  pd.options.mode.chained_assignment = None
41
52
  log = logging.getLogger(__name__)
@@ -210,20 +221,23 @@ class TimeSeries():
210
221
  self.future_variables = []
211
222
  self.target_variables = ['signal']
212
223
  self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
213
-
224
+ self.num_var = list(np.sort(self.num_var))
214
225
 
215
226
  def enrich(self,dataset,columns):
216
- if columns =='hour':
217
- dataset[columns] = dataset.time.dt.hour
218
- elif columns=='dow':
219
- dataset[columns] = dataset.time.dt.weekday
220
- elif columns=='month':
221
- dataset[columns] = dataset.time.dt.month
222
- elif columns=='minute':
223
- dataset[columns] = dataset.time.dt.minute
224
- else:
225
- if columns not in dataset.columns:
226
- beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
227
+ try:
228
+ if columns =='hour':
229
+ dataset[columns] = dataset.time.dt.hour
230
+ elif columns=='dow':
231
+ dataset[columns] = dataset.time.dt.weekday
232
+ elif columns=='month':
233
+ dataset[columns] = dataset.time.dt.month
234
+ elif columns=='minute':
235
+ dataset[columns] = dataset.time.dt.minute
236
+ else:
237
+ if columns not in dataset.columns:
238
+ beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
239
+ except:
240
+ beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
227
241
 
228
242
  def load_signal(self,data:pd.DataFrame,
229
243
  enrich_cat:List[str] = [],
@@ -300,7 +314,7 @@ class TimeSeries():
300
314
  if check_past:
301
315
  beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
302
316
  past_variables = list(set(past_variables).union(set(target_variables)))
303
-
317
+ past_variables = list(np.sort(past_variables))
304
318
  self.cat_past_var = cat_past_var
305
319
  self.cat_fut_var = cat_fut_var
306
320
 
@@ -321,14 +335,18 @@ class TimeSeries():
321
335
  beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
322
336
  else:
323
337
  self.enrich(dataset,c)
338
+ self.cat_past_var = list(np.sort(self.cat_past_var))
339
+ self.cat_fut_var = list(np.sort(self.cat_fut_var))
340
+
324
341
  self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
325
-
342
+ self.cat_var = list(np.sort(self.cat_var))
326
343
  self.dataset = dataset
327
344
  self.past_variables = past_variables
328
345
  self.future_variables = future_variables
329
346
  self.target_variables = target_variables
330
347
  self.out_vars = len(target_variables)
331
348
  self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
349
+ self.num_var = list(np.sort(self.num_var))
332
350
  if silly_model:
333
351
  beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
334
352
  self.future_variables+=self.target_variables
@@ -665,7 +683,8 @@ class TimeSeries():
665
683
  #self.model.apply(weight_init_zeros)
666
684
 
667
685
  self.config = config
668
-
686
+
687
+
669
688
  beauty_string('Setting the model','block',self.verbose)
670
689
  beauty_string(model,'',self.verbose)
671
690
 
@@ -790,8 +809,17 @@ class TimeSeries():
790
809
  weight_exists = False
791
810
  beauty_string('I can not load a previous model','section',self.verbose)
792
811
 
812
+ self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
813
+ if self.model.can_be_compiled():
814
+ try:
815
+ self.model = torch.compile(self.model)
816
+ beauty_string('Model COMPILED','block',self.verbose)
817
+
818
+ except:
819
+ beauty_string('Can not compile the model','block',self.verbose)
820
+ else:
821
+ beauty_string('Model can not still be compiled, be patient','block',self.verbose)
793
822
 
794
-
795
823
 
796
824
  if OLD_PL:
797
825
  trainer = pl.Trainer(default_root_dir=dirpath,
@@ -873,10 +901,19 @@ class TimeSeries():
873
901
  self.losses = pd.DataFrame()
874
902
 
875
903
  try:
904
+
876
905
  if OLD_PL:
877
- self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
906
+ if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
907
+ self.model = self.model._orig_mod
908
+ self.model.load_from_checkpoint(self.checkpoint_file_last)
909
+ else:
910
+ self.model = self.model.load_from_checkpoint(self.checkpoint_file_last)
878
911
  else:
879
- self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
912
+ if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
913
+ mm = self.model._orig_mod
914
+ self.model = mm.__class__.load_from_checkpoint(self.checkpoint_file_last)
915
+ else:
916
+ self.model = self.model.__class__.load_from_checkpoint(self.checkpoint_file_last)
880
917
 
881
918
  except Exception as _:
882
919
  beauty_string(f'There is a problem loading the weights on file MAYBE CHANGED HOW WEIGHTS ARE LOADED {self.checkpoint_file_last}','section',self.verbose)
@@ -1164,6 +1201,6 @@ class TimeSeries():
1164
1201
  self.model = self.model.load_from_checkpoint(tmp_path,verbose=self.verbose,)
1165
1202
  else:
1166
1203
  self.model = self.model.__class__.load_from_checkpoint(tmp_path,verbose=self.verbose,)
1167
-
1204
+ self.model.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
1168
1205
  except Exception as e:
1169
1206
  beauty_string(f'There is a problem loading the weights on file {tmp_path} {e}','section',self.verbose)
@@ -148,7 +148,8 @@ class Autoformer(Base):
148
148
  projection=nn.Linear(d_model, self.out_channels*self.mul, bias=True)
149
149
  )
150
150
  self.projection = nn.Linear(self.past_channels,self.out_channels*self.mul )
151
-
151
+ def can_be_compiled(self):
152
+ return True
152
153
  def forward(self, batch):
153
154
 
154
155
 
@@ -114,7 +114,8 @@ class CrossFormer(Base):
114
114
 
115
115
 
116
116
 
117
-
117
+ def can_be_compiled(self):
118
+ return True
118
119
 
119
120
  def forward(self, batch):
120
121
 
@@ -98,7 +98,8 @@ class D3VAE(Base):
98
98
  self.gamma = 0.01
99
99
  self.lambda1 = 1.0
100
100
 
101
-
101
+ def can_be_compiled(self):
102
+ return False
102
103
  def configure_optimizers(self):
103
104
  """
104
105
  Each model has optim_config and scheduler_config
@@ -425,6 +425,9 @@ class Diffusion(Base):
425
425
  loss = self.compute_loss(batch,out)
426
426
  return loss
427
427
 
428
+ def can_be_compiled(self):
429
+ return False
430
+
428
431
  # function to concat embedded categorical variables
429
432
  def cat_categorical_vars(self, batch:dict):
430
433
  """Extracting categorical context about past and future
@@ -234,7 +234,8 @@ class DilatedConv(Base):
234
234
  self.return_additional_loss = True
235
235
 
236
236
 
237
-
237
+ def can_be_compiled(self):
238
+ return True
238
239
 
239
240
  def forward(self, batch):
240
241
  """It is mandatory to implement this method
@@ -228,7 +228,8 @@ class DilatedConvED(Base):
228
228
  nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
229
229
  Permute() if use_bn else nn.Identity() ,
230
230
  nn.Linear(hidden_RNN ,self.mul))
231
-
231
+ def can_be_compiled(self):
232
+ return True
232
233
 
233
234
 
234
235
  def forward(self, batch):
@@ -136,7 +136,8 @@ class Duet(Base):
136
136
  activation(),
137
137
  nn.Linear(dim*2,self.out_channels*self.mul ))
138
138
 
139
-
139
+ def can_be_compiled(self):
140
+ return False
140
141
  def forward(self, batch:dict)-> float:
141
142
  # x: [Batch, Input length, Channel]
142
143
  x_enc = batch['x_num_past'].to(self.device)
@@ -8,6 +8,8 @@ import numpy as np
8
8
  from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
9
9
  from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
10
10
  from .itransformer.Embed import DataEmbedding_inverted
11
+ from ..data_structure.utils import beauty_string
12
+ from .utils import get_scope,get_activation,Embedding_cat_variables
11
13
 
12
14
  try:
13
15
  import lightning.pytorch as pl
@@ -17,12 +19,6 @@ except:
17
19
  import pytorch_lightning as pl
18
20
  OLD_PL = True
19
21
  from .base import Base
20
- from .utils import QuantileLossMO,Permute, get_activation
21
-
22
- from typing import List, Union
23
- from ..data_structure.utils import beauty_string
24
- from .utils import get_scope
25
- from .utils import Embedding_cat_variables
26
22
 
27
23
 
28
24
 
@@ -34,8 +30,6 @@ class ITransformer(Base):
34
30
  description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
35
31
 
36
32
  def __init__(self,
37
-
38
-
39
33
  # specific params
40
34
  hidden_size:int,
41
35
  d_model: int,
@@ -107,6 +101,9 @@ class ITransformer(Base):
107
101
  )
108
102
  self.projector = nn.Linear(d_model, self.future_steps*self.mul, bias=True)
109
103
 
104
+ def can_be_compiled(self):
105
+ return True
106
+
110
107
  def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
111
108
  if self.use_norm:
112
109
  # Normalization from Non-stationary Transformer
@@ -124,7 +124,8 @@ class Informer(Base):
124
124
 
125
125
 
126
126
 
127
-
127
+ def can_be_compiled(self):
128
+ return True
128
129
 
129
130
  def forward(self,batch):
130
131
  #x_enc, x_mark_enc, x_dec, x_mark_dec,enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
@@ -143,7 +143,8 @@ class LinearTS(Base):
143
143
  activation(),
144
144
  nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
145
145
  nn.Linear(hidden_size//8,self.future_steps*self.mul)))
146
-
146
+ def can_be_compiled(self):
147
+ return True
147
148
  def forward(self, batch):
148
149
 
149
150
  x = batch['x_num_past'].to(self.device)
@@ -133,6 +133,9 @@ class PatchTST(Base):
133
133
 
134
134
  #self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
135
135
 
136
+ def can_be_compiled(self):
137
+ return True
138
+
136
139
  def forward(self, batch): # x: [Batch, Input length, Channel]
137
140
 
138
141
 
@@ -148,7 +148,8 @@ class RNN(Base):
148
148
  activation(),
149
149
  MyBN(hidden_RNN//8) if use_bn else nn.Dropout(dropout_rate) ,
150
150
  nn.Linear(hidden_RNN//8,1)))
151
-
151
+ def can_be_compiled(self):
152
+ return True
152
153
 
153
154
 
154
155
  def forward(self, batch):
@@ -85,7 +85,9 @@ class Samformer(Base):
85
85
  activation(),
86
86
  nn.Linear(dim*2,self.out_channels*self.mul ))
87
87
 
88
-
88
+ def can_be_compiled(self):
89
+ return True
90
+
89
91
  def forward(self, batch:dict)-> float:
90
92
 
91
93
  x = batch['x_num_past'].to(self.device)
@@ -67,7 +67,9 @@ class Simple(Base):
67
67
  self.linear = (nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
68
68
  activation(),nn.Dropout(dropout_rate),
69
69
  nn.Linear(hidden_size,self.out_channels*self.future_steps*self.mul)))
70
-
70
+ def can_be_compiled(self):
71
+ return True
72
+
71
73
  def forward(self, batch):
72
74
 
73
75
  x = batch['x_num_past'].to(self.device)
@@ -111,6 +111,10 @@ class TFT(Base):
111
111
 
112
112
  self.outLinear = nn.Linear(d_model, self.out_channels*self.mul)
113
113
 
114
+ def can_be_compiled(self):
115
+ return False
116
+
117
+
114
118
  def forward(self, batch:dict) -> torch.Tensor:
115
119
  """Temporal Fusion Transformer
116
120
 
@@ -106,7 +106,10 @@ class TIDE(Base):
106
106
 
107
107
  # linear for Y lookback
108
108
  self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
109
-
109
+
110
+ def can_be_compiled(self):
111
+ return False
112
+
110
113
 
111
114
  def forward(self, batch:dict)-> float:
112
115
  """training process of the diffusion network
@@ -0,0 +1,158 @@
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+
14
+
15
+ from .ttm.utils import get_model, get_frequency_token, count_parameters, DEFAULT_FREQUENCY_MAPPING
16
+ from ..data_structure.utils import beauty_string
17
+ from .utils import get_scope
18
+
19
+ class TTM(Base):
20
+ handle_multivariate = True
21
+ handle_future_covariates = True
22
+ handle_categorical_variables = True
23
+ handle_quantile_loss = True
24
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
25
+
26
+ def __init__(self,
27
+ model_path:str,
28
+ prefer_l1_loss:bool, # exog: set true to use l1 loss
29
+ prefer_longer_context:bool,
30
+ prediction_channel_indices,
31
+ exogenous_channel_indices_cont,
32
+ exogenous_channel_indices_cat,
33
+ decoder_mode,
34
+ freq,
35
+ freq_prefix_tuning,
36
+ fcm_context_length,
37
+ fcm_use_mixer,
38
+ fcm_mix_layers,
39
+ fcm_prepend_past,
40
+ enable_forecast_channel_mixing,
41
+ **kwargs)->None:
42
+
43
+ super().__init__(**kwargs)
44
+ self.save_hyperparameters(logger=False)
45
+
46
+
47
+
48
+ self.index_fut = list(exogenous_channel_indices_cont)
49
+
50
+ if len(exogenous_channel_indices_cat)>0:
51
+ self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
52
+ else:
53
+ self.index_fut_cat = []
54
+ self.freq = freq
55
+
56
+ base_freq_token = get_frequency_token(self.freq) # e.g., shape [n_token] or scalar
57
+ # ensure it's a tensor of integer type
58
+ if not torch.is_tensor(base_freq_token):
59
+ base_freq_token = torch.tensor(base_freq_token)
60
+ base_freq_token = base_freq_token.long()
61
+ self.register_buffer("token", base_freq_token, persistent=True)
62
+
63
+
64
+ self.model = get_model(
65
+ model_path=model_path,
66
+ context_length=self.past_steps,
67
+ prediction_length=self.future_steps,
68
+ prefer_l1_loss=prefer_l1_loss,
69
+ prefer_longer_context=prefer_longer_context,
70
+ num_input_channels=self.past_channels+len(self.embs_past), #giusto
71
+ decoder_mode=decoder_mode,
72
+ prediction_channel_indices=list(prediction_channel_indices),
73
+ exogenous_channel_indices=self.index_fut + self.index_fut_cat,
74
+ fcm_context_length=fcm_context_length,
75
+ fcm_use_mixer=fcm_use_mixer,
76
+ fcm_mix_layers=fcm_mix_layers,
77
+ freq=freq,
78
+ freq_prefix_tuning=freq_prefix_tuning,
79
+ fcm_prepend_past=fcm_prepend_past,
80
+ enable_forecast_channel_mixing=enable_forecast_channel_mixing,
81
+
82
+ )
83
+ hidden_size = self.model.config.hidden_size
84
+ self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
85
+ self._freeze_backbone()
86
+
87
+ def _freeze_backbone(self):
88
+ """
89
+ Freeze the backbone of the model.
90
+ This is useful when you want to fine-tune only the head of the model.
91
+ """
92
+ beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
93
+
94
+ # Freeze the backbone of the model
95
+ for param in self.model.backbone.parameters():
96
+ param.requires_grad = False
97
+ # Count params
98
+ beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
99
+
100
+
101
+ def _scaler_past(self, input):
102
+ for i, e in enumerate(self.embs_past):
103
+ input[:,:,i] = input[:, :, i] / (e-1)
104
+ return input
105
+ def _scaler_fut(self, input):
106
+ for i, e in enumerate(self.embs_fut):
107
+ input[:,:,i] = input[:, :, i] / (e-1)
108
+ return input
109
+
110
+ def can_be_compiled(self):
111
+ return True
112
+
113
+ def forward(self, batch):
114
+ x_enc = batch['x_num_past'].to(self.device)
115
+ original_indexes = batch['idx_target'][0].tolist()
116
+
117
+
118
+ if 'x_cat_past' in batch.keys():
119
+ x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
120
+ x_mark_enc = self._scaler_past(x_mark_enc)
121
+ past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
122
+ else:
123
+ past_values = x_enc
124
+
125
+ future_values = torch.zeros_like(past_values).to(self.device)
126
+ future_values = future_values[:,:self.future_steps,:]
127
+
128
+ if 'x_num_future' in batch.keys():
129
+ future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
130
+ if 'x_cat_future' in batch.keys():
131
+ x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
132
+ x_mark_dec = self._scaler_fut(x_mark_dec)
133
+ future_values[:,:,self.index_cat_fut] = x_mark_dec
134
+
135
+
136
+ #investigating!! problem with dynamo!
137
+ #freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
138
+
139
+ batch_size = past_values.shape[0]
140
+ freq_token = self.token.repeat(batch_size).long().to(self.device)
141
+
142
+
143
+ res = self.model(
144
+ past_values= past_values,
145
+ future_values= future_values,# future_values if future_values.shape[0]>0 else None,
146
+ past_observed_mask = None,
147
+ future_observed_mask = None,
148
+ output_hidden_states = False,
149
+ return_dict = False,
150
+ freq_token= freq_token,#[0:past_values.shape[0]], ##investigating
151
+ static_categorical_values = None
152
+ )
153
+
154
+
155
+ BS = res.shape[0]
156
+ return res.reshape(BS,self.future_steps,-1,self.mul)
157
+
158
+
@@ -125,7 +125,9 @@ class TimeXER(Base):
125
125
 
126
126
 
127
127
 
128
-
128
+ def can_be_compiled(self):
129
+ return True
130
+
129
131
 
130
132
 
131
133
  def forward(self, batch:dict)-> float:
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
111
111
  self.train_loss_epoch = -100.0
112
112
  self.verbose = verbose
113
113
  self.name = self.__class__.__name__
114
- self.train_epoch_metrics = []
115
- self.validation_epoch_metrics = []
114
+ self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
115
+ self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
116
+ self.register_buffer("train_epoch_count", torch.tensor(0))
117
+ self.register_buffer("validation_epoch_count", torch.tensor(0))
118
+
116
119
 
117
120
  self.use_quantiles = True if len(quantiles)>0 else False
118
121
  self.quantiles = quantiles
@@ -295,7 +298,8 @@ class Base(pl.LightningModule):
295
298
  y_hat = self(batch)
296
299
  loss = self.compute_loss(batch,y_hat)
297
300
 
298
- self.train_epoch_metrics.append(loss.item())
301
+ self.train_epoch_metrics+=loss.detach()
302
+ self.train_epoch_count +=1
299
303
  return loss
300
304
 
301
305
 
@@ -311,27 +315,20 @@ class Base(pl.LightningModule):
311
315
  y_hat = self(batch)
312
316
  score = 0
313
317
  if batch_idx==0:
314
- if self.use_quantiles:
315
- idx = 1
316
- else:
317
- idx = 0
318
- #track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
319
-
318
+
320
319
  if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
321
-
322
- for i in range(batch['y'].shape[2]):
323
- real = batch['y'][0,:,i].cpu().detach().numpy()
324
- pred = y_hat[0,:,i,idx].cpu().detach().numpy()
325
- fig, ax = plt.subplots(figsize=(7,5))
326
- ax.plot(real,'o-',label='real')
327
- ax.plot(pred,'o-',label='pred')
328
- ax.legend()
329
- ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
330
- self.logger.experiment.track(Image(fig), name='cm_training_end')
331
- #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
332
-
333
- return self.compute_loss(batch,y_hat)+score
334
-
320
+ self._val_outputs.append({
321
+ "y": batch['y'].detach().cpu(),
322
+ "y_hat": y_hat.detach().cpu()
323
+ })
324
+ self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
325
+ self.validation_epoch_count+=1
326
+
327
+ return None #self.compute_loss(batch,y_hat)+score
328
+
329
+ def on_validation_start(self):
330
+ # reset buffer each epoch
331
+ self._val_outputs = []
335
332
 
336
333
  def validation_epoch_end(self, outs):
337
334
  """
@@ -339,14 +336,30 @@ class Base(pl.LightningModule):
339
336
 
340
337
  :meta private:
341
338
  """
342
- if len(outs)==0:
343
- loss = 10000
344
- beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
345
- else:
346
- loss = torch.stack(outs).mean()
347
-
348
- self.log("val_loss", loss.item(),sync_dist=True)
349
- beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {loss.item():.4f}','info',self.verbose)
339
+ if len(self._val_outputs)>0:
340
+ ys = torch.cat([o["y"] for o in self._val_outputs])
341
+ y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
342
+ if self.use_quantiles:
343
+ idx = 1
344
+ else:
345
+ idx = 0
346
+ for i in range(ys.shape[2]):
347
+ real = ys[0,:,i].cpu().detach().numpy()
348
+ pred = y_hats[0,:,i,idx].cpu().detach().numpy()
349
+ fig, ax = plt.subplots(figsize=(7,5))
350
+ ax.plot(real,'o-',label='real')
351
+ ax.plot(pred,'o-',label='pred')
352
+ ax.legend()
353
+ ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
354
+ self.logger.experiment.track(Image(fig), name='cm_training_end')
355
+ #self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
356
+ plt.close(fig)
357
+ avg = self.validation_epoch_metrics/self.validation_epoch_count
358
+
359
+ self.validation_epoch_metrics.zero_()
360
+ self.validation_epoch_count.zero_()
361
+ self.log("val_loss", avg,sync_dist=True)
362
+ beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
350
363
 
351
364
  def training_epoch_end(self, outs):
352
365
  """
@@ -355,12 +368,11 @@ class Base(pl.LightningModule):
355
368
  :meta private:
356
369
  """
357
370
 
358
- loss = sum(outs['loss'] for outs in outs) / len(outs)
359
- self.log("train_loss", loss.item(),sync_dist=True)
371
+ loss = self.train_epoch_metrics/self.global_step
372
+ self.log("train_loss", loss,sync_dist=True)
360
373
  self.count_epoch+=1
361
374
 
362
- self.train_loss_epoch = loss.item()
363
-
375
+ self.train_loss_epoch = loss
364
376
  def compute_loss(self,batch,y_hat):
365
377
  """
366
378
  custom loss calculation