torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.1.dist-info/METADATA +133 -0
  83. torchaudio-2.9.1.dist-info/RECORD +86 -0
  84. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  85. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  86. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,409 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, Tensor
7
+
8
+ __all__ = [
9
+ "ResBlock",
10
+ "MelResNet",
11
+ "Stretch2d",
12
+ "UpsampleNetwork",
13
+ "WaveRNN",
14
+ ]
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`.
19
+
20
+ Args:
21
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
22
+
23
+ Examples
24
+ >>> resblock = ResBlock()
25
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
26
+ >>> output = resblock(input) # shape: (10, 128, 512)
27
+ """
28
+
29
+ def __init__(self, n_freq: int = 128) -> None:
30
+ super().__init__()
31
+
32
+ self.resblock_model = nn.Sequential(
33
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
34
+ nn.BatchNorm1d(n_freq),
35
+ nn.ReLU(inplace=True),
36
+ nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
37
+ nn.BatchNorm1d(n_freq),
38
+ )
39
+
40
+ def forward(self, specgram: Tensor) -> Tensor:
41
+ r"""Pass the input through the ResBlock layer.
42
+ Args:
43
+ specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
44
+
45
+ Return:
46
+ Tensor shape: (n_batch, n_freq, n_time)
47
+ """
48
+
49
+ return self.resblock_model(specgram) + specgram
50
+
51
+
52
+ class MelResNet(nn.Module):
53
+ r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
54
+
55
+ Args:
56
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
57
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
58
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
59
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
60
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
61
+
62
+ Examples
63
+ >>> melresnet = MelResNet()
64
+ >>> input = torch.rand(10, 128, 512) # a random spectrogram
65
+ >>> output = melresnet(input) # shape: (10, 128, 508)
66
+ """
67
+
68
+ def __init__(
69
+ self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
74
+
75
+ self.melresnet_model = nn.Sequential(
76
+ nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
77
+ nn.BatchNorm1d(n_hidden),
78
+ nn.ReLU(inplace=True),
79
+ *ResBlocks,
80
+ nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
81
+ )
82
+
83
+ def forward(self, specgram: Tensor) -> Tensor:
84
+ r"""Pass the input through the MelResNet layer.
85
+ Args:
86
+ specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
87
+
88
+ Return:
89
+ Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
90
+ """
91
+
92
+ return self.melresnet_model(specgram)
93
+
94
+
95
+ class Stretch2d(nn.Module):
96
+ r"""Upscale the frequency and time dimensions of a spectrogram.
97
+
98
+ Args:
99
+ time_scale: the scale factor in time dimension
100
+ freq_scale: the scale factor in frequency dimension
101
+
102
+ Examples
103
+ >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
104
+
105
+ >>> input = torch.rand(10, 100, 512) # a random spectrogram
106
+ >>> output = stretch2d(input) # shape: (10, 500, 5120)
107
+ """
108
+
109
+ def __init__(self, time_scale: int, freq_scale: int) -> None:
110
+ super().__init__()
111
+
112
+ self.freq_scale = freq_scale
113
+ self.time_scale = time_scale
114
+
115
+ def forward(self, specgram: Tensor) -> Tensor:
116
+ r"""Pass the input through the Stretch2d layer.
117
+
118
+ Args:
119
+ specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
120
+
121
+ Return:
122
+ Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
123
+ """
124
+
125
+ return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
126
+
127
+
128
+ class UpsampleNetwork(nn.Module):
129
+ r"""Upscale the dimensions of a spectrogram.
130
+
131
+ Args:
132
+ upsample_scales: the list of upsample scales.
133
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
134
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
135
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
136
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
137
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
138
+
139
+ Examples
140
+ >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
141
+ >>> input = torch.rand(10, 128, 10) # a random spectrogram
142
+ >>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ upsample_scales: List[int],
148
+ n_res_block: int = 10,
149
+ n_freq: int = 128,
150
+ n_hidden: int = 128,
151
+ n_output: int = 128,
152
+ kernel_size: int = 5,
153
+ ) -> None:
154
+ super().__init__()
155
+
156
+ total_scale = 1
157
+ for upsample_scale in upsample_scales:
158
+ total_scale *= upsample_scale
159
+ self.total_scale: int = total_scale
160
+
161
+ self.indent = (kernel_size - 1) // 2 * total_scale
162
+ self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
163
+ self.resnet_stretch = Stretch2d(total_scale, 1)
164
+
165
+ up_layers = []
166
+ for scale in upsample_scales:
167
+ stretch = Stretch2d(scale, 1)
168
+ conv = nn.Conv2d(
169
+ in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
170
+ )
171
+ torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
172
+ up_layers.append(stretch)
173
+ up_layers.append(conv)
174
+ self.upsample_layers = nn.Sequential(*up_layers)
175
+
176
+ def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
177
+ r"""Pass the input through the UpsampleNetwork layer.
178
+
179
+ Args:
180
+ specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
181
+
182
+ Return:
183
+ Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
184
+ (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
185
+ where total_scale is the product of all elements in upsample_scales.
186
+ """
187
+
188
+ resnet_output = self.resnet(specgram).unsqueeze(1)
189
+ resnet_output = self.resnet_stretch(resnet_output)
190
+ resnet_output = resnet_output.squeeze(1)
191
+
192
+ specgram = specgram.unsqueeze(1)
193
+ upsampling_output = self.upsample_layers(specgram)
194
+ upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
195
+
196
+ return upsampling_output, resnet_output
197
+
198
+
199
+ class WaveRNN(nn.Module):
200
+ r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
201
+ based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
202
+
203
+ The original implementation was introduced in *Efficient Neural Audio Synthesis*
204
+ :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
205
+ The product of `upsample_scales` must equal `hop_length`.
206
+
207
+ See Also:
208
+ * `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
209
+ * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
210
+
211
+ Args:
212
+ upsample_scales: the list of upsample scales.
213
+ n_classes: the number of output classes.
214
+ hop_length: the number of samples between the starts of consecutive frames.
215
+ n_res_block: the number of ResBlock in stack. (Default: ``10``)
216
+ n_rnn: the dimension of RNN layer. (Default: ``512``)
217
+ n_fc: the dimension of fully connected layer. (Default: ``512``)
218
+ kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
219
+ n_freq: the number of bins in a spectrogram. (Default: ``128``)
220
+ n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
221
+ n_output: the number of output dimensions of melresnet. (Default: ``128``)
222
+
223
+ Example
224
+ >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
225
+ >>> waveform, sample_rate = torchaudio.load(file)
226
+ >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
227
+ >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
228
+ >>> output = wavernn(waveform, specgram)
229
+ >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ upsample_scales: List[int],
235
+ n_classes: int,
236
+ hop_length: int,
237
+ n_res_block: int = 10,
238
+ n_rnn: int = 512,
239
+ n_fc: int = 512,
240
+ kernel_size: int = 5,
241
+ n_freq: int = 128,
242
+ n_hidden: int = 128,
243
+ n_output: int = 128,
244
+ ) -> None:
245
+ super().__init__()
246
+
247
+ self.kernel_size = kernel_size
248
+ self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
249
+ self.n_rnn = n_rnn
250
+ self.n_aux = n_output // 4
251
+ self.hop_length = hop_length
252
+ self.n_classes = n_classes
253
+ self.n_bits: int = int(math.log2(self.n_classes))
254
+
255
+ total_scale = 1
256
+ for upsample_scale in upsample_scales:
257
+ total_scale *= upsample_scale
258
+ if total_scale != self.hop_length:
259
+ raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
260
+
261
+ self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
262
+ self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
263
+
264
+ self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
265
+ self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
266
+
267
+ self.relu1 = nn.ReLU(inplace=True)
268
+ self.relu2 = nn.ReLU(inplace=True)
269
+
270
+ self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
271
+ self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
272
+ self.fc3 = nn.Linear(n_fc, self.n_classes)
273
+
274
+ def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
275
+ r"""Pass the input through the WaveRNN model.
276
+
277
+ Args:
278
+ waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
279
+ specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
280
+
281
+ Return:
282
+ Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
283
+ """
284
+
285
+ if waveform.size(1) != 1:
286
+ raise ValueError("Require the input channel of waveform is 1")
287
+ if specgram.size(1) != 1:
288
+ raise ValueError("Require the input channel of specgram is 1")
289
+ # remove channel dimension until the end
290
+ waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
291
+
292
+ batch_size = waveform.size(0)
293
+ h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
294
+ h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
295
+ # output of upsample:
296
+ # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
297
+ # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
298
+ specgram, aux = self.upsample(specgram)
299
+ specgram = specgram.transpose(1, 2)
300
+ aux = aux.transpose(1, 2)
301
+
302
+ aux_idx = [self.n_aux * i for i in range(5)]
303
+ a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
304
+ a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
305
+ a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
306
+ a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
307
+
308
+ x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
309
+ x = self.fc(x)
310
+ res = x
311
+ x, _ = self.rnn1(x, h1)
312
+
313
+ x = x + res
314
+ res = x
315
+ x = torch.cat([x, a2], dim=-1)
316
+ x, _ = self.rnn2(x, h2)
317
+
318
+ x = x + res
319
+ x = torch.cat([x, a3], dim=-1)
320
+ x = self.fc1(x)
321
+ x = self.relu1(x)
322
+
323
+ x = torch.cat([x, a4], dim=-1)
324
+ x = self.fc2(x)
325
+ x = self.relu2(x)
326
+ x = self.fc3(x)
327
+
328
+ # bring back channel dimension
329
+ return x.unsqueeze(1)
330
+
331
+ @torch.jit.export
332
+ def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
333
+ r"""Inference method of WaveRNN.
334
+
335
+ This function currently only supports multinomial sampling, which assumes the
336
+ network is trained on cross entropy loss.
337
+
338
+ Args:
339
+ specgram (Tensor):
340
+ Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
341
+ lengths (Tensor or None, optional):
342
+ Indicates the valid length of each audio in the batch.
343
+ Shape: `(batch, )`.
344
+ When the ``specgram`` contains spectrograms with different durations,
345
+ by providing ``lengths`` argument, the model will compute
346
+ the corresponding valid output lengths.
347
+ If ``None``, it is assumed that all the audio in ``waveforms``
348
+ have valid length. Default: ``None``.
349
+
350
+ Returns:
351
+ (Tensor, Optional[Tensor]):
352
+ Tensor
353
+ The inferred waveform of size `(n_batch, 1, n_time)`.
354
+ 1 stands for a single channel.
355
+ Tensor or None
356
+ If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
357
+ is returned.
358
+ It indicates the valid length in time axis of the output Tensor.
359
+ """
360
+
361
+ device = specgram.device
362
+ dtype = specgram.dtype
363
+
364
+ specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
365
+ specgram, aux = self.upsample(specgram)
366
+ if lengths is not None:
367
+ lengths = lengths * self.upsample.total_scale
368
+
369
+ output: List[Tensor] = []
370
+ b_size, _, seq_len = specgram.size()
371
+
372
+ h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
373
+ h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
374
+ x = torch.zeros((b_size, 1), device=device, dtype=dtype)
375
+
376
+ aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
377
+
378
+ for i in range(seq_len):
379
+
380
+ m_t = specgram[:, :, i]
381
+
382
+ a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
383
+
384
+ x = torch.cat([x, m_t, a1_t], dim=1)
385
+ x = self.fc(x)
386
+ _, h1 = self.rnn1(x.unsqueeze(1), h1)
387
+
388
+ x = x + h1[0]
389
+ inp = torch.cat([x, a2_t], dim=1)
390
+ _, h2 = self.rnn2(inp.unsqueeze(1), h2)
391
+
392
+ x = x + h2[0]
393
+ x = torch.cat([x, a3_t], dim=1)
394
+ x = F.relu(self.fc1(x))
395
+
396
+ x = torch.cat([x, a4_t], dim=1)
397
+ x = F.relu(self.fc2(x))
398
+
399
+ logits = self.fc3(x)
400
+
401
+ posterior = F.softmax(logits, dim=1)
402
+
403
+ x = torch.multinomial(posterior, 1).float()
404
+ # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
405
+ x = 2 * x / (2**self.n_bits - 1.0) - 1.0
406
+
407
+ output.append(x)
408
+
409
+ return torch.stack(output).permute(1, 2, 0), lengths
@@ -0,0 +1,102 @@
1
+ from ._source_separation_pipeline import (
2
+ CONVTASNET_BASE_LIBRI2MIX,
3
+ HDEMUCS_HIGH_MUSDB,
4
+ HDEMUCS_HIGH_MUSDB_PLUS,
5
+ SourceSeparationBundle,
6
+ )
7
+ from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
8
+ from ._tts import (
9
+ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
10
+ TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
11
+ TACOTRON2_WAVERNN_CHAR_LJSPEECH,
12
+ TACOTRON2_WAVERNN_PHONE_LJSPEECH,
13
+ Tacotron2TTSBundle,
14
+ )
15
+ from ._wav2vec2.impl import (
16
+ HUBERT_ASR_LARGE,
17
+ HUBERT_ASR_XLARGE,
18
+ HUBERT_BASE,
19
+ HUBERT_LARGE,
20
+ HUBERT_XLARGE,
21
+ MMS_FA,
22
+ VOXPOPULI_ASR_BASE_10K_DE,
23
+ VOXPOPULI_ASR_BASE_10K_EN,
24
+ VOXPOPULI_ASR_BASE_10K_ES,
25
+ VOXPOPULI_ASR_BASE_10K_FR,
26
+ VOXPOPULI_ASR_BASE_10K_IT,
27
+ WAV2VEC2_ASR_BASE_100H,
28
+ WAV2VEC2_ASR_BASE_10M,
29
+ WAV2VEC2_ASR_BASE_960H,
30
+ WAV2VEC2_ASR_LARGE_100H,
31
+ WAV2VEC2_ASR_LARGE_10M,
32
+ WAV2VEC2_ASR_LARGE_960H,
33
+ WAV2VEC2_ASR_LARGE_LV60K_100H,
34
+ WAV2VEC2_ASR_LARGE_LV60K_10M,
35
+ WAV2VEC2_ASR_LARGE_LV60K_960H,
36
+ WAV2VEC2_BASE,
37
+ WAV2VEC2_LARGE,
38
+ WAV2VEC2_LARGE_LV60K,
39
+ WAV2VEC2_XLSR53,
40
+ WAV2VEC2_XLSR_1B,
41
+ WAV2VEC2_XLSR_2B,
42
+ WAV2VEC2_XLSR_300M,
43
+ Wav2Vec2ASRBundle,
44
+ Wav2Vec2Bundle,
45
+ Wav2Vec2FABundle,
46
+ WAVLM_BASE,
47
+ WAVLM_BASE_PLUS,
48
+ WAVLM_LARGE,
49
+ )
50
+ from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
51
+
52
+
53
+ __all__ = [
54
+ "Wav2Vec2Bundle",
55
+ "Wav2Vec2ASRBundle",
56
+ "Wav2Vec2FABundle",
57
+ "WAV2VEC2_BASE",
58
+ "WAV2VEC2_LARGE",
59
+ "WAV2VEC2_LARGE_LV60K",
60
+ "WAV2VEC2_ASR_BASE_10M",
61
+ "WAV2VEC2_ASR_BASE_100H",
62
+ "WAV2VEC2_ASR_BASE_960H",
63
+ "WAV2VEC2_ASR_LARGE_10M",
64
+ "WAV2VEC2_ASR_LARGE_100H",
65
+ "WAV2VEC2_ASR_LARGE_960H",
66
+ "WAV2VEC2_ASR_LARGE_LV60K_10M",
67
+ "WAV2VEC2_ASR_LARGE_LV60K_100H",
68
+ "WAV2VEC2_ASR_LARGE_LV60K_960H",
69
+ "WAV2VEC2_XLSR53",
70
+ "WAV2VEC2_XLSR_300M",
71
+ "WAV2VEC2_XLSR_1B",
72
+ "WAV2VEC2_XLSR_2B",
73
+ "VOXPOPULI_ASR_BASE_10K_EN",
74
+ "VOXPOPULI_ASR_BASE_10K_ES",
75
+ "VOXPOPULI_ASR_BASE_10K_DE",
76
+ "VOXPOPULI_ASR_BASE_10K_FR",
77
+ "VOXPOPULI_ASR_BASE_10K_IT",
78
+ "HUBERT_BASE",
79
+ "HUBERT_LARGE",
80
+ "HUBERT_XLARGE",
81
+ "HUBERT_ASR_LARGE",
82
+ "HUBERT_ASR_XLARGE",
83
+ "MMS_FA",
84
+ "WAVLM_BASE",
85
+ "WAVLM_BASE_PLUS",
86
+ "WAVLM_LARGE",
87
+ "Tacotron2TTSBundle",
88
+ "TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
89
+ "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
90
+ "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
91
+ "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
92
+ "RNNTBundle",
93
+ "EMFORMER_RNNT_BASE_LIBRISPEECH",
94
+ "SourceSeparationBundle",
95
+ "CONVTASNET_BASE_LIBRI2MIX",
96
+ "HDEMUCS_HIGH_MUSDB_PLUS",
97
+ "HDEMUCS_HIGH_MUSDB",
98
+ "SQUIM_OBJECTIVE",
99
+ "SQUIM_SUBJECTIVE",
100
+ "SquimObjectiveBundle",
101
+ "SquimSubjectiveBundle",
102
+ ]
@@ -0,0 +1,109 @@
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from typing import Callable
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from torchaudio.models import conv_tasnet_base, hdemucs_high
9
+
10
+
11
+ @dataclass
12
+ class SourceSeparationBundle:
13
+ """Dataclass that bundles components for performing source separation.
14
+
15
+ Example
16
+ >>> import torchaudio
17
+ >>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
18
+ >>> import torch
19
+ >>>
20
+ >>> # Build the separation model.
21
+ >>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
22
+ >>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
23
+ >>>
24
+ >>> # Instantiate the test set of Libri2Mix dataset.
25
+ >>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
26
+ >>>
27
+ >>> # Apply source separation on mixture audio.
28
+ >>> for i, data in enumerate(dataset):
29
+ >>> sample_rate, mixture, clean_sources = data
30
+ >>> # Make sure the shape of input suits the model requirement.
31
+ >>> mixture = mixture.reshape(1, 1, -1)
32
+ >>> estimated_sources = model(mixture)
33
+ >>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
34
+ >>> print(f"Si-SNR score is : {score}.)
35
+ >>> break
36
+ >>> Si-SNR score is : 16.24.
37
+ >>>
38
+ """
39
+
40
+ _model_path: str
41
+ _model_factory_func: Callable[[], torch.nn.Module]
42
+ _sample_rate: int
43
+
44
+ @property
45
+ def sample_rate(self) -> int:
46
+ """Sample rate of the audio that the model is trained on.
47
+
48
+ :type: int
49
+ """
50
+ return self._sample_rate
51
+
52
+ def get_model(self) -> torch.nn.Module:
53
+ """Construct the model and load the pretrained weight."""
54
+ model = self._model_factory_func()
55
+ path = torchaudio.utils._download_asset(self._model_path)
56
+ state_dict = torch.load(path)
57
+ model.load_state_dict(state_dict)
58
+ model.eval()
59
+ return model
60
+
61
+
62
+ CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
63
+ _model_path="models/conv_tasnet_base_libri2mix.pt",
64
+ _model_factory_func=partial(conv_tasnet_base, num_sources=2),
65
+ _sample_rate=8000,
66
+ )
67
+ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet*
68
+ :cite:`Luo_2019` trained on *Libri2Mix dataset* :cite:`cosentino2020librimix`.
69
+
70
+ The source separation model is constructed by :func:`~torchaudio.models.conv_tasnet_base`
71
+ and is trained using the training script ``lightning_train.py``
72
+ `here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
73
+ with default arguments.
74
+
75
+ Please refer to :class:`SourceSeparationBundle` for usage instructions.
76
+ """
77
+
78
+
79
+ HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
80
+ _model_path="models/hdemucs_high_trained.pt",
81
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
82
+ _sample_rate=44100,
83
+ )
84
+ HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained music source separation pipeline with
85
+ *Hybrid Demucs* :cite:`defossez2021hybrid` trained on both training and test sets of
86
+ MUSDB-HQ :cite:`MUSDB18HQ` and an additional 150 extra songs from an internal database
87
+ that was specifically produced for Meta.
88
+
89
+ The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
90
+
91
+ Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
92
+
93
+ Please refer to :class:`SourceSeparationBundle` for usage instructions.
94
+ """
95
+
96
+
97
+ HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
98
+ _model_path="models/hdemucs_high_musdbhq_only.pt",
99
+ _model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
100
+ _sample_rate=44100,
101
+ )
102
+ HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with
103
+ *Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`.
104
+
105
+ The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
106
+ Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
107
+
108
+ Please refer to :class:`SourceSeparationBundle` for usage instructions.
109
+ """