torchaudio 2.8.0__cp313-cp313-win_amd64.whl → 2.9.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +179 -39
- torchaudio/_extension/__init__.py +1 -14
- torchaudio/_extension/utils.py +0 -47
- torchaudio/_internal/module_utils.py +12 -3
- torchaudio/_torchcodec.py +73 -85
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/utils.py +1 -1
- torchaudio/functional/__init__.py +0 -2
- torchaudio/functional/_alignment.py +1 -1
- torchaudio/functional/filtering.py +70 -55
- torchaudio/functional/functional.py +26 -60
- torchaudio/lib/_torchaudio.pyd +0 -0
- torchaudio/lib/libtorchaudio.pyd +0 -0
- torchaudio/models/decoder/__init__.py +14 -2
- torchaudio/models/decoder/_ctc_decoder.py +6 -6
- torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
- torchaudio/models/squim/objective.py +2 -2
- torchaudio/pipelines/_source_separation_pipeline.py +1 -1
- torchaudio/pipelines/_squim_pipeline.py +2 -2
- torchaudio/pipelines/_tts/utils.py +1 -1
- torchaudio/pipelines/rnnt_pipeline.py +4 -4
- torchaudio/transforms/__init__.py +1 -0
- torchaudio/transforms/_transforms.py +2 -2
- torchaudio/utils/__init__.py +2 -9
- torchaudio/utils/download.py +1 -3
- torchaudio/version.py +2 -2
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
- torchaudio-2.9.0.dist-info/RECORD +85 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
- torchaudio/_backend/__init__.py +0 -61
- torchaudio/_backend/backend.py +0 -53
- torchaudio/_backend/common.py +0 -52
- torchaudio/_backend/ffmpeg.py +0 -334
- torchaudio/_backend/soundfile.py +0 -54
- torchaudio/_backend/soundfile_backend.py +0 -457
- torchaudio/_backend/sox.py +0 -91
- torchaudio/_backend/utils.py +0 -350
- torchaudio/backend/__init__.py +0 -8
- torchaudio/backend/_no_backend.py +0 -25
- torchaudio/backend/_sox_io_backend.py +0 -294
- torchaudio/backend/common.py +0 -13
- torchaudio/backend/no_backend.py +0 -14
- torchaudio/backend/soundfile_backend.py +0 -14
- torchaudio/backend/sox_io_backend.py +0 -14
- torchaudio/io/__init__.py +0 -20
- torchaudio/io/_effector.py +0 -347
- torchaudio/io/_playback.py +0 -72
- torchaudio/kaldi_io.py +0 -150
- torchaudio/prototype/__init__.py +0 -0
- torchaudio/prototype/datasets/__init__.py +0 -4
- torchaudio/prototype/datasets/musan.py +0 -68
- torchaudio/prototype/functional/__init__.py +0 -26
- torchaudio/prototype/functional/_dsp.py +0 -441
- torchaudio/prototype/functional/_rir.py +0 -382
- torchaudio/prototype/functional/functional.py +0 -193
- torchaudio/prototype/models/__init__.py +0 -39
- torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
- torchaudio/prototype/models/_emformer_hubert.py +0 -337
- torchaudio/prototype/models/conv_emformer.py +0 -529
- torchaudio/prototype/models/hifi_gan.py +0 -342
- torchaudio/prototype/models/rnnt.py +0 -717
- torchaudio/prototype/models/rnnt_decoder.py +0 -402
- torchaudio/prototype/pipelines/__init__.py +0 -21
- torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
- torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
- torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
- torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
- torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
- torchaudio/prototype/transforms/__init__.py +0 -9
- torchaudio/prototype/transforms/_transforms.py +0 -461
- torchaudio/sox_effects/__init__.py +0 -10
- torchaudio/sox_effects/sox_effects.py +0 -275
- torchaudio/utils/ffmpeg_utils.py +0 -11
- torchaudio/utils/sox_utils.py +0 -118
- torchaudio-2.8.0.dist-info/RECORD +0 -145
- torio/__init__.py +0 -8
- torio/_extension/__init__.py +0 -13
- torio/_extension/utils.py +0 -147
- torio/io/__init__.py +0 -9
- torio/io/_streaming_media_decoder.py +0 -977
- torio/io/_streaming_media_encoder.py +0 -502
- torio/lib/__init__.py +0 -0
- torio/lib/_torio_ffmpeg4.pyd +0 -0
- torio/lib/_torio_ffmpeg5.pyd +0 -0
- torio/lib/_torio_ffmpeg6.pyd +0 -0
- torio/lib/libtorio_ffmpeg4.pyd +0 -0
- torio/lib/libtorio_ffmpeg5.pyd +0 -0
- torio/lib/libtorio_ffmpeg6.pyd +0 -0
- torio/utils/__init__.py +0 -4
- torio/utils/ffmpeg_utils.py +0 -275
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
- {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,717 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from typing import Dict, List, Optional, Tuple
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torchaudio.models import Conformer, RNNT
|
|
6
|
-
from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
|
|
7
|
-
|
|
8
|
-
from torchaudio._internal.module_utils import dropping_support
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
TrieNode = Tuple[Dict[int, "TrieNode"], int, Optional[Tuple[int, int]]]
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class _ConformerEncoder(torch.nn.Module, _Transcriber):
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
*,
|
|
18
|
-
input_dim: int,
|
|
19
|
-
output_dim: int,
|
|
20
|
-
time_reduction_stride: int,
|
|
21
|
-
conformer_input_dim: int,
|
|
22
|
-
conformer_ffn_dim: int,
|
|
23
|
-
conformer_num_layers: int,
|
|
24
|
-
conformer_num_heads: int,
|
|
25
|
-
conformer_depthwise_conv_kernel_size: int,
|
|
26
|
-
conformer_dropout: float,
|
|
27
|
-
) -> None:
|
|
28
|
-
super().__init__()
|
|
29
|
-
self.time_reduction = _TimeReduction(time_reduction_stride)
|
|
30
|
-
self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
|
|
31
|
-
self.conformer = Conformer(
|
|
32
|
-
num_layers=conformer_num_layers,
|
|
33
|
-
input_dim=conformer_input_dim,
|
|
34
|
-
ffn_dim=conformer_ffn_dim,
|
|
35
|
-
num_heads=conformer_num_heads,
|
|
36
|
-
depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
|
|
37
|
-
dropout=conformer_dropout,
|
|
38
|
-
use_group_norm=True,
|
|
39
|
-
convolution_first=True,
|
|
40
|
-
)
|
|
41
|
-
self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
|
|
42
|
-
self.layer_norm = torch.nn.LayerNorm(output_dim)
|
|
43
|
-
|
|
44
|
-
def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
45
|
-
time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
|
|
46
|
-
input_linear_out = self.input_linear(time_reduction_out)
|
|
47
|
-
x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
|
|
48
|
-
output_linear_out = self.output_linear(x)
|
|
49
|
-
layer_norm_out = self.layer_norm(output_linear_out)
|
|
50
|
-
return layer_norm_out, lengths
|
|
51
|
-
|
|
52
|
-
def infer(
|
|
53
|
-
self,
|
|
54
|
-
input: torch.Tensor,
|
|
55
|
-
lengths: torch.Tensor,
|
|
56
|
-
states: Optional[List[List[torch.Tensor]]],
|
|
57
|
-
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
|
|
58
|
-
raise RuntimeError("Conformer does not support streaming inference.")
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class _JoinerBiasing(torch.nn.Module):
|
|
62
|
-
r"""Recurrent neural network transducer (RNN-T) joint network.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
input_dim (int): source and target input dimension.
|
|
66
|
-
output_dim (int): output dimension.
|
|
67
|
-
activation (str, optional): activation function to use in the joiner.
|
|
68
|
-
Must be one of ("relu", "tanh"). (Default: "relu")
|
|
69
|
-
biasing (bool): perform biasing
|
|
70
|
-
deepbiasing (bool): perform deep biasing
|
|
71
|
-
attndim (int): dimension of the biasing vector hptr
|
|
72
|
-
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
input_dim: int,
|
|
78
|
-
output_dim: int,
|
|
79
|
-
activation: str = "relu",
|
|
80
|
-
biasing: bool = False,
|
|
81
|
-
deepbiasing: bool = False,
|
|
82
|
-
attndim: int = 1,
|
|
83
|
-
) -> None:
|
|
84
|
-
super().__init__()
|
|
85
|
-
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
|
|
86
|
-
self.biasing = biasing
|
|
87
|
-
self.deepbiasing = deepbiasing
|
|
88
|
-
if self.biasing and self.deepbiasing:
|
|
89
|
-
self.biasinglinear = torch.nn.Linear(attndim, input_dim, bias=True)
|
|
90
|
-
self.attndim = attndim
|
|
91
|
-
if activation == "relu":
|
|
92
|
-
self.activation = torch.nn.ReLU()
|
|
93
|
-
elif activation == "tanh":
|
|
94
|
-
self.activation = torch.nn.Tanh()
|
|
95
|
-
else:
|
|
96
|
-
raise ValueError(f"Unsupported activation {activation}")
|
|
97
|
-
|
|
98
|
-
def forward(
|
|
99
|
-
self,
|
|
100
|
-
source_encodings: torch.Tensor,
|
|
101
|
-
source_lengths: torch.Tensor,
|
|
102
|
-
target_encodings: torch.Tensor,
|
|
103
|
-
target_lengths: torch.Tensor,
|
|
104
|
-
hptr: torch.Tensor = None,
|
|
105
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
106
|
-
r"""Forward pass for training.
|
|
107
|
-
|
|
108
|
-
B: batch size;
|
|
109
|
-
T: maximum source sequence length in batch;
|
|
110
|
-
U: maximum target sequence length in batch;
|
|
111
|
-
D: dimension of each source and target sequence encoding.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
source_encodings (torch.Tensor): source encoding sequences, with
|
|
115
|
-
shape `(B, T, D)`.
|
|
116
|
-
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
117
|
-
valid sequence length of i-th batch element in ``source_encodings``.
|
|
118
|
-
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
119
|
-
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
120
|
-
valid sequence length of i-th batch element in ``target_encodings``.
|
|
121
|
-
hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
|
|
122
|
-
|
|
123
|
-
Returns:
|
|
124
|
-
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
125
|
-
torch.Tensor
|
|
126
|
-
joint network output, with shape `(B, T, U, output_dim)`.
|
|
127
|
-
torch.Tensor
|
|
128
|
-
output source lengths, with shape `(B,)` and i-th element representing
|
|
129
|
-
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
130
|
-
torch.Tensor
|
|
131
|
-
output target lengths, with shape `(B,)` and i-th element representing
|
|
132
|
-
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
133
|
-
torch.Tensor
|
|
134
|
-
joint network second last layer output (i.e. before self.linear), with shape `(B, T, U, D)`.
|
|
135
|
-
"""
|
|
136
|
-
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
|
|
137
|
-
if self.biasing and self.deepbiasing and hptr is not None:
|
|
138
|
-
hptr = self.biasinglinear(hptr)
|
|
139
|
-
joint_encodings += hptr
|
|
140
|
-
elif self.biasing and self.deepbiasing:
|
|
141
|
-
# Hack here for unused parameters
|
|
142
|
-
joint_encodings += self.biasinglinear(joint_encodings.new_zeros(1, self.attndim)).mean() * 0
|
|
143
|
-
activation_out = self.activation(joint_encodings)
|
|
144
|
-
output = self.linear(activation_out)
|
|
145
|
-
return output, source_lengths, target_lengths, activation_out
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
class RNNTBiasing(RNNT):
|
|
149
|
-
r"""torchaudio.models.RNNT()
|
|
150
|
-
|
|
151
|
-
Recurrent neural network transducer (RNN-T) model.
|
|
152
|
-
|
|
153
|
-
Note:
|
|
154
|
-
To build the model, please use one of the factory functions.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
transcriber (torch.nn.Module): transcription network.
|
|
158
|
-
predictor (torch.nn.Module): prediction network.
|
|
159
|
-
joiner (torch.nn.Module): joint network.
|
|
160
|
-
attndim (int): TCPGen attention dimension
|
|
161
|
-
biasing (bool): If true, use biasing, otherwise use standard RNN-T
|
|
162
|
-
deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
|
|
163
|
-
embdim (int): dimension of symbol embeddings
|
|
164
|
-
jointdim (int): dimension of the joint network joint dimension
|
|
165
|
-
charlist (list): The list of word piece tokens in the same order as the output layer
|
|
166
|
-
encoutdim (int): dimension of the encoder output vectors
|
|
167
|
-
dropout_tcpgen (float): dropout rate for TCPGen
|
|
168
|
-
tcpsche (int): The epoch at which TCPGen starts to train
|
|
169
|
-
DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
|
|
170
|
-
"""
|
|
171
|
-
|
|
172
|
-
def __init__(
|
|
173
|
-
self,
|
|
174
|
-
transcriber: _Transcriber,
|
|
175
|
-
predictor: _Predictor,
|
|
176
|
-
joiner: _Joiner,
|
|
177
|
-
attndim: int,
|
|
178
|
-
biasing: bool,
|
|
179
|
-
deepbiasing: bool,
|
|
180
|
-
embdim: int,
|
|
181
|
-
jointdim: int,
|
|
182
|
-
charlist: List[str],
|
|
183
|
-
encoutdim: int,
|
|
184
|
-
dropout_tcpgen: float,
|
|
185
|
-
tcpsche: int,
|
|
186
|
-
DBaverage: bool,
|
|
187
|
-
) -> None:
|
|
188
|
-
super().__init__(transcriber, predictor, joiner)
|
|
189
|
-
self.attndim = attndim
|
|
190
|
-
self.deepbiasing = deepbiasing
|
|
191
|
-
self.jointdim = jointdim
|
|
192
|
-
self.embdim = embdim
|
|
193
|
-
self.encoutdim = encoutdim
|
|
194
|
-
self.char_list = charlist or []
|
|
195
|
-
self.blank_idx = self.char_list.index("<blank>")
|
|
196
|
-
self.nchars = len(self.char_list)
|
|
197
|
-
self.DBaverage = DBaverage
|
|
198
|
-
self.biasing = biasing
|
|
199
|
-
if self.biasing:
|
|
200
|
-
if self.deepbiasing and self.DBaverage:
|
|
201
|
-
# Deep biasing without TCPGen
|
|
202
|
-
self.biasingemb = torch.nn.Linear(self.nchars, self.attndim, bias=False)
|
|
203
|
-
else:
|
|
204
|
-
# TCPGen parameters
|
|
205
|
-
self.ooKBemb = torch.nn.Embedding(1, self.embdim)
|
|
206
|
-
self.Qproj_char = torch.nn.Linear(self.embdim, self.attndim)
|
|
207
|
-
self.Qproj_acoustic = torch.nn.Linear(self.encoutdim, self.attndim)
|
|
208
|
-
self.Kproj = torch.nn.Linear(self.embdim, self.attndim)
|
|
209
|
-
self.pointer_gate = torch.nn.Linear(self.attndim + self.jointdim, 1)
|
|
210
|
-
self.dropout_tcpgen = torch.nn.Dropout(dropout_tcpgen)
|
|
211
|
-
self.tcpsche = tcpsche
|
|
212
|
-
|
|
213
|
-
def forward(
|
|
214
|
-
self,
|
|
215
|
-
sources: torch.Tensor,
|
|
216
|
-
source_lengths: torch.Tensor,
|
|
217
|
-
targets: torch.Tensor,
|
|
218
|
-
target_lengths: torch.Tensor,
|
|
219
|
-
tries: TrieNode,
|
|
220
|
-
current_epoch: int,
|
|
221
|
-
predictor_state: Optional[List[List[torch.Tensor]]] = None,
|
|
222
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]], torch.Tensor, torch.Tensor]:
|
|
223
|
-
r"""Forward pass for training.
|
|
224
|
-
|
|
225
|
-
B: batch size;
|
|
226
|
-
T: maximum source sequence length in batch;
|
|
227
|
-
U: maximum target sequence length in batch;
|
|
228
|
-
D: feature dimension of each source sequence element.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
sources (torch.Tensor): source frame sequences right-padded with right context, with
|
|
232
|
-
shape `(B, T, D)`.
|
|
233
|
-
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
234
|
-
number of valid frames for i-th batch element in ``sources``.
|
|
235
|
-
targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
|
|
236
|
-
mapping to a target symbol.
|
|
237
|
-
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
238
|
-
number of valid frames for i-th batch element in ``targets``.
|
|
239
|
-
tries (TrieNode): wordpiece prefix trees representing the biasing list to be searched
|
|
240
|
-
current_epoch (Int): the current epoch number to determine if TCPGen should be trained
|
|
241
|
-
at this epoch
|
|
242
|
-
predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
|
|
243
|
-
representing prediction network internal state generated in preceding invocation
|
|
244
|
-
of ``forward``. (Default: ``None``)
|
|
245
|
-
|
|
246
|
-
Returns:
|
|
247
|
-
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
|
|
248
|
-
torch.Tensor
|
|
249
|
-
joint network output, with shape
|
|
250
|
-
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
|
|
251
|
-
torch.Tensor
|
|
252
|
-
output source lengths, with shape `(B,)` and i-th element representing
|
|
253
|
-
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
254
|
-
torch.Tensor
|
|
255
|
-
output target lengths, with shape `(B,)` and i-th element representing
|
|
256
|
-
number of valid elements along dim 2 for i-th batch element in joint network output.
|
|
257
|
-
List[List[torch.Tensor]]
|
|
258
|
-
output states; list of lists of tensors
|
|
259
|
-
representing prediction network internal state generated in current invocation
|
|
260
|
-
of ``forward``.
|
|
261
|
-
torch.Tensor
|
|
262
|
-
TCPGen distribution, with shape
|
|
263
|
-
`(B, max output source length, max output target length, output_dim (number of target symbols))`.
|
|
264
|
-
torch.Tensor
|
|
265
|
-
Generation probability (or copy probability), with shape
|
|
266
|
-
`(B, max output source length, max output target length, 1)`.
|
|
267
|
-
"""
|
|
268
|
-
source_encodings, source_lengths = self.transcriber(
|
|
269
|
-
input=sources,
|
|
270
|
-
lengths=source_lengths,
|
|
271
|
-
)
|
|
272
|
-
target_encodings, target_lengths, predictor_state = self.predictor(
|
|
273
|
-
input=targets,
|
|
274
|
-
lengths=target_lengths,
|
|
275
|
-
state=predictor_state,
|
|
276
|
-
)
|
|
277
|
-
# Forward TCPGen
|
|
278
|
-
hptr = None
|
|
279
|
-
tcpgen_dist, p_gen = None, None
|
|
280
|
-
if self.biasing and current_epoch >= self.tcpsche and tries != []:
|
|
281
|
-
ptrdist_mask, p_gen_mask = self.get_tcpgen_step_masks(targets, tries)
|
|
282
|
-
hptr, tcpgen_dist = self.forward_tcpgen(targets, ptrdist_mask, source_encodings)
|
|
283
|
-
hptr = self.dropout_tcpgen(hptr)
|
|
284
|
-
elif self.biasing:
|
|
285
|
-
# Hack here to bypass unused parameters
|
|
286
|
-
if self.DBaverage and self.deepbiasing:
|
|
287
|
-
dummy = self.biasingemb(source_encodings.new_zeros(1, len(self.char_list))).mean()
|
|
288
|
-
else:
|
|
289
|
-
dummy = source_encodings.new_zeros(1, self.embdim)
|
|
290
|
-
dummy = self.Qproj_char(dummy).mean()
|
|
291
|
-
dummy += self.Qproj_acoustic(source_encodings.new_zeros(1, source_encodings.size(-1))).mean()
|
|
292
|
-
dummy += self.Kproj(source_encodings.new_zeros(1, self.embdim)).mean()
|
|
293
|
-
dummy += self.pointer_gate(source_encodings.new_zeros(1, self.attndim + self.jointdim)).mean()
|
|
294
|
-
dummy += self.ooKBemb.weight.mean()
|
|
295
|
-
dummy = dummy * 0
|
|
296
|
-
source_encodings += dummy
|
|
297
|
-
|
|
298
|
-
output, source_lengths, target_lengths, jointer_activation = self.joiner(
|
|
299
|
-
source_encodings=source_encodings,
|
|
300
|
-
source_lengths=source_lengths,
|
|
301
|
-
target_encodings=target_encodings,
|
|
302
|
-
target_lengths=target_lengths,
|
|
303
|
-
hptr=hptr,
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
# Calculate Generation Probability
|
|
307
|
-
if self.biasing and hptr is not None and tcpgen_dist is not None:
|
|
308
|
-
p_gen = torch.sigmoid(self.pointer_gate(torch.cat((jointer_activation, hptr), dim=-1)))
|
|
309
|
-
# avoid collapsing to ooKB token in the first few updates
|
|
310
|
-
# if current_epoch == self.tcpsche:
|
|
311
|
-
# p_gen = p_gen * 0.1
|
|
312
|
-
p_gen = p_gen.masked_fill(p_gen_mask.bool().unsqueeze(1).unsqueeze(-1), 0)
|
|
313
|
-
|
|
314
|
-
return (output, source_lengths, target_lengths, predictor_state, tcpgen_dist, p_gen)
|
|
315
|
-
|
|
316
|
-
def get_tcpgen_distribution(self, query, ptrdist_mask):
|
|
317
|
-
# Make use of the predictor embedding matrix
|
|
318
|
-
keyvalues = torch.cat([self.predictor.embedding.weight.data, self.ooKBemb.weight], dim=0)
|
|
319
|
-
keyvalues = self.dropout_tcpgen(self.Kproj(keyvalues))
|
|
320
|
-
# B * T * U * attndim, nbpe * attndim -> B * T * U * nbpe
|
|
321
|
-
tcpgendist = torch.einsum("ntuj,ij->ntui", query, keyvalues)
|
|
322
|
-
tcpgendist = tcpgendist / math.sqrt(query.size(-1))
|
|
323
|
-
ptrdist_mask = ptrdist_mask.unsqueeze(1).repeat(1, tcpgendist.size(1), 1, 1)
|
|
324
|
-
tcpgendist.masked_fill_(ptrdist_mask.bool(), -1e9)
|
|
325
|
-
tcpgendist = torch.nn.functional.softmax(tcpgendist, dim=-1)
|
|
326
|
-
# B * T * U * nbpe, nbpe * attndim -> B * T * U * attndim
|
|
327
|
-
hptr = torch.einsum("ntui,ij->ntuj", tcpgendist[:, :, :, :-1], keyvalues[:-1, :])
|
|
328
|
-
return hptr, tcpgendist
|
|
329
|
-
|
|
330
|
-
def forward_tcpgen(self, targets, ptrdist_mask, source_encodings):
|
|
331
|
-
tcpgen_dist = None
|
|
332
|
-
if self.DBaverage and self.deepbiasing:
|
|
333
|
-
hptr = self.biasingemb(1 - ptrdist_mask[:, :, :-1].float()).unsqueeze(1)
|
|
334
|
-
else:
|
|
335
|
-
query_char = self.predictor.embedding(targets)
|
|
336
|
-
query_char = self.Qproj_char(query_char).unsqueeze(1) # B * 1 * U * attndim
|
|
337
|
-
query_acoustic = self.Qproj_acoustic(source_encodings).unsqueeze(2) # B * T * 1 * attndim
|
|
338
|
-
query = query_char + query_acoustic # B * T * U * attndim
|
|
339
|
-
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, ptrdist_mask)
|
|
340
|
-
return hptr, tcpgen_dist
|
|
341
|
-
|
|
342
|
-
def get_tcpgen_step_masks(self, yseqs, resettrie):
|
|
343
|
-
seqlen = len(yseqs[0])
|
|
344
|
-
batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
|
|
345
|
-
p_gen_masks = []
|
|
346
|
-
for i, yseq in enumerate(yseqs):
|
|
347
|
-
new_tree = resettrie
|
|
348
|
-
p_gen_mask = []
|
|
349
|
-
for j, vy in enumerate(yseq):
|
|
350
|
-
vy = vy.item()
|
|
351
|
-
new_tree = new_tree[0]
|
|
352
|
-
if vy in [self.blank_idx]:
|
|
353
|
-
new_tree = resettrie
|
|
354
|
-
p_gen_mask.append(0)
|
|
355
|
-
elif self.char_list[vy].endswith("▁"):
|
|
356
|
-
if vy in new_tree and new_tree[vy][0] != {}:
|
|
357
|
-
new_tree = new_tree[vy]
|
|
358
|
-
else:
|
|
359
|
-
new_tree = resettrie
|
|
360
|
-
p_gen_mask.append(0)
|
|
361
|
-
elif vy not in new_tree:
|
|
362
|
-
new_tree = [{}]
|
|
363
|
-
p_gen_mask.append(1)
|
|
364
|
-
else:
|
|
365
|
-
new_tree = new_tree[vy]
|
|
366
|
-
p_gen_mask.append(0)
|
|
367
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
368
|
-
# In the original paper, ooKB node was not masked
|
|
369
|
-
# In this implementation, if not masking ooKB, ooKB probability
|
|
370
|
-
# would quickly collapse to 1.0 in the first few updates.
|
|
371
|
-
# Haven't found out why this happened.
|
|
372
|
-
# batch_masks[i, j, -1] = 0
|
|
373
|
-
p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
|
|
374
|
-
p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
|
|
375
|
-
return batch_masks, p_gen_masks
|
|
376
|
-
|
|
377
|
-
def get_tcpgen_step_masks_prefix(self, yseqs, resettrie):
|
|
378
|
-
# Implemented for prefix-based wordpieces, not tested yet
|
|
379
|
-
seqlen = len(yseqs[0])
|
|
380
|
-
batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
|
|
381
|
-
p_gen_masks = []
|
|
382
|
-
for i, yseq in enumerate(yseqs):
|
|
383
|
-
p_gen_mask = []
|
|
384
|
-
new_tree = resettrie
|
|
385
|
-
for j, vy in enumerate(yseq):
|
|
386
|
-
vy = vy.item()
|
|
387
|
-
new_tree = new_tree[0]
|
|
388
|
-
if vy in [self.blank_idx]:
|
|
389
|
-
new_tree = resettrie
|
|
390
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
391
|
-
elif self.char_list[vy].startswith("▁"):
|
|
392
|
-
new_tree = resettrie
|
|
393
|
-
if vy not in new_tree[0]:
|
|
394
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
395
|
-
else:
|
|
396
|
-
new_tree = new_tree[0][vy]
|
|
397
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
398
|
-
if new_tree[1] != -1:
|
|
399
|
-
batch_masks[i, j, list(resettrie[0].keys())] = 0
|
|
400
|
-
else:
|
|
401
|
-
if vy not in new_tree:
|
|
402
|
-
new_tree = resettrie
|
|
403
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
404
|
-
else:
|
|
405
|
-
new_tree = new_tree[vy]
|
|
406
|
-
batch_masks[i, j, list(new_tree[0].keys())] = 0
|
|
407
|
-
if new_tree[1] != -1:
|
|
408
|
-
batch_masks[i, j, list(resettrie[0].keys())] = 0
|
|
409
|
-
p_gen_mask.append(0)
|
|
410
|
-
# batch_masks[i, j, -1] = 0
|
|
411
|
-
p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
|
|
412
|
-
p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
|
|
413
|
-
|
|
414
|
-
return batch_masks, p_gen_masks
|
|
415
|
-
|
|
416
|
-
def get_tcpgen_step(self, vy, trie, resettrie):
|
|
417
|
-
new_tree = trie[0]
|
|
418
|
-
if vy in [self.blank_idx]:
|
|
419
|
-
new_tree = resettrie
|
|
420
|
-
elif self.char_list[vy].endswith("▁"):
|
|
421
|
-
if vy in new_tree and new_tree[vy][0] != {}:
|
|
422
|
-
new_tree = new_tree[vy]
|
|
423
|
-
else:
|
|
424
|
-
new_tree = resettrie
|
|
425
|
-
elif vy not in new_tree:
|
|
426
|
-
new_tree = [{}]
|
|
427
|
-
else:
|
|
428
|
-
new_tree = new_tree[vy]
|
|
429
|
-
return new_tree
|
|
430
|
-
|
|
431
|
-
def join(
|
|
432
|
-
self,
|
|
433
|
-
source_encodings: torch.Tensor,
|
|
434
|
-
source_lengths: torch.Tensor,
|
|
435
|
-
target_encodings: torch.Tensor,
|
|
436
|
-
target_lengths: torch.Tensor,
|
|
437
|
-
hptr: torch.Tensor = None,
|
|
438
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
439
|
-
r"""Applies joint network to source and target encodings.
|
|
440
|
-
|
|
441
|
-
B: batch size;
|
|
442
|
-
T: maximum source sequence length in batch;
|
|
443
|
-
U: maximum target sequence length in batch;
|
|
444
|
-
D: dimension of each source and target sequence encoding.
|
|
445
|
-
A: TCPGen attention dimension
|
|
446
|
-
|
|
447
|
-
Args:
|
|
448
|
-
source_encodings (torch.Tensor): source encoding sequences, with
|
|
449
|
-
shape `(B, T, D)`.
|
|
450
|
-
source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
451
|
-
valid sequence length of i-th batch element in ``source_encodings``.
|
|
452
|
-
target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
|
|
453
|
-
target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
|
|
454
|
-
valid sequence length of i-th batch element in ``target_encodings``.
|
|
455
|
-
hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
|
|
456
|
-
|
|
457
|
-
Returns:
|
|
458
|
-
(torch.Tensor, torch.Tensor, torch.Tensor):
|
|
459
|
-
torch.Tensor
|
|
460
|
-
joint network output, with shape `(B, T, U, output_dim)`.
|
|
461
|
-
torch.Tensor
|
|
462
|
-
output source lengths, with shape `(B,)` and i-th element representing
|
|
463
|
-
number of valid elements along dim 1 for i-th batch element in joint network output.
|
|
464
|
-
torch.Tensor
|
|
465
|
-
joint network second last layer output, with shape `(B, T, U, D)`.
|
|
466
|
-
"""
|
|
467
|
-
output, source_lengths, target_lengths, jointer_activation = self.joiner(
|
|
468
|
-
source_encodings=source_encodings,
|
|
469
|
-
source_lengths=source_lengths,
|
|
470
|
-
target_encodings=target_encodings,
|
|
471
|
-
target_lengths=target_lengths,
|
|
472
|
-
hptr=hptr,
|
|
473
|
-
)
|
|
474
|
-
return output, source_lengths, jointer_activation
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
@dropping_support
|
|
478
|
-
def conformer_rnnt_model(
|
|
479
|
-
*,
|
|
480
|
-
input_dim: int,
|
|
481
|
-
encoding_dim: int,
|
|
482
|
-
time_reduction_stride: int,
|
|
483
|
-
conformer_input_dim: int,
|
|
484
|
-
conformer_ffn_dim: int,
|
|
485
|
-
conformer_num_layers: int,
|
|
486
|
-
conformer_num_heads: int,
|
|
487
|
-
conformer_depthwise_conv_kernel_size: int,
|
|
488
|
-
conformer_dropout: float,
|
|
489
|
-
num_symbols: int,
|
|
490
|
-
symbol_embedding_dim: int,
|
|
491
|
-
num_lstm_layers: int,
|
|
492
|
-
lstm_hidden_dim: int,
|
|
493
|
-
lstm_layer_norm: int,
|
|
494
|
-
lstm_layer_norm_epsilon: int,
|
|
495
|
-
lstm_dropout: int,
|
|
496
|
-
joiner_activation: str,
|
|
497
|
-
) -> RNNT:
|
|
498
|
-
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
|
|
499
|
-
|
|
500
|
-
Args:
|
|
501
|
-
input_dim (int): dimension of input sequence frames passed to transcription network.
|
|
502
|
-
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
|
|
503
|
-
passed to joint network.
|
|
504
|
-
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
505
|
-
conformer_input_dim (int): dimension of Conformer input.
|
|
506
|
-
conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
|
|
507
|
-
conformer_num_layers (int): number of Conformer layers to instantiate.
|
|
508
|
-
conformer_num_heads (int): number of attention heads in each Conformer layer.
|
|
509
|
-
conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
|
|
510
|
-
conformer_dropout (float): Conformer dropout probability.
|
|
511
|
-
num_symbols (int): cardinality of set of target tokens.
|
|
512
|
-
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
513
|
-
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
514
|
-
lstm_hidden_dim (int): output dimension of each LSTM layer.
|
|
515
|
-
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
|
|
516
|
-
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
|
|
517
|
-
lstm_dropout (float): LSTM dropout probability.
|
|
518
|
-
joiner_activation (str): activation function to use in the joiner.
|
|
519
|
-
Must be one of ("relu", "tanh"). (Default: "relu")
|
|
520
|
-
|
|
521
|
-
Returns:
|
|
522
|
-
RNNT:
|
|
523
|
-
Conformer RNN-T model.
|
|
524
|
-
"""
|
|
525
|
-
encoder = _ConformerEncoder(
|
|
526
|
-
input_dim=input_dim,
|
|
527
|
-
output_dim=encoding_dim,
|
|
528
|
-
time_reduction_stride=time_reduction_stride,
|
|
529
|
-
conformer_input_dim=conformer_input_dim,
|
|
530
|
-
conformer_ffn_dim=conformer_ffn_dim,
|
|
531
|
-
conformer_num_layers=conformer_num_layers,
|
|
532
|
-
conformer_num_heads=conformer_num_heads,
|
|
533
|
-
conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
|
|
534
|
-
conformer_dropout=conformer_dropout,
|
|
535
|
-
)
|
|
536
|
-
predictor = _Predictor(
|
|
537
|
-
num_symbols=num_symbols,
|
|
538
|
-
output_dim=encoding_dim,
|
|
539
|
-
symbol_embedding_dim=symbol_embedding_dim,
|
|
540
|
-
num_lstm_layers=num_lstm_layers,
|
|
541
|
-
lstm_hidden_dim=lstm_hidden_dim,
|
|
542
|
-
lstm_layer_norm=lstm_layer_norm,
|
|
543
|
-
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
544
|
-
lstm_dropout=lstm_dropout,
|
|
545
|
-
)
|
|
546
|
-
joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
|
|
547
|
-
return RNNT(encoder, predictor, joiner)
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
@dropping_support
|
|
551
|
-
def conformer_rnnt_base() -> RNNT:
|
|
552
|
-
r"""Builds basic version of Conformer RNN-T model.
|
|
553
|
-
|
|
554
|
-
Returns:
|
|
555
|
-
RNNT:
|
|
556
|
-
Conformer RNN-T model.
|
|
557
|
-
"""
|
|
558
|
-
return conformer_rnnt_model(
|
|
559
|
-
input_dim=80,
|
|
560
|
-
encoding_dim=1024,
|
|
561
|
-
time_reduction_stride=4,
|
|
562
|
-
conformer_input_dim=256,
|
|
563
|
-
conformer_ffn_dim=1024,
|
|
564
|
-
conformer_num_layers=16,
|
|
565
|
-
conformer_num_heads=4,
|
|
566
|
-
conformer_depthwise_conv_kernel_size=31,
|
|
567
|
-
conformer_dropout=0.1,
|
|
568
|
-
num_symbols=1024,
|
|
569
|
-
symbol_embedding_dim=256,
|
|
570
|
-
num_lstm_layers=2,
|
|
571
|
-
lstm_hidden_dim=512,
|
|
572
|
-
lstm_layer_norm=True,
|
|
573
|
-
lstm_layer_norm_epsilon=1e-5,
|
|
574
|
-
lstm_dropout=0.3,
|
|
575
|
-
joiner_activation="tanh",
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
@dropping_support
|
|
580
|
-
def conformer_rnnt_biasing(
|
|
581
|
-
*,
|
|
582
|
-
input_dim: int,
|
|
583
|
-
encoding_dim: int,
|
|
584
|
-
time_reduction_stride: int,
|
|
585
|
-
conformer_input_dim: int,
|
|
586
|
-
conformer_ffn_dim: int,
|
|
587
|
-
conformer_num_layers: int,
|
|
588
|
-
conformer_num_heads: int,
|
|
589
|
-
conformer_depthwise_conv_kernel_size: int,
|
|
590
|
-
conformer_dropout: float,
|
|
591
|
-
num_symbols: int,
|
|
592
|
-
symbol_embedding_dim: int,
|
|
593
|
-
num_lstm_layers: int,
|
|
594
|
-
lstm_hidden_dim: int,
|
|
595
|
-
lstm_layer_norm: int,
|
|
596
|
-
lstm_layer_norm_epsilon: int,
|
|
597
|
-
lstm_dropout: int,
|
|
598
|
-
joiner_activation: str,
|
|
599
|
-
attndim: int,
|
|
600
|
-
biasing: bool,
|
|
601
|
-
charlist: List[str],
|
|
602
|
-
deepbiasing: bool,
|
|
603
|
-
tcpsche: int,
|
|
604
|
-
DBaverage: bool,
|
|
605
|
-
) -> RNNTBiasing:
|
|
606
|
-
r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
|
|
607
|
-
|
|
608
|
-
Args:
|
|
609
|
-
input_dim (int): dimension of input sequence frames passed to transcription network.
|
|
610
|
-
encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
|
|
611
|
-
passed to joint network.
|
|
612
|
-
time_reduction_stride (int): factor by which to reduce length of input sequence.
|
|
613
|
-
conformer_input_dim (int): dimension of Conformer input.
|
|
614
|
-
conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
|
|
615
|
-
conformer_num_layers (int): number of Conformer layers to instantiate.
|
|
616
|
-
conformer_num_heads (int): number of attention heads in each Conformer layer.
|
|
617
|
-
conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
|
|
618
|
-
conformer_dropout (float): Conformer dropout probability.
|
|
619
|
-
num_symbols (int): cardinality of set of target tokens.
|
|
620
|
-
symbol_embedding_dim (int): dimension of each target token embedding.
|
|
621
|
-
num_lstm_layers (int): number of LSTM layers to instantiate.
|
|
622
|
-
lstm_hidden_dim (int): output dimension of each LSTM layer.
|
|
623
|
-
lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
|
|
624
|
-
lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
|
|
625
|
-
lstm_dropout (float): LSTM dropout probability.
|
|
626
|
-
joiner_activation (str): activation function to use in the joiner.
|
|
627
|
-
Must be one of ("relu", "tanh"). (Default: "relu")
|
|
628
|
-
attndim (int): TCPGen attention dimension
|
|
629
|
-
biasing (bool): If true, use biasing, otherwise use standard RNN-T
|
|
630
|
-
charlist (list): The list of word piece tokens in the same order as the output layer
|
|
631
|
-
deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
|
|
632
|
-
tcpsche (int): The epoch at which TCPGen starts to train
|
|
633
|
-
DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
|
|
634
|
-
|
|
635
|
-
Returns:
|
|
636
|
-
RNNT:
|
|
637
|
-
Conformer RNN-T model with TCPGen-based biasing support.
|
|
638
|
-
"""
|
|
639
|
-
encoder = _ConformerEncoder(
|
|
640
|
-
input_dim=input_dim,
|
|
641
|
-
output_dim=encoding_dim,
|
|
642
|
-
time_reduction_stride=time_reduction_stride,
|
|
643
|
-
conformer_input_dim=conformer_input_dim,
|
|
644
|
-
conformer_ffn_dim=conformer_ffn_dim,
|
|
645
|
-
conformer_num_layers=conformer_num_layers,
|
|
646
|
-
conformer_num_heads=conformer_num_heads,
|
|
647
|
-
conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
|
|
648
|
-
conformer_dropout=conformer_dropout,
|
|
649
|
-
)
|
|
650
|
-
predictor = _Predictor(
|
|
651
|
-
num_symbols=num_symbols,
|
|
652
|
-
output_dim=encoding_dim,
|
|
653
|
-
symbol_embedding_dim=symbol_embedding_dim,
|
|
654
|
-
num_lstm_layers=num_lstm_layers,
|
|
655
|
-
lstm_hidden_dim=lstm_hidden_dim,
|
|
656
|
-
lstm_layer_norm=lstm_layer_norm,
|
|
657
|
-
lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
|
|
658
|
-
lstm_dropout=lstm_dropout,
|
|
659
|
-
)
|
|
660
|
-
joiner = _JoinerBiasing(
|
|
661
|
-
encoding_dim,
|
|
662
|
-
num_symbols,
|
|
663
|
-
activation=joiner_activation,
|
|
664
|
-
deepbiasing=deepbiasing,
|
|
665
|
-
attndim=attndim,
|
|
666
|
-
biasing=biasing,
|
|
667
|
-
)
|
|
668
|
-
return RNNTBiasing(
|
|
669
|
-
encoder,
|
|
670
|
-
predictor,
|
|
671
|
-
joiner,
|
|
672
|
-
attndim,
|
|
673
|
-
biasing,
|
|
674
|
-
deepbiasing,
|
|
675
|
-
symbol_embedding_dim,
|
|
676
|
-
encoding_dim,
|
|
677
|
-
charlist,
|
|
678
|
-
encoding_dim,
|
|
679
|
-
conformer_dropout,
|
|
680
|
-
tcpsche,
|
|
681
|
-
DBaverage,
|
|
682
|
-
)
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
@dropping_support
|
|
686
|
-
def conformer_rnnt_biasing_base(charlist=None, biasing=True) -> RNNT:
|
|
687
|
-
r"""Builds basic version of Conformer RNN-T model with TCPGen.
|
|
688
|
-
|
|
689
|
-
Returns:
|
|
690
|
-
RNNT:
|
|
691
|
-
Conformer RNN-T model with TCPGen-based biasing support.
|
|
692
|
-
"""
|
|
693
|
-
return conformer_rnnt_biasing(
|
|
694
|
-
input_dim=80,
|
|
695
|
-
encoding_dim=576,
|
|
696
|
-
time_reduction_stride=4,
|
|
697
|
-
conformer_input_dim=144,
|
|
698
|
-
conformer_ffn_dim=576,
|
|
699
|
-
conformer_num_layers=16,
|
|
700
|
-
conformer_num_heads=4,
|
|
701
|
-
conformer_depthwise_conv_kernel_size=31,
|
|
702
|
-
conformer_dropout=0.1,
|
|
703
|
-
num_symbols=601,
|
|
704
|
-
symbol_embedding_dim=256,
|
|
705
|
-
num_lstm_layers=1,
|
|
706
|
-
lstm_hidden_dim=320,
|
|
707
|
-
lstm_layer_norm=True,
|
|
708
|
-
lstm_layer_norm_epsilon=1e-5,
|
|
709
|
-
lstm_dropout=0.3,
|
|
710
|
-
joiner_activation="tanh",
|
|
711
|
-
attndim=256,
|
|
712
|
-
biasing=biasing,
|
|
713
|
-
charlist=charlist,
|
|
714
|
-
deepbiasing=True,
|
|
715
|
-
tcpsche=30,
|
|
716
|
-
DBaverage=False,
|
|
717
|
-
)
|