torchaudio 2.0.2__cp39-cp39-manylinux1_x86_64.whl → 2.1.1__cp39-cp39-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.so +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
  51. torchaudio/lib/_torchaudio_sox.so +0 -0
  52. torchaudio/lib/libctc_prefix_decoder.so +0 -0
  53. torchaudio/lib/libtorchaudio.so +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
  55. torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
  56. torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
  57. torchaudio/lib/libtorchaudio_sox.so +0 -0
  58. torchaudio/lib/pybind11_prefixctc.so +0 -0
  59. torchaudio/models/__init__.py +14 -0
  60. torchaudio/models/decoder/__init__.py +22 -7
  61. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  62. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  63. torchaudio/models/rnnt_decoder.py +10 -14
  64. torchaudio/models/squim/__init__.py +11 -0
  65. torchaudio/models/squim/objective.py +326 -0
  66. torchaudio/models/squim/subjective.py +150 -0
  67. torchaudio/models/wav2vec2/components.py +6 -10
  68. torchaudio/pipelines/__init__.py +9 -0
  69. torchaudio/pipelines/_squim_pipeline.py +176 -0
  70. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  71. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  72. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  73. torchaudio/sox_effects/sox_effects.py +7 -30
  74. torchaudio/transforms/__init__.py +2 -0
  75. torchaudio/transforms/_transforms.py +99 -54
  76. torchaudio/utils/download.py +2 -2
  77. torchaudio/utils/ffmpeg_utils.py +20 -15
  78. torchaudio/utils/sox_utils.py +8 -9
  79. torchaudio/version.py +2 -2
  80. torchaudio-2.1.1.dist-info/METADATA +113 -0
  81. torchaudio-2.1.1.dist-info/RECORD +119 -0
  82. torchaudio/io/_compat.py +0 -241
  83. torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
  84. torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
  85. torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
  86. torchaudio/lib/libflashlight-text.so +0 -0
  87. torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
  88. torchaudio-2.0.2.dist-info/METADATA +0 -26
  89. torchaudio-2.0.2.dist-info/RECORD +0 -100
  90. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  91. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +0 -0
  92. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,45 @@
1
+ from dataclasses import dataclass
1
2
  from typing import BinaryIO, Dict, Optional, Union
2
3
 
3
4
  import torch
4
5
  import torchaudio
5
6
 
6
7
 
8
+ if torchaudio._extension._FFMPEG_EXT is None:
9
+ ConfigBase = object
10
+ else:
11
+ ConfigBase = torchaudio._extension._FFMPEG_EXT.CodecConfig
12
+ _StreamWriter = torchaudio._extension._FFMPEG_EXT.StreamWriter
13
+ _StreamWriterFileObj = torchaudio._extension._FFMPEG_EXT.StreamWriterFileObj
14
+
15
+
16
+ @dataclass
17
+ class CodecConfig(ConfigBase):
18
+ """Codec configuration."""
19
+
20
+ bit_rate: int = -1
21
+ """Bit rate"""
22
+
23
+ compression_level: int = -1
24
+ """Compression level"""
25
+
26
+ qscale: Optional[int] = None
27
+ """Global quality factor. Enables variable bit rate. Valid values depend on encoder.
28
+
29
+ For example: MP3 takes ``0`` - ``9`` (https://trac.ffmpeg.org/wiki/Encode/MP3) while
30
+ libvorbis takes ``-1`` - ``10``.
31
+ """
32
+
33
+ gop_size: int = -1
34
+ """The number of pictures in a group of pictures, or 0 for intra_only"""
35
+
36
+ max_b_frames: int = -1
37
+ """maximum number of B-frames between non-B-frames."""
38
+
39
+ def __post_init__(self):
40
+ super().__init__(self.bit_rate, self.compression_level, self.qscale, self.gop_size, self.max_b_frames)
41
+
42
+
7
43
  def _format_doc(**kwargs):
8
44
  def decorator(obj):
9
45
  obj.__doc__ = obj.__doc__.format(**kwargs)
