torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +86 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1008 @@
|
|
|
1
|
+
# *****************************************************************************
|
|
2
|
+
# MIT License
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
5
|
+
#
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
#
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
#
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
# *****************************************************************************
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
import math
|
|
27
|
+
import typing as tp
|
|
28
|
+
from typing import Any, Dict, List, Optional
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
from torch import nn
|
|
32
|
+
from torch.nn import functional as F
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _ScaledEmbedding(torch.nn.Module):
|
|
36
|
+
r"""Make continuous embeddings and boost learning rate
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
num_embeddings (int): number of embeddings
|
|
40
|
+
embedding_dim (int): embedding dimensions
|
|
41
|
+
scale (float, optional): amount to scale learning rate (Default: 10.0)
|
|
42
|
+
smooth (bool, optional): choose to apply smoothing (Default: ``False``)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
48
|
+
if smooth:
|
|
49
|
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
|
50
|
+
# when summing gaussian, scale raises as sqrt(n), so we normalize by that.
|
|
51
|
+
weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
|
|
52
|
+
self.embedding.weight.data[:] = weight
|
|
53
|
+
self.embedding.weight.data /= scale
|
|
54
|
+
self.scale = scale
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def weight(self) -> torch.Tensor:
|
|
58
|
+
return self.embedding.weight * self.scale
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
r"""Forward pass for embedding with scale.
|
|
62
|
+
Args:
|
|
63
|
+
x (torch.Tensor): input tensor of shape `(num_embeddings)`
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
(Tensor):
|
|
67
|
+
Embedding output of shape `(num_embeddings, embedding_dim)`
|
|
68
|
+
"""
|
|
69
|
+
out = self.embedding(x) * self.scale
|
|
70
|
+
return out
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _HEncLayer(torch.nn.Module):
|
|
74
|
+
|
|
75
|
+
r"""Encoder layer. This used both by the time and the frequency branch.
|
|
76
|
+
Args:
|
|
77
|
+
chin (int): number of input channels.
|
|
78
|
+
chout (int): number of output channels.
|
|
79
|
+
kernel_size (int, optional): Kernel size for encoder (Default: 8)
|
|
80
|
+
stride (int, optional): Stride for encoder layer (Default: 4)
|
|
81
|
+
norm_groups (int, optional): number of groups for group norm. (Default: 4)
|
|
82
|
+
empty (bool, optional): used to make a layer with just the first conv. this is used
|
|
83
|
+
before merging the time and freq. branches. (Default: ``False``)
|
|
84
|
+
freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
|
|
85
|
+
norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
|
86
|
+
context (int, optional): context size for the 1x1 conv. (Default: 0)
|
|
87
|
+
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
|
|
88
|
+
pad (bool, optional): true to pad the input. Padding is done so that the output size is
|
|
89
|
+
always the input size / stride. (Default: ``True``)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
chin: int,
|
|
95
|
+
chout: int,
|
|
96
|
+
kernel_size: int = 8,
|
|
97
|
+
stride: int = 4,
|
|
98
|
+
norm_groups: int = 4,
|
|
99
|
+
empty: bool = False,
|
|
100
|
+
freq: bool = True,
|
|
101
|
+
norm_type: str = "group_norm",
|
|
102
|
+
context: int = 0,
|
|
103
|
+
dconv_kw: Optional[Dict[str, Any]] = None,
|
|
104
|
+
pad: bool = True,
|
|
105
|
+
):
|
|
106
|
+
super().__init__()
|
|
107
|
+
if dconv_kw is None:
|
|
108
|
+
dconv_kw = {}
|
|
109
|
+
norm_fn = lambda d: nn.Identity() # noqa
|
|
110
|
+
if norm_type == "group_norm":
|
|
111
|
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
112
|
+
pad_val = kernel_size // 4 if pad else 0
|
|
113
|
+
klass = nn.Conv1d
|
|
114
|
+
self.freq = freq
|
|
115
|
+
self.kernel_size = kernel_size
|
|
116
|
+
self.stride = stride
|
|
117
|
+
self.empty = empty
|
|
118
|
+
self.pad = pad_val
|
|
119
|
+
if freq:
|
|
120
|
+
kernel_size = [kernel_size, 1]
|
|
121
|
+
stride = [stride, 1]
|
|
122
|
+
pad_val = [pad_val, 0]
|
|
123
|
+
klass = nn.Conv2d
|
|
124
|
+
self.conv = klass(chin, chout, kernel_size, stride, pad_val)
|
|
125
|
+
self.norm1 = norm_fn(chout)
|
|
126
|
+
|
|
127
|
+
if self.empty:
|
|
128
|
+
self.rewrite = nn.Identity()
|
|
129
|
+
self.norm2 = nn.Identity()
|
|
130
|
+
self.dconv = nn.Identity()
|
|
131
|
+
else:
|
|
132
|
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
|
133
|
+
self.norm2 = norm_fn(2 * chout)
|
|
134
|
+
self.dconv = _DConv(chout, **dconv_kw)
|
|
135
|
+
|
|
136
|
+
def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
137
|
+
r"""Forward pass for encoding layer.
|
|
138
|
+
|
|
139
|
+
Size depends on whether frequency or time
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
|
|
143
|
+
`(B, C, T)` for time
|
|
144
|
+
inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
|
|
145
|
+
same shape as x (default: ``None``)
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tensor
|
|
149
|
+
output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
|
|
150
|
+
and shape `(B, C, ceil(T / stride))` for time
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
if not self.freq and x.dim() == 4:
|
|
154
|
+
B, C, Fr, T = x.shape
|
|
155
|
+
x = x.view(B, -1, T)
|
|
156
|
+
|
|
157
|
+
if not self.freq:
|
|
158
|
+
le = x.shape[-1]
|
|
159
|
+
if not le % self.stride == 0:
|
|
160
|
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
|
161
|
+
y = self.conv(x)
|
|
162
|
+
if self.empty:
|
|
163
|
+
return y
|
|
164
|
+
if inject is not None:
|
|
165
|
+
if inject.shape[-1] != y.shape[-1]:
|
|
166
|
+
raise ValueError("Injection shapes do not align")
|
|
167
|
+
if inject.dim() == 3 and y.dim() == 4:
|
|
168
|
+
inject = inject[:, :, None]
|
|
169
|
+
y = y + inject
|
|
170
|
+
y = F.gelu(self.norm1(y))
|
|
171
|
+
if self.freq:
|
|
172
|
+
B, C, Fr, T = y.shape
|
|
173
|
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
|
174
|
+
y = self.dconv(y)
|
|
175
|
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
|
176
|
+
else:
|
|
177
|
+
y = self.dconv(y)
|
|
178
|
+
z = self.norm2(self.rewrite(y))
|
|
179
|
+
z = F.glu(z, dim=1)
|
|
180
|
+
return z
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class _HDecLayer(torch.nn.Module):
|
|
184
|
+
r"""Decoder layer. This used both by the time and the frequency branches.
|
|
185
|
+
Args:
|
|
186
|
+
chin (int): number of input channels.
|
|
187
|
+
chout (int): number of output channels.
|
|
188
|
+
last (bool, optional): whether current layer is final layer (Default: ``False``)
|
|
189
|
+
kernel_size (int, optional): Kernel size for encoder (Default: 8)
|
|
190
|
+
stride (int): Stride for encoder layer (Default: 4)
|
|
191
|
+
norm_groups (int, optional): number of groups for group norm. (Default: 1)
|
|
192
|
+
empty (bool, optional): used to make a layer with just the first conv. this is used
|
|
193
|
+
before merging the time and freq. branches. (Default: ``False``)
|
|
194
|
+
freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
|
|
195
|
+
norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
|
196
|
+
context (int, optional): context size for the 1x1 conv. (Default: 1)
|
|
197
|
+
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
|
|
198
|
+
pad (bool, optional): true to pad the input. Padding is done so that the output size is
|
|
199
|
+
always the input size / stride. (Default: ``True``)
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
chin: int,
|
|
205
|
+
chout: int,
|
|
206
|
+
last: bool = False,
|
|
207
|
+
kernel_size: int = 8,
|
|
208
|
+
stride: int = 4,
|
|
209
|
+
norm_groups: int = 1,
|
|
210
|
+
empty: bool = False,
|
|
211
|
+
freq: bool = True,
|
|
212
|
+
norm_type: str = "group_norm",
|
|
213
|
+
context: int = 1,
|
|
214
|
+
dconv_kw: Optional[Dict[str, Any]] = None,
|
|
215
|
+
pad: bool = True,
|
|
216
|
+
):
|
|
217
|
+
super().__init__()
|
|
218
|
+
if dconv_kw is None:
|
|
219
|
+
dconv_kw = {}
|
|
220
|
+
norm_fn = lambda d: nn.Identity() # noqa
|
|
221
|
+
if norm_type == "group_norm":
|
|
222
|
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
|
223
|
+
if pad:
|
|
224
|
+
if (kernel_size - stride) % 2 != 0:
|
|
225
|
+
raise ValueError("Kernel size and stride do not align")
|
|
226
|
+
pad = (kernel_size - stride) // 2
|
|
227
|
+
else:
|
|
228
|
+
pad = 0
|
|
229
|
+
self.pad = pad
|
|
230
|
+
self.last = last
|
|
231
|
+
self.freq = freq
|
|
232
|
+
self.chin = chin
|
|
233
|
+
self.empty = empty
|
|
234
|
+
self.stride = stride
|
|
235
|
+
self.kernel_size = kernel_size
|
|
236
|
+
klass = nn.Conv1d
|
|
237
|
+
klass_tr = nn.ConvTranspose1d
|
|
238
|
+
if freq:
|
|
239
|
+
kernel_size = [kernel_size, 1]
|
|
240
|
+
stride = [stride, 1]
|
|
241
|
+
klass = nn.Conv2d
|
|
242
|
+
klass_tr = nn.ConvTranspose2d
|
|
243
|
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
|
244
|
+
self.norm2 = norm_fn(chout)
|
|
245
|
+
if self.empty:
|
|
246
|
+
self.rewrite = nn.Identity()
|
|
247
|
+
self.norm1 = nn.Identity()
|
|
248
|
+
else:
|
|
249
|
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
|
250
|
+
self.norm1 = norm_fn(2 * chin)
|
|
251
|
+
|
|
252
|
+
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
|
|
253
|
+
r"""Forward pass for decoding layer.
|
|
254
|
+
|
|
255
|
+
Size depends on whether frequency or time
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
|
|
259
|
+
`(B, C, T)` for time
|
|
260
|
+
skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
|
|
261
|
+
(default: ``None``)
|
|
262
|
+
length (int): Size of tensor for output
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
(Tensor, Tensor):
|
|
266
|
+
Tensor
|
|
267
|
+
output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
|
|
268
|
+
frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
|
|
269
|
+
for time domain.
|
|
270
|
+
Tensor
|
|
271
|
+
contains the output just before final transposed convolution, which is used when the
|
|
272
|
+
freq. and time branch separate. Otherwise, does not matter. Shape is
|
|
273
|
+
`(B, C, F, T)` for frequency and `(B, C, T)` for time.
|
|
274
|
+
"""
|
|
275
|
+
if self.freq and x.dim() == 3:
|
|
276
|
+
B, C, T = x.shape
|
|
277
|
+
x = x.view(B, self.chin, -1, T)
|
|
278
|
+
|
|
279
|
+
if not self.empty:
|
|
280
|
+
x = x + skip
|
|
281
|
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
|
282
|
+
else:
|
|
283
|
+
y = x
|
|
284
|
+
if skip is not None:
|
|
285
|
+
raise ValueError("Skip must be none when empty is true.")
|
|
286
|
+
|
|
287
|
+
z = self.norm2(self.conv_tr(y))
|
|
288
|
+
if self.freq:
|
|
289
|
+
if self.pad:
|
|
290
|
+
z = z[..., self.pad : -self.pad, :]
|
|
291
|
+
else:
|
|
292
|
+
z = z[..., self.pad : self.pad + length]
|
|
293
|
+
if z.shape[-1] != length:
|
|
294
|
+
raise ValueError("Last index of z must be equal to length")
|
|
295
|
+
if not self.last:
|
|
296
|
+
z = F.gelu(z)
|
|
297
|
+
|
|
298
|
+
return z, y
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class HDemucs(torch.nn.Module):
|
|
302
|
+
r"""Hybrid Demucs model from
|
|
303
|
+
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
|
|
304
|
+
|
|
305
|
+
See Also:
|
|
306
|
+
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
sources (List[str]): list of source names. List can contain the following source
|
|
310
|
+
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
|
|
311
|
+
audio_channels (int, optional): input/output audio channels. (Default: 2)
|
|
312
|
+
channels (int, optional): initial number of hidden channels. (Default: 48)
|
|
313
|
+
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
|
|
314
|
+
nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
|
|
315
|
+
various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
|
|
316
|
+
depth (int, optional): number of layers in encoder and decoder (Default: 6)
|
|
317
|
+
freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
|
|
318
|
+
the actual value controls the weight of the embedding. (Default: 0.2)
|
|
319
|
+
emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
|
|
320
|
+
emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
|
|
321
|
+
(Default: ``True``)
|
|
322
|
+
kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
|
|
323
|
+
time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
|
|
324
|
+
stride (int, optional): stride for encoder and decoder layers. (Default: 4)
|
|
325
|
+
context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
|
|
326
|
+
context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
|
|
327
|
+
norm_starts (int, optional): layer at which group norm starts being used.
|
|
328
|
+
decoder layers are numbered in reverse order. (Default: 4)
|
|
329
|
+
norm_groups (int, optional): number of groups for group norm. (Default: 4)
|
|
330
|
+
dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
|
|
331
|
+
dconv_comp (int, optional): compression of DConv branch. (Default: 4)
|
|
332
|
+
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
|
|
333
|
+
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
|
|
334
|
+
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
def __init__(
|
|
338
|
+
self,
|
|
339
|
+
sources: List[str],
|
|
340
|
+
audio_channels: int = 2,
|
|
341
|
+
channels: int = 48,
|
|
342
|
+
growth: int = 2,
|
|
343
|
+
nfft: int = 4096,
|
|
344
|
+
depth: int = 6,
|
|
345
|
+
freq_emb: float = 0.2,
|
|
346
|
+
emb_scale: int = 10,
|
|
347
|
+
emb_smooth: bool = True,
|
|
348
|
+
kernel_size: int = 8,
|
|
349
|
+
time_stride: int = 2,
|
|
350
|
+
stride: int = 4,
|
|
351
|
+
context: int = 1,
|
|
352
|
+
context_enc: int = 0,
|
|
353
|
+
norm_starts: int = 4,
|
|
354
|
+
norm_groups: int = 4,
|
|
355
|
+
dconv_depth: int = 2,
|
|
356
|
+
dconv_comp: int = 4,
|
|
357
|
+
dconv_attn: int = 4,
|
|
358
|
+
dconv_lstm: int = 4,
|
|
359
|
+
dconv_init: float = 1e-4,
|
|
360
|
+
):
|
|
361
|
+
super().__init__()
|
|
362
|
+
self.depth = depth
|
|
363
|
+
self.nfft = nfft
|
|
364
|
+
self.audio_channels = audio_channels
|
|
365
|
+
self.sources = sources
|
|
366
|
+
self.kernel_size = kernel_size
|
|
367
|
+
self.context = context
|
|
368
|
+
self.stride = stride
|
|
369
|
+
self.channels = channels
|
|
370
|
+
|
|
371
|
+
self.hop_length = self.nfft // 4
|
|
372
|
+
self.freq_emb = None
|
|
373
|
+
|
|
374
|
+
self.freq_encoder = nn.ModuleList()
|
|
375
|
+
self.freq_decoder = nn.ModuleList()
|
|
376
|
+
|
|
377
|
+
self.time_encoder = nn.ModuleList()
|
|
378
|
+
self.time_decoder = nn.ModuleList()
|
|
379
|
+
|
|
380
|
+
chin = audio_channels
|
|
381
|
+
chin_z = chin * 2 # number of channels for the freq branch
|
|
382
|
+
chout = channels
|
|
383
|
+
chout_z = channels
|
|
384
|
+
freqs = self.nfft // 2
|
|
385
|
+
|
|
386
|
+
for index in range(self.depth):
|
|
387
|
+
lstm = index >= dconv_lstm
|
|
388
|
+
attn = index >= dconv_attn
|
|
389
|
+
norm_type = "group_norm" if index >= norm_starts else "none"
|
|
390
|
+
freq = freqs > 1
|
|
391
|
+
stri = stride
|
|
392
|
+
ker = kernel_size
|
|
393
|
+
if not freq:
|
|
394
|
+
if freqs != 1:
|
|
395
|
+
raise ValueError("When freq is false, freqs must be 1.")
|
|
396
|
+
ker = time_stride * 2
|
|
397
|
+
stri = time_stride
|
|
398
|
+
|
|
399
|
+
pad = True
|
|
400
|
+
last_freq = False
|
|
401
|
+
if freq and freqs <= kernel_size:
|
|
402
|
+
ker = freqs
|
|
403
|
+
pad = False
|
|
404
|
+
last_freq = True
|
|
405
|
+
|
|
406
|
+
kw = {
|
|
407
|
+
"kernel_size": ker,
|
|
408
|
+
"stride": stri,
|
|
409
|
+
"freq": freq,
|
|
410
|
+
"pad": pad,
|
|
411
|
+
"norm_type": norm_type,
|
|
412
|
+
"norm_groups": norm_groups,
|
|
413
|
+
"dconv_kw": {
|
|
414
|
+
"lstm": lstm,
|
|
415
|
+
"attn": attn,
|
|
416
|
+
"depth": dconv_depth,
|
|
417
|
+
"compress": dconv_comp,
|
|
418
|
+
"init": dconv_init,
|
|
419
|
+
},
|
|
420
|
+
}
|
|
421
|
+
kwt = dict(kw)
|
|
422
|
+
kwt["freq"] = 0
|
|
423
|
+
kwt["kernel_size"] = kernel_size
|
|
424
|
+
kwt["stride"] = stride
|
|
425
|
+
kwt["pad"] = True
|
|
426
|
+
kw_dec = dict(kw)
|
|
427
|
+
|
|
428
|
+
if last_freq:
|
|
429
|
+
chout_z = max(chout, chout_z)
|
|
430
|
+
chout = chout_z
|
|
431
|
+
|
|
432
|
+
enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
|
|
433
|
+
if freq:
|
|
434
|
+
if last_freq is True and nfft == 2048:
|
|
435
|
+
kwt["stride"] = 2
|
|
436
|
+
kwt["kernel_size"] = 4
|
|
437
|
+
tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
|
|
438
|
+
self.time_encoder.append(tenc)
|
|
439
|
+
|
|
440
|
+
self.freq_encoder.append(enc)
|
|
441
|
+
if index == 0:
|
|
442
|
+
chin = self.audio_channels * len(self.sources)
|
|
443
|
+
chin_z = chin * 2
|
|
444
|
+
dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
|
|
445
|
+
if freq:
|
|
446
|
+
tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
|
|
447
|
+
self.time_decoder.insert(0, tdec)
|
|
448
|
+
self.freq_decoder.insert(0, dec)
|
|
449
|
+
|
|
450
|
+
chin = chout
|
|
451
|
+
chin_z = chout_z
|
|
452
|
+
chout = int(growth * chout)
|
|
453
|
+
chout_z = int(growth * chout_z)
|
|
454
|
+
if freq:
|
|
455
|
+
if freqs <= kernel_size:
|
|
456
|
+
freqs = 1
|
|
457
|
+
else:
|
|
458
|
+
freqs //= stride
|
|
459
|
+
if index == 0 and freq_emb:
|
|
460
|
+
self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
|
461
|
+
self.freq_emb_scale = freq_emb
|
|
462
|
+
|
|
463
|
+
_rescale_module(self)
|
|
464
|
+
|
|
465
|
+
def _spec(self, x):
|
|
466
|
+
hl = self.hop_length
|
|
467
|
+
nfft = self.nfft
|
|
468
|
+
x0 = x # noqa
|
|
469
|
+
|
|
470
|
+
# We re-pad the signal in order to keep the property
|
|
471
|
+
# that the size of the output is exactly the size of the input
|
|
472
|
+
# divided by the stride (here hop_length), when divisible.
|
|
473
|
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
|
474
|
+
# which is not supported by torch.stft.
|
|
475
|
+
# Having all convolution operations follow this convention allow to easily
|
|
476
|
+
# align the time and frequency branches later on.
|
|
477
|
+
if hl != nfft // 4:
|
|
478
|
+
raise ValueError("Hop length must be nfft // 4")
|
|
479
|
+
le = int(math.ceil(x.shape[-1] / hl))
|
|
480
|
+
pad = hl // 2 * 3
|
|
481
|
+
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
|
|
482
|
+
|
|
483
|
+
z = _spectro(x, nfft, hl)[..., :-1, :]
|
|
484
|
+
if z.shape[-1] != le + 4:
|
|
485
|
+
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
|
|
486
|
+
z = z[..., 2 : 2 + le]
|
|
487
|
+
return z
|
|
488
|
+
|
|
489
|
+
def _ispec(self, z, length=None):
|
|
490
|
+
hl = self.hop_length
|
|
491
|
+
z = F.pad(z, [0, 0, 0, 1])
|
|
492
|
+
z = F.pad(z, [2, 2])
|
|
493
|
+
pad = hl // 2 * 3
|
|
494
|
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
|
495
|
+
x = _ispectro(z, hl, length=le)
|
|
496
|
+
x = x[..., pad : pad + length]
|
|
497
|
+
return x
|
|
498
|
+
|
|
499
|
+
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
|
|
500
|
+
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
|
|
501
|
+
Add extra zero padding around in order for padding to not break."""
|
|
502
|
+
length = x.shape[-1]
|
|
503
|
+
if mode == "reflect":
|
|
504
|
+
max_pad = max(padding_left, padding_right)
|
|
505
|
+
if length <= max_pad:
|
|
506
|
+
x = F.pad(x, (0, max_pad - length + 1))
|
|
507
|
+
return F.pad(x, (padding_left, padding_right), mode, value)
|
|
508
|
+
|
|
509
|
+
def _magnitude(self, z):
|
|
510
|
+
# move the complex dimension to the channel one.
|
|
511
|
+
B, C, Fr, T = z.shape
|
|
512
|
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
|
513
|
+
m = m.reshape(B, C * 2, Fr, T)
|
|
514
|
+
return m
|
|
515
|
+
|
|
516
|
+
def _mask(self, m):
|
|
517
|
+
# `m` is a full spectrogram and `z` is ignored.
|
|
518
|
+
B, S, C, Fr, T = m.shape
|
|
519
|
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
|
520
|
+
out = torch.view_as_complex(out.contiguous())
|
|
521
|
+
return out
|
|
522
|
+
|
|
523
|
+
def forward(self, input: torch.Tensor):
|
|
524
|
+
|
|
525
|
+
r"""HDemucs forward call
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
Tensor
|
|
532
|
+
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
if input.ndim != 3:
|
|
536
|
+
raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
|
|
537
|
+
|
|
538
|
+
if input.shape[1] != self.audio_channels:
|
|
539
|
+
raise ValueError(
|
|
540
|
+
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
|
|
541
|
+
f"Found:{input.shape[1]}."
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
x = input
|
|
545
|
+
length = x.shape[-1]
|
|
546
|
+
|
|
547
|
+
z = self._spec(input)
|
|
548
|
+
mag = self._magnitude(z)
|
|
549
|
+
x = mag
|
|
550
|
+
|
|
551
|
+
B, C, Fq, T = x.shape
|
|
552
|
+
|
|
553
|
+
# unlike previous Demucs, we always normalize because it is easier.
|
|
554
|
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
|
555
|
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
|
556
|
+
x = (x - mean) / (1e-5 + std)
|
|
557
|
+
# x will be the freq. branch input.
|
|
558
|
+
|
|
559
|
+
# Prepare the time branch input.
|
|
560
|
+
xt = input
|
|
561
|
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
|
562
|
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
|
563
|
+
xt = (xt - meant) / (1e-5 + stdt)
|
|
564
|
+
|
|
565
|
+
saved = [] # skip connections, freq.
|
|
566
|
+
saved_t = [] # skip connections, time.
|
|
567
|
+
lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
|
|
568
|
+
lengths_t: List[int] = [] # saved lengths for time branch.
|
|
569
|
+
|
|
570
|
+
for idx, encode in enumerate(self.freq_encoder):
|
|
571
|
+
lengths.append(x.shape[-1])
|
|
572
|
+
inject = None
|
|
573
|
+
if idx < len(self.time_encoder):
|
|
574
|
+
# we have not yet merged branches.
|
|
575
|
+
lengths_t.append(xt.shape[-1])
|
|
576
|
+
tenc = self.time_encoder[idx]
|
|
577
|
+
xt = tenc(xt)
|
|
578
|
+
if not tenc.empty:
|
|
579
|
+
# save for skip connection
|
|
580
|
+
saved_t.append(xt)
|
|
581
|
+
else:
|
|
582
|
+
# tenc contains just the first conv., so that now time and freq.
|
|
583
|
+
# branches have the same shape and can be merged.
|
|
584
|
+
inject = xt
|
|
585
|
+
x = encode(x, inject)
|
|
586
|
+
if idx == 0 and self.freq_emb is not None:
|
|
587
|
+
# add frequency embedding to allow for non equivariant convolutions
|
|
588
|
+
# over the frequency axis.
|
|
589
|
+
frs = torch.arange(x.shape[-2], device=x.device)
|
|
590
|
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
|
591
|
+
x = x + self.freq_emb_scale * emb
|
|
592
|
+
|
|
593
|
+
saved.append(x)
|
|
594
|
+
|
|
595
|
+
x = torch.zeros_like(x)
|
|
596
|
+
xt = torch.zeros_like(x)
|
|
597
|
+
# initialize everything to zero (signal will go through u-net skips).
|
|
598
|
+
|
|
599
|
+
for idx, decode in enumerate(self.freq_decoder):
|
|
600
|
+
skip = saved.pop(-1)
|
|
601
|
+
x, pre = decode(x, skip, lengths.pop(-1))
|
|
602
|
+
# `pre` contains the output just before final transposed convolution,
|
|
603
|
+
# which is used when the freq. and time branch separate.
|
|
604
|
+
offset = self.depth - len(self.time_decoder)
|
|
605
|
+
if idx >= offset:
|
|
606
|
+
tdec = self.time_decoder[idx - offset]
|
|
607
|
+
length_t = lengths_t.pop(-1)
|
|
608
|
+
if tdec.empty:
|
|
609
|
+
if pre.shape[2] != 1:
|
|
610
|
+
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
|
|
611
|
+
pre = pre[:, :, 0]
|
|
612
|
+
xt, _ = tdec(pre, None, length_t)
|
|
613
|
+
else:
|
|
614
|
+
skip = saved_t.pop(-1)
|
|
615
|
+
xt, _ = tdec(xt, skip, length_t)
|
|
616
|
+
|
|
617
|
+
if len(saved) != 0:
|
|
618
|
+
raise AssertionError("saved is not empty")
|
|
619
|
+
if len(lengths_t) != 0:
|
|
620
|
+
raise AssertionError("lengths_t is not empty")
|
|
621
|
+
if len(saved_t) != 0:
|
|
622
|
+
raise AssertionError("saved_t is not empty")
|
|
623
|
+
|
|
624
|
+
S = len(self.sources)
|
|
625
|
+
x = x.view(B, S, -1, Fq, T)
|
|
626
|
+
x = x * std[:, None] + mean[:, None]
|
|
627
|
+
|
|
628
|
+
zout = self._mask(x)
|
|
629
|
+
x = self._ispec(zout, length)
|
|
630
|
+
|
|
631
|
+
xt = xt.view(B, S, -1, length)
|
|
632
|
+
xt = xt * stdt[:, None] + meant[:, None]
|
|
633
|
+
x = xt + x
|
|
634
|
+
return x
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
class _DConv(torch.nn.Module):
|
|
638
|
+
r"""
|
|
639
|
+
New residual branches in each encoder layer.
|
|
640
|
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
|
641
|
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
|
642
|
+
e.g. of dim `channels // compress`.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
channels (int): input/output channels for residual branch.
|
|
646
|
+
compress (float, optional): amount of channel compression inside the branch. (default: 4)
|
|
647
|
+
depth (int, optional): number of layers in the residual branch. Each layer has its own
|
|
648
|
+
projection, and potentially LSTM and attention.(default: 2)
|
|
649
|
+
init (float, optional): initial scale for LayerNorm. (default: 1e-4)
|
|
650
|
+
norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
|
|
651
|
+
attn (bool, optional): use LocalAttention. (Default: ``False``)
|
|
652
|
+
heads (int, optional): number of heads for the LocalAttention. (default: 4)
|
|
653
|
+
ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
|
|
654
|
+
lstm (bool, optional): use LSTM. (Default: ``False``)
|
|
655
|
+
kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
def __init__(
|
|
659
|
+
self,
|
|
660
|
+
channels: int,
|
|
661
|
+
compress: float = 4,
|
|
662
|
+
depth: int = 2,
|
|
663
|
+
init: float = 1e-4,
|
|
664
|
+
norm_type: str = "group_norm",
|
|
665
|
+
attn: bool = False,
|
|
666
|
+
heads: int = 4,
|
|
667
|
+
ndecay: int = 4,
|
|
668
|
+
lstm: bool = False,
|
|
669
|
+
kernel_size: int = 3,
|
|
670
|
+
):
|
|
671
|
+
|
|
672
|
+
super().__init__()
|
|
673
|
+
if kernel_size % 2 == 0:
|
|
674
|
+
raise ValueError("Kernel size should not be divisible by 2")
|
|
675
|
+
self.channels = channels
|
|
676
|
+
self.compress = compress
|
|
677
|
+
self.depth = abs(depth)
|
|
678
|
+
dilate = depth > 0
|
|
679
|
+
|
|
680
|
+
norm_fn: tp.Callable[[int], nn.Module]
|
|
681
|
+
norm_fn = lambda d: nn.Identity() # noqa
|
|
682
|
+
if norm_type == "group_norm":
|
|
683
|
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
|
684
|
+
|
|
685
|
+
hidden = int(channels / compress)
|
|
686
|
+
|
|
687
|
+
act = nn.GELU
|
|
688
|
+
|
|
689
|
+
self.layers = nn.ModuleList([])
|
|
690
|
+
for d in range(self.depth):
|
|
691
|
+
dilation = pow(2, d) if dilate else 1
|
|
692
|
+
padding = dilation * (kernel_size // 2)
|
|
693
|
+
mods = [
|
|
694
|
+
nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
|
|
695
|
+
norm_fn(hidden),
|
|
696
|
+
act(),
|
|
697
|
+
nn.Conv1d(hidden, 2 * channels, 1),
|
|
698
|
+
norm_fn(2 * channels),
|
|
699
|
+
nn.GLU(1),
|
|
700
|
+
_LayerScale(channels, init),
|
|
701
|
+
]
|
|
702
|
+
if attn:
|
|
703
|
+
mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
|
|
704
|
+
if lstm:
|
|
705
|
+
mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
|
|
706
|
+
layer = nn.Sequential(*mods)
|
|
707
|
+
self.layers.append(layer)
|
|
708
|
+
|
|
709
|
+
def forward(self, x):
|
|
710
|
+
r"""DConv forward call
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
x (torch.Tensor): input tensor for convolution
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
Tensor
|
|
717
|
+
Output after being run through layers.
|
|
718
|
+
"""
|
|
719
|
+
for layer in self.layers:
|
|
720
|
+
x = x + layer(x)
|
|
721
|
+
return x
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class _BLSTM(torch.nn.Module):
|
|
725
|
+
r"""
|
|
726
|
+
BiLSTM with same hidden units as input dim.
|
|
727
|
+
If `max_steps` is not None, input will be splitting in overlapping
|
|
728
|
+
chunks and the LSTM applied separately on each chunk.
|
|
729
|
+
Args:
|
|
730
|
+
dim (int): dimensions at LSTM layer.
|
|
731
|
+
layers (int, optional): number of LSTM layers. (default: 1)
|
|
732
|
+
skip (bool, optional): (default: ``False``)
|
|
733
|
+
"""
|
|
734
|
+
|
|
735
|
+
def __init__(self, dim, layers: int = 1, skip: bool = False):
|
|
736
|
+
super().__init__()
|
|
737
|
+
self.max_steps = 200
|
|
738
|
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
|
739
|
+
self.linear = nn.Linear(2 * dim, dim)
|
|
740
|
+
self.skip = skip
|
|
741
|
+
|
|
742
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
743
|
+
r"""BLSTM forward call
|
|
744
|
+
|
|
745
|
+
Args:
|
|
746
|
+
x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
Tensor
|
|
750
|
+
Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
|
|
751
|
+
"""
|
|
752
|
+
B, C, T = x.shape
|
|
753
|
+
y = x
|
|
754
|
+
framed = False
|
|
755
|
+
width = 0
|
|
756
|
+
stride = 0
|
|
757
|
+
nframes = 0
|
|
758
|
+
if self.max_steps is not None and T > self.max_steps:
|
|
759
|
+
width = self.max_steps
|
|
760
|
+
stride = width // 2
|
|
761
|
+
frames = _unfold(x, width, stride)
|
|
762
|
+
nframes = frames.shape[2]
|
|
763
|
+
framed = True
|
|
764
|
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
|
765
|
+
|
|
766
|
+
x = x.permute(2, 0, 1)
|
|
767
|
+
|
|
768
|
+
x = self.lstm(x)[0]
|
|
769
|
+
x = self.linear(x)
|
|
770
|
+
x = x.permute(1, 2, 0)
|
|
771
|
+
if framed:
|
|
772
|
+
out = []
|
|
773
|
+
frames = x.reshape(B, -1, C, width)
|
|
774
|
+
limit = stride // 2
|
|
775
|
+
for k in range(nframes):
|
|
776
|
+
if k == 0:
|
|
777
|
+
out.append(frames[:, k, :, :-limit])
|
|
778
|
+
elif k == nframes - 1:
|
|
779
|
+
out.append(frames[:, k, :, limit:])
|
|
780
|
+
else:
|
|
781
|
+
out.append(frames[:, k, :, limit:-limit])
|
|
782
|
+
out = torch.cat(out, -1)
|
|
783
|
+
out = out[..., :T]
|
|
784
|
+
x = out
|
|
785
|
+
if self.skip:
|
|
786
|
+
x = x + y
|
|
787
|
+
|
|
788
|
+
return x
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
class _LocalState(nn.Module):
|
|
792
|
+
"""Local state allows to have attention based only on data (no positional embedding),
|
|
793
|
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
|
794
|
+
Also a failed experiments with trying to provide some frequency based attention.
|
|
795
|
+
"""
|
|
796
|
+
|
|
797
|
+
def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
|
|
798
|
+
r"""
|
|
799
|
+
Args:
|
|
800
|
+
channels (int): Size of Conv1d layers.
|
|
801
|
+
heads (int, optional): (default: 4)
|
|
802
|
+
ndecay (int, optional): (default: 4)
|
|
803
|
+
"""
|
|
804
|
+
super(_LocalState, self).__init__()
|
|
805
|
+
if channels % heads != 0:
|
|
806
|
+
raise ValueError("Channels must be divisible by heads.")
|
|
807
|
+
self.heads = heads
|
|
808
|
+
self.ndecay = ndecay
|
|
809
|
+
self.content = nn.Conv1d(channels, channels, 1)
|
|
810
|
+
self.query = nn.Conv1d(channels, channels, 1)
|
|
811
|
+
self.key = nn.Conv1d(channels, channels, 1)
|
|
812
|
+
|
|
813
|
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
|
814
|
+
if ndecay:
|
|
815
|
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
|
816
|
+
self.query_decay.weight.data *= 0.01
|
|
817
|
+
if self.query_decay.bias is None:
|
|
818
|
+
raise ValueError("bias must not be None.")
|
|
819
|
+
self.query_decay.bias.data[:] = -2
|
|
820
|
+
self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
|
|
821
|
+
|
|
822
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
823
|
+
r"""LocalState forward call
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
x (torch.Tensor): input tensor for LocalState
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
Tensor
|
|
830
|
+
Output after being run through LocalState layer.
|
|
831
|
+
"""
|
|
832
|
+
B, C, T = x.shape
|
|
833
|
+
heads = self.heads
|
|
834
|
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
|
835
|
+
# left index are keys, right index are queries
|
|
836
|
+
delta = indexes[:, None] - indexes[None, :]
|
|
837
|
+
|
|
838
|
+
queries = self.query(x).view(B, heads, -1, T)
|
|
839
|
+
keys = self.key(x).view(B, heads, -1, T)
|
|
840
|
+
# t are keys, s are queries
|
|
841
|
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
|
842
|
+
dots /= math.sqrt(keys.shape[2])
|
|
843
|
+
if self.ndecay:
|
|
844
|
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
|
845
|
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
|
846
|
+
decay_q = torch.sigmoid(decay_q) / 2
|
|
847
|
+
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
|
|
848
|
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
|
849
|
+
|
|
850
|
+
# Kill self reference.
|
|
851
|
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
|
852
|
+
weights = torch.softmax(dots, dim=2)
|
|
853
|
+
|
|
854
|
+
content = self.content(x).view(B, heads, -1, T)
|
|
855
|
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
|
856
|
+
result = result.reshape(B, -1, T)
|
|
857
|
+
return x + self.proj(result)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
class _LayerScale(nn.Module):
|
|
861
|
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
|
862
|
+
This rescales diagonally residual outputs close to 0 initially, then learnt.
|
|
863
|
+
"""
|
|
864
|
+
|
|
865
|
+
def __init__(self, channels: int, init: float = 0):
|
|
866
|
+
r"""
|
|
867
|
+
Args:
|
|
868
|
+
channels (int): Size of rescaling
|
|
869
|
+
init (float, optional): Scale to default to (default: 0)
|
|
870
|
+
"""
|
|
871
|
+
super().__init__()
|
|
872
|
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
|
873
|
+
self.scale.data[:] = init
|
|
874
|
+
|
|
875
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
876
|
+
r"""LayerScale forward call
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
x (torch.Tensor): input tensor for LayerScale
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
Tensor
|
|
883
|
+
Output after rescaling tensor.
|
|
884
|
+
"""
|
|
885
|
+
return self.scale[:, None] * x
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
|
889
|
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
|
890
|
+
with K the kernel size, by extracting frames with the given stride.
|
|
891
|
+
This will pad the input so that `F = ceil(T / K)`.
|
|
892
|
+
see https://github.com/pytorch/pytorch/issues/60466
|
|
893
|
+
"""
|
|
894
|
+
shape = list(a.shape[:-1])
|
|
895
|
+
length = int(a.shape[-1])
|
|
896
|
+
n_frames = math.ceil(length / stride)
|
|
897
|
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
|
898
|
+
a = F.pad(input=a, pad=[0, tgt_length - length])
|
|
899
|
+
strides = [a.stride(dim) for dim in range(a.dim())]
|
|
900
|
+
if strides[-1] != 1:
|
|
901
|
+
raise ValueError("Data should be contiguous.")
|
|
902
|
+
strides = strides[:-1] + [stride, 1]
|
|
903
|
+
shape.append(n_frames)
|
|
904
|
+
shape.append(kernel_size)
|
|
905
|
+
return a.as_strided(shape, strides)
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def _rescale_module(module):
|
|
909
|
+
r"""
|
|
910
|
+
Rescales initial weight scale for all models within the module.
|
|
911
|
+
"""
|
|
912
|
+
for sub in module.modules():
|
|
913
|
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
|
914
|
+
std = sub.weight.std().detach()
|
|
915
|
+
scale = (std / 0.1) ** 0.5
|
|
916
|
+
sub.weight.data /= scale
|
|
917
|
+
if sub.bias is not None:
|
|
918
|
+
sub.bias.data /= scale
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
|
|
922
|
+
other = list(x.shape[:-1])
|
|
923
|
+
length = int(x.shape[-1])
|
|
924
|
+
x = x.reshape(-1, length)
|
|
925
|
+
z = torch.stft(
|
|
926
|
+
x,
|
|
927
|
+
n_fft * (1 + pad),
|
|
928
|
+
hop_length,
|
|
929
|
+
window=torch.hann_window(n_fft).to(x),
|
|
930
|
+
win_length=n_fft,
|
|
931
|
+
normalized=True,
|
|
932
|
+
center=True,
|
|
933
|
+
return_complex=True,
|
|
934
|
+
pad_mode="reflect",
|
|
935
|
+
)
|
|
936
|
+
_, freqs, frame = z.shape
|
|
937
|
+
other.extend([freqs, frame])
|
|
938
|
+
return z.view(other)
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
|
|
942
|
+
other = list(z.shape[:-2])
|
|
943
|
+
freqs = int(z.shape[-2])
|
|
944
|
+
frames = int(z.shape[-1])
|
|
945
|
+
|
|
946
|
+
n_fft = 2 * freqs - 2
|
|
947
|
+
z = z.view(-1, freqs, frames)
|
|
948
|
+
win_length = n_fft // (1 + pad)
|
|
949
|
+
x = torch.istft(
|
|
950
|
+
z,
|
|
951
|
+
n_fft,
|
|
952
|
+
hop_length,
|
|
953
|
+
window=torch.hann_window(win_length).to(z.real),
|
|
954
|
+
win_length=win_length,
|
|
955
|
+
normalized=True,
|
|
956
|
+
length=length,
|
|
957
|
+
center=True,
|
|
958
|
+
)
|
|
959
|
+
_, length = x.shape
|
|
960
|
+
other.append(length)
|
|
961
|
+
return x.view(other)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def hdemucs_low(sources: List[str]) -> HDemucs:
|
|
965
|
+
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
sources (List[str]): See :py:func:`HDemucs`.
|
|
969
|
+
|
|
970
|
+
Returns:
|
|
971
|
+
HDemucs:
|
|
972
|
+
HDemucs model.
|
|
973
|
+
"""
|
|
974
|
+
|
|
975
|
+
return HDemucs(sources=sources, nfft=1024, depth=5)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def hdemucs_medium(sources: List[str]) -> HDemucs:
|
|
979
|
+
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
|
|
980
|
+
|
|
981
|
+
.. note::
|
|
982
|
+
|
|
983
|
+
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
|
|
984
|
+
not compatible with the original implementation in https://github.com/facebookresearch/demucs
|
|
985
|
+
|
|
986
|
+
Args:
|
|
987
|
+
sources (List[str]): See :py:func:`HDemucs`.
|
|
988
|
+
|
|
989
|
+
Returns:
|
|
990
|
+
HDemucs:
|
|
991
|
+
HDemucs model.
|
|
992
|
+
"""
|
|
993
|
+
|
|
994
|
+
return HDemucs(sources=sources, nfft=2048, depth=6)
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
def hdemucs_high(sources: List[str]) -> HDemucs:
|
|
998
|
+
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
sources (List[str]): See :py:func:`HDemucs`.
|
|
1002
|
+
|
|
1003
|
+
Returns:
|
|
1004
|
+
HDemucs:
|
|
1005
|
+
HDemucs model.
|
|
1006
|
+
"""
|
|
1007
|
+
|
|
1008
|
+
return HDemucs(sources=sources, nfft=4096, depth=6)
|