torchaudio 2.8.0__cp313-cp313t-win_amd64.whl → 2.9.0__cp313-cp313t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +179 -39
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +12 -3
  5. torchaudio/_torchcodec.py +73 -85
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +0 -2
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +26 -60
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +14 -2
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +1 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +1 -0
  23. torchaudio/transforms/_transforms.py +2 -2
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -3
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -350
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -20
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -150
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -68
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -441
  54. torchaudio/prototype/functional/_rir.py +0 -382
  55. torchaudio/prototype/functional/functional.py +0 -193
  56. torchaudio/prototype/models/__init__.py +0 -39
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -337
  59. torchaudio/prototype/models/conv_emformer.py +0 -529
  60. torchaudio/prototype/models/hifi_gan.py +0 -342
  61. torchaudio/prototype/models/rnnt.py +0 -717
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -402
  63. torchaudio/prototype/pipelines/__init__.py +0 -21
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -461
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -275
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -118
  75. torchaudio-2.8.0.dist-info/RECORD +0 -145
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -977
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -275
  91. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/LICENSE +0 -0
  92. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
@@ -1,717 +0,0 @@
1
- import math
2
- from typing import Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from torchaudio.models import Conformer, RNNT
6
- from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
7
-
8
- from torchaudio._internal.module_utils import dropping_support
9
-
10
-
11
- TrieNode = Tuple[Dict[int, "TrieNode"], int, Optional[Tuple[int, int]]]
12
-
13
-
14
- class _ConformerEncoder(torch.nn.Module, _Transcriber):
15
- def __init__(
16
- self,
17
- *,
18
- input_dim: int,
19
- output_dim: int,
20
- time_reduction_stride: int,
21
- conformer_input_dim: int,
22
- conformer_ffn_dim: int,
23
- conformer_num_layers: int,
24
- conformer_num_heads: int,
25
- conformer_depthwise_conv_kernel_size: int,
26
- conformer_dropout: float,
27
- ) -> None:
28
- super().__init__()
29
- self.time_reduction = _TimeReduction(time_reduction_stride)
30
- self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
31
- self.conformer = Conformer(
32
- num_layers=conformer_num_layers,
33
- input_dim=conformer_input_dim,
34
- ffn_dim=conformer_ffn_dim,
35
- num_heads=conformer_num_heads,
36
- depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
37
- dropout=conformer_dropout,
38
- use_group_norm=True,
39
- convolution_first=True,
40
- )
41
- self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
42
- self.layer_norm = torch.nn.LayerNorm(output_dim)
43
-
44
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
45
- time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
46
- input_linear_out = self.input_linear(time_reduction_out)
47
- x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
48
- output_linear_out = self.output_linear(x)
49
- layer_norm_out = self.layer_norm(output_linear_out)
50
- return layer_norm_out, lengths
51
-
52
- def infer(
53
- self,
54
- input: torch.Tensor,
55
- lengths: torch.Tensor,
56
- states: Optional[List[List[torch.Tensor]]],
57
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
58
- raise RuntimeError("Conformer does not support streaming inference.")
59
-
60
-
61
- class _JoinerBiasing(torch.nn.Module):
62
- r"""Recurrent neural network transducer (RNN-T) joint network.
63
-
64
- Args:
65
- input_dim (int): source and target input dimension.
66
- output_dim (int): output dimension.
67
- activation (str, optional): activation function to use in the joiner.
68
- Must be one of ("relu", "tanh"). (Default: "relu")
69
- biasing (bool): perform biasing
70
- deepbiasing (bool): perform deep biasing
71
- attndim (int): dimension of the biasing vector hptr
72
-
73
- """
74
-
75
- def __init__(
76
- self,
77
- input_dim: int,
78
- output_dim: int,
79
- activation: str = "relu",
80
- biasing: bool = False,
81
- deepbiasing: bool = False,
82
- attndim: int = 1,
83
- ) -> None:
84
- super().__init__()
85
- self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
86
- self.biasing = biasing
87
- self.deepbiasing = deepbiasing
88
- if self.biasing and self.deepbiasing:
89
- self.biasinglinear = torch.nn.Linear(attndim, input_dim, bias=True)
90
- self.attndim = attndim
91
- if activation == "relu":
92
- self.activation = torch.nn.ReLU()
93
- elif activation == "tanh":
94
- self.activation = torch.nn.Tanh()
95
- else:
96
- raise ValueError(f"Unsupported activation {activation}")
97
-
98
- def forward(
99
- self,
100
- source_encodings: torch.Tensor,
101
- source_lengths: torch.Tensor,
102
- target_encodings: torch.Tensor,
103
- target_lengths: torch.Tensor,
104
- hptr: torch.Tensor = None,
105
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
106
- r"""Forward pass for training.
107
-
108
- B: batch size;
109
- T: maximum source sequence length in batch;
110
- U: maximum target sequence length in batch;
111
- D: dimension of each source and target sequence encoding.
112
-
113
- Args:
114
- source_encodings (torch.Tensor): source encoding sequences, with
115
- shape `(B, T, D)`.
116
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
117
- valid sequence length of i-th batch element in ``source_encodings``.
118
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
119
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
120
- valid sequence length of i-th batch element in ``target_encodings``.
121
- hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
122
-
123
- Returns:
124
- (torch.Tensor, torch.Tensor, torch.Tensor):
125
- torch.Tensor
126
- joint network output, with shape `(B, T, U, output_dim)`.
127
- torch.Tensor
128
- output source lengths, with shape `(B,)` and i-th element representing
129
- number of valid elements along dim 1 for i-th batch element in joint network output.
130
- torch.Tensor
131
- output target lengths, with shape `(B,)` and i-th element representing
132
- number of valid elements along dim 2 for i-th batch element in joint network output.
133
- torch.Tensor
134
- joint network second last layer output (i.e. before self.linear), with shape `(B, T, U, D)`.
135
- """
136
- joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
137
- if self.biasing and self.deepbiasing and hptr is not None:
138
- hptr = self.biasinglinear(hptr)
139
- joint_encodings += hptr
140
- elif self.biasing and self.deepbiasing:
141
- # Hack here for unused parameters
142
- joint_encodings += self.biasinglinear(joint_encodings.new_zeros(1, self.attndim)).mean() * 0
143
- activation_out = self.activation(joint_encodings)
144
- output = self.linear(activation_out)
145
- return output, source_lengths, target_lengths, activation_out
146
-
147
-
148
- class RNNTBiasing(RNNT):
149
- r"""torchaudio.models.RNNT()
150
-
151
- Recurrent neural network transducer (RNN-T) model.
152
-
153
- Note:
154
- To build the model, please use one of the factory functions.
155
-
156
- Args:
157
- transcriber (torch.nn.Module): transcription network.
158
- predictor (torch.nn.Module): prediction network.
159
- joiner (torch.nn.Module): joint network.
160
- attndim (int): TCPGen attention dimension
161
- biasing (bool): If true, use biasing, otherwise use standard RNN-T
162
- deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
163
- embdim (int): dimension of symbol embeddings
164
- jointdim (int): dimension of the joint network joint dimension
165
- charlist (list): The list of word piece tokens in the same order as the output layer
166
- encoutdim (int): dimension of the encoder output vectors
167
- dropout_tcpgen (float): dropout rate for TCPGen
168
- tcpsche (int): The epoch at which TCPGen starts to train
169
- DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
170
- """
171
-
172
- def __init__(
173
- self,
174
- transcriber: _Transcriber,
175
- predictor: _Predictor,
176
- joiner: _Joiner,
177
- attndim: int,
178
- biasing: bool,
179
- deepbiasing: bool,
180
- embdim: int,
181
- jointdim: int,
182
- charlist: List[str],
183
- encoutdim: int,
184
- dropout_tcpgen: float,
185
- tcpsche: int,
186
- DBaverage: bool,
187
- ) -> None:
188
- super().__init__(transcriber, predictor, joiner)
189
- self.attndim = attndim
190
- self.deepbiasing = deepbiasing
191
- self.jointdim = jointdim
192
- self.embdim = embdim
193
- self.encoutdim = encoutdim
194
- self.char_list = charlist or []
195
- self.blank_idx = self.char_list.index("<blank>")
196
- self.nchars = len(self.char_list)
197
- self.DBaverage = DBaverage
198
- self.biasing = biasing
199
- if self.biasing:
200
- if self.deepbiasing and self.DBaverage:
201
- # Deep biasing without TCPGen
202
- self.biasingemb = torch.nn.Linear(self.nchars, self.attndim, bias=False)
203
- else:
204
- # TCPGen parameters
205
- self.ooKBemb = torch.nn.Embedding(1, self.embdim)
206
- self.Qproj_char = torch.nn.Linear(self.embdim, self.attndim)
207
- self.Qproj_acoustic = torch.nn.Linear(self.encoutdim, self.attndim)
208
- self.Kproj = torch.nn.Linear(self.embdim, self.attndim)
209
- self.pointer_gate = torch.nn.Linear(self.attndim + self.jointdim, 1)
210
- self.dropout_tcpgen = torch.nn.Dropout(dropout_tcpgen)
211
- self.tcpsche = tcpsche
212
-
213
- def forward(
214
- self,
215
- sources: torch.Tensor,
216
- source_lengths: torch.Tensor,
217
- targets: torch.Tensor,
218
- target_lengths: torch.Tensor,
219
- tries: TrieNode,
220
- current_epoch: int,
221
- predictor_state: Optional[List[List[torch.Tensor]]] = None,
222
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]], torch.Tensor, torch.Tensor]:
223
- r"""Forward pass for training.
224
-
225
- B: batch size;
226
- T: maximum source sequence length in batch;
227
- U: maximum target sequence length in batch;
228
- D: feature dimension of each source sequence element.
229
-
230
- Args:
231
- sources (torch.Tensor): source frame sequences right-padded with right context, with
232
- shape `(B, T, D)`.
233
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
234
- number of valid frames for i-th batch element in ``sources``.
235
- targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
236
- mapping to a target symbol.
237
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
238
- number of valid frames for i-th batch element in ``targets``.
239
- tries (TrieNode): wordpiece prefix trees representing the biasing list to be searched
240
- current_epoch (Int): the current epoch number to determine if TCPGen should be trained
241
- at this epoch
242
- predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
243
- representing prediction network internal state generated in preceding invocation
244
- of ``forward``. (Default: ``None``)
245
-
246
- Returns:
247
- (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
248
- torch.Tensor
249
- joint network output, with shape
250
- `(B, max output source length, max output target length, output_dim (number of target symbols))`.
251
- torch.Tensor
252
- output source lengths, with shape `(B,)` and i-th element representing
253
- number of valid elements along dim 1 for i-th batch element in joint network output.
254
- torch.Tensor
255
- output target lengths, with shape `(B,)` and i-th element representing
256
- number of valid elements along dim 2 for i-th batch element in joint network output.
257
- List[List[torch.Tensor]]
258
- output states; list of lists of tensors
259
- representing prediction network internal state generated in current invocation
260
- of ``forward``.
261
- torch.Tensor
262
- TCPGen distribution, with shape
263
- `(B, max output source length, max output target length, output_dim (number of target symbols))`.
264
- torch.Tensor
265
- Generation probability (or copy probability), with shape
266
- `(B, max output source length, max output target length, 1)`.
267
- """
268
- source_encodings, source_lengths = self.transcriber(
269
- input=sources,
270
- lengths=source_lengths,
271
- )
272
- target_encodings, target_lengths, predictor_state = self.predictor(
273
- input=targets,
274
- lengths=target_lengths,
275
- state=predictor_state,
276
- )
277
- # Forward TCPGen
278
- hptr = None
279
- tcpgen_dist, p_gen = None, None
280
- if self.biasing and current_epoch >= self.tcpsche and tries != []:
281
- ptrdist_mask, p_gen_mask = self.get_tcpgen_step_masks(targets, tries)
282
- hptr, tcpgen_dist = self.forward_tcpgen(targets, ptrdist_mask, source_encodings)
283
- hptr = self.dropout_tcpgen(hptr)
284
- elif self.biasing:
285
- # Hack here to bypass unused parameters
286
- if self.DBaverage and self.deepbiasing:
287
- dummy = self.biasingemb(source_encodings.new_zeros(1, len(self.char_list))).mean()
288
- else:
289
- dummy = source_encodings.new_zeros(1, self.embdim)
290
- dummy = self.Qproj_char(dummy).mean()
291
- dummy += self.Qproj_acoustic(source_encodings.new_zeros(1, source_encodings.size(-1))).mean()
292
- dummy += self.Kproj(source_encodings.new_zeros(1, self.embdim)).mean()
293
- dummy += self.pointer_gate(source_encodings.new_zeros(1, self.attndim + self.jointdim)).mean()
294
- dummy += self.ooKBemb.weight.mean()
295
- dummy = dummy * 0
296
- source_encodings += dummy
297
-
298
- output, source_lengths, target_lengths, jointer_activation = self.joiner(
299
- source_encodings=source_encodings,
300
- source_lengths=source_lengths,
301
- target_encodings=target_encodings,
302
- target_lengths=target_lengths,
303
- hptr=hptr,
304
- )
305
-
306
- # Calculate Generation Probability
307
- if self.biasing and hptr is not None and tcpgen_dist is not None:
308
- p_gen = torch.sigmoid(self.pointer_gate(torch.cat((jointer_activation, hptr), dim=-1)))
309
- # avoid collapsing to ooKB token in the first few updates
310
- # if current_epoch == self.tcpsche:
311
- # p_gen = p_gen * 0.1
312
- p_gen = p_gen.masked_fill(p_gen_mask.bool().unsqueeze(1).unsqueeze(-1), 0)
313
-
314
- return (output, source_lengths, target_lengths, predictor_state, tcpgen_dist, p_gen)
315
-
316
- def get_tcpgen_distribution(self, query, ptrdist_mask):
317
- # Make use of the predictor embedding matrix
318
- keyvalues = torch.cat([self.predictor.embedding.weight.data, self.ooKBemb.weight], dim=0)
319
- keyvalues = self.dropout_tcpgen(self.Kproj(keyvalues))
320
- # B * T * U * attndim, nbpe * attndim -> B * T * U * nbpe
321
- tcpgendist = torch.einsum("ntuj,ij->ntui", query, keyvalues)
322
- tcpgendist = tcpgendist / math.sqrt(query.size(-1))
323
- ptrdist_mask = ptrdist_mask.unsqueeze(1).repeat(1, tcpgendist.size(1), 1, 1)
324
- tcpgendist.masked_fill_(ptrdist_mask.bool(), -1e9)
325
- tcpgendist = torch.nn.functional.softmax(tcpgendist, dim=-1)
326
- # B * T * U * nbpe, nbpe * attndim -> B * T * U * attndim
327
- hptr = torch.einsum("ntui,ij->ntuj", tcpgendist[:, :, :, :-1], keyvalues[:-1, :])
328
- return hptr, tcpgendist
329
-
330
- def forward_tcpgen(self, targets, ptrdist_mask, source_encodings):
331
- tcpgen_dist = None
332
- if self.DBaverage and self.deepbiasing:
333
- hptr = self.biasingemb(1 - ptrdist_mask[:, :, :-1].float()).unsqueeze(1)
334
- else:
335
- query_char = self.predictor.embedding(targets)
336
- query_char = self.Qproj_char(query_char).unsqueeze(1) # B * 1 * U * attndim
337
- query_acoustic = self.Qproj_acoustic(source_encodings).unsqueeze(2) # B * T * 1 * attndim
338
- query = query_char + query_acoustic # B * T * U * attndim
339
- hptr, tcpgen_dist = self.get_tcpgen_distribution(query, ptrdist_mask)
340
- return hptr, tcpgen_dist
341
-
342
- def get_tcpgen_step_masks(self, yseqs, resettrie):
343
- seqlen = len(yseqs[0])
344
- batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
345
- p_gen_masks = []
346
- for i, yseq in enumerate(yseqs):
347
- new_tree = resettrie
348
- p_gen_mask = []
349
- for j, vy in enumerate(yseq):
350
- vy = vy.item()
351
- new_tree = new_tree[0]
352
- if vy in [self.blank_idx]:
353
- new_tree = resettrie
354
- p_gen_mask.append(0)
355
- elif self.char_list[vy].endswith("▁"):
356
- if vy in new_tree and new_tree[vy][0] != {}:
357
- new_tree = new_tree[vy]
358
- else:
359
- new_tree = resettrie
360
- p_gen_mask.append(0)
361
- elif vy not in new_tree:
362
- new_tree = [{}]
363
- p_gen_mask.append(1)
364
- else:
365
- new_tree = new_tree[vy]
366
- p_gen_mask.append(0)
367
- batch_masks[i, j, list(new_tree[0].keys())] = 0
368
- # In the original paper, ooKB node was not masked
369
- # In this implementation, if not masking ooKB, ooKB probability
370
- # would quickly collapse to 1.0 in the first few updates.
371
- # Haven't found out why this happened.
372
- # batch_masks[i, j, -1] = 0
373
- p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
374
- p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
375
- return batch_masks, p_gen_masks
376
-
377
- def get_tcpgen_step_masks_prefix(self, yseqs, resettrie):
378
- # Implemented for prefix-based wordpieces, not tested yet
379
- seqlen = len(yseqs[0])
380
- batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
381
- p_gen_masks = []
382
- for i, yseq in enumerate(yseqs):
383
- p_gen_mask = []
384
- new_tree = resettrie
385
- for j, vy in enumerate(yseq):
386
- vy = vy.item()
387
- new_tree = new_tree[0]
388
- if vy in [self.blank_idx]:
389
- new_tree = resettrie
390
- batch_masks[i, j, list(new_tree[0].keys())] = 0
391
- elif self.char_list[vy].startswith("▁"):
392
- new_tree = resettrie
393
- if vy not in new_tree[0]:
394
- batch_masks[i, j, list(new_tree[0].keys())] = 0
395
- else:
396
- new_tree = new_tree[0][vy]
397
- batch_masks[i, j, list(new_tree[0].keys())] = 0
398
- if new_tree[1] != -1:
399
- batch_masks[i, j, list(resettrie[0].keys())] = 0
400
- else:
401
- if vy not in new_tree:
402
- new_tree = resettrie
403
- batch_masks[i, j, list(new_tree[0].keys())] = 0
404
- else:
405
- new_tree = new_tree[vy]
406
- batch_masks[i, j, list(new_tree[0].keys())] = 0
407
- if new_tree[1] != -1:
408
- batch_masks[i, j, list(resettrie[0].keys())] = 0
409
- p_gen_mask.append(0)
410
- # batch_masks[i, j, -1] = 0
411
- p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
412
- p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
413
-
414
- return batch_masks, p_gen_masks
415
-
416
- def get_tcpgen_step(self, vy, trie, resettrie):
417
- new_tree = trie[0]
418
- if vy in [self.blank_idx]:
419
- new_tree = resettrie
420
- elif self.char_list[vy].endswith("▁"):
421
- if vy in new_tree and new_tree[vy][0] != {}:
422
- new_tree = new_tree[vy]
423
- else:
424
- new_tree = resettrie
425
- elif vy not in new_tree:
426
- new_tree = [{}]
427
- else:
428
- new_tree = new_tree[vy]
429
- return new_tree
430
-
431
- def join(
432
- self,
433
- source_encodings: torch.Tensor,
434
- source_lengths: torch.Tensor,
435
- target_encodings: torch.Tensor,
436
- target_lengths: torch.Tensor,
437
- hptr: torch.Tensor = None,
438
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
439
- r"""Applies joint network to source and target encodings.
440
-
441
- B: batch size;
442
- T: maximum source sequence length in batch;
443
- U: maximum target sequence length in batch;
444
- D: dimension of each source and target sequence encoding.
445
- A: TCPGen attention dimension
446
-
447
- Args:
448
- source_encodings (torch.Tensor): source encoding sequences, with
449
- shape `(B, T, D)`.
450
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
451
- valid sequence length of i-th batch element in ``source_encodings``.
452
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
453
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
454
- valid sequence length of i-th batch element in ``target_encodings``.
455
- hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
456
-
457
- Returns:
458
- (torch.Tensor, torch.Tensor, torch.Tensor):
459
- torch.Tensor
460
- joint network output, with shape `(B, T, U, output_dim)`.
461
- torch.Tensor
462
- output source lengths, with shape `(B,)` and i-th element representing
463
- number of valid elements along dim 1 for i-th batch element in joint network output.
464
- torch.Tensor
465
- joint network second last layer output, with shape `(B, T, U, D)`.
466
- """
467
- output, source_lengths, target_lengths, jointer_activation = self.joiner(
468
- source_encodings=source_encodings,
469
- source_lengths=source_lengths,
470
- target_encodings=target_encodings,
471
- target_lengths=target_lengths,
472
- hptr=hptr,
473
- )
474
- return output, source_lengths, jointer_activation
475
-
476
-
477
- @dropping_support
478
- def conformer_rnnt_model(
479
- *,
480
- input_dim: int,
481
- encoding_dim: int,
482
- time_reduction_stride: int,
483
- conformer_input_dim: int,
484
- conformer_ffn_dim: int,
485
- conformer_num_layers: int,
486
- conformer_num_heads: int,
487
- conformer_depthwise_conv_kernel_size: int,
488
- conformer_dropout: float,
489
- num_symbols: int,
490
- symbol_embedding_dim: int,
491
- num_lstm_layers: int,
492
- lstm_hidden_dim: int,
493
- lstm_layer_norm: int,
494
- lstm_layer_norm_epsilon: int,
495
- lstm_dropout: int,
496
- joiner_activation: str,
497
- ) -> RNNT:
498
- r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
499
-
500
- Args:
501
- input_dim (int): dimension of input sequence frames passed to transcription network.
502
- encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
503
- passed to joint network.
504
- time_reduction_stride (int): factor by which to reduce length of input sequence.
505
- conformer_input_dim (int): dimension of Conformer input.
506
- conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
507
- conformer_num_layers (int): number of Conformer layers to instantiate.
508
- conformer_num_heads (int): number of attention heads in each Conformer layer.
509
- conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
510
- conformer_dropout (float): Conformer dropout probability.
511
- num_symbols (int): cardinality of set of target tokens.
512
- symbol_embedding_dim (int): dimension of each target token embedding.
513
- num_lstm_layers (int): number of LSTM layers to instantiate.
514
- lstm_hidden_dim (int): output dimension of each LSTM layer.
515
- lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
516
- lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
517
- lstm_dropout (float): LSTM dropout probability.
518
- joiner_activation (str): activation function to use in the joiner.
519
- Must be one of ("relu", "tanh"). (Default: "relu")
520
-
521
- Returns:
522
- RNNT:
523
- Conformer RNN-T model.
524
- """
525
- encoder = _ConformerEncoder(
526
- input_dim=input_dim,
527
- output_dim=encoding_dim,
528
- time_reduction_stride=time_reduction_stride,
529
- conformer_input_dim=conformer_input_dim,
530
- conformer_ffn_dim=conformer_ffn_dim,
531
- conformer_num_layers=conformer_num_layers,
532
- conformer_num_heads=conformer_num_heads,
533
- conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
534
- conformer_dropout=conformer_dropout,
535
- )
536
- predictor = _Predictor(
537
- num_symbols=num_symbols,
538
- output_dim=encoding_dim,
539
- symbol_embedding_dim=symbol_embedding_dim,
540
- num_lstm_layers=num_lstm_layers,
541
- lstm_hidden_dim=lstm_hidden_dim,
542
- lstm_layer_norm=lstm_layer_norm,
543
- lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
544
- lstm_dropout=lstm_dropout,
545
- )
546
- joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
547
- return RNNT(encoder, predictor, joiner)
548
-
549
-
550
- @dropping_support
551
- def conformer_rnnt_base() -> RNNT:
552
- r"""Builds basic version of Conformer RNN-T model.
553
-
554
- Returns:
555
- RNNT:
556
- Conformer RNN-T model.
557
- """
558
- return conformer_rnnt_model(
559
- input_dim=80,
560
- encoding_dim=1024,
561
- time_reduction_stride=4,
562
- conformer_input_dim=256,
563
- conformer_ffn_dim=1024,
564
- conformer_num_layers=16,
565
- conformer_num_heads=4,
566
- conformer_depthwise_conv_kernel_size=31,
567
- conformer_dropout=0.1,
568
- num_symbols=1024,
569
- symbol_embedding_dim=256,
570
- num_lstm_layers=2,
571
- lstm_hidden_dim=512,
572
- lstm_layer_norm=True,
573
- lstm_layer_norm_epsilon=1e-5,
574
- lstm_dropout=0.3,
575
- joiner_activation="tanh",
576
- )
577
-
578
-
579
- @dropping_support
580
- def conformer_rnnt_biasing(
581
- *,
582
- input_dim: int,
583
- encoding_dim: int,
584
- time_reduction_stride: int,
585
- conformer_input_dim: int,
586
- conformer_ffn_dim: int,
587
- conformer_num_layers: int,
588
- conformer_num_heads: int,
589
- conformer_depthwise_conv_kernel_size: int,
590
- conformer_dropout: float,
591
- num_symbols: int,
592
- symbol_embedding_dim: int,
593
- num_lstm_layers: int,
594
- lstm_hidden_dim: int,
595
- lstm_layer_norm: int,
596
- lstm_layer_norm_epsilon: int,
597
- lstm_dropout: int,
598
- joiner_activation: str,
599
- attndim: int,
600
- biasing: bool,
601
- charlist: List[str],
602
- deepbiasing: bool,
603
- tcpsche: int,
604
- DBaverage: bool,
605
- ) -> RNNTBiasing:
606
- r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
607
-
608
- Args:
609
- input_dim (int): dimension of input sequence frames passed to transcription network.
610
- encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
611
- passed to joint network.
612
- time_reduction_stride (int): factor by which to reduce length of input sequence.
613
- conformer_input_dim (int): dimension of Conformer input.
614
- conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
615
- conformer_num_layers (int): number of Conformer layers to instantiate.
616
- conformer_num_heads (int): number of attention heads in each Conformer layer.
617
- conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
618
- conformer_dropout (float): Conformer dropout probability.
619
- num_symbols (int): cardinality of set of target tokens.
620
- symbol_embedding_dim (int): dimension of each target token embedding.
621
- num_lstm_layers (int): number of LSTM layers to instantiate.
622
- lstm_hidden_dim (int): output dimension of each LSTM layer.
623
- lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
624
- lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
625
- lstm_dropout (float): LSTM dropout probability.
626
- joiner_activation (str): activation function to use in the joiner.
627
- Must be one of ("relu", "tanh"). (Default: "relu")
628
- attndim (int): TCPGen attention dimension
629
- biasing (bool): If true, use biasing, otherwise use standard RNN-T
630
- charlist (list): The list of word piece tokens in the same order as the output layer
631
- deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
632
- tcpsche (int): The epoch at which TCPGen starts to train
633
- DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
634
-
635
- Returns:
636
- RNNT:
637
- Conformer RNN-T model with TCPGen-based biasing support.
638
- """
639
- encoder = _ConformerEncoder(
640
- input_dim=input_dim,
641
- output_dim=encoding_dim,
642
- time_reduction_stride=time_reduction_stride,
643
- conformer_input_dim=conformer_input_dim,
644
- conformer_ffn_dim=conformer_ffn_dim,
645
- conformer_num_layers=conformer_num_layers,
646
- conformer_num_heads=conformer_num_heads,
647
- conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
648
- conformer_dropout=conformer_dropout,
649
- )
650
- predictor = _Predictor(
651
- num_symbols=num_symbols,
652
- output_dim=encoding_dim,
653
- symbol_embedding_dim=symbol_embedding_dim,
654
- num_lstm_layers=num_lstm_layers,
655
- lstm_hidden_dim=lstm_hidden_dim,
656
- lstm_layer_norm=lstm_layer_norm,
657
- lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
658
- lstm_dropout=lstm_dropout,
659
- )
660
- joiner = _JoinerBiasing(
661
- encoding_dim,
662
- num_symbols,
663
- activation=joiner_activation,
664
- deepbiasing=deepbiasing,
665
- attndim=attndim,
666
- biasing=biasing,
667
- )
668
- return RNNTBiasing(
669
- encoder,
670
- predictor,
671
- joiner,
672
- attndim,
673
- biasing,
674
- deepbiasing,
675
- symbol_embedding_dim,
676
- encoding_dim,
677
- charlist,
678
- encoding_dim,
679
- conformer_dropout,
680
- tcpsche,
681
- DBaverage,
682
- )
683
-
684
-
685
- @dropping_support
686
- def conformer_rnnt_biasing_base(charlist=None, biasing=True) -> RNNT:
687
- r"""Builds basic version of Conformer RNN-T model with TCPGen.
688
-
689
- Returns:
690
- RNNT:
691
- Conformer RNN-T model with TCPGen-based biasing support.
692
- """
693
- return conformer_rnnt_biasing(
694
- input_dim=80,
695
- encoding_dim=576,
696
- time_reduction_stride=4,
697
- conformer_input_dim=144,
698
- conformer_ffn_dim=576,
699
- conformer_num_layers=16,
700
- conformer_num_heads=4,
701
- conformer_depthwise_conv_kernel_size=31,
702
- conformer_dropout=0.1,
703
- num_symbols=601,
704
- symbol_embedding_dim=256,
705
- num_lstm_layers=1,
706
- lstm_hidden_dim=320,
707
- lstm_layer_norm=True,
708
- lstm_layer_norm_epsilon=1e-5,
709
- lstm_dropout=0.3,
710
- joiner_activation="tanh",
711
- attndim=256,
712
- biasing=biasing,
713
- charlist=charlist,
714
- deepbiasing=True,
715
- tcpsche=30,
716
- DBaverage=False,
717
- )