dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
import math
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Transpose(nn.Module):
|
|
13
|
+
def __init__(self, *dims, contiguous=False):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.dims, self.contiguous = dims, contiguous
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
if self.contiguous:
|
|
18
|
+
return x.transpose(*self.dims).contiguous()
|
|
19
|
+
else:
|
|
20
|
+
return x.transpose(*self.dims)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# decomposition
|
|
26
|
+
|
|
27
|
+
class moving_avg(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Moving average block to highlight the trend of time series
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self, kernel_size, stride):
|
|
32
|
+
super(moving_avg, self).__init__()
|
|
33
|
+
self.kernel_size = kernel_size
|
|
34
|
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
35
|
+
|
|
36
|
+
def forward(self, x):
|
|
37
|
+
# padding on the both ends of time series
|
|
38
|
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
39
|
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
40
|
+
x = torch.cat([front, x, end], dim=1)
|
|
41
|
+
x = self.avg(x.permute(0, 2, 1))
|
|
42
|
+
x = x.permute(0, 2, 1)
|
|
43
|
+
return x
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class series_decomp(nn.Module):
|
|
47
|
+
"""
|
|
48
|
+
Series decomposition block
|
|
49
|
+
"""
|
|
50
|
+
def __init__(self, kernel_size):
|
|
51
|
+
super(series_decomp, self).__init__()
|
|
52
|
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
53
|
+
|
|
54
|
+
def forward(self, x):
|
|
55
|
+
moving_mean = self.moving_avg(x)
|
|
56
|
+
res = x - moving_mean
|
|
57
|
+
return res, moving_mean
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# pos_encoding
|
|
62
|
+
|
|
63
|
+
def PositionalEncoding(q_len, d_model, normalize=True):
|
|
64
|
+
pe = torch.zeros(q_len, d_model)
|
|
65
|
+
position = torch.arange(0, q_len).unsqueeze(1)
|
|
66
|
+
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
|
|
67
|
+
pe[:, 0::2] = torch.sin(position * div_term)
|
|
68
|
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
69
|
+
if normalize:
|
|
70
|
+
pe = pe - pe.mean()
|
|
71
|
+
pe = pe / (pe.std() * 10)
|
|
72
|
+
return pe
|
|
73
|
+
|
|
74
|
+
SinCosPosEncoding = PositionalEncoding
|
|
75
|
+
|
|
76
|
+
def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
|
|
77
|
+
x = .5 if exponential else 1
|
|
78
|
+
i = 0
|
|
79
|
+
for i in range(100):
|
|
80
|
+
cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
|
|
81
|
+
if abs(cpe.mean()) <= eps:
|
|
82
|
+
break
|
|
83
|
+
elif cpe.mean() > eps:
|
|
84
|
+
x += .001
|
|
85
|
+
else:
|
|
86
|
+
x -= .001
|
|
87
|
+
i += 1
|
|
88
|
+
if normalize:
|
|
89
|
+
cpe = cpe - cpe.mean()
|
|
90
|
+
cpe = cpe / (cpe.std() * 10)
|
|
91
|
+
return cpe
|
|
92
|
+
|
|
93
|
+
def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
|
|
94
|
+
cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
|
|
95
|
+
if normalize:
|
|
96
|
+
cpe = cpe - cpe.mean()
|
|
97
|
+
cpe = cpe / (cpe.std() * 10)
|
|
98
|
+
return cpe
|
|
99
|
+
|
|
100
|
+
def positional_encoding(pe, learn_pe, q_len, d_model):
|
|
101
|
+
# Positional encoding
|
|
102
|
+
if pe is None:
|
|
103
|
+
W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
|
|
104
|
+
nn.init.uniform_(W_pos, -0.02, 0.02)
|
|
105
|
+
learn_pe = False
|
|
106
|
+
elif pe == 'zero':
|
|
107
|
+
W_pos = torch.empty((q_len, 1))
|
|
108
|
+
nn.init.uniform_(W_pos, -0.02, 0.02)
|
|
109
|
+
elif pe == 'zeros':
|
|
110
|
+
W_pos = torch.empty((q_len, d_model))
|
|
111
|
+
nn.init.uniform_(W_pos, -0.02, 0.02)
|
|
112
|
+
elif pe == 'normal' or pe == 'gauss':
|
|
113
|
+
W_pos = torch.zeros((q_len, 1))
|
|
114
|
+
torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
|
|
115
|
+
elif pe == 'uniform':
|
|
116
|
+
W_pos = torch.zeros((q_len, 1))
|
|
117
|
+
nn.init.uniform_(W_pos, a=0.0, b=0.1)
|
|
118
|
+
elif pe == 'lin1d':
|
|
119
|
+
W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
|
|
120
|
+
elif pe == 'exp1d':
|
|
121
|
+
W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
|
|
122
|
+
elif pe == 'lin2d':
|
|
123
|
+
W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
|
|
124
|
+
elif pe == 'exp2d':
|
|
125
|
+
W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
|
|
126
|
+
elif pe == 'sincos':
|
|
127
|
+
W_pos = PositionalEncoding(q_len, d_model, normalize=True)
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
|
|
130
|
+
'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
|
|
131
|
+
return nn.Parameter(W_pos, requires_grad=learn_pe)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class PatchTST_backbone(nn.Module):
|
|
136
|
+
def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024,
|
|
137
|
+
n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
|
|
138
|
+
d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
|
|
139
|
+
padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
|
|
140
|
+
pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,
|
|
141
|
+
pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,
|
|
142
|
+
verbose:bool=False, **kwargs):
|
|
143
|
+
|
|
144
|
+
super().__init__()
|
|
145
|
+
|
|
146
|
+
# RevIn
|
|
147
|
+
self.revin = revin
|
|
148
|
+
if self.revin:
|
|
149
|
+
self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
|
|
150
|
+
|
|
151
|
+
# Patching
|
|
152
|
+
self.patch_len = patch_len
|
|
153
|
+
self.stride = stride
|
|
154
|
+
self.padding_patch = padding_patch
|
|
155
|
+
patch_num = int((context_window - patch_len)/stride + 1)
|
|
156
|
+
if padding_patch == 'end': # can be modified to general case
|
|
157
|
+
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
|
|
158
|
+
patch_num += 1
|
|
159
|
+
|
|
160
|
+
# Backbone
|
|
161
|
+
self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
|
|
162
|
+
n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
|
|
163
|
+
attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
|
|
164
|
+
attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
|
|
165
|
+
pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
|
|
166
|
+
|
|
167
|
+
# Head
|
|
168
|
+
self.head_nf = d_model * patch_num
|
|
169
|
+
self.n_vars = c_in
|
|
170
|
+
self.pretrain_head = pretrain_head
|
|
171
|
+
self.head_type = head_type
|
|
172
|
+
self.individual = individual
|
|
173
|
+
|
|
174
|
+
if self.pretrain_head:
|
|
175
|
+
self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
|
|
176
|
+
elif head_type == 'flatten':
|
|
177
|
+
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def forward(self, z): # z: [bs x nvars x seq_len]
|
|
181
|
+
# norm
|
|
182
|
+
if self.revin:
|
|
183
|
+
z = z.permute(0,2,1)
|
|
184
|
+
z = self.revin_layer(z, 'norm')
|
|
185
|
+
z = z.permute(0,2,1)
|
|
186
|
+
|
|
187
|
+
# do patching
|
|
188
|
+
if self.padding_patch == 'end':
|
|
189
|
+
z = self.padding_patch_layer(z)
|
|
190
|
+
z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len]
|
|
191
|
+
z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num]
|
|
192
|
+
|
|
193
|
+
# model
|
|
194
|
+
z = self.backbone(z) # z: [bs x nvars x d_model x patch_num]
|
|
195
|
+
z = self.head(z) # z: [bs x nvars x target_window]
|
|
196
|
+
|
|
197
|
+
# denorm
|
|
198
|
+
if self.revin:
|
|
199
|
+
z = z.permute(0,2,1)
|
|
200
|
+
z = self.revin_layer(z, 'denorm')
|
|
201
|
+
z = z.permute(0,2,1)
|
|
202
|
+
return z
|
|
203
|
+
|
|
204
|
+
def create_pretrain_head(self, head_nf, vars, dropout):
|
|
205
|
+
return nn.Sequential(nn.Dropout(dropout),
|
|
206
|
+
nn.Conv1d(head_nf, vars, 1)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class Flatten_Head(nn.Module):
|
|
211
|
+
def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
|
|
212
|
+
super().__init__()
|
|
213
|
+
|
|
214
|
+
self.individual = individual
|
|
215
|
+
self.n_vars = n_vars
|
|
216
|
+
|
|
217
|
+
if self.individual:
|
|
218
|
+
self.linears = nn.ModuleList()
|
|
219
|
+
self.dropouts = nn.ModuleList()
|
|
220
|
+
self.flattens = nn.ModuleList()
|
|
221
|
+
for i in range(self.n_vars):
|
|
222
|
+
self.flattens.append(nn.Flatten(start_dim=-2))
|
|
223
|
+
self.linears.append(nn.Linear(nf, target_window))
|
|
224
|
+
self.dropouts.append(nn.Dropout(head_dropout))
|
|
225
|
+
else:
|
|
226
|
+
self.flatten = nn.Flatten(start_dim=-2)
|
|
227
|
+
self.linear = nn.Linear(nf, target_window)
|
|
228
|
+
self.dropout = nn.Dropout(head_dropout)
|
|
229
|
+
|
|
230
|
+
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
|
|
231
|
+
if self.individual:
|
|
232
|
+
x_out = []
|
|
233
|
+
for i in range(self.n_vars):
|
|
234
|
+
z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num]
|
|
235
|
+
z = self.linears[i](z) # z: [bs x target_window]
|
|
236
|
+
z = self.dropouts[i](z)
|
|
237
|
+
x_out.append(z)
|
|
238
|
+
x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
|
|
239
|
+
else:
|
|
240
|
+
x = self.flatten(x)
|
|
241
|
+
x = self.linear(x)
|
|
242
|
+
x = self.dropout(x)
|
|
243
|
+
return x
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class TSTiEncoder(nn.Module): #i means channel-independent
|
|
249
|
+
def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
|
|
250
|
+
n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
|
|
251
|
+
d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
|
|
252
|
+
key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
|
|
253
|
+
pe='zeros', learn_pe=True, verbose=False, **kwargs):
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
super().__init__()
|
|
257
|
+
|
|
258
|
+
self.patch_num = patch_num
|
|
259
|
+
self.patch_len = patch_len
|
|
260
|
+
|
|
261
|
+
# Input encoding
|
|
262
|
+
q_len = patch_num
|
|
263
|
+
self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space
|
|
264
|
+
self.seq_len = q_len
|
|
265
|
+
|
|
266
|
+
# Positional encoding
|
|
267
|
+
self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
|
|
268
|
+
|
|
269
|
+
# Residual dropout
|
|
270
|
+
self.dropout = nn.Dropout(dropout)
|
|
271
|
+
|
|
272
|
+
# Encoder
|
|
273
|
+
self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
|
|
274
|
+
pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num]
|
|
278
|
+
|
|
279
|
+
n_vars = x.shape[1]
|
|
280
|
+
# Input encoding
|
|
281
|
+
x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len]
|
|
282
|
+
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
|
|
283
|
+
|
|
284
|
+
u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model]
|
|
285
|
+
u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model]
|
|
286
|
+
|
|
287
|
+
# Encoder
|
|
288
|
+
z = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
|
|
289
|
+
z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
|
|
290
|
+
z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num]
|
|
291
|
+
|
|
292
|
+
return z
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Cell
|
|
297
|
+
class TSTEncoder(nn.Module):
|
|
298
|
+
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
|
|
299
|
+
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
|
|
300
|
+
res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
|
|
301
|
+
super().__init__()
|
|
302
|
+
|
|
303
|
+
self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
|
|
304
|
+
attn_dropout=attn_dropout, dropout=dropout,
|
|
305
|
+
activation=activation, res_attention=res_attention,
|
|
306
|
+
pre_norm=pre_norm, store_attn=store_attn) for _ in range(n_layers)])
|
|
307
|
+
self.res_attention = res_attention
|
|
308
|
+
|
|
309
|
+
def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
|
310
|
+
output = src
|
|
311
|
+
scores = None
|
|
312
|
+
if self.res_attention:
|
|
313
|
+
for mod in self.layers:
|
|
314
|
+
output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
315
|
+
return output
|
|
316
|
+
else:
|
|
317
|
+
for mod in self.layers:
|
|
318
|
+
output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
319
|
+
return output
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class TSTEncoderLayer(nn.Module):
|
|
324
|
+
def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
|
|
325
|
+
norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
|
|
326
|
+
super().__init__()
|
|
327
|
+
assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
|
|
328
|
+
d_k = d_model // n_heads if d_k is None else d_k
|
|
329
|
+
d_v = d_model // n_heads if d_v is None else d_v
|
|
330
|
+
|
|
331
|
+
# Multi-Head attention
|
|
332
|
+
self.res_attention = res_attention
|
|
333
|
+
self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
|
|
334
|
+
|
|
335
|
+
# Add & Norm
|
|
336
|
+
self.dropout_attn = nn.Dropout(dropout)
|
|
337
|
+
if "batch" in norm.lower():
|
|
338
|
+
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
|
|
339
|
+
else:
|
|
340
|
+
self.norm_attn = nn.LayerNorm(d_model)
|
|
341
|
+
|
|
342
|
+
# Position-wise Feed-Forward
|
|
343
|
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
|
|
344
|
+
activation,
|
|
345
|
+
nn.Dropout(dropout),
|
|
346
|
+
nn.Linear(d_ff, d_model, bias=bias))
|
|
347
|
+
|
|
348
|
+
# Add & Norm
|
|
349
|
+
self.dropout_ffn = nn.Dropout(dropout)
|
|
350
|
+
if "batch" in norm.lower():
|
|
351
|
+
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
|
|
352
|
+
else:
|
|
353
|
+
self.norm_ffn = nn.LayerNorm(d_model)
|
|
354
|
+
|
|
355
|
+
self.pre_norm = pre_norm
|
|
356
|
+
self.store_attn = store_attn
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:
|
|
360
|
+
|
|
361
|
+
# Multi-Head attention sublayer
|
|
362
|
+
if self.pre_norm:
|
|
363
|
+
src = self.norm_attn(src)
|
|
364
|
+
## Multi-Head attention
|
|
365
|
+
if self.res_attention:
|
|
366
|
+
src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
367
|
+
else:
|
|
368
|
+
src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
369
|
+
if self.store_attn:
|
|
370
|
+
self.attn = attn
|
|
371
|
+
## Add & Norm
|
|
372
|
+
src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
|
|
373
|
+
if not self.pre_norm:
|
|
374
|
+
src = self.norm_attn(src)
|
|
375
|
+
|
|
376
|
+
# Feed-forward sublayer
|
|
377
|
+
if self.pre_norm:
|
|
378
|
+
src = self.norm_ffn(src)
|
|
379
|
+
## Position-wise Feed-Forward
|
|
380
|
+
src2 = self.ff(src)
|
|
381
|
+
## Add & Norm
|
|
382
|
+
src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
|
|
383
|
+
if not self.pre_norm:
|
|
384
|
+
src = self.norm_ffn(src)
|
|
385
|
+
|
|
386
|
+
if self.res_attention:
|
|
387
|
+
return src, scores
|
|
388
|
+
else:
|
|
389
|
+
return src
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class _MultiheadAttention(nn.Module):
|
|
395
|
+
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
|
|
396
|
+
"""Multi Head Attention Layer
|
|
397
|
+
Input shape:
|
|
398
|
+
Q: [batch_size (bs) x max_q_len x d_model]
|
|
399
|
+
K, V: [batch_size (bs) x q_len x d_model]
|
|
400
|
+
mask: [q_len x q_len]
|
|
401
|
+
"""
|
|
402
|
+
super().__init__()
|
|
403
|
+
d_k = d_model // n_heads if d_k is None else d_k
|
|
404
|
+
d_v = d_model // n_heads if d_v is None else d_v
|
|
405
|
+
|
|
406
|
+
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
|
|
407
|
+
|
|
408
|
+
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
|
409
|
+
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
|
|
410
|
+
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
|
|
411
|
+
|
|
412
|
+
# Scaled Dot-Product Attention (multiple heads)
|
|
413
|
+
self.res_attention = res_attention
|
|
414
|
+
self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)
|
|
415
|
+
|
|
416
|
+
# Poject output
|
|
417
|
+
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
|
|
421
|
+
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
|
422
|
+
|
|
423
|
+
bs = Q.size(0)
|
|
424
|
+
if K is None:
|
|
425
|
+
K = Q
|
|
426
|
+
if V is None:
|
|
427
|
+
V = Q
|
|
428
|
+
|
|
429
|
+
# Linear (+ split in multiple heads)
|
|
430
|
+
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
|
|
431
|
+
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
|
|
432
|
+
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
|
|
433
|
+
|
|
434
|
+
# Apply Scaled Dot-Product Attention (multiple heads)
|
|
435
|
+
if self.res_attention:
|
|
436
|
+
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
437
|
+
else:
|
|
438
|
+
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
|
439
|
+
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
|
|
440
|
+
|
|
441
|
+
# back to the original inputs dimensions
|
|
442
|
+
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
|
|
443
|
+
output = self.to_out(output)
|
|
444
|
+
|
|
445
|
+
if self.res_attention:
|
|
446
|
+
return output, attn_weights, attn_scores
|
|
447
|
+
else:
|
|
448
|
+
return output, attn_weights
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class _ScaledDotProductAttention(nn.Module):
|
|
452
|
+
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
|
|
453
|
+
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
|
|
454
|
+
by Lee et al, 2021)"""
|
|
455
|
+
|
|
456
|
+
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
|
|
457
|
+
super().__init__()
|
|
458
|
+
self.attn_dropout = nn.Dropout(attn_dropout)
|
|
459
|
+
self.res_attention = res_attention
|
|
460
|
+
head_dim = d_model // n_heads
|
|
461
|
+
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
|
|
462
|
+
self.lsa = lsa
|
|
463
|
+
|
|
464
|
+
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
|
|
465
|
+
'''
|
|
466
|
+
Input shape:
|
|
467
|
+
q : [bs x n_heads x max_q_len x d_k]
|
|
468
|
+
k : [bs x n_heads x d_k x seq_len]
|
|
469
|
+
v : [bs x n_heads x seq_len x d_v]
|
|
470
|
+
prev : [bs x n_heads x q_len x seq_len]
|
|
471
|
+
key_padding_mask: [bs x seq_len]
|
|
472
|
+
attn_mask : [1 x seq_len x seq_len]
|
|
473
|
+
Output shape:
|
|
474
|
+
output: [bs x n_heads x q_len x d_v]
|
|
475
|
+
attn : [bs x n_heads x q_len x seq_len]
|
|
476
|
+
scores : [bs x n_heads x q_len x seq_len]
|
|
477
|
+
'''
|
|
478
|
+
|
|
479
|
+
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
|
|
480
|
+
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
|
|
481
|
+
|
|
482
|
+
# Add pre-softmax attention scores from the previous layer (optional)
|
|
483
|
+
if prev is not None:
|
|
484
|
+
attn_scores = attn_scores + prev
|
|
485
|
+
|
|
486
|
+
# Attention mask (optional)
|
|
487
|
+
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
|
|
488
|
+
if attn_mask.dtype == torch.bool:
|
|
489
|
+
attn_scores.masked_fill_(attn_mask, -np.inf)
|
|
490
|
+
else:
|
|
491
|
+
attn_scores += attn_mask
|
|
492
|
+
|
|
493
|
+
# Key padding mask (optional)
|
|
494
|
+
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
|
|
495
|
+
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
|
|
496
|
+
|
|
497
|
+
# normalize the attention weights
|
|
498
|
+
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
|
|
499
|
+
attn_weights = self.attn_dropout(attn_weights)
|
|
500
|
+
|
|
501
|
+
# compute the new values given the attention weights
|
|
502
|
+
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
|
|
503
|
+
|
|
504
|
+
if self.res_attention:
|
|
505
|
+
return output, attn_weights, attn_scores
|
|
506
|
+
else:
|
|
507
|
+
return output, attn_weights
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
class RevIN(nn.Module):
|
|
512
|
+
def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
|
|
513
|
+
"""
|
|
514
|
+
:param num_features: the number of features or channels
|
|
515
|
+
:param eps: a value added for numerical stability
|
|
516
|
+
:param affine: if True, RevIN has learnable affine parameters
|
|
517
|
+
"""
|
|
518
|
+
super(RevIN, self).__init__()
|
|
519
|
+
self.num_features = num_features
|
|
520
|
+
self.eps = eps
|
|
521
|
+
self.affine = affine
|
|
522
|
+
self.subtract_last = subtract_last
|
|
523
|
+
if self.affine:
|
|
524
|
+
self._init_params()
|
|
525
|
+
|
|
526
|
+
def forward(self, x, mode:str):
|
|
527
|
+
if mode == 'norm':
|
|
528
|
+
self._get_statistics(x)
|
|
529
|
+
x = self._normalize(x)
|
|
530
|
+
elif mode == 'denorm':
|
|
531
|
+
x = self._denormalize(x)
|
|
532
|
+
else:
|
|
533
|
+
raise NotImplementedError
|
|
534
|
+
return x
|
|
535
|
+
|
|
536
|
+
def _init_params(self):
|
|
537
|
+
# initialize RevIN params: (C,)
|
|
538
|
+
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
539
|
+
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
540
|
+
|
|
541
|
+
def _get_statistics(self, x):
|
|
542
|
+
dim2reduce = tuple(range(1, x.ndim-1))
|
|
543
|
+
if self.subtract_last:
|
|
544
|
+
self.last = x[:,-1,:].unsqueeze(1)
|
|
545
|
+
else:
|
|
546
|
+
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
547
|
+
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
548
|
+
|
|
549
|
+
def _normalize(self, x):
|
|
550
|
+
if self.subtract_last:
|
|
551
|
+
x = x - self.last
|
|
552
|
+
else:
|
|
553
|
+
x = x - self.mean
|
|
554
|
+
x = x / self.stdev
|
|
555
|
+
if self.affine:
|
|
556
|
+
x = x * self.affine_weight
|
|
557
|
+
x = x + self.affine_bias
|
|
558
|
+
return x
|
|
559
|
+
|
|
560
|
+
def _denormalize(self, x):
|
|
561
|
+
if self.affine:
|
|
562
|
+
x = x - self.affine_bias
|
|
563
|
+
x = x / (self.affine_weight + self.eps*self.eps)
|
|
564
|
+
x = x * self.stdev
|
|
565
|
+
if self.subtract_last:
|
|
566
|
+
x = x + self.last
|
|
567
|
+
else:
|
|
568
|
+
x = x + self.mean
|
|
569
|
+
return x
|
|
File without changes
|