dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
## Copyright https://github.com/romilbert/samformer/tree/main?tab=MIT-1-ov-file#readme
|
|
2
|
+
## Modified for notation alignmenet and batch structure
|
|
3
|
+
## extended to what inside samformer folder
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
from .samformer.utils import scaled_dot_product_attention, RevIN
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import lightning.pytorch as pl
|
|
14
|
+
from .base_v2 import Base
|
|
15
|
+
OLD_PL = False
|
|
16
|
+
except:
|
|
17
|
+
import pytorch_lightning as pl
|
|
18
|
+
OLD_PL = True
|
|
19
|
+
from .base import Base
|
|
20
|
+
from .utils import QuantileLossMO,Permute, get_activation
|
|
21
|
+
|
|
22
|
+
from typing import List, Union
|
|
23
|
+
from ..data_structure.utils import beauty_string
|
|
24
|
+
from .utils import get_scope
|
|
25
|
+
from .utils import Embedding_cat_variables
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Samformer(Base):
|
|
31
|
+
handle_multivariate = True
|
|
32
|
+
handle_future_covariates = False # or at least it seems...
|
|
33
|
+
handle_categorical_variables = False #solo nel encoder
|
|
34
|
+
handle_quantile_loss = False # NOT EFFICIENTLY ADDED, TODO fix this
|
|
35
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
|
|
39
|
+
# specific params
|
|
40
|
+
hidden_size:int,
|
|
41
|
+
use_revin: bool,
|
|
42
|
+
activation: str='',
|
|
43
|
+
|
|
44
|
+
**kwargs)->None:
|
|
45
|
+
"""Initialize the model with specified parameters. Samformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention.
|
|
46
|
+
https://arxiv.org/pdf/2402.10198
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
hidden_size (int): The size of the hidden layer.
|
|
50
|
+
use_revin (bool): Flag indicating whether to use RevIN.
|
|
51
|
+
activation (str, optional): The activation function to use. Defaults to ''.
|
|
52
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If the activation function is not recognized.
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
super().__init__(**kwargs)
|
|
61
|
+
if activation == 'torch.nn.SELU':
|
|
62
|
+
beauty_string('SELU do not require BN','info',self.verbose)
|
|
63
|
+
use_bn = False
|
|
64
|
+
if isinstance(activation,str):
|
|
65
|
+
activation = get_activation(activation)
|
|
66
|
+
self.save_hyperparameters(logger=False)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
70
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
71
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
72
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
self.revin = RevIN(num_features=self.past_channels+emb_past_out_channel)
|
|
76
|
+
self.compute_keys = nn.Linear(self.past_steps, hidden_size)
|
|
77
|
+
self.compute_queries = nn.Linear(self.past_steps, hidden_size)
|
|
78
|
+
self.compute_values = nn.Linear(self.past_steps, self.past_steps)
|
|
79
|
+
self.linear_forecaster = nn.Linear(self.past_steps, self.future_steps)
|
|
80
|
+
self.use_revin = use_revin
|
|
81
|
+
|
|
82
|
+
dim = emb_past_out_channel+self.past_channels+emb_fut_out_channel+self.future_channels
|
|
83
|
+
self.final_layer = nn.Sequential(activation(),
|
|
84
|
+
nn.Linear(dim, dim*2),
|
|
85
|
+
activation(),
|
|
86
|
+
nn.Linear(dim*2,self.out_channels*self.mul ))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def forward(self, batch:dict)-> float:
|
|
90
|
+
|
|
91
|
+
x = batch['x_num_past'].to(self.device)
|
|
92
|
+
BS = x.shape[0]
|
|
93
|
+
if 'x_cat_future' in batch.keys():
|
|
94
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
95
|
+
else:
|
|
96
|
+
emb_fut = self.emb_fut(BS,None)
|
|
97
|
+
if 'x_cat_past' in batch.keys():
|
|
98
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
99
|
+
else:
|
|
100
|
+
emb_past = self.emb_past(BS,None)
|
|
101
|
+
|
|
102
|
+
tmp_future = [emb_fut]
|
|
103
|
+
if 'x_num_future' in batch.keys():
|
|
104
|
+
x_future = batch['x_num_future'].to(self.device)
|
|
105
|
+
tmp_future.append(x_future)
|
|
106
|
+
|
|
107
|
+
tot = [x,emb_past]
|
|
108
|
+
x = torch.cat(tot,axis=2)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
if self.use_revin:
|
|
114
|
+
x_norm = self.revin(x, mode='norm').transpose(1, 2) # (n, D, L)
|
|
115
|
+
else:
|
|
116
|
+
x_norm = x.transpose(1, 2)
|
|
117
|
+
# Channel-Wise Attention
|
|
118
|
+
|
|
119
|
+
queries = self.compute_queries(x_norm) # (n, D, hid_dim)
|
|
120
|
+
keys = self.compute_keys(x_norm) # (n, D, hid_dim)
|
|
121
|
+
values = self.compute_values(x_norm) # (n, D, L)
|
|
122
|
+
if hasattr(nn.functional, 'scaled_dot_product_attention'):
|
|
123
|
+
att_score = nn.functional.scaled_dot_product_attention(queries, keys, values) # (n, D, L)
|
|
124
|
+
else:
|
|
125
|
+
att_score = scaled_dot_product_attention(queries, keys, values) # (n, D, L)
|
|
126
|
+
out = x_norm + att_score # (n, D, L)
|
|
127
|
+
# Linear Forecasting
|
|
128
|
+
out = self.linear_forecaster(out) # (n, D, H)
|
|
129
|
+
# RevIN Denormalization
|
|
130
|
+
if self.use_revin:
|
|
131
|
+
out = self.revin(out.transpose(1, 2), mode='denorm').transpose(1, 2) # (n, D, H)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
tmp_future.append(out.permute(0,2,1))
|
|
135
|
+
tmp_future = torch.cat(tmp_future,2)
|
|
136
|
+
output = self.final_layer(tmp_future)
|
|
137
|
+
|
|
138
|
+
return output.reshape(BS,self.future_steps,self.out_channels,self.mul)
|
|
139
|
+
|
dsipts/models/TFT.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .tft import sub_nn
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import lightning.pytorch as pl
|
|
7
|
+
from .base_v2 import Base
|
|
8
|
+
OLD_PL = False
|
|
9
|
+
except:
|
|
10
|
+
import pytorch_lightning as pl
|
|
11
|
+
OLD_PL = True
|
|
12
|
+
from .base import Base
|
|
13
|
+
from .utils import QuantileLossMO
|
|
14
|
+
from typing import List, Union
|
|
15
|
+
from ..data_structure.utils import beauty_string
|
|
16
|
+
from .utils import get_scope
|
|
17
|
+
from .utils import Embedding_cat_variables
|
|
18
|
+
|
|
19
|
+
class TFT(Base):
|
|
20
|
+
handle_multivariate = True
|
|
21
|
+
handle_future_covariates = True
|
|
22
|
+
handle_categorical_variables = True
|
|
23
|
+
handle_quantile_loss = True
|
|
24
|
+
description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
d_model: int,
|
|
28
|
+
num_layers_RNN: int,
|
|
29
|
+
d_head: int,
|
|
30
|
+
n_head: int,
|
|
31
|
+
dropout_rate: float,
|
|
32
|
+
|
|
33
|
+
**kwargs)->None:
|
|
34
|
+
"""Initializes the model for time series forecasting with attention mechanisms and recurrent neural networks.
|
|
35
|
+
|
|
36
|
+
This model is designed for direct forecasting, allowing for multi-output and multi-horizon predictions. It leverages attention mechanisms to enhance the selection of relevant past time steps and learn long-term dependencies. The architecture includes RNN enrichment, gating mechanisms to minimize the impact of irrelevant variables, and the ability to output prediction intervals through quantile regression.
|
|
37
|
+
|
|
38
|
+
Key features include:
|
|
39
|
+
- Direct Model: Predicts all future steps at once.
|
|
40
|
+
- Multi-Output Forecasting: Capable of predicting one or more variables simultaneously.
|
|
41
|
+
- Multi-Horizon Forecasting: Predicts variables at multiple future time steps.
|
|
42
|
+
- Attention-Based Mechanism: Enhances the selection of relevant past time steps and learns long-term dependencies.
|
|
43
|
+
- RNN Enrichment: Utilizes LSTM for initial autoregressive approximation, which is refined by the rest of the network.
|
|
44
|
+
- Gating Mechanisms: Reduces the contribution of irrelevant variables.
|
|
45
|
+
- Prediction Intervals: Outputs percentiles (e.g., 10th, 50th, 90th) at each time step.
|
|
46
|
+
|
|
47
|
+
The model also facilitates interpretability by identifying:
|
|
48
|
+
- Global importance of variables for both past and future.
|
|
49
|
+
- Temporal patterns.
|
|
50
|
+
- Significant events.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
d_model (int): General hidden dimension across the network, adjustable in sub-networks.
|
|
54
|
+
num_layers_RNN (int): Number of layers in the recurrent neural network (LSTM).
|
|
55
|
+
d_head (int): Dimension of each attention head.
|
|
56
|
+
n_head (int): Number of attention heads.
|
|
57
|
+
dropout_rate (float): Dropout rate applied uniformly across all dropout layers.
|
|
58
|
+
**kwargs: Additional keyword arguments for further customization.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
super().__init__(**kwargs)
|
|
66
|
+
self.save_hyperparameters(logger=False)
|
|
67
|
+
# assert out_channels==1, logging.info("ONLY ONE CHANNEL IMPLEMENTED")
|
|
68
|
+
self.d_model = d_model
|
|
69
|
+
# linear to embed the target vartiable
|
|
70
|
+
self.target_linear = nn.Linear(self.out_channels, d_model) # same for past and fut! (same variable)
|
|
71
|
+
# number of variables in the past different from the target one(s)
|
|
72
|
+
self.aux_past_channels = self.past_channels - self.out_channels # -1 because one channel is occupied by the target variable
|
|
73
|
+
# one linear for each auxiliar past var
|
|
74
|
+
self.linear_aux_past = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_past_channels)])
|
|
75
|
+
# number of variables in the future used to predict the target one(s)
|
|
76
|
+
self.aux_fut_channels = self.future_channels
|
|
77
|
+
# one linear for each auxiliar future var
|
|
78
|
+
self.linear_aux_fut = nn.ModuleList([nn.Linear(1, d_model) for _ in range(self.aux_fut_channels)])
|
|
79
|
+
# length of the full sequence, parameter used for the embedding of all categorical variables
|
|
80
|
+
# - we assume that these are no available or available both for past and future
|
|
81
|
+
seq_len = self.past_steps+self.future_steps
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
##in v.1.1.5 this is not working, past and future are different for categorical
|
|
85
|
+
#self.emb_cat_var = sub_nn.embedding_cat_variables(seq_len, self.future_steps, d_model, embs, self.device)
|
|
86
|
+
|
|
87
|
+
self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
88
|
+
self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
|
|
89
|
+
emb_past_out_channel = self.emb_past.output_channels
|
|
90
|
+
emb_fut_out_channel = self.emb_fut.output_channels
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Recurrent Neural Network for first aproximated inference of the target variable(s) - IT IS NON RE-EMBEDDED YET
|
|
94
|
+
self.rnn = sub_nn.LSTM_Model(num_var=self.out_channels,
|
|
95
|
+
d_model = d_model,
|
|
96
|
+
pred_step = self.future_steps,
|
|
97
|
+
num_layers = num_layers_RNN,
|
|
98
|
+
dropout = dropout_rate)
|
|
99
|
+
# PARTS OF TFT:
|
|
100
|
+
# - Residual connections
|
|
101
|
+
# - Gated Residual Network
|
|
102
|
+
# - Interpretable MultiHead Attention
|
|
103
|
+
self.res_conn1_past = sub_nn.ResidualConnection(d_model, dropout_rate)
|
|
104
|
+
self.res_conn1_fut = sub_nn.ResidualConnection(d_model, dropout_rate)
|
|
105
|
+
self.grn1_past = sub_nn.GRN(d_model, dropout_rate)
|
|
106
|
+
self.grn1_fut = sub_nn.GRN(d_model, dropout_rate)
|
|
107
|
+
self.InterpretableMultiHead = sub_nn.InterpretableMultiHead(d_model, d_head, n_head)
|
|
108
|
+
self.res_conn2_att = sub_nn.ResidualConnection(d_model, dropout_rate)
|
|
109
|
+
self.grn2_att = sub_nn.GRN(d_model, dropout_rate)
|
|
110
|
+
self.res_conn3_out = sub_nn.ResidualConnection(d_model, dropout_rate)
|
|
111
|
+
|
|
112
|
+
self.outLinear = nn.Linear(d_model, self.out_channels*self.mul)
|
|
113
|
+
|
|
114
|
+
def forward(self, batch:dict) -> torch.Tensor:
|
|
115
|
+
"""Temporal Fusion Transformer
|
|
116
|
+
|
|
117
|
+
Collectiong Data
|
|
118
|
+
- Extract the autoregressive variable(s)
|
|
119
|
+
- Embedding and compute a first approximated prediction
|
|
120
|
+
- 'summary_past' and 'summary_fut' collecting data about past and future
|
|
121
|
+
Concatenating on the dimension 2 all different datas, which will be mixed through a MEAN over that imension
|
|
122
|
+
Info get from other tensor of the batch taken as input
|
|
123
|
+
|
|
124
|
+
TFT actual computations
|
|
125
|
+
- Residual Connection for y_past and summary_past
|
|
126
|
+
- Residual Connection for y_fut and summary_fut
|
|
127
|
+
- GRN1 for past and for fut
|
|
128
|
+
- ATTENTION(summary_fut, summary_past, y_past)
|
|
129
|
+
- Residual Connection for attention itself
|
|
130
|
+
- GRN2 for attention
|
|
131
|
+
- Residual Connection for attention and summary_fut
|
|
132
|
+
- Linear for actual values and reshape
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
batch (dict): Keys used are ['x_num_past', 'idx_target', 'x_num_future', 'x_cat_past', 'x_cat_future']
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
torch.Tensor: shape [B, self.future_steps, self.out_channels, self.mul] or [B, self.future_steps, self.out_channels] according to quantiles
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
num_past = batch['x_num_past'].to(self.device)
|
|
142
|
+
# PAST TARGET NUMERICAL VARIABLE
|
|
143
|
+
# always available: autoregressive variable
|
|
144
|
+
# compute rnn prediction
|
|
145
|
+
idx_target = batch['idx_target'][0]
|
|
146
|
+
target_num_past = num_past[:,:,idx_target]
|
|
147
|
+
target_emb_num_past = self.target_linear(target_num_past) # target_variables comunicating with each others
|
|
148
|
+
target_num_fut_approx = self.rnn(target_emb_num_past)
|
|
149
|
+
# embed future predictions
|
|
150
|
+
target_emb_num_fut_approx = self.target_linear(target_num_fut_approx)
|
|
151
|
+
|
|
152
|
+
### create variable summary_past and summary_fut
|
|
153
|
+
# at the beggining it is composed only by past and future target variable
|
|
154
|
+
summary_past = target_emb_num_past.unsqueeze(2)
|
|
155
|
+
summary_fut = target_emb_num_fut_approx.unsqueeze(2)
|
|
156
|
+
# now we search for others categorical and numerical variables!
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
### PAST NUMERICAL VARIABLES
|
|
160
|
+
if self.aux_past_channels>0: # so we have more numerical variables about past
|
|
161
|
+
# AUX = AUXILIARY variables
|
|
162
|
+
aux_num_past = self.remove_var(num_past, idx_target, 2) # remove the target index on the second dimension
|
|
163
|
+
assert self.aux_past_channels == aux_num_past.size(2), beauty_string(f"{self.aux_past_channels} LAYERS FOR PAST VARS AND {aux_num_past.shape(2)} VARS",'section',True) # to check if we are using the expected number of variables about past
|
|
164
|
+
aux_emb_num_past = torch.Tensor().to(aux_num_past.device)
|
|
165
|
+
for i, layer in enumerate(self.linear_aux_past):
|
|
166
|
+
aux_emb_past = layer(aux_num_past[:,:,[i]]).unsqueeze(2)
|
|
167
|
+
aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2)
|
|
168
|
+
## update summary about past
|
|
169
|
+
summary_past = torch.cat((summary_past, aux_emb_num_past), dim=2)
|
|
170
|
+
|
|
171
|
+
### FUTURE NUMERICAL VARIABLES
|
|
172
|
+
if self.aux_fut_channels>0: # so we have more numerical variables about future
|
|
173
|
+
aux_num_fut = batch['x_num_future'].to(self.device)
|
|
174
|
+
assert self.aux_fut_channels == aux_num_fut.size(2), beauty_string(f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND {aux_num_fut.size(2)} VARS",'section',True) # to check if we are using the expected number of variables about fut
|
|
175
|
+
aux_emb_num_fut = torch.Tensor().to(aux_num_fut.device)
|
|
176
|
+
for j, layer in enumerate(self.linear_aux_fut):
|
|
177
|
+
aux_emb_fut = layer(aux_num_fut[:,:,[j]]).unsqueeze(2)
|
|
178
|
+
aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2)
|
|
179
|
+
## update summary about future
|
|
180
|
+
summary_fut = torch.cat((summary_fut, aux_emb_num_fut), dim=2)
|
|
181
|
+
'''
|
|
182
|
+
### CATEGORICAL VARIABLES changed in 1.1.5
|
|
183
|
+
if 'x_cat_past' in batch.keys() and 'x_cat_future' in batch.keys(): # if we have both
|
|
184
|
+
# HERE WE ASSUME SAME NUMBER AND KIND OF VARIABLES IN PAST AND FUTURE
|
|
185
|
+
cat_past = batch['x_cat_past'].to(self.device)
|
|
186
|
+
cat_fut = batch['x_cat_future'].to(self.device)
|
|
187
|
+
cat_full = torch.cat((cat_past, cat_fut), dim = 1)
|
|
188
|
+
# EMB CATEGORICAL VARIABLES AND THEN SPLIT IN PAST AND FUTURE
|
|
189
|
+
emb_cat_full = self.emb_cat_var(cat_full,self.device)
|
|
190
|
+
else:
|
|
191
|
+
emb_cat_full = self.emb_cat_var(num_past.shape[0],self.device)
|
|
192
|
+
|
|
193
|
+
cat_emb_past = emb_cat_full[:,:-self.future_steps,:,:]
|
|
194
|
+
cat_emb_fut = emb_cat_full[:,-self.future_steps:,:,:]
|
|
195
|
+
|
|
196
|
+
## update summary
|
|
197
|
+
# past
|
|
198
|
+
summary_past = torch.cat((summary_past, cat_emb_past), dim=2)
|
|
199
|
+
# future
|
|
200
|
+
summary_fut = torch.cat((summary_fut, cat_emb_fut), dim=2)
|
|
201
|
+
'''
|
|
202
|
+
BS = num_past.shape[0]
|
|
203
|
+
if 'x_cat_future' in batch.keys():
|
|
204
|
+
emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
|
|
205
|
+
else:
|
|
206
|
+
emb_fut = self.emb_fut(BS,None)
|
|
207
|
+
if 'x_cat_past' in batch.keys():
|
|
208
|
+
emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
|
|
209
|
+
else:
|
|
210
|
+
emb_past = self.emb_past(BS,None)
|
|
211
|
+
|
|
212
|
+
## update summary
|
|
213
|
+
# past
|
|
214
|
+
|
|
215
|
+
summary_past = torch.cat((summary_past, emb_past.unsqueeze(-1).repeat((1,1,1,summary_past.shape[-1]))), dim=2)
|
|
216
|
+
# future
|
|
217
|
+
summary_fut = torch.cat((summary_fut, emb_fut.unsqueeze(-1).repeat((1,1,1,summary_past.shape[-1]))), dim=2)
|
|
218
|
+
|
|
219
|
+
# >>> PAST:
|
|
220
|
+
summary_past = torch.mean(summary_past, dim=2)
|
|
221
|
+
# >>> FUTURE:
|
|
222
|
+
summary_fut = torch.mean(summary_fut, dim=2)
|
|
223
|
+
|
|
224
|
+
### Residual Connection from LSTM
|
|
225
|
+
summary_past = self.res_conn1_past(summary_past, target_emb_num_past)
|
|
226
|
+
summary_fut = self.res_conn1_fut(summary_fut, target_emb_num_fut_approx)
|
|
227
|
+
|
|
228
|
+
### GRN1
|
|
229
|
+
summary_past = self.grn1_past(summary_past)
|
|
230
|
+
summary_fut = self.grn1_fut(summary_fut)
|
|
231
|
+
|
|
232
|
+
### INTERPRETABLE MULTI HEAD ATTENTION
|
|
233
|
+
attention = self.InterpretableMultiHead(summary_fut, summary_past, target_emb_num_past)
|
|
234
|
+
|
|
235
|
+
### Residual Connection from ATT
|
|
236
|
+
attention = self.res_conn2_att(attention, attention)
|
|
237
|
+
|
|
238
|
+
### GRN
|
|
239
|
+
attention = self.grn2_att(attention)
|
|
240
|
+
|
|
241
|
+
### Resuidual Connection from GRN1
|
|
242
|
+
out = self.res_conn3_out(attention, summary_fut)
|
|
243
|
+
|
|
244
|
+
### OUT
|
|
245
|
+
out = self.outLinear(out)
|
|
246
|
+
|
|
247
|
+
if self.mul>0:
|
|
248
|
+
out = out.view(-1, self.future_steps, self.out_channels, self.mul)
|
|
249
|
+
return out
|
|
250
|
+
|
|
251
|
+
#function to extract from batch['x_num_past'] all variables except the one autoregressive
|
|
252
|
+
def remove_var(self, tensor: torch.Tensor, indexes_to_exclude: int, dimension: int)-> torch.Tensor:
|
|
253
|
+
"""Function to remove variables from tensors in chosen dimension and position
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
tensor (torch.Tensor): starting tensor
|
|
257
|
+
indexes_to_exclude (int): index of the chosen dimension we want t oexclude
|
|
258
|
+
dimension (int): dimension of the tensor on which we want to work
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
torch.Tensor: new tensor without the chosen variables
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
remaining_idx = torch.tensor([i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude]).to(tensor.device)
|
|
265
|
+
# Select the desired sub-tensor
|
|
266
|
+
extracted_subtensors = torch.index_select(tensor, dim=dimension, index=remaining_idx)
|
|
267
|
+
|
|
268
|
+
return extracted_subtensors
|
|
269
|
+
|