torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,467 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import warnings
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torchaudio import functional as F
9
+
10
+
11
+ __all__ = []
12
+
13
+
14
+ def _get_mvdr_vector(
15
+ psd_s: torch.Tensor,
16
+ psd_n: torch.Tensor,
17
+ reference_vector: torch.Tensor,
18
+ solution: str = "ref_channel",
19
+ diagonal_loading: bool = True,
20
+ diag_eps: float = 1e-7,
21
+ eps: float = 1e-8,
22
+ ) -> torch.Tensor:
23
+ r"""Compute the MVDR beamforming weights with ``solution`` argument.
24
+
25
+ Args:
26
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
27
+ Tensor with dimensions `(..., freq, channel, channel)`.
28
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
29
+ Tensor with dimensions `(..., freq, channel, channel)`.
30
+ reference_vector (torch.Tensor): one-hot reference channel matrix.
31
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
32
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
33
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
34
+ (Default: ``True``)
35
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
36
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
37
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
38
+ (Default: ``1e-8``)
39
+
40
+ Returns:
41
+ torch.Tensor: the mvdr beamforming weight matrix
42
+ """
43
+ if solution == "ref_channel":
44
+ beamform_vector = F.mvdr_weights_souden(psd_s, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
45
+ else:
46
+ if solution == "stv_evd":
47
+ stv = F.rtf_evd(psd_s)
48
+ else:
49
+ stv = F.rtf_power(psd_s, psd_n, reference_vector, diagonal_loading=diagonal_loading, diag_eps=diag_eps)
50
+ beamform_vector = F.mvdr_weights_rtf(stv, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
51
+
52
+ return beamform_vector
53
+
54
+
55
+ class PSD(torch.nn.Module):
56
+ r"""Compute cross-channel power spectral density (PSD) matrix.
57
+
58
+ .. devices:: CPU CUDA
59
+
60
+ .. properties:: Autograd TorchScript
61
+
62
+ Args:
63
+ multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
64
+ normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
65
+ eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
66
+ """
67
+
68
+ def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
69
+ super().__init__()
70
+ self.multi_mask = multi_mask
71
+ self.normalize = normalize
72
+ self.eps = eps
73
+
74
+ def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None):
75
+ """
76
+ Args:
77
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
78
+ Tensor with dimensions `(..., channel, freq, time)`.
79
+ mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
80
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
81
+ with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
82
+ (Default: ``None``)
83
+
84
+ Returns:
85
+ torch.Tensor: The complex-valued PSD matrix of the input spectrum.
86
+ Tensor with dimensions `(..., freq, channel, channel)`
87
+ """
88
+ if mask is not None:
89
+ if self.multi_mask:
90
+ # Averaging mask along channel dimension
91
+ mask = mask.mean(dim=-3) # (..., freq, time)
92
+ psd = F.psd(specgram, mask, self.normalize, self.eps)
93
+
94
+ return psd
95
+
96
+
97
+ class MVDR(torch.nn.Module):
98
+ """Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks.
99
+
100
+ .. devices:: CPU CUDA
101
+
102
+ .. properties:: Autograd TorchScript
103
+
104
+ Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
105
+
106
+ We provide three solutions of MVDR beamforming. One is based on *reference channel selection*
107
+ :cite:`souden2009optimal` (``solution=ref_channel``).
108
+
109
+ .. math::
110
+ \\textbf{w}_{\\text{MVDR}}(f) =\
111
+ \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bf{\\Phi}_{\\textbf{SS}}}}(f)}\
112
+ {\\text{Trace}({{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f) \\bf{\\Phi}_{\\textbf{SS}}}(f))}}\\bm{u}
113
+
114
+ where :math:`\\bf{\\Phi}_{\\textbf{SS}}` and :math:`\\bf{\\Phi}_{\\textbf{NN}}` are the covariance\
115
+ matrices of speech and noise, respectively. :math:`\\bf{u}` is an one-hot vector to determine the\
116
+ reference channel.
117
+
118
+ The other two solutions are based on the steering vector (``solution=stv_evd`` or ``solution=stv_power``).
119
+
120
+ .. math::
121
+ \\textbf{w}_{\\text{MVDR}}(f) =\
122
+ \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}}\
123
+ {{\\bm{v}^{\\mathsf{H}}}(f){\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}
124
+
125
+ where :math:`\\bm{v}` is the acoustic transfer function or the steering vector.\
126
+ :math:`.^{\\mathsf{H}}` denotes the Hermitian Conjugate operation.
127
+
128
+ We apply either *eigenvalue decomposition*
129
+ :cite:`higuchi2016robust` or the *power method* :cite:`mises1929praktische` to get the
130
+ steering vector from the PSD matrix of speech.
131
+
132
+ After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by
133
+
134
+ .. math::
135
+ \\hat{\\bf{S}} = {\\bf{w}^\\mathsf{H}}{\\bf{Y}}, {\\bf{w}} \\in \\mathbb{C}^{M \\times F}
136
+
137
+ where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the STFT of the multi-channel noisy speech and\
138
+ the single-channel enhanced speech, respectively.
139
+
140
+ For online streaming audio, we provide a *recursive method* :cite:`higuchi2017online` to update the
141
+ PSD matrices of speech and noise, respectively.
142
+
143
+ Args:
144
+ ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
145
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
146
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
147
+ multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
148
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
149
+ of the noise. (Default: ``True``)
150
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
151
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
152
+ online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
153
+ the previous covarience matrices. (Default: ``False``)
154
+
155
+ Note:
156
+ To improve the numerical stability, the input spectrogram will be converted to double precision
157
+ (``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
158
+ is converted to the dtype of the input spectrogram to be compatible with other modules.
159
+
160
+ Note:
161
+ If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
162
+ eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ ref_channel: int = 0,
168
+ solution: str = "ref_channel",
169
+ multi_mask: bool = False,
170
+ diag_loading: bool = True,
171
+ diag_eps: float = 1e-7,
172
+ online: bool = False,
173
+ ):
174
+ super().__init__()
175
+ if solution not in [
176
+ "ref_channel",
177
+ "stv_evd",
178
+ "stv_power",
179
+ ]:
180
+ raise ValueError(
181
+ '`solution` must be one of ["ref_channel", "stv_evd", "stv_power"]. Given {}'.format(solution)
182
+ )
183
+ self.ref_channel = ref_channel
184
+ self.solution = solution
185
+ self.multi_mask = multi_mask
186
+ self.diag_loading = diag_loading
187
+ self.diag_eps = diag_eps
188
+ self.online = online
189
+ self.psd = PSD(multi_mask)
190
+
191
+ psd_s: torch.Tensor = torch.zeros(1)
192
+ psd_n: torch.Tensor = torch.zeros(1)
193
+ mask_sum_s: torch.Tensor = torch.zeros(1)
194
+ mask_sum_n: torch.Tensor = torch.zeros(1)
195
+ self.register_buffer("psd_s", psd_s)
196
+ self.register_buffer("psd_n", psd_n)
197
+ self.register_buffer("mask_sum_s", mask_sum_s)
198
+ self.register_buffer("mask_sum_n", mask_sum_n)
199
+
200
+ def _get_updated_mvdr_vector(
201
+ self,
202
+ psd_s: torch.Tensor,
203
+ psd_n: torch.Tensor,
204
+ mask_s: torch.Tensor,
205
+ mask_n: torch.Tensor,
206
+ reference_vector: torch.Tensor,
207
+ solution: str = "ref_channel",
208
+ diagonal_loading: bool = True,
209
+ diag_eps: float = 1e-7,
210
+ eps: float = 1e-8,
211
+ ) -> torch.Tensor:
212
+ r"""Recursively update the MVDR beamforming vector.
213
+
214
+ Args:
215
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
216
+ Tensor with dimensions `(..., freq, channel, channel)`.
217
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
218
+ Tensor with dimensions `(..., freq, channel, channel)`.
219
+ mask_s (torch.Tensor): Time-Frequency mask of the target speech.
220
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
221
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
222
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
223
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
224
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
225
+ reference_vector (torch.Tensor): One-hot reference channel matrix.
226
+ solution (str, optional): Solution to compute the MVDR beamforming weights.
227
+ Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
228
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
229
+ (Default: ``True``)
230
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
231
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
232
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
233
+ (Default: ``1e-8``)
234
+
235
+ Returns:
236
+ torch.Tensor: The MVDR beamforming weight matrix.
237
+ """
238
+ if self.multi_mask:
239
+ # Averaging mask along channel dimension
240
+ mask_s = mask_s.mean(dim=-3) # (..., freq, time)
241
+ mask_n = mask_n.mean(dim=-3) # (..., freq, time)
242
+ if self.psd_s.ndim == 1:
243
+ self.psd_s = psd_s
244
+ self.psd_n = psd_n
245
+ self.mask_sum_s = mask_s.sum(dim=-1)
246
+ self.mask_sum_n = mask_n.sum(dim=-1)
247
+ return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
248
+ else:
249
+ psd_s = self._get_updated_psd_speech(psd_s, mask_s)
250
+ psd_n = self._get_updated_psd_noise(psd_n, mask_n)
251
+ self.psd_s = psd_s
252
+ self.psd_n = psd_n
253
+ self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
254
+ self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
255
+ return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
256
+
257
+ def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
258
+ r"""Update psd of speech recursively.
259
+
260
+ Args:
261
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
262
+ Tensor with dimensions `(..., freq, channel, channel)`.
263
+ mask_s (torch.Tensor): Time-Frequency mask of the target speech.
264
+ Tensor with dimensions `(..., freq, time)`.
265
+
266
+ Returns:
267
+ torch.Tensor: The updated PSD matrix of target speech.
268
+ """
269
+ numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
270
+ denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
271
+ psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None]
272
+ return psd_s
273
+
274
+ def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
275
+ r"""Update psd of noise recursively.
276
+
277
+ Args:
278
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
279
+ Tensor with dimensions `(..., freq, channel, channel)`.
280
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
281
+ Tensor with dimensions `(..., freq, time)`.
282
+
283
+ Returns:
284
+ torch.Tensor: The updated PSD matrix of noise.
285
+ """
286
+ numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
287
+ denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
288
+ psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
289
+ return psd_n
290
+
291
+ def forward(
292
+ self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
293
+ ) -> torch.Tensor:
294
+ """Perform MVDR beamforming.
295
+
296
+ Args:
297
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
298
+ Tensor with dimensions `(..., channel, freq, time)`
299
+ mask_s (torch.Tensor): Time-Frequency mask of target speech.
300
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
301
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
302
+ mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
303
+ Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
304
+ or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
305
+ (Default: None)
306
+
307
+ Returns:
308
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
309
+ """
310
+ dtype = specgram.dtype
311
+ if specgram.ndim < 3:
312
+ raise ValueError(f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}")
313
+ if not specgram.is_complex():
314
+ raise ValueError(
315
+ f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
316
+ Found: {specgram.dtype}"
317
+ )
318
+ if specgram.dtype == torch.cfloat:
319
+ specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
320
+
321
+ if mask_n is None:
322
+ warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
323
+ mask_n = 1 - mask_s
324
+
325
+ psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
326
+ psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
327
+
328
+ u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
329
+ u[..., self.ref_channel].fill_(1)
330
+
331
+ if self.online:
332
+ w_mvdr = self._get_updated_mvdr_vector(
333
+ psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
334
+ )
335
+ else:
336
+ w_mvdr = _get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
337
+
338
+ specgram_enhanced = F.apply_beamforming(w_mvdr, specgram)
339
+
340
+ return specgram_enhanced.to(dtype)
341
+
342
+
343
+ class RTFMVDR(torch.nn.Module):
344
+ r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
345
+ based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
346
+
347
+ .. devices:: CPU CUDA
348
+
349
+ .. properties:: Autograd TorchScript
350
+
351
+ Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
352
+ or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
353
+ a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
354
+ complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
355
+
356
+ .. math::
357
+ \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
358
+
359
+ where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
360
+ :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
361
+
362
+ The beamforming weight is computed by:
363
+
364
+ .. math::
365
+ \textbf{w}_{\text{MVDR}}(f) =
366
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
367
+ {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
368
+ """
369
+
370
+ def forward(
371
+ self,
372
+ specgram: Tensor,
373
+ rtf: Tensor,
374
+ psd_n: Tensor,
375
+ reference_channel: Union[int, Tensor],
376
+ diagonal_loading: bool = True,
377
+ diag_eps: float = 1e-7,
378
+ eps: float = 1e-8,
379
+ ) -> Tensor:
380
+ """
381
+ Args:
382
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
383
+ Tensor with dimensions `(..., channel, freq, time)`
384
+ rtf (torch.Tensor): The complex-valued RTF vector of target speech.
385
+ Tensor with dimensions `(..., freq, channel)`.
386
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
387
+ Tensor with dimensions `(..., freq, channel, channel)`.
388
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
389
+ If the dtype is ``int``, it represents the reference channel index.
390
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
391
+ is one-hot.
392
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
393
+ (Default: ``True``)
394
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
395
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
396
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
397
+ (Default: ``1e-8``)
398
+
399
+ Returns:
400
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
401
+ """
402
+ w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
403
+ spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
404
+ return spectrum_enhanced
405
+
406
+
407
+ class SoudenMVDR(torch.nn.Module):
408
+ r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
409
+ based on the method proposed by *Souden et, al.* :cite:`souden2009optimal`.
410
+
411
+ .. devices:: CPU CUDA
412
+
413
+ .. properties:: Autograd TorchScript
414
+
415
+ Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
416
+ of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
417
+ a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
418
+ complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
419
+
420
+ .. math::
421
+ \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
422
+
423
+ where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.
424
+
425
+ The beamforming weight is computed by:
426
+
427
+ .. math::
428
+ \textbf{w}_{\text{MVDR}}(f) =
429
+ \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
430
+ {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
431
+ """
432
+
433
+ def forward(
434
+ self,
435
+ specgram: Tensor,
436
+ psd_s: Tensor,
437
+ psd_n: Tensor,
438
+ reference_channel: Union[int, Tensor],
439
+ diagonal_loading: bool = True,
440
+ diag_eps: float = 1e-7,
441
+ eps: float = 1e-8,
442
+ ) -> torch.Tensor:
443
+ """
444
+ Args:
445
+ specgram (torch.Tensor): Multi-channel complex-valued spectrum.
446
+ Tensor with dimensions `(..., channel, freq, time)`.
447
+ psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
448
+ Tensor with dimensions `(..., freq, channel, channel)`.
449
+ psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
450
+ Tensor with dimensions `(..., freq, channel, channel)`.
451
+ reference_channel (int or torch.Tensor): Specifies the reference channel.
452
+ If the dtype is ``int``, it represents the reference channel index.
453
+ If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
454
+ is one-hot.
455
+ diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
456
+ (Default: ``True``)
457
+ diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
458
+ It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
459
+ eps (float, optional): Value to add to the denominator in the beamforming weight formula.
460
+ (Default: ``1e-8``)
461
+
462
+ Returns:
463
+ torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
464
+ """
465
+ w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
466
+ spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
467
+ return spectrum_enhanced