dsipts 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dsipts might be problematic. Click here for more details.

Files changed (81) hide show
  1. dsipts/__init__.py +48 -0
  2. dsipts/data_management/__init__.py +0 -0
  3. dsipts/data_management/monash.py +338 -0
  4. dsipts/data_management/public_datasets.py +162 -0
  5. dsipts/data_structure/__init__.py +0 -0
  6. dsipts/data_structure/data_structure.py +1167 -0
  7. dsipts/data_structure/modifiers.py +213 -0
  8. dsipts/data_structure/utils.py +173 -0
  9. dsipts/models/Autoformer.py +199 -0
  10. dsipts/models/CrossFormer.py +152 -0
  11. dsipts/models/D3VAE.py +196 -0
  12. dsipts/models/Diffusion.py +818 -0
  13. dsipts/models/DilatedConv.py +342 -0
  14. dsipts/models/DilatedConvED.py +310 -0
  15. dsipts/models/Duet.py +197 -0
  16. dsipts/models/ITransformer.py +167 -0
  17. dsipts/models/Informer.py +180 -0
  18. dsipts/models/LinearTS.py +222 -0
  19. dsipts/models/PatchTST.py +181 -0
  20. dsipts/models/Persistent.py +44 -0
  21. dsipts/models/RNN.py +213 -0
  22. dsipts/models/Samformer.py +139 -0
  23. dsipts/models/TFT.py +269 -0
  24. dsipts/models/TIDE.py +296 -0
  25. dsipts/models/TTM.py +252 -0
  26. dsipts/models/TimeXER.py +184 -0
  27. dsipts/models/VQVAEA.py +299 -0
  28. dsipts/models/VVA.py +247 -0
  29. dsipts/models/__init__.py +0 -0
  30. dsipts/models/autoformer/__init__.py +0 -0
  31. dsipts/models/autoformer/layers.py +352 -0
  32. dsipts/models/base.py +439 -0
  33. dsipts/models/base_v2.py +444 -0
  34. dsipts/models/crossformer/__init__.py +0 -0
  35. dsipts/models/crossformer/attn.py +118 -0
  36. dsipts/models/crossformer/cross_decoder.py +77 -0
  37. dsipts/models/crossformer/cross_embed.py +18 -0
  38. dsipts/models/crossformer/cross_encoder.py +99 -0
  39. dsipts/models/d3vae/__init__.py +0 -0
  40. dsipts/models/d3vae/diffusion_process.py +169 -0
  41. dsipts/models/d3vae/embedding.py +108 -0
  42. dsipts/models/d3vae/encoder.py +326 -0
  43. dsipts/models/d3vae/model.py +211 -0
  44. dsipts/models/d3vae/neural_operations.py +314 -0
  45. dsipts/models/d3vae/resnet.py +153 -0
  46. dsipts/models/d3vae/utils.py +630 -0
  47. dsipts/models/duet/__init__.py +0 -0
  48. dsipts/models/duet/layers.py +438 -0
  49. dsipts/models/duet/masked.py +202 -0
  50. dsipts/models/informer/__init__.py +0 -0
  51. dsipts/models/informer/attn.py +185 -0
  52. dsipts/models/informer/decoder.py +50 -0
  53. dsipts/models/informer/embed.py +125 -0
  54. dsipts/models/informer/encoder.py +100 -0
  55. dsipts/models/itransformer/Embed.py +142 -0
  56. dsipts/models/itransformer/SelfAttention_Family.py +355 -0
  57. dsipts/models/itransformer/Transformer_EncDec.py +134 -0
  58. dsipts/models/itransformer/__init__.py +0 -0
  59. dsipts/models/patchtst/__init__.py +0 -0
  60. dsipts/models/patchtst/layers.py +569 -0
  61. dsipts/models/samformer/__init__.py +0 -0
  62. dsipts/models/samformer/utils.py +154 -0
  63. dsipts/models/tft/__init__.py +0 -0
  64. dsipts/models/tft/sub_nn.py +234 -0
  65. dsipts/models/timexer/Layers.py +127 -0
  66. dsipts/models/timexer/__init__.py +0 -0
  67. dsipts/models/ttm/__init__.py +0 -0
  68. dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
  69. dsipts/models/ttm/consts.py +16 -0
  70. dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
  71. dsipts/models/ttm/utils.py +438 -0
  72. dsipts/models/utils.py +624 -0
  73. dsipts/models/vva/__init__.py +0 -0
  74. dsipts/models/vva/minigpt.py +83 -0
  75. dsipts/models/vva/vqvae.py +459 -0
  76. dsipts/models/xlstm/__init__.py +0 -0
  77. dsipts/models/xlstm/xLSTM.py +255 -0
  78. dsipts-1.1.5.dist-info/METADATA +31 -0
  79. dsipts-1.1.5.dist-info/RECORD +81 -0
  80. dsipts-1.1.5.dist-info/WHEEL +5 -0
  81. dsipts-1.1.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,459 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+
