torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,339 @@
1
+ from typing import Callable, Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torchaudio.models import RNNT
5
+
6
+
7
+ __all__ = ["Hypothesis", "RNNTBeamSearch"]
8
+
9
+
10
+ Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
11
+ Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
12
+ represented as tuple of (tokens, prediction network output, prediction network state, score).
13
+ """
14
+
15
+
16
+ def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
17
+ return hypo[0]
18
+
19
+
20
+ def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
21
+ return hypo[1]
22
+
23
+
24
+ def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
25
+ return hypo[2]
26
+
27
+
28
+ def _get_hypo_score(hypo: Hypothesis) -> float:
29
+ return hypo[3]
30
+
31
+
32
+ def _get_hypo_key(hypo: Hypothesis) -> str:
33
+ return str(hypo[0])
34
+
35
+
36
+ def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
37
+ states: List[List[torch.Tensor]] = []
38
+ for i in range(len(_get_hypo_state(hypos[0]))):
39
+ batched_state_components: List[torch.Tensor] = []
40
+ for j in range(len(_get_hypo_state(hypos[0])[i])):
41
+ batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
42
+ states.append(batched_state_components)
43
+ return states
44
+
45
+
46
+ def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.device) -> List[List[torch.Tensor]]:
47
+ idx_tensor = torch.tensor([idx], device=device)
48
+ return [[state.index_select(0, idx_tensor) for state in state_tuple] for state_tuple in states]
49
+
50
+
51
+ def _default_hypo_sort_key(hypo: Hypothesis) -> float:
52
+ return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
53
+
54
+
55
+ def _compute_updated_scores(
56
+ hypos: List[Hypothesis],
57
+ next_token_probs: torch.Tensor,
58
+ beam_width: int,
59
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
61
+ nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
62
+ nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
63
+ nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
64
+ nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
65
+ return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
66
+
67
+
68
+ def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
69
+ for i, elem in enumerate(hypo_list):
70
+ if _get_hypo_key(hypo) == _get_hypo_key(elem):
71
+ del hypo_list[i]
72
+ break
73
+
74
+
75
+ class RNNTBeamSearch(torch.nn.Module):
76
+ r"""Beam search decoder for RNN-T model.
77
+
78
+ See Also:
79
+ * :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
80
+
81
+ Args:
82
+ model (RNNT): RNN-T model to use.
83
+ blank (int): index of blank token in vocabulary.
84
+ temperature (float, optional): temperature to apply to joint network output.
85
+ Larger values yield more uniform samples. (Default: 1.0)
86
+ hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
87
+ for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
88
+ hypothesis score normalized by token sequence length. (Default: None)
89
+ step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ model: RNNT,
95
+ blank: int,
96
+ temperature: float = 1.0,
97
+ hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
98
+ step_max_tokens: int = 100,
99
+ ) -> None:
100
+ super().__init__()
101
+ self.model = model
102
+ self.blank = blank
103
+ self.temperature = temperature
104
+
105
+ if hypo_sort_key is None:
106
+ self.hypo_sort_key = _default_hypo_sort_key
107
+ else:
108
+ self.hypo_sort_key = hypo_sort_key
109
+
110
+ self.step_max_tokens = step_max_tokens
111
+
112
+ def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
113
+ token = self.blank
114
+ state = None
115
+
116
+ one_tensor = torch.tensor([1], device=device)
117
+ pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
118
+ init_hypo = (
119
+ [token],
120
+ pred_out[0].detach(),
121
+ pred_state,
122
+ 0.0,
123
+ )
124
+ return [init_hypo]
125
+
126
+ def _gen_next_token_probs(
127
+ self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
128
+ ) -> torch.Tensor:
129
+ one_tensor = torch.tensor([1], device=device)
130
+ predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
131
+ joined_out, _, _ = self.model.join(
132
+ enc_out,
133
+ one_tensor,
134
+ predictor_out,
135
+ torch.tensor([1] * len(hypos), device=device),
136
+ ) # [beam_width, 1, 1, num_tokens]
137
+ joined_out = torch.nn.functional.log_softmax(joined_out / self.temperature, dim=3)
138
+ return joined_out[:, 0, 0]
139
+
140
+ def _gen_b_hypos(
141
+ self,
142
+ b_hypos: List[Hypothesis],
143
+ a_hypos: List[Hypothesis],
144
+ next_token_probs: torch.Tensor,
145
+ key_to_b_hypo: Dict[str, Hypothesis],
146
+ ) -> List[Hypothesis]:
147
+ for i in range(len(a_hypos)):
148
+ h_a = a_hypos[i]
149
+ append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
150
+ if _get_hypo_key(h_a) in key_to_b_hypo:
151
+ h_b = key_to_b_hypo[_get_hypo_key(h_a)]
152
+ _remove_hypo(h_b, b_hypos)
153
+ score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
154
+ else:
155
+ score = float(append_blank_score)
156
+ h_b = (
157
+ _get_hypo_tokens(h_a),
158
+ _get_hypo_predictor_out(h_a),
159
+ _get_hypo_state(h_a),
160
+ score,
161
+ )
162
+ b_hypos.append(h_b)
163
+ key_to_b_hypo[_get_hypo_key(h_b)] = h_b
164
+ _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
165
+ return [b_hypos[idx] for idx in sorted_idx]
166
+
167
+ def _gen_a_hypos(
168
+ self,
169
+ a_hypos: List[Hypothesis],
170
+ b_hypos: List[Hypothesis],
171
+ next_token_probs: torch.Tensor,
172
+ t: int,
173
+ beam_width: int,
174
+ device: torch.device,
175
+ ) -> List[Hypothesis]:
176
+ (
177
+ nonblank_nbest_scores,
178
+ nonblank_nbest_hypo_idx,
179
+ nonblank_nbest_token,
180
+ ) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
181
+
182
+ if len(b_hypos) < beam_width:
183
+ b_nbest_score = -float("inf")
184
+ else:
185
+ b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
186
+
187
+ base_hypos: List[Hypothesis] = []
188
+ new_tokens: List[int] = []
189
+ new_scores: List[float] = []
190
+ for i in range(beam_width):
191
+ score = float(nonblank_nbest_scores[i])
192
+ if score > b_nbest_score:
193
+ a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
194
+ base_hypos.append(a_hypos[a_hypo_idx])
195
+ new_tokens.append(int(nonblank_nbest_token[i]))
196
+ new_scores.append(score)
197
+
198
+ if base_hypos:
199
+ new_hypos = self._gen_new_hypos(base_hypos, new_tokens, new_scores, t, device)
200
+ else:
201
+ new_hypos: List[Hypothesis] = []
202
+
203
+ return new_hypos
204
+
205
+ def _gen_new_hypos(
206
+ self,
207
+ base_hypos: List[Hypothesis],
208
+ tokens: List[int],
209
+ scores: List[float],
210
+ t: int,
211
+ device: torch.device,
212
+ ) -> List[Hypothesis]:
213
+ tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
214
+ states = _batch_state(base_hypos)
215
+ pred_out, _, pred_states = self.model.predict(
216
+ tgt_tokens,
217
+ torch.tensor([1] * len(base_hypos), device=device),
218
+ states,
219
+ )
220
+ new_hypos: List[Hypothesis] = []
221
+ for i, h_a in enumerate(base_hypos):
222
+ new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
223
+ new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
224
+ return new_hypos
225
+
226
+ def _search(
227
+ self,
228
+ enc_out: torch.Tensor,
229
+ hypo: Optional[List[Hypothesis]],
230
+ beam_width: int,
231
+ ) -> List[Hypothesis]:
232
+ n_time_steps = enc_out.shape[1]
233
+ device = enc_out.device
234
+
235
+ a_hypos: List[Hypothesis] = []
236
+ b_hypos = self._init_b_hypos(device) if hypo is None else hypo
237
+ for t in range(n_time_steps):
238
+ a_hypos = b_hypos
239
+ b_hypos = torch.jit.annotate(List[Hypothesis], [])
240
+ key_to_b_hypo: Dict[str, Hypothesis] = {}
241
+ symbols_current_t = 0
242
+
243
+ while a_hypos:
244
+ next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
245
+ next_token_probs = next_token_probs.cpu()
246
+ b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
247
+
248
+ if symbols_current_t == self.step_max_tokens:
249
+ break
250
+
251
+ a_hypos = self._gen_a_hypos(
252
+ a_hypos,
253
+ b_hypos,
254
+ next_token_probs,
255
+ t,
256
+ beam_width,
257
+ device,
258
+ )
259
+ if a_hypos:
260
+ symbols_current_t += 1
261
+
262
+ _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
263
+ b_hypos = [b_hypos[idx] for idx in sorted_idx]
264
+
265
+ return b_hypos
266
+
267
+ def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> List[Hypothesis]:
268
+ r"""Performs beam search for the given input sequence.
269
+
270
+ T: number of frames;
271
+ D: feature dimension of each frame.
272
+
273
+ Args:
274
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
275
+ length (torch.Tensor): number of valid frames in input
276
+ sequence, with shape () or (1,).
277
+ beam_width (int): beam size to use during search.
278
+
279
+ Returns:
280
+ List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
281
+ """
282
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
283
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
284
+ if input.dim() == 2:
285
+ input = input.unsqueeze(0)
286
+
287
+ if length.shape != () and length.shape != (1,):
288
+ raise ValueError("length must be of shape () or (1,)")
289
+ if length.dim() == 0:
290
+ length = length.unsqueeze(0)
291
+
292
+ enc_out, _ = self.model.transcribe(input, length)
293
+ return self._search(enc_out, None, beam_width)
294
+
295
+ @torch.jit.export
296
+ def infer(
297
+ self,
298
+ input: torch.Tensor,
299
+ length: torch.Tensor,
300
+ beam_width: int,
301
+ state: Optional[List[List[torch.Tensor]]] = None,
302
+ hypothesis: Optional[List[Hypothesis]] = None,
303
+ ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
304
+ r"""Performs beam search for the given input sequence in streaming mode.
305
+
306
+ T: number of frames;
307
+ D: feature dimension of each frame.
308
+
309
+ Args:
310
+ input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
311
+ length (torch.Tensor): number of valid frames in input
312
+ sequence, with shape () or (1,).
313
+ beam_width (int): beam size to use during search.
314
+ state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
315
+ representing transcription network internal state generated in preceding
316
+ invocation. (Default: ``None``)
317
+ hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
318
+ search with. (Default: ``None``)
319
+
320
+ Returns:
321
+ (List[Hypothesis], List[List[torch.Tensor]]):
322
+ List[Hypothesis]
323
+ top-``beam_width`` hypotheses found by beam search.
324
+ List[List[torch.Tensor]]
325
+ list of lists of tensors representing transcription network
326
+ internal state generated in current invocation.
327
+ """
328
+ if input.dim() != 2 and not (input.dim() == 3 and input.shape[0] == 1):
329
+ raise ValueError("input must be of shape (T, D) or (1, T, D)")
330
+ if input.dim() == 2:
331
+ input = input.unsqueeze(0)
332
+
333
+ if length.shape != () and length.shape != (1,):
334
+ raise ValueError("length must be of shape () or (1,)")
335
+ if length.dim() == 0:
336
+ length = length.unsqueeze(0)
337
+
338
+ enc_out, _, state = self.model.transcribe_streaming(input, length, state)
339
+ return self._search(enc_out, hypothesis, beam_width), state
@@ -0,0 +1,11 @@
1
+ from .objective import squim_objective_base, squim_objective_model, SquimObjective
2
+ from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
3
+
4
+ __all__ = [
5
+ "squim_objective_base",
6
+ "squim_objective_model",
7
+ "squim_subjective_base",
8
+ "squim_subjective_model",
9
+ "SquimObjective",
10
+ "SquimSubjective",
11
+ ]
@@ -0,0 +1,326 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def transform_wb_pesq_range(x: float) -> float:
10
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
11
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
12
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
13
+
14
+ Args:
15
+ x (float): Narrow-band PESQ score.
16
+
17
+ Returns:
18
+ (float): Wide-band PESQ score.
19
+ """
20
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
21
+
22
+
23
+ PESQRange: Tuple[float, float] = (
24
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
25
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
26
+ # We are using 1.0 as a reasonable approximation.
27
+ transform_wb_pesq_range(4.5),
28
+ )
29
+
30
+
31
+ class RangeSigmoid(nn.Module):
32
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
33
+ super(RangeSigmoid, self).__init__()
34
+ assert isinstance(val_range, tuple) and len(val_range) == 2
35
+ self.val_range: Tuple[float, float] = val_range
36
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
40
+ return out
41
+
42
+
43
+ class Encoder(nn.Module):
44
+ """Encoder module that transform 1D waveform to 2D representations.
45
+
46
+ Args:
47
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
48
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
49
+ """
50
+
51
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
52
+ super(Encoder, self).__init__()
53
+
54
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ """Apply waveforms to convolutional layer and ReLU layer.
58
+
59
+ Args:
60
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
61
+
62
+ Returns:
63
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
64
+ """
65
+ out = x.unsqueeze(dim=1)
66
+ out = F.relu(self.conv1d(out))
67
+ return out
68
+
69
+
70
+ class SingleRNN(nn.Module):
71
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
72
+ super(SingleRNN, self).__init__()
73
+
74
+ self.rnn_type = rnn_type
75
+ self.input_size = input_size
76
+ self.hidden_size = hidden_size
77
+
78
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
79
+ input_size,
80
+ hidden_size,
81
+ 1,
82
+ dropout=dropout,
83
+ batch_first=True,
84
+ bidirectional=True,
85
+ )
86
+
87
+ self.proj = nn.Linear(hidden_size * 2, input_size)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ # input shape: batch, seq, dim
91
+ out, _ = self.rnn(x)
92
+ out = self.proj(out)
93
+ return out
94
+
95
+
96
+ class DPRNN(nn.Module):
97
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
98
+
99
+ Args:
100
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
101
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
102
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
103
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
104
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
105
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
106
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ feat_dim: int = 64,
112
+ hidden_dim: int = 128,
113
+ num_blocks: int = 6,
114
+ rnn_type: str = "LSTM",
115
+ d_model: int = 256,
116
+ chunk_size: int = 100,
117
+ chunk_stride: int = 50,
118
+ ) -> None:
119
+ super(DPRNN, self).__init__()
120
+
121
+ self.num_blocks = num_blocks
122
+
123
+ self.row_rnn = nn.ModuleList([])
124
+ self.col_rnn = nn.ModuleList([])
125
+ self.row_norm = nn.ModuleList([])
126
+ self.col_norm = nn.ModuleList([])
127
+ for _ in range(num_blocks):
128
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
129
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
130
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
131
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
132
+ self.conv = nn.Sequential(
133
+ nn.Conv2d(feat_dim, d_model, 1),
134
+ nn.PReLU(),
135
+ )
136
+ self.chunk_size = chunk_size
137
+ self.chunk_stride = chunk_stride
138
+
139
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
140
+ # input shape: (B, N, T)
141
+ seq_len = x.shape[-1]
142
+
143
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
144
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
145
+
146
+ return out, rest
147
+
148
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
149
+ out, rest = self.pad_chunk(x)
150
+ batch_size, feat_dim, seq_len = out.shape
151
+
152
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
153
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
154
+ out = torch.cat([segments1, segments2], dim=3)
155
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
156
+
157
+ return out, rest
158
+
159
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
160
+ batch_size, dim, _, _ = x.shape
161
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
162
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
163
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
164
+ out = out1 + out2
165
+ if rest > 0:
166
+ out = out[:, :, :-rest]
167
+ out = out.contiguous()
168
+ return out
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ x, rest = self.chunking(x)
172
+ batch_size, _, dim1, dim2 = x.shape
173
+ out = x
174
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
175
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
176
+ row_out = row_rnn(row_in)
177
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
178
+ row_out = row_norm(row_out)
179
+ out = out + row_out
180
+
181
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
182
+ col_out = col_rnn(col_in)
183
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
184
+ col_out = col_norm(col_out)
185
+ out = out + col_out
186
+ out = self.conv(out)
187
+ out = self.merging(out, rest)
188
+ out = out.transpose(1, 2).contiguous()
189
+ return out
190
+
191
+
192
+ class AutoPool(nn.Module):
193
+ def __init__(self, pool_dim: int = 1) -> None:
194
+ super(AutoPool, self).__init__()
195
+ self.pool_dim: int = pool_dim
196
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
197
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ weight = self.softmax(torch.mul(x, self.alpha))
201
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
202
+ return out
203
+
204
+
205
+ class SquimObjective(nn.Module):
206
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
207
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
208
+
209
+ Args:
210
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
211
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
212
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ encoder: nn.Module,
218
+ dprnn: nn.Module,
219
+ branches: nn.ModuleList,
220
+ ):
221
+ super(SquimObjective, self).__init__()
222
+ self.encoder = encoder
223
+ self.dprnn = dprnn
224
+ self.branches = branches
225
+
226
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
227
+ """
228
+ Args:
229
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
230
+
231
+ Returns:
232
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
233
+ """
234
+ if x.ndim != 2:
235
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
236
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
237
+ out = self.encoder(x)
238
+ out = self.dprnn(out)
239
+ scores = []
240
+ for branch in self.branches:
241
+ scores.append(branch(out).squeeze(dim=1))
242
+ return scores
243
+
244
+
245
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
246
+ """Create branch module after DPRNN model for predicting metric score.
247
+
248
+ Args:
249
+ d_model (int): The number of expected features in the input.
250
+ nhead (int): Number of heads in the multi-head attention model.
251
+ metric (str): The metric name to predict.
252
+
253
+ Returns:
254
+ (nn.Module): Returned module to predict corresponding metric score.
255
+ """
256
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
257
+ layer2 = AutoPool()
258
+ if metric == "stoi":
259
+ layer3 = nn.Sequential(
260
+ nn.Linear(d_model, d_model),
261
+ nn.PReLU(),
262
+ nn.Linear(d_model, 1),
263
+ RangeSigmoid(),
264
+ )
265
+ elif metric == "pesq":
266
+ layer3 = nn.Sequential(
267
+ nn.Linear(d_model, d_model),
268
+ nn.PReLU(),
269
+ nn.Linear(d_model, 1),
270
+ RangeSigmoid(val_range=PESQRange),
271
+ )
272
+ else:
273
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
274
+ return nn.Sequential(layer1, layer2, layer3)
275
+
276
+
277
+ def squim_objective_model(
278
+ feat_dim: int,
279
+ win_len: int,
280
+ d_model: int,
281
+ nhead: int,
282
+ hidden_dim: int,
283
+ num_blocks: int,
284
+ rnn_type: str,
285
+ chunk_size: int,
286
+ chunk_stride: Optional[int] = None,
287
+ ) -> SquimObjective:
288
+ """Build a custome :class:`torchaudio.models.squim.SquimObjective` model.
289
+
290
+ Args:
291
+ feat_dim (int, optional): The feature dimension after Encoder module.
292
+ win_len (int): Kernel size in the Encoder module.
293
+ d_model (int): The number of expected features in the input.
294
+ nhead (int): Number of heads in the multi-head attention model.
295
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
296
+ num_blocks (int): Number of DPRNN layers.
297
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
298
+ chunk_size (int): Chunk size of input for DPRNN.
299
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
300
+ """
301
+ if chunk_stride is None:
302
+ chunk_stride = chunk_size // 2
303
+ encoder = Encoder(feat_dim, win_len)
304
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
305
+ branches = nn.ModuleList(
306
+ [
307
+ _create_branch(d_model, nhead, "stoi"),
308
+ _create_branch(d_model, nhead, "pesq"),
309
+ _create_branch(d_model, nhead, "sisdr"),
310
+ ]
311
+ )
312
+ return SquimObjective(encoder, dprnn, branches)
313
+
314
+
315
+ def squim_objective_base() -> SquimObjective:
316
+ """Build :class:`torchaudio.models.squim.SquimObjective` model with default arguments."""
317
+ return squim_objective_model(
318
+ feat_dim=256,
319
+ win_len=64,
320
+ d_model=256,
321
+ nhead=4,
322
+ hidden_dim=256,
323
+ num_blocks=2,
324
+ rnn_type="LSTM",
325
+ chunk_size=71,
326
+ )