torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. torchaudio/__init__.py +204 -0
  2. torchaudio/_extension/__init__.py +61 -0
  3. torchaudio/_extension/utils.py +133 -0
  4. torchaudio/_internal/__init__.py +10 -0
  5. torchaudio/_internal/module_utils.py +171 -0
  6. torchaudio/_torchcodec.py +340 -0
  7. torchaudio/compliance/__init__.py +5 -0
  8. torchaudio/compliance/kaldi.py +813 -0
  9. torchaudio/datasets/__init__.py +47 -0
  10. torchaudio/datasets/cmuarctic.py +157 -0
  11. torchaudio/datasets/cmudict.py +186 -0
  12. torchaudio/datasets/commonvoice.py +86 -0
  13. torchaudio/datasets/dr_vctk.py +121 -0
  14. torchaudio/datasets/fluentcommands.py +108 -0
  15. torchaudio/datasets/gtzan.py +1118 -0
  16. torchaudio/datasets/iemocap.py +147 -0
  17. torchaudio/datasets/librilight_limited.py +111 -0
  18. torchaudio/datasets/librimix.py +133 -0
  19. torchaudio/datasets/librispeech.py +174 -0
  20. torchaudio/datasets/librispeech_biasing.py +189 -0
  21. torchaudio/datasets/libritts.py +168 -0
  22. torchaudio/datasets/ljspeech.py +107 -0
  23. torchaudio/datasets/musdb_hq.py +139 -0
  24. torchaudio/datasets/quesst14.py +136 -0
  25. torchaudio/datasets/snips.py +157 -0
  26. torchaudio/datasets/speechcommands.py +183 -0
  27. torchaudio/datasets/tedlium.py +218 -0
  28. torchaudio/datasets/utils.py +54 -0
  29. torchaudio/datasets/vctk.py +143 -0
  30. torchaudio/datasets/voxceleb1.py +309 -0
  31. torchaudio/datasets/yesno.py +89 -0
  32. torchaudio/functional/__init__.py +130 -0
  33. torchaudio/functional/_alignment.py +128 -0
  34. torchaudio/functional/filtering.py +1685 -0
  35. torchaudio/functional/functional.py +2505 -0
  36. torchaudio/lib/__init__.py +0 -0
  37. torchaudio/lib/_torchaudio.so +0 -0
  38. torchaudio/lib/libtorchaudio.so +0 -0
  39. torchaudio/models/__init__.py +85 -0
  40. torchaudio/models/_hdemucs.py +1008 -0
  41. torchaudio/models/conformer.py +293 -0
  42. torchaudio/models/conv_tasnet.py +330 -0
  43. torchaudio/models/decoder/__init__.py +64 -0
  44. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  45. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  46. torchaudio/models/deepspeech.py +84 -0
  47. torchaudio/models/emformer.py +884 -0
  48. torchaudio/models/rnnt.py +816 -0
  49. torchaudio/models/rnnt_decoder.py +339 -0
  50. torchaudio/models/squim/__init__.py +11 -0
  51. torchaudio/models/squim/objective.py +326 -0
  52. torchaudio/models/squim/subjective.py +150 -0
  53. torchaudio/models/tacotron2.py +1046 -0
  54. torchaudio/models/wav2letter.py +72 -0
  55. torchaudio/models/wav2vec2/__init__.py +45 -0
  56. torchaudio/models/wav2vec2/components.py +1167 -0
  57. torchaudio/models/wav2vec2/model.py +1579 -0
  58. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  59. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  60. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  61. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  62. torchaudio/models/wavernn.py +409 -0
  63. torchaudio/pipelines/__init__.py +102 -0
  64. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  65. torchaudio/pipelines/_squim_pipeline.py +156 -0
  66. torchaudio/pipelines/_tts/__init__.py +16 -0
  67. torchaudio/pipelines/_tts/impl.py +385 -0
  68. torchaudio/pipelines/_tts/interface.py +255 -0
  69. torchaudio/pipelines/_tts/utils.py +230 -0
  70. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  71. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  72. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  73. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  74. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  75. torchaudio/transforms/__init__.py +78 -0
  76. torchaudio/transforms/_multi_channel.py +467 -0
  77. torchaudio/transforms/_transforms.py +2138 -0
  78. torchaudio/utils/__init__.py +4 -0
  79. torchaudio/utils/download.py +89 -0
  80. torchaudio/version.py +2 -0
  81. torchaudio-2.9.1.dist-info/METADATA +133 -0
  82. torchaudio-2.9.1.dist-info/RECORD +85 -0
  83. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  84. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  85. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1008 @@
