dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import Optimizer
5
+
6
+
7
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
8
+ """
9
+ A copy-paste from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
10
+ """
11
+ L, S = query.size(-2), key.size(-2)
12
+ scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
13
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
14
+ if is_causal:
15
+ assert attn_mask is None
16
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
17
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
18
+ attn_bias.to(query.dtype)
19
+
20
+ if attn_mask is not None:
21
+ if attn_mask.dtype == torch.bool:
22
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
23
+ else:
24
+ attn_bias += attn_mask
25
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
26
+ attn_weight += attn_bias
27
+ attn_weight = torch.softmax(attn_weight, dim=-1)
28
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
29
+ return attn_weight @ value
30
+
31
+
32
+
33
+
34
+ class RevIN(nn.Module):
35
+ """
36
+ Reversible Instance Normalization (RevIN) https://openreview.net/pdf?id=cGDAkQo1C0p
37
+ https://github.com/ts-kim/RevIN
38
+ """
39
+ def __init__(self, num_features: int, eps=1e-5, affine=True):
40
+ """
41
+ :param num_features: the number of features or channels
42
+ :param eps: a value added for numerical stability
43
+ :param affine: if True, RevIN has learnable affine parameters
44
+ """
45
+ super(RevIN, self).__init__()
46
+ self.num_features = num_features
47
+ self.eps = eps
48
+ self.affine = affine
49
+ if self.affine:
50
+ self._init_params()
51
+
52
+ def forward(self, x, mode:str):
53
+ if mode == 'norm':
54
+ self._get_statistics(x)
55
+ x = self._normalize(x)
56
+ elif mode == 'denorm':
57
+ x = self._denormalize(x)
58
+ else: raise NotImplementedError
59
+ return x
60
+
61
+ def _init_params(self):
62
+ # initialize RevIN params: (C,)
63
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
64
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
65
+
66
+ def _get_statistics(self, x):
67
+ dim2reduce = tuple(range(1, x.ndim-1))
68
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
69
+ self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
70
+
71
+ def _normalize(self, x):
72
+ x = x - self.mean
73
+ x = x / self.stdev
74
+ if self.affine:
75
+ x = x * self.affine_weight
76
+ x = x + self.affine_bias
77
+ return x
78
+
79
+ def _denormalize(self, x):
80
+ if self.affine:
81
+ x = x - self.affine_bias
82
+ x = x / (self.affine_weight + self.eps*self.eps)
83
+ x = x * self.stdev
84
+ x = x + self.mean
85
+ return x
86
+
87
+
88
+ class SAM(torch.optim.Optimizer):
89
+ def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
90
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
91
+
92
+ defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
93
+ super(SAM, self).__init__(params, defaults)
94
+
95
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
96
+ self.param_groups = self.base_optimizer.param_groups
97
+ self.defaults.update(self.base_optimizer.defaults)
98
+
99
+ @torch.no_grad()
100
+ def first_step(self, zero_grad=False):
101
+ grad_norm = self._grad_norm()
102
+ for group in self.param_groups:
103
+ scale = group["rho"] / (grad_norm + 1e-12)
104
+
105
+ for p in group["params"]:
106
+ if p.grad is None:
107
+ continue
108
+ self.state[p]["old_p"] = p.data.clone()
109
+ e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
110
+ p.add_(e_w) # Perturb weights in the gradient direction
111
+
112
+ if zero_grad:
113
+ self.zero_grad()
114
+
115
+ @torch.no_grad()
116
+ def second_step(self, zero_grad=False):
117
+ for group in self.param_groups:
118
+ for p in group["params"]:
119
+ if p.grad is None:
120
+ continue
121
+ p.data = self.state[p]["old_p"] # Restore original weights
122
+
123
+ self.base_optimizer.step() # Apply the sharpness-aware update
124
+
125
+ if zero_grad:
126
+ self.zero_grad()
127
+
128
+ @torch.no_grad()
129
+ def step(self, closure=None):
130
+ assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
131
+
132
+ with torch.enable_grad():
133
+ closure() # First forward-backward pass
134
+
135
+ self.first_step(zero_grad=True)
136
+
137
+ with torch.enable_grad():
138
+ closure() # Second forward-backward pass
139
+
140
+ self.second_step()
141
+
142
+ def _grad_norm(self):
143
+ shared_device = self.param_groups[0]["params"][0].device
144
+ grads = [
145
+ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
146
+ for group in self.param_groups for p in group["params"]
147
+ if p.grad is not None
148
+ ]
149
+ return torch.norm(torch.stack(grads), p=2) if grads else torch.tensor(0.0, device=shared_device)
150
+
151
+ def load_state_dict(self, state_dict):
152
+ super().load_state_dict(state_dict)
153
+ if hasattr(self, "base_optimizer"): # Ensure base optimizer exists
154
+ self.base_optimizer.load_state_dict(state_dict["base_optimizer"])
File without changes
@@ -0,0 +1,234 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Union
4
+
5
+ class embedding_cat_variables(nn.Module):
6
+ # at the moment cat_past and cat_fut together
7
+ def __init__(self, seq_len: int, lag: int, d_model: int, emb_dims: list, device):
8
+ """Class for embedding categorical variables, adding 3 positional variables during forward
9
+
10
+ Args:
11
+ seq_len (int): length of the sequence (sum of past and future steps)
12
+ lag (int): number of future step to be predicted
13
+ hiden_size (int): dimension of all variables after they are embedded
14
+ emb_dims (list): size of the dictionary for embedding. One dimension for each categorical variable
15
+ device : -
16
+ """
17
+ super().__init__()
18
+ self.seq_len = seq_len
19
+ self.lag = lag
20
+ self.device = device
21
+ self.cat_embeds = emb_dims + [seq_len, lag+1, 2] #
22
+ self.cat_n_embd = nn.ModuleList([
23
+ nn.Embedding(emb_dim, d_model) for emb_dim in self.cat_embeds
24
+ ])
25
+
26
+ def forward(self, x: Union[torch.Tensor,int],device:torch.device) -> torch.Tensor:
27
+ """All components of x are concatenated with 3 new variables for data augmentation, in the order:
28
+ - pos_seq: assign at each step its time-position
29
+ - pos_fut: assign at each step its future position. 0 if it is a past step
30
+ - is_fut: explicit for each step if it is a future(1) or past one(0)
31
+
32
+ Args:
33
+ x (torch.Tensor): [bs, seq_len, num_vars]
34
+
35
+ Returns:
36
+ torch.Tensor: [bs, seq_len, num_vars+3, n_embd]
37
+ """
38
+ if isinstance(x, int):
39
+ no_emb = True
40
+ B = x
41
+ else:
42
+ no_emb = False
43
+ B, _, _ = x.shape
44
+
45
+ pos_seq = self.get_pos_seq(bs=B).to(device)
46
+ pos_fut = self.get_pos_fut(bs=B).to(device)
47
+ is_fut = self.get_is_fut(bs=B).to(device)
48
+
49
+ if no_emb:
50
+ cat_vars = torch.cat((pos_seq, pos_fut, is_fut), dim=2)
51
+ else:
52
+ cat_vars = torch.cat((x, pos_seq, pos_fut, is_fut), dim=2)
53
+
54
+ cat_n_embd = self.get_cat_n_embd(cat_vars)
55
+ return cat_n_embd
56
+
57
+ def get_pos_seq(self, bs):
58
+ pos_seq = torch.arange(0, self.seq_len)
59
+ pos_seq = pos_seq.repeat(bs,1).unsqueeze(2).to(self.device)
60
+ return pos_seq
61
+
62
+ def get_pos_fut(self, bs):
63
+ pos_fut = torch.cat((torch.zeros((self.seq_len-self.lag), dtype=torch.long),torch.arange(1,self.lag+1)))
64
+ pos_fut = pos_fut.repeat(bs,1).unsqueeze(2).to(self.device)
65
+ return pos_fut
66
+
67
+ def get_is_fut(self, bs):
68
+ is_fut = torch.cat((torch.zeros((self.seq_len-self.lag), dtype=torch.long),torch.ones((self.lag), dtype=torch.long)))
69
+ is_fut = is_fut.repeat(bs,1).unsqueeze(2).to(self.device)
70
+ return is_fut
71
+
72
+ def get_cat_n_embd(self, cat_vars):
73
+ cat_n_embd = torch.Tensor().to(cat_vars.device)
74
+ for index, layer in enumerate(self.cat_n_embd):
75
+ emb = layer(cat_vars[:, :, index])
76
+ cat_n_embd = torch.cat((cat_n_embd, emb.unsqueeze(2)),dim=2)
77
+ return cat_n_embd
78
+
79
+ class LSTM_Model(nn.Module):
80
+ def __init__(self, num_var: int, d_model: int, pred_step: int, num_layers: int, dropout: float):
81
+ """LSTM from [..., d_model] to [..., predicted_step, num_of_vars]
82
+
83
+ Args:
84
+ num_var (int): number of variables encoded in the input tensor
85
+ d_model (int): encoding dimension of the tensor
86
+ pred_step (int): step to be predicted by LSTM
87
+ num_layers (int): number of layers of LSTM
88
+ dropout (float):
89
+ """
90
+
91
+ super().__init__()
92
+ self.num_var = num_var
93
+ self.d_model = d_model
94
+ self.num_layers = num_layers
95
+ self.pred_step = pred_step
96
+
97
+ self.lstm = nn.LSTM(d_model, d_model, num_layers=num_layers, batch_first=True, dropout=dropout)
98
+ self.linear = nn.Linear(d_model, pred_step*num_var)
99
+
100
+ def forward(self, x):
101
+ """LSTM process over the x tensor and reshaping according to pred_step and num_var to be predicted
102
+
103
+ Args:
104
+ x (torch.Tensor): input tensor
105
+
106
+ Returns:
107
+ torch.Tensor: tensor resized to [B, pred_step, num_var]
108
+ """
109
+
110
+ h0 = torch.zeros(self.num_layers, x.size(0), self.d_model).to(x.device)
111
+ c0 = torch.zeros(self.num_layers, x.size(0), self.d_model).to(x.device)
112
+ out, _ = self.lstm(x, (h0, c0))
113
+ out = self.linear(out[:, -1, :]) # Take the last output of the sequence
114
+ out = out.view(-1, self.pred_step, self.num_var)
115
+ return out
116
+
117
+ class GLU(nn.Module):
118
+ def __init__(self, d_model: int):
119
+ """Gated Linear Unit
120
+
121
+ Auxiliary subnet for sigmoid element-wise multiplication
122
+
123
+ Args:
124
+ d_model (int): dimension of operations
125
+ """
126
+
127
+ super().__init__()
128
+ self.linear1 = nn.Linear(d_model, d_model, bias = False)
129
+ self.linear2 = nn.Linear(d_model, d_model, bias = False)
130
+ self.sigmoid = nn.Sigmoid()
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ x1 = self.sigmoid(self.linear1(x))
134
+ x2 = self.linear2(x)
135
+ out = x1*x2 #element-wise multiplication
136
+ return out
137
+
138
+ class GRN(nn.Module):
139
+ def __init__(self, d_model: int, dropout_rate: float):
140
+ """Gated Residual Network
141
+
142
+ Auxiliary subnet for gating residual connections
143
+
144
+ Args:
145
+ d_model (int):
146
+ dropout_rate (float):
147
+ """
148
+
149
+ super().__init__()
150
+ self.linear1 = nn.Linear(d_model, d_model)
151
+ self.elu = nn.ELU()
152
+ self.linear2 = nn.Linear(d_model, d_model)
153
+ self.res_conn = ResidualConnection(d_model, dropout_rate)
154
+
155
+ def forward(self, x: torch.Tensor, using_norm:bool = True) -> torch.Tensor:
156
+ eta1 = self.elu(self.linear1(x))
157
+ eta2 = self.linear2(eta1)
158
+ out = self.res_conn(eta2, x, using_norm)
159
+ return out
160
+
161
+ class ResidualConnection(nn.Module):
162
+ def __init__(self, d_model, dropout_rate):
163
+ """Residual Connection of res_conn with GLU(x)
164
+
165
+ Auxiliary subnet for residual connections
166
+
167
+ Args:
168
+ d_model (int):
169
+ dropout_rate (float):
170
+ """
171
+
172
+ super().__init__()
173
+ self.dropout = nn.Dropout(dropout_rate)
174
+ self.glu = GLU(d_model)
175
+ self.norm = nn.LayerNorm(d_model)
176
+
177
+ def forward(self, x: torch.Tensor, res_conn: torch.Tensor, using_norm:bool = True) -> torch.Tensor:
178
+ """Res Cionnection using normalizing computatiion on 'x' and strict 'res_conn'
179
+
180
+ Args:
181
+ x (torch.Tensor): GLU(dropout(x))
182
+ res_conn (torch.Tensor): tensor summed to x before normalization
183
+ using_norm (bool, optional): _description_. Defaults to True.
184
+
185
+ Returns:
186
+ torch.Tensor:
187
+ """
188
+ x = self.glu(self.dropout(x))
189
+ out = res_conn + x
190
+ if using_norm:
191
+ out = self.norm(out)
192
+ return out
193
+
194
+ class InterpretableMultiHead(nn.Module):
195
+ def __init__(self, d_model, d_head, n_head):
196
+ """Interpretable MultiHead Attention
197
+
198
+ Similar to canonical MultiHead Attention with Query-Keys-Value structure
199
+ Particularities are:
200
+ - Only one common "Value"-Linear layer for all heads
201
+ - output of all heads are summed together and then rescaled over the number of heads
202
+ The final output tensor is re-embedded in the initial dimension
203
+
204
+ Args:
205
+ d_model (int): starting and ending dimension of the net
206
+ d_head (int): hidden dimension of all heads
207
+ n_head (int): number of heads
208
+ """
209
+
210
+ super().__init__()
211
+ self.d_head = d_head
212
+ self.n_head = n_head
213
+ self.Q_layers = nn.ModuleList([nn.Linear(d_model,d_head) for _ in range(n_head)])
214
+ self.K_layers = nn.ModuleList([nn.Linear(d_model,d_head) for _ in range(n_head)])
215
+ self.Softmax_layers = nn.ModuleList([nn.Softmax(dim=-1) for _ in range(n_head)])
216
+ self.V_layer = nn.Linear(d_model, d_head)
217
+ self.out_layer = nn.Linear(d_head, d_model)
218
+
219
+ def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor) -> torch.Tensor:
220
+ out = torch.Tensor()
221
+ for (q_layer, k_layer, softmax) in zip(self.Q_layers, self.K_layers, self.Softmax_layers):
222
+ Q = q_layer(query)
223
+ K = k_layer(key)
224
+ wei = Q @ K.transpose(-2,-1) * (self.d_head**-0.5)
225
+ wei = softmax(wei)
226
+ V = self.V_layer(value)
227
+ out_h = wei @ V
228
+ if out.shape[0]>0:
229
+ out = out + out_h # sum the result of the head attention
230
+ else:
231
+ out = out_h # out is not modifies/initialized yet
232
+ out = out / self.n_head
233
+ out = self.out_layer(out) # comeback to d_model dimension
234
+ return out
@@ -0,0 +1,127 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class PositionalEmbedding(nn.Module):
8
+ def __init__(self, d_model, max_len=5000):
9
+ super(PositionalEmbedding, self).__init__()
10
+ # Compute the positional encodings once in log space.
11
+ pe = torch.zeros(max_len, d_model).float()
12
+ pe.require_grad = False
13
+
14
+ position = torch.arange(0, max_len).float().unsqueeze(1)
15
+ div_term = (torch.arange(0, d_model, 2).float()
16
+ * -(math.log(10000.0) / d_model)).exp()
17
+
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+
21
+ pe = pe.unsqueeze(0)
22
+ self.register_buffer('pe', pe)
23
+
24
+ def forward(self, x):
25
+ return self.pe[:, :x.size(1)]
26
+
27
+ class EnEmbedding(nn.Module):
28
+ def __init__(self, n_vars, d_model, patch_len, dropout):
29
+ super(EnEmbedding, self).__init__()
30
+ # Patching
31
+ self.patch_len = patch_len
32
+
33
+ self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
34
+ self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
35
+ self.position_embedding = PositionalEmbedding(d_model)
36
+ self.dropout = nn.Dropout(dropout)
37
+
38
+ def forward(self, x):
39
+ # do patching
40
+ n_vars = x.shape[1]
41
+ glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
42
+
43
+ x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
44
+ x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
45
+ # Input encoding
46
+ x = self.value_embedding(x) + self.position_embedding(x)
47
+ x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
48
+ x = torch.cat([x, glb], dim=2)
49
+ x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
50
+ return self.dropout(x), n_vars
51
+
52
+ class FlattenHead(nn.Module):
53
+ def __init__(self, n_vars, nf, target_window, head_dropout=0):
54
+ super().__init__()
55
+ self.n_vars = n_vars
56
+ self.flatten = nn.Flatten(start_dim=-2)
57
+ self.linear = nn.Linear(nf, target_window)
58
+ self.dropout = nn.Dropout(head_dropout)
59
+
60
+ def forward(self, x): # x: [bs x nvars x d_model x patch_num]
61
+ x = self.flatten(x)
62
+ x = self.linear(x)
63
+ x = self.dropout(x)
64
+ return x
65
+
66
+
67
+
68
+ class EncoderLayer(nn.Module):
69
+ def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
70
+ dropout=0.1, activation="relu"):
71
+ super(EncoderLayer, self).__init__()
72
+ d_ff = d_ff or 4 * d_model
73
+ self.self_attention = self_attention
74
+ self.cross_attention = cross_attention
75
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
76
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
77
+ self.norm1 = nn.LayerNorm(d_model)
78
+ self.norm2 = nn.LayerNorm(d_model)
79
+ self.norm3 = nn.LayerNorm(d_model)
80
+ self.dropout = nn.Dropout(dropout)
81
+ self.activation = F.relu if activation == "relu" else F.gelu
82
+
83
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
84
+ B, L, D = cross.shape
85
+ x = x + self.dropout(self.self_attention(
86
+ x, x, x,
87
+ attn_mask=x_mask,
88
+ tau=tau, delta=None
89
+ )[0])
90
+ x = self.norm1(x)
91
+
92
+ x_glb_ori = x[:, -1, :].unsqueeze(1)
93
+ x_glb = torch.reshape(x_glb_ori, (B, -1, D))
94
+ x_glb_attn = self.dropout(self.cross_attention(
95
+ x_glb, cross, cross,
96
+ attn_mask=cross_mask,
97
+ tau=tau, delta=delta
98
+ )[0])
99
+ x_glb_attn = torch.reshape(x_glb_attn,
100
+ (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])).unsqueeze(1)
101
+ x_glb = x_glb_ori + x_glb_attn
102
+ x_glb = self.norm2(x_glb)
103
+
104
+ y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
105
+
106
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
107
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
108
+
109
+ return self.norm3(x + y)
110
+
111
+ class Encoder(nn.Module):
112
+ def __init__(self, layers, norm_layer=None, projection=None):
113
+ super(Encoder, self).__init__()
114
+ self.layers = nn.ModuleList(layers)
115
+ self.norm = norm_layer
116
+ self.projection = projection
117
+
118
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
119
+ for layer in self.layers:
120
+ x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
121
+
122
+ if self.norm is not None:
123
+ x = self.norm(x)
124
+
125
+ if self.projection is not None:
126
+ x = self.projection(x)
127
+ return x
File without changes
File without changes