dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/TIDE.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .tft import sub_nn
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import lightning.pytorch as pl
|
|
8
|
+
from .base_v2 import Base
|
|
9
|
+
OLD_PL = False
|
|
10
|
+
except:
|
|
11
|
+
import pytorch_lightning as pl
|
|
12
|
+
OLD_PL = True
|
|
13
|
+
from .base import Base
|
|
14
|
+
from .utils import QuantileLossMO
|
|
15
|
+
from typing import List, Union
|
|
16
|
+
from ..data_structure.utils import beauty_string
|
|
17
|
+
from .utils import get_scope
|
|
18
|
+
from .utils import Embedding_cat_variables
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TIDE(Base):
|
|
22
|
+
handle_multivariate = True
|
|
23
|
+
handle_future_covariates = True
|
|
24
|
+
handle_categorical_variables = True
|
|
25
|
+
handle_quantile_loss = True
|
|
26
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
27
|
+
|
|
28
|
+
def __init__(self,
|
|
29
|
+
hidden_size:int,
|
|
30
|
+
d_model: int,
|
|
31
|
+
n_add_enc: int,
|
|
32
|
+
n_add_dec: int,
|
|
33
|
+
dropout_rate: float,
|
|
34
|
+
activation: str='',
|
|
35
|
+
**kwargs)->None:
|
|
36
|
+
"""Initializes the model with specified parameters for a neural network architecture. Long-term Forecasting with TiDE: Time-series Dense Encoder
|
|
37
|
+
https://arxiv.org/abs/2304.08424
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
hidden_size (int): The size of the hidden layers.
|
|
41
|
+
d_model (int): The dimensionality of the model.
|
|
42
|
+
n_add_enc (int): The number of additional encoder layers.
|
|
43
|
+
n_add_dec (int): The number of additional decoder layers.
|
|
44
|
+
dropout_rate (float): The dropout rate to be applied in the layers.
|
|
45
|
+
activation (str, optional): The activation function to be used. Defaults to an empty string.
|
|
46
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
super().__init__(**kwargs)
|
|
53
|
+
self.save_hyperparameters(logger=False)
|
|
54
|
+
|
|
55
|
+
# self.dropout = dropout_rate
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
self.hidden_size = hidden_size # r
|
|
60
|
+
self.d_model = d_model # r^tilda
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# for other numerical variables in the past
|
|
64
|
+
self.aux_past_channels = self.past_channels - self.out_channels
|
|
65
|
+
self.linear_aux_past = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_past_channels)])
|
|
66
|
+
|
|
67
|
+
# for numerical variables in the future
|
|
68
|
+
self.aux_fut_channels = self.future_channels
|
|
69
|
+
self.linear_aux_fut = nn.ModuleList([nn.Linear(1, self.hidden_size) for _ in range(self.aux_fut_channels)])
|
|
70
|
+
|
|
71
|
+
#changinv from 1.1.5
|
|
72
|
+
# embedding categorical for both past and future (ASSUMING BOTH AVAILABLE OR NO ONE)
|
|
73
|
+
#self.seq_len = self.past_steps + self.future_steps
|
|
74
|
+
#self.emb_cat_var = sub_nn.embedding_cat_variables(self.seq_len, future_steps, hidden_size, embs, self.device)
|
|
75
|
+
|
|
76
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
77
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
78
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
79
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
## FEATURE PROJECTION
|
|
83
|
+
# past
|
|
84
|
+
if self.aux_past_channels>0:
|
|
85
|
+
self.feat_proj_past = ResidualBlock(hidden_size+emb_past_out_channel, d_model, dropout_rate, activation)
|
|
86
|
+
else:
|
|
87
|
+
self.feat_proj_past = ResidualBlock(emb_past_out_channel, d_model, dropout_rate, activation)
|
|
88
|
+
# future
|
|
89
|
+
if self.aux_fut_channels>0:
|
|
90
|
+
self.feat_proj_fut = ResidualBlock(hidden_size+emb_fut_out_channel, d_model, dropout_rate, activation)
|
|
91
|
+
else:
|
|
92
|
+
self.feat_proj_fut = ResidualBlock(emb_fut_out_channel, d_model, dropout_rate, activation)
|
|
93
|
+
|
|
94
|
+
# # ENCODER
|
|
95
|
+
self.enc_dim_input = self.past_steps*self.out_channels + (self.past_steps+self.future_steps)*d_model
|
|
96
|
+
self.enc_dim_output = self.future_steps*d_model
|
|
97
|
+
self.first_encoder = ResidualBlock(self.enc_dim_input, self.enc_dim_output, dropout_rate, activation)
|
|
98
|
+
self.aux_encoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_enc)])
|
|
99
|
+
|
|
100
|
+
# # DECODER
|
|
101
|
+
self.first_decoder = ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation)
|
|
102
|
+
self.aux_decoder = nn.ModuleList([ResidualBlock(self.enc_dim_output, self.enc_dim_output, dropout_rate, activation) for _ in range(1, n_add_dec)])
|
|
103
|
+
|
|
104
|
+
## TEMPORAL DECOER
|
|
105
|
+
self.temporal_decoder = ResidualBlock(2*d_model, self.out_channels*self.mul, dropout_rate, activation)
|
|
106
|
+
|
|
107
|
+
# linear for Y lookback
|
|
108
|
+
self.linear_target = nn.Linear(self.past_steps*self.out_channels, self.future_steps*self.out_channels*self.mul)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def forward(self, batch:dict)-> float:
|
|
112
|
+
"""training process of the diffusion network
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
batch (dict): variables loaded
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
float: total loss about the prediction of the noises over all subnets extracted
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# LOADING AUTOREGRESSIVE CONTEXT OF TARGET VARIABLES
|
|
122
|
+
|
|
123
|
+
num_past = batch['x_num_past'].to(self.device)
|
|
124
|
+
idx_target = batch['idx_target'][0]
|
|
125
|
+
y_past = num_past[:,:,idx_target]
|
|
126
|
+
B = y_past.shape[0]
|
|
127
|
+
|
|
128
|
+
# LOADING EMBEDDING CATEGORICAL VARIABLES
|
|
129
|
+
#emb_cat_past, emb_cat_fut = self.cat_categorical_vars(batch)
|
|
130
|
+
|
|
131
|
+
if 'x_cat_future' in batch.keys():
|
|
132
|
+
emb_fut = self.emb_fut(B,batch['x_cat_future'].to(self.device))
|
|
133
|
+
else:
|
|
134
|
+
emb_fut = self.emb_fut(B,None)
|
|
135
|
+
if 'x_cat_past' in batch.keys():
|
|
136
|
+
emb_past = self.emb_past(B,batch['x_cat_past'].to(self.device))
|
|
137
|
+
else:
|
|
138
|
+
emb_past = self.emb_past(B,None)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
#emb_cat_past = torch.mean(emb_cat_past, dim = 2)
|
|
143
|
+
#emb_cat_fut = torch.mean(emb_cat_fut, dim = 2)
|
|
144
|
+
|
|
145
|
+
### LOADING PAST AND FUTURE NUMERICAL VARIABLES
|
|
146
|
+
# load in the model auxiliar numerical variables
|
|
147
|
+
|
|
148
|
+
if self.aux_past_channels>0: # if we have more numerical variables about past
|
|
149
|
+
aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the autoregressive variable
|
|
150
|
+
assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about past
|
|
151
|
+
# concat all embedded vars and mean of them
|
|
152
|
+
aux_emb_num_past = torch.Tensor().to(self.device)
|
|
153
|
+
for i, layer in enumerate(self.linear_aux_past):
|
|
154
|
+
aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2)
|
|
155
|
+
aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2)
|
|
156
|
+
aux_emb_num_past = torch.mean(aux_emb_num_past, dim = 2)
|
|
157
|
+
else:
|
|
158
|
+
aux_emb_num_past = None # non available vars
|
|
159
|
+
|
|
160
|
+
if self.aux_fut_channels>0: # if we have more numerical variables about future
|
|
161
|
+
# AUX means AUXILIARY variables
|
|
162
|
+
aux_num_fut = batch['x_num_future'].to(self.device)
|
|
163
|
+
assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut
|
|
164
|
+
# concat all embedded vars and mean of them
|
|
165
|
+
aux_emb_num_fut = torch.Tensor().to(self.device)
|
|
166
|
+
for j, layer in enumerate(self.linear_aux_fut):
|
|
167
|
+
aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2)
|
|
168
|
+
aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2)
|
|
169
|
+
aux_emb_num_fut = torch.mean(aux_emb_num_fut, dim = 2)
|
|
170
|
+
else:
|
|
171
|
+
aux_emb_num_fut = None # non available vars
|
|
172
|
+
|
|
173
|
+
# past^tilda
|
|
174
|
+
|
|
175
|
+
if self.aux_past_channels>0:
|
|
176
|
+
emb_past = torch.cat((emb_past, aux_emb_num_past), dim = 2) # [B, L, 2R] #
|
|
177
|
+
proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] #
|
|
178
|
+
else:
|
|
179
|
+
proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] #
|
|
180
|
+
|
|
181
|
+
# fut^tilda
|
|
182
|
+
if self.aux_fut_channels>0:
|
|
183
|
+
emb_fut = torch.cat((emb_fut, aux_emb_num_fut), dim = 2) # [B, H, 2R] #
|
|
184
|
+
proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] #
|
|
185
|
+
else:
|
|
186
|
+
proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] #
|
|
187
|
+
|
|
188
|
+
concat = torch.cat((y_past.view(B, -1), proj_past.view(B, -1), proj_fut.view(B, -1)), dim = 1) # [B, L*self.mul + (L+H)*R^tilda] #
|
|
189
|
+
dense_enc = self.first_encoder(concat)
|
|
190
|
+
for lay_enc in self.aux_encoder:
|
|
191
|
+
dense_enc = lay_enc(dense_enc)
|
|
192
|
+
|
|
193
|
+
dense_dec = self.first_decoder(dense_enc)
|
|
194
|
+
for lay_dec in self.aux_decoder:
|
|
195
|
+
dense_dec = lay_dec(dense_dec)
|
|
196
|
+
|
|
197
|
+
temp_dec_input = torch.cat((dense_dec.view(B, self.future_steps, self.d_model), proj_fut), dim = 2)
|
|
198
|
+
temp_dec_output = self.temporal_decoder(temp_dec_input, False)
|
|
199
|
+
temp_dec_output = temp_dec_output.view(B, self.future_steps, self.out_channels, self.mul)
|
|
200
|
+
|
|
201
|
+
linear_regr = self.linear_target(y_past.view(B, -1))
|
|
202
|
+
linear_output = linear_regr.view(B, self.future_steps, self.out_channels, self.mul)
|
|
203
|
+
|
|
204
|
+
output = temp_dec_output + linear_output
|
|
205
|
+
return output
|
|
206
|
+
|
|
207
|
+
'''
|
|
208
|
+
# function to concat embedded categorical variables
|
|
209
|
+
def cat_categorical_vars(self, batch:dict):
|
|
210
|
+
"""Extracting categorical context about past and future
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
batch (dict): Keys checked -> ['x_cat_past', 'x_cat_future']
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
List[torch.Tensor, torch.Tensor]: cat_emb_past, cat_emb_fut
|
|
217
|
+
"""
|
|
218
|
+
cat_past = None
|
|
219
|
+
cat_fut = None
|
|
220
|
+
# GET AVAILABLE CATEGORICAL CONTEXT
|
|
221
|
+
if 'x_cat_past' in batch.keys():
|
|
222
|
+
cat_past = batch['x_cat_past'].to(self.device)
|
|
223
|
+
if 'x_cat_future' in batch.keys():
|
|
224
|
+
cat_fut = batch['x_cat_future'].to(self.device)
|
|
225
|
+
# CONCAT THEM, according to self.emb_cat_var usage
|
|
226
|
+
if cat_past is None:
|
|
227
|
+
emb_cat_full = self.emb_cat_var(batch['x_num_past'].shape[0],self.device)
|
|
228
|
+
|
|
229
|
+
else:
|
|
230
|
+
cat_full = torch.cat((cat_past, cat_fut), dim = 1)
|
|
231
|
+
emb_cat_full = self.emb_cat_var(cat_full,self.device)
|
|
232
|
+
cat_emb_past = emb_cat_full[:,:self.past_steps,:,:]
|
|
233
|
+
cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:]
|
|
234
|
+
|
|
235
|
+
return cat_emb_past, cat_emb_fut
|
|
236
|
+
|
|
237
|
+
#function to extract from batch['x_num_past'] all variables except the one autoregressive
|
|
238
|
+
'''
|
|
239
|
+
def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: list, dimension: int)-> torch.Tensor:
|
|
240
|
+
"""Function to remove variables from tensors in chosen dimension and position
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
tensor (torch.Tensor): starting tensor
|
|
244
|
+
indexes_to_exclude (list): index of the chosen dimension we want t oexclude
|
|
245
|
+
dimension (int): dimension of the tensor on which we want to work (not list od dims!!)
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
torch.Tensor: new tensor without the chosen variables
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device)
|
|
252
|
+
# Select the desired sub-tensor
|
|
253
|
+
extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx)
|
|
254
|
+
|
|
255
|
+
return extracted_subtensors
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class ResidualBlock(nn.Module):
|
|
259
|
+
def __init__(self, in_size:int, out_size:int, dropout_rate:float, activation_fun:str=''):
|
|
260
|
+
"""Residual Block as basic layer of the archetecture.
|
|
261
|
+
|
|
262
|
+
MLP with one hidden layer, activation and skip connection
|
|
263
|
+
Basically dimension d_model, but better if input_dim and output_dim are explicit
|
|
264
|
+
|
|
265
|
+
in_size and out_size to handle dimensions at different stages of the NN
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
in_size (int):
|
|
269
|
+
out_size (int):
|
|
270
|
+
dropout_rate (float):
|
|
271
|
+
activation_fun (str, optional): activation function to use in the Residual Block. Defaults to nn.ReLU.
|
|
272
|
+
"""
|
|
273
|
+
super().__init__()
|
|
274
|
+
|
|
275
|
+
self.direct_linear = nn.Linear(in_size, out_size, bias = False)
|
|
276
|
+
|
|
277
|
+
if activation_fun=='':
|
|
278
|
+
self.act = nn.ReLU()
|
|
279
|
+
else:
|
|
280
|
+
activation = eval(activation_fun)
|
|
281
|
+
self.act = activation()
|
|
282
|
+
self.lin = nn.Linear(in_size, out_size)
|
|
283
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
284
|
+
|
|
285
|
+
self.final_norm = nn.LayerNorm(out_size)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def forward(self, x, apply_final_norm = True):
|
|
289
|
+
direct_x = self.direct_linear(x)
|
|
290
|
+
|
|
291
|
+
x = self.dropout(self.lin(self.act(x)))
|
|
292
|
+
|
|
293
|
+
out = x + direct_x
|
|
294
|
+
if apply_final_norm:
|
|
295
|
+
return self.final_norm(out)
|
|
296
|
+
return out
|
dsipts/models/TTM.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import List,Union
|
|
17
|
+
|
|
18
|
+
from .utils import QuantileLossMO
|
|
19
|
+
from ..data_structure.utils import beauty_string
|
|
20
|
+
from .ttm.utils import get_model, get_frequency_token, count_parameters, RMSELoss
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TTM(Base):
|
|
24
|
+
def __init__(self,
|
|
25
|
+
model_path:str,
|
|
26
|
+
past_steps:int,
|
|
27
|
+
future_steps:int,
|
|
28
|
+
freq_prefix_tuning:bool,
|
|
29
|
+
freq:str,
|
|
30
|
+
prefer_l1_loss:bool, # exog: set true to use l1 loss
|
|
31
|
+
prefer_longer_context:bool,
|
|
32
|
+
loss_type:str,
|
|
33
|
+
num_input_channels,
|
|
34
|
+
prediction_channel_indices,
|
|
35
|
+
exogenous_channel_indices,
|
|
36
|
+
decoder_mode,
|
|
37
|
+
fcm_context_length,
|
|
38
|
+
fcm_use_mixer,
|
|
39
|
+
fcm_mix_layers,
|
|
40
|
+
fcm_prepend_past,
|
|
41
|
+
enable_forecast_channel_mixing,
|
|
42
|
+
out_channels:int,
|
|
43
|
+
embs:List[int],
|
|
44
|
+
remove_last = False,
|
|
45
|
+
optim:Union[str,None]=None,
|
|
46
|
+
optim_config:dict=None,
|
|
47
|
+
scheduler_config:dict=None,
|
|
48
|
+
verbose = False,
|
|
49
|
+
use_quantiles=False,
|
|
50
|
+
persistence_weight:float=0.0,
|
|
51
|
+
quantiles:List[int]=[],
|
|
52
|
+
**kwargs)->None:
|
|
53
|
+
"""TODO and FIX for future and past categorical variables
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model_path (str): _description_
|
|
57
|
+
past_steps (int): _description_
|
|
58
|
+
future_steps (int): _description_
|
|
59
|
+
freq_prefix_tuning (bool): _description_
|
|
60
|
+
freq (str): _description_
|
|
61
|
+
prefer_l1_loss (bool): _description_
|
|
62
|
+
loss_type (str): _description_
|
|
63
|
+
num_input_channels (_type_): _description_
|
|
64
|
+
prediction_channel_indices (_type_): _description_
|
|
65
|
+
exogenous_channel_indices (_type_): _description_
|
|
66
|
+
decoder_mode (_type_): _description_
|
|
67
|
+
fcm_context_length (_type_): _description_
|
|
68
|
+
fcm_use_mixer (_type_): _description_
|
|
69
|
+
fcm_mix_layers (_type_): _description_
|
|
70
|
+
fcm_prepend_past (_type_): _description_
|
|
71
|
+
enable_forecast_channel_mixing (_type_): _description_
|
|
72
|
+
out_channels (int): _description_
|
|
73
|
+
embs (List[int]): _description_
|
|
74
|
+
remove_last (bool, optional): _description_. Defaults to False.
|
|
75
|
+
optim (Union[str,None], optional): _description_. Defaults to None.
|
|
76
|
+
optim_config (dict, optional): _description_. Defaults to None.
|
|
77
|
+
scheduler_config (dict, optional): _description_. Defaults to None.
|
|
78
|
+
verbose (bool, optional): _description_. Defaults to False.
|
|
79
|
+
use_quantiles (bool, optional): _description_. Defaults to False.
|
|
80
|
+
persistence_weight (float, optional): _description_. Defaults to 0.0.
|
|
81
|
+
quantiles (List[int], optional): _description_. Defaults to [].
|
|
82
|
+
"""
|
|
83
|
+
super(TTM, self).__init__(verbose)
|
|
84
|
+
self.save_hyperparameters(logger=False)
|
|
85
|
+
self.future_steps = future_steps
|
|
86
|
+
self.use_quantiles = use_quantiles
|
|
87
|
+
self.optim = optim
|
|
88
|
+
self.optim_config = optim_config
|
|
89
|
+
self.scheduler_config = scheduler_config
|
|
90
|
+
self.persistence_weight = persistence_weight
|
|
91
|
+
self.loss_type = loss_type
|
|
92
|
+
self.remove_last = remove_last
|
|
93
|
+
self.embs = embs
|
|
94
|
+
self.freq = freq
|
|
95
|
+
self.extend_variables = False
|
|
96
|
+
|
|
97
|
+
# NOTE: For Hydra
|
|
98
|
+
prediction_channel_indices = list(prediction_channel_indices)
|
|
99
|
+
exogenous_channel_indices = list(exogenous_channel_indices)
|
|
100
|
+
|
|
101
|
+
if len(quantiles)>0:
|
|
102
|
+
assert len(quantiles)==3, beauty_string('ONLY 3 quantiles premitted','info',True)
|
|
103
|
+
self.use_quantiles = True
|
|
104
|
+
self.mul = len(quantiles)
|
|
105
|
+
self.loss = QuantileLossMO(quantiles)
|
|
106
|
+
self.extend_variables = True
|
|
107
|
+
if out_channels * 3 != len(prediction_channel_indices):
|
|
108
|
+
prediction_channel_indices, exogenous_channel_indices, num_input_channels = self.__add_quantile_features(prediction_channel_indices,
|
|
109
|
+
exogenous_channel_indices,
|
|
110
|
+
out_channels)
|
|
111
|
+
else:
|
|
112
|
+
self.mul = 1
|
|
113
|
+
if self.loss_type == 'mse':
|
|
114
|
+
self.loss = nn.MSELoss(reduction="mean")
|
|
115
|
+
elif self.loss_type == 'rmse':
|
|
116
|
+
self.loss = RMSELoss()
|
|
117
|
+
else:
|
|
118
|
+
self.loss = nn.L1Loss()
|
|
119
|
+
|
|
120
|
+
self.model = get_model(
|
|
121
|
+
model_path=model_path,
|
|
122
|
+
context_length=past_steps,
|
|
123
|
+
prediction_length=future_steps,
|
|
124
|
+
freq_prefix_tuning=freq_prefix_tuning,
|
|
125
|
+
freq=freq,
|
|
126
|
+
prefer_l1_loss=prefer_l1_loss,
|
|
127
|
+
prefer_longer_context=prefer_longer_context,
|
|
128
|
+
num_input_channels=num_input_channels,
|
|
129
|
+
decoder_mode=decoder_mode,
|
|
130
|
+
prediction_channel_indices=list(prediction_channel_indices),
|
|
131
|
+
exogenous_channel_indices=list(exogenous_channel_indices),
|
|
132
|
+
fcm_context_length=fcm_context_length,
|
|
133
|
+
fcm_use_mixer=fcm_use_mixer,
|
|
134
|
+
fcm_mix_layers=fcm_mix_layers,
|
|
135
|
+
fcm_prepend_past=fcm_prepend_past,
|
|
136
|
+
#loss='mse',
|
|
137
|
+
enable_forecast_channel_mixing=enable_forecast_channel_mixing,
|
|
138
|
+
)
|
|
139
|
+
self.__freeze_backbone()
|
|
140
|
+
|
|
141
|
+
def __add_quantile_features(self, prediction_channel_indices, exogenous_channel_indices, out_channels):
|
|
142
|
+
prediction_channel_indices = list(range(out_channels * 3))
|
|
143
|
+
exogenous_channel_indices = [prediction_channel_indices[-1] + i for i in range(1, len(exogenous_channel_indices)+1)]
|
|
144
|
+
num_input_channels = len(prediction_channel_indices) + len(exogenous_channel_indices)
|
|
145
|
+
return prediction_channel_indices, exogenous_channel_indices, num_input_channels
|
|
146
|
+
|
|
147
|
+
def __freeze_backbone(self):
|
|
148
|
+
"""
|
|
149
|
+
Freeze the backbone of the model.
|
|
150
|
+
This is useful when you want to fine-tune only the head of the model.
|
|
151
|
+
"""
|
|
152
|
+
print(
|
|
153
|
+
"Number of params before freezing backbone",
|
|
154
|
+
count_parameters(self.model),
|
|
155
|
+
)
|
|
156
|
+
# Freeze the backbone of the model
|
|
157
|
+
for param in self.model.backbone.parameters():
|
|
158
|
+
param.requires_grad = False
|
|
159
|
+
# Count params
|
|
160
|
+
print(
|
|
161
|
+
"Number of params after freezing the backbone",
|
|
162
|
+
count_parameters(self.model),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def __scaler(self, input):
|
|
166
|
+
#new_data = torch.tensor([MinMaxScaler().fit_transform(step_data) for step_data in data])
|
|
167
|
+
for i, e in enumerate(self.embs):
|
|
168
|
+
input[:,:,i] = input[:, :, i] / (e-1)
|
|
169
|
+
return input
|
|
170
|
+
|
|
171
|
+
def __build_tupla_indexes(self, size, target_idx, current_idx):
|
|
172
|
+
permute = list(range(size))
|
|
173
|
+
history = dict()
|
|
174
|
+
for j, i in enumerate(target_idx):
|
|
175
|
+
c = history.get(current_idx[j], current_idx[j])
|
|
176
|
+
permute[i], permute[c] = current_idx[j], i
|
|
177
|
+
history[i] = current_idx[j]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def __permute_indexes(self, values, target_idx, current_idx):
|
|
181
|
+
if current_idx is None or target_idx is None:
|
|
182
|
+
raise ValueError("Indexes cannot be None")
|
|
183
|
+
if sorted(current_idx) != sorted(target_idx):
|
|
184
|
+
return values[..., self.__build_tupla_indexes(values.shape[-1], target_idx, current_idx)]
|
|
185
|
+
return values
|
|
186
|
+
|
|
187
|
+
def __extend_with_quantile_variables(self, x, original_indexes):
|
|
188
|
+
covariate_indexes = [i for i in range(x.shape[-1]) if i not in original_indexes]
|
|
189
|
+
covariate_tensors = x[..., covariate_indexes]
|
|
190
|
+
|
|
191
|
+
new_tensors = [x[..., target_index] for target_index in original_indexes for _ in range(3)]
|
|
192
|
+
|
|
193
|
+
new_original_indexes = list(range(len(original_indexes) * 3))
|
|
194
|
+
return torch.cat([torch.stack(new_tensors, dim=-1), covariate_tensors], dim=-1), new_original_indexes
|
|
195
|
+
|
|
196
|
+
def forward(self, batch):
|
|
197
|
+
x_enc = batch['x_num_past']
|
|
198
|
+
original_indexes = batch['idx_target'][0].tolist()
|
|
199
|
+
original_indexes_future = batch['idx_target_future'][0].tolist()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
if self.extend_variables:
|
|
203
|
+
x_enc, original_indexes = self.__extend_with_quantile_variables(x_enc, original_indexes)
|
|
204
|
+
|
|
205
|
+
if 'x_cat_past' in batch.keys():
|
|
206
|
+
x_mark_enc = batch['x_cat_past'].to(torch.float32).to(self.device)
|
|
207
|
+
x_mark_enc = self.__scaler(x_mark_enc)
|
|
208
|
+
past_values = torch.cat((x_enc,x_mark_enc), axis=-1).type(torch.float32)
|
|
209
|
+
else:
|
|
210
|
+
past_values = x_enc
|
|
211
|
+
|
|
212
|
+
x_dec = torch.tensor([]).to(self.device)
|
|
213
|
+
if 'x_num_future' in batch.keys():
|
|
214
|
+
x_dec = batch['x_num_future'].to(self.device)
|
|
215
|
+
if self.extend_variables:
|
|
216
|
+
x_dec, original_indexes_future = self.__extend_with_quantile_variables(x_dec, original_indexes_future)
|
|
217
|
+
if 'x_cat_future' in batch.keys():
|
|
218
|
+
x_mark_dec = batch['x_cat_future'].to(torch.float32).to(self.device)
|
|
219
|
+
x_mark_dec = self.__scaler(x_mark_dec)
|
|
220
|
+
future_values = torch.cat((x_dec, x_mark_dec), axis=-1).type(torch.float32)
|
|
221
|
+
else:
|
|
222
|
+
future_values = x_dec
|
|
223
|
+
|
|
224
|
+
if self.remove_last:
|
|
225
|
+
idx_target = batch['idx_target'][0]
|
|
226
|
+
x_start = x_enc[:,-1,idx_target].unsqueeze(1)
|
|
227
|
+
x_enc[:,:,idx_target]-=x_start
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
past_values = self.__permute_indexes(past_values, self.model.prediction_channel_indices, original_indexes)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
future_values = self.__permute_indexes(future_values, self.model.prediction_channel_indices, original_indexes_future)
|
|
234
|
+
|
|
235
|
+
freq_token = get_frequency_token(self.freq).repeat(x_enc.shape[0])
|
|
236
|
+
|
|
237
|
+
res = self.model(
|
|
238
|
+
past_values= past_values,
|
|
239
|
+
future_values= future_values,
|
|
240
|
+
past_observed_mask = None,
|
|
241
|
+
future_observed_mask = None,
|
|
242
|
+
output_hidden_states = False,
|
|
243
|
+
return_dict = False,
|
|
244
|
+
freq_token= freq_token,
|
|
245
|
+
static_categorical_values = None
|
|
246
|
+
)
|
|
247
|
+
#args = None
|
|
248
|
+
#res = self.model(**args)
|
|
249
|
+
BS = res.shape[0]
|
|
250
|
+
return res.reshape(BS,self.future_steps,-1,self.mul)
|
|
251
|
+
|
|
252
|
+
|