torchaudio 2.8.0__cp310-cp310-win_amd64.whl → 2.9.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +179 -39
  2. torchaudio/_extension/__init__.py +1 -14
  3. torchaudio/_extension/utils.py +0 -47
  4. torchaudio/_internal/module_utils.py +12 -3
  5. torchaudio/_torchcodec.py +73 -85
  6. torchaudio/datasets/cmuarctic.py +1 -1
  7. torchaudio/datasets/utils.py +1 -1
  8. torchaudio/functional/__init__.py +0 -2
  9. torchaudio/functional/_alignment.py +1 -1
  10. torchaudio/functional/filtering.py +70 -55
  11. torchaudio/functional/functional.py +26 -60
  12. torchaudio/lib/_torchaudio.pyd +0 -0
  13. torchaudio/lib/libtorchaudio.pyd +0 -0
  14. torchaudio/models/decoder/__init__.py +14 -2
  15. torchaudio/models/decoder/_ctc_decoder.py +6 -6
  16. torchaudio/models/decoder/_cuda_ctc_decoder.py +1 -1
  17. torchaudio/models/squim/objective.py +2 -2
  18. torchaudio/pipelines/_source_separation_pipeline.py +1 -1
  19. torchaudio/pipelines/_squim_pipeline.py +2 -2
  20. torchaudio/pipelines/_tts/utils.py +1 -1
  21. torchaudio/pipelines/rnnt_pipeline.py +4 -4
  22. torchaudio/transforms/__init__.py +1 -0
  23. torchaudio/transforms/_transforms.py +2 -2
  24. torchaudio/utils/__init__.py +2 -9
  25. torchaudio/utils/download.py +1 -3
  26. torchaudio/version.py +2 -2
  27. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/METADATA +8 -11
  28. torchaudio-2.9.0.dist-info/RECORD +85 -0
  29. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/top_level.txt +0 -1
  30. torchaudio/_backend/__init__.py +0 -61
  31. torchaudio/_backend/backend.py +0 -53
  32. torchaudio/_backend/common.py +0 -52
  33. torchaudio/_backend/ffmpeg.py +0 -334
  34. torchaudio/_backend/soundfile.py +0 -54
  35. torchaudio/_backend/soundfile_backend.py +0 -457
  36. torchaudio/_backend/sox.py +0 -91
  37. torchaudio/_backend/utils.py +0 -350
  38. torchaudio/backend/__init__.py +0 -8
  39. torchaudio/backend/_no_backend.py +0 -25
  40. torchaudio/backend/_sox_io_backend.py +0 -294
  41. torchaudio/backend/common.py +0 -13
  42. torchaudio/backend/no_backend.py +0 -14
  43. torchaudio/backend/soundfile_backend.py +0 -14
  44. torchaudio/backend/sox_io_backend.py +0 -14
  45. torchaudio/io/__init__.py +0 -20
  46. torchaudio/io/_effector.py +0 -347
  47. torchaudio/io/_playback.py +0 -72
  48. torchaudio/kaldi_io.py +0 -150
  49. torchaudio/prototype/__init__.py +0 -0
  50. torchaudio/prototype/datasets/__init__.py +0 -4
  51. torchaudio/prototype/datasets/musan.py +0 -68
  52. torchaudio/prototype/functional/__init__.py +0 -26
  53. torchaudio/prototype/functional/_dsp.py +0 -441
  54. torchaudio/prototype/functional/_rir.py +0 -382
  55. torchaudio/prototype/functional/functional.py +0 -193
  56. torchaudio/prototype/models/__init__.py +0 -39
  57. torchaudio/prototype/models/_conformer_wav2vec2.py +0 -801
  58. torchaudio/prototype/models/_emformer_hubert.py +0 -337
  59. torchaudio/prototype/models/conv_emformer.py +0 -529
  60. torchaudio/prototype/models/hifi_gan.py +0 -342
  61. torchaudio/prototype/models/rnnt.py +0 -717
  62. torchaudio/prototype/models/rnnt_decoder.py +0 -402
  63. torchaudio/prototype/pipelines/__init__.py +0 -21
  64. torchaudio/prototype/pipelines/_vggish/__init__.py +0 -7
  65. torchaudio/prototype/pipelines/_vggish/_vggish_impl.py +0 -236
  66. torchaudio/prototype/pipelines/_vggish/_vggish_pipeline.py +0 -83
  67. torchaudio/prototype/pipelines/hifigan_pipeline.py +0 -233
  68. torchaudio/prototype/pipelines/rnnt_pipeline.py +0 -58
  69. torchaudio/prototype/transforms/__init__.py +0 -9
  70. torchaudio/prototype/transforms/_transforms.py +0 -461
  71. torchaudio/sox_effects/__init__.py +0 -10
  72. torchaudio/sox_effects/sox_effects.py +0 -275
  73. torchaudio/utils/ffmpeg_utils.py +0 -11
  74. torchaudio/utils/sox_utils.py +0 -118
  75. torchaudio-2.8.0.dist-info/RECORD +0 -145
  76. torio/__init__.py +0 -8
  77. torio/_extension/__init__.py +0 -13
  78. torio/_extension/utils.py +0 -147
  79. torio/io/__init__.py +0 -9
  80. torio/io/_streaming_media_decoder.py +0 -977
  81. torio/io/_streaming_media_encoder.py +0 -502
  82. torio/lib/__init__.py +0 -0
  83. torio/lib/_torio_ffmpeg4.pyd +0 -0
  84. torio/lib/_torio_ffmpeg5.pyd +0 -0
  85. torio/lib/_torio_ffmpeg6.pyd +0 -0
  86. torio/lib/libtorio_ffmpeg4.pyd +0 -0
  87. torio/lib/libtorio_ffmpeg5.pyd +0 -0
  88. torio/lib/libtorio_ffmpeg6.pyd +0 -0
  89. torio/utils/__init__.py +0 -4
  90. torio/utils/ffmpeg_utils.py +0 -275
  91. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/WHEEL +0 -0
  92. {torchaudio-2.8.0.dist-info → torchaudio-2.9.0.dist-info}/licenses/LICENSE +0 -0
