dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
from .utils import QuantileLossMO,Permute, get_activation
|
|
14
|
+
from typing import List, Union
|
|
15
|
+
from ..data_structure.utils import beauty_string
|
|
16
|
+
import numpy as np
|
|
17
|
+
torch.autograd.set_detect_anomaly(True)
|
|
18
|
+
from .utils import get_scope
|
|
19
|
+
from .utils import Embedding_cat_variables
|
|
20
|
+
|
|
21
|
+
class GLU(nn.Module):
|
|
22
|
+
def __init__(self, d_model: int):
|
|
23
|
+
"""Gated Linear Unit, 'Gate' block in TFT paper
|
|
24
|
+
Sub net of GRN: linear(x) * sigmoid(linear(x))
|
|
25
|
+
No dimension changes
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
d_model (int): model dimension
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.linear = nn.Linear(d_model, d_model)
|
|
32
|
+
self.activation = nn.ReLU6()
|
|
33
|
+
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
"""Gated Linear Unit
|
|
36
|
+
Sub net of GRN: linear(x) * sigmoid(linear(x))
|
|
37
|
+
No dimension changes: [bs, seq_len, d_model]
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
x (torch.Tensor)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
torch.Tensor
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
##here comes something like BSxL
|
|
47
|
+
x1 = (self.activation(self.linear(x.unsqueeze(2)))/6.0).squeeze()
|
|
48
|
+
out = x1*x #element-wise multiplication
|
|
49
|
+
|
|
50
|
+
##get the score
|
|
51
|
+
score = torch.sign(x1).mean()
|
|
52
|
+
return out,score
|
|
53
|
+
|
|
54
|
+
class Block(nn.Module):
|
|
55
|
+
def __init__(self,input_channels:int,kernel_size:int,output_channels:int,input_size:int,sum_layers:bool ):
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
super(Block, self).__init__()
|
|
59
|
+
|
|
60
|
+
self.dilations = nn.ModuleList()
|
|
61
|
+
self.steps = int(np.floor(np.log2(input_size)))-1
|
|
62
|
+
|
|
63
|
+
if self.steps <=1:
|
|
64
|
+
self.steps = 1
|
|
65
|
+
|
|
66
|
+
for i in range(self.steps):
|
|
67
|
+
#dilation
|
|
68
|
+
self.dilations.append(nn.Conv1d(input_channels, output_channels, kernel_size, stride=1,padding='same',dilation=2**i))
|
|
69
|
+
s = max(2**i-1,1)
|
|
70
|
+
k = 2**(i+1)+1
|
|
71
|
+
p = int(((s-1)*input_size + k - 1)/2)
|
|
72
|
+
self.dilations.append(nn.Conv1d(input_channels, output_channels, k, stride=s,padding=p))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
self.sum_layers = sum_layers
|
|
78
|
+
mul = 1 if sum_layers else self.steps*2
|
|
79
|
+
self.conv_final = nn.Conv1d(output_channels*mul, output_channels*mul, kernel_size, stride=1,padding='same')
|
|
80
|
+
self.out_channels = output_channels*mul
|
|
81
|
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
|
82
|
+
x = Permute()(x)
|
|
83
|
+
tmp = []
|
|
84
|
+
for i in range(self.steps):
|
|
85
|
+
|
|
86
|
+
tmp.append(self.dilations[i](x))
|
|
87
|
+
|
|
88
|
+
if self.sum_layers:
|
|
89
|
+
tmp = torch.stack(tmp)
|
|
90
|
+
tmp = tmp.sum(axis=0)
|
|
91
|
+
else:
|
|
92
|
+
tmp = torch.cat(tmp,1)
|
|
93
|
+
|
|
94
|
+
return Permute()(tmp)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DilatedConv(Base):
|
|
99
|
+
handle_multivariate = True
|
|
100
|
+
handle_future_covariates = True
|
|
101
|
+
handle_categorical_variables = True
|
|
102
|
+
handle_quantile_loss = True
|
|
103
|
+
|
|
104
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
105
|
+
|
|
106
|
+
def __init__(self,
|
|
107
|
+
sum_layers: bool,
|
|
108
|
+
hidden_RNN:int,
|
|
109
|
+
num_layers_RNN:int,
|
|
110
|
+
kind:str,
|
|
111
|
+
kernel_size:int,
|
|
112
|
+
activation: str='torch.nn.ReLU',
|
|
113
|
+
remove_last = False,
|
|
114
|
+
dropout_rate:float=0.1,
|
|
115
|
+
use_bn:bool=False,
|
|
116
|
+
use_glu:bool=True,
|
|
117
|
+
glu_percentage: float=1.0,
|
|
118
|
+
|
|
119
|
+
**kwargs)->None:
|
|
120
|
+
"""Custom encoder-decoder
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
sum_layers (bool): Flag indicating whether to sum the layers.
|
|
124
|
+
hidden_RNN (int): Number of hidden units in the RNN.
|
|
125
|
+
num_layers_RNN (int): Number of layers in the RNN.
|
|
126
|
+
kind (str): Type of RNN to use (e.g., 'LSTM', 'GRU').
|
|
127
|
+
kernel_size (int): Size of the convolutional kernel.
|
|
128
|
+
activation (str, optional): Activation function to use. Defaults to 'torch.nn.ReLU'.
|
|
129
|
+
remove_last (bool, optional): Flag to indicate whether to remove the last element in the sequence. Defaults to False.
|
|
130
|
+
dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
|
|
131
|
+
use_bn (bool, optional): Flag to indicate whether to use batch normalization. Defaults to False.
|
|
132
|
+
use_glu (bool, optional): Flag to indicate whether to use Gated Linear Units (GLU). Defaults to True.
|
|
133
|
+
glu_percentage (float, optional): Percentage of GLU to apply. Defaults to 1.0.
|
|
134
|
+
**kwargs: Additional keyword arguments.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
None
|
|
138
|
+
"""
|
|
139
|
+
super().__init__(**kwargs)
|
|
140
|
+
if activation == 'torch.nn.SELU':
|
|
141
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
142
|
+
use_bn = False
|
|
143
|
+
if isinstance(activation,str):
|
|
144
|
+
activation = get_activation(activation)
|
|
145
|
+
else:
|
|
146
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
|
|
147
|
+
|
|
148
|
+
self.save_hyperparameters(logger=False)
|
|
149
|
+
self.num_layers_RNN = num_layers_RNN
|
|
150
|
+
self.hidden_RNN = hidden_RNN
|
|
151
|
+
self.kind = kind
|
|
152
|
+
self.use_glu = use_glu
|
|
153
|
+
self.glu_percentage = torch.tensor(glu_percentage).to(self.device)
|
|
154
|
+
self.remove_last = remove_last
|
|
155
|
+
|
|
156
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
157
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
158
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
159
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
160
|
+
|
|
161
|
+
if self.use_glu:
|
|
162
|
+
self.past_glu = nn.ModuleList()
|
|
163
|
+
self.future_glu = nn.ModuleList()
|
|
164
|
+
for _ in range(self.past_channels):
|
|
165
|
+
self.past_glu.append(GLU(1))
|
|
166
|
+
|
|
167
|
+
for _ in range(self.future_channels):
|
|
168
|
+
self.future_glu.append(GLU(1))
|
|
169
|
+
|
|
170
|
+
self.initial_linear_encoder = nn.Sequential(Permute(),
|
|
171
|
+
nn.Conv1d(self.past_channels, (self.past_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
|
|
172
|
+
activation(),
|
|
173
|
+
nn.BatchNorm1d( (self.past_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
174
|
+
nn.Conv1d( (self.past_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
|
|
175
|
+
Permute())
|
|
176
|
+
|
|
177
|
+
self.initial_linear_decoder = nn.Sequential(Permute(),
|
|
178
|
+
nn.Conv1d(self.future_channels, (self.future_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
|
|
179
|
+
activation(),
|
|
180
|
+
nn.BatchNorm1d( (self.future_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
181
|
+
nn.Conv1d( (self.future_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
|
|
182
|
+
Permute())
|
|
183
|
+
self.conv_encoder = Block(emb_past_out_channel+hidden_RNN//4,kernel_size,hidden_RNN//2,self.past_steps,sum_layers)
|
|
184
|
+
|
|
185
|
+
#nn.Sequential(Permute(), nn.Conv1d(emb_channels+hidden_RNN//8, hidden_RNN//8, kernel_size, stride=1,padding='same'),Permute(),nn.Dropout(0.3))
|
|
186
|
+
|
|
187
|
+
if self.future_channels+emb_fut_out_channel==0:
|
|
188
|
+
## occhio che vuol dire che non ho passato , per ora ci metto una pezza e uso hidden dell'encoder
|
|
189
|
+
self.conv_decoder = Block(hidden_RNN,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
|
|
190
|
+
else:
|
|
191
|
+
self.conv_decoder = Block(self.future_channels+emb_fut_out_channel,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
|
|
192
|
+
#nn.Sequential(Permute(),nn.Linear(past_steps,past_steps*2), nn.PReLU(),nn.Dropout(0.2),nn.Linear(past_steps*2, future_steps),nn.Dropout(0.3),nn.Conv1d(hidden_RNN, hidden_RNN//8, 3, stride=1,padding='same'), Permute())
|
|
193
|
+
if self.kind=='lstm':
|
|
194
|
+
self.Encoder = nn.LSTM(input_size= self.conv_encoder.out_channels,#, hidden_RNN//4,
|
|
195
|
+
hidden_size=hidden_RNN//2,
|
|
196
|
+
num_layers = num_layers_RNN,
|
|
197
|
+
batch_first=True,bidirectional=True)
|
|
198
|
+
self.Decoder = nn.LSTM(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
|
|
199
|
+
hidden_size=hidden_RNN//2,
|
|
200
|
+
num_layers = num_layers_RNN,
|
|
201
|
+
batch_first=True,bidirectional=True)
|
|
202
|
+
elif self.kind=='gru':
|
|
203
|
+
self.Encoder = nn.GRU(input_size=self.conv_encoder.out_channels,#, hidden_RNN//4,
|
|
204
|
+
hidden_size=hidden_RNN//2,
|
|
205
|
+
num_layers = num_layers_RNN,
|
|
206
|
+
batch_first=True,bidirectional=True)
|
|
207
|
+
self.Decoder = nn.GRU(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
|
|
208
|
+
hidden_size=hidden_RNN//2,
|
|
209
|
+
num_layers = num_layers_RNN,
|
|
210
|
+
batch_first=True,bidirectional=True)
|
|
211
|
+
else:
|
|
212
|
+
beauty_string('Specify kind lstm or gru please','section',True)
|
|
213
|
+
self.final_linear = nn.ModuleList()
|
|
214
|
+
for _ in range(self.out_channels*self.mul):
|
|
215
|
+
self.final_linear.append(nn.Sequential(nn.Linear(hidden_RNN+emb_fut_out_channel+self.future_channels,hidden_RNN*2),
|
|
216
|
+
activation(),
|
|
217
|
+
Permute() if use_bn else nn.Identity() ,
|
|
218
|
+
nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
219
|
+
Permute() if use_bn else nn.Identity() ,
|
|
220
|
+
nn.Linear(hidden_RNN*2,hidden_RNN),
|
|
221
|
+
activation(),
|
|
222
|
+
Permute() if use_bn else nn.Identity() ,
|
|
223
|
+
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
224
|
+
Permute() if use_bn else nn.Identity() ,
|
|
225
|
+
nn.Linear(hidden_RNN,hidden_RNN//2),
|
|
226
|
+
activation(),
|
|
227
|
+
Permute() if use_bn else nn.Identity() ,
|
|
228
|
+
nn.BatchNorm1d(hidden_RNN//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
229
|
+
Permute() if use_bn else nn.Identity() ,
|
|
230
|
+
nn.Linear(hidden_RNN//2,hidden_RNN//4),
|
|
231
|
+
activation(),
|
|
232
|
+
nn.Linear(hidden_RNN//4,1)))
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def training_step(self, batch, batch_idx):
|
|
239
|
+
"""
|
|
240
|
+
pythotrch lightening stuff
|
|
241
|
+
|
|
242
|
+
:meta private:
|
|
243
|
+
"""
|
|
244
|
+
y_hat,score = self(batch)
|
|
245
|
+
return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
|
|
246
|
+
|
|
247
|
+
def validation_step(self, batch, batch_idx):
|
|
248
|
+
"""
|
|
249
|
+
pythotrch lightening stuff
|
|
250
|
+
|
|
251
|
+
:meta private:
|
|
252
|
+
"""
|
|
253
|
+
y_hat,score = self(batch)
|
|
254
|
+
return self.compute_loss(batch,y_hat)#+torch.abs(score-self.glu_percentage)*loss/5.0 ##TODO investigating
|
|
255
|
+
|
|
256
|
+
def forward(self, batch):
|
|
257
|
+
"""It is mandatory to implement this method
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
batch (dict): batch of the dataloader
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
torch.tensor: result
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
x = batch['x_num_past'].to(self.device)
|
|
267
|
+
BS = x.shape[0]
|
|
268
|
+
if 'x_cat_future' in batch.keys():
|
|
269
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
270
|
+
else:
|
|
271
|
+
emb_fut = self.emb_fut(BS,None)
|
|
272
|
+
if 'x_cat_past' in batch.keys():
|
|
273
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
274
|
+
else:
|
|
275
|
+
emb_past = self.emb_past(BS,None)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
if 'x_num_future' in batch.keys():
|
|
279
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
280
|
+
xf = torch.clone(x_future)
|
|
281
|
+
else:
|
|
282
|
+
x_future = None
|
|
283
|
+
|
|
284
|
+
if self.remove_last:
|
|
285
|
+
idx_target = batch['idx_target'][0]
|
|
286
|
+
|
|
287
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
288
|
+
##BxC
|
|
289
|
+
x[:,:,idx_target]-=x_start
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
## first GLU
|
|
293
|
+
score = 0
|
|
294
|
+
xp = torch.clone(x)
|
|
295
|
+
|
|
296
|
+
if self.use_glu:
|
|
297
|
+
score_past_tot = 0
|
|
298
|
+
score_future_tot = 0
|
|
299
|
+
|
|
300
|
+
for i in range(len(self.past_glu)):
|
|
301
|
+
x[:,:,i],score = self.past_glu[i](xp[:,:,i])
|
|
302
|
+
score_past_tot+=score
|
|
303
|
+
score_past_tot/=len(self.past_glu)
|
|
304
|
+
|
|
305
|
+
if x_future is not None:
|
|
306
|
+
for i in range(len(self.future_glu)):
|
|
307
|
+
x_future[:,:,i],score = self.future_glu[i](xf[:,:,i])
|
|
308
|
+
score_future_tot+=score
|
|
309
|
+
score_future_tot/=len(self.future_glu)
|
|
310
|
+
score = 0.5*(score_past_tot+score_future_tot)
|
|
311
|
+
tmp = [self.initial_linear_encoder(x),emb_past]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
tot = torch.cat(tmp,2)
|
|
316
|
+
out, hidden = self.Encoder(self.conv_encoder(tot))
|
|
317
|
+
tmp = [emb_fut]
|
|
318
|
+
if x_future is not None:
|
|
319
|
+
tmp.append(x_future)
|
|
320
|
+
tot = torch.cat(tmp,2)
|
|
321
|
+
out, _ = self.Decoder(self.conv_decoder(tot),hidden)
|
|
322
|
+
res = []
|
|
323
|
+
tmp = torch.cat([tot,out],axis=2)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
for j in range(self.out_channels*self.mul):
|
|
327
|
+
res.append(self.final_linear[j](tmp))
|
|
328
|
+
|
|
329
|
+
res = torch.cat(res,2)
|
|
330
|
+
##BxLxC
|
|
331
|
+
B = res.shape[0]
|
|
332
|
+
res = res.reshape(B,self.future_steps,-1,self.mul)
|
|
333
|
+
if self.remove_last:
|
|
334
|
+
res+=x_start.unsqueeze(1)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
return res, score
|
|
338
|
+
|
|
339
|
+
def inference(self, batch:dict)->torch.tensor:
|
|
340
|
+
|
|
341
|
+
res, score = self(batch)
|
|
342
|
+
return res
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
from .utils import QuantileLossMO,Permute, get_activation
|
|
14
|
+
from typing import List, Union
|
|
15
|
+
from ..data_structure.utils import beauty_string
|
|
16
|
+
import numpy as np
|
|
17
|
+
from .utils import get_scope
|
|
18
|
+
from .utils import Embedding_cat_variables
|
|
19
|
+
torch.autograd.set_detect_anomaly(True)
|
|
20
|
+
|
|
21
|
+
class GLU(nn.Module):
|
|
22
|
+
def __init__(self, d_model: int):
|
|
23
|
+
"""Gated Linear Unit, 'Gate' block in TFT paper
|
|
24
|
+
Sub net of GRN: linear(x) * sigmoid(linear(x))
|
|
25
|
+
No dimension changes
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
d_model (int): model dimension
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.linear = nn.Linear(d_model, d_model)
|
|
32
|
+
self.activation = nn.ReLU6()
|
|
33
|
+
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
"""Gated Linear Unit
|
|
36
|
+
Sub net of GRN: linear(x) * sigmoid(linear(x))
|
|
37
|
+
No dimension changes: [bs, seq_len, d_model]
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
x (torch.Tensor)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
torch.Tensor
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
##here comes something like BSxL
|
|
47
|
+
x1 = (self.activation(self.linear(x.unsqueeze(2)))/6.0).squeeze()
|
|
48
|
+
out = x1*x #element-wise multiplication
|
|
49
|
+
|
|
50
|
+
##get the score
|
|
51
|
+
score = torch.sign(x1).mean()
|
|
52
|
+
return out,score
|
|
53
|
+
|
|
54
|
+
class Block(nn.Module):
|
|
55
|
+
def __init__(self,input_channels:int,kernel_size:int,output_channels:int,input_size:int,sum_layers:bool ):
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
super(Block, self).__init__()
|
|
59
|
+
|
|
60
|
+
self.dilations = nn.ModuleList()
|
|
61
|
+
self.steps = int(np.floor(np.log2(input_size)))-1
|
|
62
|
+
|
|
63
|
+
if self.steps <=1:
|
|
64
|
+
self.steps = 1
|
|
65
|
+
|
|
66
|
+
for i in range(self.steps):
|
|
67
|
+
#dilation
|
|
68
|
+
self.dilations.append(nn.Conv1d(input_channels, output_channels, kernel_size, stride=1,padding='same',dilation=2**i))
|
|
69
|
+
s = max(2**i-1,1)
|
|
70
|
+
k = 2**(i+1)+1
|
|
71
|
+
p = int(((s-1)*input_size + k - 1)/2)
|
|
72
|
+
self.dilations.append(nn.Conv1d(input_channels, output_channels, k, stride=s,padding=p))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
self.sum_layers = sum_layers
|
|
78
|
+
mul = 1 if sum_layers else self.steps*2
|
|
79
|
+
self.conv_final = nn.Conv1d(output_channels*mul, output_channels*mul, kernel_size, stride=1,padding='same')
|
|
80
|
+
self.out_channels = output_channels*mul
|
|
81
|
+
def forward(self, x: torch.tensor) -> torch.tensor:
|
|
82
|
+
x = Permute()(x)
|
|
83
|
+
tmp = []
|
|
84
|
+
for i in range(self.steps):
|
|
85
|
+
|
|
86
|
+
tmp.append(self.dilations[i](x))
|
|
87
|
+
|
|
88
|
+
if self.sum_layers:
|
|
89
|
+
tmp = torch.stack(tmp)
|
|
90
|
+
tmp = tmp.sum(axis=0)
|
|
91
|
+
else:
|
|
92
|
+
tmp = torch.cat(tmp,1)
|
|
93
|
+
|
|
94
|
+
return Permute()(tmp)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DilatedConvED(Base):
|
|
99
|
+
handle_multivariate = True
|
|
100
|
+
handle_future_covariates = True
|
|
101
|
+
handle_categorical_variables = True
|
|
102
|
+
handle_quantile_loss = True
|
|
103
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
104
|
+
|
|
105
|
+
def __init__(self,
|
|
106
|
+
sum_layers: bool,
|
|
107
|
+
hidden_RNN:int,
|
|
108
|
+
num_layers_RNN:int,
|
|
109
|
+
kind:str,
|
|
110
|
+
kernel_size:int,
|
|
111
|
+
dropout_rate:float=0.1,
|
|
112
|
+
use_bn:bool=False,
|
|
113
|
+
use_cumsum:bool=True,
|
|
114
|
+
use_bilinear:bool=False,
|
|
115
|
+
activation: str='torch.nn.ReLU',
|
|
116
|
+
|
|
117
|
+
**kwargs)->None:
|
|
118
|
+
"""Initialize the model with specified parameters.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
sum_layers (bool): Flag indicating whether to sum layers in the encoder/decoder blocks.
|
|
122
|
+
hidden_RNN (int): Number of hidden units in the RNN.
|
|
123
|
+
num_layers_RNN (int): Number of layers in the RNN.
|
|
124
|
+
kind (str): Type of RNN to use ('lstm' or 'gru').
|
|
125
|
+
kernel_size (int): Size of the convolutional kernel.
|
|
126
|
+
dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
|
|
127
|
+
use_bn (bool, optional): Flag to use batch normalization. Defaults to False.
|
|
128
|
+
use_cumsum (bool, optional): Flag to use cumulative sum. Defaults to True.
|
|
129
|
+
use_bilinear (bool, optional): Flag to use bilinear layers. Defaults to False.
|
|
130
|
+
activation (str, optional): Activation function to use. Defaults to 'torch.nn.ReLU'.
|
|
131
|
+
**kwargs: Additional keyword arguments.
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If the specified activation function is not recognized or if the kind is not 'lstm' or 'gru'.
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
super().__init__(**kwargs)
|
|
139
|
+
if activation == 'torch.nn.SELU':
|
|
140
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
141
|
+
use_bn = False
|
|
142
|
+
if isinstance(activation,str):
|
|
143
|
+
activation = get_activation(activation)
|
|
144
|
+
else:
|
|
145
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
|
|
146
|
+
|
|
147
|
+
self.save_hyperparameters(logger=False)
|
|
148
|
+
|
|
149
|
+
self.num_layers_RNN = num_layers_RNN
|
|
150
|
+
self.hidden_RNN = hidden_RNN
|
|
151
|
+
self.use_cumsum = use_cumsum
|
|
152
|
+
self.kind = kind
|
|
153
|
+
self.use_bilinear= use_bilinear
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
157
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
158
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
159
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
self.initial_linear_encoder = nn.Sequential(Permute(),
|
|
165
|
+
nn.Conv1d(self.past_channels, (self.past_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
|
|
166
|
+
activation(),
|
|
167
|
+
nn.BatchNorm1d( (self.past_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
168
|
+
nn.Conv1d( (self.past_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
|
|
169
|
+
Permute())
|
|
170
|
+
|
|
171
|
+
self.initial_linear_decoder = nn.Sequential(Permute(),
|
|
172
|
+
nn.Conv1d(self.future_channels, (self.future_channels+hidden_RNN//4)//2, kernel_size, stride=1,padding='same'),
|
|
173
|
+
activation(),
|
|
174
|
+
nn.BatchNorm1d( (self.future_channels+hidden_RNN//4)//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
175
|
+
nn.Conv1d( (self.future_channels+hidden_RNN//4)//2, hidden_RNN//4, kernel_size, stride=1,padding='same'),
|
|
176
|
+
Permute())
|
|
177
|
+
self.conv_encoder = Block(emb_past_out_channel+hidden_RNN//4,kernel_size,hidden_RNN//2,self.past_steps,sum_layers)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
if self.future_channels+emb_fut_out_channel==0:
|
|
181
|
+
## occhio che vuol dire che non ho passato , per ora ci metto una pezza e uso hidden dell'encoder
|
|
182
|
+
self.conv_decoder = Block(hidden_RNN,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
|
|
183
|
+
else:
|
|
184
|
+
self.conv_decoder = Block(self.future_channels+emb_fut_out_channel,kernel_size,hidden_RNN//2,self.future_steps,sum_layers)
|
|
185
|
+
#nn.Sequential(Permute(),nn.Linear(past_steps,past_steps*2), nn.PReLU(),nn.Dropout(0.2),nn.Linear(past_steps*2, future_steps),nn.Dropout(0.3),nn.Conv1d(hidden_RNN, hidden_RNN//8, 3, stride=1,padding='same'), Permute())
|
|
186
|
+
if self.kind=='lstm':
|
|
187
|
+
self.Encoder = nn.LSTM(input_size= self.conv_encoder.out_channels,#, hidden_RNN//4,
|
|
188
|
+
hidden_size=hidden_RNN//2,
|
|
189
|
+
num_layers = num_layers_RNN,
|
|
190
|
+
batch_first=True,bidirectional=True)
|
|
191
|
+
self.Decoder = nn.LSTM(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
|
|
192
|
+
hidden_size=hidden_RNN//2,
|
|
193
|
+
num_layers = num_layers_RNN,
|
|
194
|
+
batch_first=True,bidirectional=True)
|
|
195
|
+
elif self.kind=='gru':
|
|
196
|
+
self.Encoder = nn.GRU(input_size=self.conv_encoder.out_channels,#, hidden_RNN//4,
|
|
197
|
+
hidden_size=hidden_RNN//2,
|
|
198
|
+
num_layers = num_layers_RNN,
|
|
199
|
+
batch_first=True,bidirectional=True)
|
|
200
|
+
self.Decoder = nn.GRU(input_size= self.conv_decoder.out_channels,#, hidden_RNN//4,
|
|
201
|
+
hidden_size=hidden_RNN//2,
|
|
202
|
+
num_layers = num_layers_RNN,
|
|
203
|
+
batch_first=True,bidirectional=True)
|
|
204
|
+
else:
|
|
205
|
+
beauty_string('Specify kind lstm or gru please','section',True)
|
|
206
|
+
self.final_linear_decoder = nn.Sequential(nn.Linear((hidden_RNN//2*2)*num_layers_RNN ,hidden_RNN*2),
|
|
207
|
+
activation(),
|
|
208
|
+
Permute() if use_bn else nn.Identity() ,
|
|
209
|
+
nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
210
|
+
Permute() if use_bn else nn.Identity() ,
|
|
211
|
+
nn.Linear(hidden_RNN*2,hidden_RNN),
|
|
212
|
+
activation(),
|
|
213
|
+
Permute() if use_bn else nn.Identity() ,
|
|
214
|
+
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
215
|
+
Permute() if use_bn else nn.Identity() ,
|
|
216
|
+
nn.Linear(hidden_RNN,self.mul))
|
|
217
|
+
|
|
218
|
+
if use_bilinear:
|
|
219
|
+
self.bilinear = torch.nn.Bilinear((hidden_RNN//2*2)*num_layers_RNN,(hidden_RNN//2*2)*num_layers_RNN,hidden_RNN*2)
|
|
220
|
+
self.final_linear_decoder = nn.Sequential(
|
|
221
|
+
activation(),
|
|
222
|
+
Permute() if use_bn else nn.Identity() ,
|
|
223
|
+
nn.BatchNorm1d(hidden_RNN*2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
224
|
+
Permute() if use_bn else nn.Identity() ,
|
|
225
|
+
nn.Linear(hidden_RNN*2,hidden_RNN),
|
|
226
|
+
activation(),
|
|
227
|
+
Permute() if use_bn else nn.Identity() ,
|
|
228
|
+
nn.BatchNorm1d(hidden_RNN) if use_bn else nn.Dropout(dropout_rate) ,
|
|
229
|
+
Permute() if use_bn else nn.Identity() ,
|
|
230
|
+
nn.Linear(hidden_RNN ,self.mul))
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def forward(self, batch):
|
|
235
|
+
"""It is mandatory to implement this method
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
batch (dict): batch of the dataloader
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
torch.tensor: result
|
|
242
|
+
"""
|
|
243
|
+
x = batch['x_num_past'].to(self.device)
|
|
244
|
+
BS = x.shape[0]
|
|
245
|
+
if 'x_cat_future' in batch.keys():
|
|
246
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
247
|
+
else:
|
|
248
|
+
emb_fut = self.emb_fut(BS,None)
|
|
249
|
+
if 'x_cat_past' in batch.keys():
|
|
250
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
251
|
+
else:
|
|
252
|
+
emb_past = self.emb_past(BS,None)
|
|
253
|
+
|
|
254
|
+
if 'x_num_future' in batch.keys():
|
|
255
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
256
|
+
else:
|
|
257
|
+
x_future = None
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
tmp = [self.initial_linear_encoder(x),emb_past]
|
|
261
|
+
|
|
262
|
+
tot = torch.cat(tmp,2)
|
|
263
|
+
|
|
264
|
+
out_past, hidden_past = self.Encoder(self.conv_encoder(tot))
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
## hidden = 2 x bs x channels_out_encoder
|
|
268
|
+
## out = BS x len x channels_out_encoder
|
|
269
|
+
tmp = [emb_fut]
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
if x_future is not None:
|
|
273
|
+
tmp.append(x_future)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
if len(tmp)>0:
|
|
279
|
+
tot = torch.cat(tmp,2)
|
|
280
|
+
out_future, hidden_future = self.Decoder(self.conv_decoder(tot))
|
|
281
|
+
else:
|
|
282
|
+
out_future, hidden_future = self.Decoder(self.conv_decoder(out_past))
|
|
283
|
+
out_future = out_future[:,-1:,].repeat(1,self.future_steps,1) ##worakaround to check
|
|
284
|
+
##hidden state of the past --> initial state
|
|
285
|
+
|
|
286
|
+
if self.kind=='lstm':
|
|
287
|
+
hidden_past = hidden_past[0]
|
|
288
|
+
|
|
289
|
+
#past= 2num_layers_RNNxBSxhidden_RNN//2
|
|
290
|
+
# furture = BSx L x hidden_RNN//2 --> BSxLxC
|
|
291
|
+
BS = hidden_past.shape[1]
|
|
292
|
+
N = hidden_past.shape[0]//2
|
|
293
|
+
past = hidden_past.permute(1,0,2).reshape(BS,-1) #BSx2NxC --> BSx2CN
|
|
294
|
+
future = out_future.repeat(1,1,N)
|
|
295
|
+
|
|
296
|
+
if self.use_bilinear:
|
|
297
|
+
final = self.bilinear(future,past.unsqueeze(2).repeat(1,1,self.future_steps).permute(0,2,1)).permute(0,2,1)
|
|
298
|
+
else:
|
|
299
|
+
if self.use_cumsum:
|
|
300
|
+
final = torch.cumsum(future,axis=1).permute(0,2,1)+past.unsqueeze(2).repeat(1,1,self.future_steps)
|
|
301
|
+
else:
|
|
302
|
+
final = future.permute(0,2,1)+past.unsqueeze(2).repeat(1,1,self.future_steps)
|
|
303
|
+
|
|
304
|
+
res= self.final_linear_decoder(final.permute(0,2,1)).reshape(BS,self.future_steps,self.out_channels,self.mul)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
return res
|
|
310
|
+
|