dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,139 @@
1
+ ## Copyright https://github.com/romilbert/samformer/tree/main?tab=MIT-1-ov-file#readme
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside samformer folder
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .samformer.utils import scaled_dot_product_attention, RevIN
9
+
10
+
11
+
12
+ try:
13
+ import lightning.pytorch as pl
14
+ from .base_v2 import Base
15
+ OLD_PL = False
16
+ except:
17
+ import pytorch_lightning as pl
18
+ OLD_PL = True
19
+ from .base import Base
20
+ from .utils import QuantileLossMO,Permute, get_activation
21
+
22
+ from typing import List, Union
23
+ from ..data_structure.utils import beauty_string
24
+ from .utils import get_scope
25
+ from .utils import Embedding_cat_variables
26
+
27
+
28
+
29
+
30
+ class Samformer(Base):
31
+ handle_multivariate = True
32
+ handle_future_covariates = False # or at least it seems...
33
+ handle_categorical_variables = False #solo nel encoder
34
+ handle_quantile_loss = False # NOT EFFICIENTLY ADDED, TODO fix this
35
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
36
+
37
+ def __init__(self,
38
+
39
+ # specific params
40
+ hidden_size:int,
41
+ use_revin: bool,
42
+ activation: str='',
43
+
44
+ **kwargs)->None:
45
+ """Initialize the model with specified parameters. Samformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention.
46
+ https://arxiv.org/pdf/2402.10198
47
+
48
+ Args:
49
+ hidden_size (int): The size of the hidden layer.
50
+ use_revin (bool): Flag indicating whether to use RevIN.
51
+ activation (str, optional): The activation function to use. Defaults to ''.
52
+ **kwargs: Additional keyword arguments passed to the parent class.
53
+
54
+ Raises:
55
+ ValueError: If the activation function is not recognized.
56
+
57
+
58
+ """
59
+
60
+ super().__init__(**kwargs)
61
+ if activation == 'torch.nn.SELU':
62
+ beauty_string('SELU do not require BN','info',self.verbose)
63
+ use_bn = False
64
+ if isinstance(activation,str):
65
+ activation = get_activation(activation)
66
+ self.save_hyperparameters(logger=False)
67
+
68
+
69
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
70
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
71
+ emb_past_out_channel = self.emb_past.output_channels
72
+ emb_fut_out_channel = self.emb_fut.output_channels
73
+
74
+
75
+ self.revin = RevIN(num_features=self.past_channels+emb_past_out_channel)
76
+ self.compute_keys = nn.Linear(self.past_steps, hidden_size)
77
+ self.compute_queries = nn.Linear(self.past_steps, hidden_size)
78
+ self.compute_values = nn.Linear(self.past_steps, self.past_steps)
79
+ self.linear_forecaster = nn.Linear(self.past_steps, self.future_steps)
80
+ self.use_revin = use_revin
81
+
82
+ dim = emb_past_out_channel+self.past_channels+emb_fut_out_channel+self.future_channels
83
+ self.final_layer = nn.Sequential(activation(),
84
+ nn.Linear(dim, dim*2),
85
+ activation(),
86
+ nn.Linear(dim*2,self.out_channels*self.mul ))
87
+
88
+
89
+ def forward(self, batch:dict)-> float:
90
+
91
+ x = batch['x_num_past'].to(self.device)
92
+ BS = x.shape[0]
93
+ if 'x_cat_future' in batch.keys():
94
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
95
+ else:
96
+ emb_fut = self.emb_fut(BS,None)
97
+ if 'x_cat_past' in batch.keys():
98
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
99
+ else:
100
+ emb_past = self.emb_past(BS,None)
101
+
102
+ tmp_future = [emb_fut]
103
+ if 'x_num_future' in batch.keys():
104
+ x_future = batch['x_num_future'].to(self.device)
105
+ tmp_future.append(x_future)
106
+
107
+ tot = [x,emb_past]
108
+ x = torch.cat(tot,axis=2)
109
+
110
+
111
+
112
+
113
+ if self.use_revin:
114
+ x_norm = self.revin(x, mode='norm').transpose(1, 2) # (n, D, L)
115
+ else:
116
+ x_norm = x.transpose(1, 2)
117
+ # Channel-Wise Attention
118
+
119
+ queries = self.compute_queries(x_norm) # (n, D, hid_dim)
120
+ keys = self.compute_keys(x_norm) # (n, D, hid_dim)
121
+ values = self.compute_values(x_norm) # (n, D, L)
122
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
123
+ att_score = nn.functional.scaled_dot_product_attention(queries, keys, values) # (n, D, L)
124
+ else:
125
+ att_score = scaled_dot_product_attention(queries, keys, values) # (n, D, L)
126
+ out = x_norm + att_score # (n, D, L)
127
+ # Linear Forecasting
128
+ out = self.linear_forecaster(out) # (n, D, H)
129
+ # RevIN Denormalization
130
+ if self.use_revin:
131
+ out = self.revin(out.transpose(1, 2), mode='denorm').transpose(1, 2) # (n, D, H)
132
+
133
+
134
+ tmp_future.append(out.permute(0,2,1))
135
+ tmp_future = torch.cat(tmp_future,2)
136
+ output = self.final_layer(tmp_future)
137
+
138
+ return output.reshape(BS,self.future_steps,self.out_channels,self.mul)
139
+
dsipts/models/TFT.py ADDED
@@ -0,0 +1,269 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .tft import sub_nn
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+ from .utils import QuantileLossMO
14
+ from typing import List, Union
15
+ from ..data_structure.utils import beauty_string
16
+ from .utils import get_scope
17
+ from .utils import Embedding_cat_variables
18
+
19
+ class TFT(Base):
20
+ handle_multivariate = True
21
+ handle_future_covariates = True
22
+ handle_categorical_variables = True
23
+ handle_quantile_loss = True
24
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
25
+
26
+ def __init__(self,
27
+ d_model: int,
28
+ num_layers_RNN: int,
29
+ d_head: int,
30
+ n_head: int,
31
+ dropout_rate: float,
32
+
33
+ **kwargs)->None:
34
+ """Initializes the model for time series forecasting with attention mechanisms and recurrent neural networks.
35
+
36
+ This model is designed for direct forecasting, allowing for multi-output and multi-horizon predictions. It leverages attention mechanisms to enhance the selection of relevant past time steps and learn long-term dependencies. The architecture includes RNN enrichment, gating mechanisms to minimize the impact of irrelevant variables, and the ability to output prediction intervals through quantile regression.
37
+
38
+ Key features include:
39
+ - Direct Model: Predicts all future steps at once.
40
+ - Multi-Output Forecasting: Capable of predicting one or more variables simultaneously.
41
+ - Multi-Horizon Forecasting: Predicts variables at multiple future time steps.
42
+ - Attention-Based Mechanism: Enhances the selection of relevant past time steps and learns long-term dependencies.
43
+ - RNN Enrichment: Utilizes LSTM for initial autoregressive approximation, which is refined by the rest of the network.
44
+ - Gating Mechanisms: Reduces the contribution of irrelevant variables.
45
+ - Prediction Intervals: Outputs percentiles (e.g., 10th, 50th, 90th) at each time step.
46
+
47
+ The model also facilitates interpretability by identifying:
48
+ - Global importance of variables for both past and future.
49
+ - Temporal patterns.
50
+ - Significant events.
51
+
52
+ Args:
53
+ d_model (int): General hidden dimension across the network, adjustable in sub-networks.
54
+ num_layers_RNN (int): Number of layers in the recurrent neural network (LSTM).
55
+ d_head (int): Dimension of each attention head.
56
+ n_head (int): Number of attention heads.
57
+ dropout_rate (float): Dropout rate applied uniformly across all dropout layers.
58
+ **kwargs: Additional keyword arguments for further customization.
59
+ """
60
+
61
+
62
+
63
+
64
+
65
+ super().__init__(**kwargs)
66
+ self.save_hyperparameters(logger=False)
67
+ # assert out_channels==1, logging.info("ONLY ONE CHANNEL IMPLEMENTED")
68
+ self.d_model = d_model
69
+ # linear to embed the target vartiable
70
+ self.target_linear = nn.Linear(self.out_channels, d_model) # same for past and fut! (same variable)
71
+ # number of variables in the past different from the target one(s)
72
+ self.aux_past_channels = self.past_channels - self.out_channels # -1 because one channel is occupied by the target variable
73
+ # one linear for each auxiliar past var
74
+ self.linear_aux_past = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_past_channels)])
75
+ # number of variables in the future used to predict the target one(s)
76
+ self.aux_fut_channels = self.future_channels
77
+ # one linear for each auxiliar future var
78
+ self.linear_aux_fut = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_fut_channels)])
79
+ # length of the full sequence, parameter used for the embedding of all categorical variables
80
+ # - we assume that these are no available or available both for past and future
81
+ seq_len = self.past_steps+self.future_steps
82
+
83
+
84
+ ##in v.1.1.5 this is not working, past and future are different for categorical
85
+ #self.emb_cat_var = sub_nn.embedding_cat_variables(seq_len, self.future_steps, d_model, embs, self.device)
86
+
87
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
88
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
89
+ emb_past_out_channel = self.emb_past.output_channels
90
+ emb_fut_out_channel = self.emb_fut.output_channels
91
+
92
+
93
+ # Recurrent Neural Network for first aproximated inference of the target variable(s) - IT IS NON RE-EMBEDDED YET
94
+ self.rnn = sub_nn.LSTM_Model(num_var=self.out_channels,
95
+ d_model = d_model,
96
+ pred_step = self.future_steps,
97
+ num_layers = num_layers_RNN,
98
+ dropout = dropout_rate)
99
+ # PARTS OF TFT:
100
+ # - Residual connections
101
+ # - Gated Residual Network
102
+ # - Interpretable MultiHead Attention
103
+ self.res_conn1_past = sub_nn.ResidualConnection(d_model, dropout_rate)
104
+ self.res_conn1_fut = sub_nn.ResidualConnection(d_model, dropout_rate)
105
+ self.grn1_past = sub_nn.GRN(d_model, dropout_rate)
106
+ self.grn1_fut = sub_nn.GRN(d_model, dropout_rate)
107
+ self.InterpretableMultiHead = sub_nn.InterpretableMultiHead(d_model, d_head, n_head)
108
+ self.res_conn2_att = sub_nn.ResidualConnection(d_model, dropout_rate)
109
+ self.grn2_att = sub_nn.GRN(d_model, dropout_rate)
110
+ self.res_conn3_out = sub_nn.ResidualConnection(d_model, dropout_rate)
111
+
112
+ self.outLinear = nn.Linear(d_model, self.out_channels*self.mul)
113
+
114
+ def forward(self, batch:dict) -> torch.Tensor:
115
+ """Temporal Fusion Transformer
116
+
117
+ Collectiong Data
118
+ - Extract the autoregressive variable(s)
119
+ - Embedding and compute a first approximated prediction
120
+ - 'summary_past' and 'summary_fut' collecting data about past and future
121
+ Concatenating on the dimension 2 all different datas, which will be mixed through a MEAN over that imension
122
+ Info get from other tensor of the batch taken as input
123
+
124
+ TFT actual computations
125
+ - Residual Connection for y_past and summary_past
126
+ - Residual Connection for y_fut and summary_fut
127
+ - GRN1 for past and for fut
128
+ - ATTENTION(summary_fut, summary_past, y_past)
129
+ - Residual Connection for attention itself
130
+ - GRN2 for attention
131
+ - Residual Connection for attention and summary_fut
132
+ - Linear for actual values and reshape
133
+
134
+ Args:
135
+ batch (dict): Keys used are ['x_num_past', 'idx_target', 'x_num_future', 'x_cat_past', 'x_cat_future']
136
+
137
+ Returns:
138
+ torch.Tensor: shape [B, self.future_steps, self.out_channels, self.mul] or [B, self.future_steps, self.out_channels] according to quantiles
139
+ """
140
+
141
+ num_past = batch['x_num_past'].to(self.device)
142
+ # PAST TARGET NUMERICAL VARIABLE
143
+ # always available: autoregressive variable
144
+ # compute rnn prediction
145
+ idx_target = batch['idx_target'][0]
146
+ target_num_past = num_past[:,:,idx_target]
147
+ target_emb_num_past = self.target_linear(target_num_past) # target_variables comunicating with each others
148
+ target_num_fut_approx = self.rnn(target_emb_num_past)
149
+ # embed future predictions
150
+ target_emb_num_fut_approx = self.target_linear(target_num_fut_approx)
151
+
152
+ ### create variable summary_past and summary_fut
153
+ # at the beggining it is composed only by past and future target variable
154
+ summary_past = target_emb_num_past.unsqueeze(2)
155
+ summary_fut = target_emb_num_fut_approx.unsqueeze(2)
156
+ # now we search for others categorical and numerical variables!
157
+
158
+
159
+ ### PAST NUMERICAL VARIABLES
160
+ if self.aux_past_channels>0: # so we have more numerical variables about past
161
+ # AUX = AUXILIARY variables
162
+ aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the target index on the second dimension
163
+ assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.shape(2)} VARS",'section',True) # to check if we are using the expected number of variables about past
164
+ aux_emb_num_past = torch.Tensor().to(aux_num_past.device)
165
+ for i, layer in enumerate(self.linear_aux_past):
166
+ aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2)
167
+ aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2)
168
+ ## update summary about past
169
+ summary_past = torch.cat((summary_past, aux_emb_num_past), dim=2)
170
+
171
+ ### FUTURE NUMERICAL VARIABLES
172
+ if self.aux_fut_channels>0: # so we have more numerical variables about future
173
+ aux_num_fut = batch['x_num_future'].to(self.device)
174
+ assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut
175
+ aux_emb_num_fut = torch.Tensor().to(aux_num_fut.device)
176
+ for j, layer in enumerate(self.linear_aux_fut):
177
+ aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2)
178
+ aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2)
179
+ ## update summary about future
180
+ summary_fut = torch.cat((summary_fut, aux_emb_num_fut), dim=2)
181
+ '''
182
+ ### CATEGORICAL VARIABLES changed in 1.1.5
183
+ if 'x_cat_past' in batch.keys() and 'x_cat_future' in batch.keys(): # if we have both
184
+ # HERE WE ASSUME SAME NUMBER AND KIND OF VARIABLES IN PAST AND FUTURE
185
+ cat_past = batch['x_cat_past'].to(self.device)
186
+ cat_fut = batch['x_cat_future'].to(self.device)
187
+ cat_full = torch.cat((cat_past, cat_fut), dim = 1)
188
+ # EMB CATEGORICAL VARIABLES AND THEN SPLIT IN PAST AND FUTURE
189
+ emb_cat_full = self.emb_cat_var(cat_full,self.device)
190
+ else:
191
+ emb_cat_full = self.emb_cat_var(num_past.shape[0],self.device)
192
+
193
+ cat_emb_past = emb_cat_full[:,:-self.future_steps,:,:]
194
+ cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:]
195
+
196
+ ## update summary
197
+ # past
198
+ summary_past = torch.cat((summary_past, cat_emb_past), dim=2)
199
+ # future
200
+ summary_fut = torch.cat((summary_fut, cat_emb_fut), dim=2)
201
+ '''
202
+ BS = num_past.shape[0]
203
+ if 'x_cat_future' in batch.keys():
204
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
205
+ else:
206
+ emb_fut = self.emb_fut(BS,None)
207
+ if 'x_cat_past' in batch.keys():
208
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
209
+ else:
210
+ emb_past = self.emb_past(BS,None)
211
+
212
+ ## update summary
213
+ # past
214
+
215
+ summary_past = torch.cat((summary_past, emb_past.unsqueeze(-1).repeat((1,1,1,summary_past.shape[-1]))), dim=2)
216
+ # future
217
+ summary_fut = torch.cat((summary_fut, emb_fut.unsqueeze(-1).repeat((1,1,1,summary_past.shape[-1]))), dim=2)
218
+
219
+ # >>> PAST:
220
+ summary_past = torch.mean(summary_past, dim=2)
221
+ # >>> FUTURE:
222
+ summary_fut = torch.mean(summary_fut, dim=2)
223
+
224
+ ### Residual Connection from LSTM
225
+ summary_past = self.res_conn1_past(summary_past, target_emb_num_past)
226
+ summary_fut = self.res_conn1_fut(summary_fut, target_emb_num_fut_approx)
227
+
228
+ ### GRN1
229
+ summary_past = self.grn1_past(summary_past)
230
+ summary_fut = self.grn1_fut(summary_fut)
231
+
232
+ ### INTERPRETABLE MULTI HEAD ATTENTION
233
+ attention = self.InterpretableMultiHead(summary_fut, summary_past, target_emb_num_past)
234
+
235
+ ### Residual Connection from ATT
236
+ attention = self.res_conn2_att(attention, attention)
237
+
238
+ ### GRN
239
+ attention = self.grn2_att(attention)
240
+
241
+ ### Resuidual Connection from GRN1
242
+ out = self.res_conn3_out(attention, summary_fut)
243
+
244
+ ### OUT
245
+ out = self.outLinear(out)
246
+
247
+ if self.mul>0:
248
+ out = out.view(-1, self.future_steps, self.out_channels, self.mul)
249
+ return out
250
+
251
+ #function to extract from batch['x_num_past'] all variables except the one autoregressive
252
+ def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: int, dimension: int)-> torch.Tensor:
253
+ """Function to remove variables from tensors in chosen dimension and position
254
+
255
+ Args:
256
+ tensor (torch.Tensor): starting tensor
257
+ indexes_to_exclude (int): index of the chosen dimension we want t oexclude
258
+ dimension (int): dimension of the tensor on which we want to work
259
+
260
+ Returns:
261
+ torch.Tensor: new tensor without the chosen variables
262
+ """
263
+
264
+ remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device)
265
+ # Select the desired sub-tensor
266
+ extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx)
267
+
268
+ return extracted_subtensors
269
+