torchaudio 2.0.2__cp310-cp310-manylinux2014_aarch64.whl → 2.1.1__cp310-cp310-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +22 -3
- torchaudio/_backend/__init__.py +55 -4
- torchaudio/_backend/backend.py +53 -0
- torchaudio/_backend/common.py +52 -0
- torchaudio/_backend/ffmpeg.py +373 -0
- torchaudio/_backend/soundfile.py +54 -0
- torchaudio/_backend/soundfile_backend.py +457 -0
- torchaudio/_backend/sox.py +91 -0
- torchaudio/_backend/utils.py +81 -323
- torchaudio/_extension/__init__.py +55 -36
- torchaudio/_extension/utils.py +109 -17
- torchaudio/_internal/__init__.py +4 -1
- torchaudio/_internal/module_utils.py +37 -6
- torchaudio/backend/__init__.py +7 -11
- torchaudio/backend/_no_backend.py +24 -0
- torchaudio/backend/_sox_io_backend.py +297 -0
- torchaudio/backend/common.py +12 -52
- torchaudio/backend/no_backend.py +11 -21
- torchaudio/backend/soundfile_backend.py +11 -448
- torchaudio/backend/sox_io_backend.py +11 -435
- torchaudio/backend/utils.py +9 -18
- torchaudio/datasets/__init__.py +2 -0
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/cmudict.py +61 -62
- torchaudio/datasets/dr_vctk.py +1 -1
- torchaudio/datasets/gtzan.py +1 -1
- torchaudio/datasets/librilight_limited.py +1 -1
- torchaudio/datasets/librispeech.py +1 -1
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +1 -1
- torchaudio/datasets/ljspeech.py +1 -1
- torchaudio/datasets/musdb_hq.py +1 -1
- torchaudio/datasets/quesst14.py +1 -1
- torchaudio/datasets/speechcommands.py +1 -1
- torchaudio/datasets/tedlium.py +1 -1
- torchaudio/datasets/vctk.py +1 -1
- torchaudio/datasets/voxceleb1.py +1 -1
- torchaudio/datasets/yesno.py +1 -1
- torchaudio/functional/__init__.py +6 -2
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +69 -92
- torchaudio/functional/functional.py +99 -148
- torchaudio/io/__init__.py +4 -1
- torchaudio/io/_effector.py +347 -0
- torchaudio/io/_stream_reader.py +158 -90
- torchaudio/io/_stream_writer.py +196 -10
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/_torchaudio_sox.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/libtorchaudio_sox.so +0 -0
- torchaudio/models/__init__.py +14 -0
- torchaudio/models/decoder/__init__.py +22 -7
- torchaudio/models/decoder/_ctc_decoder.py +123 -69
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/rnnt_decoder.py +10 -14
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/wav2vec2/components.py +6 -10
- torchaudio/pipelines/__init__.py +9 -0
- torchaudio/pipelines/_squim_pipeline.py +176 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +198 -68
- torchaudio/pipelines/_wav2vec2/utils.py +120 -0
- torchaudio/sox_effects/sox_effects.py +7 -30
- torchaudio/transforms/__init__.py +2 -0
- torchaudio/transforms/_transforms.py +99 -54
- torchaudio/utils/download.py +2 -2
- torchaudio/utils/ffmpeg_utils.py +20 -15
- torchaudio/utils/sox_utils.py +8 -9
- torchaudio/version.py +2 -2
- torchaudio-2.1.1.dist-info/METADATA +113 -0
- torchaudio-2.1.1.dist-info/RECORD +117 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +1 -1
- torchaudio/io/_compat.py +0 -241
- torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
- torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
- torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
- torchaudio/lib/libflashlight-text.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
- torchaudio-2.0.2.dist-info/METADATA +0 -26
- torchaudio-2.0.2.dist-info/RECORD +0 -100
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torchaudio
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AttPool(nn.Module):
|
|
9
|
+
"""Attention-Pooling module that estimates the attention score.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
input_dim (int): Input feature dimension.
|
|
13
|
+
att_dim (int): Attention Tensor dimension.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, input_dim: int, att_dim: int):
|
|
17
|
+
super(AttPool, self).__init__()
|
|
18
|
+
|
|
19
|
+
self.linear1 = nn.Linear(input_dim, 1)
|
|
20
|
+
self.linear2 = nn.Linear(input_dim, att_dim)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
"""Apply attention and pooling.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
(torch.Tensor): Attention score with dimensions `(batch, att_dim)`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
att = self.linear1(x) # (batch, time, 1)
|
|
33
|
+
att = att.transpose(2, 1) # (batch, 1, time)
|
|
34
|
+
att = nn.functional.softmax(att, dim=2)
|
|
35
|
+
x = torch.matmul(att, x).squeeze(1) # (batch, input_dim)
|
|
36
|
+
x = self.linear2(x) # (batch, att_dim)
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Predictor(nn.Module):
|
|
41
|
+
"""Prediction module that apply pooling and attention, then predict subjective metric scores.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_dim (int): Input feature dimension.
|
|
45
|
+
att_dim (int): Attention Tensor dimension.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, input_dim: int, att_dim: int):
|
|
49
|
+
super(Predictor, self).__init__()
|
|
50
|
+
self.att_pool_layer = AttPool(input_dim, att_dim)
|
|
51
|
+
self.att_dim = att_dim
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
"""Predict subjective evaluation metric score.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
|
|
61
|
+
"""
|
|
62
|
+
x = self.att_pool_layer(x)
|
|
63
|
+
x = nn.functional.softmax(x, dim=1)
|
|
64
|
+
B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
|
|
65
|
+
x = (x * B).sum(dim=1)
|
|
66
|
+
return x
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SquimSubjective(nn.Module):
|
|
70
|
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores
|
|
71
|
+
for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*
|
|
72
|
+
:cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.
|
|
76
|
+
projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.
|
|
77
|
+
predictor (torch.nn.Module): Predict the subjective scores.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
|
|
81
|
+
super(SquimSubjective, self).__init__()
|
|
82
|
+
self.ssl_model = ssl_model
|
|
83
|
+
self.projector = projector
|
|
84
|
+
self.predictor = predictor
|
|
85
|
+
|
|
86
|
+
def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
87
|
+
"""Cut or pad the reference Tensor to make it aligned with waveform Tensor.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
|
|
91
|
+
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
(torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors
|
|
95
|
+
with same dimensions `(batch, time)`.
|
|
96
|
+
"""
|
|
97
|
+
T_waveform = waveform.shape[-1]
|
|
98
|
+
T_reference = reference.shape[-1]
|
|
99
|
+
if T_reference < T_waveform:
|
|
100
|
+
num_padding = T_waveform // T_reference + 1
|
|
101
|
+
reference = torch.cat([reference for _ in range(num_padding)], dim=1)
|
|
102
|
+
return waveform, reference[:, :T_waveform]
|
|
103
|
+
|
|
104
|
+
def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
|
|
105
|
+
"""Predict subjective evaluation metric score.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
|
|
109
|
+
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
|
|
113
|
+
"""
|
|
114
|
+
waveform, reference = self._align_shapes(waveform, reference)
|
|
115
|
+
waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
|
|
116
|
+
reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
|
|
117
|
+
concat = torch.cat((reference, waveform), dim=2)
|
|
118
|
+
score_diff = self.predictor(concat) # Score difference compared to the reference
|
|
119
|
+
return 5 - score_diff
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def squim_subjective_model(
|
|
123
|
+
ssl_type: str,
|
|
124
|
+
feat_dim: int,
|
|
125
|
+
proj_dim: int,
|
|
126
|
+
att_dim: int,
|
|
127
|
+
) -> SquimSubjective:
|
|
128
|
+
"""Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
ssl_type (str): Type of self-supervised learning (SSL) models.
|
|
132
|
+
Must be one of ["wav2vec2_base", "wav2vec2_large"].
|
|
133
|
+
feat_dim (int): Feature dimension of the SSL feature representation.
|
|
134
|
+
proj_dim (int): Output dimension of projection layer.
|
|
135
|
+
att_dim (int): Dimension of attention scores.
|
|
136
|
+
"""
|
|
137
|
+
ssl_model = getattr(torchaudio.models, ssl_type)()
|
|
138
|
+
projector = nn.Linear(feat_dim, proj_dim)
|
|
139
|
+
predictor = Predictor(proj_dim * 2, att_dim)
|
|
140
|
+
return SquimSubjective(ssl_model, projector, predictor)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def squim_subjective_base() -> SquimSubjective:
|
|
144
|
+
"""Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
|
|
145
|
+
return squim_subjective_model(
|
|
146
|
+
ssl_type="wav2vec2_base",
|
|
147
|
+
feat_dim=768,
|
|
148
|
+
proj_dim=32,
|
|
149
|
+
att_dim=5,
|
|
150
|
+
)
|
|
@@ -208,18 +208,13 @@ class ConvolutionalPositionalEmbedding(Module):
|
|
|
208
208
|
groups=groups,
|
|
209
209
|
)
|
|
210
210
|
|
|
211
|
-
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
|
211
|
+
self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
|
212
212
|
self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
|
|
213
213
|
|
|
214
214
|
def __prepare_scriptable__(self):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
# because of shadowing, so we check the module name directly.
|
|
219
|
-
# https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
|
|
220
|
-
if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
|
|
221
|
-
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
|
|
222
|
-
torch.nn.utils.remove_weight_norm(self.conv)
|
|
215
|
+
if self.conv.__class__.__name__ == "ParametrizedConv1d":
|
|
216
|
+
_LG.warning("Removing weight_norm from %s", self.__class__.__name__)
|
|
217
|
+
torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight")
|
|
223
218
|
return self
|
|
224
219
|
|
|
225
220
|
def forward(self, x):
|
|
@@ -458,9 +453,10 @@ class Transformer(Module):
|
|
|
458
453
|
raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
|
|
459
454
|
|
|
460
455
|
ret: List[Tensor] = []
|
|
456
|
+
position_bias = None
|
|
461
457
|
x = self._preprocess(x)
|
|
462
458
|
for layer in self.layers:
|
|
463
|
-
x,
|
|
459
|
+
x, position_bias = layer(x, attention_mask, position_bias=position_bias)
|
|
464
460
|
ret.append(x)
|
|
465
461
|
if num_layers is not None and len(ret) >= num_layers:
|
|
466
462
|
return ret
|
torchaudio/pipelines/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from ._source_separation_pipeline import (
|
|
|
4
4
|
HDEMUCS_HIGH_MUSDB_PLUS,
|
|
5
5
|
SourceSeparationBundle,
|
|
6
6
|
)
|
|
7
|
+
from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
|
|
7
8
|
from ._tts import (
|
|
8
9
|
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
|
|
9
10
|
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
|
|
@@ -17,6 +18,7 @@ from ._wav2vec2.impl import (
|
|
|
17
18
|
HUBERT_BASE,
|
|
18
19
|
HUBERT_LARGE,
|
|
19
20
|
HUBERT_XLARGE,
|
|
21
|
+
MMS_FA,
|
|
20
22
|
VOXPOPULI_ASR_BASE_10K_DE,
|
|
21
23
|
VOXPOPULI_ASR_BASE_10K_EN,
|
|
22
24
|
VOXPOPULI_ASR_BASE_10K_ES,
|
|
@@ -40,6 +42,7 @@ from ._wav2vec2.impl import (
|
|
|
40
42
|
WAV2VEC2_XLSR_300M,
|
|
41
43
|
Wav2Vec2ASRBundle,
|
|
42
44
|
Wav2Vec2Bundle,
|
|
45
|
+
Wav2Vec2FABundle,
|
|
43
46
|
WAVLM_BASE,
|
|
44
47
|
WAVLM_BASE_PLUS,
|
|
45
48
|
WAVLM_LARGE,
|
|
@@ -50,6 +53,7 @@ from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
|
|
|
50
53
|
__all__ = [
|
|
51
54
|
"Wav2Vec2Bundle",
|
|
52
55
|
"Wav2Vec2ASRBundle",
|
|
56
|
+
"Wav2Vec2FABundle",
|
|
53
57
|
"WAV2VEC2_BASE",
|
|
54
58
|
"WAV2VEC2_LARGE",
|
|
55
59
|
"WAV2VEC2_LARGE_LV60K",
|
|
@@ -76,6 +80,7 @@ __all__ = [
|
|
|
76
80
|
"HUBERT_XLARGE",
|
|
77
81
|
"HUBERT_ASR_LARGE",
|
|
78
82
|
"HUBERT_ASR_XLARGE",
|
|
83
|
+
"MMS_FA",
|
|
79
84
|
"WAVLM_BASE",
|
|
80
85
|
"WAVLM_BASE_PLUS",
|
|
81
86
|
"WAVLM_LARGE",
|
|
@@ -90,4 +95,8 @@ __all__ = [
|
|
|
90
95
|
"CONVTASNET_BASE_LIBRI2MIX",
|
|
91
96
|
"HDEMUCS_HIGH_MUSDB_PLUS",
|
|
92
97
|
"HDEMUCS_HIGH_MUSDB",
|
|
98
|
+
"SQUIM_OBJECTIVE",
|
|
99
|
+
"SQUIM_SUBJECTIVE",
|
|
100
|
+
"SquimObjectiveBundle",
|
|
101
|
+
"SquimSubjectiveBundle",
|
|
93
102
|
]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from torchaudio._internal import load_state_dict_from_url
|
|
4
|
+
|
|
5
|
+
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class SquimObjectiveBundle:
|
|
10
|
+
"""Data class that bundles associated information to use pretrained
|
|
11
|
+
:py:class:`~torchaudio.models.SquimObjective` model.
|
|
12
|
+
|
|
13
|
+
This class provides interfaces for instantiating the pretrained model along with
|
|
14
|
+
the information necessary to retrieve pretrained weights and additional data
|
|
15
|
+
to be used with the model.
|
|
16
|
+
|
|
17
|
+
Torchaudio library instantiates objects of this class, each of which represents
|
|
18
|
+
a different pretrained model. Client code should access pretrained models via these
|
|
19
|
+
instances.
|
|
20
|
+
|
|
21
|
+
This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
|
|
22
|
+
A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
|
|
23
|
+
|
|
24
|
+
Example: Estimate the objective metric scores for the input waveform.
|
|
25
|
+
>>> import torch
|
|
26
|
+
>>> import torchaudio
|
|
27
|
+
>>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Load the SquimObjective bundle
|
|
30
|
+
>>> model = bundle.get_model()
|
|
31
|
+
Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
|
|
32
|
+
100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
|
|
33
|
+
>>>
|
|
34
|
+
>>> # Resample audio to the expected sampling rate
|
|
35
|
+
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Estimate objective metric scores
|
|
38
|
+
>>> scores = model(waveform)
|
|
39
|
+
>>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
|
|
40
|
+
""" # noqa: E501
|
|
41
|
+
|
|
42
|
+
_path: str
|
|
43
|
+
_sample_rate: float
|
|
44
|
+
|
|
45
|
+
def _get_state_dict(self, dl_kwargs):
|
|
46
|
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
|
47
|
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
|
48
|
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
49
|
+
return state_dict
|
|
50
|
+
|
|
51
|
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
|
52
|
+
"""Construct the SquimObjective model, and load the pretrained weight.
|
|
53
|
+
|
|
54
|
+
The weight file is downloaded from the internet and cached with
|
|
55
|
+
:func:`torch.hub.load_state_dict_from_url`
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
|
62
|
+
"""
|
|
63
|
+
model = squim_objective_base()
|
|
64
|
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
|
65
|
+
model.eval()
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def sample_rate(self):
|
|
70
|
+
"""Sample rate of the audio that the model is trained on.
|
|
71
|
+
|
|
72
|
+
:type: float
|
|
73
|
+
"""
|
|
74
|
+
return self._sample_rate
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
|
78
|
+
"squim_objective_dns2020.pth",
|
|
79
|
+
_sample_rate=16000,
|
|
80
|
+
)
|
|
81
|
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
|
82
|
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
|
83
|
+
|
|
84
|
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
|
85
|
+
The weights are under `Creative Commons Attribution 4.0 International License
|
|
86
|
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
|
87
|
+
|
|
88
|
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class SquimSubjectiveBundle:
|
|
94
|
+
"""Data class that bundles associated information to use pretrained
|
|
95
|
+
:py:class:`~torchaudio.models.SquimSubjective` model.
|
|
96
|
+
|
|
97
|
+
This class provides interfaces for instantiating the pretrained model along with
|
|
98
|
+
the information necessary to retrieve pretrained weights and additional data
|
|
99
|
+
to be used with the model.
|
|
100
|
+
|
|
101
|
+
Torchaudio library instantiates objects of this class, each of which represents
|
|
102
|
+
a different pretrained model. Client code should access pretrained models via these
|
|
103
|
+
instances.
|
|
104
|
+
|
|
105
|
+
This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
|
|
106
|
+
A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
|
|
107
|
+
|
|
108
|
+
Example: Estimate the subjective metric scores for the input waveform.
|
|
109
|
+
>>> import torch
|
|
110
|
+
>>> import torchaudio
|
|
111
|
+
>>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
|
|
112
|
+
>>>
|
|
113
|
+
>>> # Load the SquimSubjective bundle
|
|
114
|
+
>>> model = bundle.get_model()
|
|
115
|
+
Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
|
|
116
|
+
100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
|
|
117
|
+
>>>
|
|
118
|
+
>>> # Resample audio to the expected sampling rate
|
|
119
|
+
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
|
|
120
|
+
>>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
|
|
121
|
+
>>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
|
|
122
|
+
>>>
|
|
123
|
+
>>> # Estimate subjective metric scores
|
|
124
|
+
>>> score = model(waveform, reference)
|
|
125
|
+
>>> print(f"MOS: {score}.")
|
|
126
|
+
""" # noqa: E501
|
|
127
|
+
|
|
128
|
+
_path: str
|
|
129
|
+
_sample_rate: float
|
|
130
|
+
|
|
131
|
+
def _get_state_dict(self, dl_kwargs):
|
|
132
|
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
|
133
|
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
|
134
|
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
135
|
+
return state_dict
|
|
136
|
+
|
|
137
|
+
def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
|
|
138
|
+
"""Construct the SquimSubjective model, and load the pretrained weight.
|
|
139
|
+
|
|
140
|
+
The weight file is downloaded from the internet and cached with
|
|
141
|
+
:func:`torch.hub.load_state_dict_from_url`
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
|
148
|
+
"""
|
|
149
|
+
model = squim_subjective_base()
|
|
150
|
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
|
151
|
+
model.eval()
|
|
152
|
+
return model
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def sample_rate(self):
|
|
156
|
+
"""Sample rate of the audio that the model is trained on.
|
|
157
|
+
|
|
158
|
+
:type: float
|
|
159
|
+
"""
|
|
160
|
+
return self._sample_rate
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
|
|
164
|
+
"squim_subjective_bvcc_daps.pth",
|
|
165
|
+
_sample_rate=16000,
|
|
166
|
+
)
|
|
167
|
+
SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
|
|
168
|
+
as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
|
|
169
|
+
on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
|
|
170
|
+
|
|
171
|
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
|
|
172
|
+
The weights are under `Creative Commons Attribution Non Commercial 4.0 International
|
|
173
|
+
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
|
|
174
|
+
|
|
175
|
+
Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
|
|
176
|
+
"""
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torchaudio.functional as F
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torchaudio.functional import TokenSpan
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ITokenizer(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def __call__(self, transcript: List[str]) -> List[List[str]]:
|
|
13
|
+
"""Tokenize the given transcript (list of word)
|
|
14
|
+
|
|
15
|
+
.. note::
|
|
16
|
+
|
|
17
|
+
The toranscript must be normalized.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
transcript (list of str): Transcript (list of word).
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
(list of int): List of token sequences
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Tokenizer(ITokenizer):
|
|
28
|
+
def __init__(self, dictionary: Dict[str, int]):
|
|
29
|
+
self.dictionary = dictionary
|
|
30
|
+
|
|
31
|
+
def __call__(self, transcript: List[str]) -> List[List[int]]:
|
|
32
|
+
return [[self.dictionary[c] for c in word] for word in transcript]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
|
|
36
|
+
device = emission.device
|
|
37
|
+
emission = emission.unsqueeze(0)
|
|
38
|
+
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
|
|
39
|
+
|
|
40
|
+
aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
|
|
41
|
+
|
|
42
|
+
scores = scores.exp() # convert back to probability
|
|
43
|
+
aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
|
|
44
|
+
return aligned_tokens, scores
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class IAligner(ABC):
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
|
|
50
|
+
"""Generate list of time-stamped token sequences
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
emission (Tensor): Sequence of token probability distributions in log-domain.
|
|
54
|
+
Shape: `(time, tokens)`.
|
|
55
|
+
tokens (list of integer sequence): Tokenized transcript.
|
|
56
|
+
Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
(list of TokenSpan sequence): Tokens with time stamps and scores.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _unflatten(list_, lengths):
|
|
64
|
+
assert len(list_) == sum(lengths)
|
|
65
|
+
i = 0
|
|
66
|
+
ret = []
|
|
67
|
+
for l in lengths:
|
|
68
|
+
ret.append(list_[i : i + l])
|
|
69
|
+
i += l
|
|
70
|
+
return ret
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _flatten(nested_list):
|
|
74
|
+
return [item for list_ in nested_list for item in list_]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class Aligner(IAligner):
|
|
78
|
+
def __init__(self, blank):
|
|
79
|
+
self.blank = blank
|
|
80
|
+
|
|
81
|
+
def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
|
|
82
|
+
if emission.ndim != 2:
|
|
83
|
+
raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
|
|
84
|
+
|
|
85
|
+
aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
|
|
86
|
+
spans = F.merge_tokens(aligned_tokens, scores)
|
|
87
|
+
return _unflatten(spans, [len(ts) for ts in tokens])
|