torchaudio 2.0.2__cp39-cp39-win_amd64.whl → 2.1.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (88) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.pyd +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.pyd +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.pyd +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.pyd +0 -0
  51. torchaudio/lib/libtorchaudio.pyd +0 -0
  52. torchaudio/lib/libtorchaudio_ffmpeg4.pyd +0 -0
  53. torchaudio/lib/libtorchaudio_ffmpeg5.pyd +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg6.pyd +0 -0
  55. torchaudio/models/__init__.py +14 -0
  56. torchaudio/models/decoder/__init__.py +22 -7
  57. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  58. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  59. torchaudio/models/rnnt_decoder.py +10 -14
  60. torchaudio/models/squim/__init__.py +11 -0
  61. torchaudio/models/squim/objective.py +326 -0
  62. torchaudio/models/squim/subjective.py +150 -0
  63. torchaudio/models/wav2vec2/components.py +6 -10
  64. torchaudio/pipelines/__init__.py +9 -0
  65. torchaudio/pipelines/_squim_pipeline.py +176 -0
  66. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  67. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  68. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  69. torchaudio/sox_effects/sox_effects.py +7 -30
  70. torchaudio/transforms/__init__.py +2 -0
  71. torchaudio/transforms/_transforms.py +99 -54
  72. torchaudio/utils/download.py +2 -2
  73. torchaudio/utils/ffmpeg_utils.py +20 -15
  74. torchaudio/utils/sox_utils.py +8 -9
  75. torchaudio/version.py +2 -2
  76. torchaudio-2.1.1.dist-info/METADATA +113 -0
  77. torchaudio-2.1.1.dist-info/RECORD +115 -0
  78. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +1 -1
  79. torchaudio/io/_compat.py +0 -241
  80. torchaudio/lib/_torchaudio_ffmpeg.pyd +0 -0
  81. torchaudio/lib/flashlight_lib_text_decoder.pyd +0 -0
  82. torchaudio/lib/flashlight_lib_text_dictionary.pyd +0 -0
  83. torchaudio/lib/libflashlight-text.pyd +0 -0
  84. torchaudio/lib/libtorchaudio_ffmpeg.pyd +0 -0
  85. torchaudio-2.0.2.dist-info/METADATA +0 -26
  86. torchaudio-2.0.2.dist-info/RECORD +0 -98
  87. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  88. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
