torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +85 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,568 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import itertools as it
|
|
4
|
+
|
|
5
|
+
from abc import abstractmethod
|
|
6
|
+
from collections import namedtuple
|
|
7
|
+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from flashlight.lib.text.decoder import (
|
|
12
|
+
CriterionType as _CriterionType,
|
|
13
|
+
LexiconDecoder as _LexiconDecoder,
|
|
14
|
+
LexiconDecoderOptions as _LexiconDecoderOptions,
|
|
15
|
+
LexiconFreeDecoder as _LexiconFreeDecoder,
|
|
16
|
+
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
|
|
17
|
+
LM as _LM,
|
|
18
|
+
LMState as _LMState,
|
|
19
|
+
SmearingMode as _SmearingMode,
|
|
20
|
+
Trie as _Trie,
|
|
21
|
+
ZeroLM as _ZeroLM,
|
|
22
|
+
)
|
|
23
|
+
from flashlight.lib.text.dictionary import (
|
|
24
|
+
create_word_dict as _create_word_dict,
|
|
25
|
+
Dictionary as _Dictionary,
|
|
26
|
+
load_words as _load_words,
|
|
27
|
+
)
|
|
28
|
+
from torchaudio.utils import _download_asset
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
|
|
32
|
+
except Exception:
|
|
33
|
+
try:
|
|
34
|
+
from flashlight.lib.text.decoder import KenLM as _KenLM
|
|
35
|
+
except Exception:
|
|
36
|
+
_KenLM = None
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"CTCHypothesis",
|
|
40
|
+
"CTCDecoder",
|
|
41
|
+
"CTCDecoderLM",
|
|
42
|
+
"CTCDecoderLMState",
|
|
43
|
+
"ctc_decoder",
|
|
44
|
+
"download_pretrained_files",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
|
|
51
|
+
vocab_size = tokens_dict.index_size()
|
|
52
|
+
trie = _Trie(vocab_size, silence)
|
|
53
|
+
start_state = lm.start(False)
|
|
54
|
+
|
|
55
|
+
for word, spellings in lexicon.items():
|
|
56
|
+
word_idx = word_dict.get_index(word)
|
|
57
|
+
_, score = lm.score(start_state, word_idx)
|
|
58
|
+
for spelling in spellings:
|
|
59
|
+
spelling_idx = [tokens_dict.get_index(token) for token in spelling]
|
|
60
|
+
trie.insert(spelling_idx, word_idx, score)
|
|
61
|
+
trie.smear(_SmearingMode.MAX)
|
|
62
|
+
return trie
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
|
|
66
|
+
word_dict = None
|
|
67
|
+
if lm_dict is not None:
|
|
68
|
+
word_dict = _Dictionary(lm_dict)
|
|
69
|
+
|
|
70
|
+
if lexicon and word_dict is None:
|
|
71
|
+
word_dict = _create_word_dict(lexicon)
|
|
72
|
+
elif not lexicon and word_dict is None and type(lm) is str:
|
|
73
|
+
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
|
|
74
|
+
d[unk_word] = [[unk_word]]
|
|
75
|
+
word_dict = _create_word_dict(d)
|
|
76
|
+
|
|
77
|
+
return word_dict
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class CTCHypothesis(NamedTuple):
|
|
81
|
+
r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
|
|
82
|
+
tokens: torch.LongTensor
|
|
83
|
+
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
|
|
84
|
+
|
|
85
|
+
words: List[str]
|
|
86
|
+
"""List of predicted words.
|
|
87
|
+
|
|
88
|
+
Note:
|
|
89
|
+
This attribute is only applicable if a lexicon is provided to the decoder. If
|
|
90
|
+
decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
|
|
91
|
+
:func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
score: float
|
|
95
|
+
"""Score corresponding to hypothesis"""
|
|
96
|
+
|
|
97
|
+
timesteps: torch.IntTensor
|
|
98
|
+
"""Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class CTCDecoderLMState(_LMState):
|
|
102
|
+
"""Language model state."""
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def children(self) -> Dict[int, CTCDecoderLMState]:
|
|
106
|
+
"""Map of indices to LM states"""
|
|
107
|
+
return super().children
|
|
108
|
+
|
|
109
|
+
def child(self, usr_index: int) -> CTCDecoderLMState:
|
|
110
|
+
"""Returns child corresponding to usr_index, or creates and returns a new state if input index
|
|
111
|
+
is not found.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
usr_index (int): index corresponding to child state
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
CTCDecoderLMState: child state corresponding to usr_index
|
|
118
|
+
"""
|
|
119
|
+
return super().child(usr_index)
|
|
120
|
+
|
|
121
|
+
def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
|
|
122
|
+
"""Compare two language model states.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
state (CTCDecoderLMState): LM state to compare against
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
|
|
129
|
+
"""
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CTCDecoderLM(_LM):
|
|
134
|
+
"""Language model base class for creating custom language models to use with the decoder."""
|
|
135
|
+
|
|
136
|
+
@abstractmethod
|
|
137
|
+
def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
|
|
138
|
+
"""Initialize or reset the language model.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
start_with_nothing (bool): whether or not to start sentence with sil token.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
CTCDecoderLMState: starting state
|
|
145
|
+
"""
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
@abstractmethod
|
|
149
|
+
def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
|
|
150
|
+
"""Evaluate the language model based on the current LM state and new word.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
state (CTCDecoderLMState): current LM state
|
|
154
|
+
usr_token_idx (int): index of the word
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
(CTCDecoderLMState, float)
|
|
158
|
+
CTCDecoderLMState:
|
|
159
|
+
new LM state
|
|
160
|
+
float:
|
|
161
|
+
score
|
|
162
|
+
"""
|
|
163
|
+
raise NotImplementedError
|
|
164
|
+
|
|
165
|
+
@abstractmethod
|
|
166
|
+
def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
|
|
167
|
+
"""Evaluate end for language model based on current LM state.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
state (CTCDecoderLMState): current LM state
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
(CTCDecoderLMState, float)
|
|
174
|
+
CTCDecoderLMState:
|
|
175
|
+
new LM state
|
|
176
|
+
float:
|
|
177
|
+
score
|
|
178
|
+
"""
|
|
179
|
+
raise NotImplementedError
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class CTCDecoder:
|
|
183
|
+
"""CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
|
|
184
|
+
|
|
185
|
+
.. devices:: CPU
|
|
186
|
+
|
|
187
|
+
Note:
|
|
188
|
+
To build the decoder, please use the factory function :func:`ctc_decoder`.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
nbest: int,
|
|
194
|
+
lexicon: Optional[Dict],
|
|
195
|
+
word_dict: _Dictionary,
|
|
196
|
+
tokens_dict: _Dictionary,
|
|
197
|
+
lm: CTCDecoderLM,
|
|
198
|
+
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
|
|
199
|
+
blank_token: str,
|
|
200
|
+
sil_token: str,
|
|
201
|
+
unk_word: str,
|
|
202
|
+
) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Args:
|
|
205
|
+
nbest (int): number of best decodings to return
|
|
206
|
+
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
|
|
207
|
+
word_dict (_Dictionary): dictionary of words
|
|
208
|
+
tokens_dict (_Dictionary): dictionary of tokens
|
|
209
|
+
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
|
|
210
|
+
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
|
|
211
|
+
parameters used for beam search decoding
|
|
212
|
+
blank_token (str): token corresopnding to blank
|
|
213
|
+
sil_token (str): token corresponding to silence
|
|
214
|
+
unk_word (str): word corresponding to unknown
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
self.nbest = nbest
|
|
218
|
+
self.word_dict = word_dict
|
|
219
|
+
self.tokens_dict = tokens_dict
|
|
220
|
+
self.blank = self.tokens_dict.get_index(blank_token)
|
|
221
|
+
silence = self.tokens_dict.get_index(sil_token)
|
|
222
|
+
transitions = []
|
|
223
|
+
|
|
224
|
+
if lexicon:
|
|
225
|
+
trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
|
|
226
|
+
unk_word = word_dict.get_index(unk_word)
|
|
227
|
+
token_lm = False # use word level LM
|
|
228
|
+
|
|
229
|
+
self.decoder = _LexiconDecoder(
|
|
230
|
+
decoder_options,
|
|
231
|
+
trie,
|
|
232
|
+
lm,
|
|
233
|
+
silence,
|
|
234
|
+
self.blank,
|
|
235
|
+
unk_word,
|
|
236
|
+
transitions,
|
|
237
|
+
token_lm,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
|
|
241
|
+
# https://github.com/pytorch/audio/issues/3218
|
|
242
|
+
# If lm is passed like rvalue reference, the lm object gets garbage collected,
|
|
243
|
+
# and later call to the lm fails.
|
|
244
|
+
# This ensures that lm object is not deleted as long as the decoder is alive.
|
|
245
|
+
# https://github.com/pybind/pybind11/discussions/4013
|
|
246
|
+
self.lm = lm
|
|
247
|
+
|
|
248
|
+
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
|
|
249
|
+
idxs = (g[0] for g in it.groupby(idxs))
|
|
250
|
+
idxs = filter(lambda x: x != self.blank, idxs)
|
|
251
|
+
return torch.LongTensor(list(idxs))
|
|
252
|
+
|
|
253
|
+
def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
|
|
254
|
+
"""Returns frame numbers corresponding to non-blank tokens."""
|
|
255
|
+
|
|
256
|
+
timesteps = []
|
|
257
|
+
for i, idx in enumerate(idxs):
|
|
258
|
+
if idx == self.blank:
|
|
259
|
+
continue
|
|
260
|
+
if i == 0 or idx != idxs[i - 1]:
|
|
261
|
+
timesteps.append(i)
|
|
262
|
+
return torch.IntTensor(timesteps)
|
|
263
|
+
|
|
264
|
+
def decode_begin(self):
|
|
265
|
+
"""Initialize the internal state of the decoder.
|
|
266
|
+
|
|
267
|
+
See :py:meth:`decode_step` for the usage.
|
|
268
|
+
|
|
269
|
+
.. note::
|
|
270
|
+
|
|
271
|
+
This method is required only when performing online decoding.
|
|
272
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
273
|
+
"""
|
|
274
|
+
self.decoder.decode_begin()
|
|
275
|
+
|
|
276
|
+
def decode_end(self):
|
|
277
|
+
"""Finalize the internal state of the decoder.
|
|
278
|
+
|
|
279
|
+
See :py:meth:`decode_step` for the usage.
|
|
280
|
+
|
|
281
|
+
.. note::
|
|
282
|
+
|
|
283
|
+
This method is required only when performing online decoding.
|
|
284
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
285
|
+
"""
|
|
286
|
+
self.decoder.decode_end()
|
|
287
|
+
|
|
288
|
+
def decode_step(self, emissions: torch.FloatTensor):
|
|
289
|
+
"""Perform incremental decoding on top of the curent internal state.
|
|
290
|
+
|
|
291
|
+
.. note::
|
|
292
|
+
|
|
293
|
+
This method is required only when performing online decoding.
|
|
294
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
|
|
298
|
+
probability distribution over labels; output of acoustic model.
|
|
299
|
+
|
|
300
|
+
Example:
|
|
301
|
+
>>> decoder = torchaudio.models.decoder.ctc_decoder(...)
|
|
302
|
+
>>> decoder.decode_begin()
|
|
303
|
+
>>> decoder.decode_step(emission1)
|
|
304
|
+
>>> decoder.decode_step(emission2)
|
|
305
|
+
>>> decoder.decode_end()
|
|
306
|
+
>>> result = decoder.get_final_hypothesis()
|
|
307
|
+
"""
|
|
308
|
+
if emissions.dtype != torch.float32:
|
|
309
|
+
raise ValueError("emissions must be float32.")
|
|
310
|
+
|
|
311
|
+
if not emissions.is_cpu:
|
|
312
|
+
raise RuntimeError("emissions must be a CPU tensor.")
|
|
313
|
+
|
|
314
|
+
if not emissions.is_contiguous():
|
|
315
|
+
raise RuntimeError("emissions must be contiguous.")
|
|
316
|
+
|
|
317
|
+
if emissions.ndim != 2:
|
|
318
|
+
raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
|
|
319
|
+
|
|
320
|
+
T, N = emissions.size()
|
|
321
|
+
self.decoder.decode_step(emissions.data_ptr(), T, N)
|
|
322
|
+
|
|
323
|
+
def _to_hypo(self, results) -> List[CTCHypothesis]:
|
|
324
|
+
return [
|
|
325
|
+
CTCHypothesis(
|
|
326
|
+
tokens=self._get_tokens(result.tokens),
|
|
327
|
+
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
|
|
328
|
+
score=result.score,
|
|
329
|
+
timesteps=self._get_timesteps(result.tokens),
|
|
330
|
+
)
|
|
331
|
+
for result in results
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
def get_final_hypothesis(self) -> List[CTCHypothesis]:
|
|
335
|
+
"""Get the final hypothesis
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
List[CTCHypothesis]:
|
|
339
|
+
List of sorted best hypotheses.
|
|
340
|
+
|
|
341
|
+
.. note::
|
|
342
|
+
|
|
343
|
+
This method is required only when performing online decoding.
|
|
344
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
345
|
+
"""
|
|
346
|
+
results = self.decoder.get_all_final_hypothesis()
|
|
347
|
+
return self._to_hypo(results[: self.nbest])
|
|
348
|
+
|
|
349
|
+
def __call__(
|
|
350
|
+
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
|
|
351
|
+
) -> List[List[CTCHypothesis]]:
|
|
352
|
+
"""
|
|
353
|
+
Performs batched offline decoding.
|
|
354
|
+
|
|
355
|
+
.. note::
|
|
356
|
+
|
|
357
|
+
This method performs offline decoding in one go. To perform incremental decoding,
|
|
358
|
+
please refer to :py:meth:`decode_step`.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
|
|
362
|
+
probability distribution over labels; output of acoustic model.
|
|
363
|
+
lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
|
|
364
|
+
in time axis of the output Tensor in each batch.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
List[List[CTCHypothesis]]:
|
|
368
|
+
List of sorted best hypotheses for each audio sequence in the batch.
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
if emissions.dtype != torch.float32:
|
|
372
|
+
raise ValueError("emissions must be float32.")
|
|
373
|
+
|
|
374
|
+
if not emissions.is_cpu:
|
|
375
|
+
raise RuntimeError("emissions must be a CPU tensor.")
|
|
376
|
+
|
|
377
|
+
if not emissions.is_contiguous():
|
|
378
|
+
raise RuntimeError("emissions must be contiguous.")
|
|
379
|
+
|
|
380
|
+
if emissions.ndim != 3:
|
|
381
|
+
raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
|
|
382
|
+
|
|
383
|
+
if lengths is not None and not lengths.is_cpu:
|
|
384
|
+
raise RuntimeError("lengths must be a CPU tensor.")
|
|
385
|
+
|
|
386
|
+
B, T, N = emissions.size()
|
|
387
|
+
if lengths is None:
|
|
388
|
+
lengths = torch.full((B,), T)
|
|
389
|
+
|
|
390
|
+
float_bytes = 4
|
|
391
|
+
hypos = []
|
|
392
|
+
|
|
393
|
+
for b in range(B):
|
|
394
|
+
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
|
|
395
|
+
results = self.decoder.decode(emissions_ptr, lengths[b], N)
|
|
396
|
+
hypos.append(self._to_hypo(results[: self.nbest]))
|
|
397
|
+
return hypos
|
|
398
|
+
|
|
399
|
+
def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
|
|
400
|
+
"""
|
|
401
|
+
Map raw token IDs into corresponding tokens
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
idxs (LongTensor): raw token IDs generated from decoder
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
List: tokens corresponding to the input IDs
|
|
408
|
+
"""
|
|
409
|
+
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def ctc_decoder(
|
|
413
|
+
lexicon: Optional[str],
|
|
414
|
+
tokens: Union[str, List[str]],
|
|
415
|
+
lm: Union[str, CTCDecoderLM] = None,
|
|
416
|
+
lm_dict: Optional[str] = None,
|
|
417
|
+
nbest: int = 1,
|
|
418
|
+
beam_size: int = 50,
|
|
419
|
+
beam_size_token: Optional[int] = None,
|
|
420
|
+
beam_threshold: float = 50,
|
|
421
|
+
lm_weight: float = 2,
|
|
422
|
+
word_score: float = 0,
|
|
423
|
+
unk_score: float = float("-inf"),
|
|
424
|
+
sil_score: float = 0,
|
|
425
|
+
log_add: bool = False,
|
|
426
|
+
blank_token: str = "-",
|
|
427
|
+
sil_token: str = "|",
|
|
428
|
+
unk_word: str = "<unk>",
|
|
429
|
+
) -> CTCDecoder:
|
|
430
|
+
"""Builds an instance of :class:`CTCDecoder`.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
|
|
434
|
+
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
|
|
435
|
+
decoding.
|
|
436
|
+
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
|
|
437
|
+
format is for tokens mapping to the same index to be on the same line
|
|
438
|
+
lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
|
|
439
|
+
custom language model of type `CTCDecoderLM`, or `None` if not using a language model
|
|
440
|
+
lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
|
|
441
|
+
per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
|
|
442
|
+
in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
|
|
443
|
+
(Default: None)
|
|
444
|
+
nbest (int, optional): number of best decodings to return (Default: 1)
|
|
445
|
+
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
|
|
446
|
+
beam_size_token (int, optional): max number of tokens to consider at each decode step.
|
|
447
|
+
If `None`, it is set to the total number of tokens (Default: None)
|
|
448
|
+
beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
|
|
449
|
+
lm_weight (float, optional): weight of language model (Default: 2)
|
|
450
|
+
word_score (float, optional): word insertion score (Default: 0)
|
|
451
|
+
unk_score (float, optional): unknown word insertion score (Default: -inf)
|
|
452
|
+
sil_score (float, optional): silence insertion score (Default: 0)
|
|
453
|
+
log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
|
|
454
|
+
blank_token (str, optional): token corresponding to blank (Default: "-")
|
|
455
|
+
sil_token (str, optional): token corresponding to silence (Default: "|")
|
|
456
|
+
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
CTCDecoder: decoder
|
|
460
|
+
|
|
461
|
+
Example
|
|
462
|
+
>>> decoder = ctc_decoder(
|
|
463
|
+
>>> lexicon="lexicon.txt",
|
|
464
|
+
>>> tokens="tokens.txt",
|
|
465
|
+
>>> lm="kenlm.bin",
|
|
466
|
+
>>> )
|
|
467
|
+
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
|
|
468
|
+
"""
|
|
469
|
+
if lm_dict is not None and type(lm_dict) is not str:
|
|
470
|
+
raise ValueError("lm_dict must be None or str type.")
|
|
471
|
+
|
|
472
|
+
tokens_dict = _Dictionary(tokens)
|
|
473
|
+
|
|
474
|
+
# decoder options
|
|
475
|
+
if lexicon:
|
|
476
|
+
lexicon = _load_words(lexicon)
|
|
477
|
+
decoder_options = _LexiconDecoderOptions(
|
|
478
|
+
beam_size=beam_size,
|
|
479
|
+
beam_size_token=beam_size_token or tokens_dict.index_size(),
|
|
480
|
+
beam_threshold=beam_threshold,
|
|
481
|
+
lm_weight=lm_weight,
|
|
482
|
+
word_score=word_score,
|
|
483
|
+
unk_score=unk_score,
|
|
484
|
+
sil_score=sil_score,
|
|
485
|
+
log_add=log_add,
|
|
486
|
+
criterion_type=_CriterionType.CTC,
|
|
487
|
+
)
|
|
488
|
+
else:
|
|
489
|
+
decoder_options = _LexiconFreeDecoderOptions(
|
|
490
|
+
beam_size=beam_size,
|
|
491
|
+
beam_size_token=beam_size_token or tokens_dict.index_size(),
|
|
492
|
+
beam_threshold=beam_threshold,
|
|
493
|
+
lm_weight=lm_weight,
|
|
494
|
+
sil_score=sil_score,
|
|
495
|
+
log_add=log_add,
|
|
496
|
+
criterion_type=_CriterionType.CTC,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# construct word dict and language model
|
|
500
|
+
word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
|
|
501
|
+
|
|
502
|
+
if type(lm) is str:
|
|
503
|
+
if _KenLM is None:
|
|
504
|
+
raise RuntimeError(
|
|
505
|
+
"flashlight-text is installed, but KenLM is not installed. "
|
|
506
|
+
"Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
|
|
507
|
+
)
|
|
508
|
+
lm = _KenLM(lm, word_dict)
|
|
509
|
+
elif lm is None:
|
|
510
|
+
lm = _ZeroLM()
|
|
511
|
+
|
|
512
|
+
return CTCDecoder(
|
|
513
|
+
nbest=nbest,
|
|
514
|
+
lexicon=lexicon,
|
|
515
|
+
word_dict=word_dict,
|
|
516
|
+
tokens_dict=tokens_dict,
|
|
517
|
+
lm=lm,
|
|
518
|
+
decoder_options=decoder_options,
|
|
519
|
+
blank_token=blank_token,
|
|
520
|
+
sil_token=sil_token,
|
|
521
|
+
unk_word=unk_word,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _get_filenames(model: str) -> _PretrainedFiles:
|
|
526
|
+
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
|
|
527
|
+
raise ValueError(
|
|
528
|
+
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
prefix = f"decoder-assets/{model}"
|
|
532
|
+
return _PretrainedFiles(
|
|
533
|
+
lexicon=f"{prefix}/lexicon.txt",
|
|
534
|
+
tokens=f"{prefix}/tokens.txt",
|
|
535
|
+
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def download_pretrained_files(model: str) -> _PretrainedFiles:
|
|
540
|
+
"""
|
|
541
|
+
Retrieves pretrained data files used for :func:`ctc_decoder`.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
model (str): pretrained language model to download.
|
|
545
|
+
Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
Object with the following attributes
|
|
549
|
+
|
|
550
|
+
* ``lm``: path corresponding to downloaded language model,
|
|
551
|
+
or ``None`` if the model is not associated with an lm
|
|
552
|
+
* ``lexicon``: path corresponding to downloaded lexicon file
|
|
553
|
+
* ``tokens``: path corresponding to downloaded tokens file
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
files = _get_filenames(model)
|
|
557
|
+
lexicon_file = _download_asset(files.lexicon)
|
|
558
|
+
tokens_file = _download_asset(files.tokens)
|
|
559
|
+
if files.lm is not None:
|
|
560
|
+
lm_file = _download_asset(files.lm)
|
|
561
|
+
else:
|
|
562
|
+
lm_file = None
|
|
563
|
+
|
|
564
|
+
return _PretrainedFiles(
|
|
565
|
+
lexicon=lexicon_file,
|
|
566
|
+
tokens=tokens_file,
|
|
567
|
+
lm=lm_file,
|
|
568
|
+
)
|