dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,222 @@
1
+
2
+ ## Copyright 2022 DLinear Authors (https://github.com/cure-lab/LTSF-Linear/tree/main?tab=Apache-2.0-1-ov-file#readme)
3
+ ## Code modified for align the notation and the batch generation
4
+ ## extended to all present in informer, autoformer folder
5
+
6
+ from torch import nn
7
+ import torch
8
+
9
+ try:
10
+ import lightning.pytorch as pl
11
+ from .base_v2 import Base
12
+ OLD_PL = False
13
+ except:
14
+ import pytorch_lightning as pl
15
+ OLD_PL = True
16
+ from .base import Base
17
+ from .utils import QuantileLossMO, get_activation
18
+ from typing import List, Union
19
+ from ..data_structure.utils import beauty_string
20
+ from .utils import get_scope
21
+ from .utils import Embedding_cat_variables
22
+
23
+
24
+
25
+ class moving_avg(nn.Module):
26
+ """
27
+ Moving average block to highlight the trend of time series
28
+ """
29
+ def __init__(self, kernel_size, stride):
30
+ super(moving_avg, self).__init__()
31
+ self.kernel_size = kernel_size
32
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
33
+
34
+ def forward(self, x):
35
+ # padding on the both ends of time series
36
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
37
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
38
+ x = torch.cat([front, x, end], dim=1)
39
+ x = self.avg(x.permute(0, 2, 1))
40
+ x = x.permute(0, 2, 1)
41
+ return x
42
+
43
+
44
+ class series_decomp(nn.Module):
45
+ """
46
+ Series decomposition block
47
+ """
48
+ def __init__(self, kernel_size):
49
+ super(series_decomp, self).__init__()
50
+ self.moving_avg = moving_avg(kernel_size, stride=1)
51
+
52
+ def forward(self, x):
53
+ moving_mean = self.moving_avg(x)
54
+ res = x - moving_mean
55
+ return res, moving_mean
56
+
57
+
58
+ class LinearTS(Base):
59
+ handle_multivariate = True
60
+ handle_future_covariates = True
61
+ handle_categorical_variables = True
62
+ handle_quantile_loss = True
63
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
64
+ description+='\n THE SIMPLE IMPLEMENTATION DOES NOT USE CATEGORICAL NOR FUTURE VARIABLES'
65
+
66
+ def __init__(self,
67
+
68
+ kernel_size:int,
69
+ hidden_size:int,
70
+ dropout_rate:float=0.1,
71
+ activation:str='torch.nn.ReLU',
72
+ kind:str='linear',
73
+ use_bn:bool=False,
74
+ simple:bool=False,
75
+ **kwargs)->None:
76
+ """Initialize the model with specified parameters. Linear model from https://github.com/cure-lab/LTSF-Linear/blob/main/run_longExp.py
77
+
78
+ Args:
79
+ kernel_size (int): Kernel dimension for the initial moving average.
80
+ hidden_size (int): Hidden size of the linear block.
81
+ dropout_rate (float, optional): Dropout rate in Dropout layers. Default is 0.1.
82
+ activation (str, optional): Activation function in PyTorch. Default is 'torch.nn.ReLU'.
83
+ kind (str, optional): Type of model, can be 'linear', 'dlinear' (de-trending), or 'nlinear' (differential). Defaults to 'linear'.
84
+ use_bn (bool, optional): If True, Batch Normalization layers will be added and Dropouts will be removed. Default is False.
85
+ simple (bool, optional): If True, the model used is the same as illustrated in the paper; otherwise, a more complex model with the same idea is used. Default is False.
86
+ **kwargs: Additional keyword arguments for the parent class.
87
+
88
+ Raises:
89
+ ValueError: If an invalid activation function is provided.
90
+ """
91
+
92
+ super().__init__(**kwargs)
93
+
94
+ if activation == 'torch.nn.SELU':
95
+ beauty_string('SELU do not require BN','info',self.verbose)
96
+ use_bn = False
97
+
98
+ if isinstance(activation, str):
99
+ activation = get_activation(activation)
100
+ else:
101
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice','info',self.verbose)
102
+
103
+ self.save_hyperparameters(logger=False)
104
+
105
+ self.kind = kind
106
+
107
+
108
+ self.simple = simple
109
+
110
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
111
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
112
+ emb_past_out_channel = self.emb_past.output_channels
113
+ emb_fut_out_channel = self.emb_fut.output_channels
114
+
115
+
116
+
117
+ ## ne faccio uno per ogni canale
118
+ self.linear = nn.ModuleList()
119
+
120
+
121
+ if kind=='dlinear':
122
+ self.decompsition = series_decomp(kernel_size)
123
+ self.Linear_Trend = nn.ModuleList()
124
+ for _ in range(self.out_channels):
125
+ self.Linear_Trend.append(nn.Linear(self.past_steps,self.future_steps))
126
+
127
+
128
+ for _ in range(self.out_channels):
129
+ if simple:
130
+ self.linear.append(nn.Linear(self.past_steps,self.future_steps*self.mul))
131
+
132
+ else:
133
+ self.linear.append(nn.Sequential(nn.Linear(emb_past_out_channel*self.past_steps+emb_fut_out_channel*self.future_steps+self.past_steps*self.past_channels+self.future_channels*self.future_steps,hidden_size),
134
+ activation(),
135
+ nn.BatchNorm1d(hidden_size) if use_bn else nn.Dropout(dropout_rate) ,
136
+ nn.Linear(hidden_size,hidden_size//2),
137
+ activation(),
138
+ nn.BatchNorm1d(hidden_size//2) if use_bn else nn.Dropout(dropout_rate) ,
139
+ nn.Linear(hidden_size//2,hidden_size//4),
140
+ activation(),
141
+ nn.BatchNorm1d(hidden_size//4) if use_bn else nn.Dropout(dropout_rate) ,
142
+ nn.Linear(hidden_size//4,hidden_size//8),
143
+ activation(),
144
+ nn.BatchNorm1d(hidden_size//8) if use_bn else nn.Dropout(dropout_rate) ,
145
+ nn.Linear(hidden_size//8,self.future_steps*self.mul)))
146
+
147
+ def forward(self, batch):
148
+
149
+ x = batch['x_num_past'].to(self.device)
150
+ idx_target = batch['idx_target'][0]
151
+
152
+ BS = x.shape[0]
153
+ if 'x_cat_future' in batch.keys():
154
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
155
+ else:
156
+ emb_fut = self.emb_fut(BS,None)
157
+ if 'x_cat_past' in batch.keys():
158
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
159
+ else:
160
+ emb_past = self.emb_past(BS,None)
161
+
162
+ if self.kind=='nlinear':
163
+
164
+ x_start = x[:,-1,idx_target].unsqueeze(1)
165
+ x[:,:,idx_target]-=x_start
166
+
167
+ if self.kind=='alinear':
168
+ x[:,:,idx_target] = 0
169
+
170
+ if self.kind=='dlinear':
171
+ x_start = x[:,:,idx_target]
172
+ seasonal_init, trend_init = self.decompsition(x_start)
173
+ seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
174
+ x[:,:,idx_target] = seasonal_init.permute(0,2,1)
175
+ tmp = []
176
+ for j in range(len(self.Linear_Trend)):
177
+
178
+ tmp.append(self.Linear_Trend[j](trend_init[:,j,:]))
179
+
180
+ trend = torch.stack(tmp,2)
181
+
182
+ if self.simple is False:
183
+ if 'x_num_future' in batch.keys():
184
+ x_future = batch['x_num_future'].to(self.device)
185
+ else:
186
+ x_future = None
187
+
188
+ tmp = [x,emb_past]
189
+ tot_past = torch.cat(tmp,2).flatten(1)
190
+
191
+
192
+
193
+ tmp = [emb_fut]
194
+
195
+ if x_future is not None:
196
+ tmp.append(x_future)
197
+
198
+ tot_future = torch.cat(tmp,2).flatten(1)
199
+ tot = torch.cat([tot_past,tot_future],1)
200
+
201
+ tot = tot.unsqueeze(2).repeat(1,1,len(self.linear)).permute(0,2,1)
202
+ else:
203
+ tot = x.permute(0,2,1)
204
+ res = []
205
+
206
+ for j in range(len(self.linear)):
207
+ res.append(self.linear[j](tot[:,j,:]).reshape(BS,-1,self.mul))
208
+ ## BxLxCxMUL
209
+ res = torch.stack(res,2)
210
+
211
+ if self.kind=='nlinear':
212
+ #res BxLxCx3
213
+ #start BxCx1
214
+ res+=x_start.unsqueeze(1)
215
+
216
+
217
+ if self.kind=='dlinear':
218
+ res = res+trend.unsqueeze(3)
219
+
220
+
221
+ return res
222
+
@@ -0,0 +1,181 @@
1
+ ## Copyright https://github.com/yuqinie98/PatchTST/blob/main/LICENSE
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside patchtst folder
4
+
5
+
6
+
7
+ from torch import nn
8
+ import torch
9
+
10
+ try:
11
+ import lightning.pytorch as pl
12
+ from .base_v2 import Base
13
+ OLD_PL = False
14
+ except:
15
+ import pytorch_lightning as pl
16
+ OLD_PL = True
17
+ from .base import Base
18
+ from typing import List,Union
19
+ from ..data_structure.utils import beauty_string
20
+ from .utils import get_scope
21
+ from .utils import get_activation
22
+ from .patchtst.layers import series_decomp, PatchTST_backbone
23
+ from .utils import Embedding_cat_variables
24
+
25
+
26
+
27
+
28
+ class PatchTST(Base):
29
+ handle_multivariate = True
30
+ handle_future_covariates = False
31
+ handle_categorical_variables = True
32
+ handle_quantile_loss = True
33
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
34
+
35
+
36
+ def __init__(self,
37
+
38
+ d_model:int,
39
+ patch_len:int,
40
+ kernel_size:int,
41
+ decomposition:bool=True,
42
+ activation:str='torch.nn.ReLU',
43
+ n_head:int=1,
44
+ n_layer:int=2,
45
+ stride:int=8,
46
+ remove_last:bool = False,
47
+ hidden_size:int=1048,
48
+ dropout_rate:float=0.1,
49
+ **kwargs)->None:
50
+ """Initializes the model with specified parameters.https://github.com/yuqinie98/PatchTST/blob/main/
51
+
52
+ Args:
53
+ d_model (int): The dimensionality of the model.
54
+ patch_len (int): The length of the patches.
55
+ kernel_size (int): The size of the kernel for convolutional layers.
56
+ decomposition (bool, optional): Whether to use decomposition. Defaults to True.
57
+ activation (str, optional): The activation function to use. Defaults to 'torch.nn.ReLU'.
58
+ n_head (int, optional): The number of attention heads. Defaults to 1.
59
+ n_layer (int, optional): The number of layers in the model. Defaults to 2.
60
+ stride (int, optional): The stride for convolutional layers. Defaults to 8.
61
+ remove_last (bool, optional): Whether to remove the last layer. Defaults to False.
62
+ hidden_size (int, optional): The size of the hidden layers. Defaults to 1048.
63
+ dropout_rate (float, optional): The dropout rate for regularization. Defaults to 0.1.
64
+ **kwargs: Additional keyword arguments.
65
+
66
+ Raises:
67
+ ValueError: If the activation function is not recognized.
68
+
69
+
70
+ """
71
+ super().__init__(**kwargs)
72
+
73
+ if activation == 'torch.nn.SELU':
74
+ beauty_string('SELU do not require BN','info',self.verbose)
75
+ if isinstance(activation, str):
76
+ activation = get_activation(activation)
77
+ else:
78
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
79
+
80
+
81
+ self.save_hyperparameters(logger=False)
82
+
83
+ self.remove_last = remove_last
84
+
85
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
86
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
87
+ emb_past_out_channel = self.emb_past.output_channels
88
+ emb_fut_out_channel = self.emb_fut.output_channels
89
+
90
+
91
+
92
+ self.past_channels+=emb_past_out_channel
93
+
94
+
95
+ # model
96
+ self.decomposition = decomposition
97
+ if self.decomposition:
98
+ self.decomp_module = series_decomp(kernel_size)
99
+ self.model_trend = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
100
+ max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
101
+ n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
102
+ dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
103
+ attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
104
+ pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
105
+ pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
106
+ subtract_last=remove_last, verbose=False)
107
+ self.model_res = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
108
+ max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
109
+ n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
110
+ dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
111
+ attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
112
+ pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
113
+ pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
114
+ subtract_last=remove_last, verbose=False)
115
+ else:
116
+ self.model = PatchTST_backbone(c_in=self.past_channels, context_window = self.past_steps, target_window=self.future_steps, patch_len=patch_len, stride=stride,
117
+ max_seq_len=self.past_steps+self.future_steps, n_layers=n_layer, d_model=d_model,
118
+ n_heads=n_head, d_k=None, d_v=None, d_ff=hidden_size, norm='BatchNorm', attn_dropout=dropout_rate,
119
+ dropout=dropout_rate, act=activation(), key_padding_mask='auto', padding_var=None,
120
+ attn_mask=None, res_attention=True, pre_norm=False, store_attn=False,
121
+ pe='zeros', learn_pe=True, fc_dropout=dropout_rate, head_dropout=dropout_rate, padding_patch = 'end',
122
+ pretrain_head=False, head_type='flatten', individual=False, revin=True, affine=False,
123
+ subtract_last=remove_last, verbose=False)
124
+
125
+
126
+ dim = self.past_channels+emb_fut_out_channel+self.future_channels
127
+ self.final_layer = nn.Sequential(activation(),
128
+ nn.Linear(dim, dim*2),
129
+ activation(),
130
+ nn.Linear(dim*2,self.out_channels*self.mul ))
131
+
132
+
133
+
134
+ #self.final_linear = nn.Sequential(nn.Linear(past_channels,past_channels//2),activation(),nn.Dropout(dropout_rate), nn.Linear(past_channels//2,out_channels) )
135
+
136
+ def forward(self, batch): # x: [Batch, Input length, Channel]
137
+
138
+
139
+ x_seq = batch['x_num_past'].to(self.device)#[:,:,idx_target]
140
+ BS = x_seq.shape[0]
141
+ if 'x_cat_future' in batch.keys():
142
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
143
+ else:
144
+ emb_fut = self.emb_fut(BS,None)
145
+ if 'x_cat_past' in batch.keys():
146
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
147
+ else:
148
+ emb_past = self.emb_past(BS,None)
149
+
150
+ tmp_future = [emb_fut]
151
+ if 'x_num_future' in batch.keys():
152
+ x_future = batch['x_num_future'].to(self.device)
153
+ tmp_future.append(x_future)
154
+
155
+
156
+ tot = [x_seq,emb_past]
157
+
158
+ x_seq = torch.cat(tot,axis=2)
159
+
160
+ if self.decomposition:
161
+ res_init, trend_init = self.decomp_module(x_seq)
162
+ res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length]
163
+ res = self.model_res(res_init)
164
+ trend = self.model_trend(trend_init)
165
+ x = res + trend
166
+ x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
167
+ else:
168
+ x = x_seq.permute(0,2,1)# x: [Batch, Channel, Input length]
169
+ x = self.model(x)
170
+ x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
171
+
172
+
173
+ tmp_future.append(x)
174
+ tmp_future = torch.cat(tmp_future,2)
175
+ output = self.final_layer(tmp_future)
176
+ return output.reshape(BS,self.future_steps,self.out_channels,self.mul)
177
+
178
+
179
+
180
+
181
+
@@ -0,0 +1,44 @@
1
+
2
+ from torch import nn
3
+
4
+ try:
5
+ import lightning.pytorch as pl
6
+ from .base_v2 import Base
7
+ OLD_PL = False
8
+ except:
9
+ import pytorch_lightning as pl
10
+ OLD_PL = True
11
+ from .base import Base
12
+ from .utils import L1Loss
13
+ from ..data_structure.utils import beauty_string
14
+ from .utils import get_scope
15
+
16
+ class Persistent(Base):
17
+ handle_multivariate = True
18
+ handle_future_covariates = False
19
+ handle_categorical_variables = False
20
+ handle_quantile_loss = False
21
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
22
+
23
+ def __init__(self,
24
+ **kwargs)->None:
25
+ """
26
+ Simple persistent model aligned with all the other
27
+ """
28
+
29
+ super().__init__(**kwargs)
30
+ self.save_hyperparameters(logger=False)
31
+ self.fake = nn.Linear(1,1)
32
+ self.use_quantiles = False
33
+
34
+ def forward(self, batch):
35
+
36
+ x = batch['x_num_past'].to(self.device)
37
+ idx_target = batch['idx_target'][0]
38
+ x_start = x[:,-1,idx_target].unsqueeze(1)
39
+ #this is B,1,C
40
+ #[B,L,C,1] remember the outoput size
41
+ res = x_start.repeat(1,self.future_steps,1).unsqueeze(3)
42
+
43
+ return res
44
+
dsipts/models/RNN.py ADDED
@@ -0,0 +1,213 @@
1
+
2
+ from torch import nn
3
+ import torch
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+ from .utils import QuantileLossMO,Permute,get_activation
14
+ from typing import List,Union
15
+ from ..data_structure.utils import beauty_string
16
+ from .utils import get_scope
17
+ from .xlstm.xLSTM import xLSTM
18
+ from .utils import Embedding_cat_variables
19
+
20
+
21
+ class MyBN(nn.Module):
22
+ def __init__(self,channels):
23
+ super(MyBN, self).__init__()
24
+ self.BN = nn.BatchNorm1d(channels)
25
+ def forward(self,x):
26
+ return self.BN(x.permute(0,2,1)).permute(0,2,1)
27
+
28
+ class RNN(Base):
29
+ handle_multivariate = True
30
+ handle_future_covariates = True
31
+ handle_categorical_variables = True
32
+ handle_quantile_loss = True
33
+
34
+
35
+
36
+ def __init__(self,
37
+
38
+ hidden_RNN:int,
39
+ num_layers_RNN:int,
40
+ kind:str,
41
+ kernel_size:int,
42
+ activation:str='torch.nn.ReLU',
43
+ remove_last = False,
44
+ dropout_rate:float=0.1,
45
+ use_bn:bool=False,
46
+ num_blocks:int=4,
47
+ bidirectional:bool=True,
48
+ lstm_type:str='slstm',
49
+
50
+ **kwargs)->None:
51
+ """Initialize a recurrent model with an encoder-decoder structure.
52
+
53
+ Args:
54
+ hidden_RNN (int): Hidden size of the RNN block.
55
+ num_layers_RNN (int): Number of RNN layers.
56
+ kind (str): Type of RNN to use, either 'gru' or 'lstm' or `xlstm`.
57
+ kernel_size (int): Kernel size in the encoder convolutional block.
58
+ activation (str, optional): Activation function from PyTorch. Default is 'torch.nn.ReLU'.
59
+ remove_last (bool, optional): If True, the model learns the difference with respect to the last seen point. Default is False.
60
+ dropout_rate (float, optional): Dropout rate in Dropout layers. Default is 0.1.
61
+ use_bn (bool, optional): If True, Batch Normalization layers will be added and Dropouts will be removed. Default is False.
62
+ num_blocks (int, optional): Number of xLSTM blocks (only for xLSTM). Default is 4.
63
+ bidirectional (bool, optional): If True, the RNN is bidirectional. Default is True.
64
+ lstm_type (str, optional): Type of LSTM to use (only for xLSTM), either 'slstm' or 'mlstm'. Default is 'slstm'.
65
+ **kwargs: Additional keyword arguments.
66
+
67
+
68
+ Raises:
69
+ ValueError: If the specified kind is not 'lstm', 'gru', or 'xlstm'.
70
+ """
71
+
72
+
73
+ super().__init__(**kwargs)
74
+
75
+ if activation == 'torch.nn.SELU':
76
+ beauty_string('SELU do not require BN','info',self.verbose)
77
+ use_bn = False
78
+ if isinstance(activation, str):
79
+ activation = get_activation(activation)
80
+ else:
81
+ beauty_string('There is a bug in pytorch lightening, the constructior is called twice ','info',self.verbose)
82
+
83
+ self.save_hyperparameters(logger=False)
84
+
85
+ self.num_layers_RNN = num_layers_RNN
86
+ self.hidden_RNN = hidden_RNN
87
+
88
+ self.kind = kind
89
+ self.remove_last = remove_last
90
+
91
+
92
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
93
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
94
+ emb_past_out_channel = self.emb_past.output_channels
95
+ emb_fut_out_channel = self.emb_fut.output_channels
96
+
97
+
98
+
99
+
100
+
101
+ self.initial_linear_encoder = nn.Sequential(nn.Linear(self.past_channels,4),
102
+ activation(),
103
+
104
+ MyBN(4) if use_bn else nn.Dropout(dropout_rate) ,
105
+ nn.Linear(4,8),
106
+ activation(),
107
+ MyBN(8) if use_bn else nn.Dropout(dropout_rate) ,
108
+ nn.Linear(8,hidden_RNN//8))
109
+ self.initial_linear_decoder = nn.Sequential(nn.Linear(self.future_channels,4),
110
+ activation(),
111
+ MyBN(4) if use_bn else nn.Dropout(dropout_rate) ,
112
+ nn.Linear(4,8),
113
+ activation(),
114
+ MyBN(8) if use_bn else nn.Dropout(dropout_rate) ,
115
+ nn.Linear(8,hidden_RNN//8))
116
+
117
+
118
+ self.conv_encoder = nn.Sequential(Permute(), nn.Conv1d(emb_past_out_channel+hidden_RNN//8, hidden_RNN//8, kernel_size, stride=1,padding='same'),Permute(),nn.Dropout(0.3))
119
+
120
+ if self.future_channels+emb_fut_out_channel==0:
121
+ ## occhio che vuol dire che non ho futuro , per ora ci metto una pezza e uso hidden dell'encoder
122
+ self.conv_decoder = nn.Sequential(Permute(),nn.Conv1d(hidden_RNN, hidden_RNN//8, kernel_size=kernel_size, stride=1,padding='same'), Permute())
123
+ else:
124
+ self.conv_decoder = nn.Sequential(Permute(),nn.Conv1d(self.future_channels+emb_fut_out_channel, hidden_RNN//8, kernel_size=kernel_size, stride=1,padding='same'), Permute())
125
+
126
+
127
+ if self.kind=='lstm':
128
+ self.Encoder = nn.LSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
129
+ self.Decoder = nn.LSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
130
+ elif self.kind=='gru':
131
+ self.Encoder = nn.GRU(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
132
+ self.Decoder = nn.GRU(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,batch_first=True)
133
+ elif self.kind=='xlstm':
134
+ self.Encoder = xLSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,num_blocks=num_blocks,dropout=dropout_rate, bidirectional=bidirectional, lstm_type=lstm_type)
135
+ self.Decoder = xLSTM(input_size= hidden_RNN//8,hidden_size=hidden_RNN,num_layers = num_layers_RNN,num_blocks=num_blocks,dropout=dropout_rate, bidirectional=bidirectional, lstm_type=lstm_type)
136
+
137
+ else:
138
+ beauty_string('Speciky kind= lstm or gru please','section',True)
139
+ self.final_linear = nn.ModuleList()
140
+ for _ in range(self.out_channels*self.mul):
141
+ self.final_linear.append(nn.Sequential(nn.Linear(hidden_RNN,hidden_RNN//2),
142
+ activation(),
143
+ MyBN(hidden_RNN//2) if use_bn else nn.Dropout(dropout_rate) ,
144
+ nn.Linear(hidden_RNN//2,hidden_RNN//4),
145
+ activation(),
146
+ MyBN(hidden_RNN//4) if use_bn else nn.Dropout(dropout_rate) ,
147
+ nn.Linear(hidden_RNN//4,hidden_RNN//8),
148
+ activation(),
149
+ MyBN(hidden_RNN//8) if use_bn else nn.Dropout(dropout_rate) ,
150
+ nn.Linear(hidden_RNN//8,1)))
151
+
152
+
153
+
154
+ def forward(self, batch):
155
+
156
+ x = batch['x_num_past'].to(self.device)
157
+
158
+ BS = x.shape[0]
159
+ if 'x_cat_future' in batch.keys():
160
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
161
+ else:
162
+ emb_fut = self.emb_fut(BS,None)
163
+ if 'x_cat_past' in batch.keys():
164
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
165
+ else:
166
+ emb_past = self.emb_past(BS,None)
167
+
168
+ if 'x_num_future' in batch.keys():
169
+ x_future = batch['x_num_future'].to(self.device)
170
+ else:
171
+ x_future = None
172
+
173
+ if self.remove_last:
174
+ idx_target = batch['idx_target'][0]
175
+
176
+ x_start = x[:,-1,idx_target].unsqueeze(1)
177
+ ##BxC
178
+ x[:,:,idx_target]-=x_start
179
+
180
+ tmp = [self.initial_linear_encoder(x),emb_past]
181
+
182
+
183
+ tot = torch.cat(tmp,2)
184
+
185
+ out, hidden = self.Encoder(self.conv_encoder(tot))
186
+
187
+ tmp = [emb_fut]
188
+
189
+ if x_future is not None:
190
+ tmp.append(x_future)
191
+
192
+ if len(tmp)>0:
193
+ tot = torch.cat(tmp,2)
194
+ else:
195
+ tot = out
196
+ out, _ = self.Decoder(self.conv_decoder(tot[:,-1:,:].repeat(1,self.future_steps,1)),hidden)
197
+ res = []
198
+
199
+
200
+ for j in range(len(self.final_linear)):
201
+ res.append(self.final_linear[j](out))
202
+
203
+ res = torch.cat(res,2)
204
+ ##BxLxC
205
+ B,L,_ = res.shape
206
+ res = res.reshape(B,L,-1,self.mul)
207
+
208
+ if self.remove_last:
209
+ res+=x_start.unsqueeze(1)
210
+
211
+ return res
212
+
213
+