torchaudio 2.8.0__cp313-cp313t-win_amd64.whl → 2.9.0__cp313-cp313t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +179 -39
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +12 -3
  5. torchaudio/_torchcodec.py +73 -85
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +0 -2
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +26 -60
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +14 -2
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +1 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +1 -0
  23. torchaudio/transforms/_transforms.py +2 -2
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -3
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -350
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -20
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -150
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -68
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -441
  54. torchaudio/prototype/functional/_rir.py +0 -382
  55. torchaudio/prototype/functional/functional.py +0 -193
  56. torchaudio/prototype/models/__init__.py +0 -39
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -337
  59. torchaudio/prototype/models/conv_emformer.py +0 -529
  60. torchaudio/prototype/models/hifi_gan.py +0 -342
  61. torchaudio/prototype/models/rnnt.py +0 -717
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -402
  63. torchaudio/prototype/pipelines/__init__.py +0 -21
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -461
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -275
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -118
  75. torchaudio-2.8.0.dist-info/RECORD +0 -145
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -977
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -275
  91. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/LICENSE +0 -0
  92. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
@@ -1,529 +0,0 @@
1
- import math
2
- from typing import List, Optional, Tuple
3
-
4
- import torch
5
- from torchaudio.models.emformer import _EmformerAttention, _EmformerImpl, _get_weight_init_gains
6
- from torchaudio._internal.module_utils import dropping_class_support, dropping_support
7
-
8
-
9
-
10
- def _get_activation_module(activation: str) -> torch.nn.Module:
11
- if activation == "relu":
12
- return torch.nn.ReLU()
13
- elif activation == "gelu":
14
- return torch.nn.GELU()
15
- elif activation == "silu":
16
- return torch.nn.SiLU()
17
- else:
18
- raise ValueError(f"Unsupported activation {activation}")
19
-
20
-
21
- class _ResidualContainer(torch.nn.Module):
22
- def __init__(self, module: torch.nn.Module, output_weight: int):
23
- super().__init__()
24
- self.module = module
25
- self.output_weight = output_weight
26
-
27
- def forward(self, input: torch.Tensor):
28
- output = self.module(input)
29
- return output * self.output_weight + input
30
-
31
-
32
- class _ConvolutionModule(torch.nn.Module):
33
- def __init__(
34
- self,
35
- input_dim: int,
36
- segment_length: int,
37
- right_context_length: int,
38
- kernel_size: int,
39
- activation: str = "silu",
40
- dropout: float = 0.0,
41
- ):
42
- super().__init__()
43
- self.input_dim = input_dim
44
- self.segment_length = segment_length
45
- self.right_context_length = right_context_length
46
- self.state_size = kernel_size - 1
47
-
48
- self.pre_conv = torch.nn.Sequential(
49
- torch.nn.LayerNorm(input_dim), torch.nn.Linear(input_dim, 2 * input_dim, bias=True), torch.nn.GLU()
50
- )
51
- self.conv = torch.nn.Conv1d(
52
- in_channels=input_dim,
53
- out_channels=input_dim,
54
- kernel_size=kernel_size,
55
- stride=1,
56
- padding=0,
57
- groups=input_dim,
58
- )
59
- self.post_conv = torch.nn.Sequential(
60
- torch.nn.LayerNorm(input_dim),
61
- _get_activation_module(activation),
62
- torch.nn.Linear(input_dim, input_dim, bias=True),
63
- torch.nn.Dropout(p=dropout),
64
- )
65
-
66
- def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Tensor) -> torch.Tensor:
67
- T, B, D = right_context.size()
68
- if T % self.right_context_length != 0:
69
- raise ValueError("Tensor length should be divisible by its right context length")
70
- num_segments = T // self.right_context_length
71
- # (num_segments, right context length, B, D)
72
- right_context_segments = right_context.reshape(num_segments, self.right_context_length, B, D)
73
- right_context_segments = right_context_segments.permute(0, 2, 1, 3).reshape(
74
- num_segments * B, self.right_context_length, D
75
- )
76
-
77
- pad_segments = [] # [(kernel_size - 1, B, D), ...]
78
- for seg_idx in range(num_segments):
79
- end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
80
- start_idx = end_idx - self.state_size
81
- pad_segments.append(utterance[start_idx:end_idx, :, :])
82
-
83
- pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D)
84
- return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)
85
-
86
- def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor:
87
- # (num_segments * B, D, right_context_length)
88
- right_context = right_context.reshape(-1, B, self.input_dim, self.right_context_length)
89
- right_context = right_context.permute(0, 3, 1, 2)
90
- return right_context.reshape(-1, B, self.input_dim) # (right_context_length * num_segments, B, D)
91
-
92
- def forward(
93
- self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
94
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95
- input = torch.cat((right_context, utterance)) # input: (T, B, D)
96
- x = self.pre_conv(input)
97
- x_right_context, x_utterance = x[: right_context.size(0), :, :], x[right_context.size(0) :, :, :]
98
- x_utterance = x_utterance.permute(1, 2, 0) # (B, D, T_utterance)
99
-
100
- if state is None:
101
- state = torch.zeros(
102
- input.size(1),
103
- input.size(2),
104
- self.state_size,
105
- device=input.device,
106
- dtype=input.dtype,
107
- ) # (B, D, T)
108
- state_x_utterance = torch.cat([state, x_utterance], dim=2)
109
-
110
- conv_utterance = self.conv(state_x_utterance) # (B, D, T_utterance)
111
- conv_utterance = conv_utterance.permute(2, 0, 1)
112
-
113
- if self.right_context_length > 0:
114
- # (B * num_segments, D, right_context_length + kernel_size - 1)
115
- right_context_block = self._split_right_context(state_x_utterance.permute(2, 0, 1), x_right_context)
116
- conv_right_context_block = self.conv(right_context_block) # (B * num_segments, D, right_context_length)
117
- # (T_right_context, B, D)
118
- conv_right_context = self._merge_right_context(conv_right_context_block, input.size(1))
119
- y = torch.cat([conv_right_context, conv_utterance], dim=0)
120
- else:
121
- y = conv_utterance
122
-
123
- output = self.post_conv(y) + input
124
- new_state = state_x_utterance[:, :, -self.state_size :]
125
- return output[right_context.size(0) :], output[: right_context.size(0)], new_state
126
-
127
- def infer(
128
- self, utterance: torch.Tensor, right_context: torch.Tensor, state: Optional[torch.Tensor]
129
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
130
- input = torch.cat((utterance, right_context))
131
- x = self.pre_conv(input) # (T, B, D)
132
- x = x.permute(1, 2, 0) # (B, D, T)
133
-
134
- if state is None:
135
- state = torch.zeros(
136
- input.size(1),
137
- input.size(2),
138
- self.state_size,
139
- device=input.device,
140
- dtype=input.dtype,
141
- ) # (B, D, T)
142
- state_x = torch.cat([state, x], dim=2)
143
- conv_out = self.conv(state_x)
144
- conv_out = conv_out.permute(2, 0, 1) # T, B, D
145
- output = self.post_conv(conv_out) + input
146
- new_state = state_x[:, :, -self.state_size - right_context.size(0) : -right_context.size(0)]
147
- return output[: utterance.size(0)], output[utterance.size(0) :], new_state
148
-
149
-
150
- class _ConvEmformerLayer(torch.nn.Module):
151
- r"""Convolution-augmented Emformer layer that constitutes ConvEmformer.
152
-
153
- Args:
154
- input_dim (int): input dimension.
155
- num_heads (int): number of attention heads.
156
- ffn_dim: (int): hidden layer dimension of feedforward network.
157
- segment_length (int): length of each input segment.
158
- kernel_size (int): size of kernel to use in convolution module.
159
- dropout (float, optional): dropout probability. (Default: 0.0)
160
- ffn_activation (str, optional): activation function to use in feedforward network.
161
- Must be one of ("relu", "gelu", "silu"). (Default: "relu")
162
- left_context_length (int, optional): length of left context. (Default: 0)
163
- right_context_length (int, optional): length of right context. (Default: 0)
164
- max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
165
- weight_init_gain (float or None, optional): scale factor to apply when initializing
166
- attention module parameters. (Default: ``None``)
167
- tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
168
- negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
169
- conv_activation (str, optional): activation function to use in convolution module.
170
- Must be one of ("relu", "gelu", "silu"). (Default: "silu")
171
- """
172
-
173
- def __init__(
174
- self,
175
- input_dim: int,
176
- num_heads: int,
177
- ffn_dim: int,
178
- segment_length: int,
179
- kernel_size: int,
180
- dropout: float = 0.0,
181
- ffn_activation: str = "relu",
182
- left_context_length: int = 0,
183
- right_context_length: int = 0,
184
- max_memory_size: int = 0,
185
- weight_init_gain: Optional[float] = None,
186
- tanh_on_mem: bool = False,
187
- negative_inf: float = -1e8,
188
- conv_activation: str = "silu",
189
- ):
190
- super().__init__()
191
- # TODO: implement talking heads attention.
192
- self.attention = _EmformerAttention(
193
- input_dim=input_dim,
194
- num_heads=num_heads,
195
- dropout=dropout,
196
- weight_init_gain=weight_init_gain,
197
- tanh_on_mem=tanh_on_mem,
198
- negative_inf=negative_inf,
199
- )
200
- self.dropout = torch.nn.Dropout(dropout)
201
- self.memory_op = torch.nn.AvgPool1d(kernel_size=segment_length, stride=segment_length, ceil_mode=True)
202
-
203
- activation_module = _get_activation_module(ffn_activation)
204
- self.ffn0 = _ResidualContainer(
205
- torch.nn.Sequential(
206
- torch.nn.LayerNorm(input_dim),
207
- torch.nn.Linear(input_dim, ffn_dim),
208
- activation_module,
209
- torch.nn.Dropout(dropout),
210
- torch.nn.Linear(ffn_dim, input_dim),
211
- torch.nn.Dropout(dropout),
212
- ),
213
- 0.5,
214
- )
215
- self.ffn1 = _ResidualContainer(
216
- torch.nn.Sequential(
217
- torch.nn.LayerNorm(input_dim),
218
- torch.nn.Linear(input_dim, ffn_dim),
219
- activation_module,
220
- torch.nn.Dropout(dropout),
221
- torch.nn.Linear(ffn_dim, input_dim),
222
- torch.nn.Dropout(dropout),
223
- ),
224
- 0.5,
225
- )
226
- self.layer_norm_input = torch.nn.LayerNorm(input_dim)
227
- self.layer_norm_output = torch.nn.LayerNorm(input_dim)
228
-
229
- self.conv = _ConvolutionModule(
230
- input_dim=input_dim,
231
- kernel_size=kernel_size,
232
- activation=conv_activation,
233
- dropout=dropout,
234
- segment_length=segment_length,
235
- right_context_length=right_context_length,
236
- )
237
-
238
- self.left_context_length = left_context_length
239
- self.segment_length = segment_length
240
- self.max_memory_size = max_memory_size
241
- self.input_dim = input_dim
242
- self.kernel_size = kernel_size
243
- self.use_mem = max_memory_size > 0
244
-
245
- def _init_state(self, batch_size: int, device: Optional[torch.device]) -> List[torch.Tensor]:
246
- empty_memory = torch.zeros(self.max_memory_size, batch_size, self.input_dim, device=device)
247
- left_context_key = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
248
- left_context_val = torch.zeros(self.left_context_length, batch_size, self.input_dim, device=device)
249
- past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device)
250
- conv_cache = torch.zeros(
251
- batch_size,
252
- self.input_dim,
253
- self.kernel_size - 1,
254
- device=device,
255
- )
256
- return [empty_memory, left_context_key, left_context_val, past_length, conv_cache]
257
-
258
- def _unpack_state(self, state: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
259
- past_length = state[3][0][0].item()
260
- past_left_context_length = min(self.left_context_length, past_length)
261
- past_mem_length = min(self.max_memory_size, math.ceil(past_length / self.segment_length))
262
- pre_mems = state[0][self.max_memory_size - past_mem_length :]
263
- lc_key = state[1][self.left_context_length - past_left_context_length :]
264
- lc_val = state[2][self.left_context_length - past_left_context_length :]
265
- conv_cache = state[4]
266
- return pre_mems, lc_key, lc_val, conv_cache
267
-
268
- def _pack_state(
269
- self,
270
- next_k: torch.Tensor,
271
- next_v: torch.Tensor,
272
- update_length: int,
273
- mems: torch.Tensor,
274
- conv_cache: torch.Tensor,
275
- state: List[torch.Tensor],
276
- ) -> List[torch.Tensor]:
277
- new_k = torch.cat([state[1], next_k])
278
- new_v = torch.cat([state[2], next_v])
279
- state[0] = torch.cat([state[0], mems])[-self.max_memory_size :]
280
- state[1] = new_k[new_k.shape[0] - self.left_context_length :]
281
- state[2] = new_v[new_v.shape[0] - self.left_context_length :]
282
- state[3] = state[3] + update_length
283
- state[4] = conv_cache
284
- return state
285
-
286
- def _apply_pre_attention(
287
- self, utterance: torch.Tensor, right_context: torch.Tensor, summary: torch.Tensor
288
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
289
- x = torch.cat([right_context, utterance, summary])
290
- ffn0_out = self.ffn0(x)
291
- layer_norm_input_out = self.layer_norm_input(ffn0_out)
292
- layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary = (
293
- layer_norm_input_out[: right_context.size(0)],
294
- layer_norm_input_out[right_context.size(0) : right_context.size(0) + utterance.size(0)],
295
- layer_norm_input_out[right_context.size(0) + utterance.size(0) :],
296
- )
297
- return ffn0_out, layer_norm_input_right_context, layer_norm_input_utterance, layer_norm_input_summary
298
-
299
- def _apply_post_attention(
300
- self,
301
- rc_output: torch.Tensor,
302
- ffn0_out: torch.Tensor,
303
- conv_cache: Optional[torch.Tensor],
304
- rc_length: int,
305
- utterance_length: int,
306
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
307
- result = self.dropout(rc_output) + ffn0_out[: rc_length + utterance_length]
308
- conv_utterance, conv_right_context, conv_cache = self.conv(result[rc_length:], result[:rc_length], conv_cache)
309
- result = torch.cat([conv_right_context, conv_utterance])
310
- result = self.ffn1(result)
311
- result = self.layer_norm_output(result)
312
- output_utterance, output_right_context = result[rc_length:], result[:rc_length]
313
- return output_utterance, output_right_context, conv_cache
314
-
315
- def forward(
316
- self,
317
- utterance: torch.Tensor,
318
- lengths: torch.Tensor,
319
- right_context: torch.Tensor,
320
- mems: torch.Tensor,
321
- attention_mask: torch.Tensor,
322
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
323
- r"""Forward pass for training.
324
-
325
- B: batch size;
326
- D: feature dimension of each frame;
327
- T: number of utterance frames;
328
- R: number of right context frames;
329
- M: number of memory elements.
330
-
331
- Args:
332
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
333
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
334
- number of valid frames for i-th batch element in ``utterance``.
335
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
336
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
337
- attention_mask (torch.Tensor): attention mask for underlying attention module.
338
-
339
- Returns:
340
- (Tensor, Tensor, Tensor):
341
- Tensor
342
- encoded utterance frames, with shape `(T, B, D)`.
343
- Tensor
344
- updated right context frames, with shape `(R, B, D)`.
345
- Tensor
346
- updated memory elements, with shape `(M, B, D)`.
347
- """
348
- if self.use_mem:
349
- summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)
350
- else:
351
- summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
352
-
353
- (
354
- ffn0_out,
355
- layer_norm_input_right_context,
356
- layer_norm_input_utterance,
357
- layer_norm_input_summary,
358
- ) = self._apply_pre_attention(utterance, right_context, summary)
359
-
360
- rc_output, output_mems = self.attention(
361
- utterance=layer_norm_input_utterance,
362
- lengths=lengths,
363
- right_context=layer_norm_input_right_context,
364
- summary=layer_norm_input_summary,
365
- mems=mems,
366
- attention_mask=attention_mask,
367
- )
368
-
369
- output_utterance, output_right_context, _ = self._apply_post_attention(
370
- rc_output, ffn0_out, None, right_context.size(0), utterance.size(0)
371
- )
372
-
373
- return output_utterance, output_right_context, output_mems
374
-
375
- @torch.jit.export
376
- def infer(
377
- self,
378
- utterance: torch.Tensor,
379
- lengths: torch.Tensor,
380
- right_context: torch.Tensor,
381
- state: Optional[List[torch.Tensor]],
382
- mems: torch.Tensor,
383
- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
384
- r"""Forward pass for inference.
385
-
386
- B: batch size;
387
- D: feature dimension of each frame;
388
- T: number of utterance frames;
389
- R: number of right context frames;
390
- M: number of memory elements.
391
-
392
- Args:
393
- utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
394
- lengths (torch.Tensor): with shape `(B,)` and i-th element representing
395
- number of valid frames for i-th batch element in ``utterance``.
396
- right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
397
- state (List[torch.Tensor] or None): list of tensors representing layer internal state
398
- generated in preceding invocation of ``infer``.
399
- mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
400
-
401
- Returns:
402
- (Tensor, Tensor, List[torch.Tensor], Tensor):
403
- Tensor
404
- encoded utterance frames, with shape `(T, B, D)`.
405
- Tensor
406
- updated right context frames, with shape `(R, B, D)`.
407
- List[Tensor]
408
- list of tensors representing layer internal state
409
- generated in current invocation of ``infer``.
410
- Tensor
411
- updated memory elements, with shape `(M, B, D)`.
412
- """
413
- if self.use_mem:
414
- summary = self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:1]
415
- else:
416
- summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device)
417
-
418
- (
419
- ffn0_out,
420
- layer_norm_input_right_context,
421
- layer_norm_input_utterance,
422
- layer_norm_input_summary,
423
- ) = self._apply_pre_attention(utterance, right_context, summary)
424
-
425
- if state is None:
426
- state = self._init_state(layer_norm_input_utterance.size(1), device=layer_norm_input_utterance.device)
427
- pre_mems, lc_key, lc_val, conv_cache = self._unpack_state(state)
428
-
429
- rc_output, next_m, next_k, next_v = self.attention.infer(
430
- utterance=layer_norm_input_utterance,
431
- lengths=lengths,
432
- right_context=layer_norm_input_right_context,
433
- summary=layer_norm_input_summary,
434
- mems=pre_mems,
435
- left_context_key=lc_key,
436
- left_context_val=lc_val,
437
- )
438
-
439
- output_utterance, output_right_context, conv_cache = self._apply_post_attention(
440
- rc_output, ffn0_out, conv_cache, right_context.size(0), utterance.size(0)
441
- )
442
- output_state = self._pack_state(next_k, next_v, utterance.size(0), mems, conv_cache, state)
443
- return output_utterance, output_right_context, output_state, next_m
444
-
445
-
446
- @dropping_class_support
447
- class ConvEmformer(_EmformerImpl):
448
- r"""Implements the convolution-augmented streaming transformer architecture introduced in
449
- *Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution*
450
- :cite:`9747706`.
451
-
452
- Args:
453
- input_dim (int): input dimension.
454
- num_heads (int): number of attention heads in each ConvEmformer layer.
455
- ffn_dim (int): hidden layer dimension of each ConvEmformer layer's feedforward network.
456
- num_layers (int): number of ConvEmformer layers to instantiate.
457
- segment_length (int): length of each input segment.
458
- kernel_size (int): size of kernel to use in convolution modules.
459
- dropout (float, optional): dropout probability. (Default: 0.0)
460
- ffn_activation (str, optional): activation function to use in feedforward networks.
461
- Must be one of ("relu", "gelu", "silu"). (Default: "relu")
462
- left_context_length (int, optional): length of left context. (Default: 0)
463
- right_context_length (int, optional): length of right context. (Default: 0)
464
- max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
465
- weight_init_scale_strategy (str or None, optional): per-layer weight initialization scaling
466
- strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
467
- tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
468
- negative_inf (float, optional): value to use for negative infinity in attention weights. (Default: -1e8)
469
- conv_activation (str, optional): activation function to use in convolution modules.
470
- Must be one of ("relu", "gelu", "silu"). (Default: "silu")
471
-
472
- Examples:
473
- >>> conv_emformer = ConvEmformer(80, 4, 1024, 12, 16, 8, right_context_length=4)
474
- >>> input = torch.rand(10, 200, 80)
475
- >>> lengths = torch.randint(1, 200, (10,))
476
- >>> output, lengths = conv_emformer(input, lengths)
477
- >>> input = torch.rand(4, 20, 80)
478
- >>> lengths = torch.ones(4) * 20
479
- >>> output, lengths, states = conv_emformer.infer(input, lengths, None)
480
- """
481
-
482
- @dropping_support
483
- def __init__(
484
- self,
485
- input_dim: int,
486
- num_heads: int,
487
- ffn_dim: int,
488
- num_layers: int,
489
- segment_length: int,
490
- kernel_size: int,
491
- dropout: float = 0.0,
492
- ffn_activation: str = "relu",
493
- left_context_length: int = 0,
494
- right_context_length: int = 0,
495
- max_memory_size: int = 0,
496
- weight_init_scale_strategy: Optional[str] = "depthwise",
497
- tanh_on_mem: bool = False,
498
- negative_inf: float = -1e8,
499
- conv_activation: str = "silu",
500
- ):
501
- weight_init_gains = _get_weight_init_gains(weight_init_scale_strategy, num_layers)
502
- emformer_layers = torch.nn.ModuleList(
503
- [
504
- _ConvEmformerLayer(
505
- input_dim,
506
- num_heads,
507
- ffn_dim,
508
- segment_length,
509
- kernel_size,
510
- dropout=dropout,
511
- ffn_activation=ffn_activation,
512
- left_context_length=left_context_length,
513
- right_context_length=right_context_length,
514
- max_memory_size=max_memory_size,
515
- weight_init_gain=weight_init_gains[layer_idx],
516
- tanh_on_mem=tanh_on_mem,
517
- negative_inf=negative_inf,
518
- conv_activation=conv_activation,
519
- )
520
- for layer_idx in range(num_layers)
521
- ]
522
- )
523
- super().__init__(
524
- emformer_layers,
525
- segment_length,
526
- left_context_length=left_context_length,
527
- right_context_length=right_context_length,
528
- max_memory_size=max_memory_size,
529
- )