torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.1.dist-info/METADATA +133 -0
  83. torchaudio-2.9.1.dist-info/RECORD +86 -0
  84. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  85. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  86. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,816 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torchaudio.models import Emformer
6
+
7
+
8
+ __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
9
+
10
+
11
+ class _TimeReduction(torch.nn.Module):
12
+ r"""Coalesces frames along time dimension into a
13
+ fewer number of frames with higher feature dimensionality.
14
+
15
+ Args:
16
+ stride (int): number of frames to merge for each output frame.
17
+ """
18
+
19
+ def __init__(self, stride: int) -> None:
20
+ super().__init__()
21
+ self.stride = stride
22
+
23
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""Forward pass.
25
+
26
+ B: batch size;
27
+ T: maximum input sequence length in batch;
28
+ D: feature dimension of each input sequence frame.
29
+
30
+ Args:
31
+ input (torch.Tensor): input sequences, with shape `(B, T, D)`.
32
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
33
+ number of valid frames for i-th batch element in ``input``.
34
+
35
+ Returns:
36
+ (torch.Tensor, torch.Tensor):
37
+ torch.Tensor
38
+ output sequences, with shape
39
+ `(B, T // stride, D * stride)`
40
+ torch.Tensor
41
+ output lengths, with shape `(B,)` and i-th element representing
42
+ number of valid frames for i-th batch element in output sequences.
43
+ """
44
+ B, T, D = input.shape
45
+ num_frames = T - (T % self.stride)
46
+ input = input[:, :num_frames, :]
47
+ lengths = lengths.div(self.stride, rounding_mode="trunc")
48
+ T_max = num_frames // self.stride
49
+
50
+ output = input.reshape(B, T_max, D * self.stride)
51
+ output = output.contiguous()
52
+ return output, lengths
53
+
54
+
55
+ class _CustomLSTM(torch.nn.Module):
56
+ r"""Custom long-short-term memory (LSTM) block that applies layer normalization
57
+ to internal nodes.
58
+
59
+ Args:
60
+ input_dim (int): input dimension.
61
+ hidden_dim (int): hidden dimension.
62
+ layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
63
+ layer_norm_epsilon (float, optional): value of epsilon to use in
64
+ layer normalization layers (Default: 1e-5)
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ input_dim: int,
70
+ hidden_dim: int,
71
+ layer_norm: bool = False,
72
+ layer_norm_epsilon: float = 1e-5,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
76
+ self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
77
+ if layer_norm:
78
+ self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
79
+ self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
80
+ else:
81
+ self.c_norm = torch.nn.Identity()
82
+ self.g_norm = torch.nn.Identity()
83
+
84
+ self.hidden_dim = hidden_dim
85
+
86
+ def forward(
87
+ self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
88
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
89
+ r"""Forward pass.
90
+
91
+ B: batch size;
92
+ T: maximum sequence length in batch;
93
+ D: feature dimension of each input sequence element.
94
+
95
+ Args:
96
+ input (torch.Tensor): with shape `(T, B, D)`.
97
+ state (List[torch.Tensor] or None): list of tensors
98
+ representing internal state generated in preceding invocation
99
+ of ``forward``.
100
+
101
+ Returns:
102
+ (torch.Tensor, List[torch.Tensor]):
103
+ torch.Tensor
104
+ output, with shape `(T, B, hidden_dim)`.
105
+ List[torch.Tensor]
106
+ list of tensors representing internal state generated
107
+ in current invocation of ``forward``.
108
+ """
109
+ if state is None:
110
+ B = input.size(1)
111
+ h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
112
+ c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
113
+ else:
114
+ h, c = state
115
+
116
+ gated_input = self.x2g(input)
117
+ outputs = []
118
+ for gates in gated_input.unbind(0):
119
+ gates = gates + self.p2g(h)
120
+ gates = self.g_norm(gates)
121
+ input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
122
+ input_gate = input_gate.sigmoid()
123
+ forget_gate = forget_gate.sigmoid()
124
+ cell_gate = cell_gate.tanh()
125
+ output_gate = output_gate.sigmoid()
126
+ c = forget_gate * c + input_gate * cell_gate
127
+ c = self.c_norm(c)
128
+ h = output_gate * c.tanh()
129
+ outputs.append(h)
130
+
131
+ output = torch.stack(outputs, dim=0)
132
+ state = [h, c]
133
+
134
+ return output, state
135
+
136
+
137
+ class _Transcriber(ABC):
138
+ @abstractmethod
139
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ pass
141
+
142
+ @abstractmethod
143
+ def infer(
144
+ self,
145
+ input: torch.Tensor,
146
+ lengths: torch.Tensor,
147
+ states: Optional[List[List[torch.Tensor]]],
148
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
149
+ pass
150
+
151
+
152
+ class _EmformerEncoder(torch.nn.Module, _Transcriber):
153
+ r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
154
+
155
+ Args:
156
+ input_dim (int): feature dimension of each input sequence element.
157
+ output_dim (int): feature dimension of each output sequence element.
158
+ segment_length (int): length of input segment expressed as number of frames.
159
+ right_context_length (int): length of right context expressed as number of frames.
160
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
161
+ prior to applying time reduction block.
162
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
163
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
164
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
165
+ transformer_num_layers (int): number of Emformer layers to instantiate.
166
+ transformer_left_context_length (int): length of left context.
167
+ transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
168
+ transformer_activation (str, optional): activation function to use in each Emformer layer's
169
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
170
+ transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
171
+ transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
172
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
173
+ transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ *,
179
+ input_dim: int,
180
+ output_dim: int,
181
+ segment_length: int,
182
+ right_context_length: int,
183
+ time_reduction_input_dim: int,
184
+ time_reduction_stride: int,
185
+ transformer_num_heads: int,
186
+ transformer_ffn_dim: int,
187
+ transformer_num_layers: int,
188
+ transformer_left_context_length: int,
189
+ transformer_dropout: float = 0.0,
190
+ transformer_activation: str = "relu",
191
+ transformer_max_memory_size: int = 0,
192
+ transformer_weight_init_scale_strategy: str = "depthwise",
193
+ transformer_tanh_on_mem: bool = False,
194
+ ) -> None:
195
+ super().__init__()
196
+ self.input_linear = torch.nn.Linear(
197
+ input_dim,
198
+ time_reduction_input_dim,
199
+ bias=False,
200
+ )
201
+ self.time_reduction = _TimeReduction(time_reduction_stride)
202
+ transformer_input_dim = time_reduction_input_dim * time_reduction_stride
203
+ self.transformer = Emformer(
204
+ transformer_input_dim,
205
+ transformer_num_heads,
206
+ transformer_ffn_dim,
207
+ transformer_num_layers,
208
+ segment_length // time_reduction_stride,
209
+ dropout=transformer_dropout,
210
+ activation=transformer_activation,
211
+ left_context_length=transformer_left_context_length,
212
+ right_context_length=right_context_length // time_reduction_stride,
213
+ max_memory_size=transformer_max_memory_size,
214
+ weight_init_scale_strategy=transformer_weight_init_scale_strategy,
215
+ tanh_on_mem=transformer_tanh_on_mem,
216
+ )
217
+ self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
218
+ self.layer_norm = torch.nn.LayerNorm(output_dim)
219
+
220
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221
+ r"""Forward pass for training.
222
+
223
+ B: batch size;
224
+ T: maximum input sequence length in batch;
225
+ D: feature dimension of each input sequence frame (input_dim).
226
+
227
+ Args:
228
+ input (torch.Tensor): input frame sequences right-padded with right context, with
229
+ shape `(B, T + right context length, D)`.
230
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
231
+ number of valid frames for i-th batch element in ``input``.
232
+
233
+ Returns:
234
+ (torch.Tensor, torch.Tensor):
235
+ torch.Tensor
236
+ output frame sequences, with
237
+ shape `(B, T // time_reduction_stride, output_dim)`.
238
+ torch.Tensor
239
+ output input lengths, with shape `(B,)` and i-th element representing
240
+ number of valid elements for i-th batch element in output frame sequences.
241
+ """
242
+ input_linear_out = self.input_linear(input)
243
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
244
+ transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
245
+ output_linear_out = self.output_linear(transformer_out)
246
+ layer_norm_out = self.layer_norm(output_linear_out)
247
+ return layer_norm_out, transformer_lengths
248
+
249
+ @torch.jit.export
250
+ def infer(
251
+ self,
252
+ input: torch.Tensor,
253
+ lengths: torch.Tensor,
254
+ states: Optional[List[List[torch.Tensor]]],
255
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
256
+ r"""Forward pass for inference.
257
+
258
+ B: batch size;
259
+ T: maximum input sequence segment length in batch;
260
+ D: feature dimension of each input sequence frame (input_dim).
261
+
262
+ Args:
263
+ input (torch.Tensor): input frame sequence segments right-padded with right context, with
264
+ shape `(B, T + right context length, D)`.
265
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
266
+ number of valid frames for i-th batch element in ``input``.
267
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
268
+ representing internal state generated in preceding invocation
269
+ of ``infer``.
270
+
271
+ Returns:
272
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
273
+ torch.Tensor
274
+ output frame sequences, with
275
+ shape `(B, T // time_reduction_stride, output_dim)`.
276
+ torch.Tensor
277
+ output input lengths, with shape `(B,)` and i-th element representing
278
+ number of valid elements for i-th batch element in output.
279
+ List[List[torch.Tensor]]
280
+ output states; list of lists of tensors
281
+ representing internal state generated in current invocation
282
+ of ``infer``.
283
+ """
284
+ input_linear_out = self.input_linear(input)
285
+ time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
286
+ (
287
+ transformer_out,
288
+ transformer_lengths,
289
+ transformer_states,
290
+ ) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
291
+ output_linear_out = self.output_linear(transformer_out)
292
+ layer_norm_out = self.layer_norm(output_linear_out)
293
+ return layer_norm_out, transformer_lengths, transformer_states
294
+
295
+
296
+ class _Predictor(torch.nn.Module):
297
+ r"""Recurrent neural network transducer (RNN-T) prediction network.
298
+
299
+ Args:
300
+ num_symbols (int): size of target token lexicon.
301
+ output_dim (int): feature dimension of each output sequence element.
302
+ symbol_embedding_dim (int): dimension of each target token embedding.
303
+ num_lstm_layers (int): number of LSTM layers to instantiate.
304
+ lstm_hidden_dim (int): output dimension of each LSTM layer.
305
+ lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
306
+ for LSTM layers. (Default: ``False``)
307
+ lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
308
+ LSTM layer normalization layers. (Default: 1e-5)
309
+ lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
310
+
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ num_symbols: int,
316
+ output_dim: int,
317
+ symbol_embedding_dim: int,
318
+ num_lstm_layers: int,
319
+ lstm_hidden_dim: int,
320
+ lstm_layer_norm: bool = False,
321
+ lstm_layer_norm_epsilon: float = 1e-5,
322
+ lstm_dropout: float = 0.0,
323
+ ) -> None:
324
+ super().__init__()
325
+ self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
326
+ self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
327
+ self.lstm_layers = torch.nn.ModuleList(
328
+ [
329
+ _CustomLSTM(
330
+ symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
331
+ lstm_hidden_dim,
332
+ layer_norm=lstm_layer_norm,
333
+ layer_norm_epsilon=lstm_layer_norm_epsilon,
334
+ )
335
+ for idx in range(num_lstm_layers)
336
+ ]
337
+ )
338
+ self.dropout = torch.nn.Dropout(p=lstm_dropout)
339
+ self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
340
+ self.output_layer_norm = torch.nn.LayerNorm(output_dim)
341
+
342
+ self.lstm_dropout = lstm_dropout
343
+
344
+ def forward(
345
+ self,
346
+ input: torch.Tensor,
347
+ lengths: torch.Tensor,
348
+ state: Optional[List[List[torch.Tensor]]] = None,
349
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
350
+ r"""Forward pass.
351
+
352
+ B: batch size;
353
+ U: maximum sequence length in batch;
354
+ D: feature dimension of each input sequence element.
355
+
356
+ Args:
357
+ input (torch.Tensor): target sequences, with shape `(B, U)` and each element
358
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
359
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
360
+ number of valid frames for i-th batch element in ``input``.
361
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
362
+ representing internal state generated in preceding invocation
363
+ of ``forward``. (Default: ``None``)
364
+
365
+ Returns:
366
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
367
+ torch.Tensor
368
+ output encoding sequences, with shape `(B, U, output_dim)`
369
+ torch.Tensor
370
+ output lengths, with shape `(B,)` and i-th element representing
371
+ number of valid elements for i-th batch element in output encoding sequences.
372
+ List[List[torch.Tensor]]
373
+ output states; list of lists of tensors
374
+ representing internal state generated in current invocation of ``forward``.
375
+ """
376
+ input_tb = input.permute(1, 0)
377
+ embedding_out = self.embedding(input_tb)
378
+ input_layer_norm_out = self.input_layer_norm(embedding_out)
379
+
380
+ lstm_out = input_layer_norm_out
381
+ state_out: List[List[torch.Tensor]] = []
382
+ for layer_idx, lstm in enumerate(self.lstm_layers):
383
+ lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
384
+ lstm_out = self.dropout(lstm_out)
385
+ state_out.append(lstm_state_out)
386
+
387
+ linear_out = self.linear(lstm_out)
388
+ output_layer_norm_out = self.output_layer_norm(linear_out)
389
+ return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
390
+
391
+
392
+ class _Joiner(torch.nn.Module):
393
+ r"""Recurrent neural network transducer (RNN-T) joint network.
394
+
395
+ Args:
396
+ input_dim (int): source and target input dimension.
397
+ output_dim (int): output dimension.
398
+ activation (str, optional): activation function to use in the joiner.
399
+ Must be one of ("relu", "tanh"). (Default: "relu")
400
+
401
+ """
402
+
403
+ def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
404
+ super().__init__()
405
+ self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
406
+ if activation == "relu":
407
+ self.activation = torch.nn.ReLU()
408
+ elif activation == "tanh":
409
+ self.activation = torch.nn.Tanh()
410
+ else:
411
+ raise ValueError(f"Unsupported activation {activation}")
412
+
413
+ def forward(
414
+ self,
415
+ source_encodings: torch.Tensor,
416
+ source_lengths: torch.Tensor,
417
+ target_encodings: torch.Tensor,
418
+ target_lengths: torch.Tensor,
419
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
420
+ r"""Forward pass for training.
421
+
422
+ B: batch size;
423
+ T: maximum source sequence length in batch;
424
+ U: maximum target sequence length in batch;
425
+ D: dimension of each source and target sequence encoding.
426
+
427
+ Args:
428
+ source_encodings (torch.Tensor): source encoding sequences, with
429
+ shape `(B, T, D)`.
430
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
431
+ valid sequence length of i-th batch element in ``source_encodings``.
432
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
433
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
434
+ valid sequence length of i-th batch element in ``target_encodings``.
435
+
436
+ Returns:
437
+ (torch.Tensor, torch.Tensor, torch.Tensor):
438
+ torch.Tensor
439
+ joint network output, with shape `(B, T, U, output_dim)`.
440
+ torch.Tensor
441
+ output source lengths, with shape `(B,)` and i-th element representing
442
+ number of valid elements along dim 1 for i-th batch element in joint network output.
443
+ torch.Tensor
444
+ output target lengths, with shape `(B,)` and i-th element representing
445
+ number of valid elements along dim 2 for i-th batch element in joint network output.
446
+ """
447
+ joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
448
+ activation_out = self.activation(joint_encodings)
449
+ output = self.linear(activation_out)
450
+ return output, source_lengths, target_lengths
451
+
452
+
453
+ class RNNT(torch.nn.Module):
454
+ r"""torchaudio.models.RNNT()
455
+
456
+ Recurrent neural network transducer (RNN-T) model.
457
+
458
+ Note:
459
+ To build the model, please use one of the factory functions.
460
+
461
+ See Also:
462
+ :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
463
+
464
+ Args:
465
+ transcriber (torch.nn.Module): transcription network.
466
+ predictor (torch.nn.Module): prediction network.
467
+ joiner (torch.nn.Module): joint network.
468
+ """
469
+
470
+ def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
471
+ super().__init__()
472
+ self.transcriber = transcriber
473
+ self.predictor = predictor
474
+ self.joiner = joiner
475
+
476
+ def forward(
477
+ self,
478
+ sources: torch.Tensor,
479
+ source_lengths: torch.Tensor,
480
+ targets: torch.Tensor,
481
+ target_lengths: torch.Tensor,
482
+ predictor_state: Optional[List[List[torch.Tensor]]] = None,
483
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
484
+ r"""Forward pass for training.
485
+
486
+ B: batch size;
487
+ T: maximum source sequence length in batch;
488
+ U: maximum target sequence length in batch;
489
+ D: feature dimension of each source sequence element.
490
+
491
+ Args:
492
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
493
+ shape `(B, T, D)`.
494
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
495
+ number of valid frames for i-th batch element in ``sources``.
496
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
497
+ mapping to a target symbol.
498
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
499
+ number of valid frames for i-th batch element in ``targets``.
500
+ predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
501
+ representing prediction network internal state generated in preceding invocation
502
+ of ``forward``. (Default: ``None``)
503
+
504
+ Returns:
505
+ (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
506
+ torch.Tensor
507
+ joint network output, with shape
508
+ `(B, max output source length, max output target length, output_dim (number of target symbols))`.
509
+ torch.Tensor
510
+ output source lengths, with shape `(B,)` and i-th element representing
511
+ number of valid elements along dim 1 for i-th batch element in joint network output.
512
+ torch.Tensor
513
+ output target lengths, with shape `(B,)` and i-th element representing
514
+ number of valid elements along dim 2 for i-th batch element in joint network output.
515
+ List[List[torch.Tensor]]
516
+ output states; list of lists of tensors
517
+ representing prediction network internal state generated in current invocation
518
+ of ``forward``.
519
+ """
520
+ source_encodings, source_lengths = self.transcriber(
521
+ input=sources,
522
+ lengths=source_lengths,
523
+ )
524
+ target_encodings, target_lengths, predictor_state = self.predictor(
525
+ input=targets,
526
+ lengths=target_lengths,
527
+ state=predictor_state,
528
+ )
529
+ output, source_lengths, target_lengths = self.joiner(
530
+ source_encodings=source_encodings,
531
+ source_lengths=source_lengths,
532
+ target_encodings=target_encodings,
533
+ target_lengths=target_lengths,
534
+ )
535
+
536
+ return (
537
+ output,
538
+ source_lengths,
539
+ target_lengths,
540
+ predictor_state,
541
+ )
542
+
543
+ @torch.jit.export
544
+ def transcribe_streaming(
545
+ self,
546
+ sources: torch.Tensor,
547
+ source_lengths: torch.Tensor,
548
+ state: Optional[List[List[torch.Tensor]]],
549
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
550
+ r"""Applies transcription network to sources in streaming mode.
551
+
552
+ B: batch size;
553
+ T: maximum source sequence segment length in batch;
554
+ D: feature dimension of each source sequence frame.
555
+
556
+ Args:
557
+ sources (torch.Tensor): source frame sequence segments right-padded with right context, with
558
+ shape `(B, T + right context length, D)`.
559
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
560
+ number of valid frames for i-th batch element in ``sources``.
561
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
562
+ representing transcription network internal state generated in preceding invocation
563
+ of ``transcribe_streaming``.
564
+
565
+ Returns:
566
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
567
+ torch.Tensor
568
+ output frame sequences, with
569
+ shape `(B, T // time_reduction_stride, output_dim)`.
570
+ torch.Tensor
571
+ output lengths, with shape `(B,)` and i-th element representing
572
+ number of valid elements for i-th batch element in output.
573
+ List[List[torch.Tensor]]
574
+ output states; list of lists of tensors
575
+ representing transcription network internal state generated in current invocation
576
+ of ``transcribe_streaming``.
577
+ """
578
+ return self.transcriber.infer(sources, source_lengths, state)
579
+
580
+ @torch.jit.export
581
+ def transcribe(
582
+ self,
583
+ sources: torch.Tensor,
584
+ source_lengths: torch.Tensor,
585
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
586
+ r"""Applies transcription network to sources in non-streaming mode.
587
+
588
+ B: batch size;
589
+ T: maximum source sequence length in batch;
590
+ D: feature dimension of each source sequence frame.
591
+
592
+ Args:
593
+ sources (torch.Tensor): source frame sequences right-padded with right context, with
594
+ shape `(B, T + right context length, D)`.
595
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
596
+ number of valid frames for i-th batch element in ``sources``.
597
+
598
+ Returns:
599
+ (torch.Tensor, torch.Tensor):
600
+ torch.Tensor
601
+ output frame sequences, with
602
+ shape `(B, T // time_reduction_stride, output_dim)`.
603
+ torch.Tensor
604
+ output lengths, with shape `(B,)` and i-th element representing
605
+ number of valid elements for i-th batch element in output frame sequences.
606
+ """
607
+ return self.transcriber(sources, source_lengths)
608
+
609
+ @torch.jit.export
610
+ def predict(
611
+ self,
612
+ targets: torch.Tensor,
613
+ target_lengths: torch.Tensor,
614
+ state: Optional[List[List[torch.Tensor]]],
615
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
616
+ r"""Applies prediction network to targets.
617
+
618
+ B: batch size;
619
+ U: maximum target sequence length in batch;
620
+ D: feature dimension of each target sequence frame.
621
+
622
+ Args:
623
+ targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
624
+ mapping to a target symbol, i.e. in range `[0, num_symbols)`.
625
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
626
+ number of valid frames for i-th batch element in ``targets``.
627
+ state (List[List[torch.Tensor]] or None): list of lists of tensors
628
+ representing internal state generated in preceding invocation
629
+ of ``predict``.
630
+
631
+ Returns:
632
+ (torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
633
+ torch.Tensor
634
+ output frame sequences, with shape `(B, U, output_dim)`.
635
+ torch.Tensor
636
+ output lengths, with shape `(B,)` and i-th element representing
637
+ number of valid elements for i-th batch element in output.
638
+ List[List[torch.Tensor]]
639
+ output states; list of lists of tensors
640
+ representing internal state generated in current invocation of ``predict``.
641
+ """
642
+ return self.predictor(input=targets, lengths=target_lengths, state=state)
643
+
644
+ @torch.jit.export
645
+ def join(
646
+ self,
647
+ source_encodings: torch.Tensor,
648
+ source_lengths: torch.Tensor,
649
+ target_encodings: torch.Tensor,
650
+ target_lengths: torch.Tensor,
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ r"""Applies joint network to source and target encodings.
653
+
654
+ B: batch size;
655
+ T: maximum source sequence length in batch;
656
+ U: maximum target sequence length in batch;
657
+ D: dimension of each source and target sequence encoding.
658
+
659
+ Args:
660
+ source_encodings (torch.Tensor): source encoding sequences, with
661
+ shape `(B, T, D)`.
662
+ source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
663
+ valid sequence length of i-th batch element in ``source_encodings``.
664
+ target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
665
+ target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
666
+ valid sequence length of i-th batch element in ``target_encodings``.
667
+
668
+ Returns:
669
+ (torch.Tensor, torch.Tensor, torch.Tensor):
670
+ torch.Tensor
671
+ joint network output, with shape `(B, T, U, output_dim)`.
672
+ torch.Tensor
673
+ output source lengths, with shape `(B,)` and i-th element representing
674
+ number of valid elements along dim 1 for i-th batch element in joint network output.
675
+ torch.Tensor
676
+ output target lengths, with shape `(B,)` and i-th element representing
677
+ number of valid elements along dim 2 for i-th batch element in joint network output.
678
+ """
679
+ output, source_lengths, target_lengths = self.joiner(
680
+ source_encodings=source_encodings,
681
+ source_lengths=source_lengths,
682
+ target_encodings=target_encodings,
683
+ target_lengths=target_lengths,
684
+ )
685
+ return output, source_lengths, target_lengths
686
+
687
+
688
+ def emformer_rnnt_model(
689
+ *,
690
+ input_dim: int,
691
+ encoding_dim: int,
692
+ num_symbols: int,
693
+ segment_length: int,
694
+ right_context_length: int,
695
+ time_reduction_input_dim: int,
696
+ time_reduction_stride: int,
697
+ transformer_num_heads: int,
698
+ transformer_ffn_dim: int,
699
+ transformer_num_layers: int,
700
+ transformer_dropout: float,
701
+ transformer_activation: str,
702
+ transformer_left_context_length: int,
703
+ transformer_max_memory_size: int,
704
+ transformer_weight_init_scale_strategy: str,
705
+ transformer_tanh_on_mem: bool,
706
+ symbol_embedding_dim: int,
707
+ num_lstm_layers: int,
708
+ lstm_layer_norm: bool,
709
+ lstm_layer_norm_epsilon: float,
710
+ lstm_dropout: float,
711
+ ) -> RNNT:
712
+ r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
713
+
714
+ Note:
715
+ For non-streaming inference, the expectation is for `transcribe` to be called on input
716
+ sequences right-concatenated with `right_context_length` frames.
717
+
718
+ For streaming inference, the expectation is for `transcribe_streaming` to be called
719
+ on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
720
+ frames.
721
+
722
+ Args:
723
+ input_dim (int): dimension of input sequence frames passed to transcription network.
724
+ encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
725
+ passed to joint network.
726
+ num_symbols (int): cardinality of set of target tokens.
727
+ segment_length (int): length of input segment expressed as number of frames.
728
+ right_context_length (int): length of right context expressed as number of frames.
729
+ time_reduction_input_dim (int): dimension to scale each element in input sequences to
730
+ prior to applying time reduction block.
731
+ time_reduction_stride (int): factor by which to reduce length of input sequence.
732
+ transformer_num_heads (int): number of attention heads in each Emformer layer.
733
+ transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
734
+ transformer_num_layers (int): number of Emformer layers to instantiate.
735
+ transformer_left_context_length (int): length of left context considered by Emformer.
736
+ transformer_dropout (float): Emformer dropout probability.
737
+ transformer_activation (str): activation function to use in each Emformer layer's
738
+ feedforward network. Must be one of ("relu", "gelu", "silu").
739
+ transformer_max_memory_size (int): maximum number of memory elements to use.
740
+ transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
741
+ strategy. Must be one of ("depthwise", "constant", ``None``).
742
+ transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
743
+ symbol_embedding_dim (int): dimension of each target token embedding.
744
+ num_lstm_layers (int): number of LSTM layers to instantiate.
745
+ lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
746
+ lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
747
+ lstm_dropout (float): LSTM dropout probability.
748
+
749
+ Returns:
750
+ RNNT:
751
+ Emformer RNN-T model.
752
+ """
753
+ encoder = _EmformerEncoder(
754
+ input_dim=input_dim,
755
+ output_dim=encoding_dim,
756
+ segment_length=segment_length,
757
+ right_context_length=right_context_length,
758
+ time_reduction_input_dim=time_reduction_input_dim,
759
+ time_reduction_stride=time_reduction_stride,
760
+ transformer_num_heads=transformer_num_heads,
761
+ transformer_ffn_dim=transformer_ffn_dim,
762
+ transformer_num_layers=transformer_num_layers,
763
+ transformer_dropout=transformer_dropout,
764
+ transformer_activation=transformer_activation,
765
+ transformer_left_context_length=transformer_left_context_length,
766
+ transformer_max_memory_size=transformer_max_memory_size,
767
+ transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
768
+ transformer_tanh_on_mem=transformer_tanh_on_mem,
769
+ )
770
+ predictor = _Predictor(
771
+ num_symbols,
772
+ encoding_dim,
773
+ symbol_embedding_dim=symbol_embedding_dim,
774
+ num_lstm_layers=num_lstm_layers,
775
+ lstm_hidden_dim=symbol_embedding_dim,
776
+ lstm_layer_norm=lstm_layer_norm,
777
+ lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
778
+ lstm_dropout=lstm_dropout,
779
+ )
780
+ joiner = _Joiner(encoding_dim, num_symbols)
781
+ return RNNT(encoder, predictor, joiner)
782
+
783
+
784
+ def emformer_rnnt_base(num_symbols: int) -> RNNT:
785
+ r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
786
+
787
+ Args:
788
+ num_symbols (int): The size of target token lexicon.
789
+
790
+ Returns:
791
+ RNNT:
792
+ Emformer RNN-T model.
793
+ """
794
+ return emformer_rnnt_model(
795
+ input_dim=80,
796
+ encoding_dim=1024,
797
+ num_symbols=num_symbols,
798
+ segment_length=16,
799
+ right_context_length=4,
800
+ time_reduction_input_dim=128,
801
+ time_reduction_stride=4,
802
+ transformer_num_heads=8,
803
+ transformer_ffn_dim=2048,
804
+ transformer_num_layers=20,
805
+ transformer_dropout=0.1,
806
+ transformer_activation="gelu",
807
+ transformer_left_context_length=30,
808
+ transformer_max_memory_size=0,
809
+ transformer_weight_init_scale_strategy="depthwise",
810
+ transformer_tanh_on_mem=True,
811
+ symbol_embedding_dim=512,
812
+ num_lstm_layers=3,
813
+ lstm_layer_norm=True,
814
+ lstm_layer_norm_epsilon=1e-3,
815
+ lstm_dropout=0.3,
816
+ )