dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,569 @@
1
+
2
+
3
+ import torch
4
+ from torch import nn
5
+ import math
6
+ from typing import Optional
7
+ from torch import Tensor
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+
11
+
12
+ class Transpose(nn.Module):
13
+ def __init__(self, *dims, contiguous=False):
14
+ super().__init__()
15
+ self.dims, self.contiguous = dims, contiguous
16
+ def forward(self, x):
17
+ if self.contiguous:
18
+ return x.transpose(*self.dims).contiguous()
19
+ else:
20
+ return x.transpose(*self.dims)
21
+
22
+
23
+
24
+
25
+ # decomposition
26
+
27
+ class moving_avg(nn.Module):
28
+ """
29
+ Moving average block to highlight the trend of time series
30
+ """
31
+ def __init__(self, kernel_size, stride):
32
+ super(moving_avg, self).__init__()
33
+ self.kernel_size = kernel_size
34
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
35
+
36
+ def forward(self, x):
37
+ # padding on the both ends of time series
38
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
39
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
40
+ x = torch.cat([front, x, end], dim=1)
41
+ x = self.avg(x.permute(0, 2, 1))
42
+ x = x.permute(0, 2, 1)
43
+ return x
44
+
45
+
46
+ class series_decomp(nn.Module):
47
+ """
48
+ Series decomposition block
49
+ """
50
+ def __init__(self, kernel_size):
51
+ super(series_decomp, self).__init__()
52
+ self.moving_avg = moving_avg(kernel_size, stride=1)
53
+
54
+ def forward(self, x):
55
+ moving_mean = self.moving_avg(x)
56
+ res = x - moving_mean
57
+ return res, moving_mean
58
+
59
+
60
+
61
+ # pos_encoding
62
+
63
+ def PositionalEncoding(q_len, d_model, normalize=True):
64
+ pe = torch.zeros(q_len, d_model)
65
+ position = torch.arange(0, q_len).unsqueeze(1)
66
+ div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
67
+ pe[:, 0::2] = torch.sin(position * div_term)
68
+ pe[:, 1::2] = torch.cos(position * div_term)
69
+ if normalize:
70
+ pe = pe - pe.mean()
71
+ pe = pe / (pe.std() * 10)
72
+ return pe
73
+
74
+ SinCosPosEncoding = PositionalEncoding
75
+
76
+ def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
77
+ x = .5 if exponential else 1
78
+ i = 0
79
+ for i in range(100):
80
+ cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
81
+ if abs(cpe.mean()) <= eps:
82
+ break
83
+ elif cpe.mean() > eps:
84
+ x += .001
85
+ else:
86
+ x -= .001
87
+ i += 1
88
+ if normalize:
89
+ cpe = cpe - cpe.mean()
90
+ cpe = cpe / (cpe.std() * 10)
91
+ return cpe
92
+
93
+ def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
94
+ cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
95
+ if normalize:
96
+ cpe = cpe - cpe.mean()
97
+ cpe = cpe / (cpe.std() * 10)
98
+ return cpe
99
+
100
+ def positional_encoding(pe, learn_pe, q_len, d_model):
101
+ # Positional encoding
102
+ if pe is None:
103
+ W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
104
+ nn.init.uniform_(W_pos, -0.02, 0.02)
105
+ learn_pe = False
106
+ elif pe == 'zero':
107
+ W_pos = torch.empty((q_len, 1))
108
+ nn.init.uniform_(W_pos, -0.02, 0.02)
109
+ elif pe == 'zeros':
110
+ W_pos = torch.empty((q_len, d_model))
111
+ nn.init.uniform_(W_pos, -0.02, 0.02)
112
+ elif pe == 'normal' or pe == 'gauss':
113
+ W_pos = torch.zeros((q_len, 1))
114
+ torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
115
+ elif pe == 'uniform':
116
+ W_pos = torch.zeros((q_len, 1))
117
+ nn.init.uniform_(W_pos, a=0.0, b=0.1)
118
+ elif pe == 'lin1d':
119
+ W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
120
+ elif pe == 'exp1d':
121
+ W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
122
+ elif pe == 'lin2d':
123
+ W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
124
+ elif pe == 'exp2d':
125
+ W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
126
+ elif pe == 'sincos':
127
+ W_pos = PositionalEncoding(q_len, d_model, normalize=True)
128
+ else:
129
+ raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
130
+ 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
131
+ return nn.Parameter(W_pos, requires_grad=learn_pe)
132
+
133
+
134
+
135
+ class PatchTST_backbone(nn.Module):
136
+ def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024,
137
+ n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None,
138
+ d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto',
139
+ padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
140
+ pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None,
141
+ pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False,
142
+ verbose:bool=False, **kwargs):
143
+
144
+ super().__init__()
145
+
146
+ # RevIn
147
+ self.revin = revin
148
+ if self.revin:
149
+ self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
150
+
151
+ # Patching
152
+ self.patch_len = patch_len
153
+ self.stride = stride
154
+ self.padding_patch = padding_patch
155
+ patch_num = int((context_window - patch_len)/stride + 1)
156
+ if padding_patch == 'end': # can be modified to general case
157
+ self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
158
+ patch_num += 1
159
+
160
+ # Backbone
161
+ self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len,
162
+ n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff,
163
+ attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
164
+ attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
165
+ pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs)
166
+
167
+ # Head
168
+ self.head_nf = d_model * patch_num
169
+ self.n_vars = c_in
170
+ self.pretrain_head = pretrain_head
171
+ self.head_type = head_type
172
+ self.individual = individual
173
+
174
+ if self.pretrain_head:
175
+ self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs
176
+ elif head_type == 'flatten':
177
+ self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout)
178
+
179
+
180
+ def forward(self, z): # z: [bs x nvars x seq_len]
181
+ # norm
182
+ if self.revin:
183
+ z = z.permute(0,2,1)
184
+ z = self.revin_layer(z, 'norm')
185
+ z = z.permute(0,2,1)
186
+
187
+ # do patching
188
+ if self.padding_patch == 'end':
189
+ z = self.padding_patch_layer(z)
190
+ z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len]
191
+ z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num]
192
+
193
+ # model
194
+ z = self.backbone(z) # z: [bs x nvars x d_model x patch_num]
195
+ z = self.head(z) # z: [bs x nvars x target_window]
196
+
197
+ # denorm
198
+ if self.revin:
199
+ z = z.permute(0,2,1)
200
+ z = self.revin_layer(z, 'denorm')
201
+ z = z.permute(0,2,1)
202
+ return z
203
+
204
+ def create_pretrain_head(self, head_nf, vars, dropout):
205
+ return nn.Sequential(nn.Dropout(dropout),
206
+ nn.Conv1d(head_nf, vars, 1)
207
+ )
208
+
209
+
210
+ class Flatten_Head(nn.Module):
211
+ def __init__(self, individual, n_vars, nf, target_window, head_dropout=0):
212
+ super().__init__()
213
+
214
+ self.individual = individual
215
+ self.n_vars = n_vars
216
+
217
+ if self.individual:
218
+ self.linears = nn.ModuleList()
219
+ self.dropouts = nn.ModuleList()
220
+ self.flattens = nn.ModuleList()
221
+ for i in range(self.n_vars):
222
+ self.flattens.append(nn.Flatten(start_dim=-2))
223
+ self.linears.append(nn.Linear(nf, target_window))
224
+ self.dropouts.append(nn.Dropout(head_dropout))
225
+ else:
226
+ self.flatten = nn.Flatten(start_dim=-2)
227
+ self.linear = nn.Linear(nf, target_window)
228
+ self.dropout = nn.Dropout(head_dropout)
229
+
230
+ def forward(self, x): # x: [bs x nvars x d_model x patch_num]
231
+ if self.individual:
232
+ x_out = []
233
+ for i in range(self.n_vars):
234
+ z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num]
235
+ z = self.linears[i](z) # z: [bs x target_window]
236
+ z = self.dropouts[i](z)
237
+ x_out.append(z)
238
+ x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window]
239
+ else:
240
+ x = self.flatten(x)
241
+ x = self.linear(x)
242
+ x = self.dropout(x)
243
+ return x
244
+
245
+
246
+
247
+
248
+ class TSTiEncoder(nn.Module): #i means channel-independent
249
+ def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024,
250
+ n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None,
251
+ d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False,
252
+ key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False,
253
+ pe='zeros', learn_pe=True, verbose=False, **kwargs):
254
+
255
+
256
+ super().__init__()
257
+
258
+ self.patch_num = patch_num
259
+ self.patch_len = patch_len
260
+
261
+ # Input encoding
262
+ q_len = patch_num
263
+ self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space
264
+ self.seq_len = q_len
265
+
266
+ # Positional encoding
267
+ self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model)
268
+
269
+ # Residual dropout
270
+ self.dropout = nn.Dropout(dropout)
271
+
272
+ # Encoder
273
+ self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
274
+ pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn)
275
+
276
+
277
+ def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num]
278
+
279
+ n_vars = x.shape[1]
280
+ # Input encoding
281
+ x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len]
282
+ x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
283
+
284
+ u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model]
285
+ u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model]
286
+
287
+ # Encoder
288
+ z = self.encoder(u) # z: [bs * nvars x patch_num x d_model]
289
+ z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model]
290
+ z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num]
291
+
292
+ return z
293
+
294
+
295
+
296
+ # Cell
297
+ class TSTEncoder(nn.Module):
298
+ def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None,
299
+ norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
300
+ res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
301
+ super().__init__()
302
+
303
+ self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm,
304
+ attn_dropout=attn_dropout, dropout=dropout,
305
+ activation=activation, res_attention=res_attention,
306
+ pre_norm=pre_norm, store_attn=store_attn) for _ in range(n_layers)])
307
+ self.res_attention = res_attention
308
+
309
+ def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
310
+ output = src
311
+ scores = None
312
+ if self.res_attention:
313
+ for mod in self.layers:
314
+ output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
315
+ return output
316
+ else:
317
+ for mod in self.layers:
318
+ output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
319
+ return output
320
+
321
+
322
+
323
+ class TSTEncoderLayer(nn.Module):
324
+ def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False,
325
+ norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False):
326
+ super().__init__()
327
+ assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
328
+ d_k = d_model // n_heads if d_k is None else d_k
329
+ d_v = d_model // n_heads if d_v is None else d_v
330
+
331
+ # Multi-Head attention
332
+ self.res_attention = res_attention
333
+ self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
334
+
335
+ # Add & Norm
336
+ self.dropout_attn = nn.Dropout(dropout)
337
+ if "batch" in norm.lower():
338
+ self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
339
+ else:
340
+ self.norm_attn = nn.LayerNorm(d_model)
341
+
342
+ # Position-wise Feed-Forward
343
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
344
+ activation,
345
+ nn.Dropout(dropout),
346
+ nn.Linear(d_ff, d_model, bias=bias))
347
+
348
+ # Add & Norm
349
+ self.dropout_ffn = nn.Dropout(dropout)
350
+ if "batch" in norm.lower():
351
+ self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
352
+ else:
353
+ self.norm_ffn = nn.LayerNorm(d_model)
354
+
355
+ self.pre_norm = pre_norm
356
+ self.store_attn = store_attn
357
+
358
+
359
+ def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor:
360
+
361
+ # Multi-Head attention sublayer
362
+ if self.pre_norm:
363
+ src = self.norm_attn(src)
364
+ ## Multi-Head attention
365
+ if self.res_attention:
366
+ src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
367
+ else:
368
+ src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
369
+ if self.store_attn:
370
+ self.attn = attn
371
+ ## Add & Norm
372
+ src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
373
+ if not self.pre_norm:
374
+ src = self.norm_attn(src)
375
+
376
+ # Feed-forward sublayer
377
+ if self.pre_norm:
378
+ src = self.norm_ffn(src)
379
+ ## Position-wise Feed-Forward
380
+ src2 = self.ff(src)
381
+ ## Add & Norm
382
+ src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
383
+ if not self.pre_norm:
384
+ src = self.norm_ffn(src)
385
+
386
+ if self.res_attention:
387
+ return src, scores
388
+ else:
389
+ return src
390
+
391
+
392
+
393
+
394
+ class _MultiheadAttention(nn.Module):
395
+ def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
396
+ """Multi Head Attention Layer
397
+ Input shape:
398
+ Q: [batch_size (bs) x max_q_len x d_model]
399
+ K, V: [batch_size (bs) x q_len x d_model]
400
+ mask: [q_len x q_len]
401
+ """
402
+ super().__init__()
403
+ d_k = d_model // n_heads if d_k is None else d_k
404
+ d_v = d_model // n_heads if d_v is None else d_v
405
+
406
+ self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
407
+
408
+ self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
409
+ self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
410
+ self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
411
+
412
+ # Scaled Dot-Product Attention (multiple heads)
413
+ self.res_attention = res_attention
414
+ self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)
415
+
416
+ # Poject output
417
+ self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
418
+
419
+
420
+ def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
421
+ key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
422
+
423
+ bs = Q.size(0)
424
+ if K is None:
425
+ K = Q
426
+ if V is None:
427
+ V = Q
428
+
429
+ # Linear (+ split in multiple heads)
430
+ q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
431
+ k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
432
+ v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
433
+
434
+ # Apply Scaled Dot-Product Attention (multiple heads)
435
+ if self.res_attention:
436
+ output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
437
+ else:
438
+ output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
439
+ # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
440
+
441
+ # back to the original inputs dimensions
442
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
443
+ output = self.to_out(output)
444
+
445
+ if self.res_attention:
446
+ return output, attn_weights, attn_scores
447
+ else:
448
+ return output, attn_weights
449
+
450
+
451
+ class _ScaledDotProductAttention(nn.Module):
452
+ r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
453
+ (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
454
+ by Lee et al, 2021)"""
455
+
456
+ def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):
457
+ super().__init__()
458
+ self.attn_dropout = nn.Dropout(attn_dropout)
459
+ self.res_attention = res_attention
460
+ head_dim = d_model // n_heads
461
+ self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
462
+ self.lsa = lsa
463
+
464
+ def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
465
+ '''
466
+ Input shape:
467
+ q : [bs x n_heads x max_q_len x d_k]
468
+ k : [bs x n_heads x d_k x seq_len]
469
+ v : [bs x n_heads x seq_len x d_v]
470
+ prev : [bs x n_heads x q_len x seq_len]
471
+ key_padding_mask: [bs x seq_len]
472
+ attn_mask : [1 x seq_len x seq_len]
473
+ Output shape:
474
+ output: [bs x n_heads x q_len x d_v]
475
+ attn : [bs x n_heads x q_len x seq_len]
476
+ scores : [bs x n_heads x q_len x seq_len]
477
+ '''
478
+
479
+ # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
480
+ attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
481
+
482
+ # Add pre-softmax attention scores from the previous layer (optional)
483
+ if prev is not None:
484
+ attn_scores = attn_scores + prev
485
+
486
+ # Attention mask (optional)
487
+ if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
488
+ if attn_mask.dtype == torch.bool:
489
+ attn_scores.masked_fill_(attn_mask, -np.inf)
490
+ else:
491
+ attn_scores += attn_mask
492
+
493
+ # Key padding mask (optional)
494
+ if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
495
+ attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
496
+
497
+ # normalize the attention weights
498
+ attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
499
+ attn_weights = self.attn_dropout(attn_weights)
500
+
501
+ # compute the new values given the attention weights
502
+ output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
503
+
504
+ if self.res_attention:
505
+ return output, attn_weights, attn_scores
506
+ else:
507
+ return output, attn_weights
508
+
509
+
510
+
511
+ class RevIN(nn.Module):
512
+ def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
513
+ """
514
+ :param num_features: the number of features or channels
515
+ :param eps: a value added for numerical stability
516
+ :param affine: if True, RevIN has learnable affine parameters
517
+ """
518
+ super(RevIN, self).__init__()
519
+ self.num_features = num_features
520
+ self.eps = eps
521
+ self.affine = affine
522
+ self.subtract_last = subtract_last
523
+ if self.affine:
524
+ self._init_params()
525
+
526
+ def forward(self, x, mode:str):
527
+ if mode == 'norm':
528
+ self._get_statistics(x)
529
+ x = self._normalize(x)
530
+ elif mode == 'denorm':
531
+ x = self._denormalize(x)
532
+ else:
533
+ raise NotImplementedError
534
+ return x
535
+
536
+ def _init_params(self):
537
+ # initialize RevIN params: (C,)
538
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
539
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
540
+
541
+ def _get_statistics(self, x):
542
+ dim2reduce = tuple(range(1, x.ndim-1))
543
+ if self.subtract_last:
544
+ self.last = x[:,-1,:].unsqueeze(1)
545
+ else:
546
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
547
+ self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
548
+
549
+ def _normalize(self, x):
550
+ if self.subtract_last:
551
+ x = x - self.last
552
+ else:
553
+ x = x - self.mean
554
+ x = x / self.stdev
555
+ if self.affine:
556
+ x = x * self.affine_weight
557
+ x = x + self.affine_bias
558
+ return x
559
+
560
+ def _denormalize(self, x):
561
+ if self.affine:
562
+ x = x - self.affine_bias
563
+ x = x / (self.affine_weight + self.eps*self.eps)
564
+ x = x * self.stdev
565
+ if self.subtract_last:
566
+ x = x + self.last
567
+ else:
568
+ x = x + self.mean
569
+ return x
File without changes