torchaudio/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from torchaudio import ( # noqa: F401
1
+ from . import ( # noqa: F401
2
2
  _extension,
3
3
  compliance,
4
4
  datasets,
@@ -11,15 +11,34 @@ from torchaudio import ( # noqa: F401
11
11
  transforms,
12
12
  utils,
13
13
  )
14
-
15
- from torchaudio.backend import get_audio_backend, list_audio_backends, set_audio_backend
14
+ from ._backend.common import AudioMetaData # noqa
16
15
 
17
16
  try:
18
17
  from .version import __version__, git_version # noqa: F401
19
18
  except ImportError:
20
19
  pass
21
20
 
21
+
22
+ def _is_backend_dispatcher_enabled():
23
+ import os
24
+
25
+ return os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1"
26
+
27
+
28
+ if _is_backend_dispatcher_enabled():
29
+ from ._backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
30
+ else:
31
+ from .backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
32
+
33
+
34
+ _init_backend()
35
+
36
+ # for backward compatibility. This has to happen after _backend is imported.
37
+ from . import backend # noqa: F401
38
+
39
+
22
40
  __all__ = [
41
+ "AudioMetaData",
23
42
  "io",
24
43
  "compliance",
25
44
  "datasets",
@@ -1,6 +1,57 @@
1
- from .utils import get_info_func, get_load_func, get_save_func
1
+ from typing import List, Optional
2
2
 
3
+ import torchaudio
4
+ from torchaudio._internal.module_utils import deprecated
3
5
 
4
- info = get_info_func()
5
- load = get_load_func()
6
- save = get_save_func()
6
+
7
+ # TODO: Once legacy global backend is removed, move this to torchaudio.__init__
8
+ def _init_backend():
9
+ from . import utils
10
+
11
+ torchaudio.info = utils.get_info_func()
12
+ torchaudio.load = utils.get_load_func()
13
+ torchaudio.save = utils.get_save_func()
14
+
15
+
16
+ def list_audio_backends() -> List[str]:
17
+ """List available backends
18
+
19
+ Returns:
20
+ list of str: The list of available backends.
21
+
22
+ The possible values are;
23
+
24
+ - Dispatcher mode: ``"ffmpeg"``, ``"sox"`` and ``"soundfile"``.
25
+ - Legacy backend mode: ``"sox_io"``, ``"soundfile"``.
26
+ """
27
+ from . import utils
28
+
29
+ return list(utils.get_available_backends().keys())
30
+
31
+
32
+ # Temporary until global backend is removed
33
+ @deprecated("With dispatcher enabled, this function is no-op. You can remove the function call.")
34
+ def get_audio_backend() -> Optional[str]:
35
+ """Get the name of the current global backend
36
+
37
+ Returns:
38
+ str or None:
39
+ If dispatcher mode is enabled, returns ``None`` otherwise,
40
+ the name of current backend or ``None`` (no backend is set).
41
+ """
42
+ return None
43
+
44
+
45
+ # Temporary until global backend is removed
46
+ @deprecated("With dispatcher enabled, this function is no-op. You can remove the function call.")
47
+ def set_audio_backend(backend: Optional[str]): # noqa
48
+ """Set the global backend.
49
+
50
+ This is a no-op when dispatcher mode is enabled.
51
+
52
+ Args:
53
+ backend (str or None): Name of the backend.
54
+ One of ``"sox_io"`` or ``"soundfile"`` based on availability
55
+ of the system. If ``None`` is provided the current backend is unassigned.
56
+ """
57
+ pass
@@ -0,0 +1,53 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import BinaryIO, Optional, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torchaudio.io import CodecConfig
7
+
8
+ from .common import AudioMetaData
9
+
10
+
11
+ class Backend(ABC):
12
+ @staticmethod
13
+ @abstractmethod
14
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
15
+ raise NotImplementedError
16
+
17
+ @staticmethod
18
+ @abstractmethod
19
+ def load(
20
+ uri: Union[BinaryIO, str, os.PathLike],
21
+ frame_offset: int = 0,
22
+ num_frames: int = -1,
23
+ normalize: bool = True,
24
+ channels_first: bool = True,
25
+ format: Optional[str] = None,
26
+ buffer_size: int = 4096,
27
+ ) -> Tuple[Tensor, int]:
28
+ raise NotImplementedError
29
+
30
+ @staticmethod
31
+ @abstractmethod
32
+ def save(
33
+ uri: Union[BinaryIO, str, os.PathLike],
34
+ src: Tensor,
35
+ sample_rate: int,
36
+ channels_first: bool = True,
37
+ format: Optional[str] = None,
38
+ encoding: Optional[str] = None,
39
+ bits_per_sample: Optional[int] = None,
40
+ buffer_size: int = 4096,
41
+ compression: Optional[Union[CodecConfig, float, int]] = None,
42
+ ) -> None:
43
+ raise NotImplementedError
44
+
45
+ @staticmethod
46
+ @abstractmethod
47
+ def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
48
+ raise NotImplementedError
49
+
50
+ @staticmethod
51
+ @abstractmethod
52
+ def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
53
+ raise NotImplementedError
@@ -0,0 +1,52 @@
1
+ class AudioMetaData:
2
+ """AudioMetaData()
3
+
4
+ Return type of ``torchaudio.info`` function.
5
+
6
+ :ivar int sample_rate: Sample rate
7
+ :ivar int num_frames: The number of frames
8
+ :ivar int num_channels: The number of channels
9
+ :ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
10
+ or when it cannot be accurately inferred.
11
+ :ivar str encoding: Audio encoding
12
+ The values encoding can take are one of the following:
13
+
14
+ * ``PCM_S``: Signed integer linear PCM
15
+ * ``PCM_U``: Unsigned integer linear PCM
16
+ * ``PCM_F``: Floating point linear PCM
17
+ * ``FLAC``: Flac, Free Lossless Audio Codec
18
+ * ``ULAW``: Mu-law
19
+ * ``ALAW``: A-law
20
+ * ``MP3`` : MP3, MPEG-1 Audio Layer III
21
+ * ``VORBIS``: OGG Vorbis
22
+ * ``AMR_WB``: Adaptive Multi-Rate Wideband
23
+ * ``AMR_NB``: Adaptive Multi-Rate Narrowband
24
+ * ``OPUS``: Opus
25
+ * ``HTK``: Single channel 16-bit PCM
26
+ * ``UNKNOWN`` : None of above
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ sample_rate: int,
32
+ num_frames: int,
33
+ num_channels: int,
34
+ bits_per_sample: int,
35
+ encoding: str,
36
+ ):
37
+ self.sample_rate = sample_rate
38
+ self.num_frames = num_frames
39
+ self.num_channels = num_channels
40
+ self.bits_per_sample = bits_per_sample
41
+ self.encoding = encoding
42
+
43
+ def __str__(self):
44
+ return (
45
+ f"AudioMetaData("
46
+ f"sample_rate={self.sample_rate}, "
47
+ f"num_frames={self.num_frames}, "
48
+ f"num_channels={self.num_channels}, "
49
+ f"bits_per_sample={self.bits_per_sample}, "
50
+ f"encoding={self.encoding}"
51
+ f")"
52
+ )
@@ -0,0 +1,373 @@
1
+ import os
2
+ import re
3
+ import sys
4
+ from typing import BinaryIO, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torchaudio
8
+ from torchaudio.io import StreamWriter
9
+
10
+ from .backend import Backend
11
+ from .common import AudioMetaData
12
+
13
+ if torchaudio._extension._FFMPEG_EXT is not None:
14
+ StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
15
+ else:
16
+ StreamReaderFileObj = object
17
+
18
+
19
+ def info_audio(
20
+ src: str,
21
+ format: Optional[str],
22
+ ) -> AudioMetaData:
23
+ i = torch.ops.torchaudio.compat_info(src, format)
24
+ return AudioMetaData(i[0], i[1], i[2], i[3], i[4].upper())
25
+
26
+
27
+ def info_audio_fileobj(
28
+ src,
29
+ format: Optional[str],
30
+ buffer_size: int = 4096,
31
+ ) -> AudioMetaData:
32
+ s = StreamReaderFileObj(src, format, None, buffer_size)
33
+ i = s.find_best_audio_stream()
34
+ sinfo = s.get_src_stream_info(i)
35
+ if sinfo.num_frames == 0:
36
+ waveform = _load_audio_fileobj(s)
37
+ num_frames = waveform.size(1)
38
+ else:
39
+ num_frames = sinfo.num_frames
40
+ return AudioMetaData(
41
+ int(sinfo.sample_rate),
42
+ num_frames,
43
+ sinfo.num_channels,
44
+ sinfo.bits_per_sample,
45
+ sinfo.codec_name.upper(),
46
+ )
47
+
48
+
49
+ def _get_load_filter(
50
+ frame_offset: int = 0,
51
+ num_frames: int = -1,
52
+ convert: bool = True,
53
+ ) -> Optional[str]:
54
+ if frame_offset < 0:
55
+ raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
56
+ if num_frames == 0 or num_frames < -1:
57
+ raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
58
+
59
+ # All default values -> no filter
60
+ if frame_offset == 0 and num_frames == -1 and not convert:
61
+ return None
62
+ # Only convert
63
+ aformat = "aformat=sample_fmts=fltp"
64
+ if frame_offset == 0 and num_frames == -1 and convert:
65
+ return aformat
66
+ # At least one of frame_offset or num_frames has non-default value
67
+ if num_frames > 0:
68
+ atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
69
+ else:
70
+ atrim = "atrim=start_sample={}".format(frame_offset)
71
+ if not convert:
72
+ return atrim
73
+ return "{},{}".format(atrim, aformat)
74
+
75
+
76
+ def _load_audio_fileobj(
77
+ s: StreamReaderFileObj,
78
+ filter: Optional[str] = None,
79
+ channels_first: bool = True,
80
+ ) -> torch.Tensor:
81
+ i = s.find_best_audio_stream()
82
+ s.add_audio_stream(i, -1, -1, filter, None, None)
83
+ s.process_all_packets()
84
+ chunk = s.pop_chunks()[0]
85
+ if chunk is None:
86
+ raise RuntimeError("Failed to decode audio.")
87
+ waveform = chunk.frames
88
+ return waveform.T if channels_first else waveform
89
+
90
+
91
+ def load_audio(
92
+ src: str,
93
+ frame_offset: int = 0,
94
+ num_frames: int = -1,
95
+ convert: bool = True,
96
+ channels_first: bool = True,
97
+ format: Optional[str] = None,
98
+ ) -> Tuple[torch.Tensor, int]:
99
+ filter = _get_load_filter(frame_offset, num_frames, convert)
100
+ return torch.ops.torchaudio.compat_load(src, format, filter, channels_first)
101
+
102
+
103
+ def load_audio_fileobj(
104
+ src: BinaryIO,
105
+ frame_offset: int = 0,
106
+ num_frames: int = -1,
107
+ convert: bool = True,
108
+ channels_first: bool = True,
109
+ format: Optional[str] = None,
110
+ buffer_size: int = 4096,
111
+ ) -> Tuple[torch.Tensor, int]:
112
+ demuxer = "ogg" if format == "vorbis" else format
113
+ s = StreamReaderFileObj(src, demuxer, None, buffer_size)
114
+ sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
115
+ filter = _get_load_filter(frame_offset, num_frames, convert)
116
+ waveform = _load_audio_fileobj(s, filter, channels_first)
117
+ return waveform, sample_rate
118
+
119
+
120
+ def _get_sample_format(dtype: torch.dtype) -> str:
121
+ dtype_to_format = {
122
+ torch.uint8: "u8",
123
+ torch.int16: "s16",
124
+ torch.int32: "s32",
125
+ torch.int64: "s64",
126
+ torch.float32: "flt",
127
+ torch.float64: "dbl",
128
+ }
129
+ format = dtype_to_format.get(dtype)
130
+ if format is None:
131
+ raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
132
+ return format
133
+
134
+
135
+ def _native_endianness() -> str:
136
+ if sys.byteorder == "little":
137
+ return "le"
138
+ else:
139
+ return "be"
140
+
141
+
142
+ def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
143
+ if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
144
+ raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
145
+ endianness = _native_endianness()
146
+ if not encoding:
147
+ if not bits_per_sample:
148
+ # default to PCM S16
149
+ return f"pcm_s16{endianness}"
150
+ if bits_per_sample == 8:
151
+ return "pcm_u8"
152
+ return f"pcm_s{bits_per_sample}{endianness}"
153
+ if encoding == "PCM_S":
154
+ if not bits_per_sample:
155
+ bits_per_sample = 16
156
+ if bits_per_sample == 8:
157
+ raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
158
+ return f"pcm_s{bits_per_sample}{endianness}"
159
+ if encoding == "PCM_U":
160
+ if bits_per_sample in (None, 8):
161
+ return "pcm_u8"
162
+ raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
163
+ if encoding == "PCM_F":
164
+ if not bits_per_sample:
165
+ bits_per_sample = 32
166
+ if bits_per_sample in (32, 64):
167
+ return f"pcm_f{bits_per_sample}{endianness}"
168
+ raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
169
+ if encoding == "ULAW":
170
+ if bits_per_sample in (None, 8):
171
+ return "pcm_mulaw"
172
+ raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
173
+ if encoding == "ALAW":
174
+ if bits_per_sample in (None, 8):
175
+ return "pcm_alaw"
176
+ raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
177
+ raise ValueError(f"WAV encoding {encoding} is not supported.")
178
+
179
+
180
+ def _get_flac_sample_fmt(bps):
181
+ if bps is None or bps == 16:
182
+ return "s16"
183
+ if bps == 24:
184
+ return "s32"
185
+ raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
186
+
187
+
188
+ def _parse_save_args(
189
+ ext: Optional[str],
190
+ format: Optional[str],
191
+ encoding: Optional[str],
192
+ bps: Optional[int],
193
+ ):
194
+ # torchaudio's save function accepts the followings, which do not 1to1 map
195
+ # to FFmpeg.
196
+ #
197
+ # - format: audio format
198
+ # - bits_per_sample: encoder sample format
199
+ # - encoding: such as PCM_U8.
200
+ #
201
+ # In FFmpeg, format is specified with the following three (and more)
202
+ #
203
+ # - muxer: could be audio format or container format.
204
+ # the one we passed to the constructor of StreamWriter
205
+ # - encoder: the audio encoder used to encode audio
206
+ # - encoder sample format: the format used by encoder to encode audio.
207
+ #
208
+ # If encoder sample format is different from source sample format, StreamWriter
209
+ # will insert a filter automatically.
210
+ #
211
+ def _type(spec):
212
+ # either format is exactly the specified one
213
+ # or extension matches to the spec AND there is no format override.
214
+ return format == spec or (format is None and ext == spec)
215
+
216
+ if _type("wav") or _type("amb"):
217
+ # wav is special because it supports different encoding through encoders
218
+ # each encoder only supports one encoder format
219
+ #
220
+ # amb format is a special case originated from libsox.
221
+ # It is basically a WAV format, with slight modification.
222
+ # https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
223
+ # It is a format so that decoders will recognize it as ambisonic.
224
+ # https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
225
+ # FFmpeg does not recognize amb because it is basically a WAV format.
226
+ muxer = "wav"
227
+ encoder = _get_encoder_for_wav(encoding, bps)
228
+ sample_fmt = None
229
+ elif _type("vorbis"):
230
+ # FFpmeg does not recognize vorbis extension, while libsox used to do.
231
+ # For the sake of bakward compatibility, (and the simplicity),
232
+ # we support the case where users want to do save("foo.vorbis")
233
+ muxer = "ogg"
234
+ encoder = "vorbis"
235
+ sample_fmt = None
236
+ else:
237
+ muxer = format
238
+ encoder = None
239
+ sample_fmt = None
240
+ if _type("flac"):
241
+ sample_fmt = _get_flac_sample_fmt(bps)
242
+ if _type("ogg"):
243
+ sample_fmt = _get_flac_sample_fmt(bps)
244
+ return muxer, encoder, sample_fmt
245
+
246
+
247
+ def save_audio(
248
+ uri: Union[BinaryIO, str, os.PathLike],
249
+ src: torch.Tensor,
250
+ sample_rate: int,
251
+ channels_first: bool = True,
252
+ format: Optional[str] = None,
253
+ encoding: Optional[str] = None,
254
+ bits_per_sample: Optional[int] = None,
255
+ buffer_size: int = 4096,
256
+ compression: Optional[torchaudio.io.CodecConfig] = None,
257
+ ) -> None:
258
+ ext = None
259
+ if hasattr(uri, "write"):
260
+ if format is None:
261
+ raise RuntimeError("'format' is required when saving to file object.")
262
+ else:
263
+ uri = os.path.normpath(uri)
264
+ if tokens := str(uri).split(".")[1:]:
265
+ ext = tokens[-1].lower()
266
+
267
+ muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
268
+
269
+ if channels_first:
270
+ src = src.T
271
+
272
+ s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
273
+ s.add_audio_stream(
274
+ sample_rate,
275
+ num_channels=src.size(-1),
276
+ format=_get_sample_format(src.dtype),
277
+ encoder=encoder,
278
+ encoder_format=enc_fmt,
279
+ codec_config=compression,
280
+ )
281
+ with s.open():
282
+ s.write_audio_chunk(0, src)
283
+
284
+
285
+ def _map_encoding(encoding: str) -> str:
286
+ for dst in ["PCM_S", "PCM_U", "PCM_F"]:
287
+ if dst in encoding:
288
+ return dst
289
+ if encoding == "PCM_MULAW":
290
+ return "ULAW"
291
+ elif encoding == "PCM_ALAW":
292
+ return "ALAW"
293
+ return encoding
294
+
295
+
296
+ def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
297
+ if m := re.search(r"PCM_\w(\d+)\w*", encoding):
298
+ return int(m.group(1))
299
+ elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
300
+ return 8
301
+ return bits_per_sample
302
+
303
+
304
+ class FFmpegBackend(Backend):
305
+ @staticmethod
306
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
307
+ if hasattr(uri, "read"):
308
+ metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
309
+ else:
310
+ metadata = info_audio(os.path.normpath(uri), format)
311
+ metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
312
+ metadata.encoding = _map_encoding(metadata.encoding)
313
+ return metadata
314
+
315
+ @staticmethod
316
+ def load(
317
+ uri: Union[BinaryIO, str, os.PathLike],
318
+ frame_offset: int = 0,
319
+ num_frames: int = -1,
320
+ normalize: bool = True,
321
+ channels_first: bool = True,
322
+ format: Optional[str] = None,
323
+ buffer_size: int = 4096,
324
+ ) -> Tuple[torch.Tensor, int]:
325
+ if hasattr(uri, "read"):
326
+ return load_audio_fileobj(
327
+ uri,
328
+ frame_offset,
329
+ num_frames,
330
+ normalize,
331
+ channels_first,
332
+ format,
333
+ buffer_size,
334
+ )
335
+ else:
336
+ return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
337
+
338
+ @staticmethod
339
+ def save(
340
+ uri: Union[BinaryIO, str, os.PathLike],
341
+ src: torch.Tensor,
342
+ sample_rate: int,
343
+ channels_first: bool = True,
344
+ format: Optional[str] = None,
345
+ encoding: Optional[str] = None,
346
+ bits_per_sample: Optional[int] = None,
347
+ buffer_size: int = 4096,
348
+ compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
349
+ ) -> None:
350
+ if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
351
+ raise ValueError(
352
+ "FFmpeg backend expects non-`None` value for argument `compression` to be of ",
353
+ f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
354
+ )
355
+ save_audio(
356
+ uri,
357
+ src,
358
+ sample_rate,
359
+ channels_first,
360
+ format,
361
+ encoding,
362
+ bits_per_sample,
363
+ buffer_size,
364
+ compression,
365
+ )
366
+
367
+ @staticmethod
368
+ def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
369
+ return True
370
+
371
+ @staticmethod
372
+ def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
373
+ return True
@@ -0,0 +1,54 @@
1
+ import os
2
+ from typing import BinaryIO, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torchaudio.io import CodecConfig
6
+
7
+ from . import soundfile_backend
8
+ from .backend import Backend
9
+ from .common import AudioMetaData
10
+
11
+
12
+ class SoundfileBackend(Backend):
13
+ @staticmethod
14
+ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
15
+ return soundfile_backend.info(uri, format)
16
+
17
+ @staticmethod
18
+ def load(
19
+ uri: Union[BinaryIO, str, os.PathLike],
20
+ frame_offset: int = 0,
21
+ num_frames: int = -1,
22
+ normalize: bool = True,
23
+ channels_first: bool = True,
24
+ format: Optional[str] = None,
25
+ buffer_size: int = 4096,
26
+ ) -> Tuple[torch.Tensor, int]:
27
+ return soundfile_backend.load(uri, frame_offset, num_frames, normalize, channels_first, format)
28
+
29
+ @staticmethod
30
+ def save(
31
+ uri: Union[BinaryIO, str, os.PathLike],
32
+ src: torch.Tensor,
33
+ sample_rate: int,
34
+ channels_first: bool = True,
35
+ format: Optional[str] = None,
36
+ encoding: Optional[str] = None,
37
+ bits_per_sample: Optional[int] = None,
38
+ buffer_size: int = 4096,
39
+ compression: Optional[Union[CodecConfig, float, int]] = None,
40
+ ) -> None:
41
+ if compression:
42
+ raise ValueError("soundfile backend does not support argument `compression`.")
43
+
44
+ soundfile_backend.save(
45
+ uri, src, sample_rate, channels_first, format=format, encoding=encoding, bits_per_sample=bits_per_sample
46
+ )
47
+
48
+ @staticmethod
49
+ def can_decode(uri, format) -> bool:
50
+ return True
51
+
52
+ @staticmethod
53
+ def can_encode(uri, format) -> bool:
54
+ return True