torchaudio 2.8.0__cp313-cp313-win_amd64.whl → 2.9.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +179 -39
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +12 -3
  5. torchaudio/_torchcodec.py +73 -85
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +0 -2
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +26 -60
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +14 -2
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +1 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +1 -0
  23. torchaudio/transforms/_transforms.py +2 -2
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -3
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -350
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -20
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -150
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -68
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -441
  54. torchaudio/prototype/functional/_rir.py +0 -382
  55. torchaudio/prototype/functional/functional.py +0 -193
  56. torchaudio/prototype/models/__init__.py +0 -39
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -337
  59. torchaudio/prototype/models/conv_emformer.py +0 -529
  60. torchaudio/prototype/models/hifi_gan.py +0 -342
  61. torchaudio/prototype/models/rnnt.py +0 -717
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -402
  63. torchaudio/prototype/pipelines/__init__.py +0 -21
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -461
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -275
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -118
  75. torchaudio-2.8.0.dist-info/RECORD +0 -145
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -977
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -275
  91. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
  92. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,68 +0,0 @@
1
- from pathlib import Path
2
- from typing import Tuple, Union
3
-
4
- import torch
5
- from torch.utils.data import Dataset
6
- from torchaudio.datasets.utils import _load_waveform
7
- from torchaudio._internal.module_utils import dropping_support, dropping_class_support
8
-
9
-
10
- _SUBSETS = ["music", "noise", "speech"]
11
- _SAMPLE_RATE = 16_000
12
-
13
- @dropping_class_support
14
- class Musan(Dataset):
15
- r"""*MUSAN* :cite:`musan2015` dataset.
16
-
17
- Args:
18
- root (str or Path): Root directory where the dataset's top-level directory exists.
19
- subset (str): Subset of the dataset to use. Options: [``"music"``, ``"noise"``, ``"speech"``].
20
- """
21
-
22
- def __init__(self, root: Union[str, Path], subset: str):
23
- if subset not in _SUBSETS:
24
- raise ValueError(f"Invalid subset '{subset}' given. Please provide one of {_SUBSETS}")
25
-
26
- subset_path = Path(root) / subset
27
- self._walker = [str(p) for p in subset_path.glob("*/*.*")]
28
-
29
- def get_metadata(self, n: int) -> Tuple[str, int, str]:
30
- r"""Get metadata for the n-th sample in the dataset. Returns filepath instead of waveform,
31
- but otherwise returns the same fields as :py:func:`__getitem__`.
32
-
33
- Args:
34
- n (int): Index of sample to be loaded.
35
-
36
- Returns:
37
- (str, int, str):
38
- str
39
- Path to audio.
40
- int
41
- Sample rate.
42
- str
43
- File name.
44
- """
45
- audio_path = self._walker[n]
46
- return audio_path, _SAMPLE_RATE, Path(audio_path).name
47
-
48
- def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, str]:
49
- r"""Return the n-th sample in the dataset.
50
-
51
- Args:
52
- n (int): Index of sample to be loaded.
53
-
54
- Returns:
55
- (torch.Tensor, int, str):
56
- torch.Tensor
57
- Waveform.
58
- int
59
- Sample rate.
60
- str
61
- File name.
62
- """
63
- audio_path, sample_rate, filename = self.get_metadata(n)
64
- path = Path(audio_path)
65
- return _load_waveform(path.parent, path.name, sample_rate), sample_rate, filename
66
-
67
- def __len__(self) -> int:
68
- return len(self._walker)
@@ -1,26 +0,0 @@
1
- from ._dsp import (
2
- adsr_envelope,
3
- exp_sigmoid,
4
- extend_pitch,
5
- filter_waveform,
6
- frequency_impulse_response,
7
- oscillator_bank,
8
- sinc_impulse_response,
9
- )
10
- from ._rir import ray_tracing, simulate_rir_ism
11
- from .functional import barkscale_fbanks, chroma_filterbank
12
-
13
-
14
- __all__ = [
15
- "adsr_envelope",
16
- "exp_sigmoid",
17
- "barkscale_fbanks",
18
- "chroma_filterbank",
19
- "extend_pitch",
20
- "filter_waveform",
21
- "frequency_impulse_response",
22
- "oscillator_bank",
23
- "ray_tracing",
24
- "sinc_impulse_response",
25
- "simulate_rir_ism",
26
- ]
@@ -1,441 +0,0 @@
1
- import warnings
2
- from typing import List, Optional, Union
3
-
4
- import torch
5
-
6
- from torchaudio.functional import fftconvolve
7
- from torchaudio._internal.module_utils import dropping_support
8
-
9
-
10
- @dropping_support
11
- def oscillator_bank(
12
- frequencies: torch.Tensor,
13
- amplitudes: torch.Tensor,
14
- sample_rate: float,
15
- reduction: str = "sum",
16
- dtype: Optional[torch.dtype] = torch.float64,
17
- ) -> torch.Tensor:
18
- """Synthesize waveform from the given instantaneous frequencies and amplitudes.
19
-
20
- .. devices:: CPU CUDA
21
-
22
- .. properties:: Autograd TorchScript
23
-
24
- Note:
25
- The phase information of the output waveform is found by taking the cumulative sum
26
- of the given instantaneous frequencies (``frequencies``).
27
- This incurs roundoff error when the data type does not have enough precision.
28
- Using ``torch.float64`` can work around this.
29
-
30
- The following figure shows the difference between ``torch.float32`` and
31
- ``torch.float64`` when generating a sin wave of constant frequency and amplitude
32
- with sample rate 8000 [Hz].
33
- Notice that ``torch.float32`` version shows artifacts that are not seen in
34
- ``torch.float64`` version.
35
-
36
- .. image:: https://download.pytorch.org/torchaudio/doc-assets/oscillator_precision.png
37
-
38
- Args:
39
- frequencies (Tensor): Sample-wise oscillator frequencies (Hz). Shape `(..., time, N)`.
40
- amplitudes (Tensor): Sample-wise oscillator amplitude. Shape: `(..., time, N)`.
41
- sample_rate (float): Sample rate
42
- reduction (str): Reduction to perform.
43
- Valid values are ``"sum"``, ``"mean"`` or ``"none"``. Default: ``"sum"``
44
- dtype (torch.dtype or None, optional): The data type on which cumulative sum operation is performed.
45
- Default: ``torch.float64``. Pass ``None`` to disable the casting.
46
-
47
- Returns:
48
- Tensor:
49
- The resulting waveform.
50
-
51
- If ``reduction`` is ``"none"``, then the shape is
52
- `(..., time, N)`, otherwise the shape is `(..., time)`.
53
- """
54
- if frequencies.shape != amplitudes.shape:
55
- raise ValueError(
56
- "The shapes of `frequencies` and `amplitudes` must match. "
57
- f"Found: {frequencies.shape} and {amplitudes.shape} respectively."
58
- )
59
- reductions = ["sum", "mean", "none"]
60
- if reduction not in reductions:
61
- raise ValueError(f"The value of reduction must be either {reductions}. Found: {reduction}")
62
-
63
- invalid = torch.abs(frequencies) >= sample_rate / 2
64
- if torch.any(invalid):
65
- warnings.warn(
66
- "Some frequencies are above nyquist frequency. "
67
- "Setting the corresponding amplitude to zero. "
68
- "This might cause numerically unstable gradient."
69
- )
70
- amplitudes = torch.where(invalid, 0.0, amplitudes)
71
-
72
- pi2 = 2.0 * torch.pi
73
- freqs = frequencies * pi2 / sample_rate % pi2
74
- phases = torch.cumsum(freqs, dim=-2, dtype=dtype)
75
- if dtype is not None and freqs.dtype != dtype:
76
- phases = phases.to(freqs.dtype)
77
-
78
- waveform = amplitudes * torch.sin(phases)
79
- if reduction == "sum":
80
- return waveform.sum(-1)
81
- if reduction == "mean":
82
- return waveform.mean(-1)
83
- return waveform
84
-
85
-
86
- @dropping_support
87
- def adsr_envelope(
88
- num_frames: int,
89
- *,
90
- attack: float = 0.0,
91
- hold: float = 0.0,
92
- decay: float = 0.0,
93
- sustain: float = 1.0,
94
- release: float = 0.0,
95
- n_decay: int = 2,
96
- dtype: Optional[torch.dtype] = None,
97
- device: Optional[torch.device] = None,
98
- ):
99
- """Generate ADSR Envelope
100
-
101
- .. devices:: CPU CUDA
102
-
103
- Args:
104
- num_frames (int): The number of output frames.
105
- attack (float, optional):
106
- The relative *time* it takes to reach the maximum level from
107
- the start. (Default: ``0.0``)
108
- hold (float, optional):
109
- The relative *time* the maximum level is held before
110
- it starts to decay. (Default: ``0.0``)
111
- decay (float, optional):
112
- The relative *time* it takes to sustain from
113
- the maximum level. (Default: ``0.0``)
114
- sustain (float, optional): The relative *level* at which
115
- the sound should sustain. (Default: ``1.0``)
116
-
117
- .. Note::
118
- The duration of sustain is derived as `1.0 - (The sum of attack, hold, decay and release)`.
119
-
120
- release (float, optional): The relative *time* it takes for the sound level to
121
- reach zero after the sustain. (Default: ``0.0``)
122
- n_decay (int, optional): The degree of polynomial decay. Default: ``2``.
123
- dtype (torch.dtype, optional): the desired data type of returned tensor.
124
- Default: if ``None``, uses a global default
125
- (see :py:func:`torch.set_default_tensor_type`).
126
- device (torch.device, optional): the desired device of returned tensor.
127
- Default: if ``None``, uses the current device for the default tensor type
128
- (see :py:func:`torch.set_default_tensor_type`).
129
- device will be the CPU for CPU tensor types and the current CUDA
130
- device for CUDA tensor types.
131
-
132
- Returns:
133
- Tensor: ADSR Envelope. Shape: `(num_frames, )`
134
-
135
- Example
136
- .. image:: https://download.pytorch.org/torchaudio/doc-assets/adsr_examples.png
137
-
138
- """
139
- if not 0 <= attack <= 1:
140
- raise ValueError(f"The value of `attack` must be within [0, 1]. Found: {attack}")
141
- if not 0 <= decay <= 1:
142
- raise ValueError(f"The value of `decay` must be within [0, 1]. Found: {decay}")
143
- if not 0 <= sustain <= 1:
144
- raise ValueError(f"The value of `sustain` must be within [0, 1]. Found: {sustain}")
145
- if not 0 <= hold <= 1:
146
- raise ValueError(f"The value of `hold` must be within [0, 1]. Found: {hold}")
147
- if not 0 <= release <= 1:
148
- raise ValueError(f"The value of `release` must be within [0, 1]. Found: {release}")
149
- if attack + decay + release + hold > 1:
150
- raise ValueError("The sum of `attack`, `hold`, `decay` and `release` must not exceed 1.")
151
-
152
- nframes = num_frames - 1
153
- num_a = int(nframes * attack)
154
- num_h = int(nframes * hold)
155
- num_d = int(nframes * decay)
156
- num_r = int(nframes * release)
157
-
158
- # Initialize with sustain
159
- out = torch.full((num_frames,), float(sustain), device=device, dtype=dtype)
160
-
161
- # attack
162
- if num_a > 0:
163
- torch.linspace(0.0, 1.0, num_a + 1, out=out[: num_a + 1])
164
-
165
- # hold
166
- if num_h > 0:
167
- out[num_a : num_a + num_h + 1] = 1.0
168
-
169
- # decay
170
- if num_d > 0:
171
- # Compute: sustain + (1.0 - sustain) * (linspace[1, 0] ** n_decay)
172
- i = num_a + num_h
173
- decay = out[i : i + num_d + 1]
174
- torch.linspace(1.0, 0.0, num_d + 1, out=decay)
175
- decay **= n_decay
176
- decay *= 1.0 - sustain
177
- decay += sustain
178
-
179
- # sustain is handled by initialization
180
-
181
- # release
182
- if num_r > 0:
183
- torch.linspace(sustain, 0, num_r + 1, out=out[-num_r - 1 :])
184
-
185
- return out
186
-
187
-
188
- @dropping_support
189
- def extend_pitch(
190
- base: torch.Tensor,
191
- pattern: Union[int, List[float], torch.Tensor],
192
- ):
193
- """Extend the given time series values with multipliers of them.
194
-
195
- .. devices:: CPU CUDA
196
-
197
- .. properties:: Autograd TorchScript
198
-
199
- Given a series of fundamental frequencies (pitch), this function appends
200
- its harmonic overtones or inharmonic partials.
201
-
202
- Args:
203
- base (torch.Tensor):
204
- Base time series, like fundamental frequencies (Hz). Shape: `(..., time, 1)`.
205
- pattern (int, list of floats or torch.Tensor):
206
- If ``int``, the number of pitch series after the operation.
207
- `pattern - 1` tones are added, so that the resulting Tensor contains
208
- up to `pattern`-th overtones of the given series.
209
-
210
- If list of float or ``torch.Tensor``, it must be one dimensional,
211
- representing the custom multiplier of the fundamental frequency.
212
-
213
- Returns:
214
- Tensor: Oscillator frequencies (Hz). Shape: `(..., time, num_tones)`.
215
-
216
- Example
217
- >>> # fundamental frequency
218
- >>> f0 = torch.linspace(1, 5, 5).unsqueeze(-1)
219
- >>> f0
220
- tensor([[1.],
221
- [2.],
222
- [3.],
223
- [4.],
224
- [5.]])
225
- >>> # Add harmonic overtones, up to 3rd.
226
- >>> f = extend_pitch(f0, 3)
227
- >>> f.shape
228
- torch.Size([5, 3])
229
- >>> f
230
- tensor([[ 1., 2., 3.],
231
- [ 2., 4., 6.],
232
- [ 3., 6., 9.],
233
- [ 4., 8., 12.],
234
- [ 5., 10., 15.]])
235
- >>> # Add custom (inharmonic) partials.
236
- >>> f = extend_pitch(f0, torch.tensor([1, 2.1, 3.3, 4.5]))
237
- >>> f.shape
238
- torch.Size([5, 4])
239
- >>> f
240
- tensor([[ 1.0000, 2.1000, 3.3000, 4.5000],
241
- [ 2.0000, 4.2000, 6.6000, 9.0000],
242
- [ 3.0000, 6.3000, 9.9000, 13.5000],
243
- [ 4.0000, 8.4000, 13.2000, 18.0000],
244
- [ 5.0000, 10.5000, 16.5000, 22.5000]])
245
- """
246
- if isinstance(pattern, torch.Tensor):
247
- mult = pattern
248
- elif isinstance(pattern, int):
249
- mult = torch.linspace(1.0, float(pattern), pattern, device=base.device, dtype=base.dtype)
250
- else:
251
- mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
252
- h_freq = base @ mult.unsqueeze(0)
253
- return h_freq
254
-
255
-
256
- @dropping_support
257
- def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
258
- """Create windowed-sinc impulse response for given cutoff frequencies.
259
-
260
- .. devices:: CPU CUDA
261
-
262
- .. properties:: Autograd TorchScript
263
-
264
- Args:
265
- cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.
266
-
267
- window_size (int, optional): Size of the Hamming window to apply. Must be odd.
268
- (Default: 513)
269
-
270
- high_pass (bool, optional):
271
- If ``True``, convert the resulting filter to high-pass.
272
- Otherwise low-pass filter is returned. Default: ``False``.
273
-
274
- Returns:
275
- Tensor: A series of impulse responses. Shape: `(..., window_size)`.
276
- """
277
- if window_size % 2 == 0:
278
- raise ValueError(f"`window_size` must be odd. Given: {window_size}")
279
-
280
- half = window_size // 2
281
- device, dtype = cutoff.device, cutoff.dtype
282
- idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)
283
-
284
- filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
285
- filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
286
- filt = filt / filt.sum(dim=-1, keepdim=True).abs()
287
-
288
- # High pass IR is obtained by subtracting low_pass IR from delta function.
289
- # https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
290
- if high_pass:
291
- filt = -filt
292
- filt[..., half] = 1.0 + filt[..., half]
293
- return filt
294
-
295
-
296
- @dropping_support
297
- def frequency_impulse_response(magnitudes):
298
- """Create filter from desired frequency response
299
-
300
- Args:
301
- magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
302
-
303
- Returns:
304
- Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
305
- """
306
- if magnitudes.min() < 0.0:
307
- # Negative magnitude does not make sense but allowing so that autograd works
308
- # around 0.
309
- # Should we raise error?
310
- warnings.warn("The input frequency response should not contain negative values.")
311
- ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
312
- device, dtype = magnitudes.device, magnitudes.dtype
313
- window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
314
- return ir * window
315
-
316
-
317
- def _overlap_and_add(waveform, stride):
318
- num_frames, frame_size = waveform.shape[-2:]
319
- numel = (num_frames - 1) * stride + frame_size
320
- buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
321
- for i in range(num_frames):
322
- start = i * stride
323
- end = start + frame_size
324
- buffer[..., start:end] += waveform[..., i, :]
325
- return buffer
326
-
327
-
328
- @dropping_support
329
- def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
330
- """Applies filters along time axis of the given waveform.
331
-
332
- This function applies the given filters along time axis in the following manner:
333
-
334
- 1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
335
- 2. Filter each chunk with corresponding filter.
336
- 3. Place the filtered chunks at the original indices while adding up the overlapping parts.
337
- 4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
338
- matches that of the input waveform.
339
-
340
- The following figure illustrates this.
341
-
342
- .. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png
343
-
344
- .. note::
345
-
346
- If the number of filters is one, then the operation becomes stationary.
347
- i.e. the same filtering is applied across the time axis.
348
-
349
- Args:
350
- waveform (Tensor): Shape `(..., time)`.
351
- kernels (Tensor): Impulse responses.
352
- Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
353
- `(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
354
- the dimension of waveform.
355
-
356
- In case of 2D input, the same set of filters is used across channels and batches.
357
- Otherwise, different sets of filters are applied. In this case, the shape of
358
- the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.
359
-
360
- delay_compensation (int): Control how the waveform is cropped after full convolution.
361
- If the value is zero or positive, it is interpreted as the length of crop at the
362
- beginning of the waveform. The value cannot be larger than the size of filter kernel.
363
- Otherwise the initial crop is ``filter_size // 2``.
364
- When cropping happens, the waveform is also cropped from the end so that the
365
- length of the resulting waveform matches the input waveform.
366
-
367
- Returns:
368
- Tensor: `(..., time)`.
369
- """
370
- if kernels.ndim not in [2, waveform.ndim + 1]:
371
- raise ValueError(
372
- "`kernels` must be 2 or N+1 dimension where "
373
- f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
374
- )
375
-
376
- num_filters, filter_size = kernels.shape[-2:]
377
- num_frames = waveform.size(-1)
378
-
379
- if delay_compensation > filter_size:
380
- raise ValueError(
381
- "When `delay_compenstation` is provided, it cannot be larger than the size of filters."
382
- f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
383
- )
384
-
385
- # Transform waveform's time axis into (num_filters x chunk_length) with optional padding
386
- chunk_length = num_frames // num_filters
387
- if num_frames % num_filters > 0:
388
- chunk_length += 1
389
- num_pad = chunk_length * num_filters - num_frames
390
- waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
391
- chunked = waveform.unfold(-1, chunk_length, chunk_length)
392
- assert chunked.numel() >= waveform.numel()
393
-
394
- # Broadcast kernels
395
- if waveform.ndim + 1 > kernels.ndim:
396
- expand_shape = waveform.shape[:-1] + kernels.shape
397
- kernels = kernels.expand(expand_shape)
398
-
399
- convolved = fftconvolve(chunked, kernels)
400
- restored = _overlap_and_add(convolved, chunk_length)
401
-
402
- # Trim in a way that the number of samples are same as input,
403
- # and the filter delay is compensated
404
- if delay_compensation >= 0:
405
- start = delay_compensation
406
- else:
407
- start = filter_size // 2
408
- num_crops = restored.size(-1) - num_frames
409
- end = num_crops - start
410
- result = restored[..., start:-end]
411
- return result
412
-
413
-
414
- @dropping_support
415
- def exp_sigmoid(
416
- input: torch.Tensor, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
417
- ) -> torch.Tensor:
418
- """Exponential Sigmoid pointwise nonlinearity.
419
- Implements the equation:
420
- ``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``
421
-
422
- The output has a range of [``threshold``, ``max_value``].
423
- ``exponent`` controls the slope of the output.
424
-
425
- .. devices:: CPU CUDA
426
-
427
- Args:
428
- input (Tensor): Input Tensor
429
- exponent (float, optional): Exponent. Controls the slope of the output
430
- max_value (float, optional): Maximum value of the output
431
- threshold (float, optional): Minimum value of the output
432
-
433
- Returns:
434
- Tensor: Exponential Sigmoid output. Shape: same as input
435
-
436
- """
437
-
438
- return max_value * torch.pow(
439
- torch.nn.functional.sigmoid(input),
440
- torch.log(torch.tensor(exponent, device=input.device, dtype=input.dtype)),
441
- ) + torch.tensor(threshold, device=input.device, dtype=input.dtype)