torchaudio 2.9.1__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +86 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,816 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torchaudio.models import Emformer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _TimeReduction(torch.nn.Module):
|
|
12
|
+
r"""Coalesces frames along time dimension into a
|
|
13
|
+
fewer number of frames with higher feature dimensionality.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
stride (int): number of frames to merge for each output frame.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, stride: int) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.stride = stride
|
|
22
|
+
|
|
23
|
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
24
|
+
r"""Forward pass.
|
|
25
|
+
|
|
26
|
+
B: batch size;
|
|
27
|
+
T: maximum input sequence length in batch;
|
|
28
|
+
D: feature dimension of each input sequence frame.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
input (torch.Tensor): input sequences, with shape `(B, T, D)`.
|
|
32
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
33
|
+
number of valid frames for i-th batch element in ``input``.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
(torch.Tensor, torch.Tensor):
|
|
37
|
+
torch.Tensor
|
|
38
|
+
output sequences, with shape
|
|
39
|
+
`(B, T // stride, D * stride)`
|
|
40
|
+
torch.Tensor
|
|
41
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
42
|
+
number of valid frames for i-th batch element in output sequences.
|
|
43
|
+
"""
|
|
44
|
+
B, T, D = input.shape
|
|
45
|
+
num_frames = T - (T % self.stride)
|
|
46
|
+
input = input[:, :num_frames, :]
|
|
47
|
+
lengths = lengths.div(self.stride, rounding_mode="trunc")
|
|
48
|
+
T_max = num_frames // self.stride
|
|
49
|
+
|
|
50
|
+
output = input.reshape(B, T_max, D * self.stride)
|
|
51
|
+
output = output.contiguous()
|
|
52
|
+
return output, lengths
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _CustomLSTM(torch.nn.Module):
|
|
56
|
+
r"""Custom long-short-term memory (LSTM) block that applies layer normalization
|
|
57
|
+
to internal nodes.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
input_dim (int): input dimension.
|
|
61
|
+
hidden_dim (int): hidden dimension.
|
|
62
|
+
layer_norm (bool, optional): if ``True``, enables layer normalization. (Default: ``False``)
|
|
63
|
+
layer_norm_epsilon (float, optional): value of epsilon to use in
|
|
64
|
+
layer normalization layers (Default: 1e-5)
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
input_dim: int,
|
|
70
|
+
hidden_dim: int,
|
|
71
|
+
layer_norm: bool = False,
|
|
72
|
+
layer_norm_epsilon: float = 1e-5,
|
|
73
|
+
) -> None:
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.x2g = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=(not layer_norm))
|
|
76
|
+
self.p2g = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=False)
|
|
77
|
+
if layer_norm:
|
|
78
|
+
self.c_norm = torch.nn.LayerNorm(hidden_dim, eps=layer_norm_epsilon)
|
|
79
|
+
self.g_norm = torch.nn.LayerNorm(4 * hidden_dim, eps=layer_norm_epsilon)
|
|
80
|
+
else:
|
|
81
|
+
self.c_norm = torch.nn.Identity()
|
|
82
|
+
self.g_norm = torch.nn.Identity()
|
|
83
|
+
|
|
84
|
+
self.hidden_dim = hidden_dim
|
|
85
|
+
|
|
86
|
+
def forward(
|
|
87
|
+
self, input: torch.Tensor, state: Optional[List[torch.Tensor]]
|
|
88
|
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
89
|
+
r"""Forward pass.
|
|
90
|
+
|
|
91
|
+
B: batch size;
|
|
92
|
+
T: maximum sequence length in batch;
|
|
93
|
+
D: feature dimension of each input sequence element.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
input (torch.Tensor): with shape `(T, B, D)`.
|
|
97
|
+
state (List[torch.Tensor] or None): list of tensors
|
|
98
|
+
representing internal state generated in preceding invocation
|
|
99
|
+
of ``forward``.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
(torch.Tensor, List[torch.Tensor]):
|
|
103
|
+
torch.Tensor
|
|
104
|
+
output, with shape `(T, B, hidden_dim)`.
|
|
105
|
+
List[torch.Tensor]
|
|
106
|
+
list of tensors representing internal state generated
|
|
107
|
+
in current invocation of ``forward``.
|
|
108
|
+
"""
|
|
109
|
+
if state is None:
|
|
110
|
+
B = input.size(1)
|
|
111
|
+
h = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
|
112
|
+
c = torch.zeros(B, self.hidden_dim, device=input.device, dtype=input.dtype)
|
|
113
|
+
else:
|
|
114
|
+
h, c = state
|
|
115
|
+
|
|
116
|
+
gated_input = self.x2g(input)
|
|
117
|
+
outputs = []
|
|
118
|
+
for gates in gated_input.unbind(0):
|
|
119
|
+
gates = gates + self.p2g(h)
|
|
120
|
+
gates = self.g_norm(gates)
|
|
121
|
+
input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1)
|
|
122
|
+
input_gate = input_gate.sigmoid()
|
|
123
|
+
forget_gate = forget_gate.sigmoid()
|
|
124
|
+
cell_gate = cell_gate.tanh()
|
|
125
|
+
output_gate = output_gate.sigmoid()
|
|
126
|
+
c = forget_gate * c + input_gate * cell_gate
|
|
127
|
+
c = self.c_norm(c)
|
|
128
|
+
h = output_gate * c.tanh()
|
|
129
|
+
outputs.append(h)
|
|
130
|
+
|
|
131
|
+
output = torch.stack(outputs, dim=0)
|
|
132
|
+
state = [h, c]
|
|
133
|
+
|
|
134
|
+
return output, state
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class _Transcriber(ABC):
|
|
138
|
+
@abstractmethod
|
|
139
|
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
@abstractmethod
|
|
143
|
+
def infer(
|
|
144
|
+
self,
|
|
145
|
+
input: torch.Tensor,
|
|
146
|
+
lengths: torch.Tensor,
|
|
147
|
+
states: Optional[List[List[torch.Tensor]]],
|
|
148
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class _EmformerEncoder(torch.nn.Module, _Transcriber):
|
|
153
|
+
r"""Emformer-based recurrent neural network transducer (RNN-T) encoder (transcription network).
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
input_dim (int): feature dimension of each input sequence element.
|
|
157
|
+
output_dim (int): feature dimension of each output sequence element.
|
|
158
|
+
segment_length (int): length of input segment expressed as number of frames.
|
|
159
|
+
right_context_length (int): length of right context expressed as number of frames.
|
|
160
|
+
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
|
161
|
+
prior to applying time reduction block.
|
|
162
|
+
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
163
|
+
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
|
164
|
+
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
|
165
|
+
transformer_num_layers (int): number of Emformer layers to instantiate.
|
|
166
|
+
transformer_left_context_length (int): length of left context.
|
|
167
|
+
transformer_dropout (float, optional): transformer dropout probability. (Default: 0.0)
|
|
168
|
+
transformer_activation (str, optional): activation function to use in each Emformer layer's
|
|
169
|
+
feedforward network. Must be one of ("relu", "gelu", "silu"). (Default: "relu")
|
|
170
|
+
transformer_max_memory_size (int, optional): maximum number of memory elements to use. (Default: 0)
|
|
171
|
+
transformer_weight_init_scale_strategy (str, optional): per-layer weight initialization scaling
|
|
172
|
+
strategy. Must be one of ("depthwise", "constant", ``None``). (Default: "depthwise")
|
|
173
|
+
transformer_tanh_on_mem (bool, optional): if ``True``, applies tanh to memory elements. (Default: ``False``)
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
*,
|
|
179
|
+
input_dim: int,
|
|
180
|
+
output_dim: int,
|
|
181
|
+
segment_length: int,
|
|
182
|
+
right_context_length: int,
|
|
183
|
+
time_reduction_input_dim: int,
|
|
184
|
+
time_reduction_stride: int,
|
|
185
|
+
transformer_num_heads: int,
|
|
186
|
+
transformer_ffn_dim: int,
|
|
187
|
+
transformer_num_layers: int,
|
|
188
|
+
transformer_left_context_length: int,
|
|
189
|
+
transformer_dropout: float = 0.0,
|
|
190
|
+
transformer_activation: str = "relu",
|
|
191
|
+
transformer_max_memory_size: int = 0,
|
|
192
|
+
transformer_weight_init_scale_strategy: str = "depthwise",
|
|
193
|
+
transformer_tanh_on_mem: bool = False,
|
|
194
|
+
) -> None:
|
|
195
|
+
super().__init__()
|
|
196
|
+
self.input_linear = torch.nn.Linear(
|
|
197
|
+
input_dim,
|
|
198
|
+
time_reduction_input_dim,
|
|
199
|
+
bias=False,
|
|
200
|
+
)
|
|
201
|
+
self.time_reduction = _TimeReduction(time_reduction_stride)
|
|
202
|
+
transformer_input_dim = time_reduction_input_dim * time_reduction_stride
|
|
203
|
+
self.transformer = Emformer(
|
|
204
|
+
transformer_input_dim,
|
|
205
|
+
transformer_num_heads,
|
|
206
|
+
transformer_ffn_dim,
|
|
207
|
+
transformer_num_layers,
|
|
208
|
+
segment_length // time_reduction_stride,
|
|
209
|
+
dropout=transformer_dropout,
|
|
210
|
+
activation=transformer_activation,
|
|
211
|
+
left_context_length=transformer_left_context_length,
|
|
212
|
+
right_context_length=right_context_length // time_reduction_stride,
|
|
213
|
+
max_memory_size=transformer_max_memory_size,
|
|
214
|
+
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
|
215
|
+
tanh_on_mem=transformer_tanh_on_mem,
|
|
216
|
+
)
|
|
217
|
+
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
|
|
218
|
+
self.layer_norm = torch.nn.LayerNorm(output_dim)
|
|
219
|
+
|
|
220
|
+
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
221
|
+
r"""Forward pass for training.
|
|
222
|
+
|
|
223
|
+
B: batch size;
|
|
224
|
+
T: maximum input sequence length in batch;
|
|
225
|
+
D: feature dimension of each input sequence frame (input_dim).
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
input (torch.Tensor): input frame sequences right-padded with right context, with
|
|
229
|
+
shape `(B, T + right context length, D)`.
|
|
230
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
231
|
+
number of valid frames for i-th batch element in ``input``.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
(torch.Tensor, torch.Tensor):
|
|
235
|
+
torch.Tensor
|
|
236
|
+
output frame sequences, with
|
|
237
|
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
238
|
+
torch.Tensor
|
|
239
|
+
output input lengths, with shape `(B,)` and i-th element representing
|
|
240
|
+
number of valid elements for i-th batch element in output frame sequences.
|
|
241
|
+
"""
|
|
242
|
+
input_linear_out = self.input_linear(input)
|
|
243
|
+
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
|
244
|
+
transformer_out, transformer_lengths = self.transformer(time_reduction_out, time_reduction_lengths)
|
|
245
|
+
output_linear_out = self.output_linear(transformer_out)
|
|
246
|
+
layer_norm_out = self.layer_norm(output_linear_out)
|
|
247
|
+
return layer_norm_out, transformer_lengths
|
|
248
|
+
|
|
249
|
+
@torch.jit.export
|
|
250
|
+
def infer(
|
|
251
|
+
self,
|
|
252
|
+
input: torch.Tensor,
|
|
253
|
+
lengths: torch.Tensor,
|
|
254
|
+
states: Optional[List[List[torch.Tensor]]],
|
|
255
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
256
|
+
r"""Forward pass for inference.
|
|
257
|
+
|
|
258
|
+
B: batch size;
|
|
259
|
+
T: maximum input sequence segment length in batch;
|
|
260
|
+
D: feature dimension of each input sequence frame (input_dim).
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
input (torch.Tensor): input frame sequence segments right-padded with right context, with
|
|
264
|
+
shape `(B, T + right context length, D)`.
|
|
265
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
266
|
+
number of valid frames for i-th batch element in ``input``.
|
|
267
|
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
268
|
+
representing internal state generated in preceding invocation
|
|
269
|
+
of ``infer``.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
273
|
+
torch.Tensor
|
|
274
|
+
output frame sequences, with
|
|
275
|
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
276
|
+
torch.Tensor
|
|
277
|
+
output input lengths, with shape `(B,)` and i-th element representing
|
|
278
|
+
number of valid elements for i-th batch element in output.
|
|
279
|
+
List[List[torch.Tensor]]
|
|
280
|
+
output states; list of lists of tensors
|
|
281
|
+
representing internal state generated in current invocation
|
|
282
|
+
of ``infer``.
|
|
283
|
+
"""
|
|
284
|
+
input_linear_out = self.input_linear(input)
|
|
285
|
+
time_reduction_out, time_reduction_lengths = self.time_reduction(input_linear_out, lengths)
|
|
286
|
+
(
|
|
287
|
+
transformer_out,
|
|
288
|
+
transformer_lengths,
|
|
289
|
+
transformer_states,
|
|
290
|
+
) = self.transformer.infer(time_reduction_out, time_reduction_lengths, states)
|
|
291
|
+
output_linear_out = self.output_linear(transformer_out)
|
|
292
|
+
layer_norm_out = self.layer_norm(output_linear_out)
|
|
293
|
+
return layer_norm_out, transformer_lengths, transformer_states
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class _Predictor(torch.nn.Module):
|
|
297
|
+
r"""Recurrent neural network transducer (RNN-T) prediction network.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
num_symbols (int): size of target token lexicon.
|
|
301
|
+
output_dim (int): feature dimension of each output sequence element.
|
|
302
|
+
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
303
|
+
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
304
|
+
lstm_hidden_dim (int): output dimension of each LSTM layer.
|
|
305
|
+
lstm_layer_norm (bool, optional): if ``True``, enables layer normalization
|
|
306
|
+
for LSTM layers. (Default: ``False``)
|
|
307
|
+
lstm_layer_norm_epsilon (float, optional): value of epsilon to use in
|
|
308
|
+
LSTM layer normalization layers. (Default: 1e-5)
|
|
309
|
+
lstm_dropout (float, optional): LSTM dropout probability. (Default: 0.0)
|
|
310
|
+
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def __init__(
|
|
314
|
+
self,
|
|
315
|
+
num_symbols: int,
|
|
316
|
+
output_dim: int,
|
|
317
|
+
symbol_embedding_dim: int,
|
|
318
|
+
num_lstm_layers: int,
|
|
319
|
+
lstm_hidden_dim: int,
|
|
320
|
+
lstm_layer_norm: bool = False,
|
|
321
|
+
lstm_layer_norm_epsilon: float = 1e-5,
|
|
322
|
+
lstm_dropout: float = 0.0,
|
|
323
|
+
) -> None:
|
|
324
|
+
super().__init__()
|
|
325
|
+
self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim)
|
|
326
|
+
self.input_layer_norm = torch.nn.LayerNorm(symbol_embedding_dim)
|
|
327
|
+
self.lstm_layers = torch.nn.ModuleList(
|
|
328
|
+
[
|
|
329
|
+
_CustomLSTM(
|
|
330
|
+
symbol_embedding_dim if idx == 0 else lstm_hidden_dim,
|
|
331
|
+
lstm_hidden_dim,
|
|
332
|
+
layer_norm=lstm_layer_norm,
|
|
333
|
+
layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
334
|
+
)
|
|
335
|
+
for idx in range(num_lstm_layers)
|
|
336
|
+
]
|
|
337
|
+
)
|
|
338
|
+
self.dropout = torch.nn.Dropout(p=lstm_dropout)
|
|
339
|
+
self.linear = torch.nn.Linear(lstm_hidden_dim, output_dim)
|
|
340
|
+
self.output_layer_norm = torch.nn.LayerNorm(output_dim)
|
|
341
|
+
|
|
342
|
+
self.lstm_dropout = lstm_dropout
|
|
343
|
+
|
|
344
|
+
def forward(
|
|
345
|
+
self,
|
|
346
|
+
input: torch.Tensor,
|
|
347
|
+
lengths: torch.Tensor,
|
|
348
|
+
state: Optional[List[List[torch.Tensor]]] = None,
|
|
349
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
350
|
+
r"""Forward pass.
|
|
351
|
+
|
|
352
|
+
B: batch size;
|
|
353
|
+
U: maximum sequence length in batch;
|
|
354
|
+
D: feature dimension of each input sequence element.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
input (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
358
|
+
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
|
359
|
+
lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
360
|
+
number of valid frames for i-th batch element in ``input``.
|
|
361
|
+
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
362
|
+
representing internal state generated in preceding invocation
|
|
363
|
+
of ``forward``. (Default: ``None``)
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
367
|
+
torch.Tensor
|
|
368
|
+
output encoding sequences, with shape `(B, U, output_dim)`
|
|
369
|
+
torch.Tensor
|
|
370
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
371
|
+
number of valid elements for i-th batch element in output encoding sequences.
|
|
372
|
+
List[List[torch.Tensor]]
|
|
373
|
+
output states; list of lists of tensors
|
|
374
|
+
representing internal state generated in current invocation of ``forward``.
|
|
375
|
+
"""
|
|
376
|
+
input_tb = input.permute(1, 0)
|
|
377
|
+
embedding_out = self.embedding(input_tb)
|
|
378
|
+
input_layer_norm_out = self.input_layer_norm(embedding_out)
|
|
379
|
+
|
|
380
|
+
lstm_out = input_layer_norm_out
|
|
381
|
+
state_out: List[List[torch.Tensor]] = []
|
|
382
|
+
for layer_idx, lstm in enumerate(self.lstm_layers):
|
|
383
|
+
lstm_out, lstm_state_out = lstm(lstm_out, None if state is None else state[layer_idx])
|
|
384
|
+
lstm_out = self.dropout(lstm_out)
|
|
385
|
+
state_out.append(lstm_state_out)
|
|
386
|
+
|
|
387
|
+
linear_out = self.linear(lstm_out)
|
|
388
|
+
output_layer_norm_out = self.output_layer_norm(linear_out)
|
|
389
|
+
return output_layer_norm_out.permute(1, 0, 2), lengths, state_out
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class _Joiner(torch.nn.Module):
|
|
393
|
+
r"""Recurrent neural network transducer (RNN-T) joint network.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
input_dim (int): source and target input dimension.
|
|
397
|
+
output_dim (int): output dimension.
|
|
398
|
+
activation (str, optional): activation function to use in the joiner.
|
|
399
|
+
Must be one of ("relu", "tanh"). (Default: "relu")
|
|
400
|
+
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
|
|
404
|
+
super().__init__()
|
|
405
|
+
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
|
|
406
|
+
if activation == "relu":
|
|
407
|
+
self.activation = torch.nn.ReLU()
|
|
408
|
+
elif activation == "tanh":
|
|
409
|
+
self.activation = torch.nn.Tanh()
|
|
410
|
+
else:
|
|
411
|
+
raise ValueError(f"Unsupported activation {activation}")
|
|
412
|
+
|
|
413
|
+
def forward(
|
|
414
|
+
self,
|
|
415
|
+
source_encodings: torch.Tensor,
|
|
416
|
+
source_lengths: torch.Tensor,
|
|
417
|
+
target_encodings: torch.Tensor,
|
|
418
|
+
target_lengths: torch.Tensor,
|
|
419
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
420
|
+
r"""Forward pass for training.
|
|
421
|
+
|
|
422
|
+
B: batch size;
|
|
423
|
+
T: maximum source sequence length in batch;
|
|
424
|
+
U: maximum target sequence length in batch;
|
|
425
|
+
D: dimension of each source and target sequence encoding.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
source_encodings (torch.Tensor): source encoding sequences, with
|
|
429
|
+
shape `(B, T, D)`.
|
|
430
|
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
431
|
+
valid sequence length of i-th batch element in ``source_encodings``.
|
|
432
|
+
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
433
|
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
434
|
+
valid sequence length of i-th batch element in ``target_encodings``.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
438
|
+
torch.Tensor
|
|
439
|
+
joint network output, with shape `(B, T, U, output_dim)`.
|
|
440
|
+
torch.Tensor
|
|
441
|
+
output source lengths, with shape `(B,)` and i-th element representing
|
|
442
|
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
443
|
+
torch.Tensor
|
|
444
|
+
output target lengths, with shape `(B,)` and i-th element representing
|
|
445
|
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
446
|
+
"""
|
|
447
|
+
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
|
|
448
|
+
activation_out = self.activation(joint_encodings)
|
|
449
|
+
output = self.linear(activation_out)
|
|
450
|
+
return output, source_lengths, target_lengths
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class RNNT(torch.nn.Module):
|
|
454
|
+
r"""torchaudio.models.RNNT()
|
|
455
|
+
|
|
456
|
+
Recurrent neural network transducer (RNN-T) model.
|
|
457
|
+
|
|
458
|
+
Note:
|
|
459
|
+
To build the model, please use one of the factory functions.
|
|
460
|
+
|
|
461
|
+
See Also:
|
|
462
|
+
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
transcriber (torch.nn.Module): transcription network.
|
|
466
|
+
predictor (torch.nn.Module): prediction network.
|
|
467
|
+
joiner (torch.nn.Module): joint network.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
def __init__(self, transcriber: _Transcriber, predictor: _Predictor, joiner: _Joiner) -> None:
|
|
471
|
+
super().__init__()
|
|
472
|
+
self.transcriber = transcriber
|
|
473
|
+
self.predictor = predictor
|
|
474
|
+
self.joiner = joiner
|
|
475
|
+
|
|
476
|
+
def forward(
|
|
477
|
+
self,
|
|
478
|
+
sources: torch.Tensor,
|
|
479
|
+
source_lengths: torch.Tensor,
|
|
480
|
+
targets: torch.Tensor,
|
|
481
|
+
target_lengths: torch.Tensor,
|
|
482
|
+
predictor_state: Optional[List[List[torch.Tensor]]] = None,
|
|
483
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
484
|
+
r"""Forward pass for training.
|
|
485
|
+
|
|
486
|
+
B: batch size;
|
|
487
|
+
T: maximum source sequence length in batch;
|
|
488
|
+
U: maximum target sequence length in batch;
|
|
489
|
+
D: feature dimension of each source sequence element.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
|
493
|
+
shape `(B, T, D)`.
|
|
494
|
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
495
|
+
number of valid frames for i-th batch element in ``sources``.
|
|
496
|
+
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
497
|
+
mapping to a target symbol.
|
|
498
|
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
499
|
+
number of valid frames for i-th batch element in ``targets``.
|
|
500
|
+
predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
501
|
+
representing prediction network internal state generated in preceding invocation
|
|
502
|
+
of ``forward``. (Default: ``None``)
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
506
|
+
torch.Tensor
|
|
507
|
+
joint network output, with shape
|
|
508
|
+
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
|
|
509
|
+
torch.Tensor
|
|
510
|
+
output source lengths, with shape `(B,)` and i-th element representing
|
|
511
|
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
512
|
+
torch.Tensor
|
|
513
|
+
output target lengths, with shape `(B,)` and i-th element representing
|
|
514
|
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
515
|
+
List[List[torch.Tensor]]
|
|
516
|
+
output states; list of lists of tensors
|
|
517
|
+
representing prediction network internal state generated in current invocation
|
|
518
|
+
of ``forward``.
|
|
519
|
+
"""
|
|
520
|
+
source_encodings, source_lengths = self.transcriber(
|
|
521
|
+
input=sources,
|
|
522
|
+
lengths=source_lengths,
|
|
523
|
+
)
|
|
524
|
+
target_encodings, target_lengths, predictor_state = self.predictor(
|
|
525
|
+
input=targets,
|
|
526
|
+
lengths=target_lengths,
|
|
527
|
+
state=predictor_state,
|
|
528
|
+
)
|
|
529
|
+
output, source_lengths, target_lengths = self.joiner(
|
|
530
|
+
source_encodings=source_encodings,
|
|
531
|
+
source_lengths=source_lengths,
|
|
532
|
+
target_encodings=target_encodings,
|
|
533
|
+
target_lengths=target_lengths,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
return (
|
|
537
|
+
output,
|
|
538
|
+
source_lengths,
|
|
539
|
+
target_lengths,
|
|
540
|
+
predictor_state,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
@torch.jit.export
|
|
544
|
+
def transcribe_streaming(
|
|
545
|
+
self,
|
|
546
|
+
sources: torch.Tensor,
|
|
547
|
+
source_lengths: torch.Tensor,
|
|
548
|
+
state: Optional[List[List[torch.Tensor]]],
|
|
549
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
550
|
+
r"""Applies transcription network to sources in streaming mode.
|
|
551
|
+
|
|
552
|
+
B: batch size;
|
|
553
|
+
T: maximum source sequence segment length in batch;
|
|
554
|
+
D: feature dimension of each source sequence frame.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
sources (torch.Tensor): source frame sequence segments right-padded with right context, with
|
|
558
|
+
shape `(B, T + right context length, D)`.
|
|
559
|
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
560
|
+
number of valid frames for i-th batch element in ``sources``.
|
|
561
|
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
562
|
+
representing transcription network internal state generated in preceding invocation
|
|
563
|
+
of ``transcribe_streaming``.
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
567
|
+
torch.Tensor
|
|
568
|
+
output frame sequences, with
|
|
569
|
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
570
|
+
torch.Tensor
|
|
571
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
572
|
+
number of valid elements for i-th batch element in output.
|
|
573
|
+
List[List[torch.Tensor]]
|
|
574
|
+
output states; list of lists of tensors
|
|
575
|
+
representing transcription network internal state generated in current invocation
|
|
576
|
+
of ``transcribe_streaming``.
|
|
577
|
+
"""
|
|
578
|
+
return self.transcriber.infer(sources, source_lengths, state)
|
|
579
|
+
|
|
580
|
+
@torch.jit.export
|
|
581
|
+
def transcribe(
|
|
582
|
+
self,
|
|
583
|
+
sources: torch.Tensor,
|
|
584
|
+
source_lengths: torch.Tensor,
|
|
585
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
586
|
+
r"""Applies transcription network to sources in non-streaming mode.
|
|
587
|
+
|
|
588
|
+
B: batch size;
|
|
589
|
+
T: maximum source sequence length in batch;
|
|
590
|
+
D: feature dimension of each source sequence frame.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
|
594
|
+
shape `(B, T + right context length, D)`.
|
|
595
|
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
596
|
+
number of valid frames for i-th batch element in ``sources``.
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
(torch.Tensor, torch.Tensor):
|
|
600
|
+
torch.Tensor
|
|
601
|
+
output frame sequences, with
|
|
602
|
+
shape `(B, T // time_reduction_stride, output_dim)`.
|
|
603
|
+
torch.Tensor
|
|
604
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
605
|
+
number of valid elements for i-th batch element in output frame sequences.
|
|
606
|
+
"""
|
|
607
|
+
return self.transcriber(sources, source_lengths)
|
|
608
|
+
|
|
609
|
+
@torch.jit.export
|
|
610
|
+
def predict(
|
|
611
|
+
self,
|
|
612
|
+
targets: torch.Tensor,
|
|
613
|
+
target_lengths: torch.Tensor,
|
|
614
|
+
state: Optional[List[List[torch.Tensor]]],
|
|
615
|
+
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
616
|
+
r"""Applies prediction network to targets.
|
|
617
|
+
|
|
618
|
+
B: batch size;
|
|
619
|
+
U: maximum target sequence length in batch;
|
|
620
|
+
D: feature dimension of each target sequence frame.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
624
|
+
mapping to a target symbol, i.e. in range `[0, num_symbols)`.
|
|
625
|
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
626
|
+
number of valid frames for i-th batch element in ``targets``.
|
|
627
|
+
state (List[List[torch.Tensor]] or None): list of lists of tensors
|
|
628
|
+
representing internal state generated in preceding invocation
|
|
629
|
+
of ``predict``.
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
633
|
+
torch.Tensor
|
|
634
|
+
output frame sequences, with shape `(B, U, output_dim)`.
|
|
635
|
+
torch.Tensor
|
|
636
|
+
output lengths, with shape `(B,)` and i-th element representing
|
|
637
|
+
number of valid elements for i-th batch element in output.
|
|
638
|
+
List[List[torch.Tensor]]
|
|
639
|
+
output states; list of lists of tensors
|
|
640
|
+
representing internal state generated in current invocation of ``predict``.
|
|
641
|
+
"""
|
|
642
|
+
return self.predictor(input=targets, lengths=target_lengths, state=state)
|
|
643
|
+
|
|
644
|
+
@torch.jit.export
|
|
645
|
+
def join(
|
|
646
|
+
self,
|
|
647
|
+
source_encodings: torch.Tensor,
|
|
648
|
+
source_lengths: torch.Tensor,
|
|
649
|
+
target_encodings: torch.Tensor,
|
|
650
|
+
target_lengths: torch.Tensor,
|
|
651
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
652
|
+
r"""Applies joint network to source and target encodings.
|
|
653
|
+
|
|
654
|
+
B: batch size;
|
|
655
|
+
T: maximum source sequence length in batch;
|
|
656
|
+
U: maximum target sequence length in batch;
|
|
657
|
+
D: dimension of each source and target sequence encoding.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
source_encodings (torch.Tensor): source encoding sequences, with
|
|
661
|
+
shape `(B, T, D)`.
|
|
662
|
+
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
663
|
+
valid sequence length of i-th batch element in ``source_encodings``.
|
|
664
|
+
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
665
|
+
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
666
|
+
valid sequence length of i-th batch element in ``target_encodings``.
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
670
|
+
torch.Tensor
|
|
671
|
+
joint network output, with shape `(B, T, U, output_dim)`.
|
|
672
|
+
torch.Tensor
|
|
673
|
+
output source lengths, with shape `(B,)` and i-th element representing
|
|
674
|
+
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
675
|
+
torch.Tensor
|
|
676
|
+
output target lengths, with shape `(B,)` and i-th element representing
|
|
677
|
+
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
678
|
+
"""
|
|
679
|
+
output, source_lengths, target_lengths = self.joiner(
|
|
680
|
+
source_encodings=source_encodings,
|
|
681
|
+
source_lengths=source_lengths,
|
|
682
|
+
target_encodings=target_encodings,
|
|
683
|
+
target_lengths=target_lengths,
|
|
684
|
+
)
|
|
685
|
+
return output, source_lengths, target_lengths
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def emformer_rnnt_model(
|
|
689
|
+
*,
|
|
690
|
+
input_dim: int,
|
|
691
|
+
encoding_dim: int,
|
|
692
|
+
num_symbols: int,
|
|
693
|
+
segment_length: int,
|
|
694
|
+
right_context_length: int,
|
|
695
|
+
time_reduction_input_dim: int,
|
|
696
|
+
time_reduction_stride: int,
|
|
697
|
+
transformer_num_heads: int,
|
|
698
|
+
transformer_ffn_dim: int,
|
|
699
|
+
transformer_num_layers: int,
|
|
700
|
+
transformer_dropout: float,
|
|
701
|
+
transformer_activation: str,
|
|
702
|
+
transformer_left_context_length: int,
|
|
703
|
+
transformer_max_memory_size: int,
|
|
704
|
+
transformer_weight_init_scale_strategy: str,
|
|
705
|
+
transformer_tanh_on_mem: bool,
|
|
706
|
+
symbol_embedding_dim: int,
|
|
707
|
+
num_lstm_layers: int,
|
|
708
|
+
lstm_layer_norm: bool,
|
|
709
|
+
lstm_layer_norm_epsilon: float,
|
|
710
|
+
lstm_dropout: float,
|
|
711
|
+
) -> RNNT:
|
|
712
|
+
r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
|
|
713
|
+
|
|
714
|
+
Note:
|
|
715
|
+
For non-streaming inference, the expectation is for `transcribe` to be called on input
|
|
716
|
+
sequences right-concatenated with `right_context_length` frames.
|
|
717
|
+
|
|
718
|
+
For streaming inference, the expectation is for `transcribe_streaming` to be called
|
|
719
|
+
on input chunks comprising `segment_length` frames right-concatenated with `right_context_length`
|
|
720
|
+
frames.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
input_dim (int): dimension of input sequence frames passed to transcription network.
|
|
724
|
+
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
|
|
725
|
+
passed to joint network.
|
|
726
|
+
num_symbols (int): cardinality of set of target tokens.
|
|
727
|
+
segment_length (int): length of input segment expressed as number of frames.
|
|
728
|
+
right_context_length (int): length of right context expressed as number of frames.
|
|
729
|
+
time_reduction_input_dim (int): dimension to scale each element in input sequences to
|
|
730
|
+
prior to applying time reduction block.
|
|
731
|
+
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
732
|
+
transformer_num_heads (int): number of attention heads in each Emformer layer.
|
|
733
|
+
transformer_ffn_dim (int): hidden layer dimension of each Emformer layer's feedforward network.
|
|
734
|
+
transformer_num_layers (int): number of Emformer layers to instantiate.
|
|
735
|
+
transformer_left_context_length (int): length of left context considered by Emformer.
|
|
736
|
+
transformer_dropout (float): Emformer dropout probability.
|
|
737
|
+
transformer_activation (str): activation function to use in each Emformer layer's
|
|
738
|
+
feedforward network. Must be one of ("relu", "gelu", "silu").
|
|
739
|
+
transformer_max_memory_size (int): maximum number of memory elements to use.
|
|
740
|
+
transformer_weight_init_scale_strategy (str): per-layer weight initialization scaling
|
|
741
|
+
strategy. Must be one of ("depthwise", "constant", ``None``).
|
|
742
|
+
transformer_tanh_on_mem (bool): if ``True``, applies tanh to memory elements.
|
|
743
|
+
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
744
|
+
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
745
|
+
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
|
|
746
|
+
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
|
|
747
|
+
lstm_dropout (float): LSTM dropout probability.
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
RNNT:
|
|
751
|
+
Emformer RNN-T model.
|
|
752
|
+
"""
|
|
753
|
+
encoder = _EmformerEncoder(
|
|
754
|
+
input_dim=input_dim,
|
|
755
|
+
output_dim=encoding_dim,
|
|
756
|
+
segment_length=segment_length,
|
|
757
|
+
right_context_length=right_context_length,
|
|
758
|
+
time_reduction_input_dim=time_reduction_input_dim,
|
|
759
|
+
time_reduction_stride=time_reduction_stride,
|
|
760
|
+
transformer_num_heads=transformer_num_heads,
|
|
761
|
+
transformer_ffn_dim=transformer_ffn_dim,
|
|
762
|
+
transformer_num_layers=transformer_num_layers,
|
|
763
|
+
transformer_dropout=transformer_dropout,
|
|
764
|
+
transformer_activation=transformer_activation,
|
|
765
|
+
transformer_left_context_length=transformer_left_context_length,
|
|
766
|
+
transformer_max_memory_size=transformer_max_memory_size,
|
|
767
|
+
transformer_weight_init_scale_strategy=transformer_weight_init_scale_strategy,
|
|
768
|
+
transformer_tanh_on_mem=transformer_tanh_on_mem,
|
|
769
|
+
)
|
|
770
|
+
predictor = _Predictor(
|
|
771
|
+
num_symbols,
|
|
772
|
+
encoding_dim,
|
|
773
|
+
symbol_embedding_dim=symbol_embedding_dim,
|
|
774
|
+
num_lstm_layers=num_lstm_layers,
|
|
775
|
+
lstm_hidden_dim=symbol_embedding_dim,
|
|
776
|
+
lstm_layer_norm=lstm_layer_norm,
|
|
777
|
+
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
778
|
+
lstm_dropout=lstm_dropout,
|
|
779
|
+
)
|
|
780
|
+
joiner = _Joiner(encoding_dim, num_symbols)
|
|
781
|
+
return RNNT(encoder, predictor, joiner)
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def emformer_rnnt_base(num_symbols: int) -> RNNT:
|
|
785
|
+
r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
num_symbols (int): The size of target token lexicon.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
RNNT:
|
|
792
|
+
Emformer RNN-T model.
|
|
793
|
+
"""
|
|
794
|
+
return emformer_rnnt_model(
|
|
795
|
+
input_dim=80,
|
|
796
|
+
encoding_dim=1024,
|
|
797
|
+
num_symbols=num_symbols,
|
|
798
|
+
segment_length=16,
|
|
799
|
+
right_context_length=4,
|
|
800
|
+
time_reduction_input_dim=128,
|
|
801
|
+
time_reduction_stride=4,
|
|
802
|
+
transformer_num_heads=8,
|
|
803
|
+
transformer_ffn_dim=2048,
|
|
804
|
+
transformer_num_layers=20,
|
|
805
|
+
transformer_dropout=0.1,
|
|
806
|
+
transformer_activation="gelu",
|
|
807
|
+
transformer_left_context_length=30,
|
|
808
|
+
transformer_max_memory_size=0,
|
|
809
|
+
transformer_weight_init_scale_strategy="depthwise",
|
|
810
|
+
transformer_tanh_on_mem=True,
|
|
811
|
+
symbol_embedding_dim=512,
|
|
812
|
+
num_lstm_layers=3,
|
|
813
|
+
lstm_layer_norm=True,
|
|
814
|
+
lstm_layer_norm_epsilon=1e-3,
|
|
815
|
+
lstm_dropout=0.3,
|
|
816
|
+
)
|