torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. torchaudio/__init__.py +204 -0
  2. torchaudio/_extension/__init__.py +61 -0
  3. torchaudio/_extension/utils.py +133 -0
  4. torchaudio/_internal/__init__.py +10 -0
  5. torchaudio/_internal/module_utils.py +171 -0
  6. torchaudio/_torchcodec.py +340 -0
  7. torchaudio/compliance/__init__.py +5 -0
  8. torchaudio/compliance/kaldi.py +813 -0
  9. torchaudio/datasets/__init__.py +47 -0
  10. torchaudio/datasets/cmuarctic.py +157 -0
  11. torchaudio/datasets/cmudict.py +186 -0
  12. torchaudio/datasets/commonvoice.py +86 -0
  13. torchaudio/datasets/dr_vctk.py +121 -0
  14. torchaudio/datasets/fluentcommands.py +108 -0
  15. torchaudio/datasets/gtzan.py +1118 -0
  16. torchaudio/datasets/iemocap.py +147 -0
  17. torchaudio/datasets/librilight_limited.py +111 -0
  18. torchaudio/datasets/librimix.py +133 -0
  19. torchaudio/datasets/librispeech.py +174 -0
  20. torchaudio/datasets/librispeech_biasing.py +189 -0
  21. torchaudio/datasets/libritts.py +168 -0
  22. torchaudio/datasets/ljspeech.py +107 -0
  23. torchaudio/datasets/musdb_hq.py +139 -0
  24. torchaudio/datasets/quesst14.py +136 -0
  25. torchaudio/datasets/snips.py +157 -0
  26. torchaudio/datasets/speechcommands.py +183 -0
  27. torchaudio/datasets/tedlium.py +218 -0
  28. torchaudio/datasets/utils.py +54 -0
  29. torchaudio/datasets/vctk.py +143 -0
  30. torchaudio/datasets/voxceleb1.py +309 -0
  31. torchaudio/datasets/yesno.py +89 -0
  32. torchaudio/functional/__init__.py +130 -0
  33. torchaudio/functional/_alignment.py +128 -0
  34. torchaudio/functional/filtering.py +1685 -0
  35. torchaudio/functional/functional.py +2505 -0
  36. torchaudio/lib/__init__.py +0 -0
  37. torchaudio/lib/_torchaudio.so +0 -0
  38. torchaudio/lib/libtorchaudio.so +0 -0
  39. torchaudio/models/__init__.py +85 -0
  40. torchaudio/models/_hdemucs.py +1008 -0
  41. torchaudio/models/conformer.py +293 -0
  42. torchaudio/models/conv_tasnet.py +330 -0
  43. torchaudio/models/decoder/__init__.py +64 -0
  44. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  45. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  46. torchaudio/models/deepspeech.py +84 -0
  47. torchaudio/models/emformer.py +884 -0
  48. torchaudio/models/rnnt.py +816 -0
  49. torchaudio/models/rnnt_decoder.py +339 -0
  50. torchaudio/models/squim/__init__.py +11 -0
  51. torchaudio/models/squim/objective.py +326 -0
  52. torchaudio/models/squim/subjective.py +150 -0
  53. torchaudio/models/tacotron2.py +1046 -0
  54. torchaudio/models/wav2letter.py +72 -0
  55. torchaudio/models/wav2vec2/__init__.py +45 -0
  56. torchaudio/models/wav2vec2/components.py +1167 -0
  57. torchaudio/models/wav2vec2/model.py +1579 -0
  58. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  59. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  60. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  61. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  62. torchaudio/models/wavernn.py +409 -0
  63. torchaudio/pipelines/__init__.py +102 -0
  64. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  65. torchaudio/pipelines/_squim_pipeline.py +156 -0
  66. torchaudio/pipelines/_tts/__init__.py +16 -0
  67. torchaudio/pipelines/_tts/impl.py +385 -0
  68. torchaudio/pipelines/_tts/interface.py +255 -0
  69. torchaudio/pipelines/_tts/utils.py +230 -0
  70. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  71. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  72. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  73. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  74. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  75. torchaudio/transforms/__init__.py +78 -0
  76. torchaudio/transforms/_multi_channel.py +467 -0
  77. torchaudio/transforms/_transforms.py +2138 -0
  78. torchaudio/utils/__init__.py +4 -0
  79. torchaudio/utils/download.py +89 -0
  80. torchaudio/version.py +2 -0
  81. torchaudio-2.9.1.dist-info/METADATA +133 -0
  82. torchaudio-2.9.1.dist-info/RECORD +85 -0
  83. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  84. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  85. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,884 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ __all__ = ["Emformer"]
