torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,255 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ from torch import Tensor
5
+ from torchaudio.models import Tacotron2
6
+
7
+
8
+ class _TextProcessor(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def tokens(self):
12
+ """The tokens that the each value in the processed tensor represent.
13
+
14
+ :type: List[str]
15
+ """
16
+
17
+ @abstractmethod
18
+ def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
19
+ """Encode the given (batch of) texts into numerical tensors
20
+
21
+ Args:
22
+ text (str or list of str): The input texts.
23
+
24
+ Returns:
25
+ (Tensor, Tensor):
26
+ Tensor:
27
+ The encoded texts. Shape: `(batch, max length)`
28
+ Tensor:
29
+ The valid length of each sample in the batch. Shape: `(batch, )`.
30
+ """
31
+
32
+
33
+ class _Vocoder(ABC):
34
+ @property
35
+ @abstractmethod
36
+ def sample_rate(self):
37
+ """The sample rate of the resulting waveform
38
+
39
+ :type: float
40
+ """
41
+
42
+ @abstractmethod
43
+ def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
44
+ """Generate waveform from the given input, such as spectrogram
45
+
46
+ Args:
47
+ specgrams (Tensor):
48
+ The input spectrogram. Shape: `(batch, frequency bins, time)`.
49
+ The expected shape depends on the implementation.
50
+ lengths (Tensor, or None, optional):
51
+ The valid length of each sample in the batch. Shape: `(batch, )`.
52
+ (Default: `None`)
53
+
54
+ Returns:
55
+ (Tensor, Optional[Tensor]):
56
+ Tensor:
57
+ The generated waveform. Shape: `(batch, max length)`
58
+ Tensor or None:
59
+ The valid length of each sample in the batch. Shape: `(batch, )`.
60
+ """
61
+
62
+
63
+ class Tacotron2TTSBundle(ABC):
64
+ """Data class that bundles associated information to use pretrained Tacotron2 and vocoder.
65
+
66
+ This class provides interfaces for instantiating the pretrained model along with
67
+ the information necessary to retrieve pretrained weights and additional data
68
+ to be used with the model.
69
+
70
+ Torchaudio library instantiates objects of this class, each of which represents
71
+ a different pretrained model. Client code should access pretrained models via these
72
+ instances.
73
+
74
+ Please see below for the usage and the available values.
75
+
76
+ Example - Character-based TTS pipeline with Tacotron2 and WaveRNN
77
+ >>> import torchaudio
78
+ >>>
79
+ >>> text = "Hello, T T S !"
80
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
81
+ >>>
82
+ >>> # Build processor, Tacotron2 and WaveRNN model
83
+ >>> processor = bundle.get_text_processor()
84
+ >>> tacotron2 = bundle.get_tacotron2()
85
+ Downloading:
86
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
87
+ >>> vocoder = bundle.get_vocoder()
88
+ Downloading:
89
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
90
+ >>>
91
+ >>> # Encode text
92
+ >>> input, lengths = processor(text)
93
+ >>>
94
+ >>> # Generate (mel-scale) spectrogram
95
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
96
+ >>>
97
+ >>> # Convert spectrogram to waveform
98
+ >>> waveforms, lengths = vocoder(specgram, lengths)
99
+ >>>
100
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
101
+
102
+ Example - Phoneme-based TTS pipeline with Tacotron2 and WaveRNN
103
+ >>>
104
+ >>> # Note:
105
+ >>> # This bundle uses pre-trained DeepPhonemizer as
106
+ >>> # the text pre-processor.
107
+ >>> # Please install deep-phonemizer.
108
+ >>> # See https://github.com/as-ideas/DeepPhonemizer
109
+ >>> # The pretrained weight is automatically downloaded.
110
+ >>>
111
+ >>> import torchaudio
112
+ >>>
113
+ >>> text = "Hello, TTS!"
114
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
115
+ >>>
116
+ >>> # Build processor, Tacotron2 and WaveRNN model
117
+ >>> processor = bundle.get_text_processor()
118
+ Downloading:
119
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
120
+ >>> tacotron2 = bundle.get_tacotron2()
121
+ Downloading:
122
+ 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s]
123
+ >>> vocoder = bundle.get_vocoder()
124
+ Downloading:
125
+ 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s]
126
+ >>>
127
+ >>> # Encode text
128
+ >>> input, lengths = processor(text)
129
+ >>>
130
+ >>> # Generate (mel-scale) spectrogram
131
+ >>> specgram, lengths, _ = tacotron2.infer(input, lengths)
132
+ >>>
133
+ >>> # Convert spectrogram to waveform
134
+ >>> waveforms, lengths = vocoder(specgram, lengths)
135
+ >>>
136
+ >>> torchaudio.save('hello-tts.wav', waveforms, vocoder.sample_rate)
137
+ """
138
+
139
+ # Using the inner class so that these interfaces are not directly exposed on
140
+ # `torchaudio.pipelines`, but still listed in documentation.
141
+ # The thing is, text processing and vocoder are generic and we do not know what kind of
142
+ # new text processing and vocoder will be added in the future, so we want to make these
143
+ # interfaces specific to this Tacotron2TTS pipeline.
144
+
145
+ class TextProcessor(_TextProcessor):
146
+ """Interface of the text processing part of Tacotron2TTS pipeline
147
+
148
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage.
149
+ """
150
+
151
+ class Vocoder(_Vocoder):
152
+ """Interface of the vocoder part of Tacotron2TTS pipeline
153
+
154
+ See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
155
+ """
156
+
157
+ @abstractmethod
158
+ def get_text_processor(self, *, dl_kwargs=None) -> TextProcessor:
159
+ """Create a text processor
160
+
161
+ For character-based pipeline, this processor splits the input text by character.
162
+ For phoneme-based pipeline, this processor converts the input text (grapheme) to
163
+ phonemes.
164
+
165
+ If a pre-trained weight file is necessary,
166
+ :func:`torch.hub.download_url_to_file` is used to downloaded it.
167
+
168
+ Args:
169
+ dl_kwargs (dictionary of keyword arguments,):
170
+ Passed to :func:`torch.hub.download_url_to_file`.
171
+
172
+ Returns:
173
+ TextProcessor:
174
+ A callable which takes a string or a list of strings as input and
175
+ returns Tensor of encoded texts and Tensor of valid lengths.
176
+ The object also has ``tokens`` property, which allows to recover the
177
+ tokenized form.
178
+
179
+ Example - Character-based
180
+ >>> text = [
181
+ >>> "Hello World!",
182
+ >>> "Text-to-speech!",
183
+ >>> ]
184
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
185
+ >>> processor = bundle.get_text_processor()
186
+ >>> input, lengths = processor(text)
187
+ >>>
188
+ >>> print(input)
189
+ tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 0, 0, 0],
190
+ [31, 16, 35, 31, 1, 31, 26, 1, 30, 27, 16, 16, 14, 19, 2]],
191
+ dtype=torch.int32)
192
+ >>>
193
+ >>> print(lengths)
194
+ tensor([12, 15], dtype=torch.int32)
195
+ >>>
196
+ >>> print([processor.tokens[i] for i in input[0, :lengths[0]]])
197
+ ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!']
198
+ >>> print([processor.tokens[i] for i in input[1, :lengths[1]]])
199
+ ['t', 'e', 'x', 't', '-', 't', 'o', '-', 's', 'p', 'e', 'e', 'c', 'h', '!']
200
+
201
+ Example - Phoneme-based
202
+ >>> text = [
203
+ >>> "Hello, T T S !",
204
+ >>> "Text-to-speech!",
205
+ >>> ]
206
+ >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
207
+ >>> processor = bundle.get_text_processor()
208
+ Downloading:
209
+ 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s]
210
+ >>> input, lengths = processor(text)
211
+ >>>
212
+ >>> print(input)
213
+ tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 0, 0, 0, 0],
214
+ [81, 40, 64, 79, 81, 1, 81, 20, 1, 79, 77, 59, 37, 2]],
215
+ dtype=torch.int32)
216
+ >>>
217
+ >>> print(lengths)
218
+ tensor([10, 14], dtype=torch.int32)
219
+ >>>
220
+ >>> print([processor.tokens[i] for i in input[0]])
221
+ ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', '_', '_', '_', '_']
222
+ >>> print([processor.tokens[i] for i in input[1]])
223
+ ['T', 'EH', 'K', 'S', 'T', '-', 'T', 'AH', '-', 'S', 'P', 'IY', 'CH', '!']
224
+ """
225
+
226
+ @abstractmethod
227
+ def get_vocoder(self, *, dl_kwargs=None) -> Vocoder:
228
+ """Create a vocoder module, based off of either WaveRNN or GriffinLim.
229
+
230
+ If a pre-trained weight file is necessary,
231
+ :func:`torch.hub.load_state_dict_from_url` is used to downloaded it.
232
+
233
+ Args:
234
+ dl_kwargs (dictionary of keyword arguments):
235
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
236
+
237
+ Returns:
238
+ Vocoder:
239
+ A vocoder module, which takes spectrogram Tensor and an optional
240
+ length Tensor, then returns resulting waveform Tensor and an optional
241
+ length Tensor.
242
+ """
243
+
244
+ @abstractmethod
245
+ def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
246
+ """Create a Tacotron2 model with pre-trained weight.
247
+
248
+ Args:
249
+ dl_kwargs (dictionary of keyword arguments):
250
+ Passed to :func:`torch.hub.load_state_dict_from_url`.
251
+
252
+ Returns:
253
+ Tacotron2:
254
+ The resulting model.
255
+ """
@@ -0,0 +1,230 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from torchaudio._internal import download_url_to_file, module_utils as _mod_utils
6
+
7
+
8
+ def _get_chars():
9
+ return (
10
+ "_",
11
+ "-",
12
+ "!",
13
+ "'",
14
+ "(",
15
+ ")",
16
+ ",",
17
+ ".",
18
+ ":",
19
+ ";",
20
+ "?",
21
+ " ",
22
+ "a",
23
+ "b",
24
+ "c",
25
+ "d",
26
+ "e",
27
+ "f",
28
+ "g",
29
+ "h",
30
+ "i",
31
+ "j",
32
+ "k",
33
+ "l",
34
+ "m",
35
+ "n",
36
+ "o",
37
+ "p",
38
+ "q",
39
+ "r",
40
+ "s",
41
+ "t",
42
+ "u",
43
+ "v",
44
+ "w",
45
+ "x",
46
+ "y",
47
+ "z",
48
+ )
49
+
50
+
51
+ def _get_phones():
52
+ return (
53
+ "_",
54
+ "-",
55
+ "!",
56
+ "'",
57
+ "(",
58
+ ")",
59
+ ",",
60
+ ".",
61
+ ":",
62
+ ";",
63
+ "?",
64
+ " ",
65
+ "AA",
66
+ "AA0",
67
+ "AA1",
68
+ "AA2",
69
+ "AE",
70
+ "AE0",
71
+ "AE1",
72
+ "AE2",
73
+ "AH",
74
+ "AH0",
75
+ "AH1",
76
+ "AH2",
77
+ "AO",
78
+ "AO0",
79
+ "AO1",
80
+ "AO2",
81
+ "AW",
82
+ "AW0",
83
+ "AW1",
84
+ "AW2",
85
+ "AY",
86
+ "AY0",
87
+ "AY1",
88
+ "AY2",
89
+ "B",
90
+ "CH",
91
+ "D",
92
+ "DH",
93
+ "EH",
94
+ "EH0",
95
+ "EH1",
96
+ "EH2",
97
+ "ER",
98
+ "ER0",
99
+ "ER1",
100
+ "ER2",
101
+ "EY",
102
+ "EY0",
103
+ "EY1",
104
+ "EY2",
105
+ "F",
106
+ "G",
107
+ "HH",
108
+ "IH",
109
+ "IH0",
110
+ "IH1",
111
+ "IH2",
112
+ "IY",
113
+ "IY0",
114
+ "IY1",
115
+ "IY2",
116
+ "JH",
117
+ "K",
118
+ "L",
119
+ "M",
120
+ "N",
121
+ "NG",
122
+ "OW",
123
+ "OW0",
124
+ "OW1",
125
+ "OW2",
126
+ "OY",
127
+ "OY0",
128
+ "OY1",
129
+ "OY2",
130
+ "P",
131
+ "R",
132
+ "S",
133
+ "SH",
134
+ "T",
135
+ "TH",
136
+ "UH",
137
+ "UH0",
138
+ "UH1",
139
+ "UH2",
140
+ "UW",
141
+ "UW0",
142
+ "UW1",
143
+ "UW2",
144
+ "V",
145
+ "W",
146
+ "Y",
147
+ "Z",
148
+ "ZH",
149
+ )
150
+
151
+
152
+ def _to_tensor(indices):
153
+ lengths = torch.tensor([len(i) for i in indices], dtype=torch.int32)
154
+ values = [torch.tensor(i) for i in indices]
155
+ values = torch.nn.utils.rnn.pad_sequence(values, batch_first=True)
156
+ return values, lengths
157
+
158
+
159
+ def _load_phonemizer(file, dl_kwargs):
160
+ if not _mod_utils.is_module_available("dp"):
161
+ raise RuntimeError("DeepPhonemizer is not installed. Please install it.")
162
+
163
+ from dp.phonemizer import Phonemizer
164
+ from dp.preprocessing.text import LanguageTokenizer, Preprocessor, SequenceTokenizer
165
+
166
+ # By default, dp issues DEBUG level log.
167
+ logger = logging.getLogger("dp")
168
+ orig_level = logger.level
169
+ logger.setLevel(logging.INFO)
170
+ try:
171
+ url = f"https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}"
172
+ directory = os.path.join(torch.hub.get_dir(), "checkpoints")
173
+ os.makedirs(directory, exist_ok=True)
174
+ path = os.path.join(directory, file)
175
+ if not os.path.exists(path):
176
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
177
+ download_url_to_file(url, path, **dl_kwargs)
178
+ with torch.serialization.safe_globals([Preprocessor, LanguageTokenizer, SequenceTokenizer]):
179
+ return Phonemizer.from_checkpoint(path)
180
+ finally:
181
+ logger.setLevel(orig_level)
182
+
183
+
184
+ def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
185
+ r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
186
+ waveform = torch.clamp(waveform, -1, 1)
187
+ waveform = (waveform + 1.0) * (2**bits - 1) / 2
188
+ return torch.clamp(waveform, 0, 2**bits - 1).int()
189
+
190
+
191
+ def _get_taco_params(n_symbols):
192
+ return {
193
+ "mask_padding": False,
194
+ "n_mels": 80,
195
+ "n_frames_per_step": 1,
196
+ "symbol_embedding_dim": 512,
197
+ "encoder_embedding_dim": 512,
198
+ "encoder_n_convolution": 3,
199
+ "encoder_kernel_size": 5,
200
+ "decoder_rnn_dim": 1024,
201
+ "decoder_max_step": 2000,
202
+ "decoder_dropout": 0.1,
203
+ "decoder_early_stopping": True,
204
+ "attention_rnn_dim": 1024,
205
+ "attention_hidden_dim": 128,
206
+ "attention_location_n_filter": 32,
207
+ "attention_location_kernel_size": 31,
208
+ "attention_dropout": 0.1,
209
+ "prenet_dim": 256,
210
+ "postnet_n_convolution": 5,
211
+ "postnet_kernel_size": 5,
212
+ "postnet_embedding_dim": 512,
213
+ "gate_threshold": 0.5,
214
+ "n_symbol": n_symbols,
215
+ }
216
+
217
+
218
+ def _get_wrnn_params():
219
+ return {
220
+ "upsample_scales": [5, 5, 11],
221
+ "n_classes": 2**8, # n_bits = 8
222
+ "hop_length": 275,
223
+ "n_res_block": 10,
224
+ "n_rnn": 512,
225
+ "n_fc": 512,
226
+ "kernel_size": 5,
227
+ "n_freq": 80,
228
+ "n_hidden": 128,
229
+ "n_output": 128,
230
+ }
File without changes
@@ -0,0 +1,87 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torchaudio.functional as F
6
+ from torch import Tensor
7
+ from torchaudio.functional import TokenSpan
8
+
9
+
10
+ class ITokenizer(ABC):
11
+ @abstractmethod
12
+ def __call__(self, transcript: List[str]) -> List[List[str]]:
13
+ """Tokenize the given transcript (list of word)
14
+
15
+ .. note::
16
+
17
+ The toranscript must be normalized.
18
+
19
+ Args:
20
+ transcript (list of str): Transcript (list of word).
21
+
22
+ Returns:
23
+ (list of int): List of token sequences
24
+ """
25
+
26
+
27
+ class Tokenizer(ITokenizer):
28
+ def __init__(self, dictionary: Dict[str, int]):
29
+ self.dictionary = dictionary
30
+
31
+ def __call__(self, transcript: List[str]) -> List[List[int]]:
32
+ return [[self.dictionary[c] for c in word] for word in transcript]
33
+
34
+
35
+ def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
36
+ device = emission.device
37
+ emission = emission.unsqueeze(0)
38
+ targets = torch.tensor([tokens], dtype=torch.int32, device=device)
39
+
40
+ aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
41
+
42
+ scores = scores.exp() # convert back to probability
43
+ aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
44
+ return aligned_tokens, scores
45
+
46
+
47
+ class IAligner(ABC):
48
+ @abstractmethod
49
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
50
+ """Generate list of time-stamped token sequences
51
+
52
+ Args:
53
+ emission (Tensor): Sequence of token probability distributions in log-domain.
54
+ Shape: `(time, tokens)`.
55
+ tokens (list of integer sequence): Tokenized transcript.
56
+ Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
57
+
58
+ Returns:
59
+ (list of TokenSpan sequence): Tokens with time stamps and scores.
60
+ """
61
+
62
+
63
+ def _unflatten(list_, lengths):
64
+ assert len(list_) == sum(lengths)
65
+ i = 0
66
+ ret = []
67
+ for l in lengths:
68
+ ret.append(list_[i : i + l])
69
+ i += l
70
+ return ret
71
+
72
+
73
+ def _flatten(nested_list):
74
+ return [item for list_ in nested_list for item in list_]
75
+
76
+
77
+ class Aligner(IAligner):
78
+ def __init__(self, blank):
79
+ self.blank = blank
80
+
81
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
82
+ if emission.ndim != 2:
83
+ raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
84
+
85
+ aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
86
+ spans = F.merge_tokens(aligned_tokens, scores)
87
+ return _unflatten(spans, [len(ts) for ts in tokens])