dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
dsipts/models/VVA.py ADDED
@@ -0,0 +1,247 @@
1
+
2
+ from torch import nn
3
+ import torch
4
+
5
+ try:
6
+ import lightning.pytorch as pl
7
+ from .base_v2 import Base
8
+ OLD_PL = False
9
+ except:
10
+ import pytorch_lightning as pl
11
+ OLD_PL = True
12
+ from .base import Base
13
+ from typing import List, Union
14
+ from .vva.minigpt import Block
15
+ import math
16
+ from torch.nn import functional as F
17
+ from ..data_structure.utils import beauty_string
18
+ from .utils import get_scope
19
+
20
+ torch.autograd.set_detect_anomaly(True)
21
+
22
+
23
+
24
+ class VVA(Base):
25
+ handle_multivariate = False
26
+ handle_future_covariates = False
27
+ handle_categorical_variables = False
28
+ handle_quantile_loss = False
29
+ description = get_scope(handle_multivariate,handle_future_covariates,handle_categorical_variables,handle_quantile_loss)
30
+
31
+
32
+ def __init__(self,
33
+ past_steps:int,
34
+ future_steps:int,
35
+ past_channels:int,
36
+ future_channels:int,
37
+ embs:List[int],
38
+ d_model:int,
39
+ max_voc_size:int,
40
+ token_split: int,
41
+ num_layers:int,
42
+ dropout_rate:float,
43
+ n_heads:int,
44
+ out_channels:int,
45
+ persistence_weight:float=0.0,
46
+ loss_type: str='l1',
47
+ quantiles:List[int]=[],
48
+ optim:Union[str,None]=None,
49
+ optim_config:dict=None,
50
+ scheduler_config:dict=None,
51
+ **kwargs)->None:
52
+ """ Custom encoder-decoder
53
+
54
+ Args:
55
+ past_steps (int): number of past datapoints used
56
+ future_steps (int): number of future lag to predict
57
+ past_channels (int): number of numeric past variables, must be >0
58
+ future_channels (int): number of future numeric variables
59
+ embs (List): list of the initial dimension of the categorical variables
60
+ cat_emb_dim (int): final dimension of each categorical variable
61
+ hidden_RNN (int): hidden size of the RNN block
62
+ num_layers_RNN (int): number of RNN layers
63
+ kind (str): one among GRU or LSTM
64
+ kernel_size (int): kernel size in the encoder convolutional block
65
+ sum_emb (bool): if true the contribution of each embedding will be summed-up otherwise stacked
66
+ out_channels (int): number of output channels
67
+ activation (str, optional): activation fuction function pytorch. Default torch.nn.ReLU
68
+ remove_last (bool, optional): if True the model learns the difference respect to the last seen point
69
+ persistence_weight (float): weight controlling the divergence from persistence model. Default 0
70
+ loss_type (str, optional): this model uses custom losses or l1 or mse. Custom losses can be linear_penalization or exponential_penalization. Default l1,
71
+ quantiles (List[int], optional): we can use quantile loss il len(quantiles) = 0 (usually 0.1,0.5, 0.9) or L1loss in case len(quantiles)==0. Defaults to [].
72
+ dropout_rate (float, optional): dropout rate in Dropout layers
73
+ use_bn (bool, optional): if true BN layers will be added and dropouts will be removed
74
+ use_glu (bool,optional): use GLU for feature selection. Defaults to True.
75
+ glu_percentage (float, optiona): percentage of features to use. Defaults to 1.0.
76
+ n_classes (int): number of classes (0 in regression)
77
+ optim (str, optional): if not None it expects a pytorch optim method. Defaults to None that is mapped to Adam.
78
+ optim_config (dict, optional): configuration for Adam optimizer. Defaults to None.
79
+ scheduler_config (dict, optional): configuration for stepLR scheduler. Defaults to None.
80
+
81
+ """
82
+
83
+
84
+ super().__init__(**kwargs)
85
+ self.block_size = past_steps//token_split + future_steps//token_split -1
86
+ self.save_hyperparameters(logger=False)
87
+ self.sentence_length = future_steps//token_split
88
+
89
+ self.transformer = nn.ModuleDict(dict(
90
+ wte = nn.Embedding(max_voc_size, d_model),
91
+ wpe = nn.Embedding(self.block_size, d_model),
92
+ drop = nn.Dropout(dropout_rate),
93
+ h = nn.ModuleList([Block( d_model,dropout_rate,n_heads,dropout_rate,self.block_size) for _ in range(num_layers)]), ##care can be different dropouts
94
+ ln_f = nn.LayerNorm(d_model),
95
+ ))
96
+ self.lm_head = nn.Linear(d_model, max_voc_size, bias=False)
97
+
98
+
99
+ for pn, p in self.named_parameters():
100
+ if pn.endswith('c_proj.weight'):
101
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layers))
102
+
103
+ # report number of parameters (note we don't count the decoder parameters in lm_head)
104
+ n_params = sum(p.numel() for p in self.transformer.parameters())
105
+ beauty_string("number of parameters: %.2fM" % (n_params/1e6,),'info',self.verbose)
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+ self.use_quantiles = True
114
+ self.is_classification = True
115
+ self.scheduler_config = scheduler_config
116
+ self.optim_config = optim_config
117
+ self.optim = self.scheduler_config = self.configure_optimizers()
118
+
119
+
120
+ def configure_optimizers(self):
121
+ """
122
+ This long function is unfortunately doing something very simple and is being very defensive:
123
+ We are separating out all parameters of the model into two buckets: those that will experience
124
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
125
+ We are then returning the PyTorch optimizer object.
126
+ """
127
+
128
+ # separate out all parameters to those that will and won't experience regularizing weight decay
129
+ decay = set()
130
+ no_decay = set()
131
+ whitelist_weight_modules = (torch.nn.Linear, )
132
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
133
+ for mn, m in self.named_modules():
134
+ for pn, p in m.named_parameters():
135
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
136
+ # random note: because named_modules and named_parameters are recursive
137
+ # we will see the same tensors p many many times. but doing it this way
138
+ # allows us to know which parent module any tensor p belongs to...
139
+ if pn.endswith('bias'):
140
+ # all biases will not be decayed
141
+ no_decay.add(fpn)
142
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
143
+ # weights of whitelist modules will be weight decayed
144
+ decay.add(fpn)
145
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
146
+ # weights of blacklist modules will NOT be weight decayed
147
+ no_decay.add(fpn)
148
+
149
+ # validate that we considered every parameter
150
+ param_dict = {pn: p for pn, p in self.named_parameters()}
151
+ inter_params = decay & no_decay
152
+ union_params = decay | no_decay
153
+ assert len(inter_params) == 0, beauty_string(f"parameters {inter_params} made it into both decay/no_decay sets!",'section' ,True)
154
+ assert len(param_dict.keys() - union_params) == 0, beauty_string(f"parameters {param_dict.keys() - union_params} were not separated into either decay/no_decay set!",'section',True)
155
+
156
+ # create the pytorch optimizer object
157
+ optim_groups = [
158
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.optim_config.weight_decay},
159
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
160
+ ]
161
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.optim_config.lr, betas=self.optim_config.betas)
162
+ return optimizer
163
+
164
+ def compute_loss(self,batch,y_hat):
165
+ """
166
+ custom loss calculation
167
+
168
+ :meta private:
169
+ """
170
+ return F.cross_entropy(y_hat.view(-1, y_hat.size(-1)), batch['y_emb'].view(-1), ignore_index=-1)
171
+
172
+
173
+ def forward(self, batch):
174
+ b, t = batch['x_emb'].size()
175
+ assert t <= self.block_size, beauty_string("Cannot forward sequence of length {t}, block size is only {self.block_size}",'section',True)
176
+ pos = torch.arange(0, t, dtype=torch.long, device=self.device).unsqueeze(0) # shape (1, t)
177
+
178
+ # forward the GPT model itself
179
+ tok_emb = self.transformer.wte(batch['x_emb']) # token embeddings of shape (b, t, n_embd)
180
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
181
+ x = self.transformer.drop(tok_emb + pos_emb)
182
+ for block in self.transformer.h:
183
+ x = block(x)
184
+ x = self.transformer.ln_f(x)
185
+ logits = self.lm_head(x)
186
+
187
+ return logits
188
+
189
+
190
+ def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None,num_samples=100):
191
+ """
192
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
193
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
194
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
195
+ """
196
+ if do_sample:
197
+ idx = idx.repeat(num_samples,1,1)
198
+ for _ in range(max_new_tokens):
199
+ tmp = []
200
+ for i in range(num_samples):
201
+ idx_cond = idx[i,:,:] if idx.size(2) <= self.block_size else idx[i,:, -self.block_size:]
202
+ logits = self({'x_emb':idx_cond})
203
+ logits = logits[:, -1, :] / temperature
204
+ if top_k is not None:
205
+ v, _ = torch.topk(logits, top_k)
206
+ logits[logits < v[:, [-1]]] = -float('Inf')
207
+ probs = F.softmax(logits, dim=-1)
208
+ idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
209
+ tmp.append(idx_next)
210
+ tmp = torch.cat(tmp,dim=1).T.unsqueeze(2)
211
+ idx = torch.cat((idx, tmp), dim=2)
212
+ return idx
213
+ else:
214
+ for _ in range(max_new_tokens):
215
+
216
+ # if the sequence context is growing too long we must crop it at block_size
217
+ idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
218
+ # forward the model to get the logits for the index in the sequence
219
+ logits = self({'x_emb':idx_cond})
220
+ # pluck the logits at the final step and scale by desired temperature
221
+ logits = logits[:, -1, :] / temperature
222
+ # optionally crop the logits to only the top k options
223
+ if top_k is not None:
224
+ v, _ = torch.topk(logits, top_k)
225
+ logits[logits < v[:, [-1]]] = -float('Inf')
226
+ # apply softmax to convert logits to (normalized) probabilities
227
+ probs = F.softmax(logits, dim=-1)
228
+ # either sample from the distribution or take the most likely element
229
+ _, idx_next = torch.topk(probs, k=1, dim=-1)
230
+ # append sampled index to the running sequence and continue
231
+ idx = torch.cat((idx, idx_next), dim=1)
232
+
233
+ return idx.unsqueeze(0)
234
+
235
+ def inference(self, batch:dict)->torch.tensor:
236
+ x = batch['x_emb'].to(self.device)
237
+
238
+ # isolate the input pattern alone
239
+ inp = x[:, :self.sentence_length]
240
+
241
+ # let the model sample the rest of the sequence
242
+ cat = self.generate(inp, self.sentence_length, do_sample=True,num_samples=3) # using greedy argmax, not samplingv ##todo here add sampling
243
+ sol_candidate = cat[:,:, self.sentence_length:]
244
+
245
+
246
+ return sol_candidate.permute(1,2,0)
247
+
File without changes
File without changes
@@ -0,0 +1,352 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class AutoCorrelation(nn.Module):
7
+ """
8
+ AutoCorrelation Mechanism with the following two phases:
9
+ (1) period-based dependencies discovery
10
+ (2) time delay aggregation
11
+ This block can replace the self-attention family mechanism seamlessly.
12
+ """
13
+ def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
14
+ super(AutoCorrelation, self).__init__()
15
+ self.factor = factor
16
+ self.scale = scale
17
+ self.mask_flag = mask_flag
18
+ self.output_attention = output_attention
19
+ self.dropout = nn.Dropout(attention_dropout)
20
+
21
+ def time_delay_agg_training(self, values, corr):
22
+ """
23
+ SpeedUp version of Autocorrelation (a batch-normalization style design)
24
+ This is for the training phase.
25
+ """
26
+ head = values.shape[1]
27
+ channel = values.shape[2]
28
+ length = values.shape[3]
29
+ # find top k
30
+ top_k = int(self.factor * math.log(length))
31
+ mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
32
+ index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
33
+ weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
34
+ # update corr
35
+ tmp_corr = torch.softmax(weights, dim=-1)
36
+ # aggregation
37
+ tmp_values = values
38
+ delays_agg = torch.zeros_like(values).float()
39
+ for i in range(top_k):
40
+ pattern = torch.roll(tmp_values, -int(index[i]), -1)
41
+ delays_agg = delays_agg + pattern * \
42
+ (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
43
+ return delays_agg
44
+
45
+ def time_delay_agg_inference(self, values, corr):
46
+ """
47
+ SpeedUp version of Autocorrelation (a batch-normalization style design)
48
+ This is for the inference phase.
49
+ """
50
+ batch = values.shape[0]
51
+ head = values.shape[1]
52
+ channel = values.shape[2]
53
+ length = values.shape[3]
54
+ # index init
55
+ init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(self.device)
56
+ # find top k
57
+ top_k = int(self.factor * math.log(length))
58
+ mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
59
+ weights = torch.topk(mean_value, top_k, dim=-1)[0]
60
+ delay = torch.topk(mean_value, top_k, dim=-1)[1]
61
+ # update corr
62
+ tmp_corr = torch.softmax(weights, dim=-1)
63
+ # aggregation
64
+ tmp_values = values.repeat(1, 1, 1, 2)
65
+ delays_agg = torch.zeros_like(values).float()
66
+ for i in range(top_k):
67
+ tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
68
+ pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
69
+ delays_agg = delays_agg + pattern * \
70
+ (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
71
+ return delays_agg
72
+
73
+ def time_delay_agg_full(self, values, corr):
74
+ """
75
+ Standard version of Autocorrelation
76
+ """
77
+ batch = values.shape[0]
78
+ head = values.shape[1]
79
+ channel = values.shape[2]
80
+ length = values.shape[3]
81
+ # index init
82
+ init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(self.device)
83
+ # find top k
84
+ top_k = int(self.factor * math.log(length))
85
+ weights = torch.topk(corr, top_k, dim=-1)[0]
86
+ delay = torch.topk(corr, top_k, dim=-1)[1]
87
+ # update corr
88
+ tmp_corr = torch.softmax(weights, dim=-1)
89
+ # aggregation
90
+ tmp_values = values.repeat(1, 1, 1, 2)
91
+ delays_agg = torch.zeros_like(values).float()
92
+ for i in range(top_k):
93
+ tmp_delay = init_index + delay[..., i].unsqueeze(-1)
94
+ pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
95
+ delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
96
+ return delays_agg
97
+
98
+ def forward(self, queries, keys, values, attn_mask):
99
+ B, L, H, E = queries.shape
100
+ _, S, _, D = values.shape
101
+ if L > S:
102
+ zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
103
+ values = torch.cat([values, zeros], dim=1)
104
+ keys = torch.cat([keys, zeros], dim=1)
105
+ else:
106
+ values = values[:, :L, :, :]
107
+ keys = keys[:, :L, :, :]
108
+
109
+ # period-based dependencies
110
+ q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
111
+ k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
112
+ res = q_fft * torch.conj(k_fft)
113
+ corr = torch.fft.irfft(res, dim=-1)
114
+
115
+ # time delay agg
116
+ if self.training:
117
+ V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
118
+ else:
119
+ V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
120
+
121
+ if self.output_attention:
122
+ return (V.contiguous(), corr.permute(0, 3, 1, 2))
123
+ else:
124
+ return (V.contiguous(), None)
125
+
126
+
127
+ class AutoCorrelationLayer(nn.Module):
128
+ def __init__(self, correlation, d_model, n_heads, d_keys=None,
129
+ d_values=None):
130
+ super(AutoCorrelationLayer, self).__init__()
131
+
132
+ d_keys = d_keys or (d_model // n_heads)
133
+ d_values = d_values or (d_model // n_heads)
134
+
135
+ self.inner_correlation = correlation
136
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
137
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
138
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
139
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
140
+ self.n_heads = n_heads
141
+
142
+ def forward(self, queries, keys, values, attn_mask):
143
+ self.inner_correlation.device = queries.device
144
+ B, L, _ = queries.shape
145
+ _, S, _ = keys.shape
146
+ H = self.n_heads
147
+
148
+ queries = self.query_projection(queries).view(B, L, H, -1)
149
+ keys = self.key_projection(keys).view(B, S, H, -1)
150
+ values = self.value_projection(values).view(B, S, H, -1)
151
+
152
+ out, attn = self.inner_correlation(
153
+ queries,
154
+ keys,
155
+ values,
156
+ attn_mask
157
+ )
158
+ out = out.view(B, L, -1)
159
+
160
+ return self.out_projection(out), attn
161
+
162
+
163
+
164
+
165
+ class my_Layernorm(nn.Module):
166
+ """
167
+ Special designed layernorm for the seasonal part
168
+ """
169
+ def __init__(self, channels):
170
+ super(my_Layernorm, self).__init__()
171
+ self.layernorm = nn.LayerNorm(channels)
172
+
173
+ def forward(self, x):
174
+ x_hat = self.layernorm(x)
175
+ bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
176
+ return x_hat - bias
177
+
178
+
179
+ class moving_avg(nn.Module):
180
+ """
181
+ Moving average block to highlight the trend of time series
182
+ """
183
+ def __init__(self, kernel_size, stride):
184
+ super(moving_avg, self).__init__()
185
+ self.kernel_size = kernel_size
186
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
187
+
188
+ def forward(self, x):
189
+ # padding on the both ends of time series
190
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
191
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
192
+ x = torch.cat([front, x, end], dim=1)
193
+ x = self.avg(x.permute(0, 2, 1))
194
+ x = x.permute(0, 2, 1)
195
+ return x
196
+
197
+
198
+ class series_decomp(nn.Module):
199
+ """
200
+ Series decomposition block
201
+ """
202
+ def __init__(self, kernel_size):
203
+ super(series_decomp, self).__init__()
204
+ self.moving_avg = moving_avg(kernel_size, stride=1)
205
+
206
+ def forward(self, x):
207
+ moving_mean = self.moving_avg(x)
208
+ res = x - moving_mean
209
+ return res, moving_mean
210
+
211
+
212
+ class EncoderLayer(nn.Module):
213
+ """
214
+ Autoformer encoder layer with the progressive decomposition architecture
215
+ """
216
+ def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
217
+ super(EncoderLayer, self).__init__()
218
+ d_ff = d_ff or 4 * d_model
219
+ self.attention = attention
220
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
221
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
222
+ self.decomp1 = series_decomp(moving_avg)
223
+ self.decomp2 = series_decomp(moving_avg)
224
+ self.dropout = nn.Dropout(dropout)
225
+ self.activation =activation()
226
+
227
+ def forward(self, x, attn_mask=None):
228
+ new_x, attn = self.attention(
229
+ x, x, x,
230
+ attn_mask=attn_mask
231
+ )
232
+ x = x + self.dropout(new_x)
233
+ x, _ = self.decomp1(x)
234
+ y = x
235
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
236
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
237
+ res, _ = self.decomp2(x + y)
238
+ return res, attn
239
+
240
+
241
+ class Encoder(nn.Module):
242
+ """
243
+ Autoformer encoder
244
+ """
245
+ def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
246
+ super(Encoder, self).__init__()
247
+ self.attn_layers = nn.ModuleList(attn_layers)
248
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
249
+ self.norm = norm_layer
250
+
251
+ def forward(self, x, attn_mask=None):
252
+ attns = []
253
+ if self.conv_layers is not None:
254
+ for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
255
+ x, attn = attn_layer(x, attn_mask=attn_mask)
256
+ x = conv_layer(x)
257
+ attns.append(attn)
258
+ x, attn = self.attn_layers[-1](x)
259
+ attns.append(attn)
260
+ else:
261
+ for attn_layer in self.attn_layers:
262
+ x, attn = attn_layer(x, attn_mask=attn_mask)
263
+ attns.append(attn)
264
+
265
+ if self.norm is not None:
266
+ x = self.norm(x)
267
+
268
+ return x, attns
269
+
270
+
271
+ class DecoderLayer(nn.Module):
272
+ """
273
+ Autoformer decoder layer with the progressive decomposition architecture
274
+ """
275
+ def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
276
+ moving_avg=25, dropout=0.1, activation="relu"):
277
+ super(DecoderLayer, self).__init__()
278
+ d_ff = d_ff or 4 * d_model
279
+ self.self_attention = self_attention
280
+ self.cross_attention = cross_attention
281
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
282
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
283
+ self.decomp1 = series_decomp(moving_avg)
284
+ self.decomp2 = series_decomp(moving_avg)
285
+ self.decomp3 = series_decomp(moving_avg)
286
+ self.dropout = nn.Dropout(dropout)
287
+ self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
288
+ padding_mode='circular', bias=False)
289
+ self.activation = activation()
290
+
291
+ def forward(self, x, cross, x_mask=None, cross_mask=None):
292
+ self.self_attention.device = x.device
293
+ x = x + self.dropout(self.self_attention(
294
+ x, x, x,
295
+ attn_mask=x_mask
296
+ )[0])
297
+ x, trend1 = self.decomp1(x)
298
+ x = x + self.dropout(self.cross_attention(
299
+ x, cross, cross,
300
+ attn_mask=cross_mask
301
+ )[0])
302
+ x, trend2 = self.decomp2(x)
303
+ y = x
304
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
305
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
306
+ x, trend3 = self.decomp3(x + y)
307
+
308
+ residual_trend = trend1 + trend2 + trend3
309
+ residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
310
+ return x, residual_trend
311
+
312
+
313
+ class Decoder(nn.Module):
314
+ """
315
+ Autoformer encoder
316
+ """
317
+ def __init__(self, layers, norm_layer=None, projection=None):
318
+ super(Decoder, self).__init__()
319
+ self.layers = nn.ModuleList(layers)
320
+ self.norm = norm_layer
321
+ self.projection = projection
322
+
323
+ def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
324
+ for layer in self.layers:
325
+ x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
326
+ trend = trend + residual_trend
327
+
328
+ if self.norm is not None:
329
+ x = self.norm(x)
330
+
331
+ if self.projection is not None:
332
+ x = self.projection(x)
333
+ return x, trend
334
+
335
+ class PositionalEmbedding(nn.Module):
336
+ def __init__(self, d_model, max_len=5000):
337
+ super(PositionalEmbedding, self).__init__()
338
+ # Compute the positional encodings once in log space.
339
+ pe = torch.zeros(max_len, d_model).float()
340
+ pe.require_grad = False
341
+
342
+ position = torch.arange(0, max_len).float().unsqueeze(1)
343
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
344
+
345
+ pe[:, 0::2] = torch.sin(position * div_term)
346
+ pe[:, 1::2] = torch.cos(position * div_term)
347
+
348
+ pe = pe.unsqueeze(0)
349
+ self.register_buffer('pe', pe)
350
+
351
+ def forward(self, x):
352
+ return self.pe[:, :x.size(1)]