dsipts 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dsipts might be problematic. Click here for more details.
- dsipts/__init__.py +48 -0
- dsipts/data_management/__init__.py +0 -0
- dsipts/data_management/monash.py +338 -0
- dsipts/data_management/public_datasets.py +162 -0
- dsipts/data_structure/__init__.py +0 -0
- dsipts/data_structure/data_structure.py +1167 -0
- dsipts/data_structure/modifiers.py +213 -0
- dsipts/data_structure/utils.py +173 -0
- dsipts/models/Autoformer.py +199 -0
- dsipts/models/CrossFormer.py +152 -0
- dsipts/models/D3VAE.py +196 -0
- dsipts/models/Diffusion.py +818 -0
- dsipts/models/DilatedConv.py +342 -0
- dsipts/models/DilatedConvED.py +310 -0
- dsipts/models/Duet.py +197 -0
- dsipts/models/ITransformer.py +167 -0
- dsipts/models/Informer.py +180 -0
- dsipts/models/LinearTS.py +222 -0
- dsipts/models/PatchTST.py +181 -0
- dsipts/models/Persistent.py +44 -0
- dsipts/models/RNN.py +213 -0
- dsipts/models/Samformer.py +139 -0
- dsipts/models/TFT.py +269 -0
- dsipts/models/TIDE.py +296 -0
- dsipts/models/TTM.py +252 -0
- dsipts/models/TimeXER.py +184 -0
- dsipts/models/VQVAEA.py +299 -0
- dsipts/models/VVA.py +247 -0
- dsipts/models/__init__.py +0 -0
- dsipts/models/autoformer/__init__.py +0 -0
- dsipts/models/autoformer/layers.py +352 -0
- dsipts/models/base.py +439 -0
- dsipts/models/base_v2.py +444 -0
- dsipts/models/crossformer/__init__.py +0 -0
- dsipts/models/crossformer/attn.py +118 -0
- dsipts/models/crossformer/cross_decoder.py +77 -0
- dsipts/models/crossformer/cross_embed.py +18 -0
- dsipts/models/crossformer/cross_encoder.py +99 -0
- dsipts/models/d3vae/__init__.py +0 -0
- dsipts/models/d3vae/diffusion_process.py +169 -0
- dsipts/models/d3vae/embedding.py +108 -0
- dsipts/models/d3vae/encoder.py +326 -0
- dsipts/models/d3vae/model.py +211 -0
- dsipts/models/d3vae/neural_operations.py +314 -0
- dsipts/models/d3vae/resnet.py +153 -0
- dsipts/models/d3vae/utils.py +630 -0
- dsipts/models/duet/__init__.py +0 -0
- dsipts/models/duet/layers.py +438 -0
- dsipts/models/duet/masked.py +202 -0
- dsipts/models/informer/__init__.py +0 -0
- dsipts/models/informer/attn.py +185 -0
- dsipts/models/informer/decoder.py +50 -0
- dsipts/models/informer/embed.py +125 -0
- dsipts/models/informer/encoder.py +100 -0
- dsipts/models/itransformer/Embed.py +142 -0
- dsipts/models/itransformer/SelfAttention_Family.py +355 -0
- dsipts/models/itransformer/Transformer_EncDec.py +134 -0
- dsipts/models/itransformer/__init__.py +0 -0
- dsipts/models/patchtst/__init__.py +0 -0
- dsipts/models/patchtst/layers.py +569 -0
- dsipts/models/samformer/__init__.py +0 -0
- dsipts/models/samformer/utils.py +154 -0
- dsipts/models/tft/__init__.py +0 -0
- dsipts/models/tft/sub_nn.py +234 -0
- dsipts/models/timexer/Layers.py +127 -0
- dsipts/models/timexer/__init__.py +0 -0
- dsipts/models/ttm/__init__.py +0 -0
- dsipts/models/ttm/configuration_tinytimemixer.py +307 -0
- dsipts/models/ttm/consts.py +16 -0
- dsipts/models/ttm/modeling_tinytimemixer.py +2099 -0
- dsipts/models/ttm/utils.py +438 -0
- dsipts/models/utils.py +624 -0
- dsipts/models/vva/__init__.py +0 -0
- dsipts/models/vva/minigpt.py +83 -0
- dsipts/models/vva/vqvae.py +459 -0
- dsipts/models/xlstm/__init__.py +0 -0
- dsipts/models/xlstm/xLSTM.py +255 -0
- dsipts-1.1.5.dist-info/METADATA +31 -0
- dsipts-1.1.5.dist-info/RECORD +81 -0
- dsipts-1.1.5.dist-info/WHEEL +5 -0
- dsipts-1.1.5.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import numpy as np
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
'''
|
|
8
|
+
class VectorQuantizer(nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Inspired from Sonnet implementation of VQ-VAE https://arxiv.org/abs/1711.00937,
|
|
11
|
+
in https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py and
|
|
12
|
+
pytorch implementation of it from zalandoresearch in https://github.com/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb.
|
|
13
|
+
|
|
14
|
+
Implements the algorithm presented in
|
|
15
|
+
'Neural Discrete Representation Learning' by van den Oord et al.
|
|
16
|
+
https://arxiv.org/abs/1711.00937
|
|
17
|
+
|
|
18
|
+
Input any tensor to be quantized. Last dimension will be used as space in
|
|
19
|
+
which to quantize. All other dimensions will be flattened and will be seen
|
|
20
|
+
as different examples to quantize.
|
|
21
|
+
The output tensor will have the same shape as the input.
|
|
22
|
+
For example a tensor with shape [16, 32, 32, 64] will be reshaped into
|
|
23
|
+
[16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized
|
|
24
|
+
independently.
|
|
25
|
+
Args:
|
|
26
|
+
embedding_dim: integer representing the dimensionality of the tensors in the
|
|
27
|
+
quantized space. Inputs to the modules must be in this format as well.
|
|
28
|
+
num_embeddings: integer, the number of vectors in the quantized space.
|
|
29
|
+
commitment_cost: scalar which controls the weighting of the loss terms
|
|
30
|
+
(see equation 4 in the paper - this variable is Beta).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost, device):
|
|
34
|
+
super(VectorQuantizer, self).__init__()
|
|
35
|
+
|
|
36
|
+
self._embedding_dim = embedding_dim
|
|
37
|
+
self._num_embeddings = num_embeddings
|
|
38
|
+
|
|
39
|
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
|
40
|
+
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
|
|
41
|
+
|
|
42
|
+
self._commitment_cost = commitment_cost
|
|
43
|
+
self._device = device
|
|
44
|
+
|
|
45
|
+
def forward(self, inputs, compute_distances_if_possible=True, record_codebook_stats=False):
|
|
46
|
+
"""
|
|
47
|
+
Connects the module to some inputs.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
inputs: Tensor, final dimension must be equal to embedding_dim. All other
|
|
51
|
+
leading dimensions will be flattened and treated as a large batch.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
loss: Tensor containing the loss to optimize.
|
|
55
|
+
quantize: Tensor containing the quantized version of the input.
|
|
56
|
+
perplexity: Tensor containing the perplexity of the encodings.
|
|
57
|
+
encodings: Tensor containing the discrete encodings, ie which element
|
|
58
|
+
of the quantized space each input element was mapped to.
|
|
59
|
+
distances
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
# Convert inputs from BCHW -> BHWC
|
|
63
|
+
inputs = inputs.permute(1, 2, 0).contiguous()
|
|
64
|
+
input_shape = inputs.shape
|
|
65
|
+
_, time, batch_size = input_shape
|
|
66
|
+
|
|
67
|
+
# Flatten input
|
|
68
|
+
flat_input = inputs.view(-1, self._embedding_dim)
|
|
69
|
+
|
|
70
|
+
# Compute distances between encoded audio frames and embedding vectors
|
|
71
|
+
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
|
|
72
|
+
+ torch.sum(self._embedding.weight**2, dim=1)
|
|
73
|
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
encoding_indices: Tensor containing the discrete encoding indices, ie
|
|
77
|
+
which element of the quantized space each input element was mapped to.
|
|
78
|
+
"""
|
|
79
|
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
|
80
|
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, dtype=torch.float).to(self._device)
|
|
81
|
+
encodings.scatter_(1, encoding_indices, 1)
|
|
82
|
+
|
|
83
|
+
# Compute distances between encoding vectors
|
|
84
|
+
if not self.training and compute_distances_if_possible:
|
|
85
|
+
_encoding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(flat_input, r=2)]
|
|
86
|
+
encoding_distances = torch.tensor(_encoding_distances).to(self._device).view(batch_size, -1)
|
|
87
|
+
else:
|
|
88
|
+
encoding_distances = None
|
|
89
|
+
|
|
90
|
+
# Compute distances between embedding vectors
|
|
91
|
+
if not self.training and compute_distances_if_possible:
|
|
92
|
+
_embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in combinations(self._embedding.weight, r=2)]
|
|
93
|
+
embedding_distances = torch.tensor(_embedding_distances).to(self._device)
|
|
94
|
+
else:
|
|
95
|
+
embedding_distances = None
|
|
96
|
+
|
|
97
|
+
# Sample nearest embedding
|
|
98
|
+
if not self.training and compute_distances_if_possible:
|
|
99
|
+
_frames_vs_embedding_distances = [torch.dist(items[0], items[1], 2).to(self._device) for items in product(flat_input, self._embedding.weight.detach())]
|
|
100
|
+
frames_vs_embedding_distances = torch.tensor(_frames_vs_embedding_distances).to(self._device).view(batch_size, time, -1)
|
|
101
|
+
else:
|
|
102
|
+
frames_vs_embedding_distances = None
|
|
103
|
+
|
|
104
|
+
# Quantize and unflatten
|
|
105
|
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
|
106
|
+
# TODO: Check if the more readable self._embedding.weight.index_select(dim=1, index=encoding_indices) works better
|
|
107
|
+
|
|
108
|
+
concatenated_quantized = self._embedding.weight[torch.argmin(distances, dim=1).detach().cpu()] if not self.training or record_codebook_stats else None
|
|
109
|
+
|
|
110
|
+
# Losses
|
|
111
|
+
e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
|
|
112
|
+
q_latent_loss = torch.mean((quantized - inputs.detach())**2)
|
|
113
|
+
commitment_loss = self._commitment_cost * e_latent_loss
|
|
114
|
+
vq_loss = q_latent_loss + commitment_loss
|
|
115
|
+
|
|
116
|
+
quantized = inputs + (quantized - inputs).detach() # Trick to prevent backpropagation of quantized
|
|
117
|
+
avg_probs = torch.mean(encodings, dim=0)
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
The perplexity a useful value to track during training.
|
|
121
|
+
It indicates how many codes are 'active' on average.
|
|
122
|
+
"""
|
|
123
|
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # Exponential entropy
|
|
124
|
+
|
|
125
|
+
# Convert quantized from BHWC -> BCHW
|
|
126
|
+
return vq_loss, quantized.permute(2, 0, 1).contiguous(), \
|
|
127
|
+
perplexity, encodings.view(batch_size, time, -1), \
|
|
128
|
+
distances.view(batch_size, time, -1), encoding_indices, \
|
|
129
|
+
{'e_latent_loss': e_latent_loss.item(), 'q_latent_loss': q_latent_loss.item(),
|
|
130
|
+
'commitment_loss': commitment_loss.item(), 'vq_loss': vq_loss.item()}, \
|
|
131
|
+
encoding_distances, embedding_distances, frames_vs_embedding_distances, concatenated_quantized
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def embedding(self):
|
|
135
|
+
return self._embedding
|
|
136
|
+
'''
|
|
137
|
+
|
|
138
|
+
class VectorQuantizer(nn.Module):
|
|
139
|
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
|
|
140
|
+
super(VectorQuantizer, self).__init__()
|
|
141
|
+
|
|
142
|
+
self._embedding_dim = embedding_dim
|
|
143
|
+
self._num_embeddings = num_embeddings
|
|
144
|
+
|
|
145
|
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
|
146
|
+
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
|
|
147
|
+
self._commitment_cost = commitment_cost
|
|
148
|
+
|
|
149
|
+
def forward(self, inputs):
|
|
150
|
+
# convert inputs from BCHW -> BHWC
|
|
151
|
+
inputs = inputs.permute(0, 2, 1).contiguous()
|
|
152
|
+
input_shape = inputs.shape
|
|
153
|
+
|
|
154
|
+
# Flatten input
|
|
155
|
+
flat_input = inputs.view(-1, self._embedding_dim)
|
|
156
|
+
|
|
157
|
+
# Calculate distances
|
|
158
|
+
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
|
|
159
|
+
+ torch.sum(self._embedding.weight**2, dim=1)
|
|
160
|
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
|
161
|
+
|
|
162
|
+
# Encoding
|
|
163
|
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
|
164
|
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
|
165
|
+
encodings.scatter_(1, encoding_indices, 1)
|
|
166
|
+
|
|
167
|
+
# Quantize and unflatten
|
|
168
|
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
|
169
|
+
|
|
170
|
+
# Loss
|
|
171
|
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
|
172
|
+
q_latent_loss = F.mse_loss(quantized, inputs.detach())
|
|
173
|
+
loss = q_latent_loss + self._commitment_cost * e_latent_loss
|
|
174
|
+
|
|
175
|
+
quantized = inputs + (quantized - inputs).detach()
|
|
176
|
+
avg_probs = torch.mean(encodings, dim=0)
|
|
177
|
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
|
178
|
+
|
|
179
|
+
# convert quantized from BHWC -> BCHW
|
|
180
|
+
#return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
|
|
181
|
+
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
|
|
182
|
+
|
|
183
|
+
class VectorQuantizerEMA(nn.Module):
|
|
184
|
+
|
|
185
|
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
|
|
186
|
+
super(VectorQuantizerEMA, self).__init__()
|
|
187
|
+
|
|
188
|
+
self._embedding_dim = embedding_dim
|
|
189
|
+
self._num_embeddings = num_embeddings
|
|
190
|
+
|
|
191
|
+
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
|
|
192
|
+
self._embedding.weight.data.normal_()
|
|
193
|
+
self._commitment_cost = commitment_cost
|
|
194
|
+
|
|
195
|
+
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
|
|
196
|
+
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
|
|
197
|
+
self._ema_w.data.normal_()
|
|
198
|
+
|
|
199
|
+
self._decay = decay
|
|
200
|
+
self._epsilon = epsilon
|
|
201
|
+
|
|
202
|
+
def forward(self, inputs):
|
|
203
|
+
# convert inputs from BCHW -> BHWC
|
|
204
|
+
inputs = inputs.permute(0, 2, 1).contiguous()
|
|
205
|
+
input_shape = inputs.shape
|
|
206
|
+
|
|
207
|
+
# Flatten input
|
|
208
|
+
flat_input = inputs.view(-1, self._embedding_dim)
|
|
209
|
+
|
|
210
|
+
# Calculate distances
|
|
211
|
+
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
|
|
212
|
+
+ torch.sum(self._embedding.weight**2, dim=1)
|
|
213
|
+
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
|
|
214
|
+
|
|
215
|
+
# Encoding
|
|
216
|
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
|
217
|
+
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
|
|
218
|
+
encodings.scatter_(1, encoding_indices, 1)
|
|
219
|
+
|
|
220
|
+
# Quantize and unflatten
|
|
221
|
+
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
|
|
222
|
+
|
|
223
|
+
# Use EMA to update the embedding vectors
|
|
224
|
+
if self.training:
|
|
225
|
+
self._ema_cluster_size = self._ema_cluster_size * self._decay + \
|
|
226
|
+
(1 - self._decay) * torch.sum(encodings, 0)
|
|
227
|
+
|
|
228
|
+
# Laplace smoothing of the cluster size
|
|
229
|
+
n = torch.sum(self._ema_cluster_size.data)
|
|
230
|
+
self._ema_cluster_size = (
|
|
231
|
+
(self._ema_cluster_size + self._epsilon)
|
|
232
|
+
/ (n + self._num_embeddings * self._epsilon) * n)
|
|
233
|
+
|
|
234
|
+
dw = torch.matmul(encodings.t(), flat_input)
|
|
235
|
+
self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
|
|
236
|
+
|
|
237
|
+
self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
|
|
238
|
+
|
|
239
|
+
# Loss
|
|
240
|
+
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
|
|
241
|
+
loss = self._commitment_cost * e_latent_loss
|
|
242
|
+
|
|
243
|
+
# Straight Through Estimator
|
|
244
|
+
quantized = inputs + (quantized - inputs).detach()
|
|
245
|
+
avg_probs = torch.mean(encodings, dim=0)
|
|
246
|
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
|
247
|
+
|
|
248
|
+
# convert quantized from BHWC -> BCHW
|
|
249
|
+
return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings.view(input_shape[0],-1,encodings.shape[1])
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class Residual(nn.Module):
|
|
253
|
+
|
|
254
|
+
def __init__(self, in_channels, hidden_channels, num_residual_hiddens):
|
|
255
|
+
super(Residual, self).__init__()
|
|
256
|
+
|
|
257
|
+
relu_1 = nn.ReLU(True)
|
|
258
|
+
conv_1 = nn.Conv1d(
|
|
259
|
+
in_channels=in_channels,
|
|
260
|
+
out_channels=num_residual_hiddens,
|
|
261
|
+
kernel_size=3,
|
|
262
|
+
stride=1,
|
|
263
|
+
padding=1,
|
|
264
|
+
bias=False
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
relu_2 = nn.ReLU(True)
|
|
269
|
+
conv_2 = nn.Conv1d(
|
|
270
|
+
in_channels=num_residual_hiddens,
|
|
271
|
+
out_channels=hidden_channels,
|
|
272
|
+
kernel_size=1,
|
|
273
|
+
stride=1,
|
|
274
|
+
bias=False
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# All parameters same as specified in the paper
|
|
279
|
+
self._block = nn.Sequential(
|
|
280
|
+
relu_1,
|
|
281
|
+
conv_1,
|
|
282
|
+
relu_2,
|
|
283
|
+
conv_2
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def forward(self, x):
|
|
287
|
+
return x + self._block(x)
|
|
288
|
+
|
|
289
|
+
class ResidualStack(nn.Module):
|
|
290
|
+
|
|
291
|
+
def __init__(self, in_channels, hidden_channels, num_residual_layers, num_residual_hiddens):
|
|
292
|
+
super(ResidualStack, self).__init__()
|
|
293
|
+
|
|
294
|
+
self._num_residual_layers = num_residual_layers
|
|
295
|
+
self._layers = nn.ModuleList([Residual(in_channels, hidden_channels, num_residual_hiddens)] * self._num_residual_layers)
|
|
296
|
+
|
|
297
|
+
def forward(self, x):
|
|
298
|
+
for i in range(self._num_residual_layers):
|
|
299
|
+
x = self._layers[i](x)
|
|
300
|
+
return F.relu(x)
|
|
301
|
+
|
|
302
|
+
class Encoder(nn.Module):
|
|
303
|
+
def __init__(self, in_channels, hidden_channels,num_residual_layers=3):
|
|
304
|
+
super(Encoder, self).__init__()
|
|
305
|
+
|
|
306
|
+
self._conv_1 = nn.Conv1d(in_channels=in_channels,
|
|
307
|
+
out_channels=hidden_channels,
|
|
308
|
+
kernel_size=3, padding=1)
|
|
309
|
+
|
|
310
|
+
self._conv_2 = nn.Conv1d(in_channels=hidden_channels,
|
|
311
|
+
out_channels=hidden_channels,
|
|
312
|
+
kernel_size=3,padding=1)
|
|
313
|
+
|
|
314
|
+
self._conv_3 = nn.Conv1d(in_channels=hidden_channels,
|
|
315
|
+
out_channels=hidden_channels,
|
|
316
|
+
kernel_size=4,
|
|
317
|
+
stride=2, padding=1)
|
|
318
|
+
|
|
319
|
+
self._conv_4 = nn.Conv1d(in_channels=hidden_channels,
|
|
320
|
+
out_channels=hidden_channels,
|
|
321
|
+
kernel_size=3,
|
|
322
|
+
padding=1)
|
|
323
|
+
|
|
324
|
+
self._conv_5 = nn.Conv1d(in_channels=hidden_channels,
|
|
325
|
+
out_channels=hidden_channels,
|
|
326
|
+
kernel_size=3, padding=1)
|
|
327
|
+
|
|
328
|
+
self._residual_stack = ResidualStack(
|
|
329
|
+
in_channels=hidden_channels,
|
|
330
|
+
hidden_channels=hidden_channels,
|
|
331
|
+
num_residual_layers=num_residual_layers,
|
|
332
|
+
num_residual_hiddens=hidden_channels//2
|
|
333
|
+
)
|
|
334
|
+
def forward(self, inputs):
|
|
335
|
+
x_conv_1 = F.relu(self._conv_1(inputs))
|
|
336
|
+
x = F.relu(self._conv_2(x_conv_1)) + x_conv_1
|
|
337
|
+
x_conv_3 = F.relu(self._conv_3(x))
|
|
338
|
+
x_conv_4 = F.relu(self._conv_4(x_conv_3)) + x_conv_3
|
|
339
|
+
x_conv_5 = F.relu(self._conv_5(x_conv_4)) + x_conv_4
|
|
340
|
+
x = self._residual_stack(x_conv_5) + x_conv_5
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
return x
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class Jitter(nn.Module):
|
|
348
|
+
"""
|
|
349
|
+
Jitter implementation from [Chorowski et al., 2019].
|
|
350
|
+
During training, each latent vector can replace either one or both of
|
|
351
|
+
its neighbors. As in dropout, this prevents the model from
|
|
352
|
+
relying on consistency across groups of tokens. Additionally,
|
|
353
|
+
this regularization also promotes latent representation stability
|
|
354
|
+
over time: a latent vector extracted at time step t must strive
|
|
355
|
+
to also be useful at time steps t − 1 or t + 1.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
def __init__(self, probability=0.12):
|
|
359
|
+
super(Jitter, self).__init__()
|
|
360
|
+
|
|
361
|
+
self._probability = probability
|
|
362
|
+
|
|
363
|
+
def forward(self, quantized):
|
|
364
|
+
original_quantized = quantized.detach().clone()
|
|
365
|
+
length = original_quantized.size(2)
|
|
366
|
+
for i in range(length):
|
|
367
|
+
"""
|
|
368
|
+
Each latent vector is replace with either of its neighbors with a certain probability
|
|
369
|
+
(0.12 from the paper).
|
|
370
|
+
"""
|
|
371
|
+
replace = [True, False][np.random.choice([1, 0], p=[self._probability, 1 - self._probability])]
|
|
372
|
+
if replace:
|
|
373
|
+
if i == 0:
|
|
374
|
+
neighbor_index = i + 1
|
|
375
|
+
elif i == length - 1:
|
|
376
|
+
neighbor_index = i - 1
|
|
377
|
+
else:
|
|
378
|
+
"""
|
|
379
|
+
"We independently sample whether it is to
|
|
380
|
+
be replaced with the token right after
|
|
381
|
+
or before it."
|
|
382
|
+
"""
|
|
383
|
+
neighbor_index = i + np.random.choice([-1, 1], p=[0.5, 0.5])
|
|
384
|
+
quantized[:, :, i] = original_quantized[:, :, neighbor_index]
|
|
385
|
+
|
|
386
|
+
return quantized
|
|
387
|
+
|
|
388
|
+
class Decoder(nn.Module):
|
|
389
|
+
def __init__(self, in_channels, hidden_channels,out_channels,num_residual_layers=3):
|
|
390
|
+
super(Decoder, self).__init__()
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
self._jitter = Jitter(0.125)
|
|
394
|
+
|
|
395
|
+
self._conv_1 = nn.Conv1d(in_channels=in_channels,
|
|
396
|
+
out_channels=hidden_channels,
|
|
397
|
+
kernel_size=3,
|
|
398
|
+
padding=1)
|
|
399
|
+
|
|
400
|
+
self._upsample = nn.Upsample(scale_factor=2)
|
|
401
|
+
|
|
402
|
+
self._residual_stack = ResidualStack(
|
|
403
|
+
in_channels=hidden_channels,
|
|
404
|
+
hidden_channels=hidden_channels,
|
|
405
|
+
num_residual_layers=num_residual_layers,
|
|
406
|
+
num_residual_hiddens=hidden_channels//2
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self._conv_trans_1 = nn.ConvTranspose1d(in_channels=hidden_channels,
|
|
410
|
+
out_channels=hidden_channels,
|
|
411
|
+
kernel_size=3, padding=1)
|
|
412
|
+
|
|
413
|
+
self._conv_trans_2 = nn.ConvTranspose1d(
|
|
414
|
+
in_channels=hidden_channels,
|
|
415
|
+
out_channels=hidden_channels,
|
|
416
|
+
kernel_size=4,padding=2)
|
|
417
|
+
|
|
418
|
+
self._conv_trans_3 = nn.ConvTranspose1d(in_channels=hidden_channels,
|
|
419
|
+
out_channels=out_channels,
|
|
420
|
+
kernel_size=4, padding=1)
|
|
421
|
+
|
|
422
|
+
def forward(self, x,is_training=True):
|
|
423
|
+
#if is_training:
|
|
424
|
+
# x = self._jitter(x)
|
|
425
|
+
x = self._conv_1(x)
|
|
426
|
+
x = self._upsample(x)
|
|
427
|
+
x = self._residual_stack(x)
|
|
428
|
+
x = F.relu(self._conv_trans_1(x))
|
|
429
|
+
x = F.relu(self._conv_trans_2(x))
|
|
430
|
+
x = self._conv_trans_3(x)
|
|
431
|
+
return x
|
|
432
|
+
|
|
433
|
+
class VQVAE(nn.Module):
|
|
434
|
+
def __init__(self,in_channels, hidden_channels, out_channels,num_embeddings, embedding_dim, commitment_cost, decay):
|
|
435
|
+
super(VQVAE, self).__init__()
|
|
436
|
+
|
|
437
|
+
self._encoder = Encoder(in_channels, hidden_channels)
|
|
438
|
+
self._pre_vq_conv = nn.Conv1d(in_channels=hidden_channels,
|
|
439
|
+
out_channels=embedding_dim,
|
|
440
|
+
kernel_size=1,
|
|
441
|
+
stride=1)
|
|
442
|
+
if decay > 0.0:
|
|
443
|
+
self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
|
|
444
|
+
commitment_cost, decay)
|
|
445
|
+
else:
|
|
446
|
+
logging.info('CARE NOT TESTED')
|
|
447
|
+
self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
|
|
448
|
+
commitment_cost)
|
|
449
|
+
self._decoder = Decoder(in_channels=embedding_dim,hidden_channels=hidden_channels,out_channels=out_channels)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def forward(self, x,is_training=True):
|
|
453
|
+
z = self._encoder(x)
|
|
454
|
+
|
|
455
|
+
z = self._pre_vq_conv(z)
|
|
456
|
+
loss, quantized, perplexity, encodings = self._vq_vae(z)
|
|
457
|
+
x_recon = self._decoder(quantized,is_training)
|
|
458
|
+
|
|
459
|
+
return loss, x_recon, perplexity,quantized,encodings
|
|
File without changes
|