torchaudio/_torchcodec.py CHANGED
@@ -17,28 +17,29 @@ def load_with_torchcodec(
17
17
  backend: Optional[str] = None,
18
18
  ) -> Tuple[torch.Tensor, int]:
19
19
  """Load audio data from source using TorchCodec's AudioDecoder.
20
-
20
+
21
21
  .. note::
22
-
22
+
23
23
  This function supports the same API as :func:`~torchaudio.load`, and
24
24
  relies on TorchCodec's decoding capabilities under the hood. It is
25
25
  provided for convenience, but we do recommend that you port your code to
26
26
  natively use ``torchcodec``'s ``AudioDecoder`` class for better
27
27
  performance:
28
28
  https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.
29
- In TorchAudio 2.9, :func:`~torchaudio.load` will be relying on
29
+ As of TorchAudio 2.9, :func:`~torchaudio.load` relies on
30
30
  :func:`~torchaudio.load_with_torchcodec`. Note that some parameters of
31
31
  :func:`~torchaudio.load`, like ``normalize``, ``buffer_size``, and
32
32
  ``backend``, are ignored by :func:`~torchaudio.load_with_torchcodec`.
33
-
34
-
33
+ To install torchcodec, follow the instructions at https://github.com/pytorch/torchcodec#installing-torchcodec.
34
+
35
+
35
36
  Args:
36
37
  uri (path-like object or file-like object):
37
38
  Source of audio data. The following types are accepted:
38
-
39
+
39
40
  * ``path-like``: File path or URL.
40
41
  * ``file-like``: Object with ``read(size: int) -> bytes`` method.
41
-
42
+
42
43
  frame_offset (int, optional):
43
44
  Number of samples to skip before start reading data.
44
45
  num_frames (int, optional):
@@ -58,17 +59,17 @@ def load_with_torchcodec(
58
59
  Not used by TorchCodec AudioDecoder. Provided for API compatibility.
59
60
  backend (str or None, optional):
60
61
  Not used by TorchCodec AudioDecoder. Provided for API compatibility.
61
-
62
+
62
63
  Returns:
63
64
  (torch.Tensor, int): Resulting Tensor and sample rate.
64
65
  Always returns float32 tensors. If ``channels_first=True``, shape is
65
66
  `[channel, time]`, otherwise `[time, channel]`.
66
-
67
+
67
68
  Raises:
68
69
  ImportError: If torchcodec is not available.
69
70
  ValueError: If unsupported parameters are used.
70
71
  RuntimeError: If TorchCodec fails to decode the audio.
71
-
72
+
72
73
  Note:
73
74
  - TorchCodec always returns normalized float32 samples, so the ``normalize``
74
75
  parameter has no effect.
@@ -81,64 +82,55 @@ def load_with_torchcodec(
81
82
  from torchcodec.decoders import AudioDecoder
82
83
  except ImportError as e:
83
84
  raise ImportError(
84
- "TorchCodec is required for load_with_torchcodec. "
85
- "Please install torchcodec to use this function."
85
+ "TorchCodec is required for load_with_torchcodec. " "Please install torchcodec to use this function."
86
86
  ) from e
87
-
87
+
88
88
  # Parameter validation and warnings
89
89
  if not normalize:
90
90
  import warnings
91
+
91
92
  warnings.warn(
92
93
  "TorchCodec AudioDecoder always returns normalized float32 samples. "
93
94
  "The 'normalize=False' parameter is ignored.",
94
95
  UserWarning,
95
- stacklevel=2
96
+ stacklevel=2,
96
97
  )
97
-
98
+
98
99
  if buffer_size != 4096:
99
100
  import warnings
100
- warnings.warn(
101
- "The 'buffer_size' parameter is not used by TorchCodec AudioDecoder.",
102
- UserWarning,
103
- stacklevel=2
104
- )
105
-
101
+
102
+ warnings.warn("The 'buffer_size' parameter is not used by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
103
+
106
104
  if backend is not None:
107
105
  import warnings
108
- warnings.warn(
109
- "The 'backend' parameter is not used by TorchCodec AudioDecoder.",
110
- UserWarning,
111
- stacklevel=2
112
- )
113
-
106
+
107
+ warnings.warn("The 'backend' parameter is not used by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
108
+
114
109
  if format is not None:
115
110
  import warnings
116
- warnings.warn(
117
- "The 'format' parameter is not supported by TorchCodec AudioDecoder.",
118
- UserWarning,
119
- stacklevel=2
120
- )
121
-
111
+
112
+ warnings.warn("The 'format' parameter is not supported by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
113
+
122
114
  # Create AudioDecoder
123
115
  try:
124
116
  decoder = AudioDecoder(uri)
125
117
  except Exception as e:
126
118
  raise RuntimeError(f"Failed to create AudioDecoder for {uri}: {e}") from e
127
-
119
+
128
120
  # Get sample rate from metadata
129
121
  sample_rate = decoder.metadata.sample_rate
130
122
  if sample_rate is None:
131
123
  raise RuntimeError("Unable to determine sample rate from audio metadata")
132
-
124
+
133
125
  # Decode the entire file first, then subsample manually
134
126
  # This is the simplest approach since torchcodec uses time-based indexing
135
127
  try:
136
128
  audio_samples = decoder.get_all_samples()
137
129
  except Exception as e:
138
130
  raise RuntimeError(f"Failed to decode audio samples: {e}") from e
139
-
131
+
140
132
  data = audio_samples.data
141
-
133
+
142
134
  # Apply frame_offset and num_frames (which are actually sample offsets)
143
135
  if frame_offset > 0:
144
136
  if frame_offset >= data.shape[1]:
@@ -146,19 +138,19 @@ def load_with_torchcodec(
146
138
  empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
147
139
  return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
148
140
  data = data[:, frame_offset:]
149
-
141
+
150
142
  if num_frames == 0:
151
143
  # Return empty tensor if num_frames is 0
152
144
  empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
153
145
  return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
154
146
  elif num_frames > 0:
155
147
  data = data[:, :num_frames]
156
-
148
+
157
149
  # TorchCodec returns data in [channel, time] format by default
158
150
  # Handle channels_first parameter
159
151
  if not channels_first:
160
152
  data = data.transpose(0, 1) # [channel, time] -> [time, channel]
161
-
153
+
162
154
  return data, sample_rate
163
155
 
164
156
 
@@ -177,70 +169,71 @@ def save_with_torchcodec(
177
169
  """Save audio data to file using TorchCodec's AudioEncoder.