7
+ '''
8
+ class VectorQuantizer(nn.Module):
9
+ """
10
+ Inspired from Sonnet implementation of VQ-VAE https://arxiv.org/abs/1711.00937,
11
+ in https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py and
12
+ pytorch implementation of it from zalandoresearch in https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb.
13
+
14
+ Implements the algorithm presented in
15
+ 'Neural Discrete Representation Learning' by van den Oord et al.
16
+ https://arxiv.org/abs/1711.00937
17
+
18
+ Input any tensor to be quantized. Last dimension will be used as space in
19
+ which to quantize. All other dimensions will be flattened and will be seen
20
+ as different examples to quantize.
21
+ The output tensor will have the same shape as the input.
22
+ For example a tensor with shape [16, 32, 32, 64] will be reshaped into
23
+ [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized
24
+ independently.
25
+ Args:
26
+ embedding_dim: integer representing the dimensionality of the tensors in the
27
+ quantized space. Inputs to the modules must be in this format as well.
28
+ num_embeddings: integer, the number of vectors in the quantized space.
29
+ commitment_cost: scalar which controls the weighting of the loss terms
30
+ (see equation 4 in the paper - this variable is Beta).
31
+ """
32
+
33
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, device):
34
+ super(VectorQuantizer, self).__init__()
35
+
36
+ self._embedding_dim = embedding_dim
37
+ self._num_embeddings = num_embeddings
38
+
39
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
40
+ self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
41
+
42
+ self._commitment_cost = commitment_cost
43
+ self._device = device
44
+
45
+ def forward(self, inputs, compute_distances_if_possible=True, record_codebook_stats=False):
46
+ """
47
+ Connects the module to some inputs.
48
+
49
+ Args:
50
+ inputs: Tensor, final dimension must be equal to embedding_dim. All other
51
+ leading dimensions will be flattened and treated as a large batch.
52
+
53
+ Returns:
54
+ loss: Tensor containing the loss to optimize.
55
+ quantize: Tensor containing the quantized version of the input.
56
+ perplexity: Tensor containing the perplexity of the encodings.
57
+ encodings: Tensor containing the discrete encodings, ie which element
58
+ of the quantized space each input element was mapped to.
59
+ distances
60
+ """
61
+
62
+ # Convert inputs from BCHW -> BHWC
63
+ inputs = inputs.permute(1, 2, 0).contiguous()
64
+ input_shape = inputs.shape
65
+ _, time, batch_size = input_shape
66
+
67
+ # Flatten input
68
+ flat_input = inputs.view(-1, self._embedding_dim)
69
+
70
+ # Compute distances between encoded audio frames and embedding vectors
71
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
72
+ + torch.sum(self._embedding.weight**2, dim=1)
73
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
74
+
75
+ """
76
+ encoding_indices: Tensor containing the discrete encoding indices, ie
77
+ which element of the quantized space each input element was mapped to.
78
+ """
79
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
80
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, dtype=torch.float).to(self._device)
81
+ encodings.scatter_(1, encoding_indices, 1)
82
+
83
+ # Compute distances between encoding vectors
84
+ if not self.training and compute_distances_if_possible:
85
+ _encoding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(flat_input, r=2)]
86
+ encoding_distances = torch.tensor(_encoding_distances).to(self._device).view(batch_size, -1)
87
+ else:
88
+ encoding_distances = None
89
+
90
+ # Compute distances between embedding vectors
91
+ if not self.training and compute_distances_if_possible:
92
+ _embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(self._embedding.weight, r=2)]
93
+ embedding_distances = torch.tensor(_embedding_distances).to(self._device)
94
+ else:
95
+ embedding_distances = None
96
+
97
+ # Sample nearest embedding
98
+ if not self.training and compute_distances_if_possible:
99
+ _frames_vs_embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in product(flat_input, self._embedding.weight.detach())]
100
+ frames_vs_embedding_distances = torch.tensor(_frames_vs_embedding_distances).to(self._device).view(batch_size, time, -1)
101
+ else:
102
+ frames_vs_embedding_distances = None
103
+
104
+ # Quantize and unflatten
105
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
106
+ # TODO: Check if the more readable self._embedding.weight.index_select(dim=1, index=encoding_indices) works better
107
+
108
+ concatenated_quantized = self._embedding.weight[torch.argmin(distances, dim=1).detach().cpu()] if not self.training or record_codebook_stats else None
109
+
110
+ # Losses
111
+ e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
112
+ q_latent_loss = torch.mean((quantized - inputs.detach())**2)
113
+ commitment_loss = self._commitment_cost * e_latent_loss
114
+ vq_loss = q_latent_loss + commitment_loss
115
+
116
+ quantized = inputs + (quantized - inputs).detach() # Trick to prevent backpropagation of quantized
117
+ avg_probs = torch.mean(encodings, dim=0)
118
+
119
+ """
120
+ The perplexity a useful value to track during training.
121
+ It indicates how many codes are 'active' on average.
122
+ """
123
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # Exponential entropy
124
+
125
+ # Convert quantized from BHWC -> BCHW
126
+ return vq_loss, quantized.permute(2, 0, 1).contiguous(), \
127
+ perplexity, encodings.view(batch_size, time, -1), \
128
+ distances.view(batch_size, time, -1), encoding_indices, \
129
+ {'e_latent_loss': e_latent_loss.item(), 'q_latent_loss': q_latent_loss.item(),
130
+ 'commitment_loss': commitment_loss.item(), 'vq_loss': vq_loss.item()}, \
131
+ encoding_distances, embedding_distances, frames_vs_embedding_distances, concatenated_quantized
132
+
133
+ @property
134
+ def embedding(self):
135
+ return self._embedding
136
+ '''
137
+
138
+ class VectorQuantizer(nn.Module):
139
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
140
+ super(VectorQuantizer, self).__init__()
141
+
142
+ self._embedding_dim = embedding_dim
143
+ self._num_embeddings = num_embeddings
144
+
145
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
146
+ self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
147
+ self._commitment_cost = commitment_cost
148
+
149
+ def forward(self, inputs):
150
+ # convert inputs from BCHW -> BHWC
151
+ inputs = inputs.permute(0, 2, 1).contiguous()
152
+ input_shape = inputs.shape
153
+
154
+ # Flatten input
155
+ flat_input = inputs.view(-1, self._embedding_dim)
156
+
157
+ # Calculate distances
158
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
159
+ + torch.sum(self._embedding.weight**2, dim=1)
160
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
161
+
162
+ # Encoding
163
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
164
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
165
+ encodings.scatter_(1, encoding_indices, 1)
166
+
167
+ # Quantize and unflatten
168
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
169
+
170
+ # Loss
171
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
172
+ q_latent_loss = F.mse_loss(quantized, inputs.detach())
173
+ loss = q_latent_loss + self._commitment_cost * e_latent_loss
174
+
175
+ quantized = inputs + (quantized - inputs).detach()
176
+ avg_probs = torch.mean(encodings, dim=0)
177
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
178
+
179
+ # convert quantized from BHWC -> BCHW
180
+ #return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
181
+ return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
182
+
183
+ class VectorQuantizerEMA(nn.Module):
184
+
185
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
186
+ super(VectorQuantizerEMA, self).__init__()
187
+
188
+ self._embedding_dim = embedding_dim
189
+ self._num_embeddings = num_embeddings
190
+
191
+ self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
192
+ self._embedding.weight.data.normal_()
193
+ self._commitment_cost = commitment_cost
194
+
195
+ self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
196
+ self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
197
+ self._ema_w.data.normal_()
198
+
199
+ self._decay = decay
200
+ self._epsilon = epsilon
201
+
202
+ def forward(self, inputs):
203
+ # convert inputs from BCHW -> BHWC
204
+ inputs = inputs.permute(0, 2, 1).contiguous()
205
+ input_shape = inputs.shape
206
+
207
+ # Flatten input
208
+ flat_input = inputs.view(-1, self._embedding_dim)
209
+
210
+ # Calculate distances
211
+ distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
212
+ + torch.sum(self._embedding.weight**2, dim=1)
213
+ - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
214
+
215
+ # Encoding
216
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
217
+ encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
218
+ encodings.scatter_(1, encoding_indices, 1)
219
+
220
+ # Quantize and unflatten
221
+ quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
222
+
223
+ # Use EMA to update the embedding vectors
224
+ if self.training:
225
+ self._ema_cluster_size = self._ema_cluster_size * self._decay + \
226
+ (1 - self._decay) * torch.sum(encodings, 0)
227
+
228
+ # Laplace smoothing of the cluster size
229
+ n = torch.sum(self._ema_cluster_size.data)
230
+ self._ema_cluster_size = (
231
+ (self._ema_cluster_size + self._epsilon)
232
+ / (n + self._num_embeddings * self._epsilon) * n)
233
+
234
+ dw = torch.matmul(encodings.t(), flat_input)
235
+ self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
236
+
237
+ self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
238
+
239
+ # Loss
240
+ e_latent_loss = F.mse_loss(quantized.detach(), inputs)
241
+ loss = self._commitment_cost * e_latent_loss
242
+
243
+ # Straight Through Estimator
244
+ quantized = inputs + (quantized - inputs).detach()
245
+ avg_probs = torch.mean(encodings, dim=0)
246
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
247
+
248
+ # convert quantized from BHWC -> BCHW
249
+ return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
250
+
251
+
252
+ class Residual(nn.Module):
253
+
254
+ def __init__(self, in_channels, hidden_channels, num_residual_hiddens):
255
+ super(Residual, self).__init__()
256
+
257
+ relu_1 = nn.ReLU(True)
258
+ conv_1 = nn.Conv1d(
259
+ in_channels=in_channels,
260
+ out_channels=num_residual_hiddens,
261
+ kernel_size=3,
262
+ stride=1,
263
+ padding=1,
264
+ bias=False
265
+ )
266
+
267
+
268
+ relu_2 = nn.ReLU(True)
269
+ conv_2 = nn.Conv1d(
270
+ in_channels=num_residual_hiddens,
271
+ out_channels=hidden_channels,
272
+ kernel_size=1,
273
+ stride=1,
274
+ bias=False
275
+ )
276
+
277
+
278
+ # All parameters same as specified in the paper
279
+ self._block = nn.Sequential(
280
+ relu_1,
281
+ conv_1,
282
+ relu_2,
283
+ conv_2
284
+ )
285
+
286
+ def forward(self, x):
287
+ return x + self._block(x)
288
+
289
+ class ResidualStack(nn.Module):
290
+
291
+ def __init__(self, in_channels, hidden_channels, num_residual_layers, num_residual_hiddens):
292
+ super(ResidualStack, self).__init__()
293
+
294
+ self._num_residual_layers = num_residual_layers
295
+ self._layers = nn.ModuleList([Residual(in_channels, hidden_channels, num_residual_hiddens)] * self._num_residual_layers)
296
+
297
+ def forward(self, x):
298
+ for i in range(self._num_residual_layers):
299
+ x = self._layers[i](x)
300
+ return F.relu(x)
301
+
302
+ class Encoder(nn.Module):
303
+ def __init__(self, in_channels, hidden_channels,num_residual_layers=3):
304
+ super(Encoder, self).__init__()
305
+
306
+ self._conv_1 = nn.Conv1d(in_channels=in_channels,
307
+ out_channels=hidden_channels,
308
+ kernel_size=3, padding=1)
309
+
310
+ self._conv_2 = nn.Conv1d(in_channels=hidden_channels,
311
+ out_channels=hidden_channels,
312
+ kernel_size=3,padding=1)
313
+
314
+ self._conv_3 = nn.Conv1d(in_channels=hidden_channels,
315
+ out_channels=hidden_channels,
316
+ kernel_size=4,
317
+ stride=2, padding=1)
318
+
319
+ self._conv_4 = nn.Conv1d(in_channels=hidden_channels,
320
+ out_channels=hidden_channels,
321
+ kernel_size=3,
322
+ padding=1)
323
+
324
+ self._conv_5 = nn.Conv1d(in_channels=hidden_channels,
325
+ out_channels=hidden_channels,
326
+ kernel_size=3, padding=1)
327
+
328
+ self._residual_stack = ResidualStack(
329
+ in_channels=hidden_channels,
330
+ hidden_channels=hidden_channels,
331
+ num_residual_layers=num_residual_layers,
332
+ num_residual_hiddens=hidden_channels//2
333
+ )
334
+ def forward(self, inputs):
335
+ x_conv_1 = F.relu(self._conv_1(inputs))
336
+ x = F.relu(self._conv_2(x_conv_1)) + x_conv_1
337
+ x_conv_3 = F.relu(self._conv_3(x))
338
+ x_conv_4 = F.relu(self._conv_4(x_conv_3)) + x_conv_3
339
+ x_conv_5 = F.relu(self._conv_5(x_conv_4)) + x_conv_4
340
+ x = self._residual_stack(x_conv_5) + x_conv_5
341
+
342
+
343
+ return x
344
+
345
+
346
+
347
+ class Jitter(nn.Module):
348
+ """
349
+ Jitter implementation from [Chorowski et al., 2019].
350
+ During training, each latent vector can replace either one or both of
351
+ its neighbors. As in dropout, this prevents the model from
352
+ relying on consistency across groups of tokens. Additionally,
353
+ this regularization also promotes latent representation stability
354
+ over time: a latent vector extracted at time step t must strive
355
+ to also be useful at time steps t − 1 or t + 1.
356
+ """
357
+
358
+ def __init__(self, probability=0.12):
359
+ super(Jitter, self).__init__()
360
+
361
+ self._probability = probability
362
+
363
+ def forward(self, quantized):
364
+ original_quantized = quantized.detach().clone()
365
+ length = original_quantized.size(2)
366
+ for i in range(length):
367
+ """
368
+ Each latent vector is replace with either of its neighbors with a certain probability
369
+ (0.12 from the paper).
370
+ """
371
+ replace = [True, False][np.random.choice([1, 0], p=[self._probability, 1 - self._probability])]
372
+ if replace:
373
+ if i == 0:
374
+ neighbor_index = i + 1
375
+ elif i == length - 1:
376
+ neighbor_index = i - 1
377
+ else:
378
+ """
379
+ "We independently sample whether it is to
380
+ be replaced with the token right after
381
+ or before it."
382
+ """
383
+ neighbor_index = i + np.random.choice([-1, 1], p=[0.5, 0.5])
384
+ quantized[:, :, i] = original_quantized[:, :, neighbor_index]
385
+
386
+ return quantized
387
+
388
+ class Decoder(nn.Module):
389
+ def __init__(self, in_channels, hidden_channels,out_channels,num_residual_layers=3):
390
+ super(Decoder, self).__init__()
391
+
392
+
393
+ self._jitter = Jitter(0.125)
394
+
395
+ self._conv_1 = nn.Conv1d(in_channels=in_channels,
396
+ out_channels=hidden_channels,
397
+ kernel_size=3,
398
+ padding=1)
399
+
400
+ self._upsample = nn.Upsample(scale_factor=2)
401
+
402
+ self._residual_stack = ResidualStack(
403
+ in_channels=hidden_channels,
404
+ hidden_channels=hidden_channels,
405
+ num_residual_layers=num_residual_layers,
406
+ num_residual_hiddens=hidden_channels//2
407
+ )
408
+
409
+ self._conv_trans_1 = nn.ConvTranspose1d(in_channels=hidden_channels,
410
+ out_channels=hidden_channels,
411
+ kernel_size=3, padding=1)
412
+
413
+ self._conv_trans_2 = nn.ConvTranspose1d(
414
+ in_channels=hidden_channels,
415
+ out_channels=hidden_channels,
416
+ kernel_size=4,padding=2)
417
+
418
+ self._conv_trans_3 = nn.ConvTranspose1d(in_channels=hidden_channels,
419
+ out_channels=out_channels,
420
+ kernel_size=4, padding=1)
421
+
422
+ def forward(self, x,is_training=True):
423
+ #if is_training:
424
+ # x = self._jitter(x)
425
+ x = self._conv_1(x)
426
+ x = self._upsample(x)
427
+ x = self._residual_stack(x)
428
+ x = F.relu(self._conv_trans_1(x))
429
+ x = F.relu(self._conv_trans_2(x))
430
+ x = self._conv_trans_3(x)
431
+ return x
432
+
433
+ class VQVAE(nn.Module):
434
+ def __init__(self,in_channels, hidden_channels, out_channels,num_embeddings, embedding_dim, commitment_cost, decay):
435
+ super(VQVAE, self).__init__()
436
+
437
+ self._encoder = Encoder(in_channels, hidden_channels)
438
+ self._pre_vq_conv = nn.Conv1d(in_channels=hidden_channels,
439
+ out_channels=embedding_dim,
440
+ kernel_size=1,
441
+ stride=1)
442
+ if decay > 0.0:
443
+ self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
444
+ commitment_cost, decay)
445
+ else:
446
+ logging.info('CARE NOT TESTED')
447
+ self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
448
+ commitment_cost)
449
+ self._decoder = Decoder(in_channels=embedding_dim,hidden_channels=hidden_channels,out_channels=out_channels)
450
+
451
+
452
+ def forward(self, x,is_training=True):
453
+ z = self._encoder(x)
454
+
455
+ z = self._pre_vq_conv(z)
456
+ loss, quantized, perplexity, encodings = self._vq_vae(z)
457
+ x_recon = self._decoder(quantized,is_training)
458
+
459
+ return loss, x_recon, perplexity,quantized,encodings
File without changes