dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,355 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from math import sqrt
5
+ #from utils.masking import TriangularCausalMask, ProbMask
6
+ #from reformer_pytorch import LSHSelfAttention
7
+ from einops import rearrange
8
+
9
+
10
+ class TriangularCausalMask():
11
+ def __init__(self, B, L, device="cpu"):
12
+ mask_shape = [B, 1, L, L]
13
+ with torch.no_grad():
14
+ self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
15
+
16
+ @property
17
+ def mask(self):
18
+ return self._mask
19
+
20
+
21
+ class ProbMask():
22
+ def __init__(self, B, H, L, index, scores, device="cpu"):
23
+ _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
24
+ _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
25
+ indicator = _mask_ex[torch.arange(B)[:, None, None],
26
+ torch.arange(H)[None, :, None],
27
+ index, :].to(device)
28
+ self._mask = indicator.view(scores.shape).to(device)
29
+
30
+ @property
31
+ def mask(self):
32
+ return self._mask
33
+
34
+ # Code implementation from https://github.com/thuml/Flowformer
35
+ class FlowAttention(nn.Module):
36
+ def __init__(self, attention_dropout=0.1):
37
+ super(FlowAttention, self).__init__()
38
+ self.dropout = nn.Dropout(attention_dropout)
39
+
40
+ def kernel_method(self, x):
41
+ return torch.sigmoid(x)
42
+
43
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
44
+ queries = queries.transpose(1, 2)
45
+ keys = keys.transpose(1, 2)
46
+ values = values.transpose(1, 2)
47
+ # kernel
48
+ queries = self.kernel_method(queries)
49
+ keys = self.kernel_method(keys)
50
+ # incoming and outgoing
51
+ normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6))
52
+ normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6))
53
+ # reweighting
54
+ normalizer_row_refine = (
55
+ torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))
56
+ normalizer_col_refine = (
57
+ torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))
58
+ # competition and allocation
59
+ normalizer_row_refine = torch.sigmoid(
60
+ normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))
61
+ normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis
62
+ # multiply
63
+ kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])
64
+ x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1,
65
+ 2).contiguous()
66
+ return x, None
67
+
68
+
69
+ # Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch
70
+ class FlashAttention(nn.Module):
71
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
72
+ super(FlashAttention, self).__init__()
73
+ self.scale = scale
74
+ self.mask_flag = mask_flag
75
+ self.output_attention = output_attention
76
+ self.dropout = nn.Dropout(attention_dropout)
77
+
78
+ def flash_attention_forward(self, Q, K, V, mask=None):
79
+ BLOCK_SIZE = 32
80
+ NEG_INF = -1e10 # -infinity
81
+ EPSILON = 1e-10
82
+ # mask = torch.randint(0, 2, (128, 8)).to(device='cuda')
83
+ O = torch.zeros_like(Q, requires_grad=True)
84
+ l = torch.zeros(Q.shape[:-1])[..., None]
85
+ m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
86
+
87
+ O = O.to(device='cuda')
88
+ l = l.to(device='cuda')
89
+ m = m.to(device='cuda')
90
+
91
+ Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
92
+ KV_BLOCK_SIZE = BLOCK_SIZE
93
+
94
+ Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
95
+ K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
96
+ V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
97
+ if mask is not None:
98
+ mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
99
+
100
+ Tr = len(Q_BLOCKS)
101
+ Tc = len(K_BLOCKS)
102
+
103
+ O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
104
+ l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
105
+ m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
106
+
107
+ for j in range(Tc):
108
+ Kj = K_BLOCKS[j]
109
+ Vj = V_BLOCKS[j]
110
+ if mask is not None:
111
+ maskj = mask_BLOCKS[j]
112
+
113
+ for i in range(Tr):
114
+ Qi = Q_BLOCKS[i]
115
+ Oi = O_BLOCKS[i]
116
+ li = l_BLOCKS[i]
117
+ mi = m_BLOCKS[i]
118
+
119
+ scale = 1 / np.sqrt(Q.shape[-1])
120
+ Qi_scaled = Qi * scale
121
+
122
+ S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
123
+ if mask is not None:
124
+ # Masking
125
+ maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
126
+ S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
127
+
128
+ m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
129
+ P_ij = torch.exp(S_ij - m_block_ij)
130
+ if mask is not None:
131
+ # Masking
132
+ P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
133
+
134
+ l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
135
+
136
+ P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
137
+
138
+ mi_new = torch.maximum(m_block_ij, mi)
139
+ li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij
140
+
141
+ O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (
142
+ torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
143
+ l_BLOCKS[i] = li_new
144
+ m_BLOCKS[i] = mi_new
145
+
146
+ O = torch.cat(O_BLOCKS, dim=2)
147
+ l = torch.cat(l_BLOCKS, dim=2)
148
+ m = torch.cat(m_BLOCKS, dim=2)
149
+ return O, l, m
150
+
151
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
152
+ res = \
153
+ self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3),
154
+ attn_mask)[0]
155
+ return res.permute(0, 2, 1, 3).contiguous(), None
156
+
157
+
158
+ class FullAttention(nn.Module):
159
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
160
+ super(FullAttention, self).__init__()
161
+ self.scale = scale
162
+ self.mask_flag = mask_flag
163
+ self.output_attention = output_attention
164
+ self.dropout = nn.Dropout(attention_dropout)
165
+
166
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
167
+ B, L, H, E = queries.shape
168
+ _, S, _, D = values.shape
169
+ scale = self.scale or 1. / sqrt(E)
170
+
171
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
172
+
173
+ if self.mask_flag:
174
+ if attn_mask is None:
175
+ attn_mask = TriangularCausalMask(B, L, device=queries.device)
176
+
177
+ scores.masked_fill_(attn_mask.mask, -np.inf)
178
+
179
+ A = self.dropout(torch.softmax(scale * scores, dim=-1))
180
+ V = torch.einsum("bhls,bshd->blhd", A, values)
181
+
182
+ if self.output_attention:
183
+ return (V.contiguous(), A)
184
+ else:
185
+ return (V.contiguous(), None)
186
+
187
+
188
+ # Code implementation from https://github.com/zhouhaoyi/Informer2020
189
+ class ProbAttention(nn.Module):
190
+ def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
191
+ super(ProbAttention, self).__init__()
192
+ self.factor = factor
193
+ self.scale = scale
194
+ self.mask_flag = mask_flag
195
+ self.output_attention = output_attention
196
+ self.dropout = nn.Dropout(attention_dropout)
197
+
198
+ def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
199
+ # Q [B, H, L, D]
200
+ B, H, L_K, E = K.shape
201
+ _, _, L_Q, _ = Q.shape
202
+
203
+ # calculate the sampled Q_K
204
+ K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
205
+ # real U = U_part(factor*ln(L_k))*L_q
206
+ index_sample = torch.randint(L_K, (L_Q, sample_k))
207
+ K_sample = K_expand[:, :, torch.arange(
208
+ L_Q).unsqueeze(1), index_sample, :]
209
+ Q_K_sample = torch.matmul(
210
+ Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
211
+
212
+ # find the Top_k query with sparisty measurement
213
+ M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
214
+ M_top = M.topk(n_top, sorted=False)[1]
215
+
216
+ # use the reduced Q to calculate Q_K
217
+ Q_reduce = Q[torch.arange(B)[:, None, None],
218
+ torch.arange(H)[None, :, None],
219
+ M_top, :] # factor*ln(L_q)
220
+ Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
221
+
222
+ return Q_K, M_top
223
+
224
+ def _get_initial_context(self, V, L_Q):
225
+ B, H, L_V, D = V.shape
226
+ if not self.mask_flag:
227
+ # V_sum = V.sum(dim=-2)
228
+ V_sum = V.mean(dim=-2)
229
+ contex = V_sum.unsqueeze(-2).expand(B, H,
230
+ L_Q, V_sum.shape[-1]).clone()
231
+ else: # use mask
232
+ # requires that L_Q == L_V, i.e. for self-attention only
233
+ assert (L_Q == L_V)
234
+ contex = V.cumsum(dim=-2)
235
+ return contex
236
+
237
+ def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
238
+ B, H, L_V, D = V.shape
239
+
240
+ if self.mask_flag:
241
+ attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
242
+ scores.masked_fill_(attn_mask.mask, -np.inf)
243
+
244
+ attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
245
+
246
+ context_in[torch.arange(B)[:, None, None],
247
+ torch.arange(H)[None, :, None],
248
+ index, :] = torch.matmul(attn, V).type_as(context_in)
249
+ if self.output_attention:
250
+ attns = (torch.ones([B, H, L_V, L_V]) /
251
+ L_V).type_as(attn).to(attn.device)
252
+ attns[torch.arange(B)[:, None, None], torch.arange(H)[
253
+ None, :, None], index, :] = attn
254
+ return (context_in, attns)
255
+ else:
256
+ return (context_in, None)
257
+
258
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
259
+ B, L_Q, H, D = queries.shape
260
+ _, L_K, _, _ = keys.shape
261
+
262
+ queries = queries.transpose(2, 1)
263
+ keys = keys.transpose(2, 1)
264
+ values = values.transpose(2, 1)
265
+
266
+ U_part = self.factor * \
267
+ np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
268
+ u = self.factor * \
269
+ np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
270
+
271
+ U_part = U_part if U_part < L_K else L_K
272
+ u = u if u < L_Q else L_Q
273
+
274
+ scores_top, index = self._prob_QK(
275
+ queries, keys, sample_k=U_part, n_top=u)
276
+
277
+ # add scale factor
278
+ scale = self.scale or 1. / sqrt(D)
279
+ if scale is not None:
280
+ scores_top = scores_top * scale
281
+ # get the context
282
+ context = self._get_initial_context(values, L_Q)
283
+ # update the context with selected top_k queries
284
+ context, attn = self._update_context(
285
+ context, values, scores_top, index, L_Q, attn_mask)
286
+
287
+ return context.contiguous(), attn
288
+
289
+
290
+ class AttentionLayer(nn.Module):
291
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
292
+ d_values=None):
293
+ super(AttentionLayer, self).__init__()
294
+
295
+ d_keys = d_keys or (d_model // n_heads)
296
+ d_values = d_values or (d_model // n_heads)
297
+
298
+ self.inner_attention = attention
299
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
300
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
301
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
302
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
303
+ self.n_heads = n_heads
304
+
305
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
306
+ B, L, _ = queries.shape
307
+ _, S, _ = keys.shape
308
+ H = self.n_heads
309
+
310
+ queries = self.query_projection(queries).view(B, L, H, -1)
311
+ keys = self.key_projection(keys).view(B, S, H, -1)
312
+ values = self.value_projection(values).view(B, S, H, -1)
313
+
314
+ out, attn = self.inner_attention(
315
+ queries,
316
+ keys,
317
+ values,
318
+ attn_mask,
319
+ tau=tau,
320
+ delta=delta
321
+ )
322
+ out = out.view(B, L, -1)
323
+
324
+ return self.out_projection(out), attn
325
+
326
+ '''
327
+ class ReformerLayer(nn.Module):
328
+ def __init__(self, attention, d_model, n_heads, d_keys=None,
329
+ d_values=None, causal=False, bucket_size=4, n_hashes=4):
330
+ super().__init__()
331
+ self.bucket_size = bucket_size
332
+ self.attn = LSHSelfAttention(
333
+ dim=d_model,
334
+ heads=n_heads,
335
+ bucket_size=bucket_size,
336
+ n_hashes=n_hashes,
337
+ causal=causal
338
+ )
339
+
340
+ def fit_length(self, queries):
341
+ # inside reformer: assert N % (bucket_size * 2) == 0
342
+ B, N, C = queries.shape
343
+ if N % (self.bucket_size * 2) == 0:
344
+ return queries
345
+ else:
346
+ # fill the time series
347
+ fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
348
+ return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
349
+
350
+ def forward(self, queries, keys, values, attn_mask, tau, delta):
351
+ # in Reformer: defalut queries=keys
352
+ B, N, C = queries.shape
353
+ queries = self.attn(self.fit_length(queries))[:, :N, :]
354
+ return queries, None
355
+ '''
@@ -0,0 +1,134 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ConvLayer(nn.Module):
6
+ def __init__(self, c_in):
7
+ super(ConvLayer, self).__init__()
8
+ self.downConv = nn.Conv1d(in_channels=c_in,
9
+ out_channels=c_in,
10
+ kernel_size=3,
11
+ padding=2,
12
+ padding_mode='circular')
13
+ self.norm = nn.BatchNorm1d(c_in)
14
+ self.activation = nn.ELU()
15
+ self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
16
+
17
+ def forward(self, x):
18
+ x = self.downConv(x.permute(0, 2, 1))
19
+ x = self.norm(x)
20
+ x = self.activation(x)
21
+ x = self.maxPool(x)
22
+ x = x.transpose(1, 2)
23
+ return x
24
+
25
+
26
+ class EncoderLayer(nn.Module):
27
+ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation=None):
28
+ super(EncoderLayer, self).__init__()
29
+ d_ff = d_ff or 4 * d_model
30
+ self.attention = attention
31
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
32
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
33
+ self.norm1 = nn.LayerNorm(d_model)
34
+ self.norm2 = nn.LayerNorm(d_model)
35
+ self.dropout = nn.Dropout(dropout)
36
+ self.activation = activation ##my change here
37
+
38
+ def forward(self, x, attn_mask=None, tau=None, delta=None):
39
+ new_x, attn = self.attention(
40
+ x, x, x,
41
+ attn_mask=attn_mask,
42
+ tau=tau, delta=delta
43
+ )
44
+ x = x + self.dropout(new_x)
45
+
46
+ y = x = self.norm1(x)
47
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
48
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
49
+
50
+ return self.norm2(x + y), attn
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
55
+ super(Encoder, self).__init__()
56
+ self.attn_layers = nn.ModuleList(attn_layers)
57
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
58
+ self.norm = norm_layer
59
+
60
+ def forward(self, x, attn_mask=None, tau=None, delta=None):
61
+ # x [B, L, D]
62
+ attns = []
63
+ if self.conv_layers is not None:
64
+ for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
65
+ delta = delta if i == 0 else None
66
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
67
+ x = conv_layer(x)
68
+ attns.append(attn)
69
+ x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
70
+ attns.append(attn)
71
+ else:
72
+ for attn_layer in self.attn_layers:
73
+ x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
74
+ attns.append(attn)
75
+
76
+ if self.norm is not None:
77
+ x = self.norm(x)
78
+
79
+ return x, attns
80
+
81
+
82
+ class DecoderLayer(nn.Module):
83
+ def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
84
+ dropout=0.1, activation="relu"):
85
+ super(DecoderLayer, self).__init__()
86
+ d_ff = d_ff or 4 * d_model
87
+ self.self_attention = self_attention
88
+ self.cross_attention = cross_attention
89
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
90
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
91
+ self.norm1 = nn.LayerNorm(d_model)
92
+ self.norm2 = nn.LayerNorm(d_model)
93
+ self.norm3 = nn.LayerNorm(d_model)
94
+ self.dropout = nn.Dropout(dropout)
95
+ self.activation = F.relu if activation == "relu" else F.gelu
96
+
97
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
98
+ x = x + self.dropout(self.self_attention(
99
+ x, x, x,
100
+ attn_mask=x_mask,
101
+ tau=tau, delta=None
102
+ )[0])
103
+ x = self.norm1(x)
104
+
105
+ x = x + self.dropout(self.cross_attention(
106
+ x, cross, cross,
107
+ attn_mask=cross_mask,
108
+ tau=tau, delta=delta
109
+ )[0])
110
+
111
+ y = x = self.norm2(x)
112
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
113
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
114
+
115
+ return self.norm3(x + y)
116
+
117
+
118
+ class Decoder(nn.Module):
119
+ def __init__(self, layers, norm_layer=None, projection=None):
120
+ super(Decoder, self).__init__()
121
+ self.layers = nn.ModuleList(layers)
122
+ self.norm = norm_layer
123
+ self.projection = projection
124
+
125
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
126
+ for layer in self.layers:
127
+ x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
128
+
129
+ if self.norm is not None:
130
+ x = self.norm(x)
131
+
132
+ if self.projection is not None:
133
+ x = self.projection(x)
134
+ return x
File without changes
File without changes