178
170
 
179
171
  .. note::
180
-
172
+
181
173
  This function supports the same API as :func:`~torchaudio.save`, and
182
174
  relies on TorchCodec's encoding capabilities under the hood. It is
183
175
  provided for convenience, but we do recommend that you port your code to
184
176
  natively use ``torchcodec``'s ``AudioEncoder`` class for better
185
177
  performance:
186
178
  https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder.
187
- In TorchAudio 2.9, :func:`~torchaudio.save` will be relying on
179
+ As of TorchAudio 2.9, :func:`~torchaudio.save` relies on
188
180
  :func:`~torchaudio.save_with_torchcodec`. Note that some parameters of
189
181
  :func:`~torchaudio.save`, like ``format``, ``encoding``,
190
182
  ``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored by
191
183
  are ignored by :func:`~torchaudio.save_with_torchcodec`.
192
-
184
+ To install torchcodec, follow the instructions at https://github.com/pytorch/torchcodec#installing-torchcodec.
185
+
193
186
  This function provides a TorchCodec-based alternative to torchaudio.save
194
187
  with the same API. TorchCodec's AudioEncoder provides efficient encoding
195
188
  with FFmpeg under the hood.
196
-
189
+
197
190
  Args:
198
191
  uri (path-like object):
199
192
  Path to save the audio file. The file extension determines the format.