8
+
9
+
10
+ def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
11
+ batch_size = lengths.shape[0]
12
+ max_length = int(torch.max(lengths).item())
13
+ padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
14
+ batch_size, max_length
15
+ ) >= lengths.unsqueeze(1)
16
+ return padding_mask
17
+
18
+
19
+ def _gen_padding_mask(
20
+ utterance: torch.Tensor,
21
+ right_context: torch.Tensor,
22
+ summary: torch.Tensor,
23
+ lengths: torch.Tensor,
24
+ mems: torch.Tensor,
25
+ left_context_key: Optional[torch.Tensor] = None,
26
+ ) -> Optional[torch.Tensor]:
27
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
28
+ B = right_context.size(1)
29
+ if B == 1:
30
+ padding_mask = None
31
+ else:
32
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
33
+ left_context_blocks_length = left_context_key.size(0) if left_context_key is not None else 0
34
+ klengths = lengths + mems.size(0) + right_context_blocks_length + left_context_blocks_length
35
+ padding_mask = _lengths_to_padding_mask(lengths=klengths)
36
+ return padding_mask
37
+
38
+
39
+ def _get_activation_module(activation: str) -> torch.nn.Module:
40
+ if activation == "relu":
41
+ return torch.nn.ReLU()
42
+ elif activation == "gelu":
43
+ return torch.nn.GELU()
44
+ elif activation == "silu":
45
+ return torch.nn.SiLU()
46
+ else:
47
+ raise ValueError(f"Unsupported activation {activation}")
48
+
49
+
50
+ def _get_weight_init_gains(weight_init_scale_strategy: Optional[str], num_layers: int) -> List[Optional[float]]:
51
+ if weight_init_scale_strategy is None:
52
+ return [None for _ in range(num_layers)]
53
+ elif weight_init_scale_strategy == "depthwise":
54
+ return [1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)]
55
+ elif weight_init_scale_strategy == "constant":
56
+ return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
57
+ else:
58
+ raise ValueError(f"Unsupported weight_init_scale_strategy value {weight_init_scale_strategy}")
59
+
60
+
61
+ def _gen_attention_mask_block(
62
+ col_widths: List[int], col_mask: List[bool], num_rows: int, device: torch.device
63
+ ) -> torch.Tensor:
64
+ if len(col_widths) != len(col_mask):
65
+ raise ValueError("Length of col_widths must match that of col_mask")
66
+
67
+ mask_block = [
68
+ torch.ones(num_rows, col_width, device=device)
69
+ if is_ones_col
70
+ else torch.zeros(num_rows, col_width, device=device)
71
+ for col_width, is_ones_col in zip(col_widths, col_mask)
72
+ ]
73
+ return torch.cat(mask_block, dim=1)
74
+
75
+
76
+ class _EmformerAttention(torch.nn.Module):
77
+ r"""Emformer layer attention module.
78
+
79
+ Args:
80
+ input_dim (int): input dimension.
81
+ num_heads (int): number of attention heads in each Emformer layer.
82
+ dropout (float, optional): dropout probability. (Default: 0.0)
83
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
84
+ attention module parameters. (Default: ``None``)
85
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
86
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ input_dim: int,
92
+ num_heads: int,
93
+ dropout: float = 0.0,
94
+ weight_init_gain: Optional[float] = None,
95
+ tanh_on_mem: bool = False,
96
+ negative_inf: float = -1e8,
97
+ ):
98
+ super().__init__()
99
+
100
+ if input_dim % num_heads != 0:
101
+ raise ValueError(f"input_dim ({input_dim}) is not a multiple of num_heads ({num_heads}).")
102
+
103
+ self.input_dim = input_dim
104
+ self.num_heads = num_heads
105
+ self.dropout = dropout
106
+ self.tanh_on_mem = tanh_on_mem
107
+ self.negative_inf = negative_inf
108
+
109
+ self.scaling = (self.input_dim // self.num_heads) ** -0.5
110
+
111
+ self.emb_to_key_value = torch.nn.Linear(input_dim, 2 * input_dim, bias=True)
112
+ self.emb_to_query = torch.nn.Linear(input_dim, input_dim, bias=True)
113
+ self.out_proj = torch.nn.Linear(input_dim, input_dim, bias=True)
114
+
115
+ if weight_init_gain:
116
+ torch.nn.init.xavier_uniform_(self.emb_to_key_value.weight, gain=weight_init_gain)
117
+ torch.nn.init.xavier_uniform_(self.emb_to_query.weight, gain=weight_init_gain)
118
+
119
+ def _gen_key_value(self, input: torch.Tensor, mems: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ T, _, _ = input.shape
121
+ summary_length = mems.size(0) + 1
122
+ right_ctx_utterance_block = input[: T - summary_length]
123
+ mems_right_ctx_utterance_block = torch.cat([mems, right_ctx_utterance_block])
124
+ key, value = self.emb_to_key_value(mems_right_ctx_utterance_block).chunk(chunks=2, dim=2)
125
+ return key, value
126
+
127
+ def _gen_attention_probs(
128
+ self,
129
+ attention_weights: torch.Tensor,
130
+ attention_mask: torch.Tensor,
131
+ padding_mask: Optional[torch.Tensor],
132
+ ) -> torch.Tensor:
133
+ attention_weights_float = attention_weights.float()
134
+ attention_weights_float = attention_weights_float.masked_fill(attention_mask.unsqueeze(0), self.negative_inf)
135
+ T = attention_weights.size(1)
136
+ B = attention_weights.size(0) // self.num_heads
137
+ if padding_mask is not None:
138
+ attention_weights_float = attention_weights_float.view(B, self.num_heads, T, -1)
139
+ attention_weights_float = attention_weights_float.masked_fill(
140
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf
141
+ )
142
+ attention_weights_float = attention_weights_float.view(B * self.num_heads, T, -1)
143
+ attention_probs = torch.nn.functional.softmax(attention_weights_float, dim=-1).type_as(attention_weights)
144
+ return torch.nn.functional.dropout(attention_probs, p=float(self.dropout), training=self.training)
145
+
146
+ def _forward_impl(
147
+ self,
148
+ utterance: torch.Tensor,
149
+ lengths: torch.Tensor,
150
+ right_context: torch.Tensor,
151
+ summary: torch.Tensor,
152
+ mems: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ left_context_key: Optional[torch.Tensor] = None,
155
+ left_context_val: Optional[torch.Tensor] = None,
156
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
157
+ B = utterance.size(1)
158
+ T = right_context.size(0) + utterance.size(0) + summary.size(0)
159
+
160
+ # Compute query with [right context, utterance, summary].
161
+ query = self.emb_to_query(torch.cat([right_context, utterance, summary]))
162
+
163
+ # Compute key and value with [mems, right context, utterance].
164
+ key, value = self.emb_to_key_value(torch.cat([mems, right_context, utterance])).chunk(chunks=2, dim=2)
165
+
166
+ if left_context_key is not None and left_context_val is not None:
167
+ right_context_blocks_length = T - torch.max(lengths).int() - summary.size(0)
168
+ key = torch.cat(
169
+ [
170
+ key[: mems.size(0) + right_context_blocks_length],
171
+ left_context_key,
172
+ key[mems.size(0) + right_context_blocks_length :],
173
+ ],
174
+ )
175
+ value = torch.cat(
176
+ [
177
+ value[: mems.size(0) + right_context_blocks_length],
178
+ left_context_val,
179
+ value[mems.size(0) + right_context_blocks_length :],
180
+ ],
181
+ )
182
+
183
+ # Compute attention weights from query, key, and value.
184
+ reshaped_query, reshaped_key, reshaped_value = [
185
+ tensor.contiguous().view(-1, B * self.num_heads, self.input_dim // self.num_heads).transpose(0, 1)
186
+ for tensor in [query, key, value]
187
+ ]
188
+ attention_weights = torch.bmm(reshaped_query * self.scaling, reshaped_key.transpose(1, 2))
189
+
190
+ # Compute padding mask.
191
+ padding_mask = _gen_padding_mask(utterance, right_context, summary, lengths, mems, left_context_key)
192
+
193
+ # Compute attention probabilities.
194
+ attention_probs = self._gen_attention_probs(attention_weights, attention_mask, padding_mask)
195
+
196
+ # Compute attention.
197
+ attention = torch.bmm(attention_probs, reshaped_value)
198
+ if attention.shape != (
199
+ B * self.num_heads,
200
+ T,
201
+ self.input_dim // self.num_heads,
202
+ ):
203
+ raise AssertionError("Computed attention has incorrect dimensions")
204
+ attention = attention.transpose(0, 1).contiguous().view(T, B, self.input_dim)
205
+
206
+ # Apply output projection.
207
+ output_right_context_mems = self.out_proj(attention)
208
+
209
+ summary_length = summary.size(0)
210
+ output_right_context = output_right_context_mems[: T - summary_length]
211
+ output_mems = output_right_context_mems[T - summary_length :]
212
+ if self.tanh_on_mem:
213
+ output_mems = torch.tanh(output_mems)
214
+ else:
215
+ output_mems = torch.clamp(output_mems, min=-10, max=10)
216
+
217
+ return output_right_context, output_mems, key, value
218
+
219
+ def forward(
220
+ self,
221
+ utterance: torch.Tensor,
222
+ lengths: torch.Tensor,
223
+ right_context: torch.Tensor,
224
+ summary: torch.Tensor,
225
+ mems: torch.Tensor,
226
+ attention_mask: torch.Tensor,
227
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
228
+ r"""Forward pass for training.
229
+
230
+ B: batch size;
231
+ D: feature dimension of each frame;
232
+ T: number of utterance frames;
233
+ R: number of right context frames;
234
+ S: number of summary elements;
235
+ M: number of memory elements.
236
+
237
+ Args:
238
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
239
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
240
+ number of valid frames for i-th batch element in ``utterance``.
241
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
242
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
243
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
244
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
245
+
246
+ Returns:
247
+ (Tensor, Tensor):
248
+ Tensor
249
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
250
+ Tensor
251
+ updated memory elements, with shape `(M, B, D)`.
252
+ """
253
+ output, output_mems, _, _ = self._forward_impl(utterance, lengths, right_context, summary, mems, attention_mask)
254
+ return output, output_mems[:-1]
255
+
256
+ @torch.jit.export
257
+ def infer(
258
+ self,
259
+ utterance: torch.Tensor,
260
+ lengths: torch.Tensor,
261
+ right_context: torch.Tensor,
262
+ summary: torch.Tensor,
263
+ mems: torch.Tensor,
264
+ left_context_key: torch.Tensor,
265
+ left_context_val: torch.Tensor,
266
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
267
+ r"""Forward pass for inference.
268
+
269
+ B: batch size;
270
+ D: feature dimension of each frame;
271
+ T: number of utterance frames;
272
+ R: number of right context frames;
273
+ S: number of summary elements;
274
+ M: number of memory elements.
275
+
276
+ Args:
277
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
278
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
279
+ number of valid frames for i-th batch element in ``utterance``.
280
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
281
+ summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
282
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
283
+ left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
284
+ left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
285
+
286
+ Returns:
287
+ (Tensor, Tensor, Tensor, and Tensor):
288
+ Tensor
289
+ output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
290
+ Tensor
291
+ updated memory elements, with shape `(M, B, D)`.
292
+ Tensor
293
+ attention key computed for left context and utterance.
294
+ Tensor
295
+ attention value computed for left context and utterance.
296
+ """
297
+ query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
298
+ key_dim = right_context.size(0) + utterance.size(0) + mems.size(0) + left_context_key.size(0)
299
+ attention_mask = torch.zeros(query_dim, key_dim).to(dtype=torch.bool, device=utterance.device)
300
+ attention_mask[-1, : mems.size(0)] = True
301
+ output, output_mems, key, value = self._forward_impl(
302
+ utterance,
303
+ lengths,
304
+ right_context,
305
+ summary,
306
+ mems,
307
+ attention_mask,
308
+ left_context_key=left_context_key,
309
+ left_context_val=left_context_val,
310
+ )
311
+ return (
312
+ output,
313
+ output_mems,
314
+ key[mems.size(0) + right_context.size(0) :],
315
+ value[mems.size(0) + right_context.size(0) :],
316
+ )
317
+
318
+
319
+ class _EmformerLayer(torch.nn.Module):
320
+ r"""Emformer layer that constitutes Emformer.
321
+
322
+ Args:
323
+ input_dim (int): input dimension.
324
+ num_heads (int): number of attention heads.
325
+ ffn_dim: (int): hidden layer dimension of feedforward network.
326
+ segment_length (int): length of each input segment.
327
+ dropout (float, optional): dropout probability. (Default: 0.0)
328
+ activation (str, optional): activation function to use in feedforward network.
329
+ Must be one of ("relu", "gelu", "silu"). (Default: "relu")
330
+ left_context_length (int, optional): length of left context. (Default: 0)
331
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
332
+ weight_init_gain (float or None, optional): scale factor to apply when initializing
333
+ attention module parameters. (Default: ``None``)
334
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
335
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ input_dim: int,
341
+ num_heads: int,
342
+ ffn_dim: int,
343
+ segment_length: int,
344
+ dropout: float = 0.0,
345
+ activation: str = "relu",
346
+ left_context_length: int = 0,
347
+ max_memory_size: int = 0,
348
+ weight_init_gain: Optional[float] = None,
349
+ tanh_on_mem: bool = False,
350
+ negative_inf: float = -1e8,
351
+ ):
352
+ super().__init__()
353
+
354
+ self.attention = _EmformerAttention(
355
+ input_dim=input_dim,
356
+ num_heads=num_heads,
357
+ dropout=dropout,
358
+ weight_init_gain=weight_init_gain,
359
+ tanh_on_mem=tanh_on_mem,
360
+ negative_inf=negative_inf,
361
+ )
362
+ self.dropout = torch.nn.Dropout(dropout)
363
+ self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
364
+
365
+ activation_module = _get_activation_module(activation)
366
+ self.pos_ff = torch.nn.Sequential(
367
+ torch.nn.LayerNorm(input_dim),
368
+ torch.nn.Linear(input_dim, ffn_dim),
369
+ activation_module,
370
+ torch.nn.Dropout(dropout),
371
+ torch.nn.Linear(ffn_dim, input_dim),
372
+ torch.nn.Dropout(dropout),
373
+ )
374
+ self.layer_norm_input = torch.nn.LayerNorm(input_dim)
375
+ self.layer_norm_output = torch.nn.LayerNorm(input_dim)
376
+
377
+ self.left_context_length = left_context_length
378
+ self.segment_length = segment_length
379
+ self.max_memory_size = max_memory_size
380
+ self.input_dim = input_dim
381
+
382
+ self.use_mem = max_memory_size > 0
383
+
384
+ def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
385
+ empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
386
+ left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
387
+ left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
388
+ past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
389
+ return [empty_memory, left_context_key, left_context_val, past_length]
390
+
391
+ def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
392
+ past_length = state[3][0][0].item()
393
+ past_left_context_length = min(self.left_context_length, past_length)
394
+ past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
395
+ pre_mems = state[0][self.max_memory_size - past_mem_length :]
396
+ lc_key = state[1][self.left_context_length - past_left_context_length :]
397
+ lc_val = state[2][self.left_context_length - past_left_context_length :]
398
+ return pre_mems, lc_key, lc_val
399
+
400
+ def _pack_state(
401
+ self,
402
+ next_k: torch.Tensor,
403
+ next_v: torch.Tensor,
404
+ update_length: int,
405
+ mems: torch.Tensor,
406
+ state: List[torch.Tensor],
407
+ ) -> List[torch.Tensor]:
408
+ new_k = torch.cat([state[1], next_k])
409
+ new_v = torch.cat([state[2], next_v])
410
+ state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
411
+ state[1] = new_k[new_k.shape[0] - self.left_context_length :]
412
+ state[2] = new_v[new_v.shape[0] - self.left_context_length :]
413
+ state[3] = state[3] + update_length
414
+ return state
415
+
416
+ def _process_attention_output(
417
+ self,
418
+ rc_output: torch.Tensor,
419
+ utterance: torch.Tensor,
420
+ right_context: torch.Tensor,
421
+ ) -> torch.Tensor:
422
+ result = self.dropout(rc_output) + torch.cat([right_context, utterance])
423
+ result = self.pos_ff(result) + result
424
+ result = self.layer_norm_output(result)
425
+ return result
426
+
427
+ def _apply_pre_attention_layer_norm(
428
+ self, utterance: torch.Tensor, right_context: torch.Tensor
429
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
430
+ layer_norm_input = self.layer_norm_input(torch.cat([right_context, utterance]))
431
+ return (
432
+ layer_norm_input[right_context.size(0) :],
433
+ layer_norm_input[: right_context.size(0)],
434
+ )
435
+
436
+ def _apply_post_attention_ffn(
437
+ self, rc_output: torch.Tensor, utterance: torch.Tensor, right_context: torch.Tensor
438
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
439
+ rc_output = self._process_attention_output(rc_output, utterance, right_context)
440
+ return rc_output[right_context.size(0) :], rc_output[: right_context.size(0)]
441
+
442
+ def _apply_attention_forward(
443
+ self,
444
+ utterance: torch.Tensor,
445
+ lengths: torch.Tensor,
446
+ right_context: torch.Tensor,
447
+ mems: torch.Tensor,
448
+ attention_mask: Optional[torch.Tensor],
449
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
450
+ if attention_mask is None:
451
+ raise ValueError("attention_mask must be not None when for_inference is False")
452
+
453
+ if self.use_mem:
454
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
455
+ else:
456
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
457
+ rc_output, next_m = self.attention(
458
+ utterance=utterance,
459
+ lengths=lengths,
460
+ right_context=right_context,
461
+ summary=summary,
462
+ mems=mems,
463
+ attention_mask=attention_mask,
464
+ )
465
+ return rc_output, next_m
466
+
467
+ def _apply_attention_infer(
468
+ self,
469
+ utterance: torch.Tensor,
470
+ lengths: torch.Tensor,
471
+ right_context: torch.Tensor,
472
+ mems: torch.Tensor,
473
+ state: Optional[List[torch.Tensor]],
474
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
475
+ if state is None:
476
+ state = self._init_state(utterance.size(1), device=utterance.device)
477
+ pre_mems, lc_key, lc_val = self._unpack_state(state)
478
+ if self.use_mem:
479
+ summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
480
+ summary = summary[:1]
481
+ else:
482
+ summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
483
+ rc_output, next_m, next_k, next_v = self.attention.infer(
484
+ utterance=utterance,
485
+ lengths=lengths,
486
+ right_context=right_context,
487
+ summary=summary,
488
+ mems=pre_mems,
489
+ left_context_key=lc_key,
490
+ left_context_val=lc_val,
491
+ )
492
+ state = self._pack_state(next_k, next_v, utterance.size(0), mems, state)
493
+ return rc_output, next_m, state
494
+
495
+ def forward(
496
+ self,
497
+ utterance: torch.Tensor,
498
+ lengths: torch.Tensor,
499
+ right_context: torch.Tensor,
500
+ mems: torch.Tensor,
501
+ attention_mask: torch.Tensor,
502
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
503
+ r"""Forward pass for training.
504
+
505
+ B: batch size;
506
+ D: feature dimension of each frame;
507
+ T: number of utterance frames;
508
+ R: number of right context frames;
509
+ M: number of memory elements.
510
+
511
+ Args:
512
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
513
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
514
+ number of valid frames for i-th batch element in ``utterance``.
515
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
516
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
517
+ attention_mask (torch.Tensor): attention mask for underlying attention module.
518
+
519
+ Returns:
520
+ (Tensor, Tensor, Tensor):
521
+ Tensor
522
+ encoded utterance frames, with shape `(T, B, D)`.
523
+ Tensor
524
+ updated right context frames, with shape `(R, B, D)`.
525
+ Tensor
526
+ updated memory elements, with shape `(M, B, D)`.
527
+ """
528
+ (
529
+ layer_norm_utterance,
530
+ layer_norm_right_context,
531
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
532
+ rc_output, output_mems = self._apply_attention_forward(
533
+ layer_norm_utterance,
534
+ lengths,
535
+ layer_norm_right_context,
536
+ mems,
537
+ attention_mask,
538
+ )
539
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
540
+ return output_utterance, output_right_context, output_mems
541
+
542
+ @torch.jit.export
543
+ def infer(
544
+ self,
545
+ utterance: torch.Tensor,
546
+ lengths: torch.Tensor,
547
+ right_context: torch.Tensor,
548
+ state: Optional[List[torch.Tensor]],
549
+ mems: torch.Tensor,
550
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
551
+ r"""Forward pass for inference.
552
+
553
+ B: batch size;
554
+ D: feature dimension of each frame;
555
+ T: number of utterance frames;
556
+ R: number of right context frames;
557
+ M: number of memory elements.
558
+
559
+ Args:
560
+ utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
561
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
562
+ number of valid frames for i-th batch element in ``utterance``.
563
+ right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
564
+ state (List[torch.Tensor] or None): list of tensors representing layer internal state
565
+ generated in preceding invocation of ``infer``.
566
+ mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
567
+
568
+ Returns:
569
+ (Tensor, Tensor, List[torch.Tensor], Tensor):
570
+ Tensor
571
+ encoded utterance frames, with shape `(T, B, D)`.
572
+ Tensor
573
+ updated right context frames, with shape `(R, B, D)`.
574
+ List[Tensor]
575
+ list of tensors representing layer internal state
576
+ generated in current invocation of ``infer``.
577
+ Tensor
578
+ updated memory elements, with shape `(M, B, D)`.
579
+ """
580
+ (
581
+ layer_norm_utterance,
582
+ layer_norm_right_context,
583
+ ) = self._apply_pre_attention_layer_norm(utterance, right_context)
584
+ rc_output, output_mems, output_state = self._apply_attention_infer(
585
+ layer_norm_utterance, lengths, layer_norm_right_context, mems, state
586
+ )
587
+ output_utterance, output_right_context = self._apply_post_attention_ffn(rc_output, utterance, right_context)
588
+ return output_utterance, output_right_context, output_state, output_mems
589
+
590
+
591
+ class _EmformerImpl(torch.nn.Module):
592
+ def __init__(
593
+ self,
594
+ emformer_layers: torch.nn.ModuleList,
595
+ segment_length: int,
596
+ left_context_length: int = 0,
597
+ right_context_length: int = 0,
598
+ max_memory_size: int = 0,
599
+ ):
600
+ super().__init__()
601
+
602
+ self.use_mem = max_memory_size > 0
603
+ self.memory_op = torch.nn.AvgPool1d(
604
+ kernel_size=segment_length,
605
+ stride=segment_length,
606
+ ceil_mode=True,
607
+ )
608
+ self.emformer_layers = emformer_layers
609
+ self.left_context_length = left_context_length
610
+ self.right_context_length = right_context_length
611
+ self.segment_length = segment_length
612
+ self.max_memory_size = max_memory_size
613
+
614
+ def _gen_right_context(self, input: torch.Tensor) -> torch.Tensor:
615
+ T = input.shape[0]
616
+ num_segs = math.ceil((T - self.right_context_length) / self.segment_length)
617
+ right_context_blocks = []
618
+ for seg_idx in range(num_segs - 1):
619
+ start = (seg_idx + 1) * self.segment_length
620
+ end = start + self.right_context_length
621
+ right_context_blocks.append(input[start:end])
622
+ right_context_blocks.append(input[T - self.right_context_length :])
623
+ return torch.cat(right_context_blocks)
624
+
625
+ def _gen_attention_mask_col_widths(self, seg_idx: int, utterance_length: int) -> List[int]:
626
+ num_segs = math.ceil(utterance_length / self.segment_length)
627
+ rc = self.right_context_length
628
+ lc = self.left_context_length
629
+ rc_start = seg_idx * rc
630
+ rc_end = rc_start + rc
631
+ seg_start = max(seg_idx * self.segment_length - lc, 0)
632
+ seg_end = min((seg_idx + 1) * self.segment_length, utterance_length)
633
+ rc_length = self.right_context_length * num_segs
634
+
635
+ if self.use_mem:
636
+ m_start = max(seg_idx - self.max_memory_size, 0)
637
+ mem_length = num_segs - 1
638
+ col_widths = [
639
+ m_start, # before memory
640
+ seg_idx - m_start, # memory
641
+ mem_length - seg_idx, # after memory
642
+ rc_start, # before right context
643
+ rc, # right context
644
+ rc_length - rc_end, # after right context
645
+ seg_start, # before query segment
646
+ seg_end - seg_start, # query segment
647
+ utterance_length - seg_end, # after query segment
648
+ ]
649
+ else:
650
+ col_widths = [
651
+ rc_start, # before right context
652
+ rc, # right context
653
+ rc_length - rc_end, # after right context
654
+ seg_start, # before query segment
655
+ seg_end - seg_start, # query segment
656
+ utterance_length - seg_end, # after query segment
657
+ ]
658
+
659
+ return col_widths
660
+
661
+ def _gen_attention_mask(self, input: torch.Tensor) -> torch.Tensor:
662
+ utterance_length = input.size(0)
663
+ num_segs = math.ceil(utterance_length / self.segment_length)
664
+
665
+ rc_mask = []
666
+ query_mask = []
667
+ summary_mask = []
668
+
669
+ if self.use_mem:
670
+ num_cols = 9
671
+ # memory, right context, query segment
672
+ rc_q_cols_mask = [idx in [1, 4, 7] for idx in range(num_cols)]
673
+ # right context, query segment
674
+ s_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
675
+ masks_to_concat = [rc_mask, query_mask, summary_mask]
676
+ else:
677
+ num_cols = 6
678
+ # right context, query segment
679
+ rc_q_cols_mask = [idx in [1, 4] for idx in range(num_cols)]
680
+ s_cols_mask = None
681
+ masks_to_concat = [rc_mask, query_mask]
682
+
683
+ for seg_idx in range(num_segs):
684
+ col_widths = self._gen_attention_mask_col_widths(seg_idx, utterance_length)
685
+
686
+ rc_mask_block = _gen_attention_mask_block(
687
+ col_widths, rc_q_cols_mask, self.right_context_length, input.device
688
+ )
689
+ rc_mask.append(rc_mask_block)
690
+
691
+ query_mask_block = _gen_attention_mask_block(
692
+ col_widths,
693
+ rc_q_cols_mask,
694
+ min(
695
+ self.segment_length,
696
+ utterance_length - seg_idx * self.segment_length,
697
+ ),
698
+ input.device,
699
+ )
700
+ query_mask.append(query_mask_block)
701
+
702
+ if s_cols_mask is not None:
703
+ summary_mask_block = _gen_attention_mask_block(col_widths, s_cols_mask, 1, input.device)
704
+ summary_mask.append(summary_mask_block)
705
+
706
+ attention_mask = (1 - torch.cat([torch.cat(mask) for mask in masks_to_concat])).to(torch.bool)
707
+ return attention_mask
708
+
709
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
710
+ r"""Forward pass for training and non-streaming inference.
711
+
712
+ B: batch size;
713
+ T: max number of input frames in batch;
714
+ D: feature dimension of each frame.
715
+
716
+ Args:
717
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
718
+ shape `(B, T + right_context_length, D)`.
719
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
720
+ number of valid utterance frames for i-th batch element in ``input``.
721
+
722
+ Returns:
723
+ (Tensor, Tensor):
724
+ Tensor
725
+ output frames, with shape `(B, T, D)`.
726
+ Tensor
727
+ output lengths, with shape `(B,)` and i-th element representing
728
+ number of valid frames for i-th batch element in output frames.
729
+ """
730
+ input = input.permute(1, 0, 2)
731
+ right_context = self._gen_right_context(input)
732
+ utterance = input[: input.size(0) - self.right_context_length]
733
+ attention_mask = self._gen_attention_mask(utterance)
734
+ mems = (
735
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1]
736
+ if self.use_mem
737
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
738
+ )
739
+ output = utterance
740
+ for layer in self.emformer_layers:
741
+ output, right_context, mems = layer(output, lengths, right_context, mems, attention_mask)
742
+ return output.permute(1, 0, 2), lengths
743
+
744
+ @torch.jit.export
745
+ def infer(
746
+ self,
747
+ input: torch.Tensor,
748
+ lengths: torch.Tensor,
749
+ states: Optional[List[List[torch.Tensor]]] = None,
750
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
751
+ r"""Forward pass for streaming inference.
752
+
753
+ B: batch size;
754
+ D: feature dimension of each frame.
755
+
756
+ Args:
757
+ input (torch.Tensor): utterance frames right-padded with right context frames, with
758
+ shape `(B, segment_length + right_context_length, D)`.
759
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
760
+ number of valid frames for i-th batch element in ``input``.
761
+ states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
762
+ representing internal state generated in preceding invocation of ``infer``. (Default: ``None``)
763
+
764
+ Returns:
765
+ (Tensor, Tensor, List[List[Tensor]]):
766
+ Tensor
767
+ output frames, with shape `(B, segment_length, D)`.
768
+ Tensor
769
+ output lengths, with shape `(B,)` and i-th element representing
770
+ number of valid frames for i-th batch element in output frames.
771
+ List[List[Tensor]]
772
+ output states; list of lists of tensors representing internal state
773
+ generated in current invocation of ``infer``.
774
+ """
775
+ if input.size(1) != self.segment_length + self.right_context_length:
776
+ raise ValueError(
777
+ "Per configured segment_length and right_context_length"
778
+ f", expected size of {self.segment_length + self.right_context_length} for dimension 1 of input"
779
+ f", but got {input.size(1)}."
780
+ )
781
+ input = input.permute(1, 0, 2)
782
+ right_context_start_idx = input.size(0) - self.right_context_length
783
+ right_context = input[right_context_start_idx:]
784
+ utterance = input[:right_context_start_idx]
785
+ output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
786
+ mems = (
787
+ self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
788
+ if self.use_mem
789
+ else torch.empty(0).to(dtype=input.dtype, device=input.device)
790
+ )
791
+ output = utterance
792
+ output_states: List[List[torch.Tensor]] = []
793
+ for layer_idx, layer in enumerate(self.emformer_layers):
794
+ output, right_context, output_state, mems = layer.infer(
795
+ output,
796
+ output_lengths,
797
+ right_context,
798
+ None if states is None else states[layer_idx],
799
+ mems,
800
+ )
801
+ output_states.append(output_state)
802
+
803
+ return output.permute(1, 0, 2), output_lengths, output_states
804
+
805
+
806
+ class Emformer(_EmformerImpl):
807
+ r"""Emformer architecture introduced in
808
+ *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
809
+ :cite:`shi2021emformer`.
810
+
811
+ See Also:
812
+ * :func:`~torchaudio.models.emformer_rnnt_model`,
813
+ :func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
814
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
815
+
816
+ Args:
817
+ input_dim (int): input dimension.
818
+ num_heads (int): number of attention heads in each Emformer layer.
819
+ ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
820
+ num_layers (int): number of Emformer layers to instantiate.
821
+ segment_length (int): length of each input segment.
822
+ dropout (float, optional): dropout probability. (Default: 0.0)
823
+ activation (str, optional): activation function to use in each Emformer layer's
824
+ feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
825
+ left_context_length (int, optional): length of left context. (Default: 0)
826
+ right_context_length (int, optional): length of right context. (Default: 0)
827
+ max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
828
+ weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
829
+ strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
830
+ tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
831
+ negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
832
+
833
+ Examples:
834
+ >>> emformer = Emformer(512, 8, 2048, 20, 4, right_context_length=1)
835
+ >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim
836
+ >>> lengths = torch.randint(1, 200, (128,)) # batch
837
+ >>> output, lengths = emformer(input, lengths)
838
+ >>> input = torch.rand(128, 5, 512)
839
+ >>> lengths = torch.ones(128) * 5
840
+ >>> output, lengths, states = emformer.infer(input, lengths, None)
841
+ """
842
+
843
+ def __init__(
844
+ self,
845
+ input_dim: int,
846
+ num_heads: int,
847
+ ffn_dim: int,
848
+ num_layers: int,
849
+ segment_length: int,
850
+ dropout: float = 0.0,
851
+ activation: str = "relu",
852
+ left_context_length: int = 0,
853
+ right_context_length: int = 0,
854
+ max_memory_size: int = 0,
855
+ weight_init_scale_strategy: Optional[str] = "depthwise",
856
+ tanh_on_mem: bool = False,
857
+ negative_inf: float = -1e8,
858
+ ):
859
+ weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
860
+ emformer_layers = torch.nn.ModuleList(
861
+ [
862
+ _EmformerLayer(
863
+ input_dim,
864
+ num_heads,
865
+ ffn_dim,
866
+ segment_length,
867
+ dropout=dropout,
868
+ activation=activation,
869
+ left_context_length=left_context_length,
870
+ max_memory_size=max_memory_size,
871
+ weight_init_gain=weight_init_gains[layer_idx],
872
+ tanh_on_mem=tanh_on_mem,
873
+ negative_inf=negative_inf,
874
+ )
875
+ for layer_idx in range(num_layers)
876
+ ]
877
+ )
878
+ super().__init__(
879
+ emformer_layers,
880
+ segment_length,
881
+ left_context_length=left_context_length,
882
+ right_context_length=right_context_length,
883
+ max_memory_size=max_memory_size,
884
+ )