torchaudio 2.0.2__cp38-cp38-manylinux1_x86_64.whl → 2.1.1__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +22 -3
- torchaudio/_backend/__init__.py +55 -4
- torchaudio/_backend/backend.py +53 -0
- torchaudio/_backend/common.py +52 -0
- torchaudio/_backend/ffmpeg.py +373 -0
- torchaudio/_backend/soundfile.py +54 -0
- torchaudio/_backend/soundfile_backend.py +457 -0
- torchaudio/_backend/sox.py +91 -0
- torchaudio/_backend/utils.py +81 -323
- torchaudio/_extension/__init__.py +55 -36
- torchaudio/_extension/utils.py +109 -17
- torchaudio/_internal/__init__.py +4 -1
- torchaudio/_internal/module_utils.py +37 -6
- torchaudio/backend/__init__.py +7 -11
- torchaudio/backend/_no_backend.py +24 -0
- torchaudio/backend/_sox_io_backend.py +297 -0
- torchaudio/backend/common.py +12 -52
- torchaudio/backend/no_backend.py +11 -21
- torchaudio/backend/soundfile_backend.py +11 -448
- torchaudio/backend/sox_io_backend.py +11 -435
- torchaudio/backend/utils.py +9 -18
- torchaudio/datasets/__init__.py +2 -0
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/cmudict.py +61 -62
- torchaudio/datasets/dr_vctk.py +1 -1
- torchaudio/datasets/gtzan.py +1 -1
- torchaudio/datasets/librilight_limited.py +1 -1
- torchaudio/datasets/librispeech.py +1 -1
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +1 -1
- torchaudio/datasets/ljspeech.py +1 -1
- torchaudio/datasets/musdb_hq.py +1 -1
- torchaudio/datasets/quesst14.py +1 -1
- torchaudio/datasets/speechcommands.py +1 -1
- torchaudio/datasets/tedlium.py +1 -1
- torchaudio/datasets/vctk.py +1 -1
- torchaudio/datasets/voxceleb1.py +1 -1
- torchaudio/datasets/yesno.py +1 -1
- torchaudio/functional/__init__.py +6 -2
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +69 -92
- torchaudio/functional/functional.py +99 -148
- torchaudio/io/__init__.py +4 -1
- torchaudio/io/_effector.py +347 -0
- torchaudio/io/_stream_reader.py +158 -90
- torchaudio/io/_stream_writer.py +196 -10
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/_torchaudio_sox.so +0 -0
- torchaudio/lib/libctc_prefix_decoder.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/libtorchaudio_sox.so +0 -0
- torchaudio/lib/pybind11_prefixctc.so +0 -0
- torchaudio/models/__init__.py +14 -0
- torchaudio/models/decoder/__init__.py +22 -7
- torchaudio/models/decoder/_ctc_decoder.py +123 -69
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/rnnt_decoder.py +10 -14
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/wav2vec2/components.py +6 -10
- torchaudio/pipelines/__init__.py +9 -0
- torchaudio/pipelines/_squim_pipeline.py +176 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +198 -68
- torchaudio/pipelines/_wav2vec2/utils.py +120 -0
- torchaudio/sox_effects/sox_effects.py +7 -30
- torchaudio/transforms/__init__.py +2 -0
- torchaudio/transforms/_transforms.py +99 -54
- torchaudio/utils/download.py +2 -2
- torchaudio/utils/ffmpeg_utils.py +20 -15
- torchaudio/utils/sox_utils.py +8 -9
- torchaudio/version.py +2 -2
- torchaudio-2.1.1.dist-info/METADATA +113 -0
- torchaudio-2.1.1.dist-info/RECORD +119 -0
- torchaudio/io/_compat.py +0 -241
- torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
- torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
- torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
- torchaudio/lib/libflashlight-text.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
- torchaudio-2.0.2.dist-info/METADATA +0 -26
- torchaudio-2.0.2.dist-info/RECORD +0 -100
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
|
@@ -1,41 +1,12 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any, Dict,
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple
|
|
3
3
|
|
|
4
|
-
import
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
from torch.nn import functional as F, Module
|
|
7
|
-
from torchaudio._internal import load_state_dict_from_url
|
|
8
|
-
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
|
|
4
|
+
from torch.nn import Module
|
|
9
5
|
|
|
10
|
-
from . import utils
|
|
6
|
+
from . import aligner, utils
|
|
11
7
|
|
|
12
8
|
|
|
13
|
-
__all__ = []
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class _Wav2Vec2Model(Module):
|
|
17
|
-
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
|
|
18
|
-
|
|
19
|
-
This is used for layer normalization at the input
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def __init__(self, model: Wav2Vec2Model):
|
|
23
|
-
super().__init__()
|
|
24
|
-
self.model = model
|
|
25
|
-
|
|
26
|
-
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
|
27
|
-
waveforms = F.layer_norm(waveforms, waveforms.shape)
|
|
28
|
-
return self.model(waveforms, lengths)
|
|
29
|
-
|
|
30
|
-
@torch.jit.export
|
|
31
|
-
def extract_features(
|
|
32
|
-
self,
|
|
33
|
-
waveforms: Tensor,
|
|
34
|
-
lengths: Optional[Tensor] = None,
|
|
35
|
-
num_layers: Optional[int] = None,
|
|
36
|
-
) -> Tuple[List[Tensor], Optional[Tensor]]:
|
|
37
|
-
waveforms = F.layer_norm(waveforms, waveforms.shape)
|
|
38
|
-
return self.model.extract_features(waveforms, lengths, num_layers)
|
|
9
|
+
__all__ = [] # type: ignore
|
|
39
10
|
|
|
40
11
|
|
|
41
12
|
@dataclass
|
|
@@ -84,10 +55,8 @@ class Wav2Vec2Bundle:
|
|
|
84
55
|
return self._sample_rate
|
|
85
56
|
|
|
86
57
|
def _get_state_dict(self, dl_kwargs):
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
90
|
-
return state_dict
|
|
58
|
+
# Note: This method is overridden in ASR bundle
|
|
59
|
+
return utils._get_state_dict(self._path, dl_kwargs)
|
|
91
60
|
|
|
92
61
|
def get_model(self, *, dl_kwargs=None) -> Module:
|
|
93
62
|
"""Construct the model and load the pretrained weight.
|
|
@@ -119,13 +88,11 @@ class Wav2Vec2Bundle:
|
|
|
119
88
|
- HUBERT_ASR_XLARGE
|
|
120
89
|
- WAVLM_LARGE
|
|
121
90
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
model = wav2vec2_model(**self._params)
|
|
126
|
-
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
|
91
|
+
model = utils._get_model(self._model_type, self._params)
|
|
92
|
+
state_dict = self._get_state_dict(dl_kwargs)
|
|
93
|
+
model.load_state_dict(state_dict)
|
|
127
94
|
if self._normalize_waveform:
|
|
128
|
-
model =
|
|
95
|
+
model = utils._extend_model(model, normalize_waveform=True)
|
|
129
96
|
model.eval()
|
|
130
97
|
return model
|
|
131
98
|
|
|
@@ -171,15 +138,15 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
|
|
|
171
138
|
>>> transcripts = ctc_decode(emissions, labels)
|
|
172
139
|
""" # noqa: E501
|
|
173
140
|
|
|
174
|
-
_labels: Tuple[str]
|
|
175
|
-
_remove_aux_axis: Tuple[int] = (1, 2, 3)
|
|
141
|
+
_labels: Tuple[str, ...]
|
|
142
|
+
_remove_aux_axis: Tuple[int, ...] = (1, 2, 3)
|
|
176
143
|
|
|
177
144
|
def get_labels(
|
|
178
145
|
self,
|
|
179
146
|
*,
|
|
180
147
|
blank: str = "-",
|
|
181
|
-
) -> Tuple[str]:
|
|
182
|
-
"""The output class labels
|
|
148
|
+
) -> Tuple[str, ...]:
|
|
149
|
+
"""The output class labels.
|
|
183
150
|
|
|
184
151
|
The first is blank token, and it is customizable.
|
|
185
152
|
|
|
@@ -187,35 +154,19 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
|
|
|
187
154
|
blank (str, optional): Blank token. (default: ``'-'``)
|
|
188
155
|
|
|
189
156
|
Returns:
|
|
190
|
-
Tuple[str]:
|
|
157
|
+
Tuple[str, ...]:
|
|
191
158
|
For models fine-tuned on ASR, returns the tuple of strings representing
|
|
192
159
|
the output class labels.
|
|
193
160
|
|
|
194
161
|
Example
|
|
195
|
-
>>> import
|
|
196
|
-
>>>
|
|
162
|
+
>>> from torchaudio.pipelines import HUBERT_ASR_LARGE as bundle
|
|
163
|
+
>>> bundle.get_labels()
|
|
197
164
|
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
|
|
198
165
|
""" # noqa: E501
|
|
199
166
|
return (blank, *self._labels)
|
|
200
167
|
|
|
201
168
|
def _get_state_dict(self, dl_kwargs):
|
|
202
|
-
|
|
203
|
-
if self._remove_aux_axis:
|
|
204
|
-
# Remove the seemingly unnecessary axis
|
|
205
|
-
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
|
|
206
|
-
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
|
|
207
|
-
# but not used during the ASR training.
|
|
208
|
-
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
|
|
209
|
-
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
|
|
210
|
-
#
|
|
211
|
-
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
|
|
212
|
-
# that resembles mistake.
|
|
213
|
-
# The label `1` shows up in the training dataset of German (1 out of 16M),
|
|
214
|
-
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
|
|
215
|
-
for key in ["aux.weight", "aux.bias"]:
|
|
216
|
-
t = state_dict[key]
|
|
217
|
-
state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
|
|
218
|
-
return state_dict
|
|
169
|
+
return utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
|
|
219
170
|
|
|
220
171
|
|
|
221
172
|
WAV2VEC2_BASE = Wav2Vec2Bundle(
|
|
@@ -1399,7 +1350,7 @@ WAVLM_LARGE = Wav2Vec2Bundle(
|
|
|
1399
1350
|
"encoder_ff_interm_features": 4096,
|
|
1400
1351
|
"encoder_ff_interm_dropout": 0.0,
|
|
1401
1352
|
"encoder_dropout": 0.1,
|
|
1402
|
-
"encoder_layer_norm_first":
|
|
1353
|
+
"encoder_layer_norm_first": True,
|
|
1403
1354
|
"encoder_layer_drop": 0.05,
|
|
1404
1355
|
"aux_num_out": None,
|
|
1405
1356
|
},
|
|
@@ -1567,3 +1518,182 @@ redistributed with the same license.
|
|
|
1567
1518
|
|
|
1568
1519
|
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2Bundle` for usage details.
|
|
1569
1520
|
""" # noqa: E501
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
@dataclass
|
|
1524
|
+
class Wav2Vec2FABundle(Wav2Vec2ASRBundle):
|
|
1525
|
+
"""Data class that bundles associated information to use pretrained :py:class:`~torchaudio.models.Wav2Vec2Model` for forced alignment.
|
|
1526
|
+
|
|
1527
|
+
This class provides interfaces for instantiating the pretrained model along with
|
|
1528
|
+
the information necessary to retrieve pretrained weights and additional data
|
|
1529
|
+
to be used with the model.
|
|
1530
|
+
|
|
1531
|
+
Torchaudio library instantiates objects of this class, each of which represents
|
|
1532
|
+
a different pretrained model. Client code should access pretrained models via these
|
|
1533
|
+
instances.
|
|
1534
|
+
|
|
1535
|
+
Please see below for the usage and the available values.
|
|
1536
|
+
|
|
1537
|
+
Example - Feature Extraction
|
|
1538
|
+
>>> import torchaudio
|
|
1539
|
+
>>>
|
|
1540
|
+
>>> bundle = torchaudio.pipelines.MMS_FA
|
|
1541
|
+
>>>
|
|
1542
|
+
>>> # Build the model and load pretrained weight.
|
|
1543
|
+
>>> model = bundle.get_model()
|
|
1544
|
+
Downloading:
|
|
1545
|
+
100%|███████████████████████████████| 1.18G/1.18G [00:05<00:00, 216MB/s]
|
|
1546
|
+
>>>
|
|
1547
|
+
>>> # Resample audio to the expected sampling rate
|
|
1548
|
+
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
|
|
1549
|
+
>>>
|
|
1550
|
+
>>> # Estimate the probability of token distribution
|
|
1551
|
+
>>> emission, _ = model(waveform)
|
|
1552
|
+
>>>
|
|
1553
|
+
>>> # Generate frame-wise alignment
|
|
1554
|
+
>>> alignment, scores = torchaudio.functional.forced_align(
|
|
1555
|
+
>>> emission, targets, input_lengths, target_lengths, blank=0)
|
|
1556
|
+
>>>
|
|
1557
|
+
""" # noqa: E501
|
|
1558
|
+
|
|
1559
|
+
class Tokenizer(aligner.ITokenizer):
|
|
1560
|
+
"""Interface of the tokenizer"""
|
|
1561
|
+
|
|
1562
|
+
class Aligner(aligner.IAligner):
|
|
1563
|
+
"""Interface of the aligner"""
|
|
1564
|
+
|
|
1565
|
+
def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str, ...]:
|
|
1566
|
+
"""Get the labels corresponding to the feature dimension of emission.
|
|
1567
|
+
|
|
1568
|
+
The first is blank token, and it is customizable.
|
|
1569
|
+
|
|
1570
|
+
Args:
|
|
1571
|
+
star (str or None, optional): Change or disable star token. (default: ``"*"``)
|
|
1572
|
+
blank (str, optional): Change the blank token. (default: ``'-'``)
|
|
1573
|
+
|
|
1574
|
+
Returns:
|
|
1575
|
+
Tuple[str, ...]:
|
|
1576
|
+
For models fine-tuned on ASR, returns the tuple of strings representing
|
|
1577
|
+
the output class labels.
|
|
1578
|
+
|
|
1579
|
+
Example
|
|
1580
|
+
>>> from torchaudio.pipelines import MMS_FA as bundle
|
|
1581
|
+
>>> bundle.get_labels()
|
|
1582
|
+
('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x', '*')
|
|
1583
|
+
>>> bundle.get_labels(star=None)
|
|
1584
|
+
('-', 'a', 'i', 'e', 'n', 'o', 'u', 't', 's', 'r', 'm', 'k', 'l', 'd', 'g', 'h', 'y', 'b', 'p', 'w', 'c', 'v', 'j', 'z', 'f', "'", 'q', 'x')
|
|
1585
|
+
""" # noqa: E501
|
|
1586
|
+
labels = super().get_labels(blank=blank)
|
|
1587
|
+
return labels if star is None else (*labels, star)
|
|
1588
|
+
|
|
1589
|
+
def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module:
|
|
1590
|
+
"""Construct the model and load the pretrained weight.
|
|
1591
|
+
|
|
1592
|
+
The weight file is downloaded from the internet and cached with
|
|
1593
|
+
:func:`torch.hub.load_state_dict_from_url`
|
|
1594
|
+
|
|
1595
|
+
Args:
|
|
1596
|
+
with_star (bool, optional): If enabled, the last dimension of output layer is
|
|
1597
|
+
extended by one, which corresponds to `star` token.
|
|
1598
|
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
|
1599
|
+
|
|
1600
|
+
Returns:
|
|
1601
|
+
Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`.
|
|
1602
|
+
|
|
1603
|
+
.. note::
|
|
1604
|
+
|
|
1605
|
+
The model created with this method returns probability in log-domain,
|
|
1606
|
+
(i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas
|
|
1607
|
+
the other Wav2Vec2 models returns logit.
|
|
1608
|
+
"""
|
|
1609
|
+
model = utils._get_model(self._model_type, self._params)
|
|
1610
|
+
state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis)
|
|
1611
|
+
model.load_state_dict(state_dict)
|
|
1612
|
+
model = utils._extend_model(
|
|
1613
|
+
model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star
|
|
1614
|
+
)
|
|
1615
|
+
model.eval()
|
|
1616
|
+
return model
|
|
1617
|
+
|
|
1618
|
+
def get_dict(self, star: Optional[str] = "*", blank: str = "-") -> Dict[str, int]:
|
|
1619
|
+
"""Get the mapping from token to index (in emission feature dim)
|
|
1620
|
+
|
|
1621
|
+
Args:
|
|
1622
|
+
star (str or None, optional): Change or disable star token. (default: ``"*"``)
|
|
1623
|
+
blank (str, optional): Change the blank token. (default: ``'-'``)
|
|
1624
|
+
|
|
1625
|
+
Returns:
|
|
1626
|
+
Tuple[str, ...]:
|
|
1627
|
+
For models fine-tuned on ASR, returns the tuple of strings representing
|
|
1628
|
+
the output class labels.
|
|
1629
|
+
|
|
1630
|
+
Example
|
|
1631
|
+
>>> from torchaudio.pipelines import MMS_FA as bundle
|
|
1632
|
+
>>> bundle.get_dict()
|
|
1633
|
+
{'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27, '*': 28}
|
|
1634
|
+
>>> bundle.get_dict(star=None)
|
|
1635
|
+
{'-': 0, 'a': 1, 'i': 2, 'e': 3, 'n': 4, 'o': 5, 'u': 6, 't': 7, 's': 8, 'r': 9, 'm': 10, 'k': 11, 'l': 12, 'd': 13, 'g': 14, 'h': 15, 'y': 16, 'b': 17, 'p': 18, 'w': 19, 'c': 20, 'v': 21, 'j': 22, 'z': 23, 'f': 24, "'": 25, 'q': 26, 'x': 27}
|
|
1636
|
+
""" # noqa: E501
|
|
1637
|
+
return {k: i for i, k in enumerate(self.get_labels(star=star, blank=blank))}
|
|
1638
|
+
|
|
1639
|
+
def get_tokenizer(self) -> Tokenizer:
|
|
1640
|
+
"""Instantiate a Tokenizer.
|
|
1641
|
+
|
|
1642
|
+
Returns:
|
|
1643
|
+
Tokenizer
|
|
1644
|
+
"""
|
|
1645
|
+
return aligner.Tokenizer(self.get_dict())
|
|
1646
|
+
|
|
1647
|
+
def get_aligner(self) -> Aligner:
|
|
1648
|
+
"""Instantiate an Aligner.
|
|
1649
|
+
|
|
1650
|
+
Returns:
|
|
1651
|
+
Aligner
|
|
1652
|
+
"""
|
|
1653
|
+
return aligner.Aligner(blank=0)
|
|
1654
|
+
|
|
1655
|
+
|
|
1656
|
+
MMS_FA = Wav2Vec2FABundle(
|
|
1657
|
+
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
|
|
1658
|
+
{
|
|
1659
|
+
"extractor_mode": "layer_norm",
|
|
1660
|
+
"extractor_conv_layer_config": [
|
|
1661
|
+
(512, 10, 5),
|
|
1662
|
+
(512, 3, 2),
|
|
1663
|
+
(512, 3, 2),
|
|
1664
|
+
(512, 3, 2),
|
|
1665
|
+
(512, 3, 2),
|
|
1666
|
+
(512, 2, 2),
|
|
1667
|
+
(512, 2, 2),
|
|
1668
|
+
],
|
|
1669
|
+
"extractor_conv_bias": True,
|
|
1670
|
+
"encoder_embed_dim": 1024,
|
|
1671
|
+
"encoder_projection_dropout": 0.0,
|
|
1672
|
+
"encoder_pos_conv_kernel": 128,
|
|
1673
|
+
"encoder_pos_conv_groups": 16,
|
|
1674
|
+
"encoder_num_layers": 24,
|
|
1675
|
+
"encoder_num_heads": 16,
|
|
1676
|
+
"encoder_attention_dropout": 0.0,
|
|
1677
|
+
"encoder_ff_interm_features": 4096,
|
|
1678
|
+
"encoder_ff_interm_dropout": 0.1,
|
|
1679
|
+
"encoder_dropout": 0.0,
|
|
1680
|
+
"encoder_layer_norm_first": True,
|
|
1681
|
+
"encoder_layer_drop": 0.1,
|
|
1682
|
+
"aux_num_out": 28,
|
|
1683
|
+
},
|
|
1684
|
+
_labels=utils._get_mms_labels(),
|
|
1685
|
+
_sample_rate=16000,
|
|
1686
|
+
_normalize_waveform=True,
|
|
1687
|
+
_model_type="Wav2Vec2",
|
|
1688
|
+
)
|
|
1689
|
+
MMS_FA.__doc__ = """
|
|
1690
|
+
Trained on 31K hours of data in 1,130 languages from *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling`.
|
|
1691
|
+
|
|
1692
|
+
Published by the authors of *Scaling Speech Technology to 1,000+ Languages* :cite:`pratap2023scaling` under [`CC-BY-NC 4.0 License <https://github.com/facebookresearch/fairseq/tree/100cd91db19bb27277a06a25eb4154c805b10189/examples/mms#license>`__].
|
|
1693
|
+
|
|
1694
|
+
Please refer to :py:class:`torchaudio.pipelines.Wav2Vec2FABundle` for usage details.
|
|
1695
|
+
|
|
1696
|
+
.. note::
|
|
1697
|
+
|
|
1698
|
+
Unlike other Wav2Vec2 bundles, this model does not have a token for word boundary (like `|`). This makes the post-processing of alignments slightly different.
|
|
1699
|
+
""" # noqa: E501
|
|
@@ -1,3 +1,91 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
|
|
6
|
+
from torchaudio._internal import load_state_dict_from_url
|
|
7
|
+
from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _get_model(type_, params):
|
|
11
|
+
factories = {
|
|
12
|
+
"Wav2Vec2": wav2vec2_model,
|
|
13
|
+
"WavLM": wavlm_model,
|
|
14
|
+
}
|
|
15
|
+
if type_ not in factories:
|
|
16
|
+
raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
|
|
17
|
+
factory = factories[type_]
|
|
18
|
+
return factory(**params)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _Wav2Vec2Model(nn.Module):
|
|
22
|
+
"""Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
|
|
23
|
+
|
|
24
|
+
This is used for layer normalization at the input
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.model = model
|
|
30
|
+
self.normalize_waveform = normalize_waveform
|
|
31
|
+
self.apply_log_softmax = apply_log_softmax
|
|
32
|
+
self.append_star = append_star
|
|
33
|
+
|
|
34
|
+
def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
|
35
|
+
if self.normalize_waveform:
|
|
36
|
+
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
|
|
37
|
+
output, output_lengths = self.model(waveforms, lengths)
|
|
38
|
+
if self.apply_log_softmax:
|
|
39
|
+
output = torch.nn.functional.log_softmax(output, dim=-1)
|
|
40
|
+
if self.append_star:
|
|
41
|
+
star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
|
|
42
|
+
output = torch.cat((output, star_dim), dim=-1)
|
|
43
|
+
return output, output_lengths
|
|
44
|
+
|
|
45
|
+
@torch.jit.export
|
|
46
|
+
def extract_features(
|
|
47
|
+
self,
|
|
48
|
+
waveforms: Tensor,
|
|
49
|
+
lengths: Optional[Tensor] = None,
|
|
50
|
+
num_layers: Optional[int] = None,
|
|
51
|
+
) -> Tuple[List[Tensor], Optional[Tensor]]:
|
|
52
|
+
if self.normalize_waveform:
|
|
53
|
+
waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
|
|
54
|
+
return self.model.extract_features(waveforms, lengths, num_layers)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
|
|
58
|
+
"""Add extra transformations to the model"""
|
|
59
|
+
return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _remove_aux_axes(state_dict, axes):
|
|
63
|
+
# Remove the seemingly unnecessary axis
|
|
64
|
+
# For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
|
|
65
|
+
# It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
|
|
66
|
+
# but not used during the ASR training.
|
|
67
|
+
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
|
|
68
|
+
# https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
|
|
69
|
+
#
|
|
70
|
+
# Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
|
|
71
|
+
# that resembles mistake.
|
|
72
|
+
# The label `1` shows up in the training dataset of German (1 out of 16M),
|
|
73
|
+
# English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
|
|
74
|
+
for key in ["aux.weight", "aux.bias"]:
|
|
75
|
+
mat = state_dict[key]
|
|
76
|
+
state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_state_dict(url, dl_kwargs, remove_axes=None):
|
|
80
|
+
if not url.startswith("https"):
|
|
81
|
+
url = f"https://download.pytorch.org/torchaudio/models/{url}"
|
|
82
|
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
|
83
|
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
84
|
+
if remove_axes:
|
|
85
|
+
_remove_aux_axes(state_dict, remove_axes)
|
|
86
|
+
return state_dict
|
|
87
|
+
|
|
88
|
+
|
|
1
89
|
def _get_en_labels():
|
|
2
90
|
return (
|
|
3
91
|
"|",
|
|
@@ -224,3 +312,35 @@ def _get_it_labels():
|
|
|
224
312
|
"í",
|
|
225
313
|
"ï",
|
|
226
314
|
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _get_mms_labels():
|
|
318
|
+
return (
|
|
319
|
+
"a",
|
|
320
|
+
"i",
|
|
321
|
+
"e",
|
|
322
|
+
"n",
|
|
323
|
+
"o",
|
|
324
|
+
"u",
|
|
325
|
+
"t",
|
|
326
|
+
"s",
|
|
327
|
+
"r",
|
|
328
|
+
"m",
|
|
329
|
+
"k",
|
|
330
|
+
"l",
|
|
331
|
+
"d",
|
|
332
|
+
"g",
|
|
333
|
+
"h",
|
|
334
|
+
"y",
|
|
335
|
+
"b",
|
|
336
|
+
"p",
|
|
337
|
+
"w",
|
|
338
|
+
"c",
|
|
339
|
+
"v",
|
|
340
|
+
"j",
|
|
341
|
+
"z",
|
|
342
|
+
"f",
|
|
343
|
+
"'",
|
|
344
|
+
"q",
|
|
345
|
+
"x",
|
|
346
|
+
)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import warnings
|
|
3
2
|
from typing import List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
import torch
|
|
@@ -156,14 +155,6 @@ def apply_effects_tensor(
|
|
|
156
155
|
return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first)
|
|
157
156
|
|
|
158
157
|
|
|
159
|
-
_deprecation_message = (
|
|
160
|
-
"File-like object support in sox_io backend is deprecated, "
|
|
161
|
-
"and will be removed in v2.1. "
|
|
162
|
-
"See https://github.com/pytorch/audio/issues/2950 for the detail."
|
|
163
|
-
"Please migrate to the new dispatcher, or use soundfile backend."
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
|
|
167
158
|
@torchaudio._extension.fail_if_no_sox
|
|
168
159
|
def apply_effects_file(
|
|
169
160
|
path: str,
|
|
@@ -187,18 +178,8 @@ def apply_effects_file(
|
|
|
187
178
|
rate and leave samples untouched.
|
|
188
179
|
|
|
189
180
|
Args:
|
|
190
|
-
path (path-like object
|
|
191
|
-
Source of audio data.
|
|
192
|
-
(e.g. ``torch.jit.script``), the following types are accepted:
|
|
193
|
-
|
|
194
|
-
* ``path-like``: file path
|
|
195
|
-
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
|
|
196
|
-
which returns byte string of at most ``size`` length.
|
|
197
|
-
|
|
198
|
-
When the function is compiled by TorchScript, only ``str`` type is allowed.
|
|
199
|
-
|
|
200
|
-
Note: This argument is intentionally annotated as ``str`` only for
|
|
201
|
-
TorchScript compiler compatibility.
|
|
181
|
+
path (path-like object):
|
|
182
|
+
Source of audio data.
|
|
202
183
|
effects (List[List[str]]): List of effects.
|
|
203
184
|
normalize (bool, optional):
|
|
204
185
|
When ``True``, this function converts the native sample type to ``float32``.
|
|
@@ -283,13 +264,9 @@ def apply_effects_file(
|
|
|
283
264
|
"""
|
|
284
265
|
if not torch.jit.is_scripting():
|
|
285
266
|
if hasattr(path, "read"):
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
return ret
|
|
267
|
+
raise RuntimeError(
|
|
268
|
+
"apply_effects_file function does not support file-like object. "
|
|
269
|
+
"Please use torchaudio.io.AudioEffector."
|
|
270
|
+
)
|
|
291
271
|
path = os.fspath(path)
|
|
292
|
-
|
|
293
|
-
if ret is not None:
|
|
294
|
-
return ret
|
|
295
|
-
raise RuntimeError("Failed to load audio from {}".format(path))
|
|
272
|
+
return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
|
|
@@ -23,6 +23,7 @@ from ._transforms import (
|
|
|
23
23
|
Resample,
|
|
24
24
|
RNNTLoss,
|
|
25
25
|
SlidingWindowCmn,
|
|
26
|
+
SpecAugment,
|
|
26
27
|
SpectralCentroid,
|
|
27
28
|
Spectrogram,
|
|
28
29
|
Speed,
|
|
@@ -62,6 +63,7 @@ __all__ = [
|
|
|
62
63
|
"Resample",
|
|
63
64
|
"SlidingWindowCmn",
|
|
64
65
|
"SoudenMVDR",
|
|
66
|
+
"SpecAugment",
|
|
65
67
|
"SpectralCentroid",
|
|
66
68
|
"Spectrogram",
|
|
67
69
|
"Speed",
|