200
-
193
+
201
194
  src (torch.Tensor):
202
195
  Audio data to save. Must be a 1D or 2D tensor with float32 values
203
196
  in the range [-1, 1]. If 2D, shape should be [channel, time] when
204
197
  channels_first=True, or [time, channel] when channels_first=False.
205
-
198
+
206
199
  sample_rate (int):
207
200
  Sample rate of the audio data.
208
-
201
+
209
202
  channels_first (bool, optional):
210
203
  Indicates whether the input tensor has channels as the first dimension.
211
204
  If True, expects [channel, time]. If False, expects [time, channel].
212
205
  Default: True.
213
-
206
+
214
207
  format (str or None, optional):
215
208
  Audio format hint. Not used by TorchCodec (format is determined by
216
209
  file extension). A warning is issued if provided.
217
210
  Default: None.
218
-
211
+
219
212
  encoding (str or None, optional):
220
213
  Audio encoding. Not fully supported by TorchCodec AudioEncoder.
221
214
  A warning is issued if provided. Default: None.
222
-
215
+
223
216
  bits_per_sample (int or None, optional):
224
217
  Bits per sample. Not directly supported by TorchCodec AudioEncoder.
225
218
  A warning is issued if provided. Default: None.
226
-
219
+
227
220
  buffer_size (int, optional):
228
221
  Not used by TorchCodec AudioEncoder. Provided for API compatibility.
229
222
  A warning is issued if not default value. Default: 4096.
230
-
223
+
231
224
  backend (str or None, optional):
232
225
  Not used by TorchCodec AudioEncoder. Provided for API compatibility.
233
226
  A warning is issued if provided. Default: None.
234
-
227
+
235
228
  compression (float, int or None, optional):
236
229
  Compression level or bit rate. Maps to bit_rate parameter in
