torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,346 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+
6
+ from torchaudio._internal import load_state_dict_from_url
7
+ from torchaudio.models import wav2vec2_model, Wav2Vec2Model, wavlm_model
8
+
9
+
10
+ def _get_model(type_, params):
11
+ factories = {
12
+ "Wav2Vec2": wav2vec2_model,
13
+ "WavLM": wavlm_model,
14
+ }
15
+ if type_ not in factories:
16
+ raise ValueError(f"Supported model types are {tuple(factories.keys())}. Found: {type_}")
17
+ factory = factories[type_]
18
+ return factory(**params)
19
+
20
+
21
+ class _Wav2Vec2Model(nn.Module):
22
+ """Wrapper class for :py:class:`~torchaudio.models.Wav2Vec2Model`.
23
+
24
+ This is used for layer normalization at the input
25
+ """
26
+
27
+ def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool):
28
+ super().__init__()
29
+ self.model = model
30
+ self.normalize_waveform = normalize_waveform
31
+ self.apply_log_softmax = apply_log_softmax
32
+ self.append_star = append_star
33
+
34
+ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
35
+ if self.normalize_waveform:
36
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
37
+ output, output_lengths = self.model(waveforms, lengths)
38
+ if self.apply_log_softmax:
39
+ output = torch.nn.functional.log_softmax(output, dim=-1)
40
+ if self.append_star:
41
+ star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
42
+ output = torch.cat((output, star_dim), dim=-1)
43
+ return output, output_lengths
44
+
45
+ @torch.jit.export
46
+ def extract_features(
47
+ self,
48
+ waveforms: Tensor,
49
+ lengths: Optional[Tensor] = None,
50
+ num_layers: Optional[int] = None,
51
+ ) -> Tuple[List[Tensor], Optional[Tensor]]:
52
+ if self.normalize_waveform:
53
+ waveforms = nn.functional.layer_norm(waveforms, waveforms.shape)
54
+ return self.model.extract_features(waveforms, lengths, num_layers)
55
+
56
+
57
+ def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False):
58
+ """Add extra transformations to the model"""
59
+ return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star)
60
+
61
+
62
+ def _remove_aux_axes(state_dict, axes):
63
+ # Remove the seemingly unnecessary axis
64
+ # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
65
+ # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
66
+ # but not used during the ASR training.
67
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
68
+ # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
69
+ #
70
+ # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
71
+ # that resembles mistake.
72
+ # The label `1` shows up in the training dataset of German (1 out of 16M),
73
+ # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
74
+ for key in ["aux.weight", "aux.bias"]:
75
+ mat = state_dict[key]
76
+ state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes])
77
+
78
+
79
+ def _get_state_dict(url, dl_kwargs, remove_axes=None):
80
+ if not url.startswith("https"):
81
+ url = f"https://download.pytorch.org/torchaudio/models/{url}"
82
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
83
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
84
+ if remove_axes:
85
+ _remove_aux_axes(state_dict, remove_axes)
86
+ return state_dict
87
+
88
+
89
+ def _get_en_labels():
90
+ return (
91
+ "|",
92
+ "E",
93
+ "T",
94
+ "A",
95
+ "O",
96
+ "N",
97
+ "I",
98
+ "H",
99
+ "S",
100
+ "R",
101
+ "D",
102
+ "L",
103
+ "U",
104
+ "M",
105
+ "W",
106
+ "C",
107
+ "F",
108
+ "G",
109
+ "Y",
110
+ "P",
111
+ "B",
112
+ "V",
113
+ "K",
114
+ "'",
115
+ "X",
116
+ "J",
117
+ "Q",
118
+ "Z",
119
+ )
120
+
121
+
122
+ def _get_de_labels():
123
+ return (
124
+ "|",
125
+ "e",
126
+ "n",
127
+ "i",
128
+ "r",
129
+ "s",
130
+ "t",
131
+ "a",
132
+ "d",
133
+ "h",
134
+ "u",
135
+ "l",
136
+ "g",
137
+ "c",
138
+ "m",
139
+ "o",
140
+ "b",
141
+ "w",
142
+ "f",
143
+ "k",
144
+ "z",
145
+ "p",
146
+ "v",
147
+ "ü",
148
+ "ä",
149
+ "ö",
150
+ "j",
151
+ "ß",
152
+ "y",
153
+ "x",
154
+ "q",
155
+ )
156
+
157
+
158
+ def _get_vp_en_labels():
159
+ return (
160
+ "|",
161
+ "e",
162
+ "t",
163
+ "o",
164
+ "i",
165
+ "a",
166
+ "n",
167
+ "s",
168
+ "r",
169
+ "h",
170
+ "l",
171
+ "d",
172
+ "c",
173
+ "u",
174
+ "m",
175
+ "p",
176
+ "f",
177
+ "g",
178
+ "w",
179
+ "y",
180
+ "b",
181
+ "v",
182
+ "k",
183
+ "x",
184
+ "j",
185
+ "q",
186
+ "z",
187
+ )
188
+
189
+
190
+ def _get_es_labels():
191
+ return (
192
+ "|",
193
+ "e",
194
+ "a",
195
+ "o",
196
+ "s",
197
+ "n",
198
+ "r",
199
+ "i",
200
+ "l",
201
+ "d",
202
+ "c",
203
+ "t",
204
+ "u",
205
+ "p",
206
+ "m",
207
+ "b",
208
+ "q",
209
+ "y",
210
+ "g",
211
+ "v",
212
+ "h",
213
+ "ó",
214
+ "f",
215
+ "í",
216
+ "á",
217
+ "j",
218
+ "z",
219
+ "ñ",
220
+ "é",
221
+ "x",
222
+ "ú",
223
+ "k",
224
+ "w",
225
+ "ü",
226
+ )
227
+
228
+
229
+ def _get_fr_labels():
230
+ return (
231
+ "|",
232
+ "e",
233
+ "s",
234
+ "n",
235
+ "i",
236
+ "t",
237
+ "r",
238
+ "a",
239
+ "o",
240
+ "u",
241
+ "l",
242
+ "d",
243
+ "c",
244
+ "p",
245
+ "m",
246
+ "é",
247
+ "v",
248
+ "q",
249
+ "f",
250
+ "g",
251
+ "b",
252
+ "h",
253
+ "x",
254
+ "à",
255
+ "j",
256
+ "è",
257
+ "y",
258
+ "ê",
259
+ "z",
260
+ "ô",
261
+ "k",
262
+ "ç",
263
+ "œ",
264
+ "û",
265
+ "ù",
266
+ "î",
267
+ "â",
268
+ "w",
269
+ "ï",
270
+ "ë",
271
+ "ü",
272
+ "æ",
273
+ )
274
+
275
+
276
+ def _get_it_labels():
277
+ return (
278
+ "|",
279
+ "e",
280
+ "i",
281
+ "a",
282
+ "o",
283
+ "n",
284
+ "t",
285
+ "r",
286
+ "l",
287
+ "s",
288
+ "c",
289
+ "d",
290
+ "u",
291
+ "p",
292
+ "m",
293
+ "g",
294
+ "v",
295
+ "h",
296
+ "z",
297
+ "f",
298
+ "b",
299
+ "q",
300
+ "à",
301
+ "è",
302
+ "ù",
303
+ "é",
304
+ "ò",
305
+ "ì",
306
+ "k",
307
+ "y",
308
+ "x",
309
+ "w",
310
+ "j",
311
+ "ó",
312
+ "í",
313
+ "ï",
314
+ )
315
+
316
+
317
+ def _get_mms_labels():
318
+ return (
319
+ "a",
320
+ "i",
321
+ "e",
322
+ "n",
323
+ "o",
324
+ "u",
325
+ "t",
326
+ "s",
327
+ "r",
328
+ "m",
329
+ "k",
330
+ "l",
331
+ "d",
332
+ "g",
333
+ "h",
334
+ "y",
335
+ "b",
336
+ "p",
337
+ "w",
338
+ "c",
339
+ "v",
340
+ "j",
341
+ "z",
342
+ "f",
343
+ "'",
344
+ "q",
345
+ "x",
346
+ )