torchaudio 2.7.1__cp313-cp313t-win_amd64.whl → 2.9.0__cp313-cp313t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +184 -33
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +68 -10
  5. torchaudio/_torchcodec.py +340 -0
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +6 -3
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +31 -61
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +19 -1
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +3 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +4 -1
  23. torchaudio/transforms/_transforms.py +4 -3
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -1
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/METADATA +15 -7
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -317
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -13
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -144
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -67
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -433
  54. torchaudio/prototype/functional/_rir.py +0 -379
  55. torchaudio/prototype/functional/functional.py +0 -190
  56. torchaudio/prototype/models/__init__.py +0 -36
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -794
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -333
  59. torchaudio/prototype/models/conv_emformer.py +0 -525
  60. torchaudio/prototype/models/hifi_gan.py +0 -336
  61. torchaudio/prototype/models/rnnt.py +0 -711
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -399
  63. torchaudio/prototype/pipelines/__init__.py +0 -12
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -3
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -233
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -82
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -228
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -456
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -272
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -99
  75. torchaudio-2.7.1.dist-info/RECORD +0 -144
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -978
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -247
  91. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/LICENSE +0 -0
  92. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
@@ -1,399 +0,0 @@
1
- from typing import Callable, Dict, List, Optional, Tuple
2
-
3
- import torch
4
- from torchaudio.models import RNNT
5
- from torchaudio.prototype.models.rnnt import TrieNode
6
-
7
- __all__ = ["Hypothesis", "RNNTBeamSearchBiasing"]
8
-
9
-
10
- Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float, list]
11
- Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
12
- represented as tuple of (tokens, prediction network output, prediction network state, score).
13
- """
14
-
15
-
16
- def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
17
- return hypo[0]
18
-
19
-
20
- def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
21
- return hypo[1]
22
-
23
-
24
- def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
25
- return hypo[2]
26
-
27
-
28
- def _get_hypo_score(hypo: Hypothesis) -> float:
29
- return hypo[3]
30
-
31
-
32
- def _get_hypo_trie(hypo: Hypothesis) -> TrieNode:
33
- return hypo[4]
34
-
35
-
36
- def _set_hypo_trie(hypo: Hypothesis, trie: TrieNode) -> None:
37
- hypo[4] = trie
38
-
39
-
40
- def _get_hypo_key(hypo: Hypothesis) -> str:
41
- return str(hypo[0])
42
-
43
-
44
- def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
45
- states: List[List[torch.Tensor]] = []
46
- for i in range(len(_get_hypo_state(hypos[0]))):
47
- batched_state_components: List[torch.Tensor] = []
48
- for j in range(len(_get_hypo_state(hypos[0])[i])):
49
- batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
50
- states.append(batched_state_components)
51
- return states
52
-
53
-
54
- def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
55
- idx_tensor = torch.tensor([idx], device=device)
56
- return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
57
-
58
-
59
- def _default_hypo_sort_key(hypo: Hypothesis) -> float:
60
- return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
61
-
62
-
63
- def _compute_updated_scores(
64
- hypos: List[Hypothesis],
65
- next_token_probs: torch.Tensor,
66
- beam_width: int,
67
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68
- hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
69
- nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
70
- nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
71
- nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
72
- nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
73
- return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
74
-
75
-
76
- def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
77
- for i, elem in enumerate(hypo_list):
78
- if _get_hypo_key(hypo) == _get_hypo_key(elem):
79
- del hypo_list[i]
80
- break
81
-
82
-
83
- class RNNTBeamSearchBiasing(torch.nn.Module):
84
- r"""Beam search decoder for RNN-T model with biasing support.
85
-
86
- Args:
87
- model (RNNT): RNN-T model to use.
88
- blank (int): index of blank token in vocabulary.
89
- temperature (float, optional): temperature to apply to joint network output.
90
- Larger values yield more uniform samples. (Default: 1.0)
91
- hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
92
- for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
93
- hypothesis score normalized by token sequence length. (Default: None)
94
- step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
95
- trie (list, optional): the prefix tree for TCPGen biasing
96
- biasing (bool, optional): If true, do biasing, otherwise use standard RNN-T support
97
- """
98
-
99
- def __init__(
100
- self,
101
- model: RNNT,
102
- blank: int,
103
- temperature: float = 1.0,
104
- hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
105
- step_max_tokens: int = 100,
106
- trie: TrieNode = None,
107
- biasing: bool = False,
108
- ) -> None:
109
- super().__init__()
110
- self.model = model
111
- self.blank = blank
112
- self.temperature = temperature
113
- self.resettrie = trie or []
114
- self.dobiasing = biasing
115
-
116
- if hypo_sort_key is None:
117
- self.hypo_sort_key = _default_hypo_sort_key
118
- else:
119
- self.hypo_sort_key = hypo_sort_key
120
-
121
- self.step_max_tokens = step_max_tokens
122
-
123
- def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
124
- if hypo is not None:
125
- token = _get_hypo_tokens(hypo)[-1]
126
- state = _get_hypo_state(hypo)
127
- else:
128
- token = self.blank
129
- state = None
130
-
131
- one_tensor = torch.tensor([1], device=device)
132
- pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
133
- init_hypo = ([token], pred_out[0].detach(), pred_state, 0.0, self.resettrie)
134
- return [init_hypo]
135
-
136
- def _get_trie_mask(self, trie):
137
- step_mask = torch.ones(len(self.model.char_list) + 1)
138
- step_mask[list(trie[0].keys())] = 0
139
- # step_mask[-1] = 0
140
- return step_mask
141
-
142
- def _get_generation_prob(self, trie):
143
- if len(trie[0].keys()) == 0:
144
- return True
145
- else:
146
- return False
147
-
148
- def _gen_next_token_probs(
149
- self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
150
- ) -> torch.Tensor:
151
- one_tensor = torch.tensor([1], device=device)
152
- predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
153
- if self.dobiasing:
154
- # Get valid subset of wordpieces
155
- trie_masks = torch.stack([self._get_trie_mask(_get_hypo_trie(h)) for h in hypos], dim=0)
156
- trie_masks = trie_masks.to(enc_out.device).unsqueeze(1) # beam_width, 1, nchars
157
- # Determine if there is any paths on the trie
158
- genprob_masks = torch.tensor([self._get_generation_prob(_get_hypo_trie(h)) for h in hypos]) # beam_width
159
- genprob_masks = genprob_masks.to(enc_out.device)
160
- # Forward TCPGen component
161
- last_tokens = torch.tensor([_get_hypo_tokens(h)[-1] for h in hypos]).unsqueeze(-1).to(enc_out.device)
162
- hptr, tcpgen_dist = self.model.forward_tcpgen(last_tokens, trie_masks, enc_out)
163
- else:
164
- hptr = None
165
- # hptr sent to joiner, if deepbiasing is True joiner will use it
166
- joined_out, _, joined_activation = self.model.join(
167
- enc_out,
168
- one_tensor,
169
- predictor_out,
170
- torch.tensor([1] * len(hypos), device=device),
171
- hptr=hptr,
172
- ) # [beam_width, 1, 1, num_tokens]
173
- if self.dobiasing:
174
- p_gen = torch.sigmoid(self.model.pointer_gate(torch.cat((joined_activation, hptr), dim=-1)))
175
- p_gen = p_gen.masked_fill(genprob_masks.view(p_gen.size(0), 1, 1, 1), 0)
176
- model_tu = torch.softmax(joined_out / self.temperature, dim=3)
177
- # assuming last token is blank
178
- p_not_null = 1.0 - model_tu[:, :, :, -1:]
179
- ptr_dist_fact = torch.cat([tcpgen_dist[:, :, :, :-2], tcpgen_dist[:, :, :, -1:]], dim=-1) * p_not_null
180
- ptr_gen_complement = tcpgen_dist[:, :, :, -1:] * p_gen
181
- p_partial = ptr_dist_fact[:, :, :, :-1] * p_gen + model_tu[:, :, :, :-1] * (1 - p_gen + ptr_gen_complement)
182
- p_final = torch.cat([p_partial, model_tu[:, :, :, -1:]], dim=-1)
183
- joined_out = torch.log(p_final)
184
- else:
185
- joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
186
- return joined_out[:, 0, 0]
187
-
188
- def _gen_b_hypos(
189
- self,
190
- b_hypos: List[Hypothesis],
191
- a_hypos: List[Hypothesis],
192
- next_token_probs: torch.Tensor,
193
- key_to_b_hypo: Dict[str, Hypothesis],
194
- ) -> List[Hypothesis]:
195
- for i in range(len(a_hypos)):
196
- h_a = a_hypos[i]
197
- append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
198
- if _get_hypo_key(h_a) in key_to_b_hypo:
199
- h_b = key_to_b_hypo[_get_hypo_key(h_a)]
200
- _remove_hypo(h_b, b_hypos)
201
- score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
202
- else:
203
- score = float(append_blank_score)
204
- h_b = (
205
- _get_hypo_tokens(h_a),
206
- _get_hypo_predictor_out(h_a),
207
- _get_hypo_state(h_a),
208
- score,
209
- _get_hypo_trie(h_a),
210
- )
211
- b_hypos.append(h_b)
212
- key_to_b_hypo[_get_hypo_key(h_b)] = h_b
213
- _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
214
- return [b_hypos[idx] for idx in sorted_idx]
215
-
216
- def _gen_a_hypos(
217
- self,
218
- a_hypos: List[Hypothesis],
219
- b_hypos: List[Hypothesis],
220
- next_token_probs: torch.Tensor,
221
- t: int,
222
- beam_width: int,
223
- device: torch.device,
224
- ) -> List[Hypothesis]:
225
- (
226
- nonblank_nbest_scores,
227
- nonblank_nbest_hypo_idx,
228
- nonblank_nbest_token,
229
- ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
230
-
231
- if len(b_hypos) < beam_width:
232
- b_nbest_score = -float("inf")
233
- else:
234
- b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
235
-
236
- base_hypos: List[Hypothesis] = []
237
- new_tokens: List[int] = []
238
- new_scores: List[float] = []
239
- for i in range(beam_width):
240
- score = float(nonblank_nbest_scores[i])
241
- if score > b_nbest_score:
242
- a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
243
- base_hypos.append(a_hypos[a_hypo_idx])
244
- new_tokens.append(int(nonblank_nbest_token[i]))
245
- new_scores.append(score)
246
-
247
- if base_hypos:
248
- new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
249
- else:
250
- new_hypos: List[Hypothesis] = []
251
-
252
- return new_hypos
253
-
254
- def _gen_new_hypos(
255
- self,
256
- base_hypos: List[Hypothesis],
257
- tokens: List[int],
258
- scores: List[float],
259
- t: int,
260
- device: torch.device,
261
- ) -> List[Hypothesis]:
262
- tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
263
- states = _batch_state(base_hypos)
264
- pred_out, _, pred_states = self.model.predict(
265
- tgt_tokens,
266
- torch.tensor([1] * len(base_hypos), device=device),
267
- states,
268
- )
269
- new_hypos: List[Hypothesis] = []
270
- for i, h_a in enumerate(base_hypos):
271
- new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
272
- if self.dobiasing:
273
- new_trie = self.model.get_tcpgen_step(tokens[i], _get_hypo_trie(h_a), self.resettrie)
274
- else:
275
- new_trie = self.resettrie
276
- new_hypos.append(
277
- (new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i], new_trie)
278
- )
279
- return new_hypos
280
-
281
- def _search(
282
- self,
283
- enc_out: torch.Tensor,
284
- hypo: Optional[Hypothesis],
285
- beam_width: int,
286
- ) -> List[Hypothesis]:
287
- n_time_steps = enc_out.shape[1]
288
- device = enc_out.device
289
-
290
- a_hypos: List[Hypothesis] = []
291
- b_hypos = self._init_b_hypos(hypo, device)
292
- for t in range(n_time_steps):
293
- a_hypos = b_hypos
294
- b_hypos = torch.jit.annotate(List[Hypothesis], [])
295
- key_to_b_hypo: Dict[str, Hypothesis] = {}
296
- symbols_current_t = 0
297
-
298
- while a_hypos:
299
- next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
300
- next_token_probs = next_token_probs.cpu()
301
- b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
302
-
303
- if symbols_current_t == self.step_max_tokens:
304
- break
305
-
306
- a_hypos = self._gen_a_hypos(
307
- a_hypos,
308
- b_hypos,
309
- next_token_probs,
310
- t,
311
- beam_width,
312
- device,
313
- )
314
- if a_hypos:
315
- symbols_current_t += 1
316
-
317
- _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
318
- b_hypos = [b_hypos[idx] for idx in sorted_idx]
319
-
320
- return b_hypos
321
-
322
- def forward(
323
- self,
324
- input: torch.Tensor,
325
- length: torch.Tensor,
326
- beam_width: int,
327
- ) -> List[Hypothesis]:
328
- r"""Performs beam search for the given input sequence.
329
-
330
- T: number of frames;
331
- D: feature dimension of each frame.
332
-
333
- Args:
334
- input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
335
- length (torch.Tensor): number of valid frames in input
336
- sequence, with shape () or (1,).
337
- beam_width (int): beam size to use during search.
338
-
339
- Returns:
340
- List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
341
- """
342
- if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
343
- raise ValueError("input must be of shape (T, D) or (1, T, D)")
344
- if input.dim() == 2:
345
- input = input.unsqueeze(0)
346
-
347
- if length.shape != () and length.shape != (1,):
348
- raise ValueError("length must be of shape () or (1,)")
349
- if input.dim() == 0:
350
- input = input.unsqueeze(0)
351
-
352
- enc_out, _ = self.model.transcribe(input, length)
353
- return self._search(enc_out, None, beam_width)
354
-
355
- @torch.jit.export
356
- def infer(
357
- self,
358
- input: torch.Tensor,
359
- length: torch.Tensor,
360
- beam_width: int,
361
- state: Optional[List[List[torch.Tensor]]] = None,
362
- hypothesis: Optional[Hypothesis] = None,
363
- ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
364
- r"""Performs beam search for the given input sequence in streaming mode.
365
-
366
- T: number of frames;
367
- D: feature dimension of each frame.
368
-
369
- Args:
370
- input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
371
- length (torch.Tensor): number of valid frames in input
372
- sequence, with shape () or (1,).
373
- beam_width (int): beam size to use during search.
374
- state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
375
- representing transcription network internal state generated in preceding
376
- invocation. (Default: ``None``)
377
- hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
378
- search with. (Default: ``None``)
379
-
380
- Returns:
381
- (List[Hypothesis], List[List[torch.Tensor]]):
382
- List[Hypothesis]
383
- top-``beam_width`` hypotheses found by beam search.
384
- List[List[torch.Tensor]]
385
- list of lists of tensors representing transcription network
386
- internal state generated in current invocation.
387
- """
388
- if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
389
- raise ValueError("input must be of shape (T, D) or (1, T, D)")
390
- if input.dim() == 2:
391
- input = input.unsqueeze(0)
392
-
393
- if length.shape != () and length.shape != (1,):
394
- raise ValueError("length must be of shape () or (1,)")
395
- if length.dim() == 0:
396
- length = length.unsqueeze(0)
397
-
398
- enc_out, _, state = self.model.transcribe_streaming(input, length, state)
399
- return self._search(enc_out, hypothesis, beam_width), state
@@ -1,12 +0,0 @@
1
- from ._vggish import VGGISH, VGGishBundle
2
- from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
3
- from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
4
-
5
- __all__ = [
6
- "EMFORMER_RNNT_BASE_MUSTC",
7
- "EMFORMER_RNNT_BASE_TEDLIUM3",
8
- "HIFIGAN_VOCODER_V3_LJSPEECH",
9
- "HiFiGANVocoderBundle",
10
- "VGGISH",
11
- "VGGishBundle",
12
- ]
@@ -1,3 +0,0 @@
1
- from ._vggish_pipeline import VGGISH, VGGishBundle
2
-
3
- __all__ = ["VGGISH", "VGGishBundle"]
@@ -1,233 +0,0 @@
1
- # Derived from torchvggish (https://github.com/harritaylor/torchvggish).
2
- # Copyright 2017 The TensorFlow Authors All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # ==============================================================================
16
-
17
- import math
18
-
19
- import torch
20
-
21
-
22
- _MEL_BREAK_FREQUENCY_HERTZ = 700.0
23
- _MEL_HIGH_FREQUENCY_Q = 1127.0
24
-
25
-
26
- _SAMPLE_RATE = 16000
27
- _STFT_WINDOW_LENGTH_SECONDS = 0.025
28
- _STFT_HOP_LENGTH_SECONDS = 0.010
29
- _MEL_MIN_HZ = 125
30
- _MEL_MAX_HZ = 7500
31
- _NUM_BANDS = 64
32
- _LOG_OFFSET = 0.01
33
- _EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
34
- _EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
35
-
36
-
37
- def _build_features_network():
38
- layers = []
39
-
40
- for input_dim, output_dim in [(1, 64), (64, 128)]:
41
- layers += [
42
- torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
43
- torch.nn.ReLU(inplace=True),
44
- torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
45
- ]
46
-
47
- for input_dim, output_dim in [(128, 256), (256, 512)]:
48
- layers += [
49
- torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
50
- torch.nn.ReLU(inplace=True),
51
- torch.nn.Conv2d(
52
- output_dim,
53
- output_dim,
54
- kernel_size=(3, 3),
55
- stride=(1, 1),
56
- padding=(1, 1),
57
- ),
58
- torch.nn.ReLU(inplace=True),
59
- torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
60
- ]
61
-
62
- return torch.nn.Sequential(*layers)
63
-
64
-
65
- def _build_embedding_network():
66
- return torch.nn.Sequential(
67
- torch.nn.Linear(512 * 4 * 6, 4096),
68
- torch.nn.ReLU(True),
69
- torch.nn.Linear(4096, 4096),
70
- torch.nn.ReLU(True),
71
- torch.nn.Linear(4096, 128),
72
- torch.nn.ReLU(True),
73
- )
74
-
75
-
76
- def _frame(data, window_length, hop_length):
77
- num_samples = data.shape[0]
78
- num_frames = 1 + int(math.floor((num_samples - window_length) / hop_length))
79
- shape = (num_frames, window_length) + data.shape[1:]
80
- strides = (data.stride()[0] * hop_length,) + data.stride()
81
- return torch.as_strided(data, shape, strides)
82
-
83
-
84
- def _stft_magnitude(signal, fft_length, hop_length=None, window_length=None):
85
- frames = _frame(signal, window_length, hop_length)
86
- window = torch.hann_window(window_length, periodic=True).to(signal.device)
87
- windowed_frames = frames * window
88
- return torch.abs(torch.fft.rfft(windowed_frames, int(fft_length)))
89
-
90
-
91
- def _hertz_to_mel(frequencies_hertz):
92
- return _MEL_HIGH_FREQUENCY_Q * torch.log(1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
93
-
94
-
95
- def _spectrogram_to_mel_matrix(
96
- num_mel_bins=20,
97
- num_spectrogram_bins=129,
98
- audio_sample_rate=8000,
99
- lower_edge_hertz=125.0,
100
- upper_edge_hertz=3800.0,
101
- ):
102
- nyquist_hertz = audio_sample_rate / 2.0
103
- if lower_edge_hertz < 0.0:
104
- raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
105
- if lower_edge_hertz >= upper_edge_hertz:
106
- raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % (lower_edge_hertz, upper_edge_hertz))
107
-
108
- if upper_edge_hertz > nyquist_hertz:
109
- raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % (upper_edge_hertz, nyquist_hertz))
110
- spectrogram_bins_hertz = torch.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
111
-
112
- spectrogram_bins_mel = _hertz_to_mel(spectrogram_bins_hertz)
113
- # The i'th mel band (starting from i=1) has center frequency
114
- # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
115
- # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
116
- # the band_edges_mel arrays.
117
- band_edges_mel = torch.linspace(
118
- _hertz_to_mel(torch.tensor(lower_edge_hertz)),
119
- _hertz_to_mel(torch.tensor(upper_edge_hertz)),
120
- num_mel_bins + 2,
121
- )
122
- # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
123
- # of spectrogram values.
124
- mel_weights_matrix = torch.empty((num_spectrogram_bins, num_mel_bins))
125
- for i in range(num_mel_bins):
126
- lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i : i + 3]
127
- # Calculate lower and upper slopes for every spectrogram bin.
128
- # Line segments are linear in the *mel* domain, not hertz.
129
- lower_slope = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
130
- upper_slope = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
131
-
132
- # .. then intersect them with each other and zero.
133
- mel_weights_matrix[:, i] = torch.maximum(torch.tensor(0.0), torch.minimum(lower_slope, upper_slope))
134
-
135
- # HTK excludes the spectrogram DC bin; make sure it always gets a zero
136
- # coefficient.
137
- mel_weights_matrix[0, :] = 0.0
138
- return mel_weights_matrix
139
-
140
-
141
- def _log_mel_spectrogram(
142
- data,
143
- audio_sample_rate=8000,
144
- log_offset=0.0,
145
- window_length_secs=0.025,
146
- hop_length_secs=0.010,
147
- **kwargs,
148
- ):
149
- window_length_samples = int(round(audio_sample_rate * window_length_secs))
150
- hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
151
- fft_length = 2 ** int(math.ceil(math.log(window_length_samples) / math.log(2.0)))
152
-
153
- spectrogram = _stft_magnitude(
154
- data,
155
- fft_length=fft_length,
156
- hop_length=hop_length_samples,
157
- window_length=window_length_samples,
158
- )
159
- mel_spectrogram = torch.matmul(
160
- spectrogram,
161
- _spectrogram_to_mel_matrix(
162
- num_spectrogram_bins=spectrogram.shape[1],
163
- audio_sample_rate=audio_sample_rate,
164
- **kwargs,
165
- ).to(spectrogram),
166
- )
167
- return torch.log(mel_spectrogram + log_offset)
168
-
169
-
170
- def _waveform_to_examples(data):
171
- # Compute log mel spectrogram features, with shape (n_frame, n_mel)
172
- log_mel = _log_mel_spectrogram(
173
- data,
174
- audio_sample_rate=_SAMPLE_RATE,
175
- log_offset=_LOG_OFFSET,
176
- window_length_secs=_STFT_WINDOW_LENGTH_SECONDS,
177
- hop_length_secs=_STFT_HOP_LENGTH_SECONDS,
178
- num_mel_bins=_NUM_BANDS,
179
- lower_edge_hertz=_MEL_MIN_HZ,
180
- upper_edge_hertz=_MEL_MAX_HZ,
181
- )
182
-
183
- # Frame features into examples, with shape (n_example, n_frame, n_mel)
184
- features_sample_rate = 1.0 / _STFT_HOP_LENGTH_SECONDS
185
- example_window_length = int(round(_EXAMPLE_WINDOW_SECONDS * features_sample_rate))
186
-
187
- example_hop_length = int(round(_EXAMPLE_HOP_SECONDS * features_sample_rate))
188
- log_mel_examples = _frame(log_mel, window_length=example_window_length, hop_length=example_hop_length)
189
-
190
- # (n_example, 1, n_frame, n_mel)
191
- return log_mel_examples.unsqueeze(1)
192
-
193
-
194
- class VGGish(torch.nn.Module):
195
- """Implementation of VGGish model :cite:`45611`."""
196
-
197
- def __init__(self):
198
- super().__init__()
199
-
200
- self.features_network = _build_features_network()
201
- self.embedding_network = _build_embedding_network()
202
-
203
- def forward(self, input: torch.Tensor) -> torch.Tensor:
204
- """
205
- Args:
206
- input (torch.Tensor): batch of spectrograms, with shape `(n_example, 1, n_frame, 64)`.
207
-
208
- Returns:
209
- torch.Tensor: model output, with shape `(n_example, 128)`.
210
- """
211
- x = self.features_network(input)
212
-
213
- x = x.permute(0, 2, 3, 1)
214
- x = x.reshape(x.size(0), -1)
215
-
216
- return self.embedding_network(x)
217
-
218
-
219
- class VGGishInputProcessor:
220
- """Converts raw waveforms to batches of examples to use as inputs to VGGish."""
221
-
222
- def __call__(self, input: torch.Tensor) -> torch.Tensor:
223
- """
224
- Args:
225
- input (torch.Tensor): waveform, with shape `(T,)`.
226
- sample_rate (int): sample rate of waveform in hertz.
227
-
228
- Returns:
229
- torch.Tensor: batch of examples to pass to VGGish, with shape `(n_example, 1, n_frame, 64)`.
230
- """
231
- if len(input.shape) != 1:
232
- raise ValueError("input waveform must have dimension of 1.")
233
- return _waveform_to_examples(input)