1
+ # *****************************************************************************
2
+ # MIT License
3
+ #
4
+ # Copyright (c) Facebook, Inc. and its affiliates.
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ # *****************************************************************************
24
+
25
+
26
+ import math
27
+ import typing as tp
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ import torch
31
+ from torch import nn
32
+ from torch.nn import functional as F
33
+
34
+
35
+ class _ScaledEmbedding(torch.nn.Module):
36
+ r"""Make continuous embeddings and boost learning rate
37
+
38
+ Args:
39
+ num_embeddings (int): number of embeddings
40
+ embedding_dim (int): embedding dimensions
41
+ scale (float, optional): amount to scale learning rate (Default: 10.0)
42
+ smooth (bool, optional): choose to apply smoothing (Default: ``False``)
43
+ """
44
+
45
+ def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
46
+ super().__init__()
47
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
48
+ if smooth:
49
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
50
+ # when summing gaussian, scale raises as sqrt(n), so we normalize by that.
51
+ weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
52
+ self.embedding.weight.data[:] = weight
53
+ self.embedding.weight.data /= scale
54
+ self.scale = scale
55
+
56
+ @property
57
+ def weight(self) -> torch.Tensor:
58
+ return self.embedding.weight * self.scale
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ r"""Forward pass for embedding with scale.
62
+ Args:
63
+ x (torch.Tensor): input tensor of shape `(num_embeddings)`
64
+
65
+ Returns:
66
+ (Tensor):
67
+ Embedding output of shape `(num_embeddings, embedding_dim)`
68
+ """
69
+ out = self.embedding(x) * self.scale
70
+ return out
71
+
72
+
73
+ class _HEncLayer(torch.nn.Module):
74
+
75
+ r"""Encoder layer. This used both by the time and the frequency branch.
76
+ Args:
77
+ chin (int): number of input channels.
78
+ chout (int): number of output channels.
79
+ kernel_size (int, optional): Kernel size for encoder (Default: 8)
80
+ stride (int, optional): Stride for encoder layer (Default: 4)
81
+ norm_groups (int, optional): number of groups for group norm. (Default: 4)
82
+ empty (bool, optional): used to make a layer with just the first conv. this is used
83
+ before merging the time and freq. branches. (Default: ``False``)
84
+ freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
85
+ norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
86
+ context (int, optional): context size for the 1x1 conv. (Default: 0)
87
+ dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
88
+ pad (bool, optional): true to pad the input. Padding is done so that the output size is
89
+ always the input size / stride. (Default: ``True``)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ chin: int,
95
+ chout: int,
96
+ kernel_size: int = 8,
97
+ stride: int = 4,
98
+ norm_groups: int = 4,
99
+ empty: bool = False,
100
+ freq: bool = True,
101
+ norm_type: str = "group_norm",
102
+ context: int = 0,
103
+ dconv_kw: Optional[Dict[str, Any]] = None,
104
+ pad: bool = True,
105
+ ):
106
+ super().__init__()
107
+ if dconv_kw is None:
108
+ dconv_kw = {}
109
+ norm_fn = lambda d: nn.Identity() # noqa
110
+ if norm_type == "group_norm":
111
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
112
+ pad_val = kernel_size // 4 if pad else 0
113
+ klass = nn.Conv1d
114
+ self.freq = freq
115
+ self.kernel_size = kernel_size
116
+ self.stride = stride
117
+ self.empty = empty
118
+ self.pad = pad_val
119
+ if freq:
120
+ kernel_size = [kernel_size, 1]
121
+ stride = [stride, 1]
122
+ pad_val = [pad_val, 0]
123
+ klass = nn.Conv2d
124
+ self.conv = klass(chin, chout, kernel_size, stride, pad_val)
125
+ self.norm1 = norm_fn(chout)
126
+
127
+ if self.empty:
128
+ self.rewrite = nn.Identity()
129
+ self.norm2 = nn.Identity()
130
+ self.dconv = nn.Identity()
131
+ else:
132
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
133
+ self.norm2 = norm_fn(2 * chout)
134
+ self.dconv = _DConv(chout, **dconv_kw)
135
+
136
+ def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ r"""Forward pass for encoding layer.
138
+
139
+ Size depends on whether frequency or time
140
+
141
+ Args:
142
+ x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
143
+ `(B, C, T)` for time
144
+ inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
145
+ same shape as x (default: ``None``)
146
+
147
+ Returns:
148
+ Tensor
149
+ output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
150
+ and shape `(B, C, ceil(T / stride))` for time
151
+ """
152
+
153
+ if not self.freq and x.dim() == 4:
154
+ B, C, Fr, T = x.shape
155
+ x = x.view(B, -1, T)
156
+
157
+ if not self.freq:
158
+ le = x.shape[-1]
159
+ if not le % self.stride == 0:
160
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
161
+ y = self.conv(x)
162
+ if self.empty:
163
+ return y
164
+ if inject is not None:
165
+ if inject.shape[-1] != y.shape[-1]:
166
+ raise ValueError("Injection shapes do not align")
167
+ if inject.dim() == 3 and y.dim() == 4:
168
+ inject = inject[:, :, None]
169
+ y = y + inject
170
+ y = F.gelu(self.norm1(y))
171
+ if self.freq:
172
+ B, C, Fr, T = y.shape
173
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
174
+ y = self.dconv(y)
175
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
176
+ else:
177
+ y = self.dconv(y)
178
+ z = self.norm2(self.rewrite(y))
179
+ z = F.glu(z, dim=1)
180
+ return z
181
+
182
+
183
+ class _HDecLayer(torch.nn.Module):
184
+ r"""Decoder layer. This used both by the time and the frequency branches.
185
+ Args:
186
+ chin (int): number of input channels.
187
+ chout (int): number of output channels.
188
+ last (bool, optional): whether current layer is final layer (Default: ``False``)
189
+ kernel_size (int, optional): Kernel size for encoder (Default: 8)
190
+ stride (int): Stride for encoder layer (Default: 4)
191
+ norm_groups (int, optional): number of groups for group norm. (Default: 1)
192
+ empty (bool, optional): used to make a layer with just the first conv. this is used
193
+ before merging the time and freq. branches. (Default: ``False``)
194
+ freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
195
+ norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
196
+ context (int, optional): context size for the 1x1 conv. (Default: 1)
197
+ dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
198
+ pad (bool, optional): true to pad the input. Padding is done so that the output size is
199
+ always the input size / stride. (Default: ``True``)
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ chin: int,
205
+ chout: int,
206
+ last: bool = False,
207
+ kernel_size: int = 8,
208
+ stride: int = 4,
209
+ norm_groups: int = 1,
210
+ empty: bool = False,
211
+ freq: bool = True,
212
+ norm_type: str = "group_norm",
213
+ context: int = 1,
214
+ dconv_kw: Optional[Dict[str, Any]] = None,
215
+ pad: bool = True,
216
+ ):
217
+ super().__init__()
218
+ if dconv_kw is None:
219
+ dconv_kw = {}
220
+ norm_fn = lambda d: nn.Identity() # noqa
221
+ if norm_type == "group_norm":
222
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
223
+ if pad:
224
+ if (kernel_size - stride) % 2 != 0:
225
+ raise ValueError("Kernel size and stride do not align")
226
+ pad = (kernel_size - stride) // 2
227
+ else:
228
+ pad = 0
229
+ self.pad = pad
230
+ self.last = last
231
+ self.freq = freq
232
+ self.chin = chin
233
+ self.empty = empty
234
+ self.stride = stride
235
+ self.kernel_size = kernel_size
236
+ klass = nn.Conv1d
237
+ klass_tr = nn.ConvTranspose1d
238
+ if freq:
239
+ kernel_size = [kernel_size, 1]
240
+ stride = [stride, 1]
241
+ klass = nn.Conv2d
242
+ klass_tr = nn.ConvTranspose2d
243
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
244
+ self.norm2 = norm_fn(chout)
245
+ if self.empty:
246
+ self.rewrite = nn.Identity()
247
+ self.norm1 = nn.Identity()
248
+ else:
249
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
250
+ self.norm1 = norm_fn(2 * chin)
251
+
252
+ def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
253
+ r"""Forward pass for decoding layer.
254
+
255
+ Size depends on whether frequency or time
256
+
257
+ Args:
258
+ x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
259
+ `(B, C, T)` for time
260
+ skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
261
+ (default: ``None``)
262
+ length (int): Size of tensor for output
263
+
264
+ Returns:
265
+ (Tensor, Tensor):
266
+ Tensor
267
+ output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
268
+ frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
269
+ for time domain.
270
+ Tensor
271
+ contains the output just before final transposed convolution, which is used when the
272
+ freq. and time branch separate. Otherwise, does not matter. Shape is
273
+ `(B, C, F, T)` for frequency and `(B, C, T)` for time.
274
+ """
275
+ if self.freq and x.dim() == 3:
276
+ B, C, T = x.shape
277
+ x = x.view(B, self.chin, -1, T)
278
+
279
+ if not self.empty:
280
+ x = x + skip
281
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
282
+ else:
283
+ y = x
284
+ if skip is not None:
285
+ raise ValueError("Skip must be none when empty is true.")
286
+
287
+ z = self.norm2(self.conv_tr(y))
288
+ if self.freq:
289
+ if self.pad:
290
+ z = z[..., self.pad : -self.pad, :]
291
+ else:
292
+ z = z[..., self.pad : self.pad + length]
293
+ if z.shape[-1] != length:
294
+ raise ValueError("Last index of z must be equal to length")
295
+ if not self.last:
296
+ z = F.gelu(z)
297
+
298
+ return z, y
299
+
300
+
301
+ class HDemucs(torch.nn.Module):
302
+ r"""Hybrid Demucs model from
303
+ *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
304
+
305
+ See Also:
306
+ * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
307
+
308
+ Args:
309
+ sources (List[str]): list of source names. List can contain the following source
310
+ options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
311
+ audio_channels (int, optional): input/output audio channels. (Default: 2)
312
+ channels (int, optional): initial number of hidden channels. (Default: 48)
313
+ growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
314
+ nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
315
+ various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
316
+ depth (int, optional): number of layers in encoder and decoder (Default: 6)
317
+ freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
318
+ the actual value controls the weight of the embedding. (Default: 0.2)
319
+ emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
320
+ emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
321
+ (Default: ``True``)
322
+ kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
323
+ time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
324
+ stride (int, optional): stride for encoder and decoder layers. (Default: 4)
325
+ context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
326
+ context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
327
+ norm_starts (int, optional): layer at which group norm starts being used.
328
+ decoder layers are numbered in reverse order. (Default: 4)
329
+ norm_groups (int, optional): number of groups for group norm. (Default: 4)
330
+ dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
331
+ dconv_comp (int, optional): compression of DConv branch. (Default: 4)
332
+ dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
333
+ dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
334
+ dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ sources: List[str],
340
+ audio_channels: int = 2,
341
+ channels: int = 48,
342
+ growth: int = 2,
343
+ nfft: int = 4096,
344
+ depth: int = 6,
345
+ freq_emb: float = 0.2,
346
+ emb_scale: int = 10,
347
+ emb_smooth: bool = True,
348
+ kernel_size: int = 8,
349
+ time_stride: int = 2,
350
+ stride: int = 4,
351
+ context: int = 1,
352
+ context_enc: int = 0,
353
+ norm_starts: int = 4,
354
+ norm_groups: int = 4,
355
+ dconv_depth: int = 2,
356
+ dconv_comp: int = 4,
357
+ dconv_attn: int = 4,
358
+ dconv_lstm: int = 4,
359
+ dconv_init: float = 1e-4,
360
+ ):
361
+ super().__init__()
362
+ self.depth = depth
363
+ self.nfft = nfft
364
+ self.audio_channels = audio_channels
365
+ self.sources = sources
366
+ self.kernel_size = kernel_size
367
+ self.context = context
368
+ self.stride = stride
369
+ self.channels = channels
370
+
371
+ self.hop_length = self.nfft // 4
372
+ self.freq_emb = None
373
+
374
+ self.freq_encoder = nn.ModuleList()
375
+ self.freq_decoder = nn.ModuleList()
376
+
377
+ self.time_encoder = nn.ModuleList()
378
+ self.time_decoder = nn.ModuleList()
379
+
380
+ chin = audio_channels
381
+ chin_z = chin * 2 # number of channels for the freq branch
382
+ chout = channels
383
+ chout_z = channels
384
+ freqs = self.nfft // 2
385
+
386
+ for index in range(self.depth):
387
+ lstm = index >= dconv_lstm
388
+ attn = index >= dconv_attn
389
+ norm_type = "group_norm" if index >= norm_starts else "none"
390
+ freq = freqs > 1
391
+ stri = stride
392
+ ker = kernel_size
393
+ if not freq:
394
+ if freqs != 1:
395
+ raise ValueError("When freq is false, freqs must be 1.")
396
+ ker = time_stride * 2
397
+ stri = time_stride
398
+
399
+ pad = True
400
+ last_freq = False
401
+ if freq and freqs <= kernel_size:
402
+ ker = freqs
403
+ pad = False
404
+ last_freq = True
405
+
406
+ kw = {
407
+ "kernel_size": ker,
408
+ "stride": stri,
409
+ "freq": freq,
410
+ "pad": pad,
411
+ "norm_type": norm_type,
412
+ "norm_groups": norm_groups,
413
+ "dconv_kw": {
414
+ "lstm": lstm,
415
+ "attn": attn,
416
+ "depth": dconv_depth,
417
+ "compress": dconv_comp,
418
+ "init": dconv_init,
419
+ },
420
+ }
421
+ kwt = dict(kw)
422
+ kwt["freq"] = 0
423
+ kwt["kernel_size"] = kernel_size
424
+ kwt["stride"] = stride
425
+ kwt["pad"] = True
426
+ kw_dec = dict(kw)
427
+
428
+ if last_freq:
429
+ chout_z = max(chout, chout_z)
430
+ chout = chout_z
431
+
432
+ enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
433
+ if freq:
434
+ if last_freq is True and nfft == 2048:
435
+ kwt["stride"] = 2
436
+ kwt["kernel_size"] = 4
437
+ tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
438
+ self.time_encoder.append(tenc)
439
+
440
+ self.freq_encoder.append(enc)
441
+ if index == 0:
442
+ chin = self.audio_channels * len(self.sources)
443
+ chin_z = chin * 2
444
+ dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
445
+ if freq:
446
+ tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
447
+ self.time_decoder.insert(0, tdec)
448
+ self.freq_decoder.insert(0, dec)
449
+
450
+ chin = chout
451
+ chin_z = chout_z
452
+ chout = int(growth * chout)
453
+ chout_z = int(growth * chout_z)
454
+ if freq:
455
+ if freqs <= kernel_size:
456
+ freqs = 1
457
+ else:
458
+ freqs //= stride
459
+ if index == 0 and freq_emb:
460
+ self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
461
+ self.freq_emb_scale = freq_emb
462
+
463
+ _rescale_module(self)
464
+
465
+ def _spec(self, x):
466
+ hl = self.hop_length
467
+ nfft = self.nfft
468
+ x0 = x # noqa
469
+
470
+ # We re-pad the signal in order to keep the property
471
+ # that the size of the output is exactly the size of the input
472
+ # divided by the stride (here hop_length), when divisible.
473
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
474
+ # which is not supported by torch.stft.
475
+ # Having all convolution operations follow this convention allow to easily
476
+ # align the time and frequency branches later on.
477
+ if hl != nfft // 4:
478
+ raise ValueError("Hop length must be nfft // 4")
479
+ le = int(math.ceil(x.shape[-1] / hl))
480
+ pad = hl // 2 * 3
481
+ x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
482
+
483
+ z = _spectro(x, nfft, hl)[..., :-1, :]
484
+ if z.shape[-1] != le + 4:
485
+ raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
486
+ z = z[..., 2 : 2 + le]
487
+ return z
488
+
489
+ def _ispec(self, z, length=None):
490
+ hl = self.hop_length
491
+ z = F.pad(z, [0, 0, 0, 1])
492
+ z = F.pad(z, [2, 2])
493
+ pad = hl // 2 * 3
494
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
495
+ x = _ispectro(z, hl, length=le)
496
+ x = x[..., pad : pad + length]
497
+ return x
498
+
499
+ def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
500
+ """Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
501
+ Add extra zero padding around in order for padding to not break."""
502
+ length = x.shape[-1]
503
+ if mode == "reflect":
504
+ max_pad = max(padding_left, padding_right)
505
+ if length <= max_pad:
506
+ x = F.pad(x, (0, max_pad - length + 1))
507
+ return F.pad(x, (padding_left, padding_right), mode, value)
508
+
509
+ def _magnitude(self, z):
510
+ # move the complex dimension to the channel one.
511
+ B, C, Fr, T = z.shape
512
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
513
+ m = m.reshape(B, C * 2, Fr, T)
514
+ return m
515
+
516
+ def _mask(self, m):
517
+ # `m` is a full spectrogram and `z` is ignored.
518
+ B, S, C, Fr, T = m.shape
519
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
520
+ out = torch.view_as_complex(out.contiguous())
521
+ return out
522
+
523
+ def forward(self, input: torch.Tensor):
524
+
525
+ r"""HDemucs forward call
526
+
527
+ Args:
528
+ input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
529
+
530
+ Returns:
531
+ Tensor
532
+ output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
533
+ """
534
+
535
+ if input.ndim != 3:
536
+ raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
537
+
538
+ if input.shape[1] != self.audio_channels:
539
+ raise ValueError(
540
+ f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
541
+ f"Found:{input.shape[1]}."
542
+ )
543
+
544
+ x = input
545
+ length = x.shape[-1]
546
+
547
+ z = self._spec(input)
548
+ mag = self._magnitude(z)
549
+ x = mag
550
+
551
+ B, C, Fq, T = x.shape
552
+
553
+ # unlike previous Demucs, we always normalize because it is easier.
554
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
555
+ std = x.std(dim=(1, 2, 3), keepdim=True)
556
+ x = (x - mean) / (1e-5 + std)
557
+ # x will be the freq. branch input.
558
+
559
+ # Prepare the time branch input.
560
+ xt = input
561
+ meant = xt.mean(dim=(1, 2), keepdim=True)
562
+ stdt = xt.std(dim=(1, 2), keepdim=True)
563
+ xt = (xt - meant) / (1e-5 + stdt)
564
+
565
+ saved = [] # skip connections, freq.
566
+ saved_t = [] # skip connections, time.
567
+ lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
568
+ lengths_t: List[int] = [] # saved lengths for time branch.
569
+
570
+ for idx, encode in enumerate(self.freq_encoder):
571
+ lengths.append(x.shape[-1])
572
+ inject = None
573
+ if idx < len(self.time_encoder):
574
+ # we have not yet merged branches.
575
+ lengths_t.append(xt.shape[-1])
576
+ tenc = self.time_encoder[idx]
577
+ xt = tenc(xt)
578
+ if not tenc.empty:
579
+ # save for skip connection
580
+ saved_t.append(xt)
581
+ else:
582
+ # tenc contains just the first conv., so that now time and freq.
583
+ # branches have the same shape and can be merged.
584
+ inject = xt
585
+ x = encode(x, inject)
586
+ if idx == 0 and self.freq_emb is not None:
587
+ # add frequency embedding to allow for non equivariant convolutions
588
+ # over the frequency axis.
589
+ frs = torch.arange(x.shape[-2], device=x.device)
590
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
591
+ x = x + self.freq_emb_scale * emb
592
+
593
+ saved.append(x)
594
+
595
+ x = torch.zeros_like(x)
596
+ xt = torch.zeros_like(x)
597
+ # initialize everything to zero (signal will go through u-net skips).
598
+
599
+ for idx, decode in enumerate(self.freq_decoder):
600
+ skip = saved.pop(-1)
601
+ x, pre = decode(x, skip, lengths.pop(-1))
602
+ # `pre` contains the output just before final transposed convolution,
603
+ # which is used when the freq. and time branch separate.
604
+ offset = self.depth - len(self.time_decoder)
605
+ if idx >= offset:
606
+ tdec = self.time_decoder[idx - offset]
607
+ length_t = lengths_t.pop(-1)
608
+ if tdec.empty:
609
+ if pre.shape[2] != 1:
610
+ raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
611
+ pre = pre[:, :, 0]
612
+ xt, _ = tdec(pre, None, length_t)
613
+ else:
614
+ skip = saved_t.pop(-1)
615
+ xt, _ = tdec(xt, skip, length_t)
616
+
617
+ if len(saved) != 0:
618
+ raise AssertionError("saved is not empty")
619
+ if len(lengths_t) != 0:
620
+ raise AssertionError("lengths_t is not empty")
621
+ if len(saved_t) != 0:
622
+ raise AssertionError("saved_t is not empty")
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(x)
629
+ x = self._ispec(zout, length)
630
+
631
+ xt = xt.view(B, S, -1, length)
632
+ xt = xt * stdt[:, None] + meant[:, None]
633
+ x = xt + x
634
+ return x
635
+
636
+
637
+ class _DConv(torch.nn.Module):
638
+ r"""
639
+ New residual branches in each encoder layer.
640
+ This alternates dilated convolutions, potentially with LSTMs and attention.
641
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
642
+ e.g. of dim `channels // compress`.
643
+
644
+ Args:
645
+ channels (int): input/output channels for residual branch.
646
+ compress (float, optional): amount of channel compression inside the branch. (default: 4)
647
+ depth (int, optional): number of layers in the residual branch. Each layer has its own
648
+ projection, and potentially LSTM and attention.(default: 2)
649
+ init (float, optional): initial scale for LayerNorm. (default: 1e-4)
650
+ norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
651
+ attn (bool, optional): use LocalAttention. (Default: ``False``)
652
+ heads (int, optional): number of heads for the LocalAttention. (default: 4)
653
+ ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
654
+ lstm (bool, optional): use LSTM. (Default: ``False``)
655
+ kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ channels: int,
661
+ compress: float = 4,
662
+ depth: int = 2,
663
+ init: float = 1e-4,
664
+ norm_type: str = "group_norm",
665
+ attn: bool = False,
666
+ heads: int = 4,
667
+ ndecay: int = 4,
668
+ lstm: bool = False,
669
+ kernel_size: int = 3,
670
+ ):
671
+
672
+ super().__init__()
673
+ if kernel_size % 2 == 0:
674
+ raise ValueError("Kernel size should not be divisible by 2")
675
+ self.channels = channels
676
+ self.compress = compress
677
+ self.depth = abs(depth)
678
+ dilate = depth > 0
679
+
680
+ norm_fn: tp.Callable[[int], nn.Module]
681
+ norm_fn = lambda d: nn.Identity() # noqa
682
+ if norm_type == "group_norm":
683
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
684
+
685
+ hidden = int(channels / compress)
686
+
687
+ act = nn.GELU
688
+
689
+ self.layers = nn.ModuleList([])
690
+ for d in range(self.depth):
691
+ dilation = pow(2, d) if dilate else 1
692
+ padding = dilation * (kernel_size // 2)
693
+ mods = [
694
+ nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
695
+ norm_fn(hidden),
696
+ act(),
697
+ nn.Conv1d(hidden, 2 * channels, 1),
698
+ norm_fn(2 * channels),
699
+ nn.GLU(1),
700
+ _LayerScale(channels, init),
701
+ ]
702
+ if attn:
703
+ mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
704
+ if lstm:
705
+ mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
706
+ layer = nn.Sequential(*mods)
707
+ self.layers.append(layer)
708
+
709
+ def forward(self, x):
710
+ r"""DConv forward call
711
+
712
+ Args:
713
+ x (torch.Tensor): input tensor for convolution
714
+
715
+ Returns:
716
+ Tensor
717
+ Output after being run through layers.
718
+ """
719
+ for layer in self.layers:
720
+ x = x + layer(x)
721
+ return x
722
+
723
+
724
+ class _BLSTM(torch.nn.Module):
725
+ r"""
726
+ BiLSTM with same hidden units as input dim.
727
+ If `max_steps` is not None, input will be splitting in overlapping
728
+ chunks and the LSTM applied separately on each chunk.
729
+ Args:
730
+ dim (int): dimensions at LSTM layer.
731
+ layers (int, optional): number of LSTM layers. (default: 1)
732
+ skip (bool, optional): (default: ``False``)
733
+ """
734
+
735
+ def __init__(self, dim, layers: int = 1, skip: bool = False):
736
+ super().__init__()
737
+ self.max_steps = 200
738
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
739
+ self.linear = nn.Linear(2 * dim, dim)
740
+ self.skip = skip
741
+
742
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
743
+ r"""BLSTM forward call
744
+
745
+ Args:
746
+ x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
747
+
748
+ Returns:
749
+ Tensor
750
+ Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
751
+ """
752
+ B, C, T = x.shape
753
+ y = x
754
+ framed = False
755
+ width = 0
756
+ stride = 0
757
+ nframes = 0
758
+ if self.max_steps is not None and T > self.max_steps:
759
+ width = self.max_steps
760
+ stride = width // 2
761
+ frames = _unfold(x, width, stride)
762
+ nframes = frames.shape[2]
763
+ framed = True
764
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
765
+
766
+ x = x.permute(2, 0, 1)
767
+
768
+ x = self.lstm(x)[0]
769
+ x = self.linear(x)
770
+ x = x.permute(1, 2, 0)
771
+ if framed:
772
+ out = []
773
+ frames = x.reshape(B, -1, C, width)
774
+ limit = stride // 2
775
+ for k in range(nframes):
776
+ if k == 0:
777
+ out.append(frames[:, k, :, :-limit])
778
+ elif k == nframes - 1:
779
+ out.append(frames[:, k, :, limit:])
780
+ else:
781
+ out.append(frames[:, k, :, limit:-limit])
782
+ out = torch.cat(out, -1)
783
+ out = out[..., :T]
784
+ x = out
785
+ if self.skip:
786
+ x = x + y
787
+
788
+ return x
789
+
790
+
791
+ class _LocalState(nn.Module):
792
+ """Local state allows to have attention based only on data (no positional embedding),
793
+ but while setting a constraint on the time window (e.g. decaying penalty term).
794
+ Also a failed experiments with trying to provide some frequency based attention.
795
+ """
796
+
797
+ def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
798
+ r"""
799
+ Args:
800
+ channels (int): Size of Conv1d layers.
801
+ heads (int, optional): (default: 4)
802
+ ndecay (int, optional): (default: 4)
803
+ """
804
+ super(_LocalState, self).__init__()
805
+ if channels % heads != 0:
806
+ raise ValueError("Channels must be divisible by heads.")
807
+ self.heads = heads
808
+ self.ndecay = ndecay
809
+ self.content = nn.Conv1d(channels, channels, 1)
810
+ self.query = nn.Conv1d(channels, channels, 1)
811
+ self.key = nn.Conv1d(channels, channels, 1)
812
+
813
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
814
+ if ndecay:
815
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
816
+ self.query_decay.weight.data *= 0.01
817
+ if self.query_decay.bias is None:
818
+ raise ValueError("bias must not be None.")
819
+ self.query_decay.bias.data[:] = -2
820
+ self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
821
+
822
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
823
+ r"""LocalState forward call
824
+
825
+ Args:
826
+ x (torch.Tensor): input tensor for LocalState
827
+
828
+ Returns:
829
+ Tensor
830
+ Output after being run through LocalState layer.
831
+ """
832
+ B, C, T = x.shape
833
+ heads = self.heads
834
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
835
+ # left index are keys, right index are queries
836
+ delta = indexes[:, None] - indexes[None, :]
837
+
838
+ queries = self.query(x).view(B, heads, -1, T)
839
+ keys = self.key(x).view(B, heads, -1, T)
840
+ # t are keys, s are queries
841
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
842
+ dots /= math.sqrt(keys.shape[2])
843
+ if self.ndecay:
844
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
845
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
846
+ decay_q = torch.sigmoid(decay_q) / 2
847
+ decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
848
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
849
+
850
+ # Kill self reference.
851
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
852
+ weights = torch.softmax(dots, dim=2)
853
+
854
+ content = self.content(x).view(B, heads, -1, T)
855
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
856
+ result = result.reshape(B, -1, T)
857
+ return x + self.proj(result)
858
+
859
+
860
+ class _LayerScale(nn.Module):
861
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
862
+ This rescales diagonally residual outputs close to 0 initially, then learnt.
863
+ """
864
+
865
+ def __init__(self, channels: int, init: float = 0):
866
+ r"""
867
+ Args:
868
+ channels (int): Size of rescaling
869
+ init (float, optional): Scale to default to (default: 0)
870
+ """
871
+ super().__init__()
872
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
873
+ self.scale.data[:] = init
874
+
875
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
876
+ r"""LayerScale forward call
877
+
878
+ Args:
879
+ x (torch.Tensor): input tensor for LayerScale
880
+
881
+ Returns:
882
+ Tensor
883
+ Output after rescaling tensor.
884
+ """
885
+ return self.scale[:, None] * x
886
+
887
+
888
+ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
889
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
890
+ with K the kernel size, by extracting frames with the given stride.
891
+ This will pad the input so that `F = ceil(T / K)`.
892
+ see https://github.com/pytorch/pytorch/issues/60466
893
+ """
894
+ shape = list(a.shape[:-1])
895
+ length = int(a.shape[-1])
896
+ n_frames = math.ceil(length / stride)
897
+ tgt_length = (n_frames - 1) * stride + kernel_size
898
+ a = F.pad(input=a, pad=[0, tgt_length - length])
899
+ strides = [a.stride(dim) for dim in range(a.dim())]
900
+ if strides[-1] != 1:
901
+ raise ValueError("Data should be contiguous.")
902
+ strides = strides[:-1] + [stride, 1]
903
+ shape.append(n_frames)
904
+ shape.append(kernel_size)
905
+ return a.as_strided(shape, strides)
906
+
907
+
908
+ def _rescale_module(module):
909
+ r"""
910
+ Rescales initial weight scale for all models within the module.
911
+ """
912
+ for sub in module.modules():
913
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
914
+ std = sub.weight.std().detach()
915
+ scale = (std / 0.1) ** 0.5
916
+ sub.weight.data /= scale
917
+ if sub.bias is not None:
918
+ sub.bias.data /= scale
919
+
920
+
921
+ def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
922
+ other = list(x.shape[:-1])
923
+ length = int(x.shape[-1])
924
+ x = x.reshape(-1, length)
925
+ z = torch.stft(
926
+ x,
927
+ n_fft * (1 + pad),
928
+ hop_length,
929
+ window=torch.hann_window(n_fft).to(x),
930
+ win_length=n_fft,
931
+ normalized=True,
932
+ center=True,
933
+ return_complex=True,
934
+ pad_mode="reflect",
935
+ )
936
+ _, freqs, frame = z.shape
937
+ other.extend([freqs, frame])
938
+ return z.view(other)
939
+
940
+
941
+ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
942
+ other = list(z.shape[:-2])
943
+ freqs = int(z.shape[-2])
944
+ frames = int(z.shape[-1])
945
+
946
+ n_fft = 2 * freqs - 2
947
+ z = z.view(-1, freqs, frames)
948
+ win_length = n_fft // (1 + pad)
949
+ x = torch.istft(
950
+ z,
951
+ n_fft,
952
+ hop_length,
953
+ window=torch.hann_window(win_length).to(z.real),
954
+ win_length=win_length,
955
+ normalized=True,
956
+ length=length,
957
+ center=True,
958
+ )
959
+ _, length = x.shape
960
+ other.append(length)
961
+ return x.view(other)
962
+
963
+
964
+ def hdemucs_low(sources: List[str]) -> HDemucs:
965
+ """Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
966
+
967
+ Args:
968
+ sources (List[str]): See :py:func:`HDemucs`.
969
+
970
+ Returns:
971
+ HDemucs:
972
+ HDemucs model.
973
+ """
974
+
975
+ return HDemucs(sources=sources, nfft=1024, depth=5)
976
+
977
+
978
+ def hdemucs_medium(sources: List[str]) -> HDemucs:
979
+ r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
980
+
981
+ .. note::
982
+
983
+ Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
984
+ not compatible with the original implementation in https://github.com/facebookresearch/demucs
985
+
986
+ Args:
987
+ sources (List[str]): See :py:func:`HDemucs`.
988
+
989
+ Returns:
990
+ HDemucs:
991
+ HDemucs model.
992
+ """
993
+
994
+ return HDemucs(sources=sources, nfft=2048, depth=6)
995
+
996
+
997
+ def hdemucs_high(sources: List[str]) -> HDemucs:
998
+ r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
999
+
1000
+ Args:
1001
+ sources (List[str]): See :py:func:`HDemucs`.
1002
+
1003
+ Returns:
1004
+ HDemucs:
1005
+ HDemucs model.
1006
+ """
1007
+
1008
+ return HDemucs(sources=sources, nfft=4096, depth=6)