dsipts 1.1.9__tar.gz → 1.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- {dsipts-1.1.9 → dsipts-1.1.11}/PKG-INFO +1 -1
- {dsipts-1.1.9 → dsipts-1.1.11}/pyproject.toml +1 -1
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_structure/data_structure.py +38 -16
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ITransformer.py +2 -8
- dsipts-1.1.11/src/dsipts/models/TTM.py +145 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/base.py +52 -38
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/base_v2.py +55 -37
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/utils.py +7 -5
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts.egg-info/PKG-INFO +1 -1
- dsipts-1.1.9/src/dsipts/models/TTM.py +0 -252
- {dsipts-1.1.9 → dsipts-1.1.11}/README.md +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/setup.cfg +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_management/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_management/monash.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_management/public_datasets.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_structure/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_structure/modifiers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/data_structure/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Autoformer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/CrossFormer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/D3VAE.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Diffusion.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/DilatedConv.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/DilatedConvED.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Duet.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Informer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/LinearTS.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/PatchTST.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Persistent.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/RNN.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Samformer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/Simple.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/TFT.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/TIDE.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/TimeXER.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/VQVAEA.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/VVA.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/autoformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/autoformer/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/crossformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/crossformer/attn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/crossformer/cross_decoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/crossformer/cross_embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/crossformer/cross_encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/diffusion_process.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/embedding.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/model.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/neural_operations.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/resnet.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/d3vae/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/duet/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/duet/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/duet/masked.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/informer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/informer/attn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/informer/decoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/informer/embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/informer/encoder.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/itransformer/Embed.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/itransformer/SelfAttention_Family.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/itransformer/Transformer_EncDec.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/itransformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/patchtst/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/patchtst/layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/samformer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/samformer/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/tft/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/tft/sub_nn.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/timexer/Layers.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/timexer/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ttm/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ttm/configuration_tinytimemixer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ttm/consts.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ttm/modeling_tinytimemixer.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/ttm/utils.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/vva/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/vva/minigpt.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/vva/vqvae.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/xlstm/__init__.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts/models/xlstm/xLSTM.py +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts.egg-info/SOURCES.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts.egg-info/dependency_links.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts.egg-info/requires.txt +0 -0
- {dsipts-1.1.9 → dsipts-1.1.11}/src/dsipts.egg-info/top_level.txt +0 -0
|
@@ -35,7 +35,18 @@ from .modifiers import *
|
|
|
35
35
|
from aim.pytorch_lightning import AimLogger
|
|
36
36
|
import time
|
|
37
37
|
|
|
38
|
-
|
|
38
|
+
class DummyScaler():
|
|
39
|
+
def __init__(self):
|
|
40
|
+
pass
|
|
41
|
+
def fit(self,x):
|
|
42
|
+
pass
|
|
43
|
+
def transform(self,x):
|
|
44
|
+
return x
|
|
45
|
+
def inverse_transform(self,x):
|
|
46
|
+
return x
|
|
47
|
+
def fit_transform(self,x):
|
|
48
|
+
return x
|
|
49
|
+
|
|
39
50
|
|
|
40
51
|
pd.options.mode.chained_assignment = None
|
|
41
52
|
log = logging.getLogger(__name__)
|
|
@@ -210,20 +221,23 @@ class TimeSeries():
|
|
|
210
221
|
self.future_variables = []
|
|
211
222
|
self.target_variables = ['signal']
|
|
212
223
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
213
|
-
|
|
224
|
+
self.num_var = list(np.sort(self.num_var))
|
|
214
225
|
|
|
215
226
|
def enrich(self,dataset,columns):
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
+
try:
|
|
228
|
+
if columns =='hour':
|
|
229
|
+
dataset[columns] = dataset.time.dt.hour
|
|
230
|
+
elif columns=='dow':
|
|
231
|
+
dataset[columns] = dataset.time.dt.weekday
|
|
232
|
+
elif columns=='month':
|
|
233
|
+
dataset[columns] = dataset.time.dt.month
|
|
234
|
+
elif columns=='minute':
|
|
235
|
+
dataset[columns] = dataset.time.dt.minute
|
|
236
|
+
else:
|
|
237
|
+
if columns not in dataset.columns:
|
|
238
|
+
beauty_string(f'I can not automatically enrich column {columns}. Please contact the developers or add it manually to your dataset.','section',True)
|
|
239
|
+
except:
|
|
240
|
+
beauty_string(f'I can not automatically enrich column {columns}. Probably not a temporal index.','section',True)
|
|
227
241
|
|
|
228
242
|
def load_signal(self,data:pd.DataFrame,
|
|
229
243
|
enrich_cat:List[str] = [],
|
|
@@ -300,7 +314,7 @@ class TimeSeries():
|
|
|
300
314
|
if check_past:
|
|
301
315
|
beauty_string('I will update past column adding all target columns, if you want to avoid this beahviour please use check_pass as false','info',self.verbose)
|
|
302
316
|
past_variables = list(set(past_variables).union(set(target_variables)))
|
|
303
|
-
|
|
317
|
+
past_variables = list(np.sort(past_variables))
|
|
304
318
|
self.cat_past_var = cat_past_var
|
|
305
319
|
self.cat_fut_var = cat_fut_var
|
|
306
320
|
|
|
@@ -321,14 +335,18 @@ class TimeSeries():
|
|
|
321
335
|
beauty_string('Categorical {c} already present, it will be added to categorical variable but not call the enriching function','info',self.verbose)
|
|
322
336
|
else:
|
|
323
337
|
self.enrich(dataset,c)
|
|
338
|
+
self.cat_past_var = list(np.sort(self.cat_past_var))
|
|
339
|
+
self.cat_fut_var = list(np.sort(self.cat_fut_var))
|
|
340
|
+
|
|
324
341
|
self.cat_var = list(set(self.cat_past_var+self.cat_fut_var)) ## all categorical data
|
|
325
|
-
|
|
342
|
+
self.cat_var = list(np.sort(self.cat_var))
|
|
326
343
|
self.dataset = dataset
|
|
327
344
|
self.past_variables = past_variables
|
|
328
345
|
self.future_variables = future_variables
|
|
329
346
|
self.target_variables = target_variables
|
|
330
347
|
self.out_vars = len(target_variables)
|
|
331
348
|
self.num_var = list(set(self.past_variables).union(set(self.future_variables)).union(set(self.target_variables)))
|
|
349
|
+
self.num_var = list(np.sort(self.num_var))
|
|
332
350
|
if silly_model:
|
|
333
351
|
beauty_string('YOU ARE TRAINING A SILLY MODEL WITH THE TARGETS IN THE INPUTS','section',self.verbose)
|
|
334
352
|
self.future_variables+=self.target_variables
|
|
@@ -665,7 +683,11 @@ class TimeSeries():
|
|
|
665
683
|
#self.model.apply(weight_init_zeros)
|
|
666
684
|
|
|
667
685
|
self.config = config
|
|
668
|
-
|
|
686
|
+
try:
|
|
687
|
+
self.model = torch.compile(self.model)
|
|
688
|
+
except:
|
|
689
|
+
beauty_string('Can not compile the model','block',self.verbose)
|
|
690
|
+
|
|
669
691
|
beauty_string('Setting the model','block',self.verbose)
|
|
670
692
|
beauty_string(model,'',self.verbose)
|
|
671
693
|
|
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
|
8
8
|
from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
|
|
9
9
|
from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
|
|
10
10
|
from .itransformer.Embed import DataEmbedding_inverted
|
|
11
|
+
from ..data_structure.utils import beauty_string
|
|
12
|
+
from .utils import get_scope,get_activation,Embedding_cat_variables
|
|
11
13
|
|
|
12
14
|
try:
|
|
13
15
|
import lightning.pytorch as pl
|
|
@@ -17,12 +19,6 @@ except:
|
|
|
17
19
|
import pytorch_lightning as pl
|
|
18
20
|
OLD_PL = True
|
|
19
21
|
from .base import Base
|
|
20
|
-
from .utils import QuantileLossMO,Permute, get_activation
|
|
21
|
-
|
|
22
|
-
from typing import List, Union
|
|
23
|
-
from ..data_structure.utils import beauty_string
|
|
24
|
-
from .utils import get_scope
|
|
25
|
-
from .utils import Embedding_cat_variables
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
|
|
@@ -34,8 +30,6 @@ class ITransformer(Base):
|
|
|
34
30
|
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
35
31
|
|
|
36
32
|
def __init__(self,
|
|
37
|
-
|
|
38
|
-
|
|
39
33
|
# specific params
|
|
40
34
|
hidden_size:int,
|
|
41
35
|
d_model: int,
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters
|
|
16
|
+
from ..data_structure.utils import beauty_string
|
|
17
|
+
from .utils import get_scope
|
|
18
|
+
|
|
19
|
+
class TTM(Base):
|
|
20
|
+
handle_multivariate = True
|
|
21
|
+
handle_future_covariates = True
|
|
22
|
+
handle_categorical_variables = True
|
|
23
|
+
handle_quantile_loss = True
|
|
24
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
model_path:str,
|
|
28
|
+
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
29
|
+
prefer_longer_context:bool,
|
|
30
|
+
prediction_channel_indices,
|
|
31
|
+
exogenous_channel_indices_cont,
|
|
32
|
+
exogenous_channel_indices_cat,
|
|
33
|
+
decoder_mode,
|
|
34
|
+
freq,
|
|
35
|
+
freq_prefix_tuning,
|
|
36
|
+
fcm_context_length,
|
|
37
|
+
fcm_use_mixer,
|
|
38
|
+
fcm_mix_layers,
|
|
39
|
+
fcm_prepend_past,
|
|
40
|
+
enable_forecast_channel_mixing,
|
|
41
|
+
**kwargs)->None:
|
|
42
|
+
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self.save_hyperparameters(logger=False)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
self.index_fut = list(exogenous_channel_indices_cont)
|
|
48
|
+
|
|
49
|
+
if len(exogenous_channel_indices_cat)>0:
|
|
50
|
+
self.index_fut_cat = (self.past_channels+len(self.embs_past))+list(exogenous_channel_indices_cat)
|
|
51
|
+
else:
|
|
52
|
+
self.index_fut_cat = []
|
|
53
|
+
self.freq = freq
|
|
54
|
+
|
|
55
|
+
self.model = get_model(
|
|
56
|
+
model_path=model_path,
|
|
57
|
+
context_length=self.past_steps,
|
|
58
|
+
prediction_length=self.future_steps,
|
|
59
|
+
prefer_l1_loss=prefer_l1_loss,
|
|
60
|
+
prefer_longer_context=prefer_longer_context,
|
|
61
|
+
num_input_channels=self.past_channels+len(self.embs_past), #giusto
|
|
62
|
+
decoder_mode=decoder_mode,
|
|
63
|
+
prediction_channel_indices=list(prediction_channel_indices),
|
|
64
|
+
exogenous_channel_indices=self.index_fut + self.index_fut_cat,
|
|
65
|
+
fcm_context_length=fcm_context_length,
|
|
66
|
+
fcm_use_mixer=fcm_use_mixer,
|
|
67
|
+
fcm_mix_layers=fcm_mix_layers,
|
|
68
|
+
freq=freq,
|
|
69
|
+
freq_prefix_tuning=freq_prefix_tuning,
|
|
70
|
+
fcm_prepend_past=fcm_prepend_past,
|
|
71
|
+
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
72
|
+
|
|
73
|
+
)
|
|
74
|
+
hidden_size = self.model.config.hidden_size
|
|
75
|
+
self.model.prediction_head = torch.nn.Linear(hidden_size, self.out_channels*self.mul)
|
|
76
|
+
self._freeze_backbone()
|
|
77
|
+
|
|
78
|
+
def _freeze_backbone(self):
|
|
79
|
+
"""
|
|
80
|
+
Freeze the backbone of the model.
|
|
81
|
+
This is useful when you want to fine-tune only the head of the model.
|
|
82
|
+
"""
|
|
83
|
+
beauty_string(f"Number of params before freezing backbone:{count_parameters(self.model)}",'info',self.verbose)
|
|
84
|
+
|
|
85
|
+
# Freeze the backbone of the model
|
|
86
|
+
for param in self.model.backbone.parameters():
|
|
87
|
+
param.requires_grad = False
|
|
88
|
+
# Count params
|
|
89
|
+
beauty_string(f"Number of params after freezing the backbone: {count_parameters(self.model)}",'info',self.verbose)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _scaler_past(self, input):
|
|
93
|
+
for i, e in enumerate(self.embs_past):
|
|
94
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
95
|
+
return input
|
|
96
|
+
def _scaler_fut(self, input):
|
|
97
|
+
for i, e in enumerate(self.embs_fut):
|
|
98
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
99
|
+
return input
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def forward(self, batch):
|
|
103
|
+
x_enc = batch['x_num_past']
|
|
104
|
+
original_indexes = batch['idx_target'][0].tolist()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if 'x_cat_past' in batch.keys():
|
|
110
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
111
|
+
x_mark_enc = self._scaler_past(x_mark_enc)
|
|
112
|
+
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
113
|
+
else:
|
|
114
|
+
past_values = x_enc
|
|
115
|
+
|
|
116
|
+
future_values = torch.zeros_like(past_values)
|
|
117
|
+
future_values = future_values[:,:self.future_steps,:]
|
|
118
|
+
|
|
119
|
+
if 'x_num_future' in batch.keys():
|
|
120
|
+
future_values[:,:,self.index_fut] = batch['x_num_future'].to(self.device)
|
|
121
|
+
if 'x_cat_future' in batch.keys():
|
|
122
|
+
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
123
|
+
x_mark_dec = self._scaler_fut(x_mark_dec)
|
|
124
|
+
future_values[:,:,self.index_cat_fut] = x_mark_dec
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
#investigating!!
|
|
128
|
+
freq_token = get_frequency_token(self.freq).repeat(past_values.shape[0])
|
|
129
|
+
|
|
130
|
+
res = self.model(
|
|
131
|
+
past_values= past_values,
|
|
132
|
+
future_values= future_values,# future_values if future_values.shape[0]>0 else None,
|
|
133
|
+
past_observed_mask = None,
|
|
134
|
+
future_observed_mask = None,
|
|
135
|
+
output_hidden_states = False,
|
|
136
|
+
return_dict = False,
|
|
137
|
+
freq_token= freq_token, ##investigating
|
|
138
|
+
static_categorical_values = None
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
BS = res.shape[0]
|
|
143
|
+
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
144
|
+
|
|
145
|
+
|
|
@@ -111,8 +111,11 @@ class Base(pl.LightningModule):
|
|
|
111
111
|
self.train_loss_epoch = -100.0
|
|
112
112
|
self.verbose = verbose
|
|
113
113
|
self.name = self.__class__.__name__
|
|
114
|
-
self.train_epoch_metrics
|
|
115
|
-
self.validation_epoch_metrics
|
|
114
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
115
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
116
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
117
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
|
+
|
|
116
119
|
|
|
117
120
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
118
121
|
self.quantiles = quantiles
|
|
@@ -136,9 +139,9 @@ class Base(pl.LightningModule):
|
|
|
136
139
|
self.is_classification = False
|
|
137
140
|
if len(self.quantiles)>0:
|
|
138
141
|
if self.loss_type=='cprs':
|
|
139
|
-
self.use_quantiles =
|
|
142
|
+
self.use_quantiles = True
|
|
140
143
|
self.mul = len(self.quantiles)
|
|
141
|
-
self.loss = CPRS()
|
|
144
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
142
145
|
else:
|
|
143
146
|
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
144
147
|
self.use_quantiles = True
|
|
@@ -193,7 +196,9 @@ class Base(pl.LightningModule):
|
|
|
193
196
|
"""
|
|
194
197
|
if self.loss_type=='cprs':
|
|
195
198
|
tmp = self(batch)
|
|
196
|
-
|
|
199
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
200
|
+
return tmp
|
|
201
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
197
202
|
|
|
198
203
|
return self(batch)
|
|
199
204
|
|
|
@@ -293,7 +298,8 @@ class Base(pl.LightningModule):
|
|
|
293
298
|
y_hat = self(batch)
|
|
294
299
|
loss = self.compute_loss(batch,y_hat)
|
|
295
300
|
|
|
296
|
-
self.train_epoch_metrics
|
|
301
|
+
self.train_epoch_metrics+=loss.detach()
|
|
302
|
+
self.train_epoch_count +=1
|
|
297
303
|
return loss
|
|
298
304
|
|
|
299
305
|
|
|
@@ -309,27 +315,20 @@ class Base(pl.LightningModule):
|
|
|
309
315
|
y_hat = self(batch)
|
|
310
316
|
score = 0
|
|
311
317
|
if batch_idx==0:
|
|
312
|
-
|
|
313
|
-
idx = 1
|
|
314
|
-
else:
|
|
315
|
-
idx = 0
|
|
316
|
-
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
317
|
-
|
|
318
|
+
|
|
318
319
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
return self.compute_loss(batch,y_hat)+score
|
|
332
|
-
|
|
320
|
+
self._val_outputs.append({
|
|
321
|
+
"y": batch['y'].detach().cpu(),
|
|
322
|
+
"y_hat": y_hat.detach().cpu()
|
|
323
|
+
})
|
|
324
|
+
self.validation_epoch_metrics = (self.compute_loss(batch,y_hat)+score).detach()
|
|
325
|
+
self.validation_epoch_count+=1
|
|
326
|
+
|
|
327
|
+
return None #self.compute_loss(batch,y_hat)+score
|
|
328
|
+
|
|
329
|
+
def on_validation_start(self):
|
|
330
|
+
# reset buffer each epoch
|
|
331
|
+
self._val_outputs = []
|
|
333
332
|
|
|
334
333
|
def validation_epoch_end(self, outs):
|
|
335
334
|
"""
|
|
@@ -337,14 +336,30 @@ class Base(pl.LightningModule):
|
|
|
337
336
|
|
|
338
337
|
:meta private:
|
|
339
338
|
"""
|
|
340
|
-
if len(
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
339
|
+
if len(self._val_outputs)>0:
|
|
340
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
341
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
342
|
+
if self.use_quantiles:
|
|
343
|
+
idx = 1
|
|
344
|
+
else:
|
|
345
|
+
idx = 0
|
|
346
|
+
for i in range(ys.shape[2]):
|
|
347
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
348
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
349
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
350
|
+
ax.plot(real,'o-',label='real')
|
|
351
|
+
ax.plot(pred,'o-',label='pred')
|
|
352
|
+
ax.legend()
|
|
353
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
354
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
355
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
356
|
+
plt.close(fig)
|
|
357
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
358
|
+
|
|
359
|
+
self.validation_epoch_metrics.zero_()
|
|
360
|
+
self.validation_epoch_count.zero_()
|
|
361
|
+
self.log("val_loss", avg,sync_dist=True)
|
|
362
|
+
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
348
363
|
|
|
349
364
|
def training_epoch_end(self, outs):
|
|
350
365
|
"""
|
|
@@ -353,12 +368,11 @@ class Base(pl.LightningModule):
|
|
|
353
368
|
:meta private:
|
|
354
369
|
"""
|
|
355
370
|
|
|
356
|
-
loss =
|
|
357
|
-
self.log("train_loss", loss
|
|
371
|
+
loss = self.train_epoch_metrics/self.global_step
|
|
372
|
+
self.log("train_loss", loss,sync_dist=True)
|
|
358
373
|
self.count_epoch+=1
|
|
359
374
|
|
|
360
|
-
self.train_loss_epoch = loss
|
|
361
|
-
|
|
375
|
+
self.train_loss_epoch = loss
|
|
362
376
|
def compute_loss(self,batch,y_hat):
|
|
363
377
|
"""
|
|
364
378
|
custom loss calculation
|
|
@@ -15,7 +15,6 @@ from typing import List, Union
|
|
|
15
15
|
from .utils import QuantileLossMO, CPRS
|
|
16
16
|
import torch.nn as nn
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
def standardize_momentum(x,order):
|
|
20
19
|
mean = torch.mean(x,1).unsqueeze(1).repeat(1,x.shape[1],1)
|
|
21
20
|
num = torch.pow(x-mean,order).mean(axis=1)
|
|
@@ -113,8 +112,13 @@ class Base(pl.LightningModule):
|
|
|
113
112
|
self.train_loss_epoch = -100.0
|
|
114
113
|
self.verbose = verbose
|
|
115
114
|
self.name = self.__class__.__name__
|
|
116
|
-
self.train_epoch_metrics =
|
|
117
|
-
self.validation_epoch_metrics =
|
|
115
|
+
#self.train_epoch_metrics = 0
|
|
116
|
+
#self.validation_epoch_metrics = 0
|
|
117
|
+
|
|
118
|
+
self.register_buffer("train_epoch_metrics", torch.tensor(0.0))
|
|
119
|
+
self.register_buffer("validation_epoch_metrics", torch.tensor(0.0))
|
|
120
|
+
self.register_buffer("train_epoch_count", torch.tensor(0))
|
|
121
|
+
self.register_buffer("validation_epoch_count", torch.tensor(0))
|
|
118
122
|
|
|
119
123
|
self.use_quantiles = True if len(quantiles)>0 else False
|
|
120
124
|
self.quantiles = quantiles
|
|
@@ -138,9 +142,9 @@ class Base(pl.LightningModule):
|
|
|
138
142
|
self.is_classification = False
|
|
139
143
|
if len(self.quantiles)>0:
|
|
140
144
|
if self.loss_type=='cprs':
|
|
141
|
-
self.use_quantiles =
|
|
145
|
+
self.use_quantiles = True
|
|
142
146
|
self.mul = len(self.quantiles)
|
|
143
|
-
self.loss = CPRS()
|
|
147
|
+
self.loss = CPRS(alpha=self.persistence_weight)
|
|
144
148
|
else:
|
|
145
149
|
assert len(self.quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
146
150
|
self.use_quantiles = True
|
|
@@ -197,7 +201,9 @@ class Base(pl.LightningModule):
|
|
|
197
201
|
|
|
198
202
|
if self.loss_type=='cprs':
|
|
199
203
|
tmp = self(batch)
|
|
200
|
-
|
|
204
|
+
tmp = torch.quantile(tmp, torch.tensor([0.05, 0.5, 0.95]), dim=-1).permute(1,2,3,0)
|
|
205
|
+
return tmp
|
|
206
|
+
#return tmp.mean(axis=-1).unsqueeze(-1)
|
|
201
207
|
|
|
202
208
|
return self(batch)
|
|
203
209
|
|
|
@@ -297,7 +303,8 @@ class Base(pl.LightningModule):
|
|
|
297
303
|
y_hat = self(batch)
|
|
298
304
|
loss = self.compute_loss(batch,y_hat)
|
|
299
305
|
|
|
300
|
-
self.train_epoch_metrics
|
|
306
|
+
self.train_epoch_metrics+=loss.detach()
|
|
307
|
+
self.train_epoch_count +=1
|
|
301
308
|
return loss
|
|
302
309
|
|
|
303
310
|
|
|
@@ -314,41 +321,54 @@ class Base(pl.LightningModule):
|
|
|
314
321
|
y_hat = self(batch)
|
|
315
322
|
score = 0
|
|
316
323
|
if batch_idx==0:
|
|
317
|
-
|
|
318
|
-
idx = 1
|
|
319
|
-
else:
|
|
320
|
-
idx = 0
|
|
324
|
+
|
|
321
325
|
#track the predictions! We can do better than this but maybe it is better to firstly update pytorch-lightening
|
|
322
326
|
|
|
323
327
|
if self.count_epoch%int(max(self.trainer.max_epochs/100,1))==1:
|
|
328
|
+
self._val_outputs.append({
|
|
329
|
+
"y": batch['y'].detach().cpu(),
|
|
330
|
+
"y_hat": y_hat.detach().cpu()
|
|
331
|
+
})
|
|
332
|
+
self.validation_epoch_metrics+= (self.compute_loss(batch,y_hat)+score).detach()
|
|
333
|
+
self.validation_epoch_count+=1
|
|
334
|
+
return None
|
|
324
335
|
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
ax.plot(real,'o-',label='real')
|
|
330
|
-
ax.plot(pred,'o-',label='pred')
|
|
331
|
-
ax.legend()
|
|
332
|
-
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
333
|
-
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
334
|
-
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
335
|
-
self.validation_epoch_metrics.append(self.compute_loss(batch,y_hat)+score)
|
|
336
|
-
return
|
|
337
|
-
|
|
336
|
+
def on_validation_start(self):
|
|
337
|
+
# reset buffer each epoch
|
|
338
|
+
self._val_outputs = []
|
|
339
|
+
|
|
338
340
|
|
|
339
341
|
def on_validation_epoch_end(self):
|
|
340
342
|
"""
|
|
341
343
|
pythotrch lightening stuff
|
|
342
344
|
|
|
343
345
|
:meta private:
|
|
344
|
-
"""
|
|
346
|
+
"""
|
|
345
347
|
|
|
346
|
-
if len(self.
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
348
|
+
if len(self._val_outputs)>0:
|
|
349
|
+
ys = torch.cat([o["y"] for o in self._val_outputs])
|
|
350
|
+
y_hats = torch.cat([o["y_hat"] for o in self._val_outputs])
|
|
351
|
+
if self.use_quantiles:
|
|
352
|
+
idx = 1
|
|
353
|
+
else:
|
|
354
|
+
idx = 0
|
|
355
|
+
for i in range(ys.shape[2]):
|
|
356
|
+
real = ys[0,:,i].cpu().detach().numpy()
|
|
357
|
+
pred = y_hats[0,:,i,idx].cpu().detach().numpy()
|
|
358
|
+
fig, ax = plt.subplots(figsize=(7,5))
|
|
359
|
+
ax.plot(real,'o-',label='real')
|
|
360
|
+
ax.plot(pred,'o-',label='pred')
|
|
361
|
+
ax.legend()
|
|
362
|
+
ax.set_title(f'Channel {i} first element first batch validation {int(100*self.count_epoch/self.trainer.max_epochs)}%')
|
|
363
|
+
self.logger.experiment.track(Image(fig), name='cm_training_end')
|
|
364
|
+
#self.log(f"example_{i}", np.stack([real, pred]).T,sync_dist=True)
|
|
365
|
+
plt.close(fig)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
avg = self.validation_epoch_metrics/self.validation_epoch_count
|
|
369
|
+
|
|
370
|
+
self.validation_epoch_metrics.zero_()
|
|
371
|
+
self.validation_epoch_count.zero_()
|
|
352
372
|
self.log("val_loss", avg,sync_dist=True)
|
|
353
373
|
beauty_string(f'Epoch: {self.count_epoch} train error: {self.train_loss_epoch:.4f} validation loss: {avg:.4f}','info',self.verbose)
|
|
354
374
|
|
|
@@ -359,14 +379,12 @@ class Base(pl.LightningModule):
|
|
|
359
379
|
|
|
360
380
|
:meta private:
|
|
361
381
|
"""
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
beauty_string(f'THIS IS A BUG, It should be polulated','info',self.verbose)
|
|
365
|
-
else:
|
|
366
|
-
avg = np.stack(self.train_epoch_metrics).mean()
|
|
382
|
+
|
|
383
|
+
avg = self.train_epoch_metrics/self.train_epoch_count
|
|
367
384
|
self.log("train_loss", avg,sync_dist=True)
|
|
368
385
|
self.count_epoch+=1
|
|
369
|
-
self.train_epoch_metrics
|
|
386
|
+
self.train_epoch_metrics.zero_()
|
|
387
|
+
self.train_epoch_count.zero_()
|
|
370
388
|
self.train_loss_epoch = avg
|
|
371
389
|
|
|
372
390
|
def compute_loss(self,batch,y_hat):
|
|
@@ -633,7 +633,7 @@ class CPRS(nn.Module):
|
|
|
633
633
|
with large ensembles.
|
|
634
634
|
"""
|
|
635
635
|
|
|
636
|
-
def __init__(self, alpha=0.
|
|
636
|
+
def __init__(self, alpha=0.5, reduction='mean'):
|
|
637
637
|
super().__init__()
|
|
638
638
|
self.alpha = alpha
|
|
639
639
|
self.reduction = reduction
|
|
@@ -674,17 +674,19 @@ class CPRS(nn.Module):
|
|
|
674
674
|
# Create mask to exclude diagonal (i=j)
|
|
675
675
|
mask = ~torch.eye(n_members, dtype=torch.bool, device=ensemble.device)
|
|
676
676
|
mask = mask.view(1, n_members, n_members, *[1]*(len(ensemble.shape)-2))
|
|
677
|
-
|
|
677
|
+
|
|
678
678
|
# Apply mask and compute mean
|
|
679
|
-
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2))
|
|
679
|
+
pairwise_term = (pairwise_diffs * mask).sum(dim=(1, 2)) ##formula 3 second term
|
|
680
680
|
|
|
681
681
|
# Combine terms according to afCRPS formula
|
|
682
|
-
loss = mae_term - (1 - epsilon) * pairwise_term
|
|
682
|
+
loss = mae_term - (1 - epsilon) * pairwise_term/ (2*n_members * (n_members - 1))
|
|
683
683
|
|
|
684
684
|
# Apply weights if provided
|
|
685
685
|
if weights is not None:
|
|
686
686
|
loss = loss * weights
|
|
687
|
-
|
|
687
|
+
#if loss.mean()<-2:
|
|
688
|
+
# import pdb
|
|
689
|
+
# pdb.set_trace()
|
|
688
690
|
# Apply reduction
|
|
689
691
|
if self.reduction == 'none':
|
|
690
692
|
return loss
|
|
@@ -1,252 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import numpy as np
|
|
3
|
-
from torch import nn
|
|
4
|
-
|
|
5
|
-
try:
|
|
6
|
-
import lightning.pytorch as pl
|
|
7
|
-
from .base_v2 import Base
|
|
8
|
-
OLD_PL = False
|
|
9
|
-
except:
|
|
10
|
-
import pytorch_lightning as pl
|
|
11
|
-
OLD_PL = True
|
|
12
|
-
from .base import Base
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from typing import List,Union
|
|
17
|
-
|
|
18
|
-
from .utils import QuantileLossMO
|
|
19
|
-
from ..data_structure.utils import beauty_string
|
|
20
|
-
from .ttm.utils import get_model, get_frequency_token, count_parameters, RMSELoss
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class TTM(Base):
|
|
24
|
-
def __init__(self,
|
|
25
|
-
model_path:str,
|
|
26
|
-
past_steps:int,
|
|
27
|
-
future_steps:int,
|
|
28
|
-
freq_prefix_tuning:bool,
|
|
29
|
-
freq:str,
|
|
30
|
-
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
31
|
-
prefer_longer_context:bool,
|
|
32
|
-
loss_type:str,
|
|
33
|
-
num_input_channels,
|
|
34
|
-
prediction_channel_indices,
|
|
35
|
-
exogenous_channel_indices,
|
|
36
|
-
decoder_mode,
|
|
37
|
-
fcm_context_length,
|
|
38
|
-
fcm_use_mixer,
|
|
39
|
-
fcm_mix_layers,
|
|
40
|
-
fcm_prepend_past,
|
|
41
|
-
enable_forecast_channel_mixing,
|
|
42
|
-
out_channels:int,
|
|
43
|
-
embs:List[int],
|
|
44
|
-
remove_last = False,
|
|
45
|
-
optim:Union[str,None]=None,
|
|
46
|
-
optim_config:dict=None,
|
|
47
|
-
scheduler_config:dict=None,
|
|
48
|
-
verbose = False,
|
|
49
|
-
use_quantiles=False,
|
|
50
|
-
persistence_weight:float=0.0,
|
|
51
|
-
quantiles:List[int]=[],
|
|
52
|
-
**kwargs)->None:
|
|
53
|
-
"""TODO and FIX for future and past categorical variables
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
model_path (str): _description_
|
|
57
|
-
past_steps (int): _description_
|
|
58
|
-
future_steps (int): _description_
|
|
59
|
-
freq_prefix_tuning (bool): _description_
|
|
60
|
-
freq (str): _description_
|
|
61
|
-
prefer_l1_loss (bool): _description_
|
|
62
|
-
loss_type (str): _description_
|
|
63
|
-
num_input_channels (_type_): _description_
|
|
64
|
-
prediction_channel_indices (_type_): _description_
|
|
65
|
-
exogenous_channel_indices (_type_): _description_
|
|
66
|
-
decoder_mode (_type_): _description_
|
|
67
|
-
fcm_context_length (_type_): _description_
|
|
68
|
-
fcm_use_mixer (_type_): _description_
|
|
69
|
-
fcm_mix_layers (_type_): _description_
|
|
70
|
-
fcm_prepend_past (_type_): _description_
|
|
71
|
-
enable_forecast_channel_mixing (_type_): _description_
|
|
72
|
-
out_channels (int): _description_
|
|
73
|
-
embs (List[int]): _description_
|
|
74
|
-
remove_last (bool, optional): _description_. Defaults to False.
|
|
75
|
-
optim (Union[str,None], optional): _description_. Defaults to None.
|
|
76
|
-
optim_config (dict, optional): _description_. Defaults to None.
|
|
77
|
-
scheduler_config (dict, optional): _description_. Defaults to None.
|
|
78
|
-
verbose (bool, optional): _description_. Defaults to False.
|
|
79
|
-
use_quantiles (bool, optional): _description_. Defaults to False.
|
|
80
|
-
persistence_weight (float, optional): _description_. Defaults to 0.0.
|
|
81
|
-
quantiles (List[int], optional): _description_. Defaults to [].
|
|
82
|
-
"""
|
|
83
|
-
super(TTM, self).__init__(verbose)
|
|
84
|
-
self.save_hyperparameters(logger=False)
|
|
85
|
-
self.future_steps = future_steps
|
|
86
|
-
self.use_quantiles = use_quantiles
|
|
87
|
-
self.optim = optim
|
|
88
|
-
self.optim_config = optim_config
|
|
89
|
-
self.scheduler_config = scheduler_config
|
|
90
|
-
self.persistence_weight = persistence_weight
|
|
91
|
-
self.loss_type = loss_type
|
|
92
|
-
self.remove_last = remove_last
|
|
93
|
-
self.embs = embs
|
|
94
|
-
self.freq = freq
|
|
95
|
-
self.extend_variables = False
|
|
96
|
-
|
|
97
|
-
# NOTE: For Hydra
|
|
98
|
-
prediction_channel_indices = list(prediction_channel_indices)
|
|
99
|
-
exogenous_channel_indices = list(exogenous_channel_indices)
|
|
100
|
-
|
|
101
|
-
if len(quantiles)>0:
|
|
102
|
-
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
103
|
-
self.use_quantiles = True
|
|
104
|
-
self.mul = len(quantiles)
|
|
105
|
-
self.loss = QuantileLossMO(quantiles)
|
|
106
|
-
self.extend_variables = True
|
|
107
|
-
if out_channels * 3 != len(prediction_channel_indices):
|
|
108
|
-
prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
|
|
109
|
-
exogenous_channel_indices,
|
|
110
|
-
out_channels)
|
|
111
|
-
else:
|
|
112
|
-
self.mul = 1
|
|
113
|
-
if self.loss_type == 'mse':
|
|
114
|
-
self.loss = nn.MSELoss(reduction="mean")
|
|
115
|
-
elif self.loss_type == 'rmse':
|
|
116
|
-
self.loss = RMSELoss()
|
|
117
|
-
else:
|
|
118
|
-
self.loss = nn.L1Loss()
|
|
119
|
-
|
|
120
|
-
self.model = get_model(
|
|
121
|
-
model_path=model_path,
|
|
122
|
-
context_length=past_steps,
|
|
123
|
-
prediction_length=future_steps,
|
|
124
|
-
freq_prefix_tuning=freq_prefix_tuning,
|
|
125
|
-
freq=freq,
|
|
126
|
-
prefer_l1_loss=prefer_l1_loss,
|
|
127
|
-
prefer_longer_context=prefer_longer_context,
|
|
128
|
-
num_input_channels=num_input_channels,
|
|
129
|
-
decoder_mode=decoder_mode,
|
|
130
|
-
prediction_channel_indices=list(prediction_channel_indices),
|
|
131
|
-
exogenous_channel_indices=list(exogenous_channel_indices),
|
|
132
|
-
fcm_context_length=fcm_context_length,
|
|
133
|
-
fcm_use_mixer=fcm_use_mixer,
|
|
134
|
-
fcm_mix_layers=fcm_mix_layers,
|
|
135
|
-
fcm_prepend_past=fcm_prepend_past,
|
|
136
|
-
#loss='mse',
|
|
137
|
-
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
138
|
-
)
|
|
139
|
-
self.__freeze_backbone()
|
|
140
|
-
|
|
141
|
-
def __add_quantile_features(self, prediction_channel_indices, exogenous_channel_indices, out_channels):
|
|
142
|
-
prediction_channel_indices = list(range(out_channels * 3))
|
|
143
|
-
exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
|
|
144
|
-
num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
|
|
145
|
-
return prediction_channel_indices, exogenous_channel_indices, num_input_channels
|
|
146
|
-
|
|
147
|
-
def __freeze_backbone(self):
|
|
148
|
-
"""
|
|
149
|
-
Freeze the backbone of the model.
|
|
150
|
-
This is useful when you want to fine-tune only the head of the model.
|
|
151
|
-
"""
|
|
152
|
-
print(
|
|
153
|
-
"Number of params before freezing backbone",
|
|
154
|
-
count_parameters(self.model),
|
|
155
|
-
)
|
|
156
|
-
# Freeze the backbone of the model
|
|
157
|
-
for param in self.model.backbone.parameters():
|
|
158
|
-
param.requires_grad = False
|
|
159
|
-
# Count params
|
|
160
|
-
print(
|
|
161
|
-
"Number of params after freezing the backbone",
|
|
162
|
-
count_parameters(self.model),
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
def __scaler(self, input):
|
|
166
|
-
#new_data = torch.tensor([MinMaxScaler().fit_transform(step_data) for step_data in data])
|
|
167
|
-
for i, e in enumerate(self.embs):
|
|
168
|
-
input[:,:,i] = input[:, :, i] / (e-1)
|
|
169
|
-
return input
|
|
170
|
-
|
|
171
|
-
def __build_tupla_indexes(self, size, target_idx, current_idx):
|
|
172
|
-
permute = list(range(size))
|
|
173
|
-
history = dict()
|
|
174
|
-
for j, i in enumerate(target_idx):
|
|
175
|
-
c = history.get(current_idx[j], current_idx[j])
|
|
176
|
-
permute[i], permute[c] = current_idx[j], i
|
|
177
|
-
history[i] = current_idx[j]
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def __permute_indexes(self, values, target_idx, current_idx):
|
|
181
|
-
if current_idx is None or target_idx is None:
|
|
182
|
-
raise ValueError("Indexes cannot be None")
|
|
183
|
-
if sorted(current_idx) != sorted(target_idx):
|
|
184
|
-
return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
|
|
185
|
-
return values
|
|
186
|
-
|
|
187
|
-
def __extend_with_quantile_variables(self, x, original_indexes):
|
|
188
|
-
covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
|
|
189
|
-
covariate_tensors = x[..., covariate_indexes]
|
|
190
|
-
|
|
191
|
-
new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
|
|
192
|
-
|
|
193
|
-
new_original_indexes = list(range(len(original_indexes) * 3))
|
|
194
|
-
return torch.cat([torch.stack(new_tensors, dim=-1), covariate_tensors], dim=-1), new_original_indexes
|
|
195
|
-
|
|
196
|
-
def forward(self, batch):
|
|
197
|
-
x_enc = batch['x_num_past']
|
|
198
|
-
original_indexes = batch['idx_target'][0].tolist()
|
|
199
|
-
original_indexes_future = batch['idx_target_future'][0].tolist()
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
if self.extend_variables:
|
|
203
|
-
x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
|
|
204
|
-
|
|
205
|
-
if 'x_cat_past' in batch.keys():
|
|
206
|
-
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
207
|
-
x_mark_enc = self.__scaler(x_mark_enc)
|
|
208
|
-
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
209
|
-
else:
|
|
210
|
-
past_values = x_enc
|
|
211
|
-
|
|
212
|
-
x_dec = torch.tensor([]).to(self.device)
|
|
213
|
-
if 'x_num_future' in batch.keys():
|
|
214
|
-
x_dec = batch['x_num_future'].to(self.device)
|
|
215
|
-
if self.extend_variables:
|
|
216
|
-
x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
|
|
217
|
-
if 'x_cat_future' in batch.keys():
|
|
218
|
-
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
219
|
-
x_mark_dec = self.__scaler(x_mark_dec)
|
|
220
|
-
future_values = torch.cat((x_dec, x_mark_dec), axis=-1).type(torch.float32)
|
|
221
|
-
else:
|
|
222
|
-
future_values = x_dec
|
|
223
|
-
|
|
224
|
-
if self.remove_last:
|
|
225
|
-
idx_target = batch['idx_target'][0]
|
|
226
|
-
x_start = x_enc[:,-1,idx_target].unsqueeze(1)
|
|
227
|
-
x_enc[:,:,idx_target]-=x_start
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
past_values = self.__permute_indexes(past_values, self.model.prediction_channel_indices, original_indexes)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
future_values = self.__permute_indexes(future_values, self.model.prediction_channel_indices, original_indexes_future)
|
|
234
|
-
|
|
235
|
-
freq_token = get_frequency_token(self.freq).repeat(x_enc.shape[0])
|
|
236
|
-
|
|
237
|
-
res = self.model(
|
|
238
|
-
past_values= past_values,
|
|
239
|
-
future_values= future_values,
|
|
240
|
-
past_observed_mask = None,
|
|
241
|
-
future_observed_mask = None,
|
|
242
|
-
output_hidden_states = False,
|
|
243
|
-
return_dict = False,
|
|
244
|
-
freq_token= freq_token,
|
|
245
|
-
static_categorical_values = None
|
|
246
|
-
)
|
|
247
|
-
#args = None
|
|
248
|
-
#res = self.model(**args)
|
|
249
|
-
BS = res.shape[0]
|
|
250
|
-
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
251
|
-
|
|
252
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|