torchaudio 2.9.1__cp311-cp311-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.1.dist-info/METADATA +133 -0
- torchaudio-2.9.1.dist-info/RECORD +85 -0
- torchaudio-2.9.1.dist-info/WHEEL +5 -0
- torchaudio-2.9.1.dist-info/licenses/LICENSE +25 -0
- torchaudio-2.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torchaudio import functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = []
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_mvdr_vector(
|
|
15
|
+
psd_s: torch.Tensor,
|
|
16
|
+
psd_n: torch.Tensor,
|
|
17
|
+
reference_vector: torch.Tensor,
|
|
18
|
+
solution: str = "ref_channel",
|
|
19
|
+
diagonal_loading: bool = True,
|
|
20
|
+
diag_eps: float = 1e-7,
|
|
21
|
+
eps: float = 1e-8,
|
|
22
|
+
) -> torch.Tensor:
|
|
23
|
+
r"""Compute the MVDR beamforming weights with ``solution`` argument.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
|
|
27
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
28
|
+
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
|
|
29
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
30
|
+
reference_vector (torch.Tensor): one-hot reference channel matrix.
|
|
31
|
+
solution (str, optional): Solution to compute the MVDR beamforming weights.
|
|
32
|
+
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
|
|
33
|
+
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
|
|
34
|
+
(Default: ``True``)
|
|
35
|
+
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
|
|
36
|
+
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
|
|
37
|
+
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
|
|
38
|
+
(Default: ``1e-8``)
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
torch.Tensor: the mvdr beamforming weight matrix
|
|
42
|
+
"""
|
|
43
|
+
if solution == "ref_channel":
|
|
44
|
+
beamform_vector = F.mvdr_weights_souden(psd_s, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
|
|
45
|
+
else:
|
|
46
|
+
if solution == "stv_evd":
|
|
47
|
+
stv = F.rtf_evd(psd_s)
|
|
48
|
+
else:
|
|
49
|
+
stv = F.rtf_power(psd_s, psd_n, reference_vector, diagonal_loading=diagonal_loading, diag_eps=diag_eps)
|
|
50
|
+
beamform_vector = F.mvdr_weights_rtf(stv, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
|
|
51
|
+
|
|
52
|
+
return beamform_vector
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class PSD(torch.nn.Module):
|
|
56
|
+
r"""Compute cross-channel power spectral density (PSD) matrix.
|
|
57
|
+
|
|
58
|
+
.. devices:: CPU CUDA
|
|
59
|
+
|
|
60
|
+
.. properties:: Autograd TorchScript
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
|
|
64
|
+
normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
|
|
65
|
+
eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.multi_mask = multi_mask
|
|
71
|
+
self.normalize = normalize
|
|
72
|
+
self.eps = eps
|
|
73
|
+
|
|
74
|
+
def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
|
75
|
+
"""
|
|
76
|
+
Args:
|
|
77
|
+
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
|
|
78
|
+
Tensor with dimensions `(..., channel, freq, time)`.
|
|
79
|
+
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
|
|
80
|
+
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
|
|
81
|
+
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
|
|
82
|
+
(Default: ``None``)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
torch.Tensor: The complex-valued PSD matrix of the input spectrum.
|
|
86
|
+
Tensor with dimensions `(..., freq, channel, channel)`
|
|
87
|
+
"""
|
|
88
|
+
if mask is not None:
|
|
89
|
+
if self.multi_mask:
|
|
90
|
+
# Averaging mask along channel dimension
|
|
91
|
+
mask = mask.mean(dim=-3) # (..., freq, time)
|
|
92
|
+
psd = F.psd(specgram, mask, self.normalize, self.eps)
|
|
93
|
+
|
|
94
|
+
return psd
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class MVDR(torch.nn.Module):
|
|
98
|
+
"""Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks.
|
|
99
|
+
|
|
100
|
+
.. devices:: CPU CUDA
|
|
101
|
+
|
|
102
|
+
.. properties:: Autograd TorchScript
|
|
103
|
+
|
|
104
|
+
Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
|
|
105
|
+
|
|
106
|
+
We provide three solutions of MVDR beamforming. One is based on *reference channel selection*
|
|
107
|
+
:cite:`souden2009optimal` (``solution=ref_channel``).
|
|
108
|
+
|
|
109
|
+
.. math::
|
|
110
|
+
\\textbf{w}_{\\text{MVDR}}(f) =\
|
|
111
|
+
\\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bf{\\Phi}_{\\textbf{SS}}}}(f)}\
|
|
112
|
+
{\\text{Trace}({{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f) \\bf{\\Phi}_{\\textbf{SS}}}(f))}}\\bm{u}
|
|
113
|
+
|
|
114
|
+
where :math:`\\bf{\\Phi}_{\\textbf{SS}}` and :math:`\\bf{\\Phi}_{\\textbf{NN}}` are the covariance\
|
|
115
|
+
matrices of speech and noise, respectively. :math:`\\bf{u}` is an one-hot vector to determine the\
|
|
116
|
+
reference channel.
|
|
117
|
+
|
|
118
|
+
The other two solutions are based on the steering vector (``solution=stv_evd`` or ``solution=stv_power``).
|
|
119
|
+
|
|
120
|
+
.. math::
|
|
121
|
+
\\textbf{w}_{\\text{MVDR}}(f) =\
|
|
122
|
+
\\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}}\
|
|
123
|
+
{{\\bm{v}^{\\mathsf{H}}}(f){\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}
|
|
124
|
+
|
|
125
|
+
where :math:`\\bm{v}` is the acoustic transfer function or the steering vector.\
|
|
126
|
+
:math:`.^{\\mathsf{H}}` denotes the Hermitian Conjugate operation.
|
|
127
|
+
|
|
128
|
+
We apply either *eigenvalue decomposition*
|
|
129
|
+
:cite:`higuchi2016robust` or the *power method* :cite:`mises1929praktische` to get the
|
|
130
|
+
steering vector from the PSD matrix of speech.
|
|
131
|
+
|
|
132
|
+
After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by
|
|
133
|
+
|
|
134
|
+
.. math::
|
|
135
|
+
\\hat{\\bf{S}} = {\\bf{w}^\\mathsf{H}}{\\bf{Y}}, {\\bf{w}} \\in \\mathbb{C}^{M \\times F}
|
|
136
|
+
|
|
137
|
+
where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the STFT of the multi-channel noisy speech and\
|
|
138
|
+
the single-channel enhanced speech, respectively.
|
|
139
|
+
|
|
140
|
+
For online streaming audio, we provide a *recursive method* :cite:`higuchi2017online` to update the
|
|
141
|
+
PSD matrices of speech and noise, respectively.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
|
|
145
|
+
solution (str, optional): Solution to compute the MVDR beamforming weights.
|
|
146
|
+
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
|
|
147
|
+
multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
|
|
148
|
+
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
|
|
149
|
+
of the noise. (Default: ``True``)
|
|
150
|
+
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
|
|
151
|
+
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
|
|
152
|
+
online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
|
|
153
|
+
the previous covarience matrices. (Default: ``False``)
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
To improve the numerical stability, the input spectrogram will be converted to double precision
|
|
157
|
+
(``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
|
|
158
|
+
is converted to the dtype of the input spectrogram to be compatible with other modules.
|
|
159
|
+
|
|
160
|
+
Note:
|
|
161
|
+
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
|
|
162
|
+
eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
ref_channel: int = 0,
|
|
168
|
+
solution: str = "ref_channel",
|
|
169
|
+
multi_mask: bool = False,
|
|
170
|
+
diag_loading: bool = True,
|
|
171
|
+
diag_eps: float = 1e-7,
|
|
172
|
+
online: bool = False,
|
|
173
|
+
):
|
|
174
|
+
super().__init__()
|
|
175
|
+
if solution not in [
|
|
176
|
+
"ref_channel",
|
|
177
|
+
"stv_evd",
|
|
178
|
+
"stv_power",
|
|
179
|
+
]:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
'`solution` must be one of ["ref_channel", "stv_evd", "stv_power"]. Given {}'.format(solution)
|
|
182
|
+
)
|
|
183
|
+
self.ref_channel = ref_channel
|
|
184
|
+
self.solution = solution
|
|
185
|
+
self.multi_mask = multi_mask
|
|
186
|
+
self.diag_loading = diag_loading
|
|
187
|
+
self.diag_eps = diag_eps
|
|
188
|
+
self.online = online
|
|
189
|
+
self.psd = PSD(multi_mask)
|
|
190
|
+
|
|
191
|
+
psd_s: torch.Tensor = torch.zeros(1)
|
|
192
|
+
psd_n: torch.Tensor = torch.zeros(1)
|
|
193
|
+
mask_sum_s: torch.Tensor = torch.zeros(1)
|
|
194
|
+
mask_sum_n: torch.Tensor = torch.zeros(1)
|
|
195
|
+
self.register_buffer("psd_s", psd_s)
|
|
196
|
+
self.register_buffer("psd_n", psd_n)
|
|
197
|
+
self.register_buffer("mask_sum_s", mask_sum_s)
|
|
198
|
+
self.register_buffer("mask_sum_n", mask_sum_n)
|
|
199
|
+
|
|
200
|
+
def _get_updated_mvdr_vector(
|
|
201
|
+
self,
|
|
202
|
+
psd_s: torch.Tensor,
|
|
203
|
+
psd_n: torch.Tensor,
|
|
204
|
+
mask_s: torch.Tensor,
|
|
205
|
+
mask_n: torch.Tensor,
|
|
206
|
+
reference_vector: torch.Tensor,
|
|
207
|
+
solution: str = "ref_channel",
|
|
208
|
+
diagonal_loading: bool = True,
|
|
209
|
+
diag_eps: float = 1e-7,
|
|
210
|
+
eps: float = 1e-8,
|
|
211
|
+
) -> torch.Tensor:
|
|
212
|
+
r"""Recursively update the MVDR beamforming vector.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
|
|
216
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
217
|
+
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
|
|
218
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
219
|
+
mask_s (torch.Tensor): Time-Frequency mask of the target speech.
|
|
220
|
+
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
|
|
221
|
+
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
|
|
222
|
+
mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
|
|
223
|
+
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
|
|
224
|
+
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
|
|
225
|
+
reference_vector (torch.Tensor): One-hot reference channel matrix.
|
|
226
|
+
solution (str, optional): Solution to compute the MVDR beamforming weights.
|
|
227
|
+
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
|
|
228
|
+
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
|
|
229
|
+
(Default: ``True``)
|
|
230
|
+
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
|
|
231
|
+
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
|
|
232
|
+
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
|
|
233
|
+
(Default: ``1e-8``)
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
torch.Tensor: The MVDR beamforming weight matrix.
|
|
237
|
+
"""
|
|
238
|
+
if self.multi_mask:
|
|
239
|
+
# Averaging mask along channel dimension
|
|
240
|
+
mask_s = mask_s.mean(dim=-3) # (..., freq, time)
|
|
241
|
+
mask_n = mask_n.mean(dim=-3) # (..., freq, time)
|
|
242
|
+
if self.psd_s.ndim == 1:
|
|
243
|
+
self.psd_s = psd_s
|
|
244
|
+
self.psd_n = psd_n
|
|
245
|
+
self.mask_sum_s = mask_s.sum(dim=-1)
|
|
246
|
+
self.mask_sum_n = mask_n.sum(dim=-1)
|
|
247
|
+
return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
|
|
248
|
+
else:
|
|
249
|
+
psd_s = self._get_updated_psd_speech(psd_s, mask_s)
|
|
250
|
+
psd_n = self._get_updated_psd_noise(psd_n, mask_n)
|
|
251
|
+
self.psd_s = psd_s
|
|
252
|
+
self.psd_n = psd_n
|
|
253
|
+
self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
|
|
254
|
+
self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
|
|
255
|
+
return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
|
|
256
|
+
|
|
257
|
+
def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
|
|
258
|
+
r"""Update psd of speech recursively.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
|
|
262
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
263
|
+
mask_s (torch.Tensor): Time-Frequency mask of the target speech.
|
|
264
|
+
Tensor with dimensions `(..., freq, time)`.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
torch.Tensor: The updated PSD matrix of target speech.
|
|
268
|
+
"""
|
|
269
|
+
numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
|
|
270
|
+
denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
|
|
271
|
+
psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None]
|
|
272
|
+
return psd_s
|
|
273
|
+
|
|
274
|
+
def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
|
|
275
|
+
r"""Update psd of noise recursively.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
|
|
279
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
280
|
+
mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
|
|
281
|
+
Tensor with dimensions `(..., freq, time)`.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
torch.Tensor: The updated PSD matrix of noise.
|
|
285
|
+
"""
|
|
286
|
+
numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
|
|
287
|
+
denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
|
|
288
|
+
psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
|
|
289
|
+
return psd_n
|
|
290
|
+
|
|
291
|
+
def forward(
|
|
292
|
+
self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
|
|
293
|
+
) -> torch.Tensor:
|
|
294
|
+
"""Perform MVDR beamforming.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
|
|
298
|
+
Tensor with dimensions `(..., channel, freq, time)`
|
|
299
|
+
mask_s (torch.Tensor): Time-Frequency mask of target speech.
|
|
300
|
+
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
|
|
301
|
+
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
|
|
302
|
+
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
|
|
303
|
+
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
|
|
304
|
+
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
|
|
305
|
+
(Default: None)
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
|
|
309
|
+
"""
|
|
310
|
+
dtype = specgram.dtype
|
|
311
|
+
if specgram.ndim < 3:
|
|
312
|
+
raise ValueError(f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}")
|
|
313
|
+
if not specgram.is_complex():
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"The type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.\
|
|
316
|
+
Found: {specgram.dtype}"
|
|
317
|
+
)
|
|
318
|
+
if specgram.dtype == torch.cfloat:
|
|
319
|
+
specgram = specgram.cdouble() # Convert specgram to ``torch.cdouble``.
|
|
320
|
+
|
|
321
|
+
if mask_n is None:
|
|
322
|
+
warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
|
|
323
|
+
mask_n = 1 - mask_s
|
|
324
|
+
|
|
325
|
+
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
|
|
326
|
+
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
|
|
327
|
+
|
|
328
|
+
u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
|
|
329
|
+
u[..., self.ref_channel].fill_(1)
|
|
330
|
+
|
|
331
|
+
if self.online:
|
|
332
|
+
w_mvdr = self._get_updated_mvdr_vector(
|
|
333
|
+
psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
w_mvdr = _get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
|
|
337
|
+
|
|
338
|
+
specgram_enhanced = F.apply_beamforming(w_mvdr, specgram)
|
|
339
|
+
|
|
340
|
+
return specgram_enhanced.to(dtype)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class RTFMVDR(torch.nn.Module):
|
|
344
|
+
r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
|
|
345
|
+
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
|
|
346
|
+
|
|
347
|
+
.. devices:: CPU CUDA
|
|
348
|
+
|
|
349
|
+
.. properties:: Autograd TorchScript
|
|
350
|
+
|
|
351
|
+
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
|
|
352
|
+
or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
|
|
353
|
+
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
|
|
354
|
+
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
|
|
355
|
+
|
|
356
|
+
.. math::
|
|
357
|
+
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
|
|
358
|
+
|
|
359
|
+
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
|
|
360
|
+
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
|
|
361
|
+
|
|
362
|
+
The beamforming weight is computed by:
|
|
363
|
+
|
|
364
|
+
.. math::
|
|
365
|
+
\textbf{w}_{\text{MVDR}}(f) =
|
|
366
|
+
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
|
|
367
|
+
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def forward(
|
|
371
|
+
self,
|
|
372
|
+
specgram: Tensor,
|
|
373
|
+
rtf: Tensor,
|
|
374
|
+
psd_n: Tensor,
|
|
375
|
+
reference_channel: Union[int, Tensor],
|
|
376
|
+
diagonal_loading: bool = True,
|
|
377
|
+
diag_eps: float = 1e-7,
|
|
378
|
+
eps: float = 1e-8,
|
|
379
|
+
) -> Tensor:
|
|
380
|
+
"""
|
|
381
|
+
Args:
|
|
382
|
+
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
|
|
383
|
+
Tensor with dimensions `(..., channel, freq, time)`
|
|
384
|
+
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
|
|
385
|
+
Tensor with dimensions `(..., freq, channel)`.
|
|
386
|
+
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
|
|
387
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
388
|
+
reference_channel (int or torch.Tensor): Specifies the reference channel.
|
|
389
|
+
If the dtype is ``int``, it represents the reference channel index.
|
|
390
|
+
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
|
|
391
|
+
is one-hot.
|
|
392
|
+
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
|
|
393
|
+
(Default: ``True``)
|
|
394
|
+
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
|
|
395
|
+
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
|
|
396
|
+
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
|
|
397
|
+
(Default: ``1e-8``)
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
|
|
401
|
+
"""
|
|
402
|
+
w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
|
|
403
|
+
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
|
|
404
|
+
return spectrum_enhanced
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class SoudenMVDR(torch.nn.Module):
|
|
408
|
+
r"""Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
|
|
409
|
+
based on the method proposed by *Souden et, al.* :cite:`souden2009optimal`.
|
|
410
|
+
|
|
411
|
+
.. devices:: CPU CUDA
|
|
412
|
+
|
|
413
|
+
.. properties:: Autograd TorchScript
|
|
414
|
+
|
|
415
|
+
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
|
|
416
|
+
of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
|
|
417
|
+
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
|
|
418
|
+
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
|
|
419
|
+
|
|
420
|
+
.. math::
|
|
421
|
+
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
|
|
422
|
+
|
|
423
|
+
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.
|
|
424
|
+
|
|
425
|
+
The beamforming weight is computed by:
|
|
426
|
+
|
|
427
|
+
.. math::
|
|
428
|
+
\textbf{w}_{\text{MVDR}}(f) =
|
|
429
|
+
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
|
|
430
|
+
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def forward(
|
|
434
|
+
self,
|
|
435
|
+
specgram: Tensor,
|
|
436
|
+
psd_s: Tensor,
|
|
437
|
+
psd_n: Tensor,
|
|
438
|
+
reference_channel: Union[int, Tensor],
|
|
439
|
+
diagonal_loading: bool = True,
|
|
440
|
+
diag_eps: float = 1e-7,
|
|
441
|
+
eps: float = 1e-8,
|
|
442
|
+
) -> torch.Tensor:
|
|
443
|
+
"""
|
|
444
|
+
Args:
|
|
445
|
+
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
|
|
446
|
+
Tensor with dimensions `(..., channel, freq, time)`.
|
|
447
|
+
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
|
|
448
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
449
|
+
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
|
|
450
|
+
Tensor with dimensions `(..., freq, channel, channel)`.
|
|
451
|
+
reference_channel (int or torch.Tensor): Specifies the reference channel.
|
|
452
|
+
If the dtype is ``int``, it represents the reference channel index.
|
|
453
|
+
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
|
|
454
|
+
is one-hot.
|
|
455
|
+
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
|
|
456
|
+
(Default: ``True``)
|
|
457
|
+
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
|
|
458
|
+
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
|
|
459
|
+
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
|
|
460
|
+
(Default: ``1e-8``)
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
|
|
464
|
+
"""
|
|
465
|
+
w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
|
|
466
|
+
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
|
|
467
|
+
return spectrum_enhanced
|