torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. torchaudio/__init__.py +204 -0
  2. torchaudio/_extension/__init__.py +61 -0
  3. torchaudio/_extension/utils.py +133 -0
  4. torchaudio/_internal/__init__.py +10 -0
  5. torchaudio/_internal/module_utils.py +171 -0
  6. torchaudio/_torchcodec.py +340 -0
  7. torchaudio/compliance/__init__.py +5 -0
  8. torchaudio/compliance/kaldi.py +813 -0
  9. torchaudio/datasets/__init__.py +47 -0
  10. torchaudio/datasets/cmuarctic.py +157 -0
  11. torchaudio/datasets/cmudict.py +186 -0
  12. torchaudio/datasets/commonvoice.py +86 -0
  13. torchaudio/datasets/dr_vctk.py +121 -0
  14. torchaudio/datasets/fluentcommands.py +108 -0
  15. torchaudio/datasets/gtzan.py +1118 -0
  16. torchaudio/datasets/iemocap.py +147 -0
  17. torchaudio/datasets/librilight_limited.py +111 -0
  18. torchaudio/datasets/librimix.py +133 -0
  19. torchaudio/datasets/librispeech.py +174 -0
  20. torchaudio/datasets/librispeech_biasing.py +189 -0
  21. torchaudio/datasets/libritts.py +168 -0
  22. torchaudio/datasets/ljspeech.py +107 -0
  23. torchaudio/datasets/musdb_hq.py +139 -0
  24. torchaudio/datasets/quesst14.py +136 -0
  25. torchaudio/datasets/snips.py +157 -0
  26. torchaudio/datasets/speechcommands.py +183 -0
  27. torchaudio/datasets/tedlium.py +218 -0
  28. torchaudio/datasets/utils.py +54 -0
  29. torchaudio/datasets/vctk.py +143 -0
  30. torchaudio/datasets/voxceleb1.py +309 -0
  31. torchaudio/datasets/yesno.py +89 -0
  32. torchaudio/functional/__init__.py +130 -0
  33. torchaudio/functional/_alignment.py +128 -0
  34. torchaudio/functional/filtering.py +1685 -0
  35. torchaudio/functional/functional.py +2505 -0
  36. torchaudio/lib/__init__.py +0 -0
  37. torchaudio/lib/_torchaudio.so +0 -0
  38. torchaudio/lib/libtorchaudio.so +0 -0
  39. torchaudio/models/__init__.py +85 -0
  40. torchaudio/models/_hdemucs.py +1008 -0
  41. torchaudio/models/conformer.py +293 -0
  42. torchaudio/models/conv_tasnet.py +330 -0
  43. torchaudio/models/decoder/__init__.py +64 -0
  44. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  45. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  46. torchaudio/models/deepspeech.py +84 -0
  47. torchaudio/models/emformer.py +884 -0
  48. torchaudio/models/rnnt.py +816 -0
  49. torchaudio/models/rnnt_decoder.py +339 -0
  50. torchaudio/models/squim/__init__.py +11 -0
  51. torchaudio/models/squim/objective.py +326 -0
  52. torchaudio/models/squim/subjective.py +150 -0
  53. torchaudio/models/tacotron2.py +1046 -0
  54. torchaudio/models/wav2letter.py +72 -0
  55. torchaudio/models/wav2vec2/__init__.py +45 -0
  56. torchaudio/models/wav2vec2/components.py +1167 -0
  57. torchaudio/models/wav2vec2/model.py +1579 -0
  58. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  59. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  60. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  61. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  62. torchaudio/models/wavernn.py +409 -0
  63. torchaudio/pipelines/__init__.py +102 -0
  64. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  65. torchaudio/pipelines/_squim_pipeline.py +156 -0
  66. torchaudio/pipelines/_tts/__init__.py +16 -0
  67. torchaudio/pipelines/_tts/impl.py +385 -0
  68. torchaudio/pipelines/_tts/interface.py +255 -0
  69. torchaudio/pipelines/_tts/utils.py +230 -0
  70. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  71. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  72. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  73. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  74. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  75. torchaudio/transforms/__init__.py +78 -0
  76. torchaudio/transforms/_multi_channel.py +467 -0
  77. torchaudio/transforms/_transforms.py +2138 -0
  78. torchaudio/utils/__init__.py +4 -0
  79. torchaudio/utils/download.py +89 -0
  80. torchaudio/version.py +2 -0
  81. torchaudio-2.9.1.dist-info/METADATA +133 -0
  82. torchaudio-2.9.1.dist-info/RECORD +85 -0
  83. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  84. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  85. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,568 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools as it
