torchaudio 2.0.2__cp310-cp310-win_amd64.whl → 2.1.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/__init__.py +22 -3
- torchaudio/_backend/__init__.py +55 -4
- torchaudio/_backend/backend.py +53 -0
- torchaudio/_backend/common.py +52 -0
- torchaudio/_backend/ffmpeg.py +373 -0
- torchaudio/_backend/soundfile.py +54 -0
- torchaudio/_backend/soundfile_backend.py +457 -0
- torchaudio/_backend/sox.py +91 -0
- torchaudio/_backend/utils.py +81 -323
- torchaudio/_extension/__init__.py +55 -36
- torchaudio/_extension/utils.py +109 -17
- torchaudio/_internal/__init__.py +4 -1
- torchaudio/_internal/module_utils.py +37 -6
- torchaudio/backend/__init__.py +7 -11
- torchaudio/backend/_no_backend.py +24 -0
- torchaudio/backend/_sox_io_backend.py +297 -0
- torchaudio/backend/common.py +12 -52
- torchaudio/backend/no_backend.py +11 -21
- torchaudio/backend/soundfile_backend.py +11 -448
- torchaudio/backend/sox_io_backend.py +11 -435
- torchaudio/backend/utils.py +9 -18
- torchaudio/datasets/__init__.py +2 -0
- torchaudio/datasets/cmuarctic.py +1 -1
- torchaudio/datasets/cmudict.py +61 -62
- torchaudio/datasets/dr_vctk.py +1 -1
- torchaudio/datasets/gtzan.py +1 -1
- torchaudio/datasets/librilight_limited.py +1 -1
- torchaudio/datasets/librispeech.py +1 -1
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +1 -1
- torchaudio/datasets/ljspeech.py +1 -1
- torchaudio/datasets/musdb_hq.py +1 -1
- torchaudio/datasets/quesst14.py +1 -1
- torchaudio/datasets/speechcommands.py +1 -1
- torchaudio/datasets/tedlium.py +1 -1
- torchaudio/datasets/vctk.py +1 -1
- torchaudio/datasets/voxceleb1.py +1 -1
- torchaudio/datasets/yesno.py +1 -1
- torchaudio/functional/__init__.py +6 -2
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +69 -92
- torchaudio/functional/functional.py +99 -148
- torchaudio/io/__init__.py +4 -1
- torchaudio/io/_effector.py +347 -0
- torchaudio/io/_stream_reader.py +158 -90
- torchaudio/io/_stream_writer.py +196 -10
- torchaudio/lib/_torchaudio.pyd +0 -0
- torchaudio/lib/_torchaudio_ffmpeg4.pyd +0 -0
- torchaudio/lib/_torchaudio_ffmpeg5.pyd +0 -0
- torchaudio/lib/_torchaudio_ffmpeg6.pyd +0 -0
- torchaudio/lib/libtorchaudio.pyd +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg4.pyd +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg5.pyd +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg6.pyd +0 -0
- torchaudio/models/__init__.py +14 -0
- torchaudio/models/decoder/__init__.py +22 -7
- torchaudio/models/decoder/_ctc_decoder.py +123 -69
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/rnnt_decoder.py +10 -14
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/wav2vec2/components.py +6 -10
- torchaudio/pipelines/__init__.py +9 -0
- torchaudio/pipelines/_squim_pipeline.py +176 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +198 -68
- torchaudio/pipelines/_wav2vec2/utils.py +120 -0
- torchaudio/sox_effects/sox_effects.py +7 -30
- torchaudio/transforms/__init__.py +2 -0
- torchaudio/transforms/_transforms.py +99 -54
- torchaudio/utils/download.py +2 -2
- torchaudio/utils/ffmpeg_utils.py +20 -15
- torchaudio/utils/sox_utils.py +8 -9
- torchaudio/version.py +2 -2
- torchaudio-2.1.1.dist-info/METADATA +113 -0
- torchaudio-2.1.1.dist-info/RECORD +115 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +1 -1
- torchaudio/io/_compat.py +0 -241
- torchaudio/lib/_torchaudio_ffmpeg.pyd +0 -0
- torchaudio/lib/flashlight_lib_text_decoder.pyd +0 -0
- torchaudio/lib/flashlight_lib_text_dictionary.pyd +0 -0
- torchaudio/lib/libflashlight-text.pyd +0 -0
- torchaudio/lib/libtorchaudio_ffmpeg.pyd +0 -0
- torchaudio-2.0.2.dist-info/METADATA +0 -26
- torchaudio-2.0.2.dist-info/RECORD +0 -98
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
- {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
torchaudio/io/_stream_writer.py
CHANGED
|
@@ -1,9 +1,45 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import BinaryIO, Dict, Optional, Union
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
import torchaudio
|
|
5
6
|
|
|
6
7
|
|
|
8
|
+
if torchaudio._extension._FFMPEG_EXT is None:
|
|
9
|
+
ConfigBase = object
|
|
10
|
+
else:
|
|
11
|
+
ConfigBase = torchaudio._extension._FFMPEG_EXT.CodecConfig
|
|
12
|
+
_StreamWriter = torchaudio._extension._FFMPEG_EXT.StreamWriter
|
|
13
|
+
_StreamWriterFileObj = torchaudio._extension._FFMPEG_EXT.StreamWriterFileObj
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CodecConfig(ConfigBase):
|
|
18
|
+
"""Codec configuration."""
|
|
19
|
+
|
|
20
|
+
bit_rate: int = -1
|
|
21
|
+
"""Bit rate"""
|
|
22
|
+
|
|
23
|
+
compression_level: int = -1
|
|
24
|
+
"""Compression level"""
|
|
25
|
+
|
|
26
|
+
qscale: Optional[int] = None
|
|
27
|
+
"""Global quality factor. Enables variable bit rate. Valid values depend on encoder.
|
|
28
|
+
|
|
29
|
+
For example: MP3 takes ``0`` - ``9`` (https://trac.ffmpeg.org/wiki/Encode/MP3) while
|
|
30
|
+
libvorbis takes ``-1`` - ``10``.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
gop_size: int = -1
|
|
34
|
+
"""The number of pictures in a group of pictures, or 0 for intra_only"""
|
|
35
|
+
|
|
36
|
+
max_b_frames: int = -1
|
|
37
|
+
"""maximum number of B-frames between non-B-frames."""
|
|
38
|
+
|
|
39
|
+
def __post_init__(self):
|
|
40
|
+
super().__init__(self.bit_rate, self.compression_level, self.qscale, self.gop_size, self.max_b_frames)
|
|
41
|
+
|
|
42
|
+
|
|
7
43
|
def _format_doc(**kwargs):
|
|
8
44
|
def decorator(obj):
|
|
9
45
|
obj.__doc__ = obj.__doc__.format(**kwargs)
|
|
@@ -28,7 +64,29 @@ _encoder_option = """Options passed to encoder.
|
|
|
28
64
|
To list encoder options for a encoder, you can use
|
|
29
65
|
``ffmpeg -h encoder=<ENCODER>`` command.
|
|
30
66
|
|
|
31
|
-
Default: ``None``.
|
|
67
|
+
Default: ``None``.
|
|
68
|
+
|
|
69
|
+
|
|
|
70
|
+
|
|
71
|
+
In addition to encoder-specific options, you can also pass options related
|
|
72
|
+
to multithreading. They are effective only if the encoder support them.
|
|
73
|
+
If neither of them are provided, StreamReader defaults to single thread.
|
|
74
|
+
|
|
75
|
+
``"threads"``: The number of threads (in str).
|
|
76
|
+
Providing the value ``"0"`` will let FFmpeg decides based on its heuristics.
|
|
77
|
+
|
|
78
|
+
``"thread_type"``: Which multithreading method to use.
|
|
79
|
+
The valid values are ``"frame"`` or ``"slice"``.
|
|
80
|
+
Note that each encoder supports different set of methods.
|
|
81
|
+
If not provided, a default value is used.
|
|
82
|
+
|
|
83
|
+
- ``"frame"``: Encode more than one frame at once.
|
|
84
|
+
Each thread handles one frame.
|
|
85
|
+
This will increase decoding delay by one frame per thread
|
|
86
|
+
- ``"slice"``: Encode more than one part of a single frame at once.
|
|
87
|
+
|
|
88
|
+
|
|
|
89
|
+
"""
|
|
32
90
|
|
|
33
91
|
|
|
34
92
|
_encoder_format = """Format used to encode media.
|
|
@@ -38,13 +96,34 @@ _encoder_format = """Format used to encode media.
|
|
|
38
96
|
To list supported formats for the encoder, you can use
|
|
39
97
|
``ffmpeg -h encoder=<ENCODER>`` command.
|
|
40
98
|
|
|
99
|
+
Default: ``None``.
|
|
100
|
+
|
|
101
|
+
Note:
|
|
102
|
+
When ``encoder_format`` option is not provided, encoder uses its default format.
|
|
103
|
+
|
|
104
|
+
For example, when encoding audio into wav format, 16-bit signed integer is used,
|
|
105
|
+
and when encoding video into mp4 format (h264 encoder), one of YUV format is used.
|
|
106
|
+
|
|
107
|
+
This is because typically, 32-bit or 16-bit floating point is used in audio models but
|
|
108
|
+
they are not commonly used in audio formats. Similarly, RGB24 is commonly used in vision
|
|
109
|
+
models, but video formats usually (and better) support YUV formats.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
_codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig` for
|
|
113
|
+
configuration options.
|
|
114
|
+
|
|
41
115
|
Default: ``None``."""
|
|
42
116
|
|
|
43
117
|
|
|
118
|
+
_filter_desc = """Additional processing to apply before encoding the input media.
|
|
119
|
+
"""
|
|
120
|
+
|
|
44
121
|
_format_common_args = _format_doc(
|
|
45
122
|
encoder=_encoder,
|
|
46
123
|
encoder_option=_encoder_option,
|
|
47
124
|
encoder_format=_encoder_format,
|
|
125
|
+
codec_config=_codec_config,
|
|
126
|
+
filter_desc=_filter_desc,
|
|
48
127
|
)
|
|
49
128
|
|
|
50
129
|
|
|
@@ -109,11 +188,10 @@ class StreamWriter:
|
|
|
109
188
|
format: Optional[str] = None,
|
|
110
189
|
buffer_size: int = 4096,
|
|
111
190
|
):
|
|
112
|
-
torch._C._log_api_usage_once("torchaudio.io.StreamWriter")
|
|
113
191
|
if isinstance(dst, str):
|
|
114
|
-
self._s =
|
|
192
|
+
self._s = _StreamWriter(dst, format)
|
|
115
193
|
elif hasattr(dst, "write"):
|
|
116
|
-
self._s =
|
|
194
|
+
self._s = _StreamWriterFileObj(dst, format, buffer_size)
|
|
117
195
|
else:
|
|
118
196
|
raise ValueError("`dst` must be either a string or a file-like object.")
|
|
119
197
|
self._is_open = False
|
|
@@ -124,9 +202,14 @@ class StreamWriter:
|
|
|
124
202
|
sample_rate: int,
|
|
125
203
|
num_channels: int,
|
|
126
204
|
format: str = "flt",
|
|
205
|
+
*,
|
|
127
206
|
encoder: Optional[str] = None,
|
|
128
207
|
encoder_option: Optional[Dict[str, str]] = None,
|
|
208
|
+
encoder_sample_rate: Optional[int] = None,
|
|
209
|
+
encoder_num_channels: Optional[int] = None,
|
|
129
210
|
encoder_format: Optional[str] = None,
|
|
211
|
+
codec_config: Optional[CodecConfig] = None,
|
|
212
|
+
filter_desc: Optional[str] = None,
|
|
130
213
|
):
|
|
131
214
|
"""Add an output audio stream.
|
|
132
215
|
|
|
@@ -151,9 +234,56 @@ class StreamWriter:
|
|
|
151
234
|
|
|
152
235
|
encoder_option (dict or None, optional): {encoder_option}
|
|
153
236
|
|
|
237
|
+
encoder_sample_rate (int or None, optional): Override the sample rate used for encoding time.
|
|
238
|
+
Some encoders pose restriction on the sample rate used for encoding.
|
|
239
|
+
If the source sample rate is not supported by the encoder, the source sample rate is used,
|
|
240
|
+
otherwise a default one is picked.
|
|
241
|
+
|
|
242
|
+
For example, ``"opus"`` encoder only supports 48k Hz, so, when encoding a
|
|
243
|
+
waveform with ``"opus"`` encoder, it is always encoded as 48k Hz.
|
|
244
|
+
Meanwhile ``"mp3"`` (``"libmp3lame"``) supports 44.1k, 48k, 32k, 22.05k,
|
|
245
|
+
24k, 16k, 11.025k, 12k and 8k Hz.
|
|
246
|
+
If the original sample rate is one of these, then the original sample rate
|
|
247
|
+
is used, otherwise it will be resampled to a default one (44.1k).
|
|
248
|
+
When encoding into WAV format, there is no restriction on sample rate,
|
|
249
|
+
so the original sample rate will be used.
|
|
250
|
+
|
|
251
|
+
Providing ``encoder_sample_rate`` will override this behavior and
|
|
252
|
+
make encoder attempt to use the provided sample rate.
|
|
253
|
+
The provided value must be one support by the encoder.
|
|
254
|
+
|
|
255
|
+
encoder_num_channels (int or None, optional): Override the number of channels used for encoding.
|
|
256
|
+
|
|
257
|
+
Similar to sample rate, some encoders (such as ``"opus"``,
|
|
258
|
+
``"vorbis"`` and ``"g722"``) pose restriction on
|
|
259
|
+
the numbe of channels that can be used for encoding.
|
|
260
|
+
|
|
261
|
+
If the original number of channels is supported by encoder,
|
|
262
|
+
then it will be used, otherwise, the encoder attempts to
|
|
263
|
+
remix the channel to one of the supported ones.
|
|
264
|
+
|
|
265
|
+
Providing ``encoder_num_channels`` will override this behavior and
|
|
266
|
+
make encoder attempt to use the provided number of channels.
|
|
267
|
+
The provided value must be one support by the encoder.
|
|
268
|
+
|
|
154
269
|
encoder_format (str or None, optional): {encoder_format}
|
|
270
|
+
|
|
271
|
+
codec_config (CodecConfig or None, optional): {codec_config}
|
|
272
|
+
|
|
273
|
+
filter_desc (str or None, optional): {filter_desc}
|
|
155
274
|
"""
|
|
156
|
-
self._s.add_audio_stream(
|
|
275
|
+
self._s.add_audio_stream(
|
|
276
|
+
sample_rate,
|
|
277
|
+
num_channels,
|
|
278
|
+
format,
|
|
279
|
+
encoder,
|
|
280
|
+
encoder_option,
|
|
281
|
+
encoder_format,
|
|
282
|
+
encoder_sample_rate,
|
|
283
|
+
encoder_num_channels,
|
|
284
|
+
codec_config,
|
|
285
|
+
filter_desc,
|
|
286
|
+
)
|
|
157
287
|
|
|
158
288
|
@_format_common_args
|
|
159
289
|
def add_video_stream(
|
|
@@ -162,9 +292,15 @@ class StreamWriter:
|
|
|
162
292
|
width: int,
|
|
163
293
|
height: int,
|
|
164
294
|
format: str = "rgb24",
|
|
295
|
+
*,
|
|
165
296
|
encoder: Optional[str] = None,
|
|
166
297
|
encoder_option: Optional[Dict[str, str]] = None,
|
|
298
|
+
encoder_frame_rate: Optional[float] = None,
|
|
299
|
+
encoder_width: Optional[int] = None,
|
|
300
|
+
encoder_height: Optional[int] = None,
|
|
167
301
|
encoder_format: Optional[str] = None,
|
|
302
|
+
codec_config: Optional[CodecConfig] = None,
|
|
303
|
+
filter_desc: Optional[str] = None,
|
|
168
304
|
hw_accel: Optional[str] = None,
|
|
169
305
|
):
|
|
170
306
|
"""Add an output video stream.
|
|
@@ -195,8 +331,30 @@ class StreamWriter:
|
|
|
195
331
|
|
|
196
332
|
encoder_option (dict or None, optional): {encoder_option}
|
|
197
333
|
|
|
334
|
+
encoder_frame_rate (float or None, optional): Override the frame rate used for encoding.
|
|
335
|
+
|
|
336
|
+
Some encoders, (such as ``"mpeg1"`` and ``"mpeg2"``) pose restriction on the
|
|
337
|
+
frame rate that can be used for encoding.
|
|
338
|
+
If such case, if the source frame rate (provided as ``frame_rate``) is not
|
|
339
|
+
one of the supported frame rate, then a default one is picked, and the frame rate
|
|
340
|
+
is changed on-the-fly. Otherwise the source frame rate is used.
|
|
341
|
+
|
|
342
|
+
Providing ``encoder_frame_rate`` will override this behavior and
|
|
343
|
+
make encoder attempts to use the provided sample rate.
|
|
344
|
+
The provided value must be one support by the encoder.
|
|
345
|
+
|
|
346
|
+
encoder_width (int or None, optional): Width of the image used for encoding.
|
|
347
|
+
This allows to change the image size during encoding.
|
|
348
|
+
|
|
349
|
+
encoder_height (int or None, optional): Height of the image used for encoding.
|
|
350
|
+
This allows to change the image size during encoding.
|
|
351
|
+
|
|
198
352
|
encoder_format (str or None, optional): {encoder_format}
|
|
199
353
|
|
|
354
|
+
codec_config (CodecConfig or None, optional): {codec_config}
|
|
355
|
+
|
|
356
|
+
filter_desc (str or None, optional): {filter_desc}
|
|
357
|
+
|
|
200
358
|
hw_accel (str or None, optional): Enable hardware acceleration.
|
|
201
359
|
|
|
202
360
|
When video is encoded on CUDA hardware, for example
|
|
@@ -207,7 +365,21 @@ class StreamWriter:
|
|
|
207
365
|
If `None`, the video chunk Tensor has to be CPU Tensor.
|
|
208
366
|
Default: ``None``.
|
|
209
367
|
"""
|
|
210
|
-
self._s.add_video_stream(
|
|
368
|
+
self._s.add_video_stream(
|
|
369
|
+
frame_rate,
|
|
370
|
+
width,
|
|
371
|
+
height,
|
|
372
|
+
format,
|
|
373
|
+
encoder,
|
|
374
|
+
encoder_option,
|
|
375
|
+
encoder_format,
|
|
376
|
+
encoder_frame_rate,
|
|
377
|
+
encoder_width,
|
|
378
|
+
encoder_height,
|
|
379
|
+
hw_accel,
|
|
380
|
+
codec_config,
|
|
381
|
+
filter_desc,
|
|
382
|
+
)
|
|
211
383
|
|
|
212
384
|
def set_metadata(self, metadata: Dict[str, str]):
|
|
213
385
|
"""Set file-level metadata
|
|
@@ -276,17 +448,24 @@ class StreamWriter:
|
|
|
276
448
|
self._s.close()
|
|
277
449
|
self._is_open = False
|
|
278
450
|
|
|
279
|
-
def write_audio_chunk(self, i: int, chunk: torch.Tensor):
|
|
451
|
+
def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
|
|
280
452
|
"""Write audio data
|
|
281
453
|
|
|
282
454
|
Args:
|
|
283
455
|
i (int): Stream index.
|
|
284
456
|
chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
|
|
285
457
|
The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
|
|
458
|
+
pts (float, optional, or None): If provided, overwrite the presentation timestamp.
|
|
459
|
+
|
|
460
|
+
.. note::
|
|
461
|
+
|
|
462
|
+
The provided value is converted to integer value expressed in basis of
|
|
463
|
+
sample rate. Therefore, it is truncated to the nearest value of
|
|
464
|
+
``n / sample_rate``.
|
|
286
465
|
"""
|
|
287
|
-
self._s.write_audio_chunk(i, chunk)
|
|
466
|
+
self._s.write_audio_chunk(i, chunk, pts)
|
|
288
467
|
|
|
289
|
-
def write_video_chunk(self, i: int, chunk: torch.Tensor):
|
|
468
|
+
def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
|
|
290
469
|
"""Write video/image data
|
|
291
470
|
|
|
292
471
|
Args:
|
|
@@ -296,8 +475,15 @@ class StreamWriter:
|
|
|
296
475
|
The ``dtype`` must be ``torch.uint8``.
|
|
297
476
|
The shape (height, width and the number of channels) must match
|
|
298
477
|
what was configured when calling :py:meth:`add_video_stream`
|
|
478
|
+
pts (float, optional or None): If provided, overwrite the presentation timestamp.
|
|
479
|
+
|
|
480
|
+
.. note::
|
|
481
|
+
|
|
482
|
+
The provided value is converted to integer value expressed in basis of
|
|
483
|
+
frame rate. Therefore, it is truncated to the nearest value of
|
|
484
|
+
``n / frame_rate``.
|
|
299
485
|
"""
|
|
300
|
-
self._s.write_video_chunk(i, chunk)
|
|
486
|
+
self._s.write_video_chunk(i, chunk, pts)
|
|
301
487
|
|
|
302
488
|
def flush(self):
|
|
303
489
|
"""Flush the frames from encoders and write the frames to the destination."""
|
torchaudio/lib/_torchaudio.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
torchaudio/lib/libtorchaudio.pyd
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
torchaudio/models/__init__.py
CHANGED
|
@@ -5,6 +5,14 @@ from .deepspeech import DeepSpeech
|
|
|
5
5
|
from .emformer import Emformer
|
|
6
6
|
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
|
|
7
7
|
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
|
|
8
|
+
from .squim import (
|
|
9
|
+
squim_objective_base,
|
|
10
|
+
squim_objective_model,
|
|
11
|
+
squim_subjective_base,
|
|
12
|
+
squim_subjective_model,
|
|
13
|
+
SquimObjective,
|
|
14
|
+
SquimSubjective,
|
|
15
|
+
)
|
|
8
16
|
from .tacotron2 import Tacotron2
|
|
9
17
|
from .wav2letter import Wav2Letter
|
|
10
18
|
from .wav2vec2 import (
|
|
@@ -68,4 +76,10 @@ __all__ = [
|
|
|
68
76
|
"hdemucs_low",
|
|
69
77
|
"hdemucs_medium",
|
|
70
78
|
"hdemucs_high",
|
|
79
|
+
"squim_objective_base",
|
|
80
|
+
"squim_objective_model",
|
|
81
|
+
"squim_subjective_base",
|
|
82
|
+
"squim_subjective_model",
|
|
83
|
+
"SquimObjective",
|
|
84
|
+
"SquimSubjective",
|
|
71
85
|
]
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
_LAZILY_IMPORTED = [
|
|
1
|
+
_CTC_DECODERS = [
|
|
3
2
|
"CTCHypothesis",
|
|
4
3
|
"CTCDecoder",
|
|
5
4
|
"CTCDecoderLM",
|
|
@@ -7,25 +6,41 @@ _LAZILY_IMPORTED = [
|
|
|
7
6
|
"ctc_decoder",
|
|
8
7
|
"download_pretrained_files",
|
|
9
8
|
]
|
|
9
|
+
_CUDA_CTC_DECODERS = [
|
|
10
|
+
"CUCTCDecoder",
|
|
11
|
+
"CUCTCHypothesis",
|
|
12
|
+
"cuda_ctc_decoder",
|
|
13
|
+
]
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
def __getattr__(name: str):
|
|
13
|
-
if name in
|
|
17
|
+
if name in _CTC_DECODERS:
|
|
14
18
|
try:
|
|
15
19
|
from . import _ctc_decoder
|
|
16
|
-
except
|
|
20
|
+
except Exception as err:
|
|
17
21
|
raise RuntimeError(
|
|
18
|
-
"CTC
|
|
22
|
+
"CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
|
|
19
23
|
) from err
|
|
20
24
|
|
|
21
25
|
item = getattr(_ctc_decoder, name)
|
|
22
26
|
globals()[name] = item
|
|
23
27
|
return item
|
|
28
|
+
elif name in _CUDA_CTC_DECODERS:
|
|
29
|
+
try:
|
|
30
|
+
from . import _cuda_ctc_decoder
|
|
31
|
+
except AttributeError as err:
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
"To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
|
|
34
|
+
) from err
|
|
35
|
+
|
|
36
|
+
item = getattr(_cuda_ctc_decoder, name)
|
|
37
|
+
globals()[name] = item
|
|
38
|
+
return item
|
|
24
39
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
25
40
|
|
|
26
41
|
|
|
27
42
|
def __dir__():
|
|
28
|
-
return sorted(__all__
|
|
43
|
+
return sorted(__all__)
|
|
29
44
|
|
|
30
45
|
|
|
31
|
-
__all__ = []
|
|
46
|
+
__all__ = [_CTC_DECODERS, _CUDA_CTC_DECODERS]
|
|
@@ -2,69 +2,38 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import itertools as it
|
|
4
4
|
|
|
5
|
-
import warnings
|
|
6
5
|
from abc import abstractmethod
|
|
7
6
|
from collections import namedtuple
|
|
8
7
|
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import torch
|
|
11
|
-
import torchaudio
|
|
12
|
-
from torchaudio.utils import download_asset
|
|
13
|
-
|
|
14
10
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
Dictionary as _Dictionary,
|
|
34
|
-
load_words as _load_words,
|
|
35
|
-
)
|
|
11
|
+
from flashlight.lib.text.decoder import (
|
|
12
|
+
CriterionType as _CriterionType,
|
|
13
|
+
LexiconDecoder as _LexiconDecoder,
|
|
14
|
+
LexiconDecoderOptions as _LexiconDecoderOptions,
|
|
15
|
+
LexiconFreeDecoder as _LexiconFreeDecoder,
|
|
16
|
+
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
|
|
17
|
+
LM as _LM,
|
|
18
|
+
LMState as _LMState,
|
|
19
|
+
SmearingMode as _SmearingMode,
|
|
20
|
+
Trie as _Trie,
|
|
21
|
+
ZeroLM as _ZeroLM,
|
|
22
|
+
)
|
|
23
|
+
from flashlight.lib.text.dictionary import (
|
|
24
|
+
create_word_dict as _create_word_dict,
|
|
25
|
+
Dictionary as _Dictionary,
|
|
26
|
+
load_words as _load_words,
|
|
27
|
+
)
|
|
28
|
+
from torchaudio.utils import download_asset
|
|
36
29
|
|
|
30
|
+
try:
|
|
31
|
+
from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
|
|
32
|
+
except Exception:
|
|
37
33
|
try:
|
|
38
34
|
from flashlight.lib.text.decoder import KenLM as _KenLM
|
|
39
35
|
except Exception:
|
|
40
36
|
_KenLM = None
|
|
41
|
-
else:
|
|
42
|
-
torchaudio._extension._load_lib("libflashlight-text")
|
|
43
|
-
from torchaudio.lib.flashlight_lib_text_decoder import (
|
|
44
|
-
CriterionType as _CriterionType,
|
|
45
|
-
KenLM as _KenLM,
|
|
46
|
-
LexiconDecoder as _LexiconDecoder,
|
|
47
|
-
LexiconDecoderOptions as _LexiconDecoderOptions,
|
|
48
|
-
LexiconFreeDecoder as _LexiconFreeDecoder,
|
|
49
|
-
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
|
|
50
|
-
LM as _LM,
|
|
51
|
-
LMState as _LMState,
|
|
52
|
-
SmearingMode as _SmearingMode,
|
|
53
|
-
Trie as _Trie,
|
|
54
|
-
ZeroLM as _ZeroLM,
|
|
55
|
-
)
|
|
56
|
-
from torchaudio.lib.flashlight_lib_text_dictionary import (
|
|
57
|
-
create_word_dict as _create_word_dict,
|
|
58
|
-
Dictionary as _Dictionary,
|
|
59
|
-
load_words as _load_words,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
warnings.warn(
|
|
63
|
-
"The built-in flashlight integration is deprecated, and will be removed in future release. "
|
|
64
|
-
"Please install flashlight-text. https://pypi.org/project/flashlight-text/ "
|
|
65
|
-
"For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088."
|
|
66
|
-
)
|
|
67
|
-
|
|
68
37
|
|
|
69
38
|
__all__ = [
|
|
70
39
|
"CTCHypothesis",
|
|
@@ -292,10 +261,102 @@ class CTCDecoder:
|
|
|
292
261
|
timesteps.append(i)
|
|
293
262
|
return torch.IntTensor(timesteps)
|
|
294
263
|
|
|
264
|
+
def decode_begin(self):
|
|
265
|
+
"""Initialize the internal state of the decoder.
|
|
266
|
+
|
|
267
|
+
See :py:meth:`decode_step` for the usage.
|
|
268
|
+
|
|
269
|
+
.. note::
|
|
270
|
+
|
|
271
|
+
This method is required only when performing online decoding.
|
|
272
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
273
|
+
"""
|
|
274
|
+
self.decoder.decode_begin()
|
|
275
|
+
|
|
276
|
+
def decode_end(self):
|
|
277
|
+
"""Finalize the internal state of the decoder.
|
|
278
|
+
|
|
279
|
+
See :py:meth:`decode_step` for the usage.
|
|
280
|
+
|
|
281
|
+
.. note::
|
|
282
|
+
|
|
283
|
+
This method is required only when performing online decoding.
|
|
284
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
285
|
+
"""
|
|
286
|
+
self.decoder.decode_end()
|
|
287
|
+
|
|
288
|
+
def decode_step(self, emissions: torch.FloatTensor):
|
|
289
|
+
"""Perform incremental decoding on top of the curent internal state.
|
|
290
|
+
|
|
291
|
+
.. note::
|
|
292
|
+
|
|
293
|
+
This method is required only when performing online decoding.
|
|
294
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
|
|
298
|
+
probability distribution over labels; output of acoustic model.
|
|
299
|
+
|
|
300
|
+
Example:
|
|
301
|
+
>>> decoder = torchaudio.models.decoder.ctc_decoder(...)
|
|
302
|
+
>>> decoder.decode_begin()
|
|
303
|
+
>>> decoder.decode_step(emission1)
|
|
304
|
+
>>> decoder.decode_step(emission2)
|
|
305
|
+
>>> decoder.decode_end()
|
|
306
|
+
>>> result = decoder.get_final_hypothesis()
|
|
307
|
+
"""
|
|
308
|
+
if emissions.dtype != torch.float32:
|
|
309
|
+
raise ValueError("emissions must be float32.")
|
|
310
|
+
|
|
311
|
+
if not emissions.is_cpu:
|
|
312
|
+
raise RuntimeError("emissions must be a CPU tensor.")
|
|
313
|
+
|
|
314
|
+
if not emissions.is_contiguous():
|
|
315
|
+
raise RuntimeError("emissions must be contiguous.")
|
|
316
|
+
|
|
317
|
+
if emissions.ndim != 2:
|
|
318
|
+
raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
|
|
319
|
+
|
|
320
|
+
T, N = emissions.size()
|
|
321
|
+
self.decoder.decode_step(emissions.data_ptr(), T, N)
|
|
322
|
+
|
|
323
|
+
def _to_hypo(self, results) -> List[CTCHypothesis]:
|
|
324
|
+
return [
|
|
325
|
+
CTCHypothesis(
|
|
326
|
+
tokens=self._get_tokens(result.tokens),
|
|
327
|
+
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
|
|
328
|
+
score=result.score,
|
|
329
|
+
timesteps=self._get_timesteps(result.tokens),
|
|
330
|
+
)
|
|
331
|
+
for result in results
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
def get_final_hypothesis(self) -> List[CTCHypothesis]:
|
|
335
|
+
"""Get the final hypothesis
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
List[CTCHypothesis]:
|
|
339
|
+
List of sorted best hypotheses.
|
|
340
|
+
|
|
341
|
+
.. note::
|
|
342
|
+
|
|
343
|
+
This method is required only when performing online decoding.
|
|
344
|
+
It is not necessary when performing batch decoding with :py:meth:`__call__`.
|
|
345
|
+
"""
|
|
346
|
+
results = self.decoder.get_all_final_hypothesis()
|
|
347
|
+
return self._to_hypo(results[: self.nbest])
|
|
348
|
+
|
|
295
349
|
def __call__(
|
|
296
350
|
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
|
|
297
351
|
) -> List[List[CTCHypothesis]]:
|
|
298
352
|
"""
|
|
353
|
+
Performs batched offline decoding.
|
|
354
|
+
|
|
355
|
+
.. note::
|
|
356
|
+
|
|
357
|
+
This method performs offline decoding in one go. To perform incremental decoding,
|
|
358
|
+
please refer to :py:meth:`decode_step`.
|
|
359
|
+
|
|
299
360
|
Args:
|
|
300
361
|
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
|
|
301
362
|
probability distribution over labels; output of acoustic model.
|
|
@@ -310,13 +371,16 @@ class CTCDecoder:
|
|
|
310
371
|
if emissions.dtype != torch.float32:
|
|
311
372
|
raise ValueError("emissions must be float32.")
|
|
312
373
|
|
|
313
|
-
if emissions.
|
|
374
|
+
if not emissions.is_cpu:
|
|
314
375
|
raise RuntimeError("emissions must be a CPU tensor.")
|
|
315
376
|
|
|
316
377
|
if not emissions.is_contiguous():
|
|
317
378
|
raise RuntimeError("emissions must be contiguous.")
|
|
318
379
|
|
|
319
|
-
if
|
|
380
|
+
if emissions.ndim != 3:
|
|
381
|
+
raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
|
|
382
|
+
|
|
383
|
+
if lengths is not None and not lengths.is_cpu:
|
|
320
384
|
raise RuntimeError("lengths must be a CPU tensor.")
|
|
321
385
|
|
|
322
386
|
B, T, N = emissions.size()
|
|
@@ -329,20 +393,7 @@ class CTCDecoder:
|
|
|
329
393
|
for b in range(B):
|
|
330
394
|
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
|
|
331
395
|
results = self.decoder.decode(emissions_ptr, lengths[b], N)
|
|
332
|
-
|
|
333
|
-
nbest_results = results[: self.nbest]
|
|
334
|
-
hypos.append(
|
|
335
|
-
[
|
|
336
|
-
CTCHypothesis(
|
|
337
|
-
tokens=self._get_tokens(result.tokens),
|
|
338
|
-
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
|
|
339
|
-
score=result.score,
|
|
340
|
-
timesteps=self._get_timesteps(result.tokens),
|
|
341
|
-
)
|
|
342
|
-
for result in nbest_results
|
|
343
|
-
]
|
|
344
|
-
)
|
|
345
|
-
|
|
396
|
+
hypos.append(self._to_hypo(results[: self.nbest]))
|
|
346
397
|
return hypos
|
|
347
398
|
|
|
348
399
|
def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
|
|
@@ -450,7 +501,10 @@ def ctc_decoder(
|
|
|
450
501
|
|
|
451
502
|
if type(lm) == str:
|
|
452
503
|
if _KenLM is None:
|
|
453
|
-
raise RuntimeError(
|
|
504
|
+
raise RuntimeError(
|
|
505
|
+
"flashlight-text is installed, but KenLM is not installed. "
|
|
506
|
+
"Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
|
|
507
|
+
)
|
|
454
508
|
lm = _KenLM(lm, word_dict)
|
|
455
509
|
elif lm is None:
|
|
456
510
|
lm = _ZeroLM()
|