@@ -28,7 +64,29 @@ _encoder_option = """Options passed to encoder.
28
64
  To list encoder options for a encoder, you can use
29
65
  ``ffmpeg -h encoder=<ENCODER>`` command.
30
66
 
31
- Default: ``None``."""
67
+ Default: ``None``.
68
+
69
+ |
70
+
71
+ In addition to encoder-specific options, you can also pass options related
72
+ to multithreading. They are effective only if the encoder support them.
73
+ If neither of them are provided, StreamReader defaults to single thread.
74
+
75
+ ``"threads"``: The number of threads (in str).
76
+ Providing the value ``"0"`` will let FFmpeg decides based on its heuristics.
77
+
78
+ ``"thread_type"``: Which multithreading method to use.
79
+ The valid values are ``"frame"`` or ``"slice"``.
80
+ Note that each encoder supports different set of methods.
81
+ If not provided, a default value is used.
82
+
83
+ - ``"frame"``: Encode more than one frame at once.
84
+ Each thread handles one frame.
85
+ This will increase decoding delay by one frame per thread
86
+ - ``"slice"``: Encode more than one part of a single frame at once.
87
+
88
+ |
89
+ """
32
90
 
33
91
 
34
92
  _encoder_format = """Format used to encode media.
@@ -38,13 +96,34 @@ _encoder_format = """Format used to encode media.
38
96
  To list supported formats for the encoder, you can use
39
97
  ``ffmpeg -h encoder=<ENCODER>`` command.
40
98
 
99
+ Default: ``None``.
100
+
101
+ Note:
102
+ When ``encoder_format`` option is not provided, encoder uses its default format.
103
+
104
+ For example, when encoding audio into wav format, 16-bit signed integer is used,
105
+ and when encoding video into mp4 format (h264 encoder), one of YUV format is used.
106
+
107
+ This is because typically, 32-bit or 16-bit floating point is used in audio models but
108
+ they are not commonly used in audio formats. Similarly, RGB24 is commonly used in vision
109
+ models, but video formats usually (and better) support YUV formats.
110
+ """
111
+
112
+ _codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig` for
113
+ configuration options.
114
+
41
115
  Default: ``None``."""
42
116
 
43
117
 
118
+ _filter_desc = """Additional processing to apply before encoding the input media.
119
+ """
120
+
44
121
  _format_common_args = _format_doc(
45
122
  encoder=_encoder,
46
123
  encoder_option=_encoder_option,
47
124
  encoder_format=_encoder_format,
125
+ codec_config=_codec_config,
126
+ filter_desc=_filter_desc,
48
127
  )
49
128
 
50
129
 
@@ -109,11 +188,10 @@ class StreamWriter:
109
188
  format: Optional[str] = None,
110
189
  buffer_size: int = 4096,
111
190
  ):
112
- torch._C._log_api_usage_once("torchaudio.io.StreamWriter")
113
191
  if isinstance(dst, str):
114
- self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format)
192
+ self._s = _StreamWriter(dst, format)
115
193
  elif hasattr(dst, "write"):
116
- self._s = torchaudio.lib._torchaudio_ffmpeg.StreamWriterFileObj(dst, format, buffer_size)
194
+ self._s = _StreamWriterFileObj(dst, format, buffer_size)
117
195
  else:
118
196
  raise ValueError("`dst` must be either a string or a file-like object.")
119
197
  self._is_open = False
@@ -124,9 +202,14 @@ class StreamWriter:
124
202
  sample_rate: int,
125
203
  num_channels: int,
126
204
  format: str = "flt",
205
+ *,
127
206
  encoder: Optional[str] = None,
128
207
  encoder_option: Optional[Dict[str, str]] = None,
208
+ encoder_sample_rate: Optional[int] = None,
209
+ encoder_num_channels: Optional[int] = None,
129
210
  encoder_format: Optional[str] = None,
211
+ codec_config: Optional[CodecConfig] = None,
212
+ filter_desc: Optional[str] = None,
130
213
  ):
131
214
  """Add an output audio stream.
132
215
 
@@ -151,9 +234,56 @@ class StreamWriter:
151
234
 
152
235
  encoder_option (dict or None, optional): {encoder_option}
153
236
 
237
+ encoder_sample_rate (int or None, optional): Override the sample rate used for encoding time.
238
+ Some encoders pose restriction on the sample rate used for encoding.
239
+ If the source sample rate is not supported by the encoder, the source sample rate is used,
240
+ otherwise a default one is picked.
241
+
242
+ For example, ``"opus"`` encoder only supports 48k Hz, so, when encoding a
243
+ waveform with ``"opus"`` encoder, it is always encoded as 48k Hz.
244
+ Meanwhile ``"mp3"`` (``"libmp3lame"``) supports 44.1k, 48k, 32k, 22.05k,
245
+ 24k, 16k, 11.025k, 12k and 8k Hz.
246
+ If the original sample rate is one of these, then the original sample rate
247
+ is used, otherwise it will be resampled to a default one (44.1k).
248
+ When encoding into WAV format, there is no restriction on sample rate,
249
+ so the original sample rate will be used.
250
+
251
+ Providing ``encoder_sample_rate`` will override this behavior and
252
+ make encoder attempt to use the provided sample rate.
253
+ The provided value must be one support by the encoder.
254
+
255
+ encoder_num_channels (int or None, optional): Override the number of channels used for encoding.
256
+
257
+ Similar to sample rate, some encoders (such as ``"opus"``,
258
+ ``"vorbis"`` and ``"g722"``) pose restriction on
259
+ the numbe of channels that can be used for encoding.
260
+
261
+ If the original number of channels is supported by encoder,
262
+ then it will be used, otherwise, the encoder attempts to
263
+ remix the channel to one of the supported ones.
264
+
265
+ Providing ``encoder_num_channels`` will override this behavior and
266
+ make encoder attempt to use the provided number of channels.
267
+ The provided value must be one support by the encoder.
268
+
154
269
  encoder_format (str or None, optional): {encoder_format}
270
+
271
+ codec_config (CodecConfig or None, optional): {codec_config}
272
+
273
+ filter_desc (str or None, optional): {filter_desc}
155
274
  """
156
- self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format)
275
+ self._s.add_audio_stream(
276
+ sample_rate,
277
+ num_channels,
278
+ format,
279
+ encoder,
280
+ encoder_option,
281
+ encoder_format,
282
+ encoder_sample_rate,
283
+ encoder_num_channels,
284
+ codec_config,
285
+ filter_desc,
286
+ )
157
287
 
158
288
  @_format_common_args
159
289
  def add_video_stream(
@@ -162,9 +292,15 @@ class StreamWriter:
162
292
  width: int,
163
293
  height: int,
164
294
  format: str = "rgb24",
295
+ *,
165
296
  encoder: Optional[str] = None,
166
297
  encoder_option: Optional[Dict[str, str]] = None,
298
+ encoder_frame_rate: Optional[float] = None,
299
+ encoder_width: Optional[int] = None,
300
+ encoder_height: Optional[int] = None,
167
301
  encoder_format: Optional[str] = None,
302
+ codec_config: Optional[CodecConfig] = None,
303
+ filter_desc: Optional[str] = None,
168
304
  hw_accel: Optional[str] = None,
169
305
  ):
170
306
  """Add an output video stream.
@@ -195,8 +331,30 @@ class StreamWriter:
195
331
 
196
332
  encoder_option (dict or None, optional): {encoder_option}
197
333
 
334
+ encoder_frame_rate (float or None, optional): Override the frame rate used for encoding.
335
+
336
+ Some encoders, (such as ``"mpeg1"`` and ``"mpeg2"``) pose restriction on the
337
+ frame rate that can be used for encoding.
338
+ If such case, if the source frame rate (provided as ``frame_rate``) is not
339
+ one of the supported frame rate, then a default one is picked, and the frame rate
340
+ is changed on-the-fly. Otherwise the source frame rate is used.
341
+
342
+ Providing ``encoder_frame_rate`` will override this behavior and
343
+ make encoder attempts to use the provided sample rate.
344
+ The provided value must be one support by the encoder.
345
+
346
+ encoder_width (int or None, optional): Width of the image used for encoding.
347
+ This allows to change the image size during encoding.
348
+
349
+ encoder_height (int or None, optional): Height of the image used for encoding.
350
+ This allows to change the image size during encoding.
351
+
198
352
  encoder_format (str or None, optional): {encoder_format}
199
353
 
354
+ codec_config (CodecConfig or None, optional): {codec_config}
355
+
356
+ filter_desc (str or None, optional): {filter_desc}
357
+
200
358
  hw_accel (str or None, optional): Enable hardware acceleration.
201
359
 
202
360
  When video is encoded on CUDA hardware, for example
@@ -207,7 +365,21 @@ class StreamWriter:
207
365
  If `None`, the video chunk Tensor has to be CPU Tensor.
208
366
  Default: ``None``.
209
367
  """
210
- self._s.add_video_stream(frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel)
368
+ self._s.add_video_stream(
369
+ frame_rate,
370
+ width,
371
+ height,
372
+ format,
373
+ encoder,
374
+ encoder_option,
375
+ encoder_format,
376
+ encoder_frame_rate,
377
+ encoder_width,
378
+ encoder_height,
379
+ hw_accel,
380
+ codec_config,
381
+ filter_desc,
382
+ )
211
383
 
212
384
  def set_metadata(self, metadata: Dict[str, str]):
213
385
  """Set file-level metadata
@@ -276,17 +448,24 @@ class StreamWriter:
276
448
  self._s.close()
277
449
  self._is_open = False
278
450
 
279
- def write_audio_chunk(self, i: int, chunk: torch.Tensor):
451
+ def write_audio_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
280
452
  """Write audio data
281
453
 
282
454
  Args:
283
455
  i (int): Stream index.
284
456
  chunk (Tensor): Waveform tensor. Shape: `(frame, channel)`.
285
457
  The ``dtype`` must match what was passed to :py:meth:`add_audio_stream` method.
458
+ pts (float, optional, or None): If provided, overwrite the presentation timestamp.
459
+
460
+ .. note::
461
+
462
+ The provided value is converted to integer value expressed in basis of
463
+ sample rate. Therefore, it is truncated to the nearest value of
464
+ ``n / sample_rate``.
286
465
  """
287
- self._s.write_audio_chunk(i, chunk)
466
+ self._s.write_audio_chunk(i, chunk, pts)
288
467
 
289
- def write_video_chunk(self, i: int, chunk: torch.Tensor):
468
+ def write_video_chunk(self, i: int, chunk: torch.Tensor, pts: Optional[float] = None):
290
469
  """Write video/image data
291
470
 
292
471
  Args:
@@ -296,8 +475,15 @@ class StreamWriter:
296
475
  The ``dtype`` must be ``torch.uint8``.
297
476
  The shape (height, width and the number of channels) must match
298
477
  what was configured when calling :py:meth:`add_video_stream`
478
+ pts (float, optional or None): If provided, overwrite the presentation timestamp.
479
+
480
+ .. note::
481
+
482
+ The provided value is converted to integer value expressed in basis of
483
+ frame rate. Therefore, it is truncated to the nearest value of
484
+ ``n / frame_rate``.
299
485
  """
300
- self._s.write_video_chunk(i, chunk)
486
+ self._s.write_video_chunk(i, chunk, pts)
301
487
 
302
488
  def flush(self):
303
489
  """Flush the frames from encoders and write the frames to the destination."""
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -5,6 +5,14 @@ from .deepspeech import DeepSpeech
5
5
  from .emformer import Emformer
6
6
  from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
7
7
  from .rnnt_decoder import Hypothesis, RNNTBeamSearch
8
+ from .squim import (
9
+ squim_objective_base,
10
+ squim_objective_model,
11
+ squim_subjective_base,
12
+ squim_subjective_model,
13
+ SquimObjective,
14
+ SquimSubjective,
15
+ )
8
16
  from .tacotron2 import Tacotron2
9
17
  from .wav2letter import Wav2Letter
10
18
  from .wav2vec2 import (
@@ -68,4 +76,10 @@ __all__ = [
68
76
  "hdemucs_low",
69
77
  "hdemucs_medium",
70
78
  "hdemucs_high",
79
+ "squim_objective_base",
80
+ "squim_objective_model",
81
+ "squim_subjective_base",
82
+ "squim_subjective_model",
83
+ "SquimObjective",
84
+ "SquimSubjective",
71
85
  ]
@@ -1,5 +1,4 @@
1
- _INITIALIZED = False
2
- _LAZILY_IMPORTED = [
1
+ _CTC_DECODERS = [
3
2
  "CTCHypothesis",
4
3
  "CTCDecoder",
5
4
  "CTCDecoderLM",
@@ -7,25 +6,41 @@ _LAZILY_IMPORTED = [
7
6
  "ctc_decoder",
8
7
  "download_pretrained_files",
9
8
  ]
9
+ _CUDA_CTC_DECODERS = [
10
+ "CUCTCDecoder",
11
+ "CUCTCHypothesis",
12
+ "cuda_ctc_decoder",
13
+ ]
10
14
 
11
15
 
12
16
  def __getattr__(name: str):
13
- if name in _LAZILY_IMPORTED:
17
+ if name in _CTC_DECODERS:
14
18
  try:
15
19
  from . import _ctc_decoder
16
- except AttributeError as err:
20
+ except Exception as err:
17
21
  raise RuntimeError(
18
- "CTC decoder requires the decoder extension. Please set BUILD_CTC_DECODER=1 when building from source."
22
+ "CTC Decoder suit requires flashlight-text package and optionally KenLM. Please install them."
19
23
  ) from err
20
24
 
21
25
  item = getattr(_ctc_decoder, name)
22
26
  globals()[name] = item
23
27
  return item
28
+ elif name in _CUDA_CTC_DECODERS:
29
+ try:
30
+ from . import _cuda_ctc_decoder
31
+ except AttributeError as err:
32
+ raise RuntimeError(
33
+ "To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
34
+ ) from err
35
+
36
+ item = getattr(_cuda_ctc_decoder, name)
37
+ globals()[name] = item
38
+ return item
24
39
  raise AttributeError(f"module {__name__} has no attribute {name}")
25
40
 
26
41
 
27
42
  def __dir__():
28
- return sorted(__all__ + _LAZILY_IMPORTED)
43
+ return sorted(__all__)
29
44
 
30
45
 
31
- __all__ = []
46
+ __all__ = [_CTC_DECODERS, _CUDA_CTC_DECODERS]
@@ -2,69 +2,38 @@ from __future__ import annotations
2
2
 
3
3
  import itertools as it
4
4
 
5
- import warnings
6
5
  from abc import abstractmethod
7
6
  from collections import namedtuple
8
7
  from typing import Dict, List, NamedTuple, Optional, Tuple, Union
9
8
 
10
9
  import torch
11
- import torchaudio
12
- from torchaudio.utils import download_asset
13
-
14
10
 
15
- # We prioritize the version from upstream flashlight here.
16
- # This will allow applications that use the upstream flashlight
17
- # alongside torchaudio.
18
- if torchaudio._internal.module_utils.is_module_available("flashlight"):
19
- from flashlight.lib.text.decoder import (
20
- CriterionType as _CriterionType,
21
- LexiconDecoder as _LexiconDecoder,
22
- LexiconDecoderOptions as _LexiconDecoderOptions,
23
- LexiconFreeDecoder as _LexiconFreeDecoder,
24
- LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
25
- LM as _LM,
26
- LMState as _LMState,
27
- SmearingMode as _SmearingMode,
28
- Trie as _Trie,
29
- ZeroLM as _ZeroLM,
30
- )
31
- from flashlight.lib.text.dictionary import (
32
- create_word_dict as _create_word_dict,
33
- Dictionary as _Dictionary,
34
- load_words as _load_words,
35
- )
11
+ from flashlight.lib.text.decoder import (
12
+ CriterionType as _CriterionType,
13
+ LexiconDecoder as _LexiconDecoder,
14
+ LexiconDecoderOptions as _LexiconDecoderOptions,
15
+ LexiconFreeDecoder as _LexiconFreeDecoder,
16
+ LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
17
+ LM as _LM,
18
+ LMState as _LMState,
19
+ SmearingMode as _SmearingMode,
20
+ Trie as _Trie,
21
+ ZeroLM as _ZeroLM,
22
+ )
23
+ from flashlight.lib.text.dictionary import (
24
+ create_word_dict as _create_word_dict,
25
+ Dictionary as _Dictionary,
26
+ load_words as _load_words,
27
+ )
28
+ from torchaudio.utils import download_asset
36
29
 
30
+ try:
31
+ from flashlight.lib.text.decoder.kenlm import KenLM as _KenLM
32
+ except Exception:
37
33
  try:
38
34
  from flashlight.lib.text.decoder import KenLM as _KenLM
39
35
  except Exception:
40
36
  _KenLM = None
41
- else:
42
- torchaudio._extension._load_lib("libflashlight-text")
43
- from torchaudio.lib.flashlight_lib_text_decoder import (
44
- CriterionType as _CriterionType,
45
- KenLM as _KenLM,
46
- LexiconDecoder as _LexiconDecoder,
47
- LexiconDecoderOptions as _LexiconDecoderOptions,
48
- LexiconFreeDecoder as _LexiconFreeDecoder,
49
- LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
50
- LM as _LM,
51
- LMState as _LMState,
52
- SmearingMode as _SmearingMode,
53
- Trie as _Trie,
54
- ZeroLM as _ZeroLM,
55
- )
56
- from torchaudio.lib.flashlight_lib_text_dictionary import (
57
- create_word_dict as _create_word_dict,
58
- Dictionary as _Dictionary,
59
- load_words as _load_words,
60
- )
61
-
62
- warnings.warn(
63
- "The built-in flashlight integration is deprecated, and will be removed in future release. "
64
- "Please install flashlight-text. https://pypi.org/project/flashlight-text/ "
65
- "For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088."
66
- )
67
-
68
37
 
69
38
  __all__ = [
70
39
  "CTCHypothesis",
@@ -292,10 +261,102 @@ class CTCDecoder:
292
261
  timesteps.append(i)
293
262
  return torch.IntTensor(timesteps)
294
263
 
264
+ def decode_begin(self):
265
+ """Initialize the internal state of the decoder.
266
+
267
+ See :py:meth:`decode_step` for the usage.
268
+
269
+ .. note::
270
+
271
+ This method is required only when performing online decoding.
272
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
273
+ """
274
+ self.decoder.decode_begin()
275
+
276
+ def decode_end(self):
277
+ """Finalize the internal state of the decoder.
278
+
279
+ See :py:meth:`decode_step` for the usage.
280
+
281
+ .. note::
282
+
283
+ This method is required only when performing online decoding.
284
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
285
+ """
286
+ self.decoder.decode_end()
287
+
288
+ def decode_step(self, emissions: torch.FloatTensor):
289
+ """Perform incremental decoding on top of the curent internal state.
290
+
291
+ .. note::
292
+
293
+ This method is required only when performing online decoding.
294
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
295
+
296
+ Args:
297
+ emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
298
+ probability distribution over labels; output of acoustic model.
299
+
300
+ Example:
301
+ >>> decoder = torchaudio.models.decoder.ctc_decoder(...)
302
+ >>> decoder.decode_begin()
303
+ >>> decoder.decode_step(emission1)
304
+ >>> decoder.decode_step(emission2)
305
+ >>> decoder.decode_end()
306
+ >>> result = decoder.get_final_hypothesis()
307
+ """
308
+ if emissions.dtype != torch.float32:
309
+ raise ValueError("emissions must be float32.")
310
+
311
+ if not emissions.is_cpu:
312
+ raise RuntimeError("emissions must be a CPU tensor.")
313
+
314
+ if not emissions.is_contiguous():
315
+ raise RuntimeError("emissions must be contiguous.")
316
+
317
+ if emissions.ndim != 2:
318
+ raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
319
+
320
+ T, N = emissions.size()
321
+ self.decoder.decode_step(emissions.data_ptr(), T, N)
322
+
323
+ def _to_hypo(self, results) -> List[CTCHypothesis]:
324
+ return [
325
+ CTCHypothesis(
326
+ tokens=self._get_tokens(result.tokens),
327
+ words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
328
+ score=result.score,
329
+ timesteps=self._get_timesteps(result.tokens),
330
+ )
331
+ for result in results
332
+ ]
333
+
334
+ def get_final_hypothesis(self) -> List[CTCHypothesis]:
335
+ """Get the final hypothesis
336
+
337
+ Returns:
338
+ List[CTCHypothesis]:
339
+ List of sorted best hypotheses.
340
+
341
+ .. note::
342
+
343
+ This method is required only when performing online decoding.
344
+ It is not necessary when performing batch decoding with :py:meth:`__call__`.
345
+ """
346
+ results = self.decoder.get_all_final_hypothesis()
347
+ return self._to_hypo(results[: self.nbest])
348
+
295
349
  def __call__(
296
350
  self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
297
351
  ) -> List[List[CTCHypothesis]]:
298
352
  """
353
+ Performs batched offline decoding.
354
+
355
+ .. note::
356
+
357
+ This method performs offline decoding in one go. To perform incremental decoding,
358
+ please refer to :py:meth:`decode_step`.
359
+
299
360
  Args:
300
361
  emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
301
362
  probability distribution over labels; output of acoustic model.
@@ -310,13 +371,16 @@ class CTCDecoder:
310
371
  if emissions.dtype != torch.float32:
311
372
  raise ValueError("emissions must be float32.")
312
373
 
313
- if emissions.is_cuda:
374
+ if not emissions.is_cpu:
314
375
  raise RuntimeError("emissions must be a CPU tensor.")
315
376
 
316
377
  if not emissions.is_contiguous():
317
378
  raise RuntimeError("emissions must be contiguous.")
318
379
 
319
- if lengths is not None and lengths.is_cuda:
380
+ if emissions.ndim != 3:
381
+ raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
382
+
383
+ if lengths is not None and not lengths.is_cpu:
320
384
  raise RuntimeError("lengths must be a CPU tensor.")
321
385
 
322
386
  B, T, N = emissions.size()
@@ -329,20 +393,7 @@ class CTCDecoder:
329
393
  for b in range(B):
330
394
  emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
331
395
  results = self.decoder.decode(emissions_ptr, lengths[b], N)
332
-
333
- nbest_results = results[: self.nbest]
334
- hypos.append(
335
- [
336
- CTCHypothesis(
337
- tokens=self._get_tokens(result.tokens),
338
- words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
339
- score=result.score,
340
- timesteps=self._get_timesteps(result.tokens),
341
- )
342
- for result in nbest_results
343
- ]
344
- )
345
-
396
+ hypos.append(self._to_hypo(results[: self.nbest]))
346
397
  return hypos
347
398
 
348
399
  def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
@@ -450,7 +501,10 @@ def ctc_decoder(
450
501
 
451
502
  if type(lm) == str:
452
503
  if _KenLM is None:
453
- raise RuntimeError("flashlight is installed, but KenLM is not installed. Please install KenLM.")
504
+ raise RuntimeError(
505
+ "flashlight-text is installed, but KenLM is not installed. "
506
+ "Please refer to https://github.com/kpu/kenlm#python-module for how to install it."
507
+ )
454
508
  lm = _KenLM(lm, word_dict)
455
509
  elif lm is None:
456
510
  lm = _ZeroLM()