dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/Duet.py ADDED
@@ -0,0 +1,197 @@
1
+
2
+ ## Copyright 2025 DUET (https://github.com/decisionintelligence/DUET)
3
+ ## Code modified for align the notation and the batch generation
4
+ ## extended to all present in duet and autoformer folder
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+
10
+ from .duet.layers import Linear_extractor_cluster
11
+ from .duet.masked import Mahalanobis_mask, Encoder, EncoderLayer, FullAttention, AttentionLayer
12
+ from einops import rearrange
13
+
14
+ try:
15
+ import lightning.pytorch as pl
16
+ from .base_v2 import Base
17
+ OLD_PL = False
18
+ except:
19
+ import pytorch_lightning as pl
20
+ OLD_PL = True
21
+ from .base import Base
22
+ from .utils import QuantileLossMO,Permute, get_activation
23
+
24
+ from typing import List, Union
25
+ from ..data_structure.utils import beauty_string
26
+ from .utils import get_scope
27
+ from .utils import Embedding_cat_variables
28
+
29
+
30
+
31
+
32
+ class Duet(Base):
33
+ handle_multivariate = True
34
+ handle_future_covariates = True
35
+ handle_categorical_variables = True
36
+ handle_quantile_loss = True
37
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
38
+
39
+ def __init__(self,
40
+
41
+ factor:int,
42
+ d_model: int,
43
+ n_head: int,
44
+ n_layer: int,
45
+ CI: bool,
46
+ d_ff: int,
47
+ noisy_gating:bool,
48
+ num_experts: int,
49
+ kernel_size:int,
50
+ hidden_size:int,
51
+ k: int,
52
+ dropout_rate: float=0.1,
53
+ activation: str='',
54
+ **kwargs)->None:
55
+ """Initializes the model with the specified parameters. https://github.com/decisionintelligence/DUET
56
+
57
+ Args:
58
+ factor (int): The factor for attention scaling. NOT USED but in the original implementation
59
+ d_model (int): The dimensionality of the model.
60
+ n_head (int): The number of attention heads.
61
+ n_layer (int): The number of layers in the encoder.
62
+ CI (bool): Perform channel independent operations.
63
+ d_ff (int): The dimensionality of the feedforward layer.
64
+ noisy_gating (bool): Flag to indicate if noisy gating is used.
65
+ num_experts (int): The number of experts in the mixture of experts.
66
+ kernel_size (int): The size of the convolutional kernel.
67
+ hidden_size (int): The size of the hidden layer.
68
+ k (int): The number of clusters for the linear extractor.
69
+ dropout_rate (float, optional): The dropout rate. Defaults to 0.1.
70
+ activation (str, optional): The activation function to use. Defaults to ''.
71
+ **kwargs: Additional keyword arguments.
72
+
73
+ Raises:
74
+ ValueError: If the activation function is not recognized.
75
+
76
+ """
77
+
78
+ super().__init__(**kwargs)
79
+ if activation == 'torch.nn.SELU':
80
+ beauty_string('SELU do not require BN','info',self.verbose)
81
+ use_bn = False
82
+ if isinstance(activation,str):
83
+ activation = get_activation(activation)
84
+ self.save_hyperparameters(logger=False)
85
+
86
+
87
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
88
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
89
+ emb_past_out_channel = self.emb_past.output_channels
90
+ emb_fut_out_channel = self.emb_fut.output_channels
91
+
92
+
93
+
94
+
95
+
96
+ self.cluster = Linear_extractor_cluster(noisy_gating,
97
+ num_experts,
98
+ self.past_steps,
99
+ k,
100
+ d_model,
101
+ self.past_channels+emb_past_out_channel,
102
+ CI,kernel_size,
103
+ hidden_size)
104
+ self.CI = CI
105
+ self.n_vars = self.out_channels
106
+ self.mask_generator = Mahalanobis_mask(self.future_steps)
107
+ self.Channel_transformer = Encoder(
108
+ [
109
+ EncoderLayer(
110
+ AttentionLayer(
111
+ FullAttention(
112
+ True,
113
+ factor,
114
+ attention_dropout=dropout_rate,
115
+ output_attention=0,
116
+ ),
117
+ d_model,
118
+ n_head,
119
+ ),
120
+ d_model,
121
+ d_ff,
122
+ dropout=dropout_rate,
123
+ activation=activation,
124
+ )
125
+ for _ in range(n_layer)
126
+ ],
127
+ norm_layer=torch.nn.LayerNorm(d_model)
128
+ )
129
+
130
+ self.linear_head = nn.Sequential(nn.Linear(d_model, self.future_steps), nn.Dropout(dropout_rate))
131
+
132
+
133
+ dim = self.past_channels+emb_past_out_channel+emb_fut_out_channel+self.future_channels
134
+ self.final_layer = nn.Sequential(activation(),
135
+ nn.Linear(dim, dim*2),
136
+ activation(),
137
+ nn.Linear(dim*2,self.out_channels*self.mul ))
138
+
139
+
140
+ def forward(self, batch:dict)-> float:
141
+ # x: [Batch, Input length, Channel]
142
+ x_enc = batch['x_num_past'].to(self.device)
143
+ idx_target = batch['idx_target'][0]
144
+ BS = x_enc.shape[0]
145
+
146
+ if 'x_cat_past' in batch.keys():
147
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
148
+ else:
149
+ emb_past = self.emb_past(BS,None)
150
+
151
+
152
+ if 'x_cat_future' in batch.keys():
153
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
154
+ else:
155
+ emb_fut = self.emb_fut(BS,None)
156
+
157
+ tmp_future = [emb_fut]
158
+ if 'x_num_future' in batch.keys():
159
+ x_future = batch['x_num_future'].to(self.device)
160
+ tmp_future.append(x_future)
161
+
162
+ x_enc = torch.concat([x_enc,emb_past],axis=-1)
163
+
164
+ if self.CI:
165
+ channel_independent_input = rearrange(x_enc, 'b l n -> (b n) l 1')
166
+
167
+ reshaped_output, _ = self.cluster(channel_independent_input)
168
+
169
+ temporal_feature = rearrange(reshaped_output, '(b n) l 1 -> b l n', b=x_enc.shape[0])
170
+
171
+ else:
172
+ temporal_feature, _ = self.cluster(x_enc)
173
+
174
+ # B x d_model x n_vars -> B x n_vars x d_model
175
+ temporal_feature = rearrange(temporal_feature, 'b d n -> b n d')
176
+ if self.n_vars > 1:
177
+ changed_input = rearrange(x_enc, 'b l n -> b n l')
178
+ channel_mask = self.mask_generator(changed_input)
179
+
180
+ channel_group_feature, _ = self.Channel_transformer(x=temporal_feature, attn_mask=channel_mask)
181
+
182
+ output = self.linear_head(channel_group_feature)
183
+ else:
184
+ output = temporal_feature
185
+ output = self.linear_head(output)
186
+
187
+
188
+ output = rearrange(output, 'b n d -> b d n')
189
+ output = self.cluster.revin(output, "denorm")
190
+ tmp_future.append(output)
191
+ tmp_future = torch.cat(tmp_future,2)
192
+ output = self.final_layer(tmp_future)
193
+
194
+ return output.reshape(BS,self.future_steps,self.n_vars,self.mul)
195
+
196
+
197
+
@@ -0,0 +1,167 @@
1
+ ## Copyright https://github.com/thuml/iTransformer?tab=MIT-1-ov-file#readme
2
+ ## Modified for notation alignmenet and batch structure
3
+ ## extended to what inside itransformer folder
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .itransformer.Transformer_EncDec import Encoder, EncoderLayer
9
+ from .itransformer.SelfAttention_Family import FullAttention, AttentionLayer
10
+ from .itransformer.Embed import DataEmbedding_inverted
11
+
12
+ try:
13
+ import lightning.pytorch as pl
14
+ from .base_v2 import Base
15
+ OLD_PL = False
16
+ except:
17
+ import pytorch_lightning as pl
18
+ OLD_PL = True
19
+ from .base import Base
20
+ from .utils import QuantileLossMO,Permute, get_activation
21
+
22
+ from typing import List, Union
23
+ from ..data_structure.utils import beauty_string
24
+ from .utils import get_scope
25
+ from .utils import Embedding_cat_variables
26
+
27
+
28
+
29
+ class ITransformer(Base):
30
+ handle_multivariate = True
31
+ handle_future_covariates = True
32
+ handle_categorical_variables = True
33
+ handle_quantile_loss = True
34
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
35
+
36
+ def __init__(self,
37
+
38
+
39
+ # specific params
40
+ hidden_size:int,
41
+ d_model: int,
42
+ n_head: int,
43
+ n_layer_decoder: int,
44
+ use_norm: bool,
45
+ class_strategy: str = 'projection', #projection/average/cls_token
46
+ dropout_rate: float=0.1,
47
+ activation: str='',
48
+ **kwargs)->None:
49
+ """Initialize the ITransformer model for time series forecasting.
50
+
51
+ This class implements the Inverted Transformer architecture as described in the paper
52
+ "ITRANSFORMER: INVERTED TRANSFORMERS ARE EFFECTIVE FOR TIME SERIES FORECASTING"
53
+ (https://arxiv.org/pdf/2310.06625).
54
+
55
+ Args:
56
+ hidden_size (int): The first embedding size of the model ('r' in the paper).
57
+ d_model (int): The second embedding size (r^{tilda} in the model). Should be smaller than hidden_size.
58
+ n_head (int): The number of attention heads.
59
+ n_layer_decoder (int): The number of layers in the decoder.
60
+ use_norm (bool): Flag to indicate whether to use normalization.
61
+ class_strategy (str, optional): The strategy for classification, can be 'projection', 'average', or 'cls_token'. Defaults to 'projection'.
62
+ dropout_rate (float, optional): The dropout rate for regularization. Defaults to 0.1.
63
+ activation (str, optional): The activation function to be used. Defaults to ''.
64
+ **kwargs: Additional keyword arguments.
65
+
66
+ Raises:
67
+ ValueError: If the activation function is not recognized.
68
+ """
69
+
70
+
71
+
72
+
73
+ super().__init__(**kwargs)
74
+ if activation == 'torch.nn.SELU':
75
+ beauty_string('SELU do not require BN','info',self.verbose)
76
+ use_bn = False
77
+ if isinstance(activation,str):
78
+ activation = get_activation(activation)
79
+ self.save_hyperparameters(logger=False)
80
+
81
+ self.emb_past = Embedding_cat_variables(self.past_steps,self.emb_dim,self.embs_past, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
82
+ self.emb_fut = Embedding_cat_variables(self.future_steps,self.emb_dim,self.embs_fut, reduction_mode=self.reduction_mode,use_classical_positional_encoder=self.use_classical_positional_encoder,device = self.device)
83
+ emb_past_out_channel = self.emb_past.output_channels
84
+ emb_fut_out_channel = self.emb_fut.output_channels
85
+
86
+
87
+
88
+ self.output_attention = False## not need output attention
89
+ self.use_norm = use_norm
90
+ # Embedding
91
+ self.enc_embedding = DataEmbedding_inverted(self.past_steps, d_model, embed_type='what?', freq='what?', dropout=dropout_rate) ##embed, freq not used inside
92
+ self.class_strategy = class_strategy
93
+ # Encoder-only architecture
94
+ self.encoder = Encoder(
95
+ [
96
+ EncoderLayer(
97
+ AttentionLayer(
98
+ FullAttention(False, factor=0.1, attention_dropout=dropout_rate, ##factor is not used in the Full attention
99
+ output_attention=self.output_attention), d_model, n_head), ## not need output attention
100
+ d_model,
101
+ hidden_size,
102
+ dropout = dropout_rate,
103
+ activation = activation()
104
+ ) for l in range(n_layer_decoder)
105
+ ],
106
+ norm_layer=torch.nn.LayerNorm(d_model)
107
+ )
108
+ self.projector = nn.Linear(d_model, self.future_steps*self.mul, bias=True)
109
+
110
+ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
111
+ if self.use_norm:
112
+ # Normalization from Non-stationary Transformer
113
+ means = x_enc.mean(1, keepdim=True).detach()
114
+ x_enc = x_enc - means
115
+ stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
116
+ x_enc /= stdev
117
+
118
+ _, _, N = x_enc.shape # B L N
119
+ # B: batch_size; E: d_model;
120
+ # L: seq_len; S: pred_len;
121
+ # N: number of variate (tokens), can also includes covariates
122
+
123
+ # Embedding
124
+ # B L N -> B N E (B L N -> B L E in the vanilla Transformer)
125
+ enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
126
+
127
+ # B N E -> B N E (B L E -> B L E in the vanilla Transformer)
128
+ # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
129
+ enc_out, attns = self.encoder(enc_out, attn_mask=None)
130
+
131
+ # B N E -> B N S -> B S N
132
+ dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
133
+
134
+ if self.use_norm:
135
+
136
+ # De-Normalization from Non-stationary Transformer
137
+ dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.future_steps*self.mul, 1))
138
+ dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.future_steps*self.mul, 1))
139
+
140
+
141
+ return dec_out
142
+
143
+ def forward(self, batch:dict)-> float:
144
+
145
+ x_enc = batch['x_num_past'].to(self.device)
146
+ BS = x_enc.shape[0]
147
+ if 'x_cat_future' in batch.keys():
148
+ emb_fut = self.emb_fut(BS,batch['x_cat_future'].to(self.device))
149
+ else:
150
+ emb_fut = self.emb_fut(BS,None)
151
+ if 'x_cat_past' in batch.keys():
152
+ emb_past = self.emb_past(BS,batch['x_cat_past'].to(self.device))
153
+ else:
154
+ emb_past = self.emb_past(BS,None)
155
+
156
+
157
+
158
+
159
+ ##row 124 Transformer/experiments/exp_long_term_forecasting.py ma in realta' NON USATO!
160
+ x_dec = torch.zeros(x_enc.shape[0],self.past_steps,self.out_channels).float().to(self.device)
161
+ x_dec = torch.cat([batch['y'].to(self.device), x_dec], dim=1).float()
162
+
163
+ dec_out = self.forecast(x_enc, emb_past, x_dec, emb_fut)
164
+ idx_target = batch['idx_target'][0]
165
+ return dec_out[:, :,idx_target].reshape(BS,self.future_steps,self.out_channels,self.mul)
166
+
167
+ #return dec_out[:, -self.pred_len:, :] # [B, L, D]
@@ -0,0 +1,180 @@
1
+
2
+ ## Copyright 2020 Informer (hhttps://github.com/zhouhaoyi/Informer2020/tree/main/models)
3
+ ## Code modified for align the notation and the batch generation
4
+ ## extended to all present in informer, autoformer folder
5
+ from torch import nn
6
+ import torch
7
+
8
+ try:
9
+ import lightning.pytorch as pl
10
+ from .base_v2 import Base
11
+ OLD_PL = False
12
+ except:
13
+ import pytorch_lightning as pl
14
+ OLD_PL = True
15
+ from .base import Base
16
+ from typing import List,Union
17
+
18
+ from .informer.encoder import Encoder, EncoderLayer, ConvLayer
19
+ from .informer.decoder import Decoder, DecoderLayer
20
+ from .informer.attn import FullAttention, ProbAttention, AttentionLayer
21
+ from .informer.embed import DataEmbedding
22
+ from ..data_structure.utils import beauty_string
23
+ #from .utils import Embedding_cat_variables not used here, custom cat embedding
24
+ from .utils import get_scope
25
+
26
+
27
+
28
+ class Informer(Base):
29
+ handle_multivariate = True
30
+ handle_future_covariates = True
31
+ handle_categorical_variables = True
32
+ handle_quantile_loss = True
33
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
34
+
35
+
36
+ def __init__(self,
37
+ d_model:int,
38
+ hidden_size:int,
39
+ n_layer_encoder:int,
40
+ n_layer_decoder:int,
41
+ mix:bool=True,
42
+ activation:str='torch.nn.ReLU',
43
+ remove_last = False,
44
+ attn: str='prob',
45
+ distil:bool=True,
46
+ factor:int=5,
47
+ n_head:int=1,
48
+ dropout_rate:float=0.1,
49
+
50
+ **kwargs)->None:
51
+ """Initialize the model with specified parameters. hhttps://github.com/zhouhaoyi/Informer2020/tree/main/models
52
+
53
+ Args:
54
+ d_model (int): The dimensionality of the model.
55
+ hidden_size (int): The size of the hidden layers.
56
+ n_layer_encoder (int): The number of layers in the encoder.
57
+ n_layer_decoder (int): The number of layers in the decoder.
58
+ mix (bool, optional): Whether to use mixed attention. Defaults to True.
59
+ activation (str, optional): The activation function to use. Defaults to 'torch.nn.ReLU'.
60
+ remove_last (bool, optional): Whether to remove the last layer. Defaults to False.
61
+ attn (str, optional): The type of attention mechanism to use. Defaults to 'prob'.
62
+ distil (bool, optional): Whether to use distillation. Defaults to True.
63
+ factor (int, optional): The factor for attention. Defaults to 5.
64
+ n_head (int, optional): The number of attention heads. Defaults to 1.
65
+ dropout_rate (float, optional): The dropout rate. Defaults to 0.1.
66
+ **kwargs: Additional keyword arguments.
67
+
68
+ Raises:
69
+ ValueError: If any of the parameters are invalid.
70
+
71
+ Notes:
72
+ Ensure to set up split_params: shift: ${model_configs.future_steps} as it is required!!
73
+ """
74
+
75
+ super().__init__(**kwargs)
76
+ self.save_hyperparameters(logger=False)
77
+ beauty_string("BE SURE TO SETUP split_params: shift: ${model_configs.future_steps} BECAUSE IT IS REQUIRED",'info',True)
78
+
79
+ self.remove_last = remove_last
80
+
81
+
82
+ self.enc_embedding = DataEmbedding(self.past_channels, d_model, self.embs_past, dropout_rate)
83
+ self.dec_embedding = DataEmbedding(self.future_channels, d_model, self.embs_fut, dropout_rate)
84
+ # Attention
85
+ Attn = ProbAttention if attn=='prob' else FullAttention
86
+ # Encoder
87
+ self.encoder = Encoder(
88
+ [
89
+ EncoderLayer(
90
+ AttentionLayer(Attn(False, factor, attention_dropout=dropout_rate, output_attention=False),
91
+ d_model, n_head, mix=False),
92
+ d_model,
93
+ hidden_size,
94
+ dropout=dropout_rate,
95
+ activation=activation
96
+ ) for _ in range(n_layer_encoder)
97
+ ],
98
+ [
99
+ ConvLayer(
100
+ d_model
101
+ ) for _ in range(n_layer_encoder-1)
102
+ ] if distil else None,
103
+ norm_layer=torch.nn.LayerNorm(d_model)
104
+ )
105
+ # Decoder
106
+ self.decoder = Decoder(
107
+ [
108
+ DecoderLayer(
109
+ AttentionLayer(Attn(True, factor, attention_dropout=dropout_rate, output_attention=False),
110
+ d_model, n_head, mix=mix),
111
+ AttentionLayer(FullAttention(False, factor, attention_dropout=dropout_rate, output_attention=False),
112
+ d_model, n_head, mix=False),
113
+ d_model,
114
+ hidden_size,
115
+ dropout=dropout_rate,
116
+ activation=activation,
117
+ )
118
+ for _ in range(n_layer_decoder)
119
+ ],
120
+ norm_layer=torch.nn.LayerNorm(d_model)
121
+ )
122
+
123
+ self.projection = nn.Linear(d_model, self.out_channels*self.mul, bias=True)
124
+
125
+
126
+
127
+
128
+
129
+ def forward(self,batch):
130
+ #x_enc, x_mark_enc, x_dec, x_mark_dec,enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
131
+
132
+ x_enc = batch['x_num_past'].to(self.device)
133
+ idx_target_future = batch['idx_target_future'][0]
134
+
135
+ if 'x_cat_past' in batch.keys():
136
+ x_mark_enc = batch['x_cat_past'].to(self.device)
137
+ else:
138
+ x_mark_enc = None
139
+
140
+ enc_self_mask = None
141
+
142
+ x_dec = batch['x_num_future'].to(self.device)
143
+ x_dec[:,-self.future_steps:,idx_target_future] = 0
144
+
145
+
146
+ if 'x_cat_future' in batch.keys():
147
+ x_mark_dec = batch['x_cat_future'].to(self.device)
148
+ else:
149
+ x_mark_dec = None
150
+ dec_self_mask = None
151
+ dec_enc_mask = None
152
+
153
+
154
+ if self.remove_last:
155
+ idx_target = batch['idx_target'][0]
156
+ x_start = x_enc[:,-1,idx_target].unsqueeze(1)
157
+ x_enc[:,:,idx_target]-=x_start
158
+
159
+
160
+ enc_out = self.enc_embedding(x_enc, x_mark_enc)
161
+ enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
162
+
163
+ dec_out = self.dec_embedding(x_dec, x_mark_dec)
164
+ dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
165
+ dec_out = self.projection(dec_out)
166
+
167
+ # dec_out = self.end_conv1(dec_out)
168
+ # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2)
169
+
170
+ #import pdb
171
+ #pdb.set_trace()
172
+ res = dec_out[:,-self.future_steps:,:].unsqueeze(3)
173
+ if self.remove_last:
174
+ res+=x_start.unsqueeze(1)
175
+ BS = res.shape[0]
176
+ return res.reshape(BS,self.future_steps,-1,self.mul)
177
+
178
+
179
+
180
+