237
230
  TorchCodec AudioEncoder. Default: None.
238
-
231
+
239
232
  Raises:
240
233
  ImportError: If torchcodec is not available.
241
234
  ValueError: If input parameters are invalid.
242
235
  RuntimeError: If TorchCodec fails to encode the audio.
243
-
236
+
244
237
  Note:
245
238
  - TorchCodec AudioEncoder expects float32 samples in [-1, 1] range.
246
239
  - Some parameters (format, encoding, bits_per_sample, buffer_size, backend)
@@ -253,62 +246,56 @@ def save_with_torchcodec(
253
246
  from torchcodec.encoders import AudioEncoder
254
247
  except ImportError as e:
255
248
  raise ImportError(
256
- "TorchCodec is required for save_with_torchcodec. "
257
- "Please install torchcodec to use this function."
249
+ "TorchCodec is required for save_with_torchcodec. " "Please install torchcodec to use this function."
258
250
  ) from e
259
-
251
+
260
252
  # Parameter validation and warnings
261
253
  if format is not None:
262
254
  import warnings
255
+
263
256
  warnings.warn(
264
257
  "The 'format' parameter is not used by TorchCodec AudioEncoder. "
265
258
  "Format is determined by the file extension.",
266
259
  UserWarning,
267
- stacklevel=2
260
+ stacklevel=2,
268
261
  )
269
-
262
+
270
263
  if encoding is not None:
271
264
  import warnings
265
+
272
266
  warnings.warn(
273
- "The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder.",
274
- UserWarning,
275
- stacklevel=2
267
+ "The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder.", UserWarning, stacklevel=2
276
268
  )
277
-
269
+
278
270
  if bits_per_sample is not None:
279
271
  import warnings
272
+
280
273
  warnings.warn(
281
274
  "The 'bits_per_sample' parameter is not directly supported by TorchCodec AudioEncoder.",
282
275
  UserWarning,
283
- stacklevel=2
276
+ stacklevel=2,
284
277
  )
285
-
278
+
286
279
  if buffer_size != 4096:
287
280
  import warnings
288
- warnings.warn(
289
- "The 'buffer_size' parameter is not used by TorchCodec AudioEncoder.",
290
- UserWarning,
291
- stacklevel=2
292
- )
293
-
281
+
282
+ warnings.warn("The 'buffer_size' parameter is not used by TorchCodec AudioEncoder.", UserWarning, stacklevel=2)
283
+
294
284
  if backend is not None:
295
285
  import warnings
296
- warnings.warn(
297
- "The 'backend' parameter is not used by TorchCodec AudioEncoder.",
298
- UserWarning,
299
- stacklevel=2
300
- )
301
-
286
+
287
+ warnings.warn("The 'backend' parameter is not used by TorchCodec AudioEncoder.", UserWarning, stacklevel=2)
288
+
302
289
  # Input validation
303
290
  if not isinstance(src, torch.Tensor):
304
291
  raise ValueError(f"Expected src to be a torch.Tensor, got {type(src)}")
305
-
292
+
306
293
  if src.dtype != torch.float32:
307
294
  src = src.float()
308
-
295
+
309
296
  if sample_rate <= 0:
310
297
  raise ValueError(f"sample_rate must be positive, got {sample_rate}")
311
-
298
+
312
299
  # Handle tensor shape and channels_first
313
300
  if src.ndim == 1:
314
301
  # Convert to 2D: [1, time] for channels_first=True
@@ -324,13 +311,13 @@ def save_with_torchcodec(
324
311
  data = src.transpose(0, 1) # [time, channel] -> [channel, time]
325
312
  else:
326
313
  raise ValueError(f"Expected 1D or 2D tensor, got {src.ndim}D tensor")
327
-
314
+
328
315
  # Create AudioEncoder
329
316
  try:
330
317
  encoder = AudioEncoder(data, sample_rate=sample_rate)
331
318
  except Exception as e:
332
319
  raise RuntimeError(f"Failed to create AudioEncoder: {e}") from e
333
-
320
+
334
321
  # Determine bit_rate from compression parameter
335
322
  bit_rate = None
336
323
  if compression is not None:
@@ -338,13 +325,14 @@ def save_with_torchcodec(
338
325
  bit_rate = int(compression)
339
326
  else:
340
327
  import warnings
328
+
341
329
  warnings.warn(
342
330
  f"Unsupported compression type {type(compression)}. "
343
331
  "TorchCodec AudioEncoder expects int or float for bit_rate.",
344
332
  UserWarning,
345
- stacklevel=2
333
+ stacklevel=2,
346
334
  )
347
-
335
+
348
336
  # Save to file
349
337
  try:
350
338
  encoder.to_file(uri, bit_rate=bit_rate)
@@ -129,7 +129,7 @@ class CMUARCTIC(Dataset):
129
129
  self._text = os.path.join(self._path, self._folder_text, self._file_text)
130
130
 
131
131
  with open(self._text, "r") as text:
132
- walker = csv.reader(text, delimiter="\n")
132
+ walker = csv.reader(text)
133
133
  self._walker = list(walker)
134
134
 
135
135
  def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
@@ -2,7 +2,7 @@ import logging
2
2
  import os
3
3
  import tarfile
4
4
  import zipfile
5
- from typing import Any, List, Optional
5
+ from typing import Any, List, Optional # noqa: F401
6
6
 
7
7
  import torchaudio
8
8
 
@@ -32,7 +32,6 @@ from .functional import (
32
32
  add_noise,
33
33
  amplitude_to_DB,
34
34
  apply_beamforming,
35
- apply_codec,
36
35
  compute_deltas,
37
36
  convolve,
38
37
  create_dct,
@@ -111,7 +110,6 @@ __all__ = [
111
110
  "riaa_biquad",
112
111
  "treble_biquad",
113
112
  "vad",
114
- "apply_codec",
115
113
  "resample",
116
114
  "edit_distance",
117
115
  "pitch_shift",
@@ -70,7 +70,7 @@ def forced_align(
70
70
  assert target_lengths is not None
71
71
 
72
72
  paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
73
- return paths, scores
73
+ return paths, scores[:, torch.arange(scores.shape[1]), paths[0]]
74
74
 
75
75
 
76
76
  @dataclass
@@ -3,6 +3,7 @@ import warnings
3
3
  from typing import Optional
4
4
 
5
5
  import torch
6
+ import torch.nn.functional as F
6
7
  from torch import Tensor
7
8
 
8
9
  from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
@@ -932,69 +933,83 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
932
933
 
933
934
 
934
935
  if _IS_TORCHAUDIO_EXT_AVAILABLE:
935
- _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
936
+ _lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop
936
937
  else:
937
- _lfilter_core_cpu_loop = _lfilter_core_generic_loop
938
+ _lfilter_core_loop = _lfilter_core_generic_loop
939
+
940
+
941
+ class DifferentiableFIR(torch.autograd.Function):
942
+ @staticmethod
943
+ def forward(ctx, waveform, b_coeffs):
944
+ n_order = b_coeffs.size(1)
945
+ n_channel = b_coeffs.size(0)
946
+ b_coeff_flipped = b_coeffs.flip(1).contiguous()
947
+ padded_waveform = F.pad(waveform, (n_order - 1, 0))
948
+ output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
949
+ ctx.save_for_backward(waveform, b_coeffs, output)
950
+ return output
951
+
952
+ @staticmethod
953
+ def backward(ctx, dy):
954
+ x, b_coeffs, y = ctx.saved_tensors
955
+ n_batch = x.size(0)
956
+ n_channel = x.size(1)
957
+ n_order = b_coeffs.size(1)
958
+ db = (
959
+ F.conv1d(
960
+ F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
961
+ dy.view(n_batch * n_channel, 1, -1),
962
+ groups=n_batch * n_channel,
963
+ )
964
+ .view(n_batch, n_channel, -1)
965
+ .sum(0)
966
+ .flip(1)
967
+ if b_coeffs.requires_grad
968
+ else None
969
+ )
970
+ dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None
971
+ return (dx, db)
938
972
 
939
973
 
940
- def _lfilter_core(
941
- waveform: Tensor,
942
- a_coeffs: Tensor,
943
- b_coeffs: Tensor,
944
- ) -> Tensor:
974
+ class DifferentiableIIR(torch.autograd.Function):
975
+ @staticmethod
976
+ def forward(ctx, waveform, a_coeffs_normalized):
977
+ n_batch, n_channel, n_sample = waveform.shape
978
+ n_order = a_coeffs_normalized.size(1)
979
+ n_sample_padded = n_sample + n_order - 1
945
980
 
946
- if a_coeffs.size() != b_coeffs.size():
947
- raise ValueError(
948
- "Expected coeffs to be the same size."
949
- f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
981
+ a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous()
982
+ padded_output_waveform = torch.zeros(
983
+ n_batch, n_channel, n_sample_padded, device=waveform.device, dtype=waveform.dtype
950
984
  )
951
- if waveform.ndim != 3:
952
- raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
953
- if not (waveform.device == a_coeffs.device == b_coeffs.device):
954
- raise ValueError(
955
- "Expected waveform and coeffs to be on the same device."
956
- f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
957
- f"b_coeffs device: {b_coeffs.device}"
985
+ _lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
986
+ output = padded_output_waveform[:, :, n_order - 1 :]
987
+ ctx.save_for_backward(waveform, a_coeffs_normalized, output)
988
+ return output
989
+
990
+ @staticmethod
991
+ def backward(ctx, dy):
992
+ x, a_coeffs_normalized, y = ctx.saved_tensors
993
+ n_channel = x.size(1)
994
+ n_order = a_coeffs_normalized.size(1)
995
+ tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2)
996
+ dx = tmp if x.requires_grad else None
997
+ da = (
998
+ -(
999
+ tmp.transpose(0, 1).reshape(n_channel, 1, -1)
1000
+ @ F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0, 1).reshape(n_channel, -1, n_order)
1001
+ )
1002
+ .squeeze(1)
1003
+ .flip(1)
1004
+ if a_coeffs_normalized.requires_grad
1005
+ else None
958
1006
  )
1007
+ return (dx, da)
959
1008
 
960
- n_batch, n_channel, n_sample = waveform.size()
961
- n_order = a_coeffs.size(1)
962
- if n_order <= 0:
963
- raise ValueError(f"Expected n_order to be positive. Found: {n_order}")
964
-
965
- # Pad the input and create output
966
1009
 
967
- padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
968
- padded_output_waveform = torch.zeros_like(padded_waveform)
969
-
970
- # Set up the coefficients matrix
971
- # Flip coefficients' order
972
- a_coeffs_flipped = a_coeffs.flip(1)
973
- b_coeffs_flipped = b_coeffs.flip(1)
974
-
975
- # calculate windowed_input_signal in parallel using convolution
976
- input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)
977
-
978
- input_signal_windows.div_(a_coeffs[:, :1])
979
- a_coeffs_flipped.div_(a_coeffs[:, :1])
980
-
981
- if (
982
- input_signal_windows.device == torch.device("cpu")
983
- and a_coeffs_flipped.device == torch.device("cpu")
984
- and padded_output_waveform.device == torch.device("cpu")
985
- ):
986
- _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
987
- else:
988
- _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
989
-
990
- output = padded_output_waveform[:, :, n_order - 1 :]
991
- return output
992
-
993
-
994
- if _IS_TORCHAUDIO_EXT_AVAILABLE:
995
- _lfilter = torch.ops.torchaudio._lfilter
996
- else:
997
- _lfilter = _lfilter_core
1010
+ def _lfilter(waveform, a_coeffs, b_coeffs):
1011
+ filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1012
+ return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
998
1013
 
999
1014
 
1000
1015
  def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor: