torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.0.dist-info/LICENSE +25 -0
- torchaudio-2.9.0.dist-info/METADATA +122 -0
- torchaudio-2.9.0.dist-info/RECORD +86 -0
- torchaudio-2.9.0.dist-info/WHEEL +5 -0
- torchaudio-2.9.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torchaudio
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AttPool(nn.Module):
|
|
9
|
+
"""Attention-Pooling module that estimates the attention score.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
input_dim (int): Input feature dimension.
|
|
13
|
+
att_dim (int): Attention Tensor dimension.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, input_dim: int, att_dim: int):
|
|
17
|
+
super(AttPool, self).__init__()
|
|
18
|
+
|
|
19
|
+
self.linear1 = nn.Linear(input_dim, 1)
|
|
20
|
+
self.linear2 = nn.Linear(input_dim, att_dim)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
"""Apply attention and pooling.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
(torch.Tensor): Attention score with dimensions `(batch, att_dim)`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
att = self.linear1(x) # (batch, time, 1)
|
|
33
|
+
att = att.transpose(2, 1) # (batch, 1, time)
|
|
34
|
+
att = nn.functional.softmax(att, dim=2)
|
|
35
|
+
x = torch.matmul(att, x).squeeze(1) # (batch, input_dim)
|
|
36
|
+
x = self.linear2(x) # (batch, att_dim)
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Predictor(nn.Module):
|
|
41
|
+
"""Prediction module that apply pooling and attention, then predict subjective metric scores.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_dim (int): Input feature dimension.
|
|
45
|
+
att_dim (int): Attention Tensor dimension.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, input_dim: int, att_dim: int):
|
|
49
|
+
super(Predictor, self).__init__()
|
|
50
|
+
self.att_pool_layer = AttPool(input_dim, att_dim)
|
|
51
|
+
self.att_dim = att_dim
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
"""Predict subjective evaluation metric score.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
|
|
61
|
+
"""
|
|
62
|
+
x = self.att_pool_layer(x)
|
|
63
|
+
x = nn.functional.softmax(x, dim=1)
|
|
64
|
+
B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
|
|
65
|
+
x = (x * B).sum(dim=1)
|
|
66
|
+
return x
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class SquimSubjective(nn.Module):
|
|
70
|
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores
|
|
71
|
+
for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*
|
|
72
|
+
:cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.
|
|
76
|
+
projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.
|
|
77
|
+
predictor (torch.nn.Module): Predict the subjective scores.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
|
|
81
|
+
super(SquimSubjective, self).__init__()
|
|
82
|
+
self.ssl_model = ssl_model
|
|
83
|
+
self.projector = projector
|
|
84
|
+
self.predictor = predictor
|
|
85
|
+
|
|
86
|
+
def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
87
|
+
"""Cut or pad the reference Tensor to make it aligned with waveform Tensor.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
|
|
91
|
+
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
(torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors
|
|
95
|
+
with same dimensions `(batch, time)`.
|
|
96
|
+
"""
|
|
97
|
+
T_waveform = waveform.shape[-1]
|
|
98
|
+
T_reference = reference.shape[-1]
|
|
99
|
+
if T_reference < T_waveform:
|
|
100
|
+
num_padding = T_waveform // T_reference + 1
|
|
101
|
+
reference = torch.cat([reference for _ in range(num_padding)], dim=1)
|
|
102
|
+
return waveform, reference[:, :T_waveform]
|
|
103
|
+
|
|
104
|
+
def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
|
|
105
|
+
"""Predict subjective evaluation metric score.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
|
|
109
|
+
reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
(torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
|
|
113
|
+
"""
|
|
114
|
+
waveform, reference = self._align_shapes(waveform, reference)
|
|
115
|
+
waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
|
|
116
|
+
reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
|
|
117
|
+
concat = torch.cat((reference, waveform), dim=2)
|
|
118
|
+
score_diff = self.predictor(concat) # Score difference compared to the reference
|
|
119
|
+
return 5 - score_diff
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def squim_subjective_model(
|
|
123
|
+
ssl_type: str,
|
|
124
|
+
feat_dim: int,
|
|
125
|
+
proj_dim: int,
|
|
126
|
+
att_dim: int,
|
|
127
|
+
) -> SquimSubjective:
|
|
128
|
+
"""Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
ssl_type (str): Type of self-supervised learning (SSL) models.
|
|
132
|
+
Must be one of ["wav2vec2_base", "wav2vec2_large"].
|
|
133
|
+
feat_dim (int): Feature dimension of the SSL feature representation.
|
|
134
|
+
proj_dim (int): Output dimension of projection layer.
|
|
135
|
+
att_dim (int): Dimension of attention scores.
|
|
136
|
+
"""
|
|
137
|
+
ssl_model = getattr(torchaudio.models, ssl_type)()
|
|
138
|
+
projector = nn.Linear(feat_dim, proj_dim)
|
|
139
|
+
predictor = Predictor(proj_dim * 2, att_dim)
|
|
140
|
+
return SquimSubjective(ssl_model, projector, predictor)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def squim_subjective_base() -> SquimSubjective:
|
|
144
|
+
"""Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
|
|
145
|
+
return squim_subjective_model(
|
|
146
|
+
ssl_type="wav2vec2_base",
|
|
147
|
+
feat_dim=768,
|
|
148
|
+
proj_dim=32,
|
|
149
|
+
att_dim=5,
|
|
150
|
+
)
|