torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +86 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn, Tensor
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ResBlock",
|
|
10
|
+
"MelResNet",
|
|
11
|
+
"Stretch2d",
|
|
12
|
+
"UpsampleNetwork",
|
|
13
|
+
"WaveRNN",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ResBlock(nn.Module):
|
|
18
|
+
r"""ResNet block based on *Efficient Neural Audio Synthesis* :cite:`kalchbrenner2018efficient`.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
>>> resblock = ResBlock()
|
|
25
|
+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
|
|
26
|
+
>>> output = resblock(input) # shape: (10, 128, 512)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, n_freq: int = 128) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.resblock_model = nn.Sequential(
|
|
33
|
+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
|
|
34
|
+
nn.BatchNorm1d(n_freq),
|
|
35
|
+
nn.ReLU(inplace=True),
|
|
36
|
+
nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
|
|
37
|
+
nn.BatchNorm1d(n_freq),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def forward(self, specgram: Tensor) -> Tensor:
|
|
41
|
+
r"""Pass the input through the ResBlock layer.
|
|
42
|
+
Args:
|
|
43
|
+
specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
|
|
44
|
+
|
|
45
|
+
Return:
|
|
46
|
+
Tensor shape: (n_batch, n_freq, n_time)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
return self.resblock_model(specgram) + specgram
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MelResNet(nn.Module):
|
|
53
|
+
r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
|
57
|
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
|
58
|
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
|
59
|
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
|
60
|
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
|
61
|
+
|
|
62
|
+
Examples
|
|
63
|
+
>>> melresnet = MelResNet()
|
|
64
|
+
>>> input = torch.rand(10, 128, 512) # a random spectrogram
|
|
65
|
+
>>> output = melresnet(input) # shape: (10, 128, 508)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self, n_res_block: int = 10, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128, kernel_size: int = 5
|
|
70
|
+
) -> None:
|
|
71
|
+
super().__init__()
|
|
72
|
+
|
|
73
|
+
ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
|
|
74
|
+
|
|
75
|
+
self.melresnet_model = nn.Sequential(
|
|
76
|
+
nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
|
|
77
|
+
nn.BatchNorm1d(n_hidden),
|
|
78
|
+
nn.ReLU(inplace=True),
|
|
79
|
+
*ResBlocks,
|
|
80
|
+
nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def forward(self, specgram: Tensor) -> Tensor:
|
|
84
|
+
r"""Pass the input through the MelResNet layer.
|
|
85
|
+
Args:
|
|
86
|
+
specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
|
|
87
|
+
|
|
88
|
+
Return:
|
|
89
|
+
Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
return self.melresnet_model(specgram)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Stretch2d(nn.Module):
|
|
96
|
+
r"""Upscale the frequency and time dimensions of a spectrogram.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
time_scale: the scale factor in time dimension
|
|
100
|
+
freq_scale: the scale factor in frequency dimension
|
|
101
|
+
|
|
102
|
+
Examples
|
|
103
|
+
>>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
|
|
104
|
+
|
|
105
|
+
>>> input = torch.rand(10, 100, 512) # a random spectrogram
|
|
106
|
+
>>> output = stretch2d(input) # shape: (10, 500, 5120)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, time_scale: int, freq_scale: int) -> None:
|
|
110
|
+
super().__init__()
|
|
111
|
+
|
|
112
|
+
self.freq_scale = freq_scale
|
|
113
|
+
self.time_scale = time_scale
|
|
114
|
+
|
|
115
|
+
def forward(self, specgram: Tensor) -> Tensor:
|
|
116
|
+
r"""Pass the input through the Stretch2d layer.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
|
|
120
|
+
|
|
121
|
+
Return:
|
|
122
|
+
Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class UpsampleNetwork(nn.Module):
|
|
129
|
+
r"""Upscale the dimensions of a spectrogram.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
upsample_scales: the list of upsample scales.
|
|
133
|
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
|
134
|
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
|
135
|
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
|
136
|
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
|
137
|
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
>>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
|
|
141
|
+
>>> input = torch.rand(10, 128, 10) # a random spectrogram
|
|
142
|
+
>>> output = upsamplenetwork(input) # shape: (10, 128, 1536), (10, 128, 1536)
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
upsample_scales: List[int],
|
|
148
|
+
n_res_block: int = 10,
|
|
149
|
+
n_freq: int = 128,
|
|
150
|
+
n_hidden: int = 128,
|
|
151
|
+
n_output: int = 128,
|
|
152
|
+
kernel_size: int = 5,
|
|
153
|
+
) -> None:
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
total_scale = 1
|
|
157
|
+
for upsample_scale in upsample_scales:
|
|
158
|
+
total_scale *= upsample_scale
|
|
159
|
+
self.total_scale: int = total_scale
|
|
160
|
+
|
|
161
|
+
self.indent = (kernel_size - 1) // 2 * total_scale
|
|
162
|
+
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
|
163
|
+
self.resnet_stretch = Stretch2d(total_scale, 1)
|
|
164
|
+
|
|
165
|
+
up_layers = []
|
|
166
|
+
for scale in upsample_scales:
|
|
167
|
+
stretch = Stretch2d(scale, 1)
|
|
168
|
+
conv = nn.Conv2d(
|
|
169
|
+
in_channels=1, out_channels=1, kernel_size=(1, scale * 2 + 1), padding=(0, scale), bias=False
|
|
170
|
+
)
|
|
171
|
+
torch.nn.init.constant_(conv.weight, 1.0 / (scale * 2 + 1))
|
|
172
|
+
up_layers.append(stretch)
|
|
173
|
+
up_layers.append(conv)
|
|
174
|
+
self.upsample_layers = nn.Sequential(*up_layers)
|
|
175
|
+
|
|
176
|
+
def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
|
|
177
|
+
r"""Pass the input through the UpsampleNetwork layer.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
|
|
181
|
+
|
|
182
|
+
Return:
|
|
183
|
+
Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
|
|
184
|
+
(n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
|
|
185
|
+
where total_scale is the product of all elements in upsample_scales.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
resnet_output = self.resnet(specgram).unsqueeze(1)
|
|
189
|
+
resnet_output = self.resnet_stretch(resnet_output)
|
|
190
|
+
resnet_output = resnet_output.squeeze(1)
|
|
191
|
+
|
|
192
|
+
specgram = specgram.unsqueeze(1)
|
|
193
|
+
upsampling_output = self.upsample_layers(specgram)
|
|
194
|
+
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent : -self.indent]
|
|
195
|
+
|
|
196
|
+
return upsampling_output, resnet_output
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class WaveRNN(nn.Module):
|
|
200
|
+
r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
|
|
201
|
+
based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
|
|
202
|
+
|
|
203
|
+
The original implementation was introduced in *Efficient Neural Audio Synthesis*
|
|
204
|
+
:cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
|
|
205
|
+
The product of `upsample_scales` must equal `hop_length`.
|
|
206
|
+
|
|
207
|
+
See Also:
|
|
208
|
+
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
|
|
209
|
+
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
upsample_scales: the list of upsample scales.
|
|
213
|
+
n_classes: the number of output classes.
|
|
214
|
+
hop_length: the number of samples between the starts of consecutive frames.
|
|
215
|
+
n_res_block: the number of ResBlock in stack. (Default: ``10``)
|
|
216
|
+
n_rnn: the dimension of RNN layer. (Default: ``512``)
|
|
217
|
+
n_fc: the dimension of fully connected layer. (Default: ``512``)
|
|
218
|
+
kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
|
|
219
|
+
n_freq: the number of bins in a spectrogram. (Default: ``128``)
|
|
220
|
+
n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
|
|
221
|
+
n_output: the number of output dimensions of melresnet. (Default: ``128``)
|
|
222
|
+
|
|
223
|
+
Example
|
|
224
|
+
>>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
|
|
225
|
+
>>> waveform, sample_rate = torchaudio.load(file)
|
|
226
|
+
>>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
|
|
227
|
+
>>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time)
|
|
228
|
+
>>> output = wavernn(waveform, specgram)
|
|
229
|
+
>>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
upsample_scales: List[int],
|
|
235
|
+
n_classes: int,
|
|
236
|
+
hop_length: int,
|
|
237
|
+
n_res_block: int = 10,
|
|
238
|
+
n_rnn: int = 512,
|
|
239
|
+
n_fc: int = 512,
|
|
240
|
+
kernel_size: int = 5,
|
|
241
|
+
n_freq: int = 128,
|
|
242
|
+
n_hidden: int = 128,
|
|
243
|
+
n_output: int = 128,
|
|
244
|
+
) -> None:
|
|
245
|
+
super().__init__()
|
|
246
|
+
|
|
247
|
+
self.kernel_size = kernel_size
|
|
248
|
+
self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
|
|
249
|
+
self.n_rnn = n_rnn
|
|
250
|
+
self.n_aux = n_output // 4
|
|
251
|
+
self.hop_length = hop_length
|
|
252
|
+
self.n_classes = n_classes
|
|
253
|
+
self.n_bits: int = int(math.log2(self.n_classes))
|
|
254
|
+
|
|
255
|
+
total_scale = 1
|
|
256
|
+
for upsample_scale in upsample_scales:
|
|
257
|
+
total_scale *= upsample_scale
|
|
258
|
+
if total_scale != self.hop_length:
|
|
259
|
+
raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")
|
|
260
|
+
|
|
261
|
+
self.upsample = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
|
|
262
|
+
self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)
|
|
263
|
+
|
|
264
|
+
self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
|
|
265
|
+
self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)
|
|
266
|
+
|
|
267
|
+
self.relu1 = nn.ReLU(inplace=True)
|
|
268
|
+
self.relu2 = nn.ReLU(inplace=True)
|
|
269
|
+
|
|
270
|
+
self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
|
|
271
|
+
self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
|
|
272
|
+
self.fc3 = nn.Linear(n_fc, self.n_classes)
|
|
273
|
+
|
|
274
|
+
def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
|
|
275
|
+
r"""Pass the input through the WaveRNN model.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
|
|
279
|
+
specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
|
|
280
|
+
|
|
281
|
+
Return:
|
|
282
|
+
Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
if waveform.size(1) != 1:
|
|
286
|
+
raise ValueError("Require the input channel of waveform is 1")
|
|
287
|
+
if specgram.size(1) != 1:
|
|
288
|
+
raise ValueError("Require the input channel of specgram is 1")
|
|
289
|
+
# remove channel dimension until the end
|
|
290
|
+
waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)
|
|
291
|
+
|
|
292
|
+
batch_size = waveform.size(0)
|
|
293
|
+
h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
|
|
294
|
+
h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
|
|
295
|
+
# output of upsample:
|
|
296
|
+
# specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
|
|
297
|
+
# aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
|
|
298
|
+
specgram, aux = self.upsample(specgram)
|
|
299
|
+
specgram = specgram.transpose(1, 2)
|
|
300
|
+
aux = aux.transpose(1, 2)
|
|
301
|
+
|
|
302
|
+
aux_idx = [self.n_aux * i for i in range(5)]
|
|
303
|
+
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
|
304
|
+
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
|
305
|
+
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
|
306
|
+
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
|
307
|
+
|
|
308
|
+
x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
|
|
309
|
+
x = self.fc(x)
|
|
310
|
+
res = x
|
|
311
|
+
x, _ = self.rnn1(x, h1)
|
|
312
|
+
|
|
313
|
+
x = x + res
|
|
314
|
+
res = x
|
|
315
|
+
x = torch.cat([x, a2], dim=-1)
|
|
316
|
+
x, _ = self.rnn2(x, h2)
|
|
317
|
+
|
|
318
|
+
x = x + res
|
|
319
|
+
x = torch.cat([x, a3], dim=-1)
|
|
320
|
+
x = self.fc1(x)
|
|
321
|
+
x = self.relu1(x)
|
|
322
|
+
|
|
323
|
+
x = torch.cat([x, a4], dim=-1)
|
|
324
|
+
x = self.fc2(x)
|
|
325
|
+
x = self.relu2(x)
|
|
326
|
+
x = self.fc3(x)
|
|
327
|
+
|
|
328
|
+
# bring back channel dimension
|
|
329
|
+
return x.unsqueeze(1)
|
|
330
|
+
|
|
331
|
+
@torch.jit.export
|
|
332
|
+
def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
|
333
|
+
r"""Inference method of WaveRNN.
|
|
334
|
+
|
|
335
|
+
This function currently only supports multinomial sampling, which assumes the
|
|
336
|
+
network is trained on cross entropy loss.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
specgram (Tensor):
|
|
340
|
+
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
|
|
341
|
+
lengths (Tensor or None, optional):
|
|
342
|
+
Indicates the valid length of each audio in the batch.
|
|
343
|
+
Shape: `(batch, )`.
|
|
344
|
+
When the ``specgram`` contains spectrograms with different durations,
|
|
345
|
+
by providing ``lengths`` argument, the model will compute
|
|
346
|
+
the corresponding valid output lengths.
|
|
347
|
+
If ``None``, it is assumed that all the audio in ``waveforms``
|
|
348
|
+
have valid length. Default: ``None``.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
(Tensor, Optional[Tensor]):
|
|
352
|
+
Tensor
|
|
353
|
+
The inferred waveform of size `(n_batch, 1, n_time)`.
|
|
354
|
+
1 stands for a single channel.
|
|
355
|
+
Tensor or None
|
|
356
|
+
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
|
|
357
|
+
is returned.
|
|
358
|
+
It indicates the valid length in time axis of the output Tensor.
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
device = specgram.device
|
|
362
|
+
dtype = specgram.dtype
|
|
363
|
+
|
|
364
|
+
specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
|
|
365
|
+
specgram, aux = self.upsample(specgram)
|
|
366
|
+
if lengths is not None:
|
|
367
|
+
lengths = lengths * self.upsample.total_scale
|
|
368
|
+
|
|
369
|
+
output: List[Tensor] = []
|
|
370
|
+
b_size, _, seq_len = specgram.size()
|
|
371
|
+
|
|
372
|
+
h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
|
|
373
|
+
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
|
|
374
|
+
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
|
|
375
|
+
|
|
376
|
+
aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
|
|
377
|
+
|
|
378
|
+
for i in range(seq_len):
|
|
379
|
+
|
|
380
|
+
m_t = specgram[:, :, i]
|
|
381
|
+
|
|
382
|
+
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
|
|
383
|
+
|
|
384
|
+
x = torch.cat([x, m_t, a1_t], dim=1)
|
|
385
|
+
x = self.fc(x)
|
|
386
|
+
_, h1 = self.rnn1(x.unsqueeze(1), h1)
|
|
387
|
+
|
|
388
|
+
x = x + h1[0]
|
|
389
|
+
inp = torch.cat([x, a2_t], dim=1)
|
|
390
|
+
_, h2 = self.rnn2(inp.unsqueeze(1), h2)
|
|
391
|
+
|
|
392
|
+
x = x + h2[0]
|
|
393
|
+
x = torch.cat([x, a3_t], dim=1)
|
|
394
|
+
x = F.relu(self.fc1(x))
|
|
395
|
+
|
|
396
|
+
x = torch.cat([x, a4_t], dim=1)
|
|
397
|
+
x = F.relu(self.fc2(x))
|
|
398
|
+
|
|
399
|
+
logits = self.fc3(x)
|
|
400
|
+
|
|
401
|
+
posterior = F.softmax(logits, dim=1)
|
|
402
|
+
|
|
403
|
+
x = torch.multinomial(posterior, 1).float()
|
|
404
|
+
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
|
|
405
|
+
x = 2 * x / (2**self.n_bits - 1.0) - 1.0
|
|
406
|
+
|
|
407
|
+
output.append(x)
|
|
408
|
+
|
|
409
|
+
return torch.stack(output).permute(1, 2, 0), lengths
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from ._source_separation_pipeline import (
|
|
2
|
+
CONVTASNET_BASE_LIBRI2MIX,
|
|
3
|
+
HDEMUCS_HIGH_MUSDB,
|
|
4
|
+
HDEMUCS_HIGH_MUSDB_PLUS,
|
|
5
|
+
SourceSeparationBundle,
|
|
6
|
+
)
|
|
7
|
+
from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
|
|
8
|
+
from ._tts import (
|
|
9
|
+
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
|
|
10
|
+
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
|
|
11
|
+
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
|
|
12
|
+
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
|
|
13
|
+
Tacotron2TTSBundle,
|
|
14
|
+
)
|
|
15
|
+
from ._wav2vec2.impl import (
|
|
16
|
+
HUBERT_ASR_LARGE,
|
|
17
|
+
HUBERT_ASR_XLARGE,
|
|
18
|
+
HUBERT_BASE,
|
|
19
|
+
HUBERT_LARGE,
|
|
20
|
+
HUBERT_XLARGE,
|
|
21
|
+
MMS_FA,
|
|
22
|
+
VOXPOPULI_ASR_BASE_10K_DE,
|
|
23
|
+
VOXPOPULI_ASR_BASE_10K_EN,
|
|
24
|
+
VOXPOPULI_ASR_BASE_10K_ES,
|
|
25
|
+
VOXPOPULI_ASR_BASE_10K_FR,
|
|
26
|
+
VOXPOPULI_ASR_BASE_10K_IT,
|
|
27
|
+
WAV2VEC2_ASR_BASE_100H,
|
|
28
|
+
WAV2VEC2_ASR_BASE_10M,
|
|
29
|
+
WAV2VEC2_ASR_BASE_960H,
|
|
30
|
+
WAV2VEC2_ASR_LARGE_100H,
|
|
31
|
+
WAV2VEC2_ASR_LARGE_10M,
|
|
32
|
+
WAV2VEC2_ASR_LARGE_960H,
|
|
33
|
+
WAV2VEC2_ASR_LARGE_LV60K_100H,
|
|
34
|
+
WAV2VEC2_ASR_LARGE_LV60K_10M,
|
|
35
|
+
WAV2VEC2_ASR_LARGE_LV60K_960H,
|
|
36
|
+
WAV2VEC2_BASE,
|
|
37
|
+
WAV2VEC2_LARGE,
|
|
38
|
+
WAV2VEC2_LARGE_LV60K,
|
|
39
|
+
WAV2VEC2_XLSR53,
|
|
40
|
+
WAV2VEC2_XLSR_1B,
|
|
41
|
+
WAV2VEC2_XLSR_2B,
|
|
42
|
+
WAV2VEC2_XLSR_300M,
|
|
43
|
+
Wav2Vec2ASRBundle,
|
|
44
|
+
Wav2Vec2Bundle,
|
|
45
|
+
Wav2Vec2FABundle,
|
|
46
|
+
WAVLM_BASE,
|
|
47
|
+
WAVLM_BASE_PLUS,
|
|
48
|
+
WAVLM_LARGE,
|
|
49
|
+
)
|
|
50
|
+
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
__all__ = [
|
|
54
|
+
"Wav2Vec2Bundle",
|
|
55
|
+
"Wav2Vec2ASRBundle",
|
|
56
|
+
"Wav2Vec2FABundle",
|
|
57
|
+
"WAV2VEC2_BASE",
|
|
58
|
+
"WAV2VEC2_LARGE",
|
|
59
|
+
"WAV2VEC2_LARGE_LV60K",
|
|
60
|
+
"WAV2VEC2_ASR_BASE_10M",
|
|
61
|
+
"WAV2VEC2_ASR_BASE_100H",
|
|
62
|
+
"WAV2VEC2_ASR_BASE_960H",
|
|
63
|
+
"WAV2VEC2_ASR_LARGE_10M",
|
|
64
|
+
"WAV2VEC2_ASR_LARGE_100H",
|
|
65
|
+
"WAV2VEC2_ASR_LARGE_960H",
|
|
66
|
+
"WAV2VEC2_ASR_LARGE_LV60K_10M",
|
|
67
|
+
"WAV2VEC2_ASR_LARGE_LV60K_100H",
|
|
68
|
+
"WAV2VEC2_ASR_LARGE_LV60K_960H",
|
|
69
|
+
"WAV2VEC2_XLSR53",
|
|
70
|
+
"WAV2VEC2_XLSR_300M",
|
|
71
|
+
"WAV2VEC2_XLSR_1B",
|
|
72
|
+
"WAV2VEC2_XLSR_2B",
|
|
73
|
+
"VOXPOPULI_ASR_BASE_10K_EN",
|
|
74
|
+
"VOXPOPULI_ASR_BASE_10K_ES",
|
|
75
|
+
"VOXPOPULI_ASR_BASE_10K_DE",
|
|
76
|
+
"VOXPOPULI_ASR_BASE_10K_FR",
|
|
77
|
+
"VOXPOPULI_ASR_BASE_10K_IT",
|
|
78
|
+
"HUBERT_BASE",
|
|
79
|
+
"HUBERT_LARGE",
|
|
80
|
+
"HUBERT_XLARGE",
|
|
81
|
+
"HUBERT_ASR_LARGE",
|
|
82
|
+
"HUBERT_ASR_XLARGE",
|
|
83
|
+
"MMS_FA",
|
|
84
|
+
"WAVLM_BASE",
|
|
85
|
+
"WAVLM_BASE_PLUS",
|
|
86
|
+
"WAVLM_LARGE",
|
|
87
|
+
"Tacotron2TTSBundle",
|
|
88
|
+
"TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
|
|
89
|
+
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
|
|
90
|
+
"TACOTRON2_WAVERNN_CHAR_LJSPEECH",
|
|
91
|
+
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
|
|
92
|
+
"RNNTBundle",
|
|
93
|
+
"EMFORMER_RNNT_BASE_LIBRISPEECH",
|
|
94
|
+
"SourceSeparationBundle",
|
|
95
|
+
"CONVTASNET_BASE_LIBRI2MIX",
|
|
96
|
+
"HDEMUCS_HIGH_MUSDB_PLUS",
|
|
97
|
+
"HDEMUCS_HIGH_MUSDB",
|
|
98
|
+
"SQUIM_OBJECTIVE",
|
|
99
|
+
"SQUIM_SUBJECTIVE",
|
|
100
|
+
"SquimObjectiveBundle",
|
|
101
|
+
"SquimSubjectiveBundle",
|
|
102
|
+
]
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchaudio
|
|
7
|
+
|
|
8
|
+
from torchaudio.models import conv_tasnet_base, hdemucs_high
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SourceSeparationBundle:
|
|
13
|
+
"""Dataclass that bundles components for performing source separation.
|
|
14
|
+
|
|
15
|
+
Example
|
|
16
|
+
>>> import torchaudio
|
|
17
|
+
>>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
|
|
18
|
+
>>> import torch
|
|
19
|
+
>>>
|
|
20
|
+
>>> # Build the separation model.
|
|
21
|
+
>>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
|
|
22
|
+
>>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Instantiate the test set of Libri2Mix dataset.
|
|
25
|
+
>>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Apply source separation on mixture audio.
|
|
28
|
+
>>> for i, data in enumerate(dataset):
|
|
29
|
+
>>> sample_rate, mixture, clean_sources = data
|
|
30
|
+
>>> # Make sure the shape of input suits the model requirement.
|
|
31
|
+
>>> mixture = mixture.reshape(1, 1, -1)
|
|
32
|
+
>>> estimated_sources = model(mixture)
|
|
33
|
+
>>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
|
|
34
|
+
>>> print(f"Si-SNR score is : {score}.)
|
|
35
|
+
>>> break
|
|
36
|
+
>>> Si-SNR score is : 16.24.
|
|
37
|
+
>>>
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
_model_path: str
|
|
41
|
+
_model_factory_func: Callable[[], torch.nn.Module]
|
|
42
|
+
_sample_rate: int
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def sample_rate(self) -> int:
|
|
46
|
+
"""Sample rate of the audio that the model is trained on.
|
|
47
|
+
|
|
48
|
+
:type: int
|
|
49
|
+
"""
|
|
50
|
+
return self._sample_rate
|
|
51
|
+
|
|
52
|
+
def get_model(self) -> torch.nn.Module:
|
|
53
|
+
"""Construct the model and load the pretrained weight."""
|
|
54
|
+
model = self._model_factory_func()
|
|
55
|
+
path = torchaudio.utils._download_asset(self._model_path)
|
|
56
|
+
state_dict = torch.load(path)
|
|
57
|
+
model.load_state_dict(state_dict)
|
|
58
|
+
model.eval()
|
|
59
|
+
return model
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
|
|
63
|
+
_model_path="models/conv_tasnet_base_libri2mix.pt",
|
|
64
|
+
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
|
|
65
|
+
_sample_rate=8000,
|
|
66
|
+
)
|
|
67
|
+
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet*
|
|
68
|
+
:cite:`Luo_2019` trained on *Libri2Mix dataset* :cite:`cosentino2020librimix`.
|
|
69
|
+
|
|
70
|
+
The source separation model is constructed by :func:`~torchaudio.models.conv_tasnet_base`
|
|
71
|
+
and is trained using the training script ``lightning_train.py``
|
|
72
|
+
`here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
|
|
73
|
+
with default arguments.
|
|
74
|
+
|
|
75
|
+
Please refer to :class:`SourceSeparationBundle` for usage instructions.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
|
|
80
|
+
_model_path="models/hdemucs_high_trained.pt",
|
|
81
|
+
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
|
|
82
|
+
_sample_rate=44100,
|
|
83
|
+
)
|
|
84
|
+
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained music source separation pipeline with
|
|
85
|
+
*Hybrid Demucs* :cite:`defossez2021hybrid` trained on both training and test sets of
|
|
86
|
+
MUSDB-HQ :cite:`MUSDB18HQ` and an additional 150 extra songs from an internal database
|
|
87
|
+
that was specifically produced for Meta.
|
|
88
|
+
|
|
89
|
+
The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
|
|
90
|
+
|
|
91
|
+
Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
|
|
92
|
+
|
|
93
|
+
Please refer to :class:`SourceSeparationBundle` for usage instructions.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
|
|
98
|
+
_model_path="models/hdemucs_high_musdbhq_only.pt",
|
|
99
|
+
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
|
|
100
|
+
_sample_rate=44100,
|
|
101
|
+
)
|
|
102
|
+
HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained music source separation pipeline with
|
|
103
|
+
*Hybrid Demucs* :cite:`defossez2021hybrid` trained on the training set of MUSDB-HQ :cite:`MUSDB18HQ`.
|
|
104
|
+
|
|
105
|
+
The model is constructed by :func:`~torchaudio.models.hdemucs_high`.
|
|
106
|
+
Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
|
|
107
|
+
|
|
108
|
+
Please refer to :class:`SourceSeparationBundle` for usage instructions.
|
|
109
|
+
"""
|