dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/TIDE.py ADDED
@@ -0,0 +1,296 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from .tft import sub_nn
5
+
6
+ try:
7
+ import lightning.pytorch as pl
8
+ from .base_v2 import Base
9
+ OLD_PL = False
10
+ except:
11
+ import pytorch_lightning as pl
12
+ OLD_PL = True
13
+ from .base import Base
14
+ from .utils import QuantileLossMO
15
+ from typing import List, Union
16
+ from ..data_structure.utils import beauty_string
17
+ from .utils import get_scope
18
+ from .utils import Embedding_cat_variables
19
+
20
+
21
+ class TIDE(Base):
22
+ handle_multivariate = True
23
+ handle_future_covariates = True
24
+ handle_categorical_variables = True
25
+ handle_quantile_loss = True
26
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
27
+
28
+ def __init__(self,
29
+ hidden_size:int,
30
+ d_model: int,
31
+ n_add_enc: int,
32
+ n_add_dec: int,
33
+ dropout_rate: float,
34
+ activation: str='',
35
+ **kwargs)->None:
36
+ """Initializes the model with specified parameters for a neural network architecture. Long-term Forecasting with TiDE: Time-series Dense Encoder
37
+ https://arxiv.org/abs/2304.08424
38
+
39
+ Args:
40
+ hidden_size (int): The size of the hidden layers.
41
+ d_model (int): The dimensionality of the model.
42
+ n_add_enc (int): The number of additional encoder layers.
43
+ n_add_dec (int): The number of additional decoder layers.
44
+ dropout_rate (float): The dropout rate to be applied in the layers.
45
+ activation (str, optional): The activation function to be used. Defaults to an empty string.
46
+ **kwargs: Additional keyword arguments passed to the parent class.
47
+
48
+ """
49
+
50
+
51
+
52
+ super().__init__(**kwargs)
53
+ self.save_hyperparameters(logger=False)
54
+
55
+ # self.dropout = dropout_rate
56
+
57
+
58
+
59
+ self.hidden_size = hidden_size # r
60
+ self.d_model = d_model # r^tilda
61
+
62
+
63
+ # for other numerical variables in the past
64
+ self.aux_past_channels = self.past_channels - self.out_channels
65
+ self.linear_aux_past = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_past_channels)])
66
+
67
+ # for numerical variables in the future
68
+ self.aux_fut_channels = self.future_channels
69
+ self.linear_aux_fut = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_fut_channels)])
70
+
71
+ #changinv from 1.1.5
72
+ # embedding categorical for both past and future (ASSUMING BOTH AVAILABLE OR NO ONE)
73
+ #self.seq_len = self.past_steps + self.future_steps
74
+ #self.emb_cat_var = sub_nn.embedding_cat_variables(self.seq_len, future_steps, hidden_size, embs, self.device)
75
+
76
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
77
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
78
+ emb_past_out_channel = self.emb_past.output_channels
79
+ emb_fut_out_channel = self.emb_fut.output_channels
80
+
81
+
82
+ ## FEATURE PROJECTION
83
+ # past
84
+ if self.aux_past_channels>0:
85
+ self.feat_proj_past = ResidualBlock(hidden_size+emb_past_out_channel, d_model, dropout_rate, activation)
86
+ else:
87
+ self.feat_proj_past = ResidualBlock(emb_past_out_channel, d_model, dropout_rate, activation)
88
+ # future
89
+ if self.aux_fut_channels>0:
90
+ self.feat_proj_fut = ResidualBlock(hidden_size+emb_fut_out_channel, d_model, dropout_rate, activation)
91
+ else:
92
+ self.feat_proj_fut = ResidualBlock(emb_fut_out_channel, d_model, dropout_rate, activation)
93
+
94
+ # # ENCODER
95
+ self.enc_dim_input = self.past_steps*self.out_channels + (self.past_steps+self.future_steps)*d_model
96
+ self.enc_dim_output = self.future_steps*d_model
97
+ self.first_encoder = ResidualBlock(self.enc_dim_input, self.enc_dim_output, dropout_rate, activation)
98
+ self.aux_encoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_enc)])
99
+
100
+ # # DECODER
101
+ self.first_decoder = ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation)
102
+ self.aux_decoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_dec)])
103
+
104
+ ## TEMPORAL DECOER
105
+ self.temporal_decoder = ResidualBlock(2*d_model, self.out_channels*self.mul, dropout_rate, activation)
106
+
107
+ # linear for Y lookback
108
+ self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
109
+
110
+
111
+ def forward(self, batch:dict)-> float:
112
+ """training process of the diffusion network
113
+
114
+ Args:
115
+ batch (dict): variables loaded
116
+
117
+ Returns:
118
+ float: total loss about the prediction of the noises over all subnets extracted
119
+ """
120
+
121
+ # LOADING AUTOREGRESSIVE CONTEXT OF TARGET VARIABLES
122
+
123
+ num_past = batch['x_num_past'].to(self.device)
124
+ idx_target = batch['idx_target'][0]
125
+ y_past = num_past[:,:,idx_target]
126
+ B = y_past.shape[0]
127
+
128
+ # LOADING EMBEDDING CATEGORICAL VARIABLES
129
+ #emb_cat_past, emb_cat_fut = self.cat_categorical_vars(batch)
130
+
131
+ if 'x_cat_future' in batch.keys():
132
+ emb_fut = self.emb_fut(B,batch['x_cat_future'].to(self.device))
133
+ else:
134
+ emb_fut = self.emb_fut(B,None)
135
+ if 'x_cat_past' in batch.keys():
136
+ emb_past = self.emb_past(B,batch['x_cat_past'].to(self.device))
137
+ else:
138
+ emb_past = self.emb_past(B,None)
139
+
140
+
141
+
142
+ #emb_cat_past = torch.mean(emb_cat_past, dim = 2)
143
+ #emb_cat_fut = torch.mean(emb_cat_fut, dim = 2)
144
+
145
+ ### LOADING PAST AND FUTURE NUMERICAL VARIABLES
146
+ # load in the model auxiliar numerical variables
147
+
148
+ if self.aux_past_channels>0: # if we have more numerical variables about past
149
+ aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the autoregressive variable
150
+ assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about past
151
+ # concat all embedded vars and mean of them
152
+ aux_emb_num_past = torch.Tensor().to(self.device)
153
+ for i, layer in enumerate(self.linear_aux_past):
154
+ aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2)
155
+ aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2)
156
+ aux_emb_num_past = torch.mean(aux_emb_num_past, dim = 2)
157
+ else:
158
+ aux_emb_num_past = None # non available vars
159
+
160
+ if self.aux_fut_channels>0: # if we have more numerical variables about future
161
+ # AUX means AUXILIARY variables
162
+ aux_num_fut = batch['x_num_future'].to(self.device)
163
+ assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut
164
+ # concat all embedded vars and mean of them
165
+ aux_emb_num_fut = torch.Tensor().to(self.device)
166
+ for j, layer in enumerate(self.linear_aux_fut):
167
+ aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2)
168
+ aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2)
169
+ aux_emb_num_fut = torch.mean(aux_emb_num_fut, dim = 2)
170
+ else:
171
+ aux_emb_num_fut = None # non available vars
172
+
173
+ # past^tilda
174
+
175
+ if self.aux_past_channels>0:
176
+ emb_past = torch.cat((emb_past, aux_emb_num_past), dim = 2) # [B, L, 2R] #
177
+ proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] #
178
+ else:
179
+ proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] #
180
+
181
+ # fut^tilda
182
+ if self.aux_fut_channels>0:
183
+ emb_fut = torch.cat((emb_fut, aux_emb_num_fut), dim = 2) # [B, H, 2R] #
184
+ proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] #
185
+ else:
186
+ proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] #
187
+
188
+ concat = torch.cat((y_past.view(B, -1), proj_past.view(B, -1), proj_fut.view(B, -1)), dim = 1) # [B, L*self.mul + (L+H)*R^tilda] #
189
+ dense_enc = self.first_encoder(concat)
190
+ for lay_enc in self.aux_encoder:
191
+ dense_enc = lay_enc(dense_enc)
192
+
193
+ dense_dec = self.first_decoder(dense_enc)
194
+ for lay_dec in self.aux_decoder:
195
+ dense_dec = lay_dec(dense_dec)
196
+
197
+ temp_dec_input = torch.cat((dense_dec.view(B, self.future_steps, self.d_model), proj_fut), dim = 2)
198
+ temp_dec_output = self.temporal_decoder(temp_dec_input, False)
199
+ temp_dec_output = temp_dec_output.view(B, self.future_steps, self.out_channels, self.mul)
200
+
201
+ linear_regr = self.linear_target(y_past.view(B, -1))
202
+ linear_output = linear_regr.view(B, self.future_steps, self.out_channels, self.mul)
203
+
204
+ output = temp_dec_output + linear_output
205
+ return output
206
+
207
+ '''
208
+ # function to concat embedded categorical variables
209
+ def cat_categorical_vars(self, batch:dict):
210
+ """Extracting categorical context about past and future
211
+
212
+ Args:
213
+ batch (dict): Keys checked -> ['x_cat_past', 'x_cat_future']
214
+
215
+ Returns:
216
+ List[torch.Tensor, torch.Tensor]: cat_emb_past, cat_emb_fut
217
+ """
218
+ cat_past = None
219
+ cat_fut = None
220
+ # GET AVAILABLE CATEGORICAL CONTEXT
221
+ if 'x_cat_past' in batch.keys():
222
+ cat_past = batch['x_cat_past'].to(self.device)
223
+ if 'x_cat_future' in batch.keys():
224
+ cat_fut = batch['x_cat_future'].to(self.device)
225
+ # CONCAT THEM, according to self.emb_cat_var usage
226
+ if cat_past is None:
227
+ emb_cat_full = self.emb_cat_var(batch['x_num_past'].shape[0],self.device)
228
+
229
+ else:
230
+ cat_full = torch.cat((cat_past, cat_fut), dim = 1)
231
+ emb_cat_full = self.emb_cat_var(cat_full,self.device)
232
+ cat_emb_past = emb_cat_full[:,:self.past_steps,:,:]
233
+ cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:]
234
+
235
+ return cat_emb_past, cat_emb_fut
236
+
237
+ #function to extract from batch['x_num_past'] all variables except the one autoregressive
238
+ '''
239
+ def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: list, dimension: int)-> torch.Tensor:
240
+ """Function to remove variables from tensors in chosen dimension and position
241
+
242
+ Args:
243
+ tensor (torch.Tensor): starting tensor
244
+ indexes_to_exclude (list): index of the chosen dimension we want t oexclude
245
+ dimension (int): dimension of the tensor on which we want to work (not list od dims!!)
246
+
247
+ Returns:
248
+ torch.Tensor: new tensor without the chosen variables
249
+ """
250
+
251
+ remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device)
252
+ # Select the desired sub-tensor
253
+ extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx)
254
+
255
+ return extracted_subtensors
256
+
257
+
258
+ class ResidualBlock(nn.Module):
259
+ def __init__(self, in_size:int, out_size:int, dropout_rate:float, activation_fun:str=''):
260
+ """Residual Block as basic layer of the archetecture.
261
+
262
+ MLP with one hidden layer, activation and skip connection
263
+ Basically dimension d_model, but better if input_dim and output_dim are explicit
264
+
265
+ in_size and out_size to handle dimensions at different stages of the NN
266
+
267
+ Args:
268
+ in_size (int):
269
+ out_size (int):
270
+ dropout_rate (float):
271
+ activation_fun (str, optional): activation function to use in the Residual Block. Defaults to nn.ReLU.
272
+ """
273
+ super().__init__()
274
+
275
+ self.direct_linear = nn.Linear(in_size, out_size, bias = False)
276
+
277
+ if activation_fun=='':
278
+ self.act = nn.ReLU()
279
+ else:
280
+ activation = eval(activation_fun)
281
+ self.act = activation()
282
+ self.lin = nn.Linear(in_size, out_size)
283
+ self.dropout = nn.Dropout(dropout_rate)
284
+
285
+ self.final_norm = nn.LayerNorm(out_size)
286
+
287
+
288
+ def forward(self, x, apply_final_norm = True):
289
+ direct_x = self.direct_linear(x)
290
+
291
+ x = self.dropout(self.lin(self.act(x)))
292
+
293
+ out = x + direct_x
294
+ if apply_final_norm:
295
+ return self.final_norm(out)
296
+ return out
dsipts/models/TTM.py ADDED
@@ -0,0 +1,252 @@
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+
14
+
15
+
16
+ from typing import List,Union
17
+
18
+ from .utils import QuantileLossMO
19
+ from ..data_structure.utils import beauty_string
20
+ from .ttm.utils import get_model, get_frequency_token, count_parameters, RMSELoss
21
+
22
+
23
+ class TTM(Base):
24
+ def __init__(self,
25
+ model_path:str,
26
+ past_steps:int,
27
+ future_steps:int,
28
+ freq_prefix_tuning:bool,
29
+ freq:str,
30
+ prefer_l1_loss:bool, # exog: set true to use l1 loss
31
+ prefer_longer_context:bool,
32
+ loss_type:str,
33
+ num_input_channels,
34
+ prediction_channel_indices,
35
+ exogenous_channel_indices,
36
+ decoder_mode,
37
+ fcm_context_length,
38
+ fcm_use_mixer,
39
+ fcm_mix_layers,
40
+ fcm_prepend_past,
41
+ enable_forecast_channel_mixing,
42
+ out_channels:int,
43
+ embs:List[int],
44
+ remove_last = False,
45
+ optim:Union[str,None]=None,
46
+ optim_config:dict=None,
47
+ scheduler_config:dict=None,
48
+ verbose = False,
49
+ use_quantiles=False,
50
+ persistence_weight:float=0.0,
51
+ quantiles:List[int]=[],
52
+ **kwargs)->None:
53
+ """TODO and FIX for future and past categorical variables
54
+
55
+ Args:
56
+ model_path (str): _description_
57
+ past_steps (int): _description_
58
+ future_steps (int): _description_
59
+ freq_prefix_tuning (bool): _description_
60
+ freq (str): _description_
61
+ prefer_l1_loss (bool): _description_
62
+ loss_type (str): _description_
63
+ num_input_channels (_type_): _description_
64
+ prediction_channel_indices (_type_): _description_
65
+ exogenous_channel_indices (_type_): _description_
66
+ decoder_mode (_type_): _description_
67
+ fcm_context_length (_type_): _description_
68
+ fcm_use_mixer (_type_): _description_
69
+ fcm_mix_layers (_type_): _description_
70
+ fcm_prepend_past (_type_): _description_
71
+ enable_forecast_channel_mixing (_type_): _description_
72
+ out_channels (int): _description_
73
+ embs (List[int]): _description_
74
+ remove_last (bool, optional): _description_. Defaults to False.
75
+ optim (Union[str,None], optional): _description_. Defaults to None.
76
+ optim_config (dict, optional): _description_. Defaults to None.
77
+ scheduler_config (dict, optional): _description_. Defaults to None.
78
+ verbose (bool, optional): _description_. Defaults to False.
79
+ use_quantiles (bool, optional): _description_. Defaults to False.
80
+ persistence_weight (float, optional): _description_. Defaults to 0.0.
81
+ quantiles (List[int], optional): _description_. Defaults to [].
82
+ """
83
+ super(TTM, self).__init__(verbose)
84
+ self.save_hyperparameters(logger=False)
85
+ self.future_steps = future_steps
86
+ self.use_quantiles = use_quantiles
87
+ self.optim = optim
88
+ self.optim_config = optim_config
89
+ self.scheduler_config = scheduler_config
90
+ self.persistence_weight = persistence_weight
91
+ self.loss_type = loss_type
92
+ self.remove_last = remove_last
93
+ self.embs = embs
94
+ self.freq = freq
95
+ self.extend_variables = False
96
+
97
+ # NOTE: For Hydra
98
+ prediction_channel_indices = list(prediction_channel_indices)
99
+ exogenous_channel_indices = list(exogenous_channel_indices)
100
+
101
+ if len(quantiles)>0:
102
+ assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
103
+ self.use_quantiles = True
104
+ self.mul = len(quantiles)
105
+ self.loss = QuantileLossMO(quantiles)
106
+ self.extend_variables = True
107
+ if out_channels * 3 != len(prediction_channel_indices):
108
+ prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
109
+ exogenous_channel_indices,
110
+ out_channels)
111
+ else:
112
+ self.mul = 1
113
+ if self.loss_type == 'mse':
114
+ self.loss = nn.MSELoss(reduction="mean")
115
+ elif self.loss_type == 'rmse':
116
+ self.loss = RMSELoss()
117
+ else:
118
+ self.loss = nn.L1Loss()
119
+
120
+ self.model = get_model(
121
+ model_path=model_path,
122
+ context_length=past_steps,
123
+ prediction_length=future_steps,
124
+ freq_prefix_tuning=freq_prefix_tuning,
125
+ freq=freq,
126
+ prefer_l1_loss=prefer_l1_loss,
127
+ prefer_longer_context=prefer_longer_context,
128
+ num_input_channels=num_input_channels,
129
+ decoder_mode=decoder_mode,
130
+ prediction_channel_indices=list(prediction_channel_indices),
131
+ exogenous_channel_indices=list(exogenous_channel_indices),
132
+ fcm_context_length=fcm_context_length,
133
+ fcm_use_mixer=fcm_use_mixer,
134
+ fcm_mix_layers=fcm_mix_layers,
135
+ fcm_prepend_past=fcm_prepend_past,
136
+ #loss='mse',
137
+ enable_forecast_channel_mixing=enable_forecast_channel_mixing,
138
+ )
139
+ self.__freeze_backbone()
140
+
141
+ def __add_quantile_features(self, prediction_channel_indices, exogenous_channel_indices, out_channels):
142
+ prediction_channel_indices = list(range(out_channels * 3))
143
+ exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
144
+ num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
145
+ return prediction_channel_indices, exogenous_channel_indices, num_input_channels
146
+
147
+ def __freeze_backbone(self):
148
+ """
149
+ Freeze the backbone of the model.
150
+ This is useful when you want to fine-tune only the head of the model.
151
+ """
152
+ print(
153
+ "Number of params before freezing backbone",
154
+ count_parameters(self.model),
155
+ )
156
+ # Freeze the backbone of the model
157
+ for param in self.model.backbone.parameters():
158
+ param.requires_grad = False
159
+ # Count params
160
+ print(
161
+ "Number of params after freezing the backbone",
162
+ count_parameters(self.model),
163
+ )
164
+
165
+ def __scaler(self, input):
166
+ #new_data = torch.tensor([MinMaxScaler().fit_transform(step_data) for step_data in data])
167
+ for i, e in enumerate(self.embs):
168
+ input[:,:,i] = input[:, :, i] / (e-1)
169
+ return input
170
+
171
+ def __build_tupla_indexes(self, size, target_idx, current_idx):
172
+ permute = list(range(size))
173
+ history = dict()
174
+ for j, i in enumerate(target_idx):
175
+ c = history.get(current_idx[j], current_idx[j])
176
+ permute[i], permute[c] = current_idx[j], i
177
+ history[i] = current_idx[j]
178
+
179
+
180
+ def __permute_indexes(self, values, target_idx, current_idx):
181
+ if current_idx is None or target_idx is None:
182
+ raise ValueError("Indexes cannot be None")
183
+ if sorted(current_idx) != sorted(target_idx):
184
+ return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
185
+ return values
186
+
187
+ def __extend_with_quantile_variables(self, x, original_indexes):
188
+ covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
189
+ covariate_tensors = x[..., covariate_indexes]
190
+
191
+ new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
192
+
193
+ new_original_indexes = list(range(len(original_indexes) * 3))
194
+ return torch.cat([torch.stack(new_tensors, dim=-1), covariate_tensors], dim=-1), new_original_indexes
195
+
196
+ def forward(self, batch):
197
+ x_enc = batch['x_num_past']
198
+ original_indexes = batch['idx_target'][0].tolist()
199
+ original_indexes_future = batch['idx_target_future'][0].tolist()
200
+
201
+
202
+ if self.extend_variables:
203
+ x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
204
+
205
+ if 'x_cat_past' in batch.keys():
206
+ x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
207
+ x_mark_enc = self.__scaler(x_mark_enc)
208
+ past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
209
+ else:
210
+ past_values = x_enc
211
+
212
+ x_dec = torch.tensor([]).to(self.device)
213
+ if 'x_num_future' in batch.keys():
214
+ x_dec = batch['x_num_future'].to(self.device)
215
+ if self.extend_variables:
216
+ x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
217
+ if 'x_cat_future' in batch.keys():
218
+ x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
219
+ x_mark_dec = self.__scaler(x_mark_dec)
220
+ future_values = torch.cat((x_dec, x_mark_dec), axis=-1).type(torch.float32)
221
+ else:
222
+ future_values = x_dec
223
+
224
+ if self.remove_last:
225
+ idx_target = batch['idx_target'][0]
226
+ x_start = x_enc[:,-1,idx_target].unsqueeze(1)
227
+ x_enc[:,:,idx_target]-=x_start
228
+
229
+
230
+ past_values = self.__permute_indexes(past_values, self.model.prediction_channel_indices, original_indexes)
231
+
232
+
233
+ future_values = self.__permute_indexes(future_values, self.model.prediction_channel_indices, original_indexes_future)
234
+
235
+ freq_token = get_frequency_token(self.freq).repeat(x_enc.shape[0])
236
+
237
+ res = self.model(
238
+ past_values= past_values,
239
+ future_values= future_values,
240
+ past_observed_mask = None,
241
+ future_observed_mask = None,
242
+ output_hidden_states = False,
243
+ return_dict = False,
244
+ freq_token= freq_token,
245
+ static_categorical_values = None
246
+ )
247
+ #args = None
248
+ #res = self.model(**args)
249
+ BS = res.shape[0]
250
+ return res.reshape(BS,self.future_steps,-1,self.mul)
251
+
252
+