torchaudio 2.0.2__cp39-cp39-manylinux1_x86_64.whl → 2.1.1__cp39-cp39-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.so +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
  51. torchaudio/lib/_torchaudio_sox.so +0 -0
  52. torchaudio/lib/libctc_prefix_decoder.so +0 -0
  53. torchaudio/lib/libtorchaudio.so +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
  55. torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
  56. torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
  57. torchaudio/lib/libtorchaudio_sox.so +0 -0
  58. torchaudio/lib/pybind11_prefixctc.so +0 -0
  59. torchaudio/models/__init__.py +14 -0
  60. torchaudio/models/decoder/__init__.py +22 -7
  61. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  62. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  63. torchaudio/models/rnnt_decoder.py +10 -14
  64. torchaudio/models/squim/__init__.py +11 -0
  65. torchaudio/models/squim/objective.py +326 -0
  66. torchaudio/models/squim/subjective.py +150 -0
  67. torchaudio/models/wav2vec2/components.py +6 -10
  68. torchaudio/pipelines/__init__.py +9 -0
  69. torchaudio/pipelines/_squim_pipeline.py +176 -0
  70. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  71. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  72. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  73. torchaudio/sox_effects/sox_effects.py +7 -30
  74. torchaudio/transforms/__init__.py +2 -0
  75. torchaudio/transforms/_transforms.py +99 -54
  76. torchaudio/utils/download.py +2 -2
  77. torchaudio/utils/ffmpeg_utils.py +20 -15
  78. torchaudio/utils/sox_utils.py +8 -9
  79. torchaudio/version.py +2 -2
  80. torchaudio-2.1.1.dist-info/METADATA +113 -0
  81. torchaudio-2.1.1.dist-info/RECORD +119 -0
  82. torchaudio/io/_compat.py +0 -241
  83. torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
  84. torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
  85. torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
  86. torchaudio/lib/libflashlight-text.so +0 -0
  87. torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
  88. torchaudio-2.0.2.dist-info/METADATA +0 -26
  89. torchaudio-2.0.2.dist-info/RECORD +0 -100
  90. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  91. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +0 -0
  92. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,187 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from typing import List, NamedTuple, Union
