dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
|
|
2
|
+
## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
|
|
3
|
+
## Code modified for align the notation and the batch generation
|
|
4
|
+
## extended to all present in informer, autoformer folder
|
|
5
|
+
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import lightning.pytorch as pl
|
|
11
|
+
from .base_v2 import Base
|
|
12
|
+
OLD_PL = False
|
|
13
|
+
except:
|
|
14
|
+
import pytorch_lightning as pl
|
|
15
|
+
OLD_PL = True
|
|
16
|
+
from .base import Base
|
|
17
|
+
from .utils import QuantileLossMO, get_activation
|
|
18
|
+
from typing import List, Union
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .utils import get_scope
|
|
21
|
+
from .utils import Embedding_cat_variables
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class moving_avg(nn.Module):
|
|
26
|
+
"""
|
|
27
|
+
Moving average block to highlight the trend of time series
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, kernel_size, stride):
|
|
30
|
+
super(moving_avg, self).__init__()
|
|
31
|
+
self.kernel_size = kernel_size
|
|
32
|
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
# padding on the both ends of time series
|
|
36
|
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
37
|
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
38
|
+
x = torch.cat([front, x, end], dim=1)
|
|
39
|
+
x = self.avg(x.permute(0, 2, 1))
|
|
40
|
+
x = x.permute(0, 2, 1)
|
|
41
|
+
return x
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class series_decomp(nn.Module):
|
|
45
|
+
"""
|
|
46
|
+
Series decomposition block
|
|
47
|
+
"""
|
|
48
|
+
def __init__(self, kernel_size):
|
|
49
|
+
super(series_decomp, self).__init__()
|
|
50
|
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
51
|
+
|
|
52
|
+
def forward(self, x):
|
|
53
|
+
moving_mean = self.moving_avg(x)
|
|
54
|
+
res = x - moving_mean
|
|
55
|
+
return res, moving_mean
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LinearTS(Base):
|
|
59
|
+
handle_multivariate = True
|
|
60
|
+
handle_future_covariates = True
|
|
61
|
+
handle_categorical_variables = True
|
|
62
|
+
handle_quantile_loss = True
|
|
63
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
64
|
+
description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
|
|
65
|
+
|
|
66
|
+
def __init__(self,
|
|
67
|
+
|
|
68
|
+
kernel_size:int,
|
|
69
|
+
hidden_size:int,
|
|
70
|
+
dropout_rate:float=0.1,
|
|
71
|
+
activation:str='torch.nn.ReLU',
|
|
72
|
+
kind:str='linear',
|
|
73
|
+
use_bn:bool=False,
|
|
74
|
+
simple:bool=False,
|
|
75
|
+
**kwargs)->None:
|
|
76
|
+
"""Initialize the model with specified parameters. Linear model from https://github.com/cure-lab/LTSF-Linear/blob/main/run_longExp.py
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
kernel_size (int): Kernel dimension for the initial moving average.
|
|
80
|
+
hidden_size (int): Hidden size of the linear block.
|
|
81
|
+
dropout_rate (float, optional): Dropout rate in Dropout layers. Default is 0.1.
|
|
82
|
+
activation (str, optional): Activation function in PyTorch. Default is 'torch.nn.ReLU'.
|
|
83
|
+
kind (str, optional): Type of model, can be 'linear', 'dlinear' (de-trending), or 'nlinear' (differential). Defaults to 'linear'.
|
|
84
|
+
use_bn (bool, optional): If True, Batch Normalization layers will be added and Dropouts will be removed. Default is False.
|
|
85
|
+
simple (bool, optional): If True, the model used is the same as illustrated in the paper; otherwise, a more complex model with the same idea is used. Default is False.
|
|
86
|
+
**kwargs: Additional keyword arguments for the parent class.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: If an invalid activation function is provided.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
super().__init__(**kwargs)
|
|
93
|
+
|
|
94
|
+
if activation == 'torch.nn.SELU':
|
|
95
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
96
|
+
use_bn = False
|
|
97
|
+
|
|
98
|
+
if isinstance(activation, str):
|
|
99
|
+
activation = get_activation(activation)
|
|
100
|
+
else:
|
|
101
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose)
|
|
102
|
+
|
|
103
|
+
self.save_hyperparameters(logger=False)
|
|
104
|
+
|
|
105
|
+
self.kind = kind
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
self.simple = simple
|
|
109
|
+
|
|
110
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
111
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
112
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
113
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
## ne faccio uno per ogni canale
|
|
118
|
+
self.linear = nn.ModuleList()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if kind=='dlinear':
|
|
122
|
+
self.decompsition = series_decomp(kernel_size)
|
|
123
|
+
self.Linear_Trend = nn.ModuleList()
|
|
124
|
+
for _ in range(self.out_channels):
|
|
125
|
+
self.Linear_Trend.append(nn.Linear(self.past_steps,self.future_steps))
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
for _ in range(self.out_channels):
|
|
129
|
+
if simple:
|
|
130
|
+
self.linear.append(nn.Linear(self.past_steps,self.future_steps*self.mul))
|
|
131
|
+
|
|
132
|
+
else:
|
|
133
|
+
self.linear.append(nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
|
|
134
|
+
activation(),
|
|
135
|
+
nn.BatchNorm1d(hidden_size) if use_bn else nn.Dropout(dropout_rate) ,
|
|
136
|
+
nn.Linear(hidden_size,hidden_size//2),
|
|
137
|
+
activation(),
|
|
138
|
+
nn.BatchNorm1d(hidden_size//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
139
|
+
nn.Linear(hidden_size//2,hidden_size//4),
|
|
140
|
+
activation(),
|
|
141
|
+
nn.BatchNorm1d(hidden_size//4) if use_bn else nn.Dropout(dropout_rate) ,
|
|
142
|
+
nn.Linear(hidden_size//4,hidden_size//8),
|
|
143
|
+
activation(),
|
|
144
|
+
nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
145
|
+
nn.Linear(hidden_size//8,self.future_steps*self.mul)))
|
|
146
|
+
|
|
147
|
+
def forward(self, batch):
|
|
148
|
+
|
|
149
|
+
x = batch['x_num_past'].to(self.device)
|
|
150
|
+
idx_target = batch['idx_target'][0]
|
|
151
|
+
|
|
152
|
+
BS = x.shape[0]
|
|
153
|
+
if 'x_cat_future' in batch.keys():
|
|
154
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
155
|
+
else:
|
|
156
|
+
emb_fut = self.emb_fut(BS,None)
|
|
157
|
+
if 'x_cat_past' in batch.keys():
|
|
158
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
159
|
+
else:
|
|
160
|
+
emb_past = self.emb_past(BS,None)
|
|
161
|
+
|
|
162
|
+
if self.kind=='nlinear':
|
|
163
|
+
|
|
164
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
165
|
+
x[:,:,idx_target]-=x_start
|
|
166
|
+
|
|
167
|
+
if self.kind=='alinear':
|
|
168
|
+
x[:,:,idx_target] = 0
|
|
169
|
+
|
|
170
|
+
if self.kind=='dlinear':
|
|
171
|
+
x_start = x[:,:,idx_target]
|
|
172
|
+
seasonal_init, trend_init = self.decompsition(x_start)
|
|
173
|
+
seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
|
|
174
|
+
x[:,:,idx_target] = seasonal_init.permute(0,2,1)
|
|
175
|
+
tmp = []
|
|
176
|
+
for j in range(len(self.Linear_Trend)):
|
|
177
|
+
|
|
178
|
+
tmp.append(self.Linear_Trend[j](trend_init[:,j,:]))
|
|
179
|
+
|
|
180
|
+
trend = torch.stack(tmp,2)
|
|
181
|
+
|
|
182
|
+
if self.simple is False:
|
|
183
|
+
if 'x_num_future' in batch.keys():
|
|
184
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
185
|
+
else:
|
|
186
|
+
x_future = None
|
|
187
|
+
|
|
188
|
+
tmp = [x,emb_past]
|
|
189
|
+
tot_past = torch.cat(tmp,2).flatten(1)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
tmp = [emb_fut]
|
|
194
|
+
|
|
195
|
+
if x_future is not None:
|
|
196
|
+
tmp.append(x_future)
|
|
197
|
+
|
|
198
|
+
tot_future = torch.cat(tmp,2).flatten(1)
|
|
199
|
+
tot = torch.cat([tot_past,tot_future],1)
|
|
200
|
+
|
|
201
|
+
tot = tot.unsqueeze(2).repeat(1,1,len(self.linear)).permute(0,2,1)
|
|
202
|
+
else:
|
|
203
|
+
tot = x.permute(0,2,1)
|
|
204
|
+
res = []
|
|
205
|
+
|
|
206
|
+
for j in range(len(self.linear)):
|
|
207
|
+
res.append(self.linear[j](tot[:,j,:]).reshape(BS,-1,self.mul))
|
|
208
|
+
## BxLxCxMUL
|
|
209
|
+
res = torch.stack(res,2)
|
|
210
|
+
|
|
211
|
+
if self.kind=='nlinear':
|
|
212
|
+
#res BxLxCx3
|
|
213
|
+
#start BxCx1
|
|
214
|
+
res+=x_start.unsqueeze(1)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
if self.kind=='dlinear':
|
|
218
|
+
res = res+trend.unsqueeze(3)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
return res
|
|
222
|
+
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
## Copyright https://github.com/yuqinie98/PatchTST/blob/main/LICENSE
|
|
2
|
+
## Modified for notation alignmenet and batch structure
|
|
3
|
+
## extended to what inside patchtst folder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from torch import nn
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import lightning.pytorch as pl
|
|
12
|
+
from .base_v2 import Base
|
|
13
|
+
OLD_PL = False
|
|
14
|
+
except:
|
|
15
|
+
import pytorch_lightning as pl
|
|
16
|
+
OLD_PL = True
|
|
17
|
+
from .base import Base
|
|
18
|
+
from typing import List,Union
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .utils import get_scope
|
|
21
|
+
from .utils import get_activation
|
|
22
|
+
from .patchtst.layers import series_decomp, PatchTST_backbone
|
|
23
|
+
from .utils import Embedding_cat_variables
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PatchTST(Base):
|
|
29
|
+
handle_multivariate = True
|
|
30
|
+
handle_future_covariates = False
|
|
31
|
+
handle_categorical_variables = True
|
|
32
|
+
handle_quantile_loss = True
|
|
33
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
|
|
38
|
+
d_model:int,
|
|
39
|
+
patch_len:int,
|
|
40
|
+
kernel_size:int,
|
|
41
|
+
decomposition:bool=True,
|
|
42
|
+
activation:str='torch.nn.ReLU',
|
|
43
|
+
n_head:int=1,
|
|
44
|
+
n_layer:int=2,
|
|
45
|
+
stride:int=8,
|
|
46
|
+
remove_last:bool = False,
|
|
47
|
+
hidden_size:int=1048,
|
|
48
|
+
dropout_rate:float=0.1,
|
|
49
|
+
**kwargs)->None:
|
|
50
|
+
"""Initializes the model with specified parameters.https://github.com/yuqinie98/PatchTST/blob/main/
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
d_model (int): The dimensionality of the model.
|
|
54
|
+
patch_len (int): The length of the patches.
|
|
55
|
+
kernel_size (int): The size of the kernel for convolutional layers.
|
|
56
|
+
decomposition (bool, optional): Whether to use decomposition. Defaults to True.
|
|
57
|
+
activation (str, optional): The activation function to use. Defaults to 'torch.nn.ReLU'.
|
|
58
|
+
n_head (int, optional): The number of attention heads. Defaults to 1.
|
|
59
|
+
n_layer (int, optional): The number of layers in the model. Defaults to 2.
|
|
60
|
+
stride (int, optional): The stride for convolutional layers. Defaults to 8.
|
|
61
|
+
remove_last (bool, optional): Whether to remove the last layer. Defaults to False.
|
|
62
|
+
hidden_size (int, optional): The size of the hidden layers. Defaults to 1048.
|
|
63
|
+
dropout_rate (float, optional): The dropout rate for regularization. Defaults to 0.1.
|
|
64
|
+
**kwargs: Additional keyword arguments.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ValueError: If the activation function is not recognized.
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
super().__init__(**kwargs)
|
|
72
|
+
|
|
73
|
+
if activation == 'torch.nn.SELU':
|
|
74
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
75
|
+
if isinstance(activation, str):
|
|
76
|
+
activation = get_activation(activation)
|
|
77
|
+
else:
|
|
78
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
self.save_hyperparameters(logger=False)
|
|
82
|
+
|
|
83
|
+
self.remove_last = remove_last
|
|
84
|
+
|
|
85
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
86
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
87
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
88
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
self.past_channels+=emb_past_out_channel
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# model
|
|
96
|
+
self.decomposition = decomposition
|
|
97
|
+
if self.decomposition:
|
|
98
|
+
self.decomp_module = series_decomp(kernel_size)
|
|
99
|
+
self.model_trend = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
|
|
100
|
+
max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
|
|
101
|
+
n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
|
|
102
|
+
dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
|
|
103
|
+
attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
|
|
104
|
+
pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
|
|
105
|
+
pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
|
|
106
|
+
subtract_last=remove_last, verbose=False)
|
|
107
|
+
self.model_res = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
|
|
108
|
+
max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
|
|
109
|
+
n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
|
|
110
|
+
dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
|
|
111
|
+
attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
|
|
112
|
+
pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
|
|
113
|
+
pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
|
|
114
|
+
subtract_last=remove_last, verbose=False)
|
|
115
|
+
else:
|
|
116
|
+
self.model = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
|
|
117
|
+
max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
|
|
118
|
+
n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
|
|
119
|
+
dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
|
|
120
|
+
attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
|
|
121
|
+
pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
|
|
122
|
+
pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
|
|
123
|
+
subtract_last=remove_last, verbose=False)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
dim = self.past_channels+emb_fut_out_channel+self.future_channels
|
|
127
|
+
self.final_layer = nn.Sequential(activation(),
|
|
128
|
+
nn.Linear(dim, dim*2),
|
|
129
|
+
activation(),
|
|
130
|
+
nn.Linear(dim*2,self.out_channels*self.mul ))
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
#self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
|
|
135
|
+
|
|
136
|
+
def forward(self, batch): # x: [Batch, Input length, Channel]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
x_seq = batch['x_num_past'].to(self.device)#[:,:,idx_target]
|
|
140
|
+
BS = x_seq.shape[0]
|
|
141
|
+
if 'x_cat_future' in batch.keys():
|
|
142
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
143
|
+
else:
|
|
144
|
+
emb_fut = self.emb_fut(BS,None)
|
|
145
|
+
if 'x_cat_past' in batch.keys():
|
|
146
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
147
|
+
else:
|
|
148
|
+
emb_past = self.emb_past(BS,None)
|
|
149
|
+
|
|
150
|
+
tmp_future = [emb_fut]
|
|
151
|
+
if 'x_num_future' in batch.keys():
|
|
152
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
153
|
+
tmp_future.append(x_future)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
tot = [x_seq,emb_past]
|
|
157
|
+
|
|
158
|
+
x_seq = torch.cat(tot,axis=2)
|
|
159
|
+
|
|
160
|
+
if self.decomposition:
|
|
161
|
+
res_init, trend_init = self.decomp_module(x_seq)
|
|
162
|
+
res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length]
|
|
163
|
+
res = self.model_res(res_init)
|
|
164
|
+
trend = self.model_trend(trend_init)
|
|
165
|
+
x = res + trend
|
|
166
|
+
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
|
|
167
|
+
else:
|
|
168
|
+
x = x_seq.permute(0,2,1)# x: [Batch, Channel, Input length]
|
|
169
|
+
x = self.model(x)
|
|
170
|
+
x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
tmp_future.append(x)
|
|
174
|
+
tmp_future = torch.cat(tmp_future,2)
|
|
175
|
+
output = self.final_layer(tmp_future)
|
|
176
|
+
return output.reshape(BS,self.future_steps,self.out_channels,self.mul)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import lightning.pytorch as pl
|
|
6
|
+
from .base_v2 import Base
|
|
7
|
+
OLD_PL = False
|
|
8
|
+
except:
|
|
9
|
+
import pytorch_lightning as pl
|
|
10
|
+
OLD_PL = True
|
|
11
|
+
from .base import Base
|
|
12
|
+
from .utils import L1Loss
|
|
13
|
+
from ..data_structure.utils import beauty_string
|
|
14
|
+
from .utils import get_scope
|
|
15
|
+
|
|
16
|
+
class Persistent(Base):
|
|
17
|
+
handle_multivariate = True
|
|
18
|
+
handle_future_covariates = False
|
|
19
|
+
handle_categorical_variables = False
|
|
20
|
+
handle_quantile_loss = False
|
|
21
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
22
|
+
|
|
23
|
+
def __init__(self,
|
|
24
|
+
**kwargs)->None:
|
|
25
|
+
"""
|
|
26
|
+
Simple persistent model aligned with all the other
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
super().__init__(**kwargs)
|
|
30
|
+
self.save_hyperparameters(logger=False)
|
|
31
|
+
self.fake = nn.Linear(1,1)
|
|
32
|
+
self.use_quantiles = False
|
|
33
|
+
|
|
34
|
+
def forward(self, batch):
|
|
35
|
+
|
|
36
|
+
x = batch['x_num_past'].to(self.device)
|
|
37
|
+
idx_target = batch['idx_target'][0]
|
|
38
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
39
|
+
#this is B,1,C
|
|
40
|
+
#[B,L,C,1] remember the outoput size
|
|
41
|
+
res = x_start.repeat(1,self.future_steps,1).unsqueeze(3)
|
|
42
|
+
|
|
43
|
+
return res
|
|
44
|
+
|
dsipts/models/RNN.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
from .utils import QuantileLossMO,Permute,get_activation
|
|
14
|
+
from typing import List,Union
|
|
15
|
+
from ..data_structure.utils import beauty_string
|
|
16
|
+
from .utils import get_scope
|
|
17
|
+
from .xlstm.xLSTM import xLSTM
|
|
18
|
+
from .utils import Embedding_cat_variables
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MyBN(nn.Module):
|
|
22
|
+
def __init__(self,channels):
|
|
23
|
+
super(MyBN, self).__init__()
|
|
24
|
+
self.BN = nn.BatchNorm1d(channels)
|
|
25
|
+
def forward(self,x):
|
|
26
|
+
return self.BN(x.permute(0,2,1)).permute(0,2,1)
|
|
27
|
+
|
|
28
|
+
class RNN(Base):
|
|
29
|
+
handle_multivariate = True
|
|
30
|
+
handle_future_covariates = True
|
|
31
|
+
handle_categorical_variables = True
|
|
32
|
+
handle_quantile_loss = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def __init__(self,
|
|
37
|
+
|
|
38
|
+
hidden_RNN:int,
|
|
39
|
+
num_layers_RNN:int,
|
|
40
|
+
kind:str,
|
|
41
|
+
kernel_size:int,
|
|
42
|
+
activation:str='torch.nn.ReLU',
|
|
43
|
+
remove_last = False,
|
|
44
|
+
dropout_rate:float=0.1,
|
|
45
|
+
use_bn:bool=False,
|
|
46
|
+
num_blocks:int=4,
|
|
47
|
+
bidirectional:bool=True,
|
|
48
|
+
lstm_type:str='slstm',
|
|
49
|
+
|
|
50
|
+
**kwargs)->None:
|
|
51
|
+
"""Initialize a recurrent model with an encoder-decoder structure.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
hidden_RNN (int): Hidden size of the RNN block.
|
|
55
|
+
num_layers_RNN (int): Number of RNN layers.
|
|
56
|
+
kind (str): Type of RNN to use, either 'gru' or 'lstm' or `xlstm`.
|
|
57
|
+
kernel_size (int): Kernel size in the encoder convolutional block.
|
|
58
|
+
activation (str, optional): Activation function from PyTorch. Default is 'torch.nn.ReLU'.
|
|
59
|
+
remove_last (bool, optional): If True, the model learns the difference with respect to the last seen point. Default is False.
|
|
60
|
+
dropout_rate (float, optional): Dropout rate in Dropout layers. Default is 0.1.
|
|
61
|
+
use_bn (bool, optional): If True, Batch Normalization layers will be added and Dropouts will be removed. Default is False.
|
|
62
|
+
num_blocks (int, optional): Number of xLSTM blocks (only for xLSTM). Default is 4.
|
|
63
|
+
bidirectional (bool, optional): If True, the RNN is bidirectional. Default is True.
|
|
64
|
+
lstm_type (str, optional): Type of LSTM to use (only for xLSTM), either 'slstm' or 'mlstm'. Default is 'slstm'.
|
|
65
|
+
**kwargs: Additional keyword arguments.
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If the specified kind is not 'lstm', 'gru', or 'xlstm'.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
|
|
75
|
+
if activation == 'torch.nn.SELU':
|
|
76
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
77
|
+
use_bn = False
|
|
78
|
+
if isinstance(activation, str):
|
|
79
|
+
activation = get_activation(activation)
|
|
80
|
+
else:
|
|
81
|
+
beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
|
|
82
|
+
|
|
83
|
+
self.save_hyperparameters(logger=False)
|
|
84
|
+
|
|
85
|
+
self.num_layers_RNN = num_layers_RNN
|
|
86
|
+
self.hidden_RNN = hidden_RNN
|
|
87
|
+
|
|
88
|
+
self.kind = kind
|
|
89
|
+
self.remove_last = remove_last
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
93
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
94
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
95
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
self.initial_linear_encoder = nn.Sequential(nn.Linear(self.past_channels,4),
|
|
102
|
+
activation(),
|
|
103
|
+
|
|
104
|
+
MyBN(4) if use_bn else nn.Dropout(dropout_rate) ,
|
|
105
|
+
nn.Linear(4,8),
|
|
106
|
+
activation(),
|
|
107
|
+
MyBN(8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
108
|
+
nn.Linear(8,hidden_RNN//8))
|
|
109
|
+
self.initial_linear_decoder = nn.Sequential(nn.Linear(self.future_channels,4),
|
|
110
|
+
activation(),
|
|
111
|
+
MyBN(4) if use_bn else nn.Dropout(dropout_rate) ,
|
|
112
|
+
nn.Linear(4,8),
|
|
113
|
+
activation(),
|
|
114
|
+
MyBN(8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
115
|
+
nn.Linear(8,hidden_RNN//8))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
self.conv_encoder = nn.Sequential(Permute(), nn.Conv1d(emb_past_out_channel+hidden_RNN//8, hidden_RNN//8, kernel_size, stride=1,padding='same'),Permute(),nn.Dropout(0.3))
|
|
119
|
+
|
|
120
|
+
if self.future_channels+emb_fut_out_channel==0:
|
|
121
|
+
## occhio che vuol dire che non ho futuro , per ora ci metto una pezza e uso hidden dell'encoder
|
|
122
|
+
self.conv_decoder = nn.Sequential(Permute(),nn.Conv1d(hidden_RNN, hidden_RNN//8, kernel_size=kernel_size, stride=1,padding='same'), Permute())
|
|
123
|
+
else:
|
|
124
|
+
self.conv_decoder = nn.Sequential(Permute(),nn.Conv1d(self.future_channels+emb_fut_out_channel, hidden_RNN//8, kernel_size=kernel_size, stride=1,padding='same'), Permute())
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if self.kind=='lstm':
|
|
128
|
+
self.Encoder = nn.LSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
|
|
129
|
+
self.Decoder = nn.LSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
|
|
130
|
+
elif self.kind=='gru':
|
|
131
|
+
self.Encoder = nn.GRU(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
|
|
132
|
+
self.Decoder = nn.GRU(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
|
|
133
|
+
elif self.kind=='xlstm':
|
|
134
|
+
self.Encoder = xLSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,num_blocks=num_blocks,dropout=dropout_rate, bidirectional=bidirectional, lstm_type=lstm_type)
|
|
135
|
+
self.Decoder = xLSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,num_blocks=num_blocks,dropout=dropout_rate, bidirectional=bidirectional, lstm_type=lstm_type)
|
|
136
|
+
|
|
137
|
+
else:
|
|
138
|
+
beauty_string('Speciky kind= lstm or gru please','section',True)
|
|
139
|
+
self.final_linear = nn.ModuleList()
|
|
140
|
+
for _ in range(self.out_channels*self.mul):
|
|
141
|
+
self.final_linear.append(nn.Sequential(nn.Linear(hidden_RNN,hidden_RNN//2),
|
|
142
|
+
activation(),
|
|
143
|
+
MyBN(hidden_RNN//2) if use_bn else nn.Dropout(dropout_rate) ,
|
|
144
|
+
nn.Linear(hidden_RNN//2,hidden_RNN//4),
|
|
145
|
+
activation(),
|
|
146
|
+
MyBN(hidden_RNN//4) if use_bn else nn.Dropout(dropout_rate) ,
|
|
147
|
+
nn.Linear(hidden_RNN//4,hidden_RNN//8),
|
|
148
|
+
activation(),
|
|
149
|
+
MyBN(hidden_RNN//8) if use_bn else nn.Dropout(dropout_rate) ,
|
|
150
|
+
nn.Linear(hidden_RNN//8,1)))
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def forward(self, batch):
|
|
155
|
+
|
|
156
|
+
x = batch['x_num_past'].to(self.device)
|
|
157
|
+
|
|
158
|
+
BS = x.shape[0]
|
|
159
|
+
if 'x_cat_future' in batch.keys():
|
|
160
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
161
|
+
else:
|
|
162
|
+
emb_fut = self.emb_fut(BS,None)
|
|
163
|
+
if 'x_cat_past' in batch.keys():
|
|
164
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
165
|
+
else:
|
|
166
|
+
emb_past = self.emb_past(BS,None)
|
|
167
|
+
|
|
168
|
+
if 'x_num_future' in batch.keys():
|
|
169
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
170
|
+
else:
|
|
171
|
+
x_future = None
|
|
172
|
+
|
|
173
|
+
if self.remove_last:
|
|
174
|
+
idx_target = batch['idx_target'][0]
|
|
175
|
+
|
|
176
|
+
x_start = x[:,-1,idx_target].unsqueeze(1)
|
|
177
|
+
##BxC
|
|
178
|
+
x[:,:,idx_target]-=x_start
|
|
179
|
+
|
|
180
|
+
tmp = [self.initial_linear_encoder(x),emb_past]
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
tot = torch.cat(tmp,2)
|
|
184
|
+
|
|
185
|
+
out, hidden = self.Encoder(self.conv_encoder(tot))
|
|
186
|
+
|
|
187
|
+
tmp = [emb_fut]
|
|
188
|
+
|
|
189
|
+
if x_future is not None:
|
|
190
|
+
tmp.append(x_future)
|
|
191
|
+
|
|
192
|
+
if len(tmp)>0:
|
|
193
|
+
tot = torch.cat(tmp,2)
|
|
194
|
+
else:
|
|
195
|
+
tot = out
|
|
196
|
+
out, _ = self.Decoder(self.conv_decoder(tot[:,-1:,:].repeat(1,self.future_steps,1)),hidden)
|
|
197
|
+
res = []
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
for j in range(len(self.final_linear)):
|
|
201
|
+
res.append(self.final_linear[j](out))
|
|
202
|
+
|
|
203
|
+
res = torch.cat(res,2)
|
|
204
|
+
##BxLxC
|
|
205
|
+
B,L,_ = res.shape
|
|
206
|
+
res = res.reshape(B,L,-1,self.mul)
|
|
207
|
+
|
|
208
|
+
if self.remove_last:
|
|
209
|
+
res+=x_start.unsqueeze(1)
|
|
210
|
+
|
|
211
|
+
return res
|
|
212
|
+
|
|
213
|
+
|