4
+
5
+ from abc import abstractmethod
6
+ from collections import namedtuple
7
+ from typing import Dict, List, NamedTuple, Optional, Tuple, Union
8
+
9
+ import torch
10
+
11
+ from flashlight.lib.text.decoder import (
12
+ CriterionType as _CriterionType,
13
+ LexiconDecoder as _LexiconDecoder,
14
+ LexiconDecoderOptions as _LexiconDecoderOptions,
15
+ LexiconFreeDecoder as _LexiconFreeDecoder,
16
+ LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
17
+ LM as _LM,
18
+ LMState as _LMState,
19
+ SmearingMode as _SmearingMode,
20
+ Trie as _Trie,
21
+ ZeroLM as _ZeroLM,
22
+ )
23
+ from flashlight.lib.text.dictionary import (
24
+ create_word_dict as _create_word_dict,
25
+ Dictionary as _Dictionary,
26
+ load_words as _load_words,
27
+ )
28
+ from torchaudio.utils import _download_asset
29
+
30
+ try:
31
+ from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
32
+ except Exception:
33
+ try:
34
+ from flashlight.lib.text.decoder import KenLM as _KenLM
35
+ except Exception:
36
+ _KenLM = None
37
+
38
+ __all__ = [
39
+ "CTCHypothesis",
40
+ "CTCDecoder",
41
+ "CTCDecoderLM",
42
+ "CTCDecoderLMState",
43
+ "ctc_decoder",
44
+ "download_pretrained_files",
45
+ ]
46
+
47
+ _PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
48
+
49
+
50
+ def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
51
+ vocab_size = tokens_dict.index_size()
52
+ trie = _Trie(vocab_size, silence)
53
+ start_state = lm.start(False)
54
+
55
+ for word, spellings in lexicon.items():
56
+ word_idx = word_dict.get_index(word)
57
+ _, score = lm.score(start_state, word_idx)
58
+ for spelling in spellings:
59
+ spelling_idx = [tokens_dict.get_index(token) for token in spelling]
60
+ trie.insert(spelling_idx, word_idx, score)
61
+ trie.smear(_SmearingMode.MAX)
62
+ return trie
63
+
64
+
65
+ def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
66
+ word_dict = None
67
+ if lm_dict is not None:
68
+ word_dict = _Dictionary(lm_dict)
69
+
70
+ if lexicon and word_dict is None:
71
+ word_dict = _create_word_dict(lexicon)
72
+ elif not lexicon and word_dict is None and type(lm) is str:
73
+ d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
74
+ d[unk_word] = [[unk_word]]
75
+ word_dict = _create_word_dict(d)
76
+
77
+ return word_dict
78
+
79
+
80
+ class CTCHypothesis(NamedTuple):
81
+ r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
82
+ tokens: torch.LongTensor
83
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
84
+
85
+ words: List[str]
86
+ """List of predicted words.
87
+
88
+ Note:
89
+ This attribute is only applicable if a lexicon is provided to the decoder. If
90
+ decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
91
+ :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
92
+ """
93
+
94
+ score: float
95
+ """Score corresponding to hypothesis"""
96
+
97
+ timesteps: torch.IntTensor
98
+ """Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
99
+
100
+
101
+ class CTCDecoderLMState(_LMState):
102
+ """Language model state."""
103
+
104
+ @property
105
+ def children(self) -> Dict[int, CTCDecoderLMState]:
106
+ """Map of indices to LM states"""
107
+ return super().children
108
+
109
+ def child(self, usr_index: int) -> CTCDecoderLMState:
110
+ """Returns child corresponding to usr_index, or creates and returns a new state if input index
111
+ is not found.
112
+
113
+ Args:
114
+ usr_index (int): index corresponding to child state
115
+
116
+ Returns:
117
+ CTCDecoderLMState: child state corresponding to usr_index
118
+ """
119
+ return super().child(usr_index)
120
+
121
+ def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
122
+ """Compare two language model states.
123
+
124
+ Args:
125
+ state (CTCDecoderLMState): LM state to compare against
126
+
127
+ Returns:
128
+ int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
129
+ """
130
+ pass
131
+
132
+
133
+ class CTCDecoderLM(_LM):
134
+ """Language model base class for creating custom language models to use with the decoder."""
135
+
136
+ @abstractmethod
137
+ def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
138
+ """Initialize or reset the language model.
139
+
140
+ Args:
141
+ start_with_nothing (bool): whether or not to start sentence with sil token.
142
+
143
+ Returns:
144
+ CTCDecoderLMState: starting state
145
+ """
146
+ raise NotImplementedError
147
+
148
+ @abstractmethod
149
+ def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
150
+ """Evaluate the language model based on the current LM state and new word.
151
+
152
+ Args:
153
+ state (CTCDecoderLMState): current LM state
154
+ usr_token_idx (int): index of the word
155
+
156
+ Returns:
157
+ (CTCDecoderLMState, float)
158
+ CTCDecoderLMState:
159
+ new LM state
160
+ float:
161
+ score
162
+ """
163
+ raise NotImplementedError
164
+
165
+ @abstractmethod
166
+ def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
167
+ """Evaluate end for language model based on current LM state.
168
+
169
+ Args:
170
+ state (CTCDecoderLMState): current LM state
171
+
172
+ Returns:
173
+ (CTCDecoderLMState, float)
174
+ CTCDecoderLMState:
175
+ new LM state
176
+ float:
177
+ score
178
+ """
179
+ raise NotImplementedError
180
+
181
+
182
+ class CTCDecoder:
183
+ """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
184
+
185
+ .. devices:: CPU
186
+
187
+ Note:
188
+ To build the decoder, please use the factory function :func:`ctc_decoder`.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ nbest: int,
194
+ lexicon: Optional[Dict],
195
+ word_dict: _Dictionary,
196
+ tokens_dict: _Dictionary,
197
+ lm: CTCDecoderLM,
198
+ decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
199
+ blank_token: str,
200
+ sil_token: str,
201
+ unk_word: str,
202
+ ) -> None:
203
+ """
204
+ Args:
205
+ nbest (int): number of best decodings to return
206
+ lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
207
+ word_dict (_Dictionary): dictionary of words
208
+ tokens_dict (_Dictionary): dictionary of tokens
209
+ lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
210
+ decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
211
+ parameters used for beam search decoding
212
+ blank_token (str): token corresopnding to blank
213
+ sil_token (str): token corresponding to silence
214
+ unk_word (str): word corresponding to unknown
215
+ """
216
+
217
+ self.nbest = nbest
218
+ self.word_dict = word_dict
219
+ self.tokens_dict = tokens_dict
220
+ self.blank = self.tokens_dict.get_index(blank_token)
221
+ silence = self.tokens_dict.get_index(sil_token)
222
+ transitions = []
223
+
224
+ if lexicon:
225
+ trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
226
+ unk_word = word_dict.get_index(unk_word)
227
+ token_lm = False # use word level LM
228
+
229
+ self.decoder = _LexiconDecoder(
230
+ decoder_options,
231
+ trie,
232
+ lm,
233
+ silence,
234
+ self.blank,
235
+ unk_word,
236
+ transitions,
237
+ token_lm,
238
+ )
239
+ else:
240
+ self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
241
+ # https://github.com/pytorch/audio/issues/3218
242
+ # If lm is passed like rvalue reference, the lm object gets garbage collected,
243
+ # and later call to the lm fails.
244
+ # This ensures that lm object is not deleted as long as the decoder is alive.
245
+ # https://github.com/pybind/pybind11/discussions/4013
246
+ self.lm = lm
247
+
248
+ def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
249
+ idxs = (g[0] for g in it.groupby(idxs))
250
+ idxs = filter(lambda x: x != self.blank, idxs)
251
+ return torch.LongTensor(list(idxs))
252
+
253
+ def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
254
+ """Returns frame numbers corresponding to non-blank tokens."""
255
+
256
+ timesteps = []
257
+ for i, idx in enumerate(idxs):
258
+ if idx == self.blank:
259
+ continue
260
+ if i == 0 or idx != idxs[i - 1]:
261
+ timesteps.append(i)
262
+ return torch.IntTensor(timesteps)
263
+
264
+ def decode_begin(self):
265
+ """Initialize the internal state of the decoder.
266
+
267
+ See :py:meth:`decode_step` for the usage.
268
+
269
+ .. note::
270
+
271
+ This method is required only when performing online decoding.
272
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
273
+ """
274
+ self.decoder.decode_begin()
275
+
276
+ def decode_end(self):
277
+ """Finalize the internal state of the decoder.
278
+
279
+ See :py:meth:`decode_step` for the usage.
280
+
281
+ .. note::
282
+
283
+ This method is required only when performing online decoding.
284
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
285
+ """
286
+ self.decoder.decode_end()
287
+
288
+ def decode_step(self, emissions: torch.FloatTensor):
289
+ """Perform incremental decoding on top of the curent internal state.
290
+
291
+ .. note::
292
+
293
+ This method is required only when performing online decoding.
294
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
295
+
296
+ Args:
297
+ emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
298
+ probability distribution over labels; output of acoustic model.
299
+
300
+ Example:
301
+ >>> decoder = torchaudio.models.decoder.ctc_decoder(...)
302
+ >>> decoder.decode_begin()
303
+ >>> decoder.decode_step(emission1)
304
+ >>> decoder.decode_step(emission2)
305
+ >>> decoder.decode_end()
306
+ >>> result = decoder.get_final_hypothesis()
307
+ """
308
+ if emissions.dtype != torch.float32:
309
+ raise ValueError("emissions must be float32.")
310
+
311
+ if not emissions.is_cpu:
312
+ raise RuntimeError("emissions must be a CPU tensor.")
313
+
314
+ if not emissions.is_contiguous():
315
+ raise RuntimeError("emissions must be contiguous.")
316
+
317
+ if emissions.ndim != 2:
318
+ raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
319
+
320
+ T, N = emissions.size()
321
+ self.decoder.decode_step(emissions.data_ptr(), T, N)
322
+
323
+ def _to_hypo(self, results) -> List[CTCHypothesis]:
324
+ return [
325
+ CTCHypothesis(
326
+ tokens=self._get_tokens(result.tokens),
327
+ words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
328
+ score=result.score,
329
+ timesteps=self._get_timesteps(result.tokens),
330
+ )
331
+ for result in results
332
+ ]
333
+
334
+ def get_final_hypothesis(self) -> List[CTCHypothesis]:
335
+ """Get the final hypothesis
336
+
337
+ Returns:
338
+ List[CTCHypothesis]:
339
+ List of sorted best hypotheses.
340
+
341
+ .. note::
342
+
343
+ This method is required only when performing online decoding.
344
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
345
+ """
346
+ results = self.decoder.get_all_final_hypothesis()
347
+ return self._to_hypo(results[: self.nbest])
348
+
349
+ def __call__(
350
+ self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
351
+ ) -> List[List[CTCHypothesis]]:
352
+ """
353
+ Performs batched offline decoding.
354
+
355
+ .. note::
356
+
357
+ This method performs offline decoding in one go. To perform incremental decoding,
358
+ please refer to :py:meth:`decode_step`.
359
+
360
+ Args:
361
+ emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
362
+ probability distribution over labels; output of acoustic model.
363
+ lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
364
+ in time axis of the output Tensor in each batch.
365
+
366
+ Returns:
367
+ List[List[CTCHypothesis]]:
368
+ List of sorted best hypotheses for each audio sequence in the batch.
369
+ """
370
+
371
+ if emissions.dtype != torch.float32:
372
+ raise ValueError("emissions must be float32.")
373
+
374
+ if not emissions.is_cpu:
375
+ raise RuntimeError("emissions must be a CPU tensor.")
376
+
377
+ if not emissions.is_contiguous():
378
+ raise RuntimeError("emissions must be contiguous.")
379
+
380
+ if emissions.ndim != 3:
381
+ raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
382
+
383
+ if lengths is not None and not lengths.is_cpu:
384
+ raise RuntimeError("lengths must be a CPU tensor.")
385
+
386
+ B, T, N = emissions.size()
387
+ if lengths is None:
388
+ lengths = torch.full((B,), T)
389
+
390
+ float_bytes = 4
391
+ hypos = []
392
+
393
+ for b in range(B):
394
+ emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
395
+ results = self.decoder.decode(emissions_ptr, lengths[b], N)
396
+ hypos.append(self._to_hypo(results[: self.nbest]))
397
+ return hypos
398
+
399
+ def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
400
+ """
401
+ Map raw token IDs into corresponding tokens
402
+
403
+ Args:
404
+ idxs (LongTensor): raw token IDs generated from decoder
405
+
406
+ Returns:
407
+ List: tokens corresponding to the input IDs
408
+ """
409
+ return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
410
+
411
+
412
+ def ctc_decoder(
413
+ lexicon: Optional[str],
414
+ tokens: Union[str, List[str]],
415
+ lm: Union[str, CTCDecoderLM] = None,
416
+ lm_dict: Optional[str] = None,
417
+ nbest: int = 1,
418
+ beam_size: int = 50,
419
+ beam_size_token: Optional[int] = None,
420
+ beam_threshold: float = 50,
421
+ lm_weight: float = 2,
422
+ word_score: float = 0,
423
+ unk_score: float = float("-inf"),
424
+ sil_score: float = 0,
425
+ log_add: bool = False,
426
+ blank_token: str = "-",
427
+ sil_token: str = "|",
428
+ unk_word: str = "<unk>",
429
+ ) -> CTCDecoder:
430
+ """Builds an instance of :class:`CTCDecoder`.
431
+
432
+ Args:
433
+ lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
434
+ Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
435
+ decoding.
436
+ tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
437
+ format is for tokens mapping to the same index to be on the same line
438
+ lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
439
+ custom language model of type `CTCDecoderLM`, or `None` if not using a language model
440
+ lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
441
+ per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
442
+ in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
443
+ (Default: None)
444
+ nbest (int, optional): number of best decodings to return (Default: 1)
445
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
446
+ beam_size_token (int, optional): max number of tokens to consider at each decode step.
447
+ If `None`, it is set to the total number of tokens (Default: None)
448
+ beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
449
+ lm_weight (float, optional): weight of language model (Default: 2)
450
+ word_score (float, optional): word insertion score (Default: 0)
451
+ unk_score (float, optional): unknown word insertion score (Default: -inf)
452
+ sil_score (float, optional): silence insertion score (Default: 0)
453
+ log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
454
+ blank_token (str, optional): token corresponding to blank (Default: "-")
455
+ sil_token (str, optional): token corresponding to silence (Default: "|")
456
+ unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
457
+
458
+ Returns:
459
+ CTCDecoder: decoder
460
+
461
+ Example
462
+ >>> decoder = ctc_decoder(
463
+ >>> lexicon="lexicon.txt",
464
+ >>> tokens="tokens.txt",
465
+ >>> lm="kenlm.bin",
466
+ >>> )
467
+ >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
468
+ """
469
+ if lm_dict is not None and type(lm_dict) is not str:
470
+ raise ValueError("lm_dict must be None or str type.")
471
+
472
+ tokens_dict = _Dictionary(tokens)
473
+
474
+ # decoder options
475
+ if lexicon:
476
+ lexicon = _load_words(lexicon)
477
+ decoder_options = _LexiconDecoderOptions(
478
+ beam_size=beam_size,
479
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
480
+ beam_threshold=beam_threshold,
481
+ lm_weight=lm_weight,
482
+ word_score=word_score,
483
+ unk_score=unk_score,
484
+ sil_score=sil_score,
485
+ log_add=log_add,
486
+ criterion_type=_CriterionType.CTC,
487
+ )
488
+ else:
489
+ decoder_options = _LexiconFreeDecoderOptions(
490
+ beam_size=beam_size,
491
+ beam_size_token=beam_size_token or tokens_dict.index_size(),
492
+ beam_threshold=beam_threshold,
493
+ lm_weight=lm_weight,
494
+ sil_score=sil_score,
495
+ log_add=log_add,
496
+ criterion_type=_CriterionType.CTC,
497
+ )
498
+
499
+ # construct word dict and language model
500
+ word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
501
+
502
+ if type(lm) is str:
503
+ if _KenLM is None:
504
+ raise RuntimeError(
505
+ "flashlight-text is installed, but KenLM is not installed. "
506
+ "Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
507
+ )
508
+ lm = _KenLM(lm, word_dict)
509
+ elif lm is None:
510
+ lm = _ZeroLM()
511
+
512
+ return CTCDecoder(
513
+ nbest=nbest,
514
+ lexicon=lexicon,
515
+ word_dict=word_dict,
516
+ tokens_dict=tokens_dict,
517
+ lm=lm,
518
+ decoder_options=decoder_options,
519
+ blank_token=blank_token,
520
+ sil_token=sil_token,
521
+ unk_word=unk_word,
522
+ )
523
+
524
+
525
+ def _get_filenames(model: str) -> _PretrainedFiles:
526
+ if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
527
+ raise ValueError(
528
+ f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
529
+ )
530
+
531
+ prefix = f"decoder-assets/{model}"
532
+ return _PretrainedFiles(
533
+ lexicon=f"{prefix}/lexicon.txt",
534
+ tokens=f"{prefix}/tokens.txt",
535
+ lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
536
+ )
537
+
538
+
539
+ def download_pretrained_files(model: str) -> _PretrainedFiles:
540
+ """
541
+ Retrieves pretrained data files used for :func:`ctc_decoder`.
542
+
543
+ Args:
544
+ model (str): pretrained language model to download.
545
+ Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
546
+
547
+ Returns:
548
+ Object with the following attributes
549
+
550
+ * ``lm``: path corresponding to downloaded language model,
551
+ or ``None`` if the model is not associated with an lm
552
+ * ``lexicon``: path corresponding to downloaded lexicon file
553
+ * ``tokens``: path corresponding to downloaded tokens file
554
+ """
555
+
556
+ files = _get_filenames(model)
557
+ lexicon_file = _download_asset(files.lexicon)
558
+ tokens_file = _download_asset(files.tokens)
559
+ if files.lm is not None:
560
+ lm_file = _download_asset(files.lm)
561
+ else:
562
+ lm_file = None
563
+
564
+ return _PretrainedFiles(
565
+ lexicon=lexicon_file,
566
+ tokens=tokens_file,
567
+ lm=lm_file,
568
+ )