torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.0.dist-info/LICENSE +25 -0
- torchaudio-2.9.0.dist-info/METADATA +122 -0
- torchaudio-2.9.0.dist-info/RECORD +86 -0
- torchaudio-2.9.0.dist-info/WHEEL +5 -0
- torchaudio-2.9.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torchaudio.models import RNNT
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__all__ = ["Hypothesis", "RNNTBeamSearch"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
|
|
11
|
+
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
|
|
12
|
+
represented as tuple of (tokens, prediction network output, prediction network state, score).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
|
|
17
|
+
return hypo[0]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
|
|
21
|
+
return hypo[1]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
|
|
25
|
+
return hypo[2]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_hypo_score(hypo: Hypothesis) -> float:
|
|
29
|
+
return hypo[3]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_hypo_key(hypo: Hypothesis) -> str:
|
|
33
|
+
return str(hypo[0])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
|
|
37
|
+
states: List[List[torch.Tensor]] = []
|
|
38
|
+
for i in range(len(_get_hypo_state(hypos[0]))):
|
|
39
|
+
batched_state_components: List[torch.Tensor] = []
|
|
40
|
+
for j in range(len(_get_hypo_state(hypos[0])[i])):
|
|
41
|
+
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
|
|
42
|
+
states.append(batched_state_components)
|
|
43
|
+
return states
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
|
|
47
|
+
idx_tensor = torch.tensor([idx], device=device)
|
|
48
|
+
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
|
|
52
|
+
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _compute_updated_scores(
|
|
56
|
+
hypos: List[Hypothesis],
|
|
57
|
+
next_token_probs: torch.Tensor,
|
|
58
|
+
beam_width: int,
|
|
59
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
60
|
+
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
|
|
61
|
+
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
|
|
62
|
+
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
|
|
63
|
+
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
|
|
64
|
+
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
|
|
65
|
+
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
|
|
69
|
+
for i, elem in enumerate(hypo_list):
|
|
70
|
+
if _get_hypo_key(hypo) == _get_hypo_key(elem):
|
|
71
|
+
del hypo_list[i]
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RNNTBeamSearch(torch.nn.Module):
|
|
76
|
+
r"""Beam search decoder for RNN-T model.
|
|
77
|
+
|
|
78
|
+
See Also:
|
|
79
|
+
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model (RNNT): RNN-T model to use.
|
|
83
|
+
blank (int): index of blank token in vocabulary.
|
|
84
|
+
temperature (float, optional): temperature to apply to joint network output.
|
|
85
|
+
Larger values yield more uniform samples. (Default: 1.0)
|
|
86
|
+
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
|
|
87
|
+
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
|
|
88
|
+
hypothesis score normalized by token sequence length. (Default: None)
|
|
89
|
+
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
model: RNNT,
|
|
95
|
+
blank: int,
|
|
96
|
+
temperature: float = 1.0,
|
|
97
|
+
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
|
|
98
|
+
step_max_tokens: int = 100,
|
|
99
|
+
) -> None:
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.model = model
|
|
102
|
+
self.blank = blank
|
|
103
|
+
self.temperature = temperature
|
|
104
|
+
|
|
105
|
+
if hypo_sort_key is None:
|
|
106
|
+
self.hypo_sort_key = _default_hypo_sort_key
|
|
107
|
+
else:
|
|
108
|
+
self.hypo_sort_key = hypo_sort_key
|
|
109
|
+
|
|
110
|
+
self.step_max_tokens = step_max_tokens
|
|
111
|
+
|
|
112
|
+
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
|
|
113
|
+
token = self.blank
|
|
114
|
+
state = None
|
|
115
|
+
|
|
116
|
+
one_tensor = torch.tensor([1], device=device)
|
|
117
|
+
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
|
|
118
|
+
init_hypo = (
|
|
119
|
+
[token],
|
|
120
|
+
pred_out[0].detach(),
|
|
121
|
+
pred_state,
|
|
122
|
+
0.0,
|
|
123
|
+
)
|
|
124
|
+
return [init_hypo]
|
|
125
|
+
|
|
126
|
+
def _gen_next_token_probs(
|
|
127
|
+
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
|
|
128
|
+
) -> torch.Tensor:
|
|
129
|
+
one_tensor = torch.tensor([1], device=device)
|
|
130
|
+
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
|
|
131
|
+
joined_out, _, _ = self.model.join(
|
|
132
|
+
enc_out,
|
|
133
|
+
one_tensor,
|
|
134
|
+
predictor_out,
|
|
135
|
+
torch.tensor([1] * len(hypos), device=device),
|
|
136
|
+
) # [beam_width, 1, 1, num_tokens]
|
|
137
|
+
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
|
|
138
|
+
return joined_out[:, 0, 0]
|
|
139
|
+
|
|
140
|
+
def _gen_b_hypos(
|
|
141
|
+
self,
|
|
142
|
+
b_hypos: List[Hypothesis],
|
|
143
|
+
a_hypos: List[Hypothesis],
|
|
144
|
+
next_token_probs: torch.Tensor,
|
|
145
|
+
key_to_b_hypo: Dict[str, Hypothesis],
|
|
146
|
+
) -> List[Hypothesis]:
|
|
147
|
+
for i in range(len(a_hypos)):
|
|
148
|
+
h_a = a_hypos[i]
|
|
149
|
+
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
|
|
150
|
+
if _get_hypo_key(h_a) in key_to_b_hypo:
|
|
151
|
+
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
|
|
152
|
+
_remove_hypo(h_b, b_hypos)
|
|
153
|
+
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
|
|
154
|
+
else:
|
|
155
|
+
score = float(append_blank_score)
|
|
156
|
+
h_b = (
|
|
157
|
+
_get_hypo_tokens(h_a),
|
|
158
|
+
_get_hypo_predictor_out(h_a),
|
|
159
|
+
_get_hypo_state(h_a),
|
|
160
|
+
score,
|
|
161
|
+
)
|
|
162
|
+
b_hypos.append(h_b)
|
|
163
|
+
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
|
|
164
|
+
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
|
|
165
|
+
return [b_hypos[idx] for idx in sorted_idx]
|
|
166
|
+
|
|
167
|
+
def _gen_a_hypos(
|
|
168
|
+
self,
|
|
169
|
+
a_hypos: List[Hypothesis],
|
|
170
|
+
b_hypos: List[Hypothesis],
|
|
171
|
+
next_token_probs: torch.Tensor,
|
|
172
|
+
t: int,
|
|
173
|
+
beam_width: int,
|
|
174
|
+
device: torch.device,
|
|
175
|
+
) -> List[Hypothesis]:
|
|
176
|
+
(
|
|
177
|
+
nonblank_nbest_scores,
|
|
178
|
+
nonblank_nbest_hypo_idx,
|
|
179
|
+
nonblank_nbest_token,
|
|
180
|
+
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
|
|
181
|
+
|
|
182
|
+
if len(b_hypos) < beam_width:
|
|
183
|
+
b_nbest_score = -float("inf")
|
|
184
|
+
else:
|
|
185
|
+
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
|
|
186
|
+
|
|
187
|
+
base_hypos: List[Hypothesis] = []
|
|
188
|
+
new_tokens: List[int] = []
|
|
189
|
+
new_scores: List[float] = []
|
|
190
|
+
for i in range(beam_width):
|
|
191
|
+
score = float(nonblank_nbest_scores[i])
|
|
192
|
+
if score > b_nbest_score:
|
|
193
|
+
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
|
|
194
|
+
base_hypos.append(a_hypos[a_hypo_idx])
|
|
195
|
+
new_tokens.append(int(nonblank_nbest_token[i]))
|
|
196
|
+
new_scores.append(score)
|
|
197
|
+
|
|
198
|
+
if base_hypos:
|
|
199
|
+
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
|
|
200
|
+
else:
|
|
201
|
+
new_hypos: List[Hypothesis] = []
|
|
202
|
+
|
|
203
|
+
return new_hypos
|
|
204
|
+
|
|
205
|
+
def _gen_new_hypos(
|
|
206
|
+
self,
|
|
207
|
+
base_hypos: List[Hypothesis],
|
|
208
|
+
tokens: List[int],
|
|
209
|
+
scores: List[float],
|
|
210
|
+
t: int,
|
|
211
|
+
device: torch.device,
|
|
212
|
+
) -> List[Hypothesis]:
|
|
213
|
+
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
|
|
214
|
+
states = _batch_state(base_hypos)
|
|
215
|
+
pred_out, _, pred_states = self.model.predict(
|
|
216
|
+
tgt_tokens,
|
|
217
|
+
torch.tensor([1] * len(base_hypos), device=device),
|
|
218
|
+
states,
|
|
219
|
+
)
|
|
220
|
+
new_hypos: List[Hypothesis] = []
|
|
221
|
+
for i, h_a in enumerate(base_hypos):
|
|
222
|
+
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
|
|
223
|
+
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
|
|
224
|
+
return new_hypos
|
|
225
|
+
|
|
226
|
+
def _search(
|
|
227
|
+
self,
|
|
228
|
+
enc_out: torch.Tensor,
|
|
229
|
+
hypo: Optional[List[Hypothesis]],
|
|
230
|
+
beam_width: int,
|
|
231
|
+
) -> List[Hypothesis]:
|
|
232
|
+
n_time_steps = enc_out.shape[1]
|
|
233
|
+
device = enc_out.device
|
|
234
|
+
|
|
235
|
+
a_hypos: List[Hypothesis] = []
|
|
236
|
+
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
|
|
237
|
+
for t in range(n_time_steps):
|
|
238
|
+
a_hypos = b_hypos
|
|
239
|
+
b_hypos = torch.jit.annotate(List[Hypothesis], [])
|
|
240
|
+
key_to_b_hypo: Dict[str, Hypothesis] = {}
|
|
241
|
+
symbols_current_t = 0
|
|
242
|
+
|
|
243
|
+
while a_hypos:
|
|
244
|
+
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
|
|
245
|
+
next_token_probs = next_token_probs.cpu()
|
|
246
|
+
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
|
|
247
|
+
|
|
248
|
+
if symbols_current_t == self.step_max_tokens:
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
a_hypos = self._gen_a_hypos(
|
|
252
|
+
a_hypos,
|
|
253
|
+
b_hypos,
|
|
254
|
+
next_token_probs,
|
|
255
|
+
t,
|
|
256
|
+
beam_width,
|
|
257
|
+
device,
|
|
258
|
+
)
|
|
259
|
+
if a_hypos:
|
|
260
|
+
symbols_current_t += 1
|
|
261
|
+
|
|
262
|
+
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
|
|
263
|
+
b_hypos = [b_hypos[idx] for idx in sorted_idx]
|
|
264
|
+
|
|
265
|
+
return b_hypos
|
|
266
|
+
|
|
267
|
+
def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
|
|
268
|
+
r"""Performs beam search for the given input sequence.
|
|
269
|
+
|
|
270
|
+
T: number of frames;
|
|
271
|
+
D: feature dimension of each frame.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
|
275
|
+
length (torch.Tensor): number of valid frames in input
|
|
276
|
+
sequence, with shape () or (1,).
|
|
277
|
+
beam_width (int): beam size to use during search.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
|
|
281
|
+
"""
|
|
282
|
+
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
|
283
|
+
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
|
284
|
+
if input.dim() == 2:
|
|
285
|
+
input = input.unsqueeze(0)
|
|
286
|
+
|
|
287
|
+
if length.shape != () and length.shape != (1,):
|
|
288
|
+
raise ValueError("length must be of shape () or (1,)")
|
|
289
|
+
if length.dim() == 0:
|
|
290
|
+
length = length.unsqueeze(0)
|
|
291
|
+
|
|
292
|
+
enc_out, _ = self.model.transcribe(input, length)
|
|
293
|
+
return self._search(enc_out, None, beam_width)
|
|
294
|
+
|
|
295
|
+
@torch.jit.export
|
|
296
|
+
def infer(
|
|
297
|
+
self,
|
|
298
|
+
input: torch.Tensor,
|
|
299
|
+
length: torch.Tensor,
|
|
300
|
+
beam_width: int,
|
|
301
|
+
state: Optional[List[List[torch.Tensor]]] = None,
|
|
302
|
+
hypothesis: Optional[List[Hypothesis]] = None,
|
|
303
|
+
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
|
|
304
|
+
r"""Performs beam search for the given input sequence in streaming mode.
|
|
305
|
+
|
|
306
|
+
T: number of frames;
|
|
307
|
+
D: feature dimension of each frame.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
|
311
|
+
length (torch.Tensor): number of valid frames in input
|
|
312
|
+
sequence, with shape () or (1,).
|
|
313
|
+
beam_width (int): beam size to use during search.
|
|
314
|
+
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
315
|
+
representing transcription network internal state generated in preceding
|
|
316
|
+
invocation. (Default: ``None``)
|
|
317
|
+
hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
|
|
318
|
+
search with. (Default: ``None``)
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
(List[Hypothesis], List[List[torch.Tensor]]):
|
|
322
|
+
List[Hypothesis]
|
|
323
|
+
top-``beam_width`` hypotheses found by beam search.
|
|
324
|
+
List[List[torch.Tensor]]
|
|
325
|
+
list of lists of tensors representing transcription network
|
|
326
|
+
internal state generated in current invocation.
|
|
327
|
+
"""
|
|
328
|
+
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
|
329
|
+
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
|
330
|
+
if input.dim() == 2:
|
|
331
|
+
input = input.unsqueeze(0)
|
|
332
|
+
|
|
333
|
+
if length.shape != () and length.shape != (1,):
|
|
334
|
+
raise ValueError("length must be of shape () or (1,)")
|
|
335
|
+
if length.dim() == 0:
|
|
336
|
+
length = length.unsqueeze(0)
|
|
337
|
+
|
|
338
|
+
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
|
|
339
|
+
return self._search(enc_out, hypothesis, beam_width), state
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .objective import squim_objective_base, squim_objective_model, SquimObjective
|
|
2
|
+
from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"squim_objective_base",
|
|
6
|
+
"squim_objective_model",
|
|
7
|
+
"squim_subjective_base",
|
|
8
|
+
"squim_subjective_model",
|
|
9
|
+
"SquimObjective",
|
|
10
|
+
"SquimSubjective",
|
|
11
|
+
]
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def transform_wb_pesq_range(x: float) -> float:
|
|
10
|
+
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
|
|
11
|
+
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
|
|
12
|
+
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
x (float): Narrow-band PESQ score.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
(float): Wide-band PESQ score.
|
|
19
|
+
"""
|
|
20
|
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
PESQRange: Tuple[float, float] = (
|
|
24
|
+
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
|
|
25
|
+
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
|
|
26
|
+
# We are using 1.0 as a reasonable approximation.
|
|
27
|
+
transform_wb_pesq_range(4.5),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RangeSigmoid(nn.Module):
|
|
32
|
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
|
33
|
+
super(RangeSigmoid, self).__init__()
|
|
34
|
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
|
35
|
+
self.val_range: Tuple[float, float] = val_range
|
|
36
|
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
|
37
|
+
|
|
38
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
39
|
+
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
|
|
40
|
+
return out
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Encoder(nn.Module):
|
|
44
|
+
"""Encoder module that transform 1D waveform to 2D representations.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
|
|
48
|
+
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
|
52
|
+
super(Encoder, self).__init__()
|
|
53
|
+
|
|
54
|
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
"""Apply waveforms to convolutional layer and ReLU layer.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
|
|
64
|
+
"""
|
|
65
|
+
out = x.unsqueeze(dim=1)
|
|
66
|
+
out = F.relu(self.conv1d(out))
|
|
67
|
+
return out
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SingleRNN(nn.Module):
|
|
71
|
+
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
|
|
72
|
+
super(SingleRNN, self).__init__()
|
|
73
|
+
|
|
74
|
+
self.rnn_type = rnn_type
|
|
75
|
+
self.input_size = input_size
|
|
76
|
+
self.hidden_size = hidden_size
|
|
77
|
+
|
|
78
|
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
|
79
|
+
input_size,
|
|
80
|
+
hidden_size,
|
|
81
|
+
1,
|
|
82
|
+
dropout=dropout,
|
|
83
|
+
batch_first=True,
|
|
84
|
+
bidirectional=True,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
|
88
|
+
|
|
89
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
# input shape: batch, seq, dim
|
|
91
|
+
out, _ = self.rnn(x)
|
|
92
|
+
out = self.proj(out)
|
|
93
|
+
return out
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class DPRNN(nn.Module):
|
|
97
|
+
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
|
|
101
|
+
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
|
|
102
|
+
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
|
|
103
|
+
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
|
|
104
|
+
d_model (int, optional): The number of expected features in the input. (Default: 256)
|
|
105
|
+
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
|
|
106
|
+
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
feat_dim: int = 64,
|
|
112
|
+
hidden_dim: int = 128,
|
|
113
|
+
num_blocks: int = 6,
|
|
114
|
+
rnn_type: str = "LSTM",
|
|
115
|
+
d_model: int = 256,
|
|
116
|
+
chunk_size: int = 100,
|
|
117
|
+
chunk_stride: int = 50,
|
|
118
|
+
) -> None:
|
|
119
|
+
super(DPRNN, self).__init__()
|
|
120
|
+
|
|
121
|
+
self.num_blocks = num_blocks
|
|
122
|
+
|
|
123
|
+
self.row_rnn = nn.ModuleList([])
|
|
124
|
+
self.col_rnn = nn.ModuleList([])
|
|
125
|
+
self.row_norm = nn.ModuleList([])
|
|
126
|
+
self.col_norm = nn.ModuleList([])
|
|
127
|
+
for _ in range(num_blocks):
|
|
128
|
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
|
129
|
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
|
130
|
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
|
131
|
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
|
132
|
+
self.conv = nn.Sequential(
|
|
133
|
+
nn.Conv2d(feat_dim, d_model, 1),
|
|
134
|
+
nn.PReLU(),
|
|
135
|
+
)
|
|
136
|
+
self.chunk_size = chunk_size
|
|
137
|
+
self.chunk_stride = chunk_stride
|
|
138
|
+
|
|
139
|
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
|
140
|
+
# input shape: (B, N, T)
|
|
141
|
+
seq_len = x.shape[-1]
|
|
142
|
+
|
|
143
|
+
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
|
144
|
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
|
145
|
+
|
|
146
|
+
return out, rest
|
|
147
|
+
|
|
148
|
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
|
149
|
+
out, rest = self.pad_chunk(x)
|
|
150
|
+
batch_size, feat_dim, seq_len = out.shape
|
|
151
|
+
|
|
152
|
+
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
|
153
|
+
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
|
154
|
+
out = torch.cat([segments1, segments2], dim=3)
|
|
155
|
+
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
|
|
156
|
+
|
|
157
|
+
return out, rest
|
|
158
|
+
|
|
159
|
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
|
160
|
+
batch_size, dim, _, _ = x.shape
|
|
161
|
+
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
|
|
162
|
+
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
|
163
|
+
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
|
164
|
+
out = out1 + out2
|
|
165
|
+
if rest > 0:
|
|
166
|
+
out = out[:, :, :-rest]
|
|
167
|
+
out = out.contiguous()
|
|
168
|
+
return out
|
|
169
|
+
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
x, rest = self.chunking(x)
|
|
172
|
+
batch_size, _, dim1, dim2 = x.shape
|
|
173
|
+
out = x
|
|
174
|
+
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
|
|
175
|
+
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
|
|
176
|
+
row_out = row_rnn(row_in)
|
|
177
|
+
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
|
178
|
+
row_out = row_norm(row_out)
|
|
179
|
+
out = out + row_out
|
|
180
|
+
|
|
181
|
+
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
|
|
182
|
+
col_out = col_rnn(col_in)
|
|
183
|
+
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
|
184
|
+
col_out = col_norm(col_out)
|
|
185
|
+
out = out + col_out
|
|
186
|
+
out = self.conv(out)
|
|
187
|
+
out = self.merging(out, rest)
|
|
188
|
+
out = out.transpose(1, 2).contiguous()
|
|
189
|
+
return out
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class AutoPool(nn.Module):
|
|
193
|
+
def __init__(self, pool_dim: int = 1) -> None:
|
|
194
|
+
super(AutoPool, self).__init__()
|
|
195
|
+
self.pool_dim: int = pool_dim
|
|
196
|
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
|
197
|
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
|
198
|
+
|
|
199
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
200
|
+
weight = self.softmax(torch.mul(x, self.alpha))
|
|
201
|
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
|
202
|
+
return out
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class SquimObjective(nn.Module):
|
|
206
|
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
|
|
207
|
+
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
|
|
211
|
+
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
|
|
212
|
+
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
encoder: nn.Module,
|
|
218
|
+
dprnn: nn.Module,
|
|
219
|
+
branches: nn.ModuleList,
|
|
220
|
+
):
|
|
221
|
+
super(SquimObjective, self).__init__()
|
|
222
|
+
self.encoder = encoder
|
|
223
|
+
self.dprnn = dprnn
|
|
224
|
+
self.branches = branches
|
|
225
|
+
|
|
226
|
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
227
|
+
"""
|
|
228
|
+
Args:
|
|
229
|
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
|
|
233
|
+
"""
|
|
234
|
+
if x.ndim != 2:
|
|
235
|
+
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
|
|
236
|
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
|
237
|
+
out = self.encoder(x)
|
|
238
|
+
out = self.dprnn(out)
|
|
239
|
+
scores = []
|
|
240
|
+
for branch in self.branches:
|
|
241
|
+
scores.append(branch(out).squeeze(dim=1))
|
|
242
|
+
return scores
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
|
246
|
+
"""Create branch module after DPRNN model for predicting metric score.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
d_model (int): The number of expected features in the input.
|
|
250
|
+
nhead (int): Number of heads in the multi-head attention model.
|
|
251
|
+
metric (str): The metric name to predict.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
(nn.Module): Returned module to predict corresponding metric score.
|
|
255
|
+
"""
|
|
256
|
+
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
|
|
257
|
+
layer2 = AutoPool()
|
|
258
|
+
if metric == "stoi":
|
|
259
|
+
layer3 = nn.Sequential(
|
|
260
|
+
nn.Linear(d_model, d_model),
|
|
261
|
+
nn.PReLU(),
|
|
262
|
+
nn.Linear(d_model, 1),
|
|
263
|
+
RangeSigmoid(),
|
|
264
|
+
)
|
|
265
|
+
elif metric == "pesq":
|
|
266
|
+
layer3 = nn.Sequential(
|
|
267
|
+
nn.Linear(d_model, d_model),
|
|
268
|
+
nn.PReLU(),
|
|
269
|
+
nn.Linear(d_model, 1),
|
|
270
|
+
RangeSigmoid(val_range=PESQRange),
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
|
|
274
|
+
return nn.Sequential(layer1, layer2, layer3)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def squim_objective_model(
|
|
278
|
+
feat_dim: int,
|
|
279
|
+
win_len: int,
|
|
280
|
+
d_model: int,
|
|
281
|
+
nhead: int,
|
|
282
|
+
hidden_dim: int,
|
|
283
|
+
num_blocks: int,
|
|
284
|
+
rnn_type: str,
|
|
285
|
+
chunk_size: int,
|
|
286
|
+
chunk_stride: Optional[int] = None,
|
|
287
|
+
) -> SquimObjective:
|
|
288
|
+
"""Build a custome :class:`torchaudio.models.squim.SquimObjective` model.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
feat_dim (int, optional): The feature dimension after Encoder module.
|
|
292
|
+
win_len (int): Kernel size in the Encoder module.
|
|
293
|
+
d_model (int): The number of expected features in the input.
|
|
294
|
+
nhead (int): Number of heads in the multi-head attention model.
|
|
295
|
+
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
|
296
|
+
num_blocks (int): Number of DPRNN layers.
|
|
297
|
+
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
|
298
|
+
chunk_size (int): Chunk size of input for DPRNN.
|
|
299
|
+
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
|
300
|
+
"""
|
|
301
|
+
if chunk_stride is None:
|
|
302
|
+
chunk_stride = chunk_size // 2
|
|
303
|
+
encoder = Encoder(feat_dim, win_len)
|
|
304
|
+
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
|
|
305
|
+
branches = nn.ModuleList(
|
|
306
|
+
[
|
|
307
|
+
_create_branch(d_model, nhead, "stoi"),
|
|
308
|
+
_create_branch(d_model, nhead, "pesq"),
|
|
309
|
+
_create_branch(d_model, nhead, "sisdr"),
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
return SquimObjective(encoder, dprnn, branches)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def squim_objective_base() -> SquimObjective:
|
|
316
|
+
"""Build :class:`torchaudio.models.squim.SquimObjective` model with default arguments."""
|
|
317
|
+
return squim_objective_model(
|
|
318
|
+
feat_dim=256,
|
|
319
|
+
win_len=64,
|
|
320
|
+
d_model=256,
|
|
321
|
+
nhead=4,
|
|
322
|
+
hidden_dim=256,
|
|
323
|
+
num_blocks=2,
|
|
324
|
+
rnn_type="LSTM",
|
|
325
|
+
chunk_size=71,
|
|
326
|
+
)
|