torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. torchaudio/__init__.py +204 -0
  2. torchaudio/_extension/__init__.py +61 -0
  3. torchaudio/_extension/utils.py +133 -0
  4. torchaudio/_internal/__init__.py +10 -0
  5. torchaudio/_internal/module_utils.py +171 -0
  6. torchaudio/_torchcodec.py +340 -0
  7. torchaudio/compliance/__init__.py +5 -0
  8. torchaudio/compliance/kaldi.py +813 -0
  9. torchaudio/datasets/__init__.py +47 -0
  10. torchaudio/datasets/cmuarctic.py +157 -0
  11. torchaudio/datasets/cmudict.py +186 -0
  12. torchaudio/datasets/commonvoice.py +86 -0
  13. torchaudio/datasets/dr_vctk.py +121 -0
  14. torchaudio/datasets/fluentcommands.py +108 -0
  15. torchaudio/datasets/gtzan.py +1118 -0
  16. torchaudio/datasets/iemocap.py +147 -0
  17. torchaudio/datasets/librilight_limited.py +111 -0
  18. torchaudio/datasets/librimix.py +133 -0
  19. torchaudio/datasets/librispeech.py +174 -0
  20. torchaudio/datasets/librispeech_biasing.py +189 -0
  21. torchaudio/datasets/libritts.py +168 -0
  22. torchaudio/datasets/ljspeech.py +107 -0
  23. torchaudio/datasets/musdb_hq.py +139 -0
  24. torchaudio/datasets/quesst14.py +136 -0
  25. torchaudio/datasets/snips.py +157 -0
  26. torchaudio/datasets/speechcommands.py +183 -0
  27. torchaudio/datasets/tedlium.py +218 -0
  28. torchaudio/datasets/utils.py +54 -0
  29. torchaudio/datasets/vctk.py +143 -0
  30. torchaudio/datasets/voxceleb1.py +309 -0
  31. torchaudio/datasets/yesno.py +89 -0
  32. torchaudio/functional/__init__.py +130 -0
  33. torchaudio/functional/_alignment.py +128 -0
  34. torchaudio/functional/filtering.py +1685 -0
  35. torchaudio/functional/functional.py +2505 -0
  36. torchaudio/lib/__init__.py +0 -0
  37. torchaudio/lib/_torchaudio.so +0 -0
  38. torchaudio/lib/libtorchaudio.so +0 -0
  39. torchaudio/models/__init__.py +85 -0
  40. torchaudio/models/_hdemucs.py +1008 -0
  41. torchaudio/models/conformer.py +293 -0
  42. torchaudio/models/conv_tasnet.py +330 -0
  43. torchaudio/models/decoder/__init__.py +64 -0
  44. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  45. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  46. torchaudio/models/deepspeech.py +84 -0
  47. torchaudio/models/emformer.py +884 -0
  48. torchaudio/models/rnnt.py +816 -0
  49. torchaudio/models/rnnt_decoder.py +339 -0
  50. torchaudio/models/squim/__init__.py +11 -0
  51. torchaudio/models/squim/objective.py +326 -0
  52. torchaudio/models/squim/subjective.py +150 -0
  53. torchaudio/models/tacotron2.py +1046 -0
  54. torchaudio/models/wav2letter.py +72 -0
  55. torchaudio/models/wav2vec2/__init__.py +45 -0
  56. torchaudio/models/wav2vec2/components.py +1167 -0
  57. torchaudio/models/wav2vec2/model.py +1579 -0
  58. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  59. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  60. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  61. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  62. torchaudio/models/wavernn.py +409 -0
  63. torchaudio/pipelines/__init__.py +102 -0
  64. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  65. torchaudio/pipelines/_squim_pipeline.py +156 -0
  66. torchaudio/pipelines/_tts/__init__.py +16 -0
  67. torchaudio/pipelines/_tts/impl.py +385 -0
  68. torchaudio/pipelines/_tts/interface.py +255 -0
  69. torchaudio/pipelines/_tts/utils.py +230 -0
  70. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  71. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  72. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  73. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  74. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  75. torchaudio/transforms/__init__.py +78 -0
  76. torchaudio/transforms/_multi_channel.py +467 -0
  77. torchaudio/transforms/_transforms.py +2138 -0
  78. torchaudio/utils/__init__.py +4 -0
  79. torchaudio/utils/download.py +89 -0
  80. torchaudio/version.py +2 -0
  81. torchaudio-2.9.1.dist-info/METADATA +133 -0
  82. torchaudio-2.9.1.dist-info/RECORD +85 -0
  83. torchaudio-2.9.1.dist-info/WHEEL +5 -0
  84. torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
  85. torchaudio-2.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,293 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ __all__ = ["Conformer"]
