torchaudio 2.0.2__cp311-cp311-manylinux1_x86_64.whl → 2.1.1__cp311-cp311-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +22 -3
- torchaudio/_backend/__init__.py +55 -4
- torchaudio/_backend/backend.py +53 -0
- torchaudio/_backend/common.py +52 -0
- torchaudio/_backend/ffmpeg.py +373 -0
- torchaudio/_backend/soundfile.py +54 -0
- torchaudio/_backend/soundfile_backend.py +457 -0
- torchaudio/_backend/sox.py +91 -0
- torchaudio/_backend/utils.py +81 -323
- torchaudio/_extension/__init__.py +55 -36
- torchaudio/_extension/utils.py +109 -17
- torchaudio/_internal/__init__.py +4 -1
- torchaudio/_internal/module_utils.py +37 -6
- torchaudio/backend/__init__.py +7 -11
- torchaudio/backend/_no_backend.py +24 -0
- torchaudio/backend/_sox_io_backend.py +297 -0
- torchaudio/backend/common.py +12 -52
- torchaudio/backend/no_backend.py +11 -21
- torchaudio/backend/soundfile_backend.py +11 -448
- torchaudio/backend/sox_io_backend.py +11 -435
- torchaudio/backend/utils.py +9 -18
- torchaudio/datasets/__init__.py +2 -0
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/cmudict.py +61 -62
- torchaudio/datasets/dr_vctk.py +1 -1
- torchaudio/datasets/gtzan.py +1 -1
- torchaudio/datasets/librilight_limited.py +1 -1
- torchaudio/datasets/librispeech.py +1 -1
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +1 -1
- torchaudio/datasets/ljspeech.py +1 -1
- torchaudio/datasets/musdb_hq.py +1 -1
- torchaudio/datasets/quesst14.py +1 -1
- torchaudio/datasets/speechcommands.py +1 -1
- torchaudio/datasets/tedlium.py +1 -1
- torchaudio/datasets/vctk.py +1 -1
- torchaudio/datasets/voxceleb1.py +1 -1
- torchaudio/datasets/yesno.py +1 -1
- torchaudio/functional/__init__.py +6 -2
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +69 -92
- torchaudio/functional/functional.py +99 -148
- torchaudio/io/__init__.py +4 -1
- torchaudio/io/_effector.py +347 -0
- torchaudio/io/_stream_reader.py +158 -90
- torchaudio/io/_stream_writer.py +196 -10
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/_torchaudio_sox.so +0 -0
- torchaudio/lib/libctc_prefix_decoder.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
- torchaudio/lib/libtorchaudio_sox.so +0 -0
- torchaudio/lib/pybind11_prefixctc.so +0 -0
- torchaudio/models/__init__.py +14 -0
- torchaudio/models/decoder/__init__.py +22 -7
- torchaudio/models/decoder/_ctc_decoder.py +123 -69
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/rnnt_decoder.py +10 -14
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/wav2vec2/components.py +6 -10
- torchaudio/pipelines/__init__.py +9 -0
- torchaudio/pipelines/_squim_pipeline.py +176 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +198 -68
- torchaudio/pipelines/_wav2vec2/utils.py +120 -0
- torchaudio/sox_effects/sox_effects.py +7 -30
- torchaudio/transforms/__init__.py +2 -0
- torchaudio/transforms/_transforms.py +99 -54
- torchaudio/utils/download.py +2 -2
- torchaudio/utils/ffmpeg_utils.py +20 -15
- torchaudio/utils/sox_utils.py +8 -9
- torchaudio/version.py +2 -2
- torchaudio-2.1.1.dist-info/METADATA +113 -0
- torchaudio-2.1.1.dist-info/RECORD +119 -0
- torchaudio/io/_compat.py +0 -241
- torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
- torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
- torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
- torchaudio/lib/libflashlight-text.so +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
- torchaudio-2.0.2.dist-info/METADATA +0 -26
- torchaudio-2.0.2.dist-info/RECORD +0 -100
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from typing import List, NamedTuple, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torchaudio
|
|
9
|
+
|
|
10
|
+
torchaudio._extension._load_lib("libctc_prefix_decoder")
|
|
11
|
+
import torchaudio.lib.pybind11_prefixctc as cuctc
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_vocab_list(vocab_file):
|
|
18
|
+
vocab = []
|
|
19
|
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
|
20
|
+
for line in f:
|
|
21
|
+
line = line.strip().split()
|
|
22
|
+
vocab.append(line[0])
|
|
23
|
+
return vocab
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CUCTCHypothesis(NamedTuple):
|
|
27
|
+
r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
|
|
28
|
+
tokens: List[int]
|
|
29
|
+
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
|
|
30
|
+
|
|
31
|
+
words: List[str]
|
|
32
|
+
"""List of predicted tokens. Algin with modeling unit.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
score: float
|
|
36
|
+
"""Score corresponding to hypothesis"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CUCTCDecoder:
|
|
43
|
+
"""CUDA CTC beam search decoder.
|
|
44
|
+
|
|
45
|
+
.. devices:: CUDA
|
|
46
|
+
|
|
47
|
+
Note:
|
|
48
|
+
To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
vocab_list: List[str],
|
|
54
|
+
blank_id: int = 0,
|
|
55
|
+
beam_size: int = 10,
|
|
56
|
+
nbest: int = 1,
|
|
57
|
+
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
|
|
58
|
+
cuda_stream: torch.cuda.streams.Stream = None,
|
|
59
|
+
):
|
|
60
|
+
"""
|
|
61
|
+
Args:
|
|
62
|
+
blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
|
|
63
|
+
vocab_list (List[str]): list of vocabulary tokens
|
|
64
|
+
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
|
|
65
|
+
nbest (int): number of best decodings to return
|
|
66
|
+
blank_skip_threshold (float):
|
|
67
|
+
skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
|
|
68
|
+
(Default: 0.95).
|
|
69
|
+
cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
if cuda_stream:
|
|
73
|
+
if not isinstance(cuda_stream, torch.cuda.streams.Stream):
|
|
74
|
+
raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
|
|
75
|
+
cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
|
|
76
|
+
self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
|
|
77
|
+
self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
|
|
78
|
+
if blank_id != 0:
|
|
79
|
+
raise AssertionError("blank_id must be 0")
|
|
80
|
+
self.blank_id = blank_id
|
|
81
|
+
self.vocab_list = vocab_list
|
|
82
|
+
self.space_id = 0
|
|
83
|
+
self.nbest = nbest
|
|
84
|
+
if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
|
|
85
|
+
raise AssertionError("blank_skip_threshold must be between 0 and 1")
|
|
86
|
+
self.blank_skip_threshold = math.log(blank_skip_threshold)
|
|
87
|
+
self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
|
|
88
|
+
|
|
89
|
+
def __del__(self):
|
|
90
|
+
if cuctc is not None:
|
|
91
|
+
cuctc.prefixCTC_free(self.internal_data)
|
|
92
|
+
|
|
93
|
+
def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
|
|
94
|
+
"""
|
|
95
|
+
Args:
|
|
96
|
+
log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
|
|
97
|
+
probability distribution over labels; log_softmax(output of acoustic model).
|
|
98
|
+
lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
|
|
99
|
+
in time axis of the output Tensor in each batch.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List[List[CUCTCHypothesis]]:
|
|
103
|
+
List of sorted best hypotheses for each audio sequence in the batch.
|
|
104
|
+
"""
|
|
105
|
+
if not encoder_out_lens.dtype == torch.int32:
|
|
106
|
+
raise AssertionError("encoder_out_lens must be torch.int32")
|
|
107
|
+
if not log_prob.dtype == torch.float32:
|
|
108
|
+
raise AssertionError("log_prob must be torch.float32")
|
|
109
|
+
if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
|
|
110
|
+
raise AssertionError("inputs must be cuda tensors")
|
|
111
|
+
if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
|
|
112
|
+
raise AssertionError("input tensors must be contiguous")
|
|
113
|
+
required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
|
|
114
|
+
self.internal_data,
|
|
115
|
+
self.memory.data_ptr(),
|
|
116
|
+
self.memory.size(0),
|
|
117
|
+
log_prob.data_ptr(),
|
|
118
|
+
encoder_out_lens.data_ptr(),
|
|
119
|
+
log_prob.size(),
|
|
120
|
+
log_prob.stride(),
|
|
121
|
+
self.beam_size,
|
|
122
|
+
self.blank_id,
|
|
123
|
+
self.space_id,
|
|
124
|
+
self.blank_skip_threshold,
|
|
125
|
+
)
|
|
126
|
+
if required_size > 0:
|
|
127
|
+
self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
|
|
128
|
+
_, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
|
|
129
|
+
self.internal_data,
|
|
130
|
+
self.memory.data_ptr(),
|
|
131
|
+
self.memory.size(0),
|
|
132
|
+
log_prob.data_ptr(),
|
|
133
|
+
encoder_out_lens.data_ptr(),
|
|
134
|
+
log_prob.size(),
|
|
135
|
+
log_prob.stride(),
|
|
136
|
+
self.beam_size,
|
|
137
|
+
self.blank_id,
|
|
138
|
+
self.space_id,
|
|
139
|
+
self.blank_skip_threshold,
|
|
140
|
+
)
|
|
141
|
+
batch_size = len(score_hyps)
|
|
142
|
+
hypos = []
|
|
143
|
+
for i in range(batch_size):
|
|
144
|
+
hypos.append(
|
|
145
|
+
[
|
|
146
|
+
CUCTCHypothesis(
|
|
147
|
+
tokens=score_hyps[i][j][1],
|
|
148
|
+
words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
|
|
149
|
+
score=score_hyps[i][j][0],
|
|
150
|
+
)
|
|
151
|
+
for j in range(self.nbest)
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
return hypos
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def cuda_ctc_decoder(
|
|
158
|
+
tokens: Union[str, List[str]],
|
|
159
|
+
nbest: int = 1,
|
|
160
|
+
beam_size: int = 10,
|
|
161
|
+
blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
|
|
162
|
+
) -> CUCTCDecoder:
|
|
163
|
+
"""Builds an instance of :class:`CUCTCDecoder`.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
tokens (str or List[str]): File or list containing valid tokens.
|
|
167
|
+
If using a file, the expected format is for tokens mapping to the same index to be on the same line
|
|
168
|
+
beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
|
|
169
|
+
nbest (int): The number of best decodings to return
|
|
170
|
+
blank_id (int): The token ID corresopnding to the blank symbol.
|
|
171
|
+
blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
|
|
172
|
+
(Default: 0.95).
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
CUCTCDecoder: decoder
|
|
176
|
+
|
|
177
|
+
Example
|
|
178
|
+
>>> decoder = cuda_ctc_decoder(
|
|
179
|
+
>>> vocab_file="tokens.txt",
|
|
180
|
+
>>> blank_skip_threshold=0.95,
|
|
181
|
+
>>> )
|
|
182
|
+
>>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
|
|
183
|
+
"""
|
|
184
|
+
if type(tokens) == str:
|
|
185
|
+
tokens = _get_vocab_list(tokens)
|
|
186
|
+
|
|
187
|
+
return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
|
|
@@ -109,13 +109,9 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
109
109
|
|
|
110
110
|
self.step_max_tokens = step_max_tokens
|
|
111
111
|
|
|
112
|
-
def _init_b_hypos(self,
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
state = _get_hypo_state(hypo)
|
|
116
|
-
else:
|
|
117
|
-
token = self.blank
|
|
118
|
-
state = None
|
|
112
|
+
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
|
|
113
|
+
token = self.blank
|
|
114
|
+
state = None
|
|
119
115
|
|
|
120
116
|
one_tensor = torch.tensor([1], device=device)
|
|
121
117
|
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
|
|
@@ -230,14 +226,14 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
230
226
|
def _search(
|
|
231
227
|
self,
|
|
232
228
|
enc_out: torch.Tensor,
|
|
233
|
-
hypo: Optional[Hypothesis],
|
|
229
|
+
hypo: Optional[List[Hypothesis]],
|
|
234
230
|
beam_width: int,
|
|
235
231
|
) -> List[Hypothesis]:
|
|
236
232
|
n_time_steps = enc_out.shape[1]
|
|
237
233
|
device = enc_out.device
|
|
238
234
|
|
|
239
235
|
a_hypos: List[Hypothesis] = []
|
|
240
|
-
b_hypos = self._init_b_hypos(hypo
|
|
236
|
+
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
|
|
241
237
|
for t in range(n_time_steps):
|
|
242
238
|
a_hypos = b_hypos
|
|
243
239
|
b_hypos = torch.jit.annotate(List[Hypothesis], [])
|
|
@@ -263,7 +259,7 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
263
259
|
if a_hypos:
|
|
264
260
|
symbols_current_t += 1
|
|
265
261
|
|
|
266
|
-
_, sorted_idx = torch.tensor([self.hypo_sort_key(
|
|
262
|
+
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
|
|
267
263
|
b_hypos = [b_hypos[idx] for idx in sorted_idx]
|
|
268
264
|
|
|
269
265
|
return b_hypos
|
|
@@ -290,8 +286,8 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
290
286
|
|
|
291
287
|
if length.shape != () and length.shape != (1,):
|
|
292
288
|
raise ValueError("length must be of shape () or (1,)")
|
|
293
|
-
if
|
|
294
|
-
|
|
289
|
+
if length.dim() == 0:
|
|
290
|
+
length = length.unsqueeze(0)
|
|
295
291
|
|
|
296
292
|
enc_out, _ = self.model.transcribe(input, length)
|
|
297
293
|
return self._search(enc_out, None, beam_width)
|
|
@@ -303,7 +299,7 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
303
299
|
length: torch.Tensor,
|
|
304
300
|
beam_width: int,
|
|
305
301
|
state: Optional[List[List[torch.Tensor]]] = None,
|
|
306
|
-
hypothesis: Optional[Hypothesis] = None,
|
|
302
|
+
hypothesis: Optional[List[Hypothesis]] = None,
|
|
307
303
|
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
|
|
308
304
|
r"""Performs beam search for the given input sequence in streaming mode.
|
|
309
305
|
|
|
@@ -318,7 +314,7 @@ class RNNTBeamSearch(torch.nn.Module):
|
|
|
318
314
|
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
319
315
|
representing transcription network internal state generated in preceding
|
|
320
316
|
invocation. (Default: ``None``)
|
|
321
|
-
hypothesis (Hypothesis or None):
|
|
317
|
+
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
|
|
322
318
|
search with. (Default: ``None``)
|
|
323
319
|
|
|
324
320
|
Returns:
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .objective import squim_objective_base, squim_objective_model, SquimObjective
|
|
2
|
+
from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"squim_objective_base",
|
|
6
|
+
"squim_objective_model",
|
|
7
|
+
"squim_subjective_base",
|
|
8
|
+
"squim_subjective_model",
|
|
9
|
+
"SquimObjective",
|
|
10
|
+
"SquimSubjective",
|
|
11
|
+
]
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def transform_wb_pesq_range(x: float) -> float:
|
|
10
|
+
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
|
|
11
|
+
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
|
|
12
|
+
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
x (float): Narrow-band PESQ score.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
(float): Wide-band PESQ score.
|
|
19
|
+
"""
|
|
20
|
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
PESQRange: Tuple[float, float] = (
|
|
24
|
+
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
|
|
25
|
+
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
|
|
26
|
+
# We are using 1.0 as a reasonable approximation.
|
|
27
|
+
transform_wb_pesq_range(4.5),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RangeSigmoid(nn.Module):
|
|
32
|
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
|
33
|
+
super(RangeSigmoid, self).__init__()
|
|
34
|
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
|
35
|
+
self.val_range: Tuple[float, float] = val_range
|
|
36
|
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
|
37
|
+
|
|
38
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
|
|
40
|
+
return out
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Encoder(nn.Module):
|
|
44
|
+
"""Encoder module that transform 1D waveform to 2D representations.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
|
|
48
|
+
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
|
52
|
+
super(Encoder, self).__init__()
|
|
53
|
+
|
|
54
|
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
"""Apply waveforms to convolutional layer and ReLU layer.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
|
|
64
|
+
"""
|
|
65
|
+
out = x.unsqueeze(dim=1)
|
|
66
|
+
out = F.relu(self.conv1d(out))
|
|
67
|
+
return out
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SingleRNN(nn.Module):
|
|
71
|
+
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
|
|
72
|
+
super(SingleRNN, self).__init__()
|
|
73
|
+
|
|
74
|
+
self.rnn_type = rnn_type
|
|
75
|
+
self.input_size = input_size
|
|
76
|
+
self.hidden_size = hidden_size
|
|
77
|
+
|
|
78
|
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
|
79
|
+
input_size,
|
|
80
|
+
hidden_size,
|
|
81
|
+
1,
|
|
82
|
+
dropout=dropout,
|
|
83
|
+
batch_first=True,
|
|
84
|
+
bidirectional=True,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
|
88
|
+
|
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
# input shape: batch, seq, dim
|
|
91
|
+
out, _ = self.rnn(x)
|
|
92
|
+
out = self.proj(out)
|
|
93
|
+
return out
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DPRNN(nn.Module):
|
|
97
|
+
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
|
|
101
|
+
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
|
|
102
|
+
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
|
|
103
|
+
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
|
|
104
|
+
d_model (int, optional): The number of expected features in the input. (Default: 256)
|
|
105
|
+
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
|
|
106
|
+
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
feat_dim: int = 64,
|
|
112
|
+
hidden_dim: int = 128,
|
|
113
|
+
num_blocks: int = 6,
|
|
114
|
+
rnn_type: str = "LSTM",
|
|
115
|
+
d_model: int = 256,
|
|
116
|
+
chunk_size: int = 100,
|
|
117
|
+
chunk_stride: int = 50,
|
|
118
|
+
) -> None:
|
|
119
|
+
super(DPRNN, self).__init__()
|
|
120
|
+
|
|
121
|
+
self.num_blocks = num_blocks
|
|
122
|
+
|
|
123
|
+
self.row_rnn = nn.ModuleList([])
|
|
124
|
+
self.col_rnn = nn.ModuleList([])
|
|
125
|
+
self.row_norm = nn.ModuleList([])
|
|
126
|
+
self.col_norm = nn.ModuleList([])
|
|
127
|
+
for _ in range(num_blocks):
|
|
128
|
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
|
129
|
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
|
130
|
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
|
131
|
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
|
132
|
+
self.conv = nn.Sequential(
|
|
133
|
+
nn.Conv2d(feat_dim, d_model, 1),
|
|
134
|
+
nn.PReLU(),
|
|
135
|
+
)
|
|
136
|
+
self.chunk_size = chunk_size
|
|
137
|
+
self.chunk_stride = chunk_stride
|
|
138
|
+
|
|
139
|
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
|
140
|
+
# input shape: (B, N, T)
|
|
141
|
+
seq_len = x.shape[-1]
|
|
142
|
+
|
|
143
|
+
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
|
144
|
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
|
145
|
+
|
|
146
|
+
return out, rest
|
|
147
|
+
|
|
148
|
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
|
149
|
+
out, rest = self.pad_chunk(x)
|
|
150
|
+
batch_size, feat_dim, seq_len = out.shape
|
|
151
|
+
|
|
152
|
+
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
|
153
|
+
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
|
154
|
+
out = torch.cat([segments1, segments2], dim=3)
|
|
155
|
+
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
|
|
156
|
+
|
|
157
|
+
return out, rest
|
|
158
|
+
|
|
159
|
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
|
160
|
+
batch_size, dim, _, _ = x.shape
|
|
161
|
+
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
|
|
162
|
+
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
|
163
|
+
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
|
164
|
+
out = out1 + out2
|
|
165
|
+
if rest > 0:
|
|
166
|
+
out = out[:, :, :-rest]
|
|
167
|
+
out = out.contiguous()
|
|
168
|
+
return out
|
|
169
|
+
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
x, rest = self.chunking(x)
|
|
172
|
+
batch_size, _, dim1, dim2 = x.shape
|
|
173
|
+
out = x
|
|
174
|
+
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
|
|
175
|
+
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
|
|
176
|
+
row_out = row_rnn(row_in)
|
|
177
|
+
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
|
178
|
+
row_out = row_norm(row_out)
|
|
179
|
+
out = out + row_out
|
|
180
|
+
|
|
181
|
+
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
|
|
182
|
+
col_out = col_rnn(col_in)
|
|
183
|
+
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
|
184
|
+
col_out = col_norm(col_out)
|
|
185
|
+
out = out + col_out
|
|
186
|
+
out = self.conv(out)
|
|
187
|
+
out = self.merging(out, rest)
|
|
188
|
+
out = out.transpose(1, 2).contiguous()
|
|
189
|
+
return out
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class AutoPool(nn.Module):
|
|
193
|
+
def __init__(self, pool_dim: int = 1) -> None:
|
|
194
|
+
super(AutoPool, self).__init__()
|
|
195
|
+
self.pool_dim: int = pool_dim
|
|
196
|
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
|
197
|
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
|
198
|
+
|
|
199
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
200
|
+
weight = self.softmax(torch.mul(x, self.alpha))
|
|
201
|
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
|
202
|
+
return out
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class SquimObjective(nn.Module):
|
|
206
|
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
|
|
207
|
+
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
|
|
211
|
+
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
|
|
212
|
+
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
encoder: nn.Module,
|
|
218
|
+
dprnn: nn.Module,
|
|
219
|
+
branches: nn.ModuleList,
|
|
220
|
+
):
|
|
221
|
+
super(SquimObjective, self).__init__()
|
|
222
|
+
self.encoder = encoder
|
|
223
|
+
self.dprnn = dprnn
|
|
224
|
+
self.branches = branches
|
|
225
|
+
|
|
226
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
227
|
+
"""
|
|
228
|
+
Args:
|
|
229
|
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
|
|
233
|
+
"""
|
|
234
|
+
if x.ndim != 2:
|
|
235
|
+
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
|
|
236
|
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
|
237
|
+
out = self.encoder(x)
|
|
238
|
+
out = self.dprnn(out)
|
|
239
|
+
scores = []
|
|
240
|
+
for branch in self.branches:
|
|
241
|
+
scores.append(branch(out).squeeze(dim=1))
|
|
242
|
+
return scores
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
|
246
|
+
"""Create branch module after DPRNN model for predicting metric score.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
d_model (int): The number of expected features in the input.
|
|
250
|
+
nhead (int): Number of heads in the multi-head attention model.
|
|
251
|
+
metric (str): The metric name to predict.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
(nn.Module): Returned module to predict corresponding metric score.
|
|
255
|
+
"""
|
|
256
|
+
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
|
|
257
|
+
layer2 = AutoPool()
|
|
258
|
+
if metric == "stoi":
|
|
259
|
+
layer3 = nn.Sequential(
|
|
260
|
+
nn.Linear(d_model, d_model),
|
|
261
|
+
nn.PReLU(),
|
|
262
|
+
nn.Linear(d_model, 1),
|
|
263
|
+
RangeSigmoid(),
|
|
264
|
+
)
|
|
265
|
+
elif metric == "pesq":
|
|
266
|
+
layer3 = nn.Sequential(
|
|
267
|
+
nn.Linear(d_model, d_model),
|
|
268
|
+
nn.PReLU(),
|
|
269
|
+
nn.Linear(d_model, 1),
|
|
270
|
+
RangeSigmoid(val_range=PESQRange),
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
|
|
274
|
+
return nn.Sequential(layer1, layer2, layer3)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def squim_objective_model(
|
|
278
|
+
feat_dim: int,
|
|
279
|
+
win_len: int,
|
|
280
|
+
d_model: int,
|
|
281
|
+
nhead: int,
|
|
282
|
+
hidden_dim: int,
|
|
283
|
+
num_blocks: int,
|
|
284
|
+
rnn_type: str,
|
|
285
|
+
chunk_size: int,
|
|
286
|
+
chunk_stride: Optional[int] = None,
|
|
287
|
+
) -> SquimObjective:
|
|
288
|
+
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
feat_dim (int, optional): The feature dimension after Encoder module.
|
|
292
|
+
win_len (int): Kernel size in the Encoder module.
|
|
293
|
+
d_model (int): The number of expected features in the input.
|
|
294
|
+
nhead (int): Number of heads in the multi-head attention model.
|
|
295
|
+
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
|
296
|
+
num_blocks (int): Number of DPRNN layers.
|
|
297
|
+
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
|
298
|
+
chunk_size (int): Chunk size of input for DPRNN.
|
|
299
|
+
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
|
300
|
+
"""
|
|
301
|
+
if chunk_stride is None:
|
|
302
|
+
chunk_stride = chunk_size // 2
|
|
303
|
+
encoder = Encoder(feat_dim, win_len)
|
|
304
|
+
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
|
|
305
|
+
branches = nn.ModuleList(
|
|
306
|
+
[
|
|
307
|
+
_create_branch(d_model, nhead, "stoi"),
|
|
308
|
+
_create_branch(d_model, nhead, "pesq"),
|
|
309
|
+
_create_branch(d_model, nhead, "sisdr"),
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
return SquimObjective(encoder, dprnn, branches)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def squim_objective_base() -> SquimObjective:
|
|
316
|
+
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
|
|
317
|
+
return squim_objective_model(
|
|
318
|
+
feat_dim=256,
|
|
319
|
+
win_len=64,
|
|
320
|
+
d_model=256,
|
|
321
|
+
nhead=4,
|
|
322
|
+
hidden_dim=256,
|
|
323
|
+
num_blocks=2,
|
|
324
|
+
rnn_type="LSTM",
|
|
325
|
+
chunk_size=71,
|
|
326
|
+
)
|