torchaudio 2.8.0__cp310-cp310-win_amd64.whl → 2.9.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +179 -39
- torchaudio/_extension/__init__.py +1 -14
- torchaudio/_extension/utils.py +0 -47
- torchaudio/_internal/module_utils.py +12 -3
- torchaudio/_torchcodec.py +73 -85
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/utils.py +1 -1
- torchaudio/functional/__init__.py +0 -2
- torchaudio/functional/_alignment.py +1 -1
- torchaudio/functional/filtering.py +70 -55
- torchaudio/functional/functional.py +26 -60
- torchaudio/lib/_torchaudio.pyd +0 -0
- torchaudio/lib/libtorchaudio.pyd +0 -0
- torchaudio/models/decoder/__init__.py +14 -2
- torchaudio/models/decoder/_ctc_decoder.py +6 -6
- torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
- torchaudio/models/squim/objective.py +2 -2
- torchaudio/pipelines/_source_separation_pipeline.py +1 -1
- torchaudio/pipelines/_squim_pipeline.py +2 -2
- torchaudio/pipelines/_tts/utils.py +1 -1
- torchaudio/pipelines/rnnt_pipeline.py +4 -4
- torchaudio/transforms/__init__.py +1 -0
- torchaudio/transforms/_transforms.py +2 -2
- torchaudio/utils/__init__.py +2 -9
- torchaudio/utils/download.py +1 -3
- torchaudio/version.py +2 -2
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
- torchaudio-2.9.0.dist-info/RECORD +85 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
- torchaudio/_backend/__init__.py +0 -61
- torchaudio/_backend/backend.py +0 -53
- torchaudio/_backend/common.py +0 -52
- torchaudio/_backend/ffmpeg.py +0 -334
- torchaudio/_backend/soundfile.py +0 -54
- torchaudio/_backend/soundfile_backend.py +0 -457
- torchaudio/_backend/sox.py +0 -91
- torchaudio/_backend/utils.py +0 -350
- torchaudio/backend/__init__.py +0 -8
- torchaudio/backend/_no_backend.py +0 -25
- torchaudio/backend/_sox_io_backend.py +0 -294
- torchaudio/backend/common.py +0 -13
- torchaudio/backend/no_backend.py +0 -14
- torchaudio/backend/soundfile_backend.py +0 -14
- torchaudio/backend/sox_io_backend.py +0 -14
- torchaudio/io/__init__.py +0 -20
- torchaudio/io/_effector.py +0 -347
- torchaudio/io/_playback.py +0 -72
- torchaudio/kaldi_io.py +0 -150
- torchaudio/prototype/__init__.py +0 -0
- torchaudio/prototype/datasets/__init__.py +0 -4
- torchaudio/prototype/datasets/musan.py +0 -68
- torchaudio/prototype/functional/__init__.py +0 -26
- torchaudio/prototype/functional/_dsp.py +0 -441
- torchaudio/prototype/functional/_rir.py +0 -382
- torchaudio/prototype/functional/functional.py +0 -193
- torchaudio/prototype/models/__init__.py +0 -39
- torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
- torchaudio/prototype/models/_emformer_hubert.py +0 -337
- torchaudio/prototype/models/conv_emformer.py +0 -529
- torchaudio/prototype/models/hifi_gan.py +0 -342
- torchaudio/prototype/models/rnnt.py +0 -717
- torchaudio/prototype/models/rnnt_decoder.py +0 -402
- torchaudio/prototype/pipelines/__init__.py +0 -21
- torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
- torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
- torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
- torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
- torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
- torchaudio/prototype/transforms/__init__.py +0 -9
- torchaudio/prototype/transforms/_transforms.py +0 -461
- torchaudio/sox_effects/__init__.py +0 -10
- torchaudio/sox_effects/sox_effects.py +0 -275
- torchaudio/utils/ffmpeg_utils.py +0 -11
- torchaudio/utils/sox_utils.py +0 -118
- torchaudio-2.8.0.dist-info/RECORD +0 -145
- torio/__init__.py +0 -8
- torio/_extension/__init__.py +0 -13
- torio/_extension/utils.py +0 -147
- torio/io/__init__.py +0 -9
- torio/io/_streaming_media_decoder.py +0 -977
- torio/io/_streaming_media_encoder.py +0 -502
- torio/lib/__init__.py +0 -0
- torio/lib/_torio_ffmpeg4.pyd +0 -0
- torio/lib/_torio_ffmpeg5.pyd +0 -0
- torio/lib/_torio_ffmpeg6.pyd +0 -0
- torio/lib/libtorio_ffmpeg4.pyd +0 -0
- torio/lib/libtorio_ffmpeg5.pyd +0 -0
- torio/lib/libtorio_ffmpeg6.pyd +0 -0
- torio/utils/__init__.py +0 -4
- torio/utils/ffmpeg_utils.py +0 -275
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,402 +0,0 @@
|
|
|
1
|
-
from typing import Callable, Dict, List, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torchaudio.models import RNNT
|
|
5
|
-
from torchaudio.prototype.models.rnnt import TrieNode
|
|
6
|
-
|
|
7
|
-
from torchaudio._internal.module_utils import dropping_class_support
|
|
8
|
-
|
|
9
|
-
__all__ = ["Hypothesis", "RNNTBeamSearchBiasing"]
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float, list]
|
|
13
|
-
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
|
|
14
|
-
represented as tuple of (tokens, prediction network output, prediction network state, score).
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
|
|
19
|
-
return hypo[0]
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
|
|
23
|
-
return hypo[1]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
|
|
27
|
-
return hypo[2]
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _get_hypo_score(hypo: Hypothesis) -> float:
|
|
31
|
-
return hypo[3]
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _get_hypo_trie(hypo: Hypothesis) -> TrieNode:
|
|
35
|
-
return hypo[4]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _set_hypo_trie(hypo: Hypothesis, trie: TrieNode) -> None:
|
|
39
|
-
hypo[4] = trie
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _get_hypo_key(hypo: Hypothesis) -> str:
|
|
43
|
-
return str(hypo[0])
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
|
|
47
|
-
states: List[List[torch.Tensor]] = []
|
|
48
|
-
for i in range(len(_get_hypo_state(hypos[0]))):
|
|
49
|
-
batched_state_components: List[torch.Tensor] = []
|
|
50
|
-
for j in range(len(_get_hypo_state(hypos[0])[i])):
|
|
51
|
-
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
|
|
52
|
-
states.append(batched_state_components)
|
|
53
|
-
return states
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
|
|
57
|
-
idx_tensor = torch.tensor([idx], device=device)
|
|
58
|
-
return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
|
|
62
|
-
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def _compute_updated_scores(
|
|
66
|
-
hypos: List[Hypothesis],
|
|
67
|
-
next_token_probs: torch.Tensor,
|
|
68
|
-
beam_width: int,
|
|
69
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
70
|
-
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
|
|
71
|
-
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
|
|
72
|
-
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
|
|
73
|
-
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
|
|
74
|
-
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
|
|
75
|
-
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
|
|
79
|
-
for i, elem in enumerate(hypo_list):
|
|
80
|
-
if _get_hypo_key(hypo) == _get_hypo_key(elem):
|
|
81
|
-
del hypo_list[i]
|
|
82
|
-
break
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@dropping_class_support
|
|
86
|
-
class RNNTBeamSearchBiasing(torch.nn.Module):
|
|
87
|
-
r"""Beam search decoder for RNN-T model with biasing support.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
model (RNNT): RNN-T model to use.
|
|
91
|
-
blank (int): index of blank token in vocabulary.
|
|
92
|
-
temperature (float, optional): temperature to apply to joint network output.
|
|
93
|
-
Larger values yield more uniform samples. (Default: 1.0)
|
|
94
|
-
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
|
|
95
|
-
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
|
|
96
|
-
hypothesis score normalized by token sequence length. (Default: None)
|
|
97
|
-
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
|
|
98
|
-
trie (list, optional): the prefix tree for TCPGen biasing
|
|
99
|
-
biasing (bool, optional): If true, do biasing, otherwise use standard RNN-T support
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
model: RNNT,
|
|
105
|
-
blank: int,
|
|
106
|
-
temperature: float = 1.0,
|
|
107
|
-
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
|
|
108
|
-
step_max_tokens: int = 100,
|
|
109
|
-
trie: TrieNode = None,
|
|
110
|
-
biasing: bool = False,
|
|
111
|
-
) -> None:
|
|
112
|
-
super().__init__()
|
|
113
|
-
self.model = model
|
|
114
|
-
self.blank = blank
|
|
115
|
-
self.temperature = temperature
|
|
116
|
-
self.resettrie = trie or []
|
|
117
|
-
self.dobiasing = biasing
|
|
118
|
-
|
|
119
|
-
if hypo_sort_key is None:
|
|
120
|
-
self.hypo_sort_key = _default_hypo_sort_key
|
|
121
|
-
else:
|
|
122
|
-
self.hypo_sort_key = hypo_sort_key
|
|
123
|
-
|
|
124
|
-
self.step_max_tokens = step_max_tokens
|
|
125
|
-
|
|
126
|
-
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
|
|
127
|
-
if hypo is not None:
|
|
128
|
-
token = _get_hypo_tokens(hypo)[-1]
|
|
129
|
-
state = _get_hypo_state(hypo)
|
|
130
|
-
else:
|
|
131
|
-
token = self.blank
|
|
132
|
-
state = None
|
|
133
|
-
|
|
134
|
-
one_tensor = torch.tensor([1], device=device)
|
|
135
|
-
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
|
|
136
|
-
init_hypo = ([token], pred_out[0].detach(), pred_state, 0.0, self.resettrie)
|
|
137
|
-
return [init_hypo]
|
|
138
|
-
|
|
139
|
-
def _get_trie_mask(self, trie):
|
|
140
|
-
step_mask = torch.ones(len(self.model.char_list) + 1)
|
|
141
|
-
step_mask[list(trie[0].keys())] = 0
|
|
142
|
-
# step_mask[-1] = 0
|
|
143
|
-
return step_mask
|
|
144
|
-
|
|
145
|
-
def _get_generation_prob(self, trie):
|
|
146
|
-
if len(trie[0].keys()) == 0:
|
|
147
|
-
return True
|
|
148
|
-
else:
|
|
149
|
-
return False
|
|
150
|
-
|
|
151
|
-
def _gen_next_token_probs(
|
|
152
|
-
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
|
|
153
|
-
) -> torch.Tensor:
|
|
154
|
-
one_tensor = torch.tensor([1], device=device)
|
|
155
|
-
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
|
|
156
|
-
if self.dobiasing:
|
|
157
|
-
# Get valid subset of wordpieces
|
|
158
|
-
trie_masks = torch.stack([self._get_trie_mask(_get_hypo_trie(h)) for h in hypos], dim=0)
|
|
159
|
-
trie_masks = trie_masks.to(enc_out.device).unsqueeze(1) # beam_width, 1, nchars
|
|
160
|
-
# Determine if there is any paths on the trie
|
|
161
|
-
genprob_masks = torch.tensor([self._get_generation_prob(_get_hypo_trie(h)) for h in hypos]) # beam_width
|
|
162
|
-
genprob_masks = genprob_masks.to(enc_out.device)
|
|
163
|
-
# Forward TCPGen component
|
|
164
|
-
last_tokens = torch.tensor([_get_hypo_tokens(h)[-1] for h in hypos]).unsqueeze(-1).to(enc_out.device)
|
|
165
|
-
hptr, tcpgen_dist = self.model.forward_tcpgen(last_tokens, trie_masks, enc_out)
|
|
166
|
-
else:
|
|
167
|
-
hptr = None
|
|
168
|
-
# hptr sent to joiner, if deepbiasing is True joiner will use it
|
|
169
|
-
joined_out, _, joined_activation = self.model.join(
|
|
170
|
-
enc_out,
|
|
171
|
-
one_tensor,
|
|
172
|
-
predictor_out,
|
|
173
|
-
torch.tensor([1] * len(hypos), device=device),
|
|
174
|
-
hptr=hptr,
|
|
175
|
-
) # [beam_width, 1, 1, num_tokens]
|
|
176
|
-
if self.dobiasing:
|
|
177
|
-
p_gen = torch.sigmoid(self.model.pointer_gate(torch.cat((joined_activation, hptr), dim=-1)))
|
|
178
|
-
p_gen = p_gen.masked_fill(genprob_masks.view(p_gen.size(0), 1, 1, 1), 0)
|
|
179
|
-
model_tu = torch.softmax(joined_out / self.temperature, dim=3)
|
|
180
|
-
# assuming last token is blank
|
|
181
|
-
p_not_null = 1.0 - model_tu[:, :, :, -1:]
|
|
182
|
-
ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
|
|
183
|
-
ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
|
|
184
|
-
p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_tu[:, :, :, :-1] * (1 - p_gen + ptr_gen_complement)
|
|
185
|
-
p_final = torch.cat([p_partial, model_tu[:, :, :, -1:]], dim=-1)
|
|
186
|
-
joined_out = torch.log(p_final)
|
|
187
|
-
else:
|
|
188
|
-
joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
|
|
189
|
-
return joined_out[:, 0, 0]
|
|
190
|
-
|
|
191
|
-
def _gen_b_hypos(
|
|
192
|
-
self,
|
|
193
|
-
b_hypos: List[Hypothesis],
|
|
194
|
-
a_hypos: List[Hypothesis],
|
|
195
|
-
next_token_probs: torch.Tensor,
|
|
196
|
-
key_to_b_hypo: Dict[str, Hypothesis],
|
|
197
|
-
) -> List[Hypothesis]:
|
|
198
|
-
for i in range(len(a_hypos)):
|
|
199
|
-
h_a = a_hypos[i]
|
|
200
|
-
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
|
|
201
|
-
if _get_hypo_key(h_a) in key_to_b_hypo:
|
|
202
|
-
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
|
|
203
|
-
_remove_hypo(h_b, b_hypos)
|
|
204
|
-
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
|
|
205
|
-
else:
|
|
206
|
-
score = float(append_blank_score)
|
|
207
|
-
h_b = (
|
|
208
|
-
_get_hypo_tokens(h_a),
|
|
209
|
-
_get_hypo_predictor_out(h_a),
|
|
210
|
-
_get_hypo_state(h_a),
|
|
211
|
-
score,
|
|
212
|
-
_get_hypo_trie(h_a),
|
|
213
|
-
)
|
|
214
|
-
b_hypos.append(h_b)
|
|
215
|
-
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
|
|
216
|
-
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
|
|
217
|
-
return [b_hypos[idx] for idx in sorted_idx]
|
|
218
|
-
|
|
219
|
-
def _gen_a_hypos(
|
|
220
|
-
self,
|
|
221
|
-
a_hypos: List[Hypothesis],
|
|
222
|
-
b_hypos: List[Hypothesis],
|
|
223
|
-
next_token_probs: torch.Tensor,
|
|
224
|
-
t: int,
|
|
225
|
-
beam_width: int,
|
|
226
|
-
device: torch.device,
|
|
227
|
-
) -> List[Hypothesis]:
|
|
228
|
-
(
|
|
229
|
-
nonblank_nbest_scores,
|
|
230
|
-
nonblank_nbest_hypo_idx,
|
|
231
|
-
nonblank_nbest_token,
|
|
232
|
-
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
|
|
233
|
-
|
|
234
|
-
if len(b_hypos) < beam_width:
|
|
235
|
-
b_nbest_score = -float("inf")
|
|
236
|
-
else:
|
|
237
|
-
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
|
|
238
|
-
|
|
239
|
-
base_hypos: List[Hypothesis] = []
|
|
240
|
-
new_tokens: List[int] = []
|
|
241
|
-
new_scores: List[float] = []
|
|
242
|
-
for i in range(beam_width):
|
|
243
|
-
score = float(nonblank_nbest_scores[i])
|
|
244
|
-
if score > b_nbest_score:
|
|
245
|
-
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
|
|
246
|
-
base_hypos.append(a_hypos[a_hypo_idx])
|
|
247
|
-
new_tokens.append(int(nonblank_nbest_token[i]))
|
|
248
|
-
new_scores.append(score)
|
|
249
|
-
|
|
250
|
-
if base_hypos:
|
|
251
|
-
new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
|
|
252
|
-
else:
|
|
253
|
-
new_hypos: List[Hypothesis] = []
|
|
254
|
-
|
|
255
|
-
return new_hypos
|
|
256
|
-
|
|
257
|
-
def _gen_new_hypos(
|
|
258
|
-
self,
|
|
259
|
-
base_hypos: List[Hypothesis],
|
|
260
|
-
tokens: List[int],
|
|
261
|
-
scores: List[float],
|
|
262
|
-
t: int,
|
|
263
|
-
device: torch.device,
|
|
264
|
-
) -> List[Hypothesis]:
|
|
265
|
-
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
|
|
266
|
-
states = _batch_state(base_hypos)
|
|
267
|
-
pred_out, _, pred_states = self.model.predict(
|
|
268
|
-
tgt_tokens,
|
|
269
|
-
torch.tensor([1] * len(base_hypos), device=device),
|
|
270
|
-
states,
|
|
271
|
-
)
|
|
272
|
-
new_hypos: List[Hypothesis] = []
|
|
273
|
-
for i, h_a in enumerate(base_hypos):
|
|
274
|
-
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
|
|
275
|
-
if self.dobiasing:
|
|
276
|
-
new_trie = self.model.get_tcpgen_step(tokens[i], _get_hypo_trie(h_a), self.resettrie)
|
|
277
|
-
else:
|
|
278
|
-
new_trie = self.resettrie
|
|
279
|
-
new_hypos.append(
|
|
280
|
-
(new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i], new_trie)
|
|
281
|
-
)
|
|
282
|
-
return new_hypos
|
|
283
|
-
|
|
284
|
-
def _search(
|
|
285
|
-
self,
|
|
286
|
-
enc_out: torch.Tensor,
|
|
287
|
-
hypo: Optional[Hypothesis],
|
|
288
|
-
beam_width: int,
|
|
289
|
-
) -> List[Hypothesis]:
|
|
290
|
-
n_time_steps = enc_out.shape[1]
|
|
291
|
-
device = enc_out.device
|
|
292
|
-
|
|
293
|
-
a_hypos: List[Hypothesis] = []
|
|
294
|
-
b_hypos = self._init_b_hypos(hypo, device)
|
|
295
|
-
for t in range(n_time_steps):
|
|
296
|
-
a_hypos = b_hypos
|
|
297
|
-
b_hypos = torch.jit.annotate(List[Hypothesis], [])
|
|
298
|
-
key_to_b_hypo: Dict[str, Hypothesis] = {}
|
|
299
|
-
symbols_current_t = 0
|
|
300
|
-
|
|
301
|
-
while a_hypos:
|
|
302
|
-
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
|
|
303
|
-
next_token_probs = next_token_probs.cpu()
|
|
304
|
-
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
|
|
305
|
-
|
|
306
|
-
if symbols_current_t == self.step_max_tokens:
|
|
307
|
-
break
|
|
308
|
-
|
|
309
|
-
a_hypos = self._gen_a_hypos(
|
|
310
|
-
a_hypos,
|
|
311
|
-
b_hypos,
|
|
312
|
-
next_token_probs,
|
|
313
|
-
t,
|
|
314
|
-
beam_width,
|
|
315
|
-
device,
|
|
316
|
-
)
|
|
317
|
-
if a_hypos:
|
|
318
|
-
symbols_current_t += 1
|
|
319
|
-
|
|
320
|
-
_, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
|
|
321
|
-
b_hypos = [b_hypos[idx] for idx in sorted_idx]
|
|
322
|
-
|
|
323
|
-
return b_hypos
|
|
324
|
-
|
|
325
|
-
def forward(
|
|
326
|
-
self,
|
|
327
|
-
input: torch.Tensor,
|
|
328
|
-
length: torch.Tensor,
|
|
329
|
-
beam_width: int,
|
|
330
|
-
) -> List[Hypothesis]:
|
|
331
|
-
r"""Performs beam search for the given input sequence.
|
|
332
|
-
|
|
333
|
-
T: number of frames;
|
|
334
|
-
D: feature dimension of each frame.
|
|
335
|
-
|
|
336
|
-
Args:
|
|
337
|
-
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
|
338
|
-
length (torch.Tensor): number of valid frames in input
|
|
339
|
-
sequence, with shape () or (1,).
|
|
340
|
-
beam_width (int): beam size to use during search.
|
|
341
|
-
|
|
342
|
-
Returns:
|
|
343
|
-
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
|
|
344
|
-
"""
|
|
345
|
-
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
|
346
|
-
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
|
347
|
-
if input.dim() == 2:
|
|
348
|
-
input = input.unsqueeze(0)
|
|
349
|
-
|
|
350
|
-
if length.shape != () and length.shape != (1,):
|
|
351
|
-
raise ValueError("length must be of shape () or (1,)")
|
|
352
|
-
if input.dim() == 0:
|
|
353
|
-
input = input.unsqueeze(0)
|
|
354
|
-
|
|
355
|
-
enc_out, _ = self.model.transcribe(input, length)
|
|
356
|
-
return self._search(enc_out, None, beam_width)
|
|
357
|
-
|
|
358
|
-
@torch.jit.export
|
|
359
|
-
def infer(
|
|
360
|
-
self,
|
|
361
|
-
input: torch.Tensor,
|
|
362
|
-
length: torch.Tensor,
|
|
363
|
-
beam_width: int,
|
|
364
|
-
state: Optional[List[List[torch.Tensor]]] = None,
|
|
365
|
-
hypothesis: Optional[Hypothesis] = None,
|
|
366
|
-
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
|
|
367
|
-
r"""Performs beam search for the given input sequence in streaming mode.
|
|
368
|
-
|
|
369
|
-
T: number of frames;
|
|
370
|
-
D: feature dimension of each frame.
|
|
371
|
-
|
|
372
|
-
Args:
|
|
373
|
-
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
|
|
374
|
-
length (torch.Tensor): number of valid frames in input
|
|
375
|
-
sequence, with shape () or (1,).
|
|
376
|
-
beam_width (int): beam size to use during search.
|
|
377
|
-
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
378
|
-
representing transcription network internal state generated in preceding
|
|
379
|
-
invocation. (Default: ``None``)
|
|
380
|
-
hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
|
|
381
|
-
search with. (Default: ``None``)
|
|
382
|
-
|
|
383
|
-
Returns:
|
|
384
|
-
(List[Hypothesis], List[List[torch.Tensor]]):
|
|
385
|
-
List[Hypothesis]
|
|
386
|
-
top-``beam_width`` hypotheses found by beam search.
|
|
387
|
-
List[List[torch.Tensor]]
|
|
388
|
-
list of lists of tensors representing transcription network
|
|
389
|
-
internal state generated in current invocation.
|
|
390
|
-
"""
|
|
391
|
-
if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
|
|
392
|
-
raise ValueError("input must be of shape (T, D) or (1, T, D)")
|
|
393
|
-
if input.dim() == 2:
|
|
394
|
-
input = input.unsqueeze(0)
|
|
395
|
-
|
|
396
|
-
if length.shape != () and length.shape != (1,):
|
|
397
|
-
raise ValueError("length must be of shape () or (1,)")
|
|
398
|
-
if length.dim() == 0:
|
|
399
|
-
length = length.unsqueeze(0)
|
|
400
|
-
|
|
401
|
-
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
|
|
402
|
-
return self._search(enc_out, hypothesis, beam_width), state
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
from ._vggish import VGGISH, VGGishBundle
|
|
2
|
-
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH as _HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
|
|
3
|
-
from .rnnt_pipeline import (
|
|
4
|
-
EMFORMER_RNNT_BASE_MUSTC as _EMFORMER_RNNT_BASE_MUSTC,
|
|
5
|
-
EMFORMER_RNNT_BASE_TEDLIUM3 as _EMFORMER_RNNT_BASE_TEDLIUM3
|
|
6
|
-
)
|
|
7
|
-
from torchaudio._internal.module_utils import dropping_const_support
|
|
8
|
-
|
|
9
|
-
EMFORMER_RNNT_BASE_MUSTC = dropping_const_support(_EMFORMER_RNNT_BASE_MUSTC)
|
|
10
|
-
EMFORMER_RNNT_BASE_TEDLIUM3 = dropping_const_support(_EMFORMER_RNNT_BASE_TEDLIUM3)
|
|
11
|
-
HIFIGAN_VOCODER_V3_LJSPEECH = dropping_const_support(_HIFIGAN_VOCODER_V3_LJSPEECH)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
__all__ = [
|
|
15
|
-
"EMFORMER_RNNT_BASE_MUSTC",
|
|
16
|
-
"EMFORMER_RNNT_BASE_TEDLIUM3",
|
|
17
|
-
"HIFIGAN_VOCODER_V3_LJSPEECH",
|
|
18
|
-
"HiFiGANVocoderBundle",
|
|
19
|
-
"VGGISH",
|
|
20
|
-
"VGGishBundle",
|
|
21
|
-
]
|
|
@@ -1,236 +0,0 @@
|
|
|
1
|
-
# Derived from torchvggish (https://github.com/harritaylor/torchvggish).
|
|
2
|
-
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
# ==============================================================================
|
|
16
|
-
|
|
17
|
-
import math
|
|
18
|
-
|
|
19
|
-
import torch
|
|
20
|
-
|
|
21
|
-
from torchaudio._internal.module_utils import dropping_class_support
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
|
25
|
-
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
_SAMPLE_RATE = 16000
|
|
29
|
-
_STFT_WINDOW_LENGTH_SECONDS = 0.025
|
|
30
|
-
_STFT_HOP_LENGTH_SECONDS = 0.010
|
|
31
|
-
_MEL_MIN_HZ = 125
|
|
32
|
-
_MEL_MAX_HZ = 7500
|
|
33
|
-
_NUM_BANDS = 64
|
|
34
|
-
_LOG_OFFSET = 0.01
|
|
35
|
-
_EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
|
|
36
|
-
_EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _build_features_network():
|
|
40
|
-
layers = []
|
|
41
|
-
|
|
42
|
-
for input_dim, output_dim in [(1, 64), (64, 128)]:
|
|
43
|
-
layers += [
|
|
44
|
-
torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
45
|
-
torch.nn.ReLU(inplace=True),
|
|
46
|
-
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
47
|
-
]
|
|
48
|
-
|
|
49
|
-
for input_dim, output_dim in [(128, 256), (256, 512)]:
|
|
50
|
-
layers += [
|
|
51
|
-
torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
|
|
52
|
-
torch.nn.ReLU(inplace=True),
|
|
53
|
-
torch.nn.Conv2d(
|
|
54
|
-
output_dim,
|
|
55
|
-
output_dim,
|
|
56
|
-
kernel_size=(3, 3),
|
|
57
|
-
stride=(1, 1),
|
|
58
|
-
padding=(1, 1),
|
|
59
|
-
),
|
|
60
|
-
torch.nn.ReLU(inplace=True),
|
|
61
|
-
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
62
|
-
]
|
|
63
|
-
|
|
64
|
-
return torch.nn.Sequential(*layers)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def _build_embedding_network():
|
|
68
|
-
return torch.nn.Sequential(
|
|
69
|
-
torch.nn.Linear(512 * 4 * 6, 4096),
|
|
70
|
-
torch.nn.ReLU(True),
|
|
71
|
-
torch.nn.Linear(4096, 4096),
|
|
72
|
-
torch.nn.ReLU(True),
|
|
73
|
-
torch.nn.Linear(4096, 128),
|
|
74
|
-
torch.nn.ReLU(True),
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def _frame(data, window_length, hop_length):
|
|
79
|
-
num_samples = data.shape[0]
|
|
80
|
-
num_frames = 1 + int(math.floor((num_samples - window_length) / hop_length))
|
|
81
|
-
shape = (num_frames, window_length) + data.shape[1:]
|
|
82
|
-
strides = (data.stride()[0] * hop_length,) + data.stride()
|
|
83
|
-
return torch.as_strided(data, shape, strides)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def _stft_magnitude(signal, fft_length, hop_length=None, window_length=None):
|
|
87
|
-
frames = _frame(signal, window_length, hop_length)
|
|
88
|
-
window = torch.hann_window(window_length, periodic=True).to(signal.device)
|
|
89
|
-
windowed_frames = frames * window
|
|
90
|
-
return torch.abs(torch.fft.rfft(windowed_frames, int(fft_length)))
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _hertz_to_mel(frequencies_hertz):
|
|
94
|
-
return _MEL_HIGH_FREQUENCY_Q * torch.log(1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def _spectrogram_to_mel_matrix(
|
|
98
|
-
num_mel_bins=20,
|
|
99
|
-
num_spectrogram_bins=129,
|
|
100
|
-
audio_sample_rate=8000,
|
|
101
|
-
lower_edge_hertz=125.0,
|
|
102
|
-
upper_edge_hertz=3800.0,
|
|
103
|
-
):
|
|
104
|
-
nyquist_hertz = audio_sample_rate / 2.0
|
|
105
|
-
if lower_edge_hertz < 0.0:
|
|
106
|
-
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
|
|
107
|
-
if lower_edge_hertz >= upper_edge_hertz:
|
|
108
|
-
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % (lower_edge_hertz, upper_edge_hertz))
|
|
109
|
-
|
|
110
|
-
if upper_edge_hertz > nyquist_hertz:
|
|
111
|
-
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % (upper_edge_hertz, nyquist_hertz))
|
|
112
|
-
spectrogram_bins_hertz = torch.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
|
113
|
-
|
|
114
|
-
spectrogram_bins_mel = _hertz_to_mel(spectrogram_bins_hertz)
|
|
115
|
-
# The i'th mel band (starting from i=1) has center frequency
|
|
116
|
-
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
|
117
|
-
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
|
118
|
-
# the band_edges_mel arrays.
|
|
119
|
-
band_edges_mel = torch.linspace(
|
|
120
|
-
_hertz_to_mel(torch.tensor(lower_edge_hertz)),
|
|
121
|
-
_hertz_to_mel(torch.tensor(upper_edge_hertz)),
|
|
122
|
-
num_mel_bins + 2,
|
|
123
|
-
)
|
|
124
|
-
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
|
125
|
-
# of spectrogram values.
|
|
126
|
-
mel_weights_matrix = torch.empty((num_spectrogram_bins, num_mel_bins))
|
|
127
|
-
for i in range(num_mel_bins):
|
|
128
|
-
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i : i + 3]
|
|
129
|
-
# Calculate lower and upper slopes for every spectrogram bin.
|
|
130
|
-
# Line segments are linear in the *mel* domain, not hertz.
|
|
131
|
-
lower_slope = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
|
|
132
|
-
upper_slope = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
|
|
133
|
-
|
|
134
|
-
# .. then intersect them with each other and zero.
|
|
135
|
-
mel_weights_matrix[:, i] = torch.maximum(torch.tensor(0.0), torch.minimum(lower_slope, upper_slope))
|
|
136
|
-
|
|
137
|
-
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
|
138
|
-
# coefficient.
|
|
139
|
-
mel_weights_matrix[0, :] = 0.0
|
|
140
|
-
return mel_weights_matrix
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def _log_mel_spectrogram(
|
|
144
|
-
data,
|
|
145
|
-
audio_sample_rate=8000,
|
|
146
|
-
log_offset=0.0,
|
|
147
|
-
window_length_secs=0.025,
|
|
148
|
-
hop_length_secs=0.010,
|
|
149
|
-
**kwargs,
|
|
150
|
-
):
|
|
151
|
-
window_length_samples = int(round(audio_sample_rate * window_length_secs))
|
|
152
|
-
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
|
|
153
|
-
fft_length = 2 ** int(math.ceil(math.log(window_length_samples) / math.log(2.0)))
|
|
154
|
-
|
|
155
|
-
spectrogram = _stft_magnitude(
|
|
156
|
-
data,
|
|
157
|
-
fft_length=fft_length,
|
|
158
|
-
hop_length=hop_length_samples,
|
|
159
|
-
window_length=window_length_samples,
|
|
160
|
-
)
|
|
161
|
-
mel_spectrogram = torch.matmul(
|
|
162
|
-
spectrogram,
|
|
163
|
-
_spectrogram_to_mel_matrix(
|
|
164
|
-
num_spectrogram_bins=spectrogram.shape[1],
|
|
165
|
-
audio_sample_rate=audio_sample_rate,
|
|
166
|
-
**kwargs,
|
|
167
|
-
).to(spectrogram),
|
|
168
|
-
)
|
|
169
|
-
return torch.log(mel_spectrogram + log_offset)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def _waveform_to_examples(data):
|
|
173
|
-
# Compute log mel spectrogram features, with shape (n_frame, n_mel)
|
|
174
|
-
log_mel = _log_mel_spectrogram(
|
|
175
|
-
data,
|
|
176
|
-
audio_sample_rate=_SAMPLE_RATE,
|
|
177
|
-
log_offset=_LOG_OFFSET,
|
|
178
|
-
window_length_secs=_STFT_WINDOW_LENGTH_SECONDS,
|
|
179
|
-
hop_length_secs=_STFT_HOP_LENGTH_SECONDS,
|
|
180
|
-
num_mel_bins=_NUM_BANDS,
|
|
181
|
-
lower_edge_hertz=_MEL_MIN_HZ,
|
|
182
|
-
upper_edge_hertz=_MEL_MAX_HZ,
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
# Frame features into examples, with shape (n_example, n_frame, n_mel)
|
|
186
|
-
features_sample_rate = 1.0 / _STFT_HOP_LENGTH_SECONDS
|
|
187
|
-
example_window_length = int(round(_EXAMPLE_WINDOW_SECONDS * features_sample_rate))
|
|
188
|
-
|
|
189
|
-
example_hop_length = int(round(_EXAMPLE_HOP_SECONDS * features_sample_rate))
|
|
190
|
-
log_mel_examples = _frame(log_mel, window_length=example_window_length, hop_length=example_hop_length)
|
|
191
|
-
|
|
192
|
-
# (n_example, 1, n_frame, n_mel)
|
|
193
|
-
return log_mel_examples.unsqueeze(1)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@dropping_class_support
|
|
197
|
-
class VGGish(torch.nn.Module):
|
|
198
|
-
"""Implementation of VGGish model :cite:`45611`."""
|
|
199
|
-
|
|
200
|
-
def __init__(self):
|
|
201
|
-
super().__init__()
|
|
202
|
-
|
|
203
|
-
self.features_network = _build_features_network()
|
|
204
|
-
self.embedding_network = _build_embedding_network()
|
|
205
|
-
|
|
206
|
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
207
|
-
"""
|
|
208
|
-
Args:
|
|
209
|
-
input (torch.Tensor): batch of spectrograms, with shape `(n_example, 1, n_frame, 64)`.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
torch.Tensor: model output, with shape `(n_example, 128)`.
|
|
213
|
-
"""
|
|
214
|
-
x = self.features_network(input)
|
|
215
|
-
|
|
216
|
-
x = x.permute(0, 2, 3, 1)
|
|
217
|
-
x = x.reshape(x.size(0), -1)
|
|
218
|
-
|
|
219
|
-
return self.embedding_network(x)
|
|
220
|
-
|
|
221
|
-
@dropping_class_support
|
|
222
|
-
class VGGishInputProcessor:
|
|
223
|
-
"""Converts raw waveforms to batches of examples to use as inputs to VGGish."""
|
|
224
|
-
|
|
225
|
-
def __call__(self, input: torch.Tensor) -> torch.Tensor:
|
|
226
|
-
"""
|
|
227
|
-
Args:
|
|
228
|
-
input (torch.Tensor): waveform, with shape `(T,)`.
|
|
229
|
-
sample_rate (int): sample rate of waveform in hertz.
|
|
230
|
-
|
|
231
|
-
Returns:
|
|
232
|
-
torch.Tensor: batch of examples to pass to VGGish, with shape `(n_example, 1, n_frame, 64)`.
|
|
233
|
-
"""
|
|
234
|
-
if len(input.shape) != 1:
|
|
235
|
-
raise ValueError("input waveform must have dimension of 1.")
|
|
236
|
-
return _waveform_to_examples(input)
|