6
+
7
+ import torch
8
+ import torchaudio
9
+
10
+ torchaudio._extension._load_lib("libctc_prefix_decoder")
11
+ import torchaudio.lib.pybind11_prefixctc as cuctc
12
+
13
+
14
+ __all__ = ["CUCTCHypothesis", "CUCTCDecoder", "cuda_ctc_decoder"]
15
+
16
+
17
+ def _get_vocab_list(vocab_file):
18
+ vocab = []
19
+ with open(vocab_file, "r", encoding="utf-8") as f:
20
+ for line in f:
21
+ line = line.strip().split()
22
+ vocab.append(line[0])
23
+ return vocab
24
+
25
+
26
+ class CUCTCHypothesis(NamedTuple):
27
+ r"""Represents hypothesis generated by CUCTC beam search decoder :class:`CUCTCDecoder`."""
28
+ tokens: List[int]
29
+ """Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
30
+
31
+ words: List[str]
32
+ """List of predicted tokens. Algin with modeling unit.
33
+ """
34
+
35
+ score: float
36
+ """Score corresponding to hypothesis"""
37
+
38
+
39
+ _DEFAULT_BLANK_SKIP_THREASHOLD = 0.95
40
+
41
+
42
+ class CUCTCDecoder:
43
+ """CUDA CTC beam search decoder.
44
+
45
+ .. devices:: CUDA
46
+
47
+ Note:
48
+ To build the decoder, please use the factory function :func:`cuda_ctc_decoder`.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ vocab_list: List[str],
54
+ blank_id: int = 0,
55
+ beam_size: int = 10,
56
+ nbest: int = 1,
57
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
58
+ cuda_stream: torch.cuda.streams.Stream = None,
59
+ ):
60
+ """
61
+ Args:
62
+ blank_id (int): token id corresopnding to blank, only support 0 for now. (Default: 0)
63
+ vocab_list (List[str]): list of vocabulary tokens
64
+ beam_size (int, optional): max number of hypos to hold after each decode step (Default: 10)
65
+ nbest (int): number of best decodings to return
66
+ blank_skip_threshold (float):
67
+ skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding.
68
+ (Default: 0.95).
69
+ cuda_stream (torch.cuda.streams.Stream): using assigned cuda stream (Default: using default stream)
70
+
71
+ """
72
+ if cuda_stream:
73
+ if not isinstance(cuda_stream, torch.cuda.streams.Stream):
74
+ raise AssertionError("cuda_stream must be torch.cuda.streams.Stream")
75
+ cuda_stream_ = cuda_stream.cuda_stream if cuda_stream else torch.cuda.current_stream().cuda_stream
76
+ self.internal_data = cuctc.prefixCTC_alloc(cuda_stream_)
77
+ self.memory = torch.empty(0, dtype=torch.int8, device=torch.device("cuda"))
78
+ if blank_id != 0:
79
+ raise AssertionError("blank_id must be 0")
80
+ self.blank_id = blank_id
81
+ self.vocab_list = vocab_list
82
+ self.space_id = 0
83
+ self.nbest = nbest
84
+ if not (blank_skip_threshold >= 0 and blank_skip_threshold <= 1):
85
+ raise AssertionError("blank_skip_threshold must be between 0 and 1")
86
+ self.blank_skip_threshold = math.log(blank_skip_threshold)
87
+ self.beam_size = min(beam_size, len(vocab_list)) # beam size must be smaller than vocab size
88
+
89
+ def __del__(self):
90
+ if cuctc is not None:
91
+ cuctc.prefixCTC_free(self.internal_data)
92
+
93
+ def __call__(self, log_prob: torch.Tensor, encoder_out_lens: torch.Tensor):
94
+ """
95
+ Args:
96
+ log_prob (torch.FloatTensor): GPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
97
+ probability distribution over labels; log_softmax(output of acoustic model).
98
+ lengths (dtype torch.int32): GPU tensor of shape `(batch, )` storing the valid length of
99
+ in time axis of the output Tensor in each batch.
100
+
101
+ Returns:
102
+ List[List[CUCTCHypothesis]]:
103
+ List of sorted best hypotheses for each audio sequence in the batch.
104
+ """
105
+ if not encoder_out_lens.dtype == torch.int32:
106
+ raise AssertionError("encoder_out_lens must be torch.int32")
107
+ if not log_prob.dtype == torch.float32:
108
+ raise AssertionError("log_prob must be torch.float32")
109
+ if not (log_prob.is_cuda and encoder_out_lens.is_cuda):
110
+ raise AssertionError("inputs must be cuda tensors")
111
+ if not (log_prob.is_contiguous() and encoder_out_lens.is_contiguous()):
112
+ raise AssertionError("input tensors must be contiguous")
113
+ required_size, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
114
+ self.internal_data,
115
+ self.memory.data_ptr(),
116
+ self.memory.size(0),
117
+ log_prob.data_ptr(),
118
+ encoder_out_lens.data_ptr(),
119
+ log_prob.size(),
120
+ log_prob.stride(),
121
+ self.beam_size,
122
+ self.blank_id,
123
+ self.space_id,
124
+ self.blank_skip_threshold,
125
+ )
126
+ if required_size > 0:
127
+ self.memory = torch.empty(required_size, dtype=torch.int8, device=log_prob.device).contiguous()
128
+ _, score_hyps = cuctc.ctc_beam_search_decoder_batch_gpu_v2(
129
+ self.internal_data,
130
+ self.memory.data_ptr(),
131
+ self.memory.size(0),
132
+ log_prob.data_ptr(),
133
+ encoder_out_lens.data_ptr(),
134
+ log_prob.size(),
135
+ log_prob.stride(),
136
+ self.beam_size,
137
+ self.blank_id,
138
+ self.space_id,
139
+ self.blank_skip_threshold,
140
+ )
141
+ batch_size = len(score_hyps)
142
+ hypos = []
143
+ for i in range(batch_size):
144
+ hypos.append(
145
+ [
146
+ CUCTCHypothesis(
147
+ tokens=score_hyps[i][j][1],
148
+ words=[self.vocab_list[word_id] for word_id in score_hyps[i][j][1]],
149
+ score=score_hyps[i][j][0],
150
+ )
151
+ for j in range(self.nbest)
152
+ ]
153
+ )
154
+ return hypos
155
+
156
+
157
+ def cuda_ctc_decoder(
158
+ tokens: Union[str, List[str]],
159
+ nbest: int = 1,
160
+ beam_size: int = 10,
161
+ blank_skip_threshold: float = _DEFAULT_BLANK_SKIP_THREASHOLD,
162
+ ) -> CUCTCDecoder:
163
+ """Builds an instance of :class:`CUCTCDecoder`.
164
+
165
+ Args:
166
+ tokens (str or List[str]): File or list containing valid tokens.
167
+ If using a file, the expected format is for tokens mapping to the same index to be on the same line
168
+ beam_size (int, optional): The maximum number of hypos to hold after each decode step (Default: 10)
169
+ nbest (int): The number of best decodings to return
170
+ blank_id (int): The token ID corresopnding to the blank symbol.
171
+ blank_skip_threshold (float): skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding
172
+ (Default: 0.95).
173
+
174
+ Returns:
175
+ CUCTCDecoder: decoder
176
+
177
+ Example
178
+ >>> decoder = cuda_ctc_decoder(
179
+ >>> vocab_file="tokens.txt",
180
+ >>> blank_skip_threshold=0.95,
181
+ >>> )
182
+ >>> results = decoder(log_probs, encoder_out_lens) # List of shape (B, nbest) of Hypotheses
183
+ """
184
+ if type(tokens) == str:
185
+ tokens = _get_vocab_list(tokens)
186
+
187
+ return CUCTCDecoder(vocab_list=tokens, beam_size=beam_size, nbest=nbest, blank_skip_threshold=blank_skip_threshold)
@@ -109,13 +109,9 @@ class RNNTBeamSearch(torch.nn.Module):
109
109
 
