dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,342 @@
1
+
2
+ from torch import nn
3
+ import torch
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+ from .utils import QuantileLossMO,Permute, get_activation
14
+ from typing import List, Union
15
+ from ..data_structure.utils import beauty_string
16
+ import numpy as np
17
+ torch.autograd.set_detect_anomaly(True)
18
+ from .utils import get_scope
19
+ from .utils import Embedding_cat_variables
20
+
21
+ class GLU(nn.Module):
22
+ def __init__(self, d_model: int):
23
+ """Gated Linear Unit, 'Gate' block in TFT paper
24
+ Sub net of GRN: linear(x) * sigmoid(linear(x))
25
+ No dimension changes
26
+
27
+ Args:
28
+ d_model (int): model dimension
29
+ """
30
+ super().__init__()
31
+ self.linear = nn.Linear(d_model, d_model)
32
+ self.activation = nn.ReLU6()
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ """Gated Linear Unit
36
+ Sub net of GRN: linear(x) * sigmoid(linear(x))
37
+ No dimension changes: [bs, seq_len, d_model]
38
+
39
+ Args:
40
+ x (torch.Tensor)
41
+
42
+ Returns:
43
+ torch.Tensor
44
+ """
45
+
46
+ ##here comes something like BSxL
47
+ x1 = (self.activation(self.linear(x.unsqueeze(2)))/6.0).squeeze()
48
+ out = x1*x #element-wise multiplication
49
+
50
+ ##get the score
51
+ score = torch.sign(x1).mean()
52
+ return out,score
53
+
54
+ class Block(nn.Module):
55
+ def __init__(self,input_channels:int,kernel_size:int,output_channels:int,input_size:int,sum_layers:bool ):
56
+
57
+
58
+ super(Block, self).__init__()
59
+
60
+ self.dilations = nn.ModuleList()
61
+ self.steps = int(np.floor(np.log2(input_size)))-1
62
+
63
+ if self.steps <=1:
64
+ self.steps = 1
65
+
66
+ for i in range(self.steps):
67
+ #dilation
68
+ self.dilations.append(nn.Conv1d(input_channels, output_channels, kernel_size, stride=1,padding='same',dilation=2**i))
69
+ s = max(2**i-1,1)
70
+ k = 2**(i+1)+1
71
+ p = int(((s-1)*input_size + k - 1)/2)
72
+ self.dilations.append(nn.Conv1d(input_channels, output_channels, k, stride=s,padding=p))
73
+
74
+
75
+
76
+
77
+ self.sum_layers = sum_layers
78
+ mul = 1 if sum_layers else self.steps*2
79
+ self.conv_final = nn.Conv1d(output_channels*mul, output_channels*mul, kernel_size, stride=1,padding='same')
80
+ self.out_channels = output_channels*mul
81
+ def forward(self, x: torch.tensor) -> torch.tensor:
82
+ x = Permute()(x)
83
+ tmp = []
84
+ for i in range(self.steps):
85
+
86
+ tmp.append(self.dilations[i](x))
87
+
88
+ if self.sum_layers:
89
+ tmp = torch.stack(tmp)
90
+ tmp = tmp.sum(axis=0)
91
+ else:
92
+ tmp = torch.cat(tmp,1)
93
+
94
+ return Permute()(tmp)
95
+
96
+
97
+
98
+ class DilatedConv(Base):
99
+ handle_multivariate = True
100
+ handle_future_covariates = True
101
+ handle_categorical_variables = True
102
+ handle_quantile_loss = True
103
+
104
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
105
+
106
+ def __init__(self,
107
+ sum_layers: bool,
108
+ hidden_RNN:int,
109
+ num_layers_RNN:int,
110
+ kind:str,
111
+ kernel_size:int,
112
+ activation: str='torch.nn.ReLU',
113
+ remove_last = False,
114
+ dropout_rate:float=0.1,
115
+ use_bn:bool=False,
116
+ use_glu:bool=True,
117
+ glu_percentage: float=1.0,
118
+
119
+ **kwargs)->None:
120
+ """Custom encoder-decoder
121
+
122
+ Args:
123
+ sum_layers (bool): Flag indicating whether to sum the layers.
124
+ hidden_RNN (int): Number of hidden units in the RNN.
125
+ num_layers_RNN (int): Number of layers in the RNN.
126
+ kind (str): Type of RNN to use (e.g., 'LSTM', 'GRU').
127
+ kernel_size (int): Size of the convolutional kernel.
128
+ activation (str, optional): Activation function to use. Defaults to 'torch.nn.ReLU'.
129
+ remove_last (bool, optional): Flag to indicate whether to remove the last element in the sequence. Defaults to False.
130
+ dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
131
+ use_bn (bool, optional): Flag to indicate whether to use batch normalization. Defaults to False.
132
+ use_glu (bool, optional): Flag to indicate whether to use Gated Linear Units (GLU). Defaults to True.
133
+ glu_percentage (float, optional): Percentage of GLU to apply. Defaults to 1.0.
134
+ **kwargs: Additional keyword arguments.
135
+
136
+ Returns:
137
+ None
138
+ """
139
+ super().__init__(**kwargs)
140
+ if activation == 'torch.nn.SELU':
141
+ beauty_string('SELU do not require BN','info',self.verbose)
142
+ use_bn = False
143
+ if isinstance(activation,str):
144
+ activation = get_activation(activation)
145
+ else:
146
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
147
+
148
+ self.save_hyperparameters(logger=False)
149
+ self.num_layers_RNN = num_layers_RNN
150
+ self.hidden_RNN = hidden_RNN
151
+ self.kind = kind
152
+ self.use_glu = use_glu
153
+ self.glu_percentage = torch.tensor(glu_percentage).to(self.device)
154
+ self.remove_last = remove_last
155
+
156
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
157
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
158
+ emb_past_out_channel = self.emb_past.output_channels
159
+ emb_fut_out_channel = self.emb_fut.output_channels
160
+
161
+ if self.use_glu:
162
+ self.past_glu = nn.ModuleList()
163
+ self.future_glu = nn.ModuleList()
164
+ for _ in range(self.past_channels):
165
+ self.past_glu.append(GLU(1))
166
+
167
+ for _ in range(self.future_channels):
168
+ self.future_glu.append(GLU(1))
169
+
170
+ self.initial_linear_encoder = nn.Sequential(Permute(),
171
+ nn.Conv1d(self.past_channels, (self.past_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
172
+ activation(),
173
+ nn.BatchNorm1d( (self.past_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
174
+ nn.Conv1d( (self.past_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
175
+ Permute())
176
+
177
+ self.initial_linear_decoder = nn.Sequential(Permute(),
178
+ nn.Conv1d(self.future_channels, (self.future_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
179
+ activation(),
180
+ nn.BatchNorm1d( (self.future_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
181
+ nn.Conv1d( (self.future_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
182
+ Permute())
183
+ self.conv_encoder = Block(emb_past_out_channel+hidden_RNN//4,kernel_size,hidden_RNN//2,self.past_steps,sum_layers)
184
+
185
+ #nn.Sequential(Permute(), nn.Conv1d(emb_channels+hidden_RNN//8, hidden_RNN//8, kernel_size, stride=1,padding='same'),Permute(),nn.Dropout(0.3))
186
+
187
+ if self.future_channels+emb_fut_out_channel==0:
188
+ ## occhio che vuol dire che non ho passato , per ora ci metto una pezza e uso hidden dell'encoder
189
+ self.conv_decoder = Block(hidden_RNN,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
190
+ else:
191
+ self.conv_decoder = Block(self.future_channels+emb_fut_out_channel,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
192
+ #nn.Sequential(Permute(),nn.Linear(past_steps,past_steps*2), nn.PReLU(),nn.Dropout(0.2),nn.Linear(past_steps*2, future_steps),nn.Dropout(0.3),nn.Conv1d(hidden_RNN, hidden_RNN//8, 3, stride=1,padding='same'), Permute())
193
+ if self.kind=='lstm':
194
+ self.Encoder = nn.LSTM(input_size= self.conv_encoder.out_channels,#, hidden_RNN//4,
195
+ hidden_size=hidden_RNN//2,
196
+ num_layers = num_layers_RNN,
197
+ batch_first=True,bidirectional=True)
198
+ self.Decoder = nn.LSTM(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
199
+ hidden_size=hidden_RNN//2,
200
+ num_layers = num_layers_RNN,
201
+ batch_first=True,bidirectional=True)
202
+ elif self.kind=='gru':
203
+ self.Encoder = nn.GRU(input_size=self.conv_encoder.out_channels,#, hidden_RNN//4,
204
+ hidden_size=hidden_RNN//2,
205
+ num_layers = num_layers_RNN,
206
+ batch_first=True,bidirectional=True)
207
+ self.Decoder = nn.GRU(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
208
+ hidden_size=hidden_RNN//2,
209
+ num_layers = num_layers_RNN,
210
+ batch_first=True,bidirectional=True)
211
+ else:
212
+ beauty_string('Specify kind lstm or gru please','section',True)
213
+ self.final_linear = nn.ModuleList()
214
+ for _ in range(self.out_channels*self.mul):
215
+ self.final_linear.append(nn.Sequential(nn.Linear(hidden_RNN+emb_fut_out_channel+self.future_channels,hidden_RNN*2),
216
+ activation(),
217
+ Permute() if use_bn else nn.Identity() ,
218
+ nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
219
+ Permute() if use_bn else nn.Identity() ,
220
+ nn.Linear(hidden_RNN*2,hidden_RNN),
221
+ activation(),
222
+ Permute() if use_bn else nn.Identity() ,
223
+ nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
224
+ Permute() if use_bn else nn.Identity() ,
225
+ nn.Linear(hidden_RNN,hidden_RNN//2),
226
+ activation(),
227
+ Permute() if use_bn else nn.Identity() ,
228
+ nn.BatchNorm1d(hidden_RNN//2) if use_bn else nn.Dropout(dropout_rate) ,
229
+ Permute() if use_bn else nn.Identity() ,
230
+ nn.Linear(hidden_RNN//2,hidden_RNN//4),
231
+ activation(),
232
+ nn.Linear(hidden_RNN//4,1)))
233
+
234
+
235
+
236
+
237
+
238
+ def training_step(self, batch, batch_idx):
239
+ """
240
+ pythotrch lightening stuff
241
+
242
+ :meta private:
243
+ """
244
+ y_hat,score = self(batch)
245
+ return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
246
+
247
+ def validation_step(self, batch, batch_idx):
248
+ """
249
+ pythotrch lightening stuff
250
+
251
+ :meta private:
252
+ """
253
+ y_hat,score = self(batch)
254
+ return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
255
+
256
+ def forward(self, batch):
257
+ """It is mandatory to implement this method
258
+
259
+ Args:
260
+ batch (dict): batch of the dataloader
261
+
262
+ Returns:
263
+ torch.tensor: result
264
+ """
265
+
266
+ x = batch['x_num_past'].to(self.device)
267
+ BS = x.shape[0]
268
+ if 'x_cat_future' in batch.keys():
269
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
270
+ else:
271
+ emb_fut = self.emb_fut(BS,None)
272
+ if 'x_cat_past' in batch.keys():
273
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
274
+ else:
275
+ emb_past = self.emb_past(BS,None)
276
+
277
+
278
+ if 'x_num_future' in batch.keys():
279
+ x_future = batch['x_num_future'].to(self.device)
280
+ xf = torch.clone(x_future)
281
+ else:
282
+ x_future = None
283
+
284
+ if self.remove_last:
285
+ idx_target = batch['idx_target'][0]
286
+
287
+ x_start = x[:,-1,idx_target].unsqueeze(1)
288
+ ##BxC
289
+ x[:,:,idx_target]-=x_start
290
+
291
+
292
+ ## first GLU
293
+ score = 0
294
+ xp = torch.clone(x)
295
+
296
+ if self.use_glu:
297
+ score_past_tot = 0
298
+ score_future_tot = 0
299
+
300
+ for i in range(len(self.past_glu)):
301
+ x[:,:,i],score = self.past_glu[i](xp[:,:,i])
302
+ score_past_tot+=score
303
+ score_past_tot/=len(self.past_glu)
304
+
305
+ if x_future is not None:
306
+ for i in range(len(self.future_glu)):
307
+ x_future[:,:,i],score = self.future_glu[i](xf[:,:,i])
308
+ score_future_tot+=score
309
+ score_future_tot/=len(self.future_glu)
310
+ score = 0.5*(score_past_tot+score_future_tot)
311
+ tmp = [self.initial_linear_encoder(x),emb_past]
312
+
313
+
314
+
315
+ tot = torch.cat(tmp,2)
316
+ out, hidden = self.Encoder(self.conv_encoder(tot))
317
+ tmp = [emb_fut]
318
+ if x_future is not None:
319
+ tmp.append(x_future)
320
+ tot = torch.cat(tmp,2)
321
+ out, _ = self.Decoder(self.conv_decoder(tot),hidden)
322
+ res = []
323
+ tmp = torch.cat([tot,out],axis=2)
324
+
325
+
326
+ for j in range(self.out_channels*self.mul):
327
+ res.append(self.final_linear[j](tmp))
328
+
329
+ res = torch.cat(res,2)
330
+ ##BxLxC
331
+ B = res.shape[0]
332
+ res = res.reshape(B,self.future_steps,-1,self.mul)
333
+ if self.remove_last:
334
+ res+=x_start.unsqueeze(1)
335
+
336
+
337
+ return res, score
338
+
339
+ def inference(self, batch:dict)->torch.tensor:
340
+
341
+ res, score = self(batch)
342
+ return res
@@ -0,0 +1,310 @@
1
+
2
+ from torch import nn
3
+ import torch
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+ from .utils import QuantileLossMO,Permute, get_activation
14
+ from typing import List, Union
15
+ from ..data_structure.utils import beauty_string
16
+ import numpy as np
17
+ from .utils import get_scope
18
+ from .utils import Embedding_cat_variables
19
+ torch.autograd.set_detect_anomaly(True)
20
+
21
+ class GLU(nn.Module):
22
+ def __init__(self, d_model: int):
23
+ """Gated Linear Unit, 'Gate' block in TFT paper
24
+ Sub net of GRN: linear(x) * sigmoid(linear(x))
25
+ No dimension changes
26
+
27
+ Args:
28
+ d_model (int): model dimension
29
+ """
30
+ super().__init__()
31
+ self.linear = nn.Linear(d_model, d_model)
32
+ self.activation = nn.ReLU6()
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ """Gated Linear Unit
36
+ Sub net of GRN: linear(x) * sigmoid(linear(x))
37
+ No dimension changes: [bs, seq_len, d_model]
38
+
39
+ Args:
40
+ x (torch.Tensor)
41
+
42
+ Returns:
43
+ torch.Tensor
44
+ """
45
+
46
+ ##here comes something like BSxL
47
+ x1 = (self.activation(self.linear(x.unsqueeze(2)))/6.0).squeeze()
48
+ out = x1*x #element-wise multiplication
49
+
50
+ ##get the score
51
+ score = torch.sign(x1).mean()
52
+ return out,score
53
+
54
+ class Block(nn.Module):
55
+ def __init__(self,input_channels:int,kernel_size:int,output_channels:int,input_size:int,sum_layers:bool ):
56
+
57
+
58
+ super(Block, self).__init__()
59
+
60
+ self.dilations = nn.ModuleList()
61
+ self.steps = int(np.floor(np.log2(input_size)))-1
62
+
63
+ if self.steps <=1:
64
+ self.steps = 1
65
+
66
+ for i in range(self.steps):
67
+ #dilation
68
+ self.dilations.append(nn.Conv1d(input_channels, output_channels, kernel_size, stride=1,padding='same',dilation=2**i))
69
+ s = max(2**i-1,1)
70
+ k = 2**(i+1)+1
71
+ p = int(((s-1)*input_size + k - 1)/2)
72
+ self.dilations.append(nn.Conv1d(input_channels, output_channels, k, stride=s,padding=p))
73
+
74
+
75
+
76
+
77
+ self.sum_layers = sum_layers
78
+ mul = 1 if sum_layers else self.steps*2
79
+ self.conv_final = nn.Conv1d(output_channels*mul, output_channels*mul, kernel_size, stride=1,padding='same')
80
+ self.out_channels = output_channels*mul
81
+ def forward(self, x: torch.tensor) -> torch.tensor:
82
+ x = Permute()(x)
83
+ tmp = []
84
+ for i in range(self.steps):
85
+
86
+ tmp.append(self.dilations[i](x))
87
+
88
+ if self.sum_layers:
89
+ tmp = torch.stack(tmp)
90
+ tmp = tmp.sum(axis=0)
91
+ else:
92
+ tmp = torch.cat(tmp,1)
93
+
94
+ return Permute()(tmp)
95
+
96
+
97
+
98
+ class DilatedConvED(Base):
99
+ handle_multivariate = True
100
+ handle_future_covariates = True
101
+ handle_categorical_variables = True
102
+ handle_quantile_loss = True
103
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
104
+
105
+ def __init__(self,
106
+ sum_layers: bool,
107
+ hidden_RNN:int,
108
+ num_layers_RNN:int,
109
+ kind:str,
110
+ kernel_size:int,
111
+ dropout_rate:float=0.1,
112
+ use_bn:bool=False,
113
+ use_cumsum:bool=True,
114
+ use_bilinear:bool=False,
115
+ activation: str='torch.nn.ReLU',
116
+
117
+ **kwargs)->None:
118
+ """Initialize the model with specified parameters.
119
+
120
+ Args:
121
+ sum_layers (bool): Flag indicating whether to sum layers in the encoder/decoder blocks.
122
+ hidden_RNN (int): Number of hidden units in the RNN.
123
+ num_layers_RNN (int): Number of layers in the RNN.
124
+ kind (str): Type of RNN to use ('lstm' or 'gru').
125
+ kernel_size (int): Size of the convolutional kernel.
126
+ dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
127
+ use_bn (bool, optional): Flag to use batch normalization. Defaults to False.
128
+ use_cumsum (bool, optional): Flag to use cumulative sum. Defaults to True.
129
+ use_bilinear (bool, optional): Flag to use bilinear layers. Defaults to False.
130
+ activation (str, optional): Activation function to use. Defaults to 'torch.nn.ReLU'.
131
+ **kwargs: Additional keyword arguments.
132
+
133
+ Raises:
134
+ ValueError: If the specified activation function is not recognized or if the kind is not 'lstm' or 'gru'.
135
+
136
+
137
+ """
138
+ super().__init__(**kwargs)
139
+ if activation == 'torch.nn.SELU':
140
+ beauty_string('SELU do not require BN','info',self.verbose)
141
+ use_bn = False
142
+ if isinstance(activation,str):
143
+ activation = get_activation(activation)
144
+ else:
145
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
146
+
147
+ self.save_hyperparameters(logger=False)
148
+
149
+ self.num_layers_RNN = num_layers_RNN
150
+ self.hidden_RNN = hidden_RNN
151
+ self.use_cumsum = use_cumsum
152
+ self.kind = kind
153
+ self.use_bilinear= use_bilinear
154
+
155
+
156
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
157
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
158
+ emb_past_out_channel = self.emb_past.output_channels
159
+ emb_fut_out_channel = self.emb_fut.output_channels
160
+
161
+
162
+
163
+
164
+ self.initial_linear_encoder = nn.Sequential(Permute(),
165
+ nn.Conv1d(self.past_channels, (self.past_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
166
+ activation(),
167
+ nn.BatchNorm1d( (self.past_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
168
+ nn.Conv1d( (self.past_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
169
+ Permute())
170
+
171
+ self.initial_linear_decoder = nn.Sequential(Permute(),
172
+ nn.Conv1d(self.future_channels, (self.future_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
173
+ activation(),
174
+ nn.BatchNorm1d( (self.future_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
175
+ nn.Conv1d( (self.future_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
176
+ Permute())
177
+ self.conv_encoder = Block(emb_past_out_channel+hidden_RNN//4,kernel_size,hidden_RNN//2,self.past_steps,sum_layers)
178
+
179
+
180
+ if self.future_channels+emb_fut_out_channel==0:
181
+ ## occhio che vuol dire che non ho passato , per ora ci metto una pezza e uso hidden dell'encoder
182
+ self.conv_decoder = Block(hidden_RNN,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
183
+ else:
184
+ self.conv_decoder = Block(self.future_channels+emb_fut_out_channel,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
185
+ #nn.Sequential(Permute(),nn.Linear(past_steps,past_steps*2), nn.PReLU(),nn.Dropout(0.2),nn.Linear(past_steps*2, future_steps),nn.Dropout(0.3),nn.Conv1d(hidden_RNN, hidden_RNN//8, 3, stride=1,padding='same'), Permute())
186
+ if self.kind=='lstm':
187
+ self.Encoder = nn.LSTM(input_size= self.conv_encoder.out_channels,#, hidden_RNN//4,
188
+ hidden_size=hidden_RNN//2,
189
+ num_layers = num_layers_RNN,
190
+ batch_first=True,bidirectional=True)
191
+ self.Decoder = nn.LSTM(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
192
+ hidden_size=hidden_RNN//2,
193
+ num_layers = num_layers_RNN,
194
+ batch_first=True,bidirectional=True)
195
+ elif self.kind=='gru':
196
+ self.Encoder = nn.GRU(input_size=self.conv_encoder.out_channels,#, hidden_RNN//4,
197
+ hidden_size=hidden_RNN//2,
198
+ num_layers = num_layers_RNN,
199
+ batch_first=True,bidirectional=True)
200
+ self.Decoder = nn.GRU(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
201
+ hidden_size=hidden_RNN//2,
202
+ num_layers = num_layers_RNN,
203
+ batch_first=True,bidirectional=True)
204
+ else:
205
+ beauty_string('Specify kind lstm or gru please','section',True)
206
+ self.final_linear_decoder = nn.Sequential(nn.Linear((hidden_RNN//2*2)*num_layers_RNN ,hidden_RNN*2),
207
+ activation(),
208
+ Permute() if use_bn else nn.Identity() ,
209
+ nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
210
+ Permute() if use_bn else nn.Identity() ,
211
+ nn.Linear(hidden_RNN*2,hidden_RNN),
212
+ activation(),
213
+ Permute() if use_bn else nn.Identity() ,
214
+ nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
215
+ Permute() if use_bn else nn.Identity() ,
216
+ nn.Linear(hidden_RNN,self.mul))
217
+
218
+ if use_bilinear:
219
+ self.bilinear = torch.nn.Bilinear((hidden_RNN//2*2)*num_layers_RNN,(hidden_RNN//2*2)*num_layers_RNN,hidden_RNN*2)
220
+ self.final_linear_decoder = nn.Sequential(
221
+ activation(),
222
+ Permute() if use_bn else nn.Identity() ,
223
+ nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
224
+ Permute() if use_bn else nn.Identity() ,
225
+ nn.Linear(hidden_RNN*2,hidden_RNN),
226
+ activation(),
227
+ Permute() if use_bn else nn.Identity() ,
228
+ nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
229
+ Permute() if use_bn else nn.Identity() ,
230
+ nn.Linear(hidden_RNN ,self.mul))
231
+
232
+
233
+
234
+ def forward(self, batch):
235
+ """It is mandatory to implement this method
236
+
237
+ Args:
238
+ batch (dict): batch of the dataloader
239
+
240
+ Returns:
241
+ torch.tensor: result
242
+ """
243
+ x = batch['x_num_past'].to(self.device)
244
+ BS = x.shape[0]
245
+ if 'x_cat_future' in batch.keys():
246
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
247
+ else:
248
+ emb_fut = self.emb_fut(BS,None)
249
+ if 'x_cat_past' in batch.keys():
250
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
251
+ else:
252
+ emb_past = self.emb_past(BS,None)
253
+
254
+ if 'x_num_future' in batch.keys():
255
+ x_future = batch['x_num_future'].to(self.device)
256
+ else:
257
+ x_future = None
258
+
259
+
260
+ tmp = [self.initial_linear_encoder(x),emb_past]
261
+
262
+ tot = torch.cat(tmp,2)
263
+
264
+ out_past, hidden_past = self.Encoder(self.conv_encoder(tot))
265
+
266
+
267
+ ## hidden = 2 x bs x channels_out_encoder
268
+ ## out = BS x len x channels_out_encoder
269
+ tmp = [emb_fut]
270
+
271
+
272
+ if x_future is not None:
273
+ tmp.append(x_future)
274
+
275
+
276
+
277
+
278
+ if len(tmp)>0:
279
+ tot = torch.cat(tmp,2)
280
+ out_future, hidden_future = self.Decoder(self.conv_decoder(tot))
281
+ else:
282
+ out_future, hidden_future = self.Decoder(self.conv_decoder(out_past))
283
+ out_future = out_future[:,-1:,].repeat(1,self.future_steps,1) ##worakaround to check
284
+ ##hidden state of the past --> initial state
285
+
286
+ if self.kind=='lstm':
287
+ hidden_past = hidden_past[0]
288
+
289
+ #past= 2num_layers_RNNxBSxhidden_RNN//2
290
+ # furture = BSx L x hidden_RNN//2 --> BSxLxC
291
+ BS = hidden_past.shape[1]
292
+ N = hidden_past.shape[0]//2
293
+ past = hidden_past.permute(1,0,2).reshape(BS,-1) #BSx2NxC --> BSx2CN
294
+ future = out_future.repeat(1,1,N)
295
+
296
+ if self.use_bilinear:
297
+ final = self.bilinear(future,past.unsqueeze(2).repeat(1,1,self.future_steps).permute(0,2,1)).permute(0,2,1)
298
+ else:
299
+ if self.use_cumsum:
300
+ final = torch.cumsum(future,axis=1).permute(0,2,1)+past.unsqueeze(2).repeat(1,1,self.future_steps)
301
+ else:
302
+ final = future.permute(0,2,1)+past.unsqueeze(2).repeat(1,1,self.future_steps)
303
+
304
+ res= self.final_linear_decoder(final.permute(0,2,1)).reshape(BS,self.future_steps,self.out_channels,self.mul)
305
+
306
+
307
+
308
+
309
+ return res
310
+