7
+
8
+
9
+ def _lengths_to_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
10
+ batch_size = lengths.shape[0]
11
+ max_length = int(torch.max(lengths).item())
12
+ padding_mask = torch.arange(max_length, device=lengths.device, dtype=lengths.dtype).expand(
13
+ batch_size, max_length
14
+ ) >= lengths.unsqueeze(1)
15
+ return padding_mask
16
+
17
+
18
+ class _ConvolutionModule(torch.nn.Module):
19
+ r"""Conformer convolution module.
20
+
21
+ Args:
22
+ input_dim (int): input dimension.
23
+ num_channels (int): number of depthwise convolution layer input channels.
24
+ depthwise_kernel_size (int): kernel size of depthwise convolution layer.
25
+ dropout (float, optional): dropout probability. (Default: 0.0)
26
+ bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
27
+ use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ input_dim: int,
33
+ num_channels: int,
34
+ depthwise_kernel_size: int,
35
+ dropout: float = 0.0,
36
+ bias: bool = False,
37
+ use_group_norm: bool = False,
38
+ ) -> None:
39
+ super().__init__()
40
+ if (depthwise_kernel_size - 1) % 2 != 0:
41
+ raise ValueError("depthwise_kernel_size must be odd to achieve 'SAME' padding.")
42
+ self.layer_norm = torch.nn.LayerNorm(input_dim)
43
+ self.sequential = torch.nn.Sequential(
44
+ torch.nn.Conv1d(
45
+ input_dim,
46
+ 2 * num_channels,
47
+ 1,
48
+ stride=1,
49
+ padding=0,
50
+ bias=bias,
51
+ ),
52
+ torch.nn.GLU(dim=1),
53
+ torch.nn.Conv1d(
54
+ num_channels,
55
+ num_channels,
56
+ depthwise_kernel_size,
57
+ stride=1,
58
+ padding=(depthwise_kernel_size - 1) // 2,
59
+ groups=num_channels,
60
+ bias=bias,
61
+ ),
62
+ torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
63
+ if use_group_norm
64
+ else torch.nn.BatchNorm1d(num_channels),
65
+ torch.nn.SiLU(),
66
+ torch.nn.Conv1d(
67
+ num_channels,
68
+ input_dim,
69
+ kernel_size=1,
70
+ stride=1,
71
+ padding=0,
72
+ bias=bias,
73
+ ),
74
+ torch.nn.Dropout(dropout),
75
+ )
76
+
77
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
78
+ r"""
79
+ Args:
80
+ input (torch.Tensor): with shape `(B, T, D)`.
81
+
82
+ Returns:
83
+ torch.Tensor: output, with shape `(B, T, D)`.
84
+ """
85
+ x = self.layer_norm(input)
86
+ x = x.transpose(1, 2)
87
+ x = self.sequential(x)
88
+ return x.transpose(1, 2)
89
+
90
+
91
+ class _FeedForwardModule(torch.nn.Module):
92
+ r"""Positionwise feed forward layer.
93
+
94
+ Args:
95
+ input_dim (int): input dimension.
96
+ hidden_dim (int): hidden dimension.
97
+ dropout (float, optional): dropout probability. (Default: 0.0)
98
+ """
99
+
100
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
101
+ super().__init__()
102
+ self.sequential = torch.nn.Sequential(
103
+ torch.nn.LayerNorm(input_dim),
104
+ torch.nn.Linear(input_dim, hidden_dim, bias=True),
105
+ torch.nn.SiLU(),
106
+ torch.nn.Dropout(dropout),
107
+ torch.nn.Linear(hidden_dim, input_dim, bias=True),
108
+ torch.nn.Dropout(dropout),
109
+ )
110
+
111
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
112
+ r"""
113
+ Args:
114
+ input (torch.Tensor): with shape `(*, D)`.
115
+
116
+ Returns:
117
+ torch.Tensor: output, with shape `(*, D)`.
118
+ """
119
+ return self.sequential(input)
120
+
121
+
122
+ class ConformerLayer(torch.nn.Module):
123
+ r"""Conformer layer that constitutes Conformer.
124
+
125
+ Args:
126
+ input_dim (int): input dimension.
127
+ ffn_dim (int): hidden layer dimension of feedforward network.
128
+ num_attention_heads (int): number of attention heads.
129
+ depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
130
+ dropout (float, optional): dropout probability. (Default: 0.0)
131
+ use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
132
+ in the convolution module. (Default: ``False``)
133
+ convolution_first (bool, optional): apply the convolution module ahead of
134
+ the attention module. (Default: ``False``)
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ input_dim: int,
140
+ ffn_dim: int,
141
+ num_attention_heads: int,
142
+ depthwise_conv_kernel_size: int,
143
+ dropout: float = 0.0,
144
+ use_group_norm: bool = False,
145
+ convolution_first: bool = False,
146
+ ) -> None:
147
+ super().__init__()
148
+
149
+ self.ffn1 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
150
+
151
+ self.self_attn_layer_norm = torch.nn.LayerNorm(input_dim)
152
+ self.self_attn = torch.nn.MultiheadAttention(input_dim, num_attention_heads, dropout=dropout)
153
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
154
+
155
+ self.conv_module = _ConvolutionModule(
156
+ input_dim=input_dim,
157
+ num_channels=input_dim,
158
+ depthwise_kernel_size=depthwise_conv_kernel_size,
159
+ dropout=dropout,
160
+ bias=True,
161
+ use_group_norm=use_group_norm,
162
+ )
163
+
164
+ self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
165
+ self.final_layer_norm = torch.nn.LayerNorm(input_dim)
166
+ self.convolution_first = convolution_first
167
+
168
+ def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
169
+ residual = input
170
+ input = input.transpose(0, 1)
171
+ input = self.conv_module(input)
172
+ input = input.transpose(0, 1)
173
+ input = residual + input
174
+ return input
175
+
176
+ def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
177
+ r"""
178
+ Args:
179
+ input (torch.Tensor): input, with shape `(T, B, D)`.
180
+ key_padding_mask (torch.Tensor or None): key padding mask to use in self attention layer.
181
+
182
+ Returns:
183
+ torch.Tensor: output, with shape `(T, B, D)`.
184
+ """
185
+ residual = input
186
+ x = self.ffn1(input)
187
+ x = x * 0.5 + residual
188
+
189
+ if self.convolution_first:
190
+ x = self._apply_convolution(x)
191
+
192
+ residual = x
193
+ x = self.self_attn_layer_norm(x)
194
+ x, _ = self.self_attn(
195
+ query=x,
196
+ key=x,
197
+ value=x,
198
+ key_padding_mask=key_padding_mask,
199
+ need_weights=False,
200
+ )
201
+ x = self.self_attn_dropout(x)
202
+ x = x + residual
203
+
204
+ if not self.convolution_first:
205
+ x = self._apply_convolution(x)
206
+
207
+ residual = x
208
+ x = self.ffn2(x)
209
+ x = x * 0.5 + residual
210
+
211
+ x = self.final_layer_norm(x)
212
+ return x
213
+
214
+
215
+ class Conformer(torch.nn.Module):
216
+ r"""Conformer architecture introduced in
217
+ *Conformer: Convolution-augmented Transformer for Speech Recognition*
218
+ :cite:`gulati2020conformer`.
219
+
220
+ Args:
221
+ input_dim (int): input dimension.
222
+ num_heads (int): number of attention heads in each Conformer layer.
223
+ ffn_dim (int): hidden layer dimension of feedforward networks.
224
+ num_layers (int): number of Conformer layers to instantiate.
225
+ depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
226
+ dropout (float, optional): dropout probability. (Default: 0.0)
227
+ use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
228
+ in the convolution module. (Default: ``False``)
229
+ convolution_first (bool, optional): apply the convolution module ahead of
230
+ the attention module. (Default: ``False``)
231
+
232
+ Examples:
233
+ >>> conformer = Conformer(
234
+ >>> input_dim=80,
235
+ >>> num_heads=4,
236
+ >>> ffn_dim=128,
237
+ >>> num_layers=4,
238
+ >>> depthwise_conv_kernel_size=31,
239
+ >>> )
240
+ >>> lengths = torch.randint(1, 400, (10,)) # (batch,)
241
+ >>> input = torch.rand(10, int(lengths.max()), input_dim) # (batch, num_frames, input_dim)
242
+ >>> output = conformer(input, lengths)
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ input_dim: int,
248
+ num_heads: int,
249
+ ffn_dim: int,
250
+ num_layers: int,
251
+ depthwise_conv_kernel_size: int,
252
+ dropout: float = 0.0,
253
+ use_group_norm: bool = False,
254
+ convolution_first: bool = False,
255
+ ):
256
+ super().__init__()
257
+
258
+ self.conformer_layers = torch.nn.ModuleList(
259
+ [
260
+ ConformerLayer(
261
+ input_dim,
262
+ ffn_dim,
263
+ num_heads,
264
+ depthwise_conv_kernel_size,
265
+ dropout=dropout,
266
+ use_group_norm=use_group_norm,
267
+ convolution_first=convolution_first,
268
+ )
269
+ for _ in range(num_layers)
270
+ ]
271
+ )
272
+
273
+ def forward(self, input: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
274
+ r"""
275
+ Args:
276
+ input (torch.Tensor): with shape `(B, T, input_dim)`.
277
+ lengths (torch.Tensor): with shape `(B,)` and i-th element representing
278
+ number of valid frames for i-th batch element in ``input``.
279
+
280
+ Returns:
281
+ (torch.Tensor, torch.Tensor)
282
+ torch.Tensor
283
+ output frames, with shape `(B, T, input_dim)`
284
+ torch.Tensor
285
+ output lengths, with shape `(B,)` and i-th element representing
286
+ number of valid frames for i-th batch element in output frames.
287
+ """
288
+ encoder_padding_mask = _lengths_to_padding_mask(lengths)
289
+
290
+ x = input.transpose(0, 1)
291
+ for layer in self.conformer_layers:
292
+ x = layer(x, encoder_padding_mask)
293
+ return x.transpose(0, 1), lengths
@@ -0,0 +1,330 @@
1
+ """Implements Conv-TasNet with building blocks of it.
2
+
3
+ Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c
4
+ """
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+
10
+
11
+ class ConvBlock(torch.nn.Module):
12
+ """1D Convolutional block.
13
+
14
+ Args:
15
+ io_channels (int): The number of input/output channels, <B, Sc>
16
+ hidden_channels (int): The number of channels in the internal layers, <H>.
17
+ kernel_size (int): The convolution kernel size of the middle layer, <P>.
18
+ padding (int): Padding value of the convolution in the middle layer.
19
+ dilation (int, optional): Dilation value of the convolution in the middle layer.
20
+ no_redisual (bool, optional): Disable residual block/output.
21
+
22
+ Note:
23
+ This implementation corresponds to the "non-causal" setting in the paper.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ io_channels: int,
29
+ hidden_channels: int,
30
+ kernel_size: int,
31
+ padding: int,
32
+ dilation: int = 1,
33
+ no_residual: bool = False,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.conv_layers = torch.nn.Sequential(
38
+ torch.nn.Conv1d(in_channels=io_channels, out_channels=hidden_channels, kernel_size=1),
39
+ torch.nn.PReLU(),
40
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
41
+ torch.nn.Conv1d(
42
+ in_channels=hidden_channels,
43
+ out_channels=hidden_channels,
44
+ kernel_size=kernel_size,
45
+ padding=padding,
46
+ dilation=dilation,
47
+ groups=hidden_channels,
48
+ ),
49
+ torch.nn.PReLU(),
50
+ torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08),
51
+ )
52
+
53
+ self.res_out = (
54
+ None
55
+ if no_residual
56
+ else torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
57
+ )
58
+ self.skip_out = torch.nn.Conv1d(in_channels=hidden_channels, out_channels=io_channels, kernel_size=1)
59
+
60
+ def forward(self, input: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
61
+ feature = self.conv_layers(input)
62
+ if self.res_out is None:
63
+ residual = None
64
+ else:
65
+ residual = self.res_out(feature)
66
+ skip_out = self.skip_out(feature)
67
+ return residual, skip_out
68
+
69
+
70
+ class MaskGenerator(torch.nn.Module):
71
+ """TCN (Temporal Convolution Network) Separation Module
72
+
73
+ Generates masks for separation.
74
+
75
+ Args:
76
+ input_dim (int): Input feature dimension, <N>.
77
+ num_sources (int): The number of sources to separate.
78
+ kernel_size (int): The convolution kernel size of conv blocks, <P>.
79
+ num_featrs (int): Input/output feature dimenstion of conv blocks, <B, Sc>.
80
+ num_hidden (int): Intermediate feature dimention of conv blocks, <H>
81
+ num_layers (int): The number of conv blocks in one stack, <X>.
82
+ num_stacks (int): The number of conv block stacks, <R>.
83
+ msk_activate (str): The activation function of the mask output.
84
+
85
+ Note:
86
+ This implementation corresponds to the "non-causal" setting in the paper.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ input_dim: int,
92
+ num_sources: int,
93
+ kernel_size: int,
94
+ num_feats: int,
95
+ num_hidden: int,
96
+ num_layers: int,
97
+ num_stacks: int,
98
+ msk_activate: str,
99
+ ):
100
+ super().__init__()
101
+
102
+ self.input_dim = input_dim
103
+ self.num_sources = num_sources
104
+
105
+ self.input_norm = torch.nn.GroupNorm(num_groups=1, num_channels=input_dim, eps=1e-8)
106
+ self.input_conv = torch.nn.Conv1d(in_channels=input_dim, out_channels=num_feats, kernel_size=1)
107
+
108
+ self.receptive_field = 0
109
+ self.conv_layers = torch.nn.ModuleList([])
110
+ for s in range(num_stacks):
111
+ for l in range(num_layers):
112
+ multi = 2**l
113
+ self.conv_layers.append(
114
+ ConvBlock(
115
+ io_channels=num_feats,
116
+ hidden_channels=num_hidden,
117
+ kernel_size=kernel_size,
118
+ dilation=multi,
119
+ padding=multi,
120
+ # The last ConvBlock does not need residual
121
+ no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)),
122
+ )
123
+ )
124
+ self.receptive_field += kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi
125
+ self.output_prelu = torch.nn.PReLU()
126
+ self.output_conv = torch.nn.Conv1d(
127
+ in_channels=num_feats,
128
+ out_channels=input_dim * num_sources,
129
+ kernel_size=1,
130
+ )
131
+ if msk_activate == "sigmoid":
132
+ self.mask_activate = torch.nn.Sigmoid()
133
+ elif msk_activate == "relu":
134
+ self.mask_activate = torch.nn.ReLU()
135
+ else:
136
+ raise ValueError(f"Unsupported activation {msk_activate}")
137
+
138
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
139
+ """Generate separation mask.
140
+
141
+ Args:
142
+ input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
143
+
144
+ Returns:
145
+ Tensor: shape [batch, num_sources, features, frames]
146
+ """
147
+ batch_size = input.shape[0]
148
+ feats = self.input_norm(input)
149
+ feats = self.input_conv(feats)
150
+ output = 0.0
151
+ for layer in self.conv_layers:
152
+ residual, skip = layer(feats)
153
+ if residual is not None: # the last conv layer does not produce residual
154
+ feats = feats + residual
155
+ output = output + skip
156
+ output = self.output_prelu(output)
157
+ output = self.output_conv(output)
158
+ output = self.mask_activate(output)
159
+ return output.view(batch_size, self.num_sources, self.input_dim, -1)
160
+
161
+
162
+ class ConvTasNet(torch.nn.Module):
163
+ """Conv-TasNet architecture introduced in
164
+ *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
165
+ :cite:`Luo_2019`.
166
+
167
+ Note:
168
+ This implementation corresponds to the "non-causal" setting in the paper.
169
+
170
+ See Also:
171
+ * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
172
+
173
+ Args:
174
+ num_sources (int, optional): The number of sources to split.
175
+ enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
176
+ enc_num_feats (int, optional): The feature dimensions passed to mask generator, <N>.
177
+ msk_kernel_size (int, optional): The convolution kernel size of the mask generator, <P>.
178
+ msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, <B, Sc>.
179
+ msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, <H>.
180
+ msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
181
+ msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
182
+ msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ num_sources: int = 2,
188
+ # encoder/decoder parameters
189
+ enc_kernel_size: int = 16,
190
+ enc_num_feats: int = 512,
191
+ # mask generator parameters
192
+ msk_kernel_size: int = 3,
193
+ msk_num_feats: int = 128,
194
+ msk_num_hidden_feats: int = 512,
195
+ msk_num_layers: int = 8,
196
+ msk_num_stacks: int = 3,
197
+ msk_activate: str = "sigmoid",
198
+ ):
199
+ super().__init__()
200
+
201
+ self.num_sources = num_sources
202
+ self.enc_num_feats = enc_num_feats
203
+ self.enc_kernel_size = enc_kernel_size
204
+ self.enc_stride = enc_kernel_size // 2
205
+
206
+ self.encoder = torch.nn.Conv1d(
207
+ in_channels=1,
208
+ out_channels=enc_num_feats,
209
+ kernel_size=enc_kernel_size,
210
+ stride=self.enc_stride,
211
+ padding=self.enc_stride,
212
+ bias=False,
213
+ )
214
+ self.mask_generator = MaskGenerator(
215
+ input_dim=enc_num_feats,
216
+ num_sources=num_sources,
217
+ kernel_size=msk_kernel_size,
218
+ num_feats=msk_num_feats,
219
+ num_hidden=msk_num_hidden_feats,
220
+ num_layers=msk_num_layers,
221
+ num_stacks=msk_num_stacks,
222
+ msk_activate=msk_activate,
223
+ )
224
+ self.decoder = torch.nn.ConvTranspose1d(
225
+ in_channels=enc_num_feats,
226
+ out_channels=1,
227
+ kernel_size=enc_kernel_size,
228
+ stride=self.enc_stride,
229
+ padding=self.enc_stride,
230
+ bias=False,
231
+ )
232
+
233
+ def _align_num_frames_with_strides(self, input: torch.Tensor) -> Tuple[torch.Tensor, int]:
234
+ """Pad input Tensor so that the end of the input tensor corresponds with
235
+
236
+ 1. (if kernel size is odd) the center of the last convolution kernel
237
+ or 2. (if kernel size is even) the end of the first half of the last convolution kernel
238
+
239
+ Assumption:
240
+ The resulting Tensor will be padded with the size of stride (== kernel_width // 2)
241
+ on the both ends in Conv1D
242
+
243
+ |<--- k_1 --->|
244
+ | | |<-- k_n-1 -->|
245
+ | | | |<--- k_n --->|
246
+ | | | | |
247
+ | | | | |
248
+ | v v v |
249
+ |<---->|<--- input signal --->|<--->|<---->|
250
+ stride PAD stride
251
+
252
+ Args:
253
+ input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
254
+
255
+ Returns:
256
+ Tensor: Padded Tensor
257
+ int: Number of paddings performed
258
+ """
259
+ batch_size, num_channels, num_frames = input.shape
260
+ is_odd = self.enc_kernel_size % 2
261
+ num_strides = (num_frames - is_odd) // self.enc_stride
262
+ num_remainings = num_frames - (is_odd + num_strides * self.enc_stride)
263
+ if num_remainings == 0:
264
+ return input, 0
265
+
266
+ num_paddings = self.enc_stride - num_remainings
267
+ pad = torch.zeros(
268
+ batch_size,
269
+ num_channels,
270
+ num_paddings,
271
+ dtype=input.dtype,
272
+ device=input.device,
273
+ )
274
+ return torch.cat([input, pad], 2), num_paddings
275
+
276
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
277
+ """Perform source separation. Generate audio source waveforms.
278
+
279
+ Args:
280
+ input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
281
+
282
+ Returns:
283
+ Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
284
+ """
285
+ if input.ndim != 3 or input.shape[1] != 1:
286
+ raise ValueError(f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}")
287
+
288
+ # B: batch size
289
+ # L: input frame length
290
+ # L': padded input frame length
291
+ # F: feature dimension
292
+ # M: feature frame length
293
+ # S: number of sources
294
+
295
+ padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L'
296
+ batch_size, num_padded_frames = padded.shape[0], padded.shape[2]
297
+ feats = self.encoder(padded) # B, F, M
298
+ masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M
299
+ masked = masked.view(batch_size * self.num_sources, self.enc_num_feats, -1) # B*S, F, M
300
+ decoded = self.decoder(masked) # B*S, 1, L'
301
+ output = decoded.view(batch_size, self.num_sources, num_padded_frames) # B, S, L'
302
+ if num_pads > 0:
303
+ output = output[..., :-num_pads] # B, S, L
304
+ return output
305
+
306
+
307
+ def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
308
+ r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
309
+
310
+ The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
311
+ except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
312
+
313
+ Args:
314
+ num_sources (int, optional): Number of sources in the output.
315
+ (Default: 2)
316
+ Returns:
317
+ ConvTasNet:
318
+ ConvTasNet model.
319
+ """
320
+ return ConvTasNet(
321
+ num_sources=num_sources,
322
+ enc_kernel_size=16,
323
+ enc_num_feats=512,
324
+ msk_kernel_size=3,
325
+ msk_num_feats=128,
326
+ msk_num_hidden_feats=512,
327
+ msk_num_layers=8,
328
+ msk_num_stacks=3,
329
+ msk_activate="relu",
330
+ )
@@ -0,0 +1,64 @@
1
+ import inspect
2
+
3
+ from torchaudio._internal.module_utils import dropping_class_support, dropping_support
4
+
5
+ _CTC_DECODERS = [
6
+ "CTCHypothesis",
7
+ "CTCDecoder",
8
+ "CTCDecoderLM",
9
+ "CTCDecoderLMState",
10
+ "ctc_decoder",
11
+ "download_pretrained_files",
12
+ ]
13
+ _CUDA_CTC_DECODERS = [
14
+ "CUCTCDecoder",
15
+ "CUCTCHypothesis",
16
+ "cuda_ctc_decoder",
17
+ ]
18
+
19
+
20
+ def __getattr__(name: str):
21
+ if name in _CTC_DECODERS:
22
+ try:
23
+ from . import _ctc_decoder
24
+ except Exception as err:
25
+ raise RuntimeError(
26
+ "CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
27
+ ) from err
28
+
29
+ item = getattr(_ctc_decoder, name)
30
+ globals()[name] = item
31
+ return item
32
+ elif name in _CUDA_CTC_DECODERS:
33
+ try:
34
+ from . import _cuda_ctc_decoder
35
+ except AttributeError as err:
36
+ raise RuntimeError(
37
+ "To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
38
+ ) from err
39
+
40
+ # TODO: when all unsupported classes are removed, replace the
41
+ # following if-else block with
42
+ # item = getattr(_cuda_ctc_decoder, name)
43
+ orig_item = getattr(_cuda_ctc_decoder, name)
44
+ if inspect.isclass(orig_item) or (
45
+ # workaround a failure to detect type instances
46
+ # after sphinx autodoc mocking, required for
47
+ # building docs
48
+ getattr(orig_item, "__sphinx_mock__", False)
49
+ and inspect.isclass(orig_item.__class__)
50
+ ):
51
+ item = dropping_class_support(orig_item)
52
+ else:
53
+ item = dropping_support(orig_item)
54
+
55
+ globals()[name] = item
56
+ return item
57
+ raise AttributeError(f"module {__name__} has no attribute {name}")
58
+
59
+
60
+ def __dir__():
61
+ return sorted(__all__)
62
+
63
+
64
+ __all__ = _CTC_DECODERS + _CUDA_CTC_DECODERS