110
110
  self.step_max_tokens = step_max_tokens
111
111
 
112
- def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
113
- if hypo is not None:
114
- token = _get_hypo_tokens(hypo)[-1]
115
- state = _get_hypo_state(hypo)
116
- else:
117
- token = self.blank
118
- state = None
112
+ def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
113
+ token = self.blank
114
+ state = None
119
115
 
120
116
  one_tensor = torch.tensor([1], device=device)
121
117
  pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
@@ -230,14 +226,14 @@ class RNNTBeamSearch(torch.nn.Module):
230
226
  def _search(
231
227
  self,
232
228
  enc_out: torch.Tensor,
233
- hypo: Optional[Hypothesis],
229
+ hypo: Optional[List[Hypothesis]],
234
230
  beam_width: int,
235
231
  ) -> List[Hypothesis]:
236
232
  n_time_steps = enc_out.shape[1]
237
233
  device = enc_out.device
238
234
 
239
235
  a_hypos: List[Hypothesis] = []
240
- b_hypos = self._init_b_hypos(hypo, device)
236
+ b_hypos = self._init_b_hypos(device) if hypo is None else hypo
241
237
  for t in range(n_time_steps):
242
238
  a_hypos = b_hypos
243
239
  b_hypos = torch.jit.annotate(List[Hypothesis], [])
@@ -263,7 +259,7 @@ class RNNTBeamSearch(torch.nn.Module):
263
259
  if a_hypos:
264
260
  symbols_current_t += 1
265
261
 
266
- _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
262
+ _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
267
263
  b_hypos = [b_hypos[idx] for idx in sorted_idx]
268
264
 
269
265
  return b_hypos
@@ -290,8 +286,8 @@ class RNNTBeamSearch(torch.nn.Module):
290
286
 
291
287
  if length.shape != () and length.shape != (1,):
292
288
  raise ValueError("length must be of shape () or (1,)")
293
- if input.dim() == 0:
294
- input = input.unsqueeze(0)
289
+ if length.dim() == 0:
290
+ length = length.unsqueeze(0)
295
291
 
296
292
  enc_out, _ = self.model.transcribe(input, length)
297
293
  return self._search(enc_out, None, beam_width)
@@ -303,7 +299,7 @@ class RNNTBeamSearch(torch.nn.Module):
303
299
  length: torch.Tensor,
304
300
  beam_width: int,
305
301
  state: Optional[List[List[torch.Tensor]]] = None,
306
- hypothesis: Optional[Hypothesis] = None,
302
+ hypothesis: Optional[List[Hypothesis]] = None,
307
303
  ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
308
304
  r"""Performs beam search for the given input sequence in streaming mode.
309
305
 
@@ -318,7 +314,7 @@ class RNNTBeamSearch(torch.nn.Module):
318
314
  state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
319
315
  representing transcription network internal state generated in preceding
320
316
  invocation. (Default: ``None``)
321
- hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
317
+ hypothesis (List[Hypothesis] or None): hypotheses from preceding invocation to seed
322
318
  search with. (Default: ``None``)
323
319
 
324
320
  Returns:
@@ -0,0 +1,11 @@
1
+ from .objective import squim_objective_base, squim_objective_model, SquimObjective
2
+ from .subjective import squim_subjective_base, squim_subjective_model, SquimSubjective
3
+
4
+ __all__ = [
5
+ "squim_objective_base",
6
+ "squim_objective_model",
7
+ "squim_subjective_base",
8
+ "squim_subjective_model",
9
+ "SquimObjective",
10
+ "SquimSubjective",
11
+ ]
@@ -0,0 +1,326 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def transform_wb_pesq_range(x: float) -> float:
10
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
11
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
12
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
13
+
14
+ Args:
15
+ x (float): Narrow-band PESQ score.
16
+
17
+ Returns:
18
+ (float): Wide-band PESQ score.
19
+ """
20
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
21
+
22
+
23
+ PESQRange: Tuple[float, float] = (
24
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
25
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
26
+ # We are using 1.0 as a reasonable approximation.
27
+ transform_wb_pesq_range(4.5),
28
+ )
29
+
30
+
31
+ class RangeSigmoid(nn.Module):
32
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
33
+ super(RangeSigmoid, self).__init__()
34
+ assert isinstance(val_range, tuple) and len(val_range) == 2
35
+ self.val_range: Tuple[float, float] = val_range
36
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
40
+ return out
41
+
42
+
43
+ class Encoder(nn.Module):
44
+ """Encoder module that transform 1D waveform to 2D representations.
45
+
46
+ Args:
47
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
48
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
49
+ """
50
+
51
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
52
+ super(Encoder, self).__init__()
53
+
54
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ """Apply waveforms to convolutional layer and ReLU layer.
58
+
59
+ Args:
60
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
61
+
62
+ Returns:
63
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
64
+ """
65
+ out = x.unsqueeze(dim=1)
66
+ out = F.relu(self.conv1d(out))
67
+ return out
68
+
69
+
70
+ class SingleRNN(nn.Module):
71
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
72
+ super(SingleRNN, self).__init__()
73
+
74
+ self.rnn_type = rnn_type
75
+ self.input_size = input_size
76
+ self.hidden_size = hidden_size
77
+
78
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
79
+ input_size,
80
+ hidden_size,
81
+ 1,
82
+ dropout=dropout,
83
+ batch_first=True,
84
+ bidirectional=True,
85
+ )
86
+
87
+ self.proj = nn.Linear(hidden_size * 2, input_size)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ # input shape: batch, seq, dim
91
+ out, _ = self.rnn(x)
92
+ out = self.proj(out)
93
+ return out
94
+
95
+
96
+ class DPRNN(nn.Module):
97
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
98
+
99
+ Args:
100
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
101
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
102
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
103
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
104
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
105
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
106
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ feat_dim: int = 64,
112
+ hidden_dim: int = 128,
113
+ num_blocks: int = 6,
114
+ rnn_type: str = "LSTM",
115
+ d_model: int = 256,
116
+ chunk_size: int = 100,
117
+ chunk_stride: int = 50,
118
+ ) -> None:
119
+ super(DPRNN, self).__init__()
120
+
121
+ self.num_blocks = num_blocks
122
+
123
+ self.row_rnn = nn.ModuleList([])
124
+ self.col_rnn = nn.ModuleList([])
125
+ self.row_norm = nn.ModuleList([])
126
+ self.col_norm = nn.ModuleList([])
127
+ for _ in range(num_blocks):
128
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
129
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
130
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
131
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
132
+ self.conv = nn.Sequential(
133
+ nn.Conv2d(feat_dim, d_model, 1),
134
+ nn.PReLU(),
135
+ )
136
+ self.chunk_size = chunk_size
137
+ self.chunk_stride = chunk_stride
138
+
139
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
140
+ # input shape: (B, N, T)
141
+ seq_len = x.shape[-1]
142
+
143
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
144
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
145
+
146
+ return out, rest
147
+
148
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
149
+ out, rest = self.pad_chunk(x)
150
+ batch_size, feat_dim, seq_len = out.shape
151
+
152
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
153
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
154
+ out = torch.cat([segments1, segments2], dim=3)
155
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
156
+
157
+ return out, rest
158
+
159
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
160
+ batch_size, dim, _, _ = x.shape
161
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
162
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
163
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
164
+ out = out1 + out2
165
+ if rest > 0:
166
+ out = out[:, :, :-rest]
167
+ out = out.contiguous()
168
+ return out
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ x, rest = self.chunking(x)
172
+ batch_size, _, dim1, dim2 = x.shape
173
+ out = x
174
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
175
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
176
+ row_out = row_rnn(row_in)
177
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
178
+ row_out = row_norm(row_out)
179
+ out = out + row_out
180
+
181
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
182
+ col_out = col_rnn(col_in)
183
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
184
+ col_out = col_norm(col_out)
185
+ out = out + col_out
186
+ out = self.conv(out)
187
+ out = self.merging(out, rest)
188
+ out = out.transpose(1, 2).contiguous()
189
+ return out
190
+
191
+
192
+ class AutoPool(nn.Module):
193
+ def __init__(self, pool_dim: int = 1) -> None:
194
+ super(AutoPool, self).__init__()
195
+ self.pool_dim: int = pool_dim
196
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
197
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ weight = self.softmax(torch.mul(x, self.alpha))
201
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
202
+ return out
203
+
204
+
205
+ class SquimObjective(nn.Module):
206
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
207
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
208
+
209
+ Args:
210
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
211
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
212
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ encoder: nn.Module,
218
+ dprnn: nn.Module,
219
+ branches: nn.ModuleList,
220
+ ):
221
+ super(SquimObjective, self).__init__()
222
+ self.encoder = encoder
223
+ self.dprnn = dprnn
224
+ self.branches = branches
225
+
226
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
227
+ """
228
+ Args:
229
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
230
+
231
+ Returns:
232
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
233
+ """
234
+ if x.ndim != 2:
235
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
236
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
237
+ out = self.encoder(x)
238
+ out = self.dprnn(out)
239
+ scores = []
240
+ for branch in self.branches:
241
+ scores.append(branch(out).squeeze(dim=1))
242
+ return scores
243
+
244
+
245
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
246
+ """Create branch module after DPRNN model for predicting metric score.
247
+
248
+ Args:
249
+ d_model (int): The number of expected features in the input.
250
+ nhead (int): Number of heads in the multi-head attention model.
251
+ metric (str): The metric name to predict.
252
+
253
+ Returns:
254
+ (nn.Module): Returned module to predict corresponding metric score.
255
+ """
256
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
257
+ layer2 = AutoPool()
258
+ if metric == "stoi":
259
+ layer3 = nn.Sequential(
260
+ nn.Linear(d_model, d_model),
261
+ nn.PReLU(),
262
+ nn.Linear(d_model, 1),
263
+ RangeSigmoid(),
264
+ )
265
+ elif metric == "pesq":
266
+ layer3 = nn.Sequential(
267
+ nn.Linear(d_model, d_model),
268
+ nn.PReLU(),
269
+ nn.Linear(d_model, 1),
270
+ RangeSigmoid(val_range=PESQRange),
271
+ )
272
+ else:
273
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
274
+ return nn.Sequential(layer1, layer2, layer3)
275
+
276
+
277
+ def squim_objective_model(
278
+ feat_dim: int,
279
+ win_len: int,
280
+ d_model: int,
281
+ nhead: int,
282
+ hidden_dim: int,
283
+ num_blocks: int,
284
+ rnn_type: str,
285
+ chunk_size: int,
286
+ chunk_stride: Optional[int] = None,
287
+ ) -> SquimObjective:
288
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
289
+
290
+ Args:
291
+ feat_dim (int, optional): The feature dimension after Encoder module.
292
+ win_len (int): Kernel size in the Encoder module.
293
+ d_model (int): The number of expected features in the input.
294
+ nhead (int): Number of heads in the multi-head attention model.
295
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
296
+ num_blocks (int): Number of DPRNN layers.
297
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
298
+ chunk_size (int): Chunk size of input for DPRNN.
299
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
300
+ """
301
+ if chunk_stride is None:
302
+ chunk_stride = chunk_size // 2
303
+ encoder = Encoder(feat_dim, win_len)
304
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
305
+ branches = nn.ModuleList(
306
+ [
307
+ _create_branch(d_model, nhead, "stoi"),
308
+ _create_branch(d_model, nhead, "pesq"),
309
+ _create_branch(d_model, nhead, "sisdr"),
310
+ ]
311
+ )
312
+ return SquimObjective(encoder, dprnn, branches)
313
+
314
+
315
+ def squim_objective_base() -> SquimObjective:
316
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
317
+ return squim_objective_model(
318
+ feat_dim=256,
319
+ win_len=64,
320
+ d_model=256,
321
+ nhead=4,
322
+ hidden_dim=256,
323
+ num_blocks=2,
324
+ rnn_type="LSTM",
325
+ chunk_size=71,
326
+ )