torchaudio 2.7.1__cp313-cp313t-win_amd64.whl → 2.9.0__cp313-cp313t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +184 -33
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +68 -10
  5. torchaudio/_torchcodec.py +340 -0
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +6 -3
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +31 -61
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +19 -1
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +3 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +4 -1
  23. torchaudio/transforms/_transforms.py +4 -3
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -1
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/METADATA +15 -7
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -317
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -13
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -144
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -67
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -433
  54. torchaudio/prototype/functional/_rir.py +0 -379
  55. torchaudio/prototype/functional/functional.py +0 -190
  56. torchaudio/prototype/models/__init__.py +0 -36
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -794
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -333
  59. torchaudio/prototype/models/conv_emformer.py +0 -525
  60. torchaudio/prototype/models/hifi_gan.py +0 -336
  61. torchaudio/prototype/models/rnnt.py +0 -711
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -399
  63. torchaudio/prototype/pipelines/__init__.py +0 -12
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -3
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -233
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -82
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -228
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -456
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -272
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -99
  75. torchaudio-2.7.1.dist-info/RECORD +0 -144
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -978
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -247
  91. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/LICENSE +0 -0
  92. {torchaudio-2.7.1.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
@@ -1,711 +0,0 @@
1
- import math
2
- from typing import Dict, List, Optional, Tuple
3
-
4
- import torch
5
- from torchaudio.models import Conformer, RNNT
6
- from torchaudio.models.rnnt import _Joiner, _Predictor, _TimeReduction, _Transcriber
7
-
8
-
9
- TrieNode = Tuple[Dict[int, "TrieNode"], int, Optional[Tuple[int, int]]]
10
-
11
-
12
- class _ConformerEncoder(torch.nn.Module, _Transcriber):
13
- def __init__(
14
- self,
15
- *,
16
- input_dim: int,
17
- output_dim: int,
18
- time_reduction_stride: int,
19
- conformer_input_dim: int,
20
- conformer_ffn_dim: int,
21
- conformer_num_layers: int,
22
- conformer_num_heads: int,
23
- conformer_depthwise_conv_kernel_size: int,
24
- conformer_dropout: float,
25
- ) -> None:
26
- super().__init__()
27
- self.time_reduction = _TimeReduction(time_reduction_stride)
28
- self.input_linear = torch.nn.Linear(input_dim * time_reduction_stride, conformer_input_dim)
29
- self.conformer = Conformer(
30
- num_layers=conformer_num_layers,
31
- input_dim=conformer_input_dim,
32
- ffn_dim=conformer_ffn_dim,
33
- num_heads=conformer_num_heads,
34
- depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
35
- dropout=conformer_dropout,
36
- use_group_norm=True,
37
- convolution_first=True,
38
- )
39
- self.output_linear = torch.nn.Linear(conformer_input_dim, output_dim)
40
- self.layer_norm = torch.nn.LayerNorm(output_dim)
41
-
42
- def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
43
- time_reduction_out, time_reduction_lengths = self.time_reduction(input, lengths)
44
- input_linear_out = self.input_linear(time_reduction_out)
45
- x, lengths = self.conformer(input_linear_out, time_reduction_lengths)
46
- output_linear_out = self.output_linear(x)
47
- layer_norm_out = self.layer_norm(output_linear_out)
48
- return layer_norm_out, lengths
49
-
50
- def infer(
51
- self,
52
- input: torch.Tensor,
53
- lengths: torch.Tensor,
54
- states: Optional[List[List[torch.Tensor]]],
55
- ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
56
- raise RuntimeError("Conformer does not support streaming inference.")
57
-
58
-
59
- class _JoinerBiasing(torch.nn.Module):
60
- r"""Recurrent neural network transducer (RNN-T) joint network.
61
-
62
- Args:
63
- input_dim (int): source and target input dimension.
64
- output_dim (int): output dimension.
65
- activation (str, optional): activation function to use in the joiner.
66
- Must be one of ("relu", "tanh"). (Default: "relu")
67
- biasing (bool): perform biasing
68
- deepbiasing (bool): perform deep biasing
69
- attndim (int): dimension of the biasing vector hptr
70
-
71
- """
72
-
73
- def __init__(
74
- self,
75
- input_dim: int,
76
- output_dim: int,
77
- activation: str = "relu",
78
- biasing: bool = False,
79
- deepbiasing: bool = False,
80
- attndim: int = 1,
81
- ) -> None:
82
- super().__init__()
83
- self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
84
- self.biasing = biasing
85
- self.deepbiasing = deepbiasing
86
- if self.biasing and self.deepbiasing:
87
- self.biasinglinear = torch.nn.Linear(attndim, input_dim, bias=True)
88
- self.attndim = attndim
89
- if activation == "relu":
90
- self.activation = torch.nn.ReLU()
91
- elif activation == "tanh":
92
- self.activation = torch.nn.Tanh()
93
- else:
94
- raise ValueError(f"Unsupported activation {activation}")
95
-
96
- def forward(
97
- self,
98
- source_encodings: torch.Tensor,
99
- source_lengths: torch.Tensor,
100
- target_encodings: torch.Tensor,
101
- target_lengths: torch.Tensor,
102
- hptr: torch.Tensor = None,
103
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
- r"""Forward pass for training.
105
-
106
- B: batch size;
107
- T: maximum source sequence length in batch;
108
- U: maximum target sequence length in batch;
109
- D: dimension of each source and target sequence encoding.
110
-
111
- Args:
112
- source_encodings (torch.Tensor): source encoding sequences, with
113
- shape `(B, T, D)`.
114
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
115
- valid sequence length of i-th batch element in ``source_encodings``.
116
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
117
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
118
- valid sequence length of i-th batch element in ``target_encodings``.
119
- hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
120
-
121
- Returns:
122
- (torch.Tensor, torch.Tensor, torch.Tensor):
123
- torch.Tensor
124
- joint network output, with shape `(B, T, U, output_dim)`.
125
- torch.Tensor
126
- output source lengths, with shape `(B,)` and i-th element representing
127
- number of valid elements along dim 1 for i-th batch element in joint network output.
128
- torch.Tensor
129
- output target lengths, with shape `(B,)` and i-th element representing
130
- number of valid elements along dim 2 for i-th batch element in joint network output.
131
- torch.Tensor
132
- joint network second last layer output (i.e. before self.linear), with shape `(B, T, U, D)`.
133
- """
134
- joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
135
- if self.biasing and self.deepbiasing and hptr is not None:
136
- hptr = self.biasinglinear(hptr)
137
- joint_encodings += hptr
138
- elif self.biasing and self.deepbiasing:
139
- # Hack here for unused parameters
140
- joint_encodings += self.biasinglinear(joint_encodings.new_zeros(1, self.attndim)).mean() * 0
141
- activation_out = self.activation(joint_encodings)
142
- output = self.linear(activation_out)
143
- return output, source_lengths, target_lengths, activation_out
144
-
145
-
146
- class RNNTBiasing(RNNT):
147
- r"""torchaudio.models.RNNT()
148
-
149
- Recurrent neural network transducer (RNN-T) model.
150
-
151
- Note:
152
- To build the model, please use one of the factory functions.
153
-
154
- Args:
155
- transcriber (torch.nn.Module): transcription network.
156
- predictor (torch.nn.Module): prediction network.
157
- joiner (torch.nn.Module): joint network.
158
- attndim (int): TCPGen attention dimension
159
- biasing (bool): If true, use biasing, otherwise use standard RNN-T
160
- deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
161
- embdim (int): dimension of symbol embeddings
162
- jointdim (int): dimension of the joint network joint dimension
163
- charlist (list): The list of word piece tokens in the same order as the output layer
164
- encoutdim (int): dimension of the encoder output vectors
165
- dropout_tcpgen (float): dropout rate for TCPGen
166
- tcpsche (int): The epoch at which TCPGen starts to train
167
- DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
168
- """
169
-
170
- def __init__(
171
- self,
172
- transcriber: _Transcriber,
173
- predictor: _Predictor,
174
- joiner: _Joiner,
175
- attndim: int,
176
- biasing: bool,
177
- deepbiasing: bool,
178
- embdim: int,
179
- jointdim: int,
180
- charlist: List[str],
181
- encoutdim: int,
182
- dropout_tcpgen: float,
183
- tcpsche: int,
184
- DBaverage: bool,
185
- ) -> None:
186
- super().__init__(transcriber, predictor, joiner)
187
- self.attndim = attndim
188
- self.deepbiasing = deepbiasing
189
- self.jointdim = jointdim
190
- self.embdim = embdim
191
- self.encoutdim = encoutdim
192
- self.char_list = charlist or []
193
- self.blank_idx = self.char_list.index("<blank>")
194
- self.nchars = len(self.char_list)
195
- self.DBaverage = DBaverage
196
- self.biasing = biasing
197
- if self.biasing:
198
- if self.deepbiasing and self.DBaverage:
199
- # Deep biasing without TCPGen
200
- self.biasingemb = torch.nn.Linear(self.nchars, self.attndim, bias=False)
201
- else:
202
- # TCPGen parameters
203
- self.ooKBemb = torch.nn.Embedding(1, self.embdim)
204
- self.Qproj_char = torch.nn.Linear(self.embdim, self.attndim)
205
- self.Qproj_acoustic = torch.nn.Linear(self.encoutdim, self.attndim)
206
- self.Kproj = torch.nn.Linear(self.embdim, self.attndim)
207
- self.pointer_gate = torch.nn.Linear(self.attndim + self.jointdim, 1)
208
- self.dropout_tcpgen = torch.nn.Dropout(dropout_tcpgen)
209
- self.tcpsche = tcpsche
210
-
211
- def forward(
212
- self,
213
- sources: torch.Tensor,
214
- source_lengths: torch.Tensor,
215
- targets: torch.Tensor,
216
- target_lengths: torch.Tensor,
217
- tries: TrieNode,
218
- current_epoch: int,
219
- predictor_state: Optional[List[List[torch.Tensor]]] = None,
220
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]], torch.Tensor, torch.Tensor]:
221
- r"""Forward pass for training.
222
-
223
- B: batch size;
224
- T: maximum source sequence length in batch;
225
- U: maximum target sequence length in batch;
226
- D: feature dimension of each source sequence element.
227
-
228
- Args:
229
- sources (torch.Tensor): source frame sequences right-padded with right context, with
230
- shape `(B, T, D)`.
231
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
232
- number of valid frames for i-th batch element in ``sources``.
233
- targets (torch.Tensor): target sequences, with shape `(B, U)` and each element
234
- mapping to a target symbol.
235
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
236
- number of valid frames for i-th batch element in ``targets``.
237
- tries (TrieNode): wordpiece prefix trees representing the biasing list to be searched
238
- current_epoch (Int): the current epoch number to determine if TCPGen should be trained
239
- at this epoch
240
- predictor_state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
241
- representing prediction network internal state generated in preceding invocation
242
- of ``forward``. (Default: ``None``)
243
-
244
- Returns:
245
- (torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]):
246
- torch.Tensor
247
- joint network output, with shape
248
- `(B, max output source length, max output target length, output_dim (number of target symbols))`.
249
- torch.Tensor
250
- output source lengths, with shape `(B,)` and i-th element representing
251
- number of valid elements along dim 1 for i-th batch element in joint network output.
252
- torch.Tensor
253
- output target lengths, with shape `(B,)` and i-th element representing
254
- number of valid elements along dim 2 for i-th batch element in joint network output.
255
- List[List[torch.Tensor]]
256
- output states; list of lists of tensors
257
- representing prediction network internal state generated in current invocation
258
- of ``forward``.
259
- torch.Tensor
260
- TCPGen distribution, with shape
261
- `(B, max output source length, max output target length, output_dim (number of target symbols))`.
262
- torch.Tensor
263
- Generation probability (or copy probability), with shape
264
- `(B, max output source length, max output target length, 1)`.
265
- """
266
- source_encodings, source_lengths = self.transcriber(
267
- input=sources,
268
- lengths=source_lengths,
269
- )
270
- target_encodings, target_lengths, predictor_state = self.predictor(
271
- input=targets,
272
- lengths=target_lengths,
273
- state=predictor_state,
274
- )
275
- # Forward TCPGen
276
- hptr = None
277
- tcpgen_dist, p_gen = None, None
278
- if self.biasing and current_epoch >= self.tcpsche and tries != []:
279
- ptrdist_mask, p_gen_mask = self.get_tcpgen_step_masks(targets, tries)
280
- hptr, tcpgen_dist = self.forward_tcpgen(targets, ptrdist_mask, source_encodings)
281
- hptr = self.dropout_tcpgen(hptr)
282
- elif self.biasing:
283
- # Hack here to bypass unused parameters
284
- if self.DBaverage and self.deepbiasing:
285
- dummy = self.biasingemb(source_encodings.new_zeros(1, len(self.char_list))).mean()
286
- else:
287
- dummy = source_encodings.new_zeros(1, self.embdim)
288
- dummy = self.Qproj_char(dummy).mean()
289
- dummy += self.Qproj_acoustic(source_encodings.new_zeros(1, source_encodings.size(-1))).mean()
290
- dummy += self.Kproj(source_encodings.new_zeros(1, self.embdim)).mean()
291
- dummy += self.pointer_gate(source_encodings.new_zeros(1, self.attndim + self.jointdim)).mean()
292
- dummy += self.ooKBemb.weight.mean()
293
- dummy = dummy * 0
294
- source_encodings += dummy
295
-
296
- output, source_lengths, target_lengths, jointer_activation = self.joiner(
297
- source_encodings=source_encodings,
298
- source_lengths=source_lengths,
299
- target_encodings=target_encodings,
300
- target_lengths=target_lengths,
301
- hptr=hptr,
302
- )
303
-
304
- # Calculate Generation Probability
305
- if self.biasing and hptr is not None and tcpgen_dist is not None:
306
- p_gen = torch.sigmoid(self.pointer_gate(torch.cat((jointer_activation, hptr), dim=-1)))
307
- # avoid collapsing to ooKB token in the first few updates
308
- # if current_epoch == self.tcpsche:
309
- # p_gen = p_gen * 0.1
310
- p_gen = p_gen.masked_fill(p_gen_mask.bool().unsqueeze(1).unsqueeze(-1), 0)
311
-
312
- return (output, source_lengths, target_lengths, predictor_state, tcpgen_dist, p_gen)
313
-
314
- def get_tcpgen_distribution(self, query, ptrdist_mask):
315
- # Make use of the predictor embedding matrix
316
- keyvalues = torch.cat([self.predictor.embedding.weight.data, self.ooKBemb.weight], dim=0)
317
- keyvalues = self.dropout_tcpgen(self.Kproj(keyvalues))
318
- # B * T * U * attndim, nbpe * attndim -> B * T * U * nbpe
319
- tcpgendist = torch.einsum("ntuj,ij->ntui", query, keyvalues)
320
- tcpgendist = tcpgendist / math.sqrt(query.size(-1))
321
- ptrdist_mask = ptrdist_mask.unsqueeze(1).repeat(1, tcpgendist.size(1), 1, 1)
322
- tcpgendist.masked_fill_(ptrdist_mask.bool(), -1e9)
323
- tcpgendist = torch.nn.functional.softmax(tcpgendist, dim=-1)
324
- # B * T * U * nbpe, nbpe * attndim -> B * T * U * attndim
325
- hptr = torch.einsum("ntui,ij->ntuj", tcpgendist[:, :, :, :-1], keyvalues[:-1, :])
326
- return hptr, tcpgendist
327
-
328
- def forward_tcpgen(self, targets, ptrdist_mask, source_encodings):
329
- tcpgen_dist = None
330
- if self.DBaverage and self.deepbiasing:
331
- hptr = self.biasingemb(1 - ptrdist_mask[:, :, :-1].float()).unsqueeze(1)
332
- else:
333
- query_char = self.predictor.embedding(targets)
334
- query_char = self.Qproj_char(query_char).unsqueeze(1) # B * 1 * U * attndim
335
- query_acoustic = self.Qproj_acoustic(source_encodings).unsqueeze(2) # B * T * 1 * attndim
336
- query = query_char + query_acoustic # B * T * U * attndim
337
- hptr, tcpgen_dist = self.get_tcpgen_distribution(query, ptrdist_mask)
338
- return hptr, tcpgen_dist
339
-
340
- def get_tcpgen_step_masks(self, yseqs, resettrie):
341
- seqlen = len(yseqs[0])
342
- batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
343
- p_gen_masks = []
344
- for i, yseq in enumerate(yseqs):
345
- new_tree = resettrie
346
- p_gen_mask = []
347
- for j, vy in enumerate(yseq):
348
- vy = vy.item()
349
- new_tree = new_tree[0]
350
- if vy in [self.blank_idx]:
351
- new_tree = resettrie
352
- p_gen_mask.append(0)
353
- elif self.char_list[vy].endswith("▁"):
354
- if vy in new_tree and new_tree[vy][0] != {}:
355
- new_tree = new_tree[vy]
356
- else:
357
- new_tree = resettrie
358
- p_gen_mask.append(0)
359
- elif vy not in new_tree:
360
- new_tree = [{}]
361
- p_gen_mask.append(1)
362
- else:
363
- new_tree = new_tree[vy]
364
- p_gen_mask.append(0)
365
- batch_masks[i, j, list(new_tree[0].keys())] = 0
366
- # In the original paper, ooKB node was not masked
367
- # In this implementation, if not masking ooKB, ooKB probability
368
- # would quickly collapse to 1.0 in the first few updates.
369
- # Haven't found out why this happened.
370
- # batch_masks[i, j, -1] = 0
371
- p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
372
- p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
373
- return batch_masks, p_gen_masks
374
-
375
- def get_tcpgen_step_masks_prefix(self, yseqs, resettrie):
376
- # Implemented for prefix-based wordpieces, not tested yet
377
- seqlen = len(yseqs[0])
378
- batch_masks = yseqs.new_ones(len(yseqs), seqlen, len(self.char_list) + 1)
379
- p_gen_masks = []
380
- for i, yseq in enumerate(yseqs):
381
- p_gen_mask = []
382
- new_tree = resettrie
383
- for j, vy in enumerate(yseq):
384
- vy = vy.item()
385
- new_tree = new_tree[0]
386
- if vy in [self.blank_idx]:
387
- new_tree = resettrie
388
- batch_masks[i, j, list(new_tree[0].keys())] = 0
389
- elif self.char_list[vy].startswith("▁"):
390
- new_tree = resettrie
391
- if vy not in new_tree[0]:
392
- batch_masks[i, j, list(new_tree[0].keys())] = 0
393
- else:
394
- new_tree = new_tree[0][vy]
395
- batch_masks[i, j, list(new_tree[0].keys())] = 0
396
- if new_tree[1] != -1:
397
- batch_masks[i, j, list(resettrie[0].keys())] = 0
398
- else:
399
- if vy not in new_tree:
400
- new_tree = resettrie
401
- batch_masks[i, j, list(new_tree[0].keys())] = 0
402
- else:
403
- new_tree = new_tree[vy]
404
- batch_masks[i, j, list(new_tree[0].keys())] = 0
405
- if new_tree[1] != -1:
406
- batch_masks[i, j, list(resettrie[0].keys())] = 0
407
- p_gen_mask.append(0)
408
- # batch_masks[i, j, -1] = 0
409
- p_gen_masks.append(p_gen_mask + [1] * (seqlen - len(p_gen_mask)))
410
- p_gen_masks = torch.Tensor(p_gen_masks).to(yseqs.device).byte()
411
-
412
- return batch_masks, p_gen_masks
413
-
414
- def get_tcpgen_step(self, vy, trie, resettrie):
415
- new_tree = trie[0]
416
- if vy in [self.blank_idx]:
417
- new_tree = resettrie
418
- elif self.char_list[vy].endswith("▁"):
419
- if vy in new_tree and new_tree[vy][0] != {}:
420
- new_tree = new_tree[vy]
421
- else:
422
- new_tree = resettrie
423
- elif vy not in new_tree:
424
- new_tree = [{}]
425
- else:
426
- new_tree = new_tree[vy]
427
- return new_tree
428
-
429
- def join(
430
- self,
431
- source_encodings: torch.Tensor,
432
- source_lengths: torch.Tensor,
433
- target_encodings: torch.Tensor,
434
- target_lengths: torch.Tensor,
435
- hptr: torch.Tensor = None,
436
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
437
- r"""Applies joint network to source and target encodings.
438
-
439
- B: batch size;
440
- T: maximum source sequence length in batch;
441
- U: maximum target sequence length in batch;
442
- D: dimension of each source and target sequence encoding.
443
- A: TCPGen attention dimension
444
-
445
- Args:
446
- source_encodings (torch.Tensor): source encoding sequences, with
447
- shape `(B, T, D)`.
448
- source_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
449
- valid sequence length of i-th batch element in ``source_encodings``.
450
- target_encodings (torch.Tensor): target encoding sequences, with shape `(B, U, D)`.
451
- target_lengths (torch.Tensor): with shape `(B,)` and i-th element representing
452
- valid sequence length of i-th batch element in ``target_encodings``.
453
- hptr (torch.Tensor): deep biasing vector with shape `(B, T, U, A)`.
454
-
455
- Returns:
456
- (torch.Tensor, torch.Tensor, torch.Tensor):
457
- torch.Tensor
458
- joint network output, with shape `(B, T, U, output_dim)`.
459
- torch.Tensor
460
- output source lengths, with shape `(B,)` and i-th element representing
461
- number of valid elements along dim 1 for i-th batch element in joint network output.
462
- torch.Tensor
463
- joint network second last layer output, with shape `(B, T, U, D)`.
464
- """
465
- output, source_lengths, target_lengths, jointer_activation = self.joiner(
466
- source_encodings=source_encodings,
467
- source_lengths=source_lengths,
468
- target_encodings=target_encodings,
469
- target_lengths=target_lengths,
470
- hptr=hptr,
471
- )
472
- return output, source_lengths, jointer_activation
473
-
474
-
475
- def conformer_rnnt_model(
476
- *,
477
- input_dim: int,
478
- encoding_dim: int,
479
- time_reduction_stride: int,
480
- conformer_input_dim: int,
481
- conformer_ffn_dim: int,
482
- conformer_num_layers: int,
483
- conformer_num_heads: int,
484
- conformer_depthwise_conv_kernel_size: int,
485
- conformer_dropout: float,
486
- num_symbols: int,
487
- symbol_embedding_dim: int,
488
- num_lstm_layers: int,
489
- lstm_hidden_dim: int,
490
- lstm_layer_norm: int,
491
- lstm_layer_norm_epsilon: int,
492
- lstm_dropout: int,
493
- joiner_activation: str,
494
- ) -> RNNT:
495
- r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
496
-
497
- Args:
498
- input_dim (int): dimension of input sequence frames passed to transcription network.
499
- encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
500
- passed to joint network.
501
- time_reduction_stride (int): factor by which to reduce length of input sequence.
502
- conformer_input_dim (int): dimension of Conformer input.
503
- conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
504
- conformer_num_layers (int): number of Conformer layers to instantiate.
505
- conformer_num_heads (int): number of attention heads in each Conformer layer.
506
- conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
507
- conformer_dropout (float): Conformer dropout probability.
508
- num_symbols (int): cardinality of set of target tokens.
509
- symbol_embedding_dim (int): dimension of each target token embedding.
510
- num_lstm_layers (int): number of LSTM layers to instantiate.
511
- lstm_hidden_dim (int): output dimension of each LSTM layer.
512
- lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
513
- lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
514
- lstm_dropout (float): LSTM dropout probability.
515
- joiner_activation (str): activation function to use in the joiner.
516
- Must be one of ("relu", "tanh"). (Default: "relu")
517
-
518
- Returns:
519
- RNNT:
520
- Conformer RNN-T model.
521
- """
522
- encoder = _ConformerEncoder(
523
- input_dim=input_dim,
524
- output_dim=encoding_dim,
525
- time_reduction_stride=time_reduction_stride,
526
- conformer_input_dim=conformer_input_dim,
527
- conformer_ffn_dim=conformer_ffn_dim,
528
- conformer_num_layers=conformer_num_layers,
529
- conformer_num_heads=conformer_num_heads,
530
- conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
531
- conformer_dropout=conformer_dropout,
532
- )
533
- predictor = _Predictor(
534
- num_symbols=num_symbols,
535
- output_dim=encoding_dim,
536
- symbol_embedding_dim=symbol_embedding_dim,
537
- num_lstm_layers=num_lstm_layers,
538
- lstm_hidden_dim=lstm_hidden_dim,
539
- lstm_layer_norm=lstm_layer_norm,
540
- lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
541
- lstm_dropout=lstm_dropout,
542
- )
543
- joiner = _Joiner(encoding_dim, num_symbols, activation=joiner_activation)
544
- return RNNT(encoder, predictor, joiner)
545
-
546
-
547
- def conformer_rnnt_base() -> RNNT:
548
- r"""Builds basic version of Conformer RNN-T model.
549
-
550
- Returns:
551
- RNNT:
552
- Conformer RNN-T model.
553
- """
554
- return conformer_rnnt_model(
555
- input_dim=80,
556
- encoding_dim=1024,
557
- time_reduction_stride=4,
558
- conformer_input_dim=256,
559
- conformer_ffn_dim=1024,
560
- conformer_num_layers=16,
561
- conformer_num_heads=4,
562
- conformer_depthwise_conv_kernel_size=31,
563
- conformer_dropout=0.1,
564
- num_symbols=1024,
565
- symbol_embedding_dim=256,
566
- num_lstm_layers=2,
567
- lstm_hidden_dim=512,
568
- lstm_layer_norm=True,
569
- lstm_layer_norm_epsilon=1e-5,
570
- lstm_dropout=0.3,
571
- joiner_activation="tanh",
572
- )
573
-
574
-
575
- def conformer_rnnt_biasing(
576
- *,
577
- input_dim: int,
578
- encoding_dim: int,
579
- time_reduction_stride: int,
580
- conformer_input_dim: int,
581
- conformer_ffn_dim: int,
582
- conformer_num_layers: int,
583
- conformer_num_heads: int,
584
- conformer_depthwise_conv_kernel_size: int,
585
- conformer_dropout: float,
586
- num_symbols: int,
587
- symbol_embedding_dim: int,
588
- num_lstm_layers: int,
589
- lstm_hidden_dim: int,
590
- lstm_layer_norm: int,
591
- lstm_layer_norm_epsilon: int,
592
- lstm_dropout: int,
593
- joiner_activation: str,
594
- attndim: int,
595
- biasing: bool,
596
- charlist: List[str],
597
- deepbiasing: bool,
598
- tcpsche: int,
599
- DBaverage: bool,
600
- ) -> RNNTBiasing:
601
- r"""Builds Conformer-based recurrent neural network transducer (RNN-T) model.
602
-
603
- Args:
604
- input_dim (int): dimension of input sequence frames passed to transcription network.
605
- encoding_dim (int): dimension of transcription- and prediction-network-generated encodings
606
- passed to joint network.
607
- time_reduction_stride (int): factor by which to reduce length of input sequence.
608
- conformer_input_dim (int): dimension of Conformer input.
609
- conformer_ffn_dim (int): hidden layer dimension of each Conformer layer's feedforward network.
610
- conformer_num_layers (int): number of Conformer layers to instantiate.
611
- conformer_num_heads (int): number of attention heads in each Conformer layer.
612
- conformer_depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
613
- conformer_dropout (float): Conformer dropout probability.
614
- num_symbols (int): cardinality of set of target tokens.
615
- symbol_embedding_dim (int): dimension of each target token embedding.
616
- num_lstm_layers (int): number of LSTM layers to instantiate.
617
- lstm_hidden_dim (int): output dimension of each LSTM layer.
618
- lstm_layer_norm (bool): if ``True``, enables layer normalization for LSTM layers.
619
- lstm_layer_norm_epsilon (float): value of epsilon to use in LSTM layer normalization layers.
620
- lstm_dropout (float): LSTM dropout probability.
621
- joiner_activation (str): activation function to use in the joiner.
622
- Must be one of ("relu", "tanh"). (Default: "relu")
623
- attndim (int): TCPGen attention dimension
624
- biasing (bool): If true, use biasing, otherwise use standard RNN-T
625
- charlist (list): The list of word piece tokens in the same order as the output layer
626
- deepbiasing (bool): If true, use deep biasing by extracting the biasing vector
627
- tcpsche (int): The epoch at which TCPGen starts to train
628
- DBaverage (bool): If true, instead of TCPGen, use DBRNNT for biasing
629
-
630
- Returns:
631
- RNNT:
632
- Conformer RNN-T model with TCPGen-based biasing support.
633
- """
634
- encoder = _ConformerEncoder(
635
- input_dim=input_dim,
636
- output_dim=encoding_dim,
637
- time_reduction_stride=time_reduction_stride,
638
- conformer_input_dim=conformer_input_dim,
639
- conformer_ffn_dim=conformer_ffn_dim,
640
- conformer_num_layers=conformer_num_layers,
641
- conformer_num_heads=conformer_num_heads,
642
- conformer_depthwise_conv_kernel_size=conformer_depthwise_conv_kernel_size,
643
- conformer_dropout=conformer_dropout,
644
- )
645
- predictor = _Predictor(
646
- num_symbols=num_symbols,
647
- output_dim=encoding_dim,
648
- symbol_embedding_dim=symbol_embedding_dim,
649
- num_lstm_layers=num_lstm_layers,
650
- lstm_hidden_dim=lstm_hidden_dim,
651
- lstm_layer_norm=lstm_layer_norm,
652
- lstm_layer_norm_epsilon=lstm_layer_norm_epsilon,
653
- lstm_dropout=lstm_dropout,
654
- )
655
- joiner = _JoinerBiasing(
656
- encoding_dim,
657
- num_symbols,
658
- activation=joiner_activation,
659
- deepbiasing=deepbiasing,
660
- attndim=attndim,
661
- biasing=biasing,
662
- )
663
- return RNNTBiasing(
664
- encoder,
665
- predictor,
666
- joiner,
667
- attndim,
668
- biasing,
669
- deepbiasing,
670
- symbol_embedding_dim,
671
- encoding_dim,
672
- charlist,
673
- encoding_dim,
674
- conformer_dropout,
675
- tcpsche,
676
- DBaverage,
677
- )
678
-
679
-
680
- def conformer_rnnt_biasing_base(charlist=None, biasing=True) -> RNNT:
681
- r"""Builds basic version of Conformer RNN-T model with TCPGen.
682
-
683
- Returns:
684
- RNNT:
685
- Conformer RNN-T model with TCPGen-based biasing support.
686
- """
687
- return conformer_rnnt_biasing(
688
- input_dim=80,
689
- encoding_dim=576,
690
- time_reduction_stride=4,
691
- conformer_input_dim=144,
692
- conformer_ffn_dim=576,
693
- conformer_num_layers=16,
694
- conformer_num_heads=4,
695
- conformer_depthwise_conv_kernel_size=31,
696
- conformer_dropout=0.1,
697
- num_symbols=601,
698
- symbol_embedding_dim=256,
699
- num_lstm_layers=1,
700
- lstm_hidden_dim=320,
701
- lstm_layer_norm=True,
702
- lstm_layer_norm_epsilon=1e-5,
703
- lstm_dropout=0.3,
704
- joiner_activation="tanh",
705
- attndim=256,
706
- biasing=biasing,
707
- charlist=charlist,
708
- deepbiasing=True,
709
- tcpsche=30,
710
- DBaverage=False,
711
- )