torchaudio 2.0.2__cp310-cp310-win_amd64.whl → 2.1.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (88) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.pyd +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.pyd +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.pyd +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.pyd +0 -0
  51. torchaudio/lib/libtorchaudio.pyd +0 -0
  52. torchaudio/lib/libtorchaudio_ffmpeg4.pyd +0 -0
  53. torchaudio/lib/libtorchaudio_ffmpeg5.pyd +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg6.pyd +0 -0
  55. torchaudio/models/__init__.py +14 -0
  56. torchaudio/models/decoder/__init__.py +22 -7
  57. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  58. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  59. torchaudio/models/rnnt_decoder.py +10 -14
  60. torchaudio/models/squim/__init__.py +11 -0
  61. torchaudio/models/squim/objective.py +326 -0
  62. torchaudio/models/squim/subjective.py +150 -0
  63. torchaudio/models/wav2vec2/components.py +6 -10
  64. torchaudio/pipelines/__init__.py +9 -0
  65. torchaudio/pipelines/_squim_pipeline.py +176 -0
  66. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  67. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  68. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  69. torchaudio/sox_effects/sox_effects.py +7 -30
  70. torchaudio/transforms/__init__.py +2 -0
  71. torchaudio/transforms/_transforms.py +99 -54
  72. torchaudio/utils/download.py +2 -2
  73. torchaudio/utils/ffmpeg_utils.py +20 -15
  74. torchaudio/utils/sox_utils.py +8 -9
  75. torchaudio/version.py +2 -2
  76. torchaudio-2.1.1.dist-info/METADATA +113 -0
  77. torchaudio-2.1.1.dist-info/RECORD +115 -0
  78. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +1 -1
  79. torchaudio/io/_compat.py +0 -241
  80. torchaudio/lib/_torchaudio_ffmpeg.pyd +0 -0
  81. torchaudio/lib/flashlight_lib_text_decoder.pyd +0 -0
  82. torchaudio/lib/flashlight_lib_text_dictionary.pyd +0 -0
  83. torchaudio/lib/libflashlight-text.pyd +0 -0
  84. torchaudio/lib/libtorchaudio_ffmpeg.pyd +0 -0
  85. torchaudio-2.0.2.dist-info/METADATA +0 -26
  86. torchaudio-2.0.2.dist-info/RECORD +0 -98
  87. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  88. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
@@ -3,78 +3,77 @@ import re
3
3
  from pathlib import Path
4
4
  from typing import Iterable, List, Tuple, Union
5
5
 
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+
8
9
 
9
10
  _CHECKSUMS = {
10
11
  "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": "209a8b4cd265013e96f4658632a9878103b0c5abf62b50d4ef3ae1be226b29e4", # noqa: E501
11
12
  "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": "408ccaae803641c6d7b626b6299949320c2dbca96b2220fd3fb17887b023b027", # noqa: E501
12
13
  }
13
- _PUNCTUATIONS = set(
14
- [
15
- "!EXCLAMATION-POINT",
16
- '"CLOSE-QUOTE',
17
- '"DOUBLE-QUOTE',
18
- '"END-OF-QUOTE',
19
- '"END-QUOTE',
20
- '"IN-QUOTES',
21
- '"QUOTE',
22
- '"UNQUOTE',
23
- "#HASH-MARK",
24
- "#POUND-SIGN",
25
- "#SHARP-SIGN",
26
- "%PERCENT",
27
- "&AMPERSAND",
28
- "'END-INNER-QUOTE",
29
- "'END-QUOTE",
30
- "'INNER-QUOTE",
31
- "'QUOTE",
32
- "'SINGLE-QUOTE",
33
- "(BEGIN-PARENS",
34
- "(IN-PARENTHESES",
35
- "(LEFT-PAREN",
36
- "(OPEN-PARENTHESES",
37
- "(PAREN",
38
- "(PARENS",
39
- "(PARENTHESES",
40
- ")CLOSE-PAREN",
41
- ")CLOSE-PARENTHESES",
42
- ")END-PAREN",
43
- ")END-PARENS",
44
- ")END-PARENTHESES",
45
- ")END-THE-PAREN",
46
- ")PAREN",
47
- ")PARENS",
48
- ")RIGHT-PAREN",
49
- ")UN-PARENTHESES",
50
- "+PLUS",
51
- ",COMMA",
52
- "--DASH",
53
- "-DASH",
54
- "-HYPHEN",
55
- "...ELLIPSIS",
56
- ".DECIMAL",
57
- ".DOT",
58
- ".FULL-STOP",
59
- ".PERIOD",
60
- ".POINT",
61
- "/SLASH",
62
- ":COLON",
63
- ";SEMI-COLON",
64
- ";SEMI-COLON(1)",
65
- "?QUESTION-MARK",
66
- "{BRACE",
67
- "{LEFT-BRACE",
68
- "{OPEN-BRACE",
69
- "}CLOSE-BRACE",
70
- "}RIGHT-BRACE",
71
- ]
72
- )
14
+ _PUNCTUATIONS = {
15
+ "!EXCLAMATION-POINT",
16
+ '"CLOSE-QUOTE',
17
+ '"DOUBLE-QUOTE',
18
+ '"END-OF-QUOTE',
19
+ '"END-QUOTE',
20
+ '"IN-QUOTES',
21
+ '"QUOTE',
22
+ '"UNQUOTE',
23
+ "#HASH-MARK",
24
+ "#POUND-SIGN",
25
+ "#SHARP-SIGN",
26
+ "%PERCENT",
27
+ "&AMPERSAND",
28
+ "'END-INNER-QUOTE",
29
+ "'END-QUOTE",
30
+ "'INNER-QUOTE",
31
+ "'QUOTE",
32
+ "'SINGLE-QUOTE",
33
+ "(BEGIN-PARENS",
34
+ "(IN-PARENTHESES",
35
+ "(LEFT-PAREN",
36
+ "(OPEN-PARENTHESES",
37
+ "(PAREN",
38
+ "(PARENS",
39
+ "(PARENTHESES",
40
+ ")CLOSE-PAREN",
41
+ ")CLOSE-PARENTHESES",
42
+ ")END-PAREN",
43
+ ")END-PARENS",
44
+ ")END-PARENTHESES",
45
+ ")END-THE-PAREN",
46
+ ")PAREN",
47
+ ")PARENS",
48
+ ")RIGHT-PAREN",
49
+ ")UN-PARENTHESES",
50
+ "+PLUS",
51
+ ",COMMA",
52
+ "--DASH",
53
+ "-DASH",
54
+ "-HYPHEN",
55
+ "...ELLIPSIS",
56
+ ".DECIMAL",
57
+ ".DOT",
58
+ ".FULL-STOP",
59
+ ".PERIOD",
60
+ ".POINT",
61
+ "/SLASH",
62
+ ":COLON",
63
+ ";SEMI-COLON",
64
+ ";SEMI-COLON(1)",
65
+ "?QUESTION-MARK",
66
+ "{BRACE",
67
+ "{LEFT-BRACE",
68
+ "{OPEN-BRACE",
69
+ "}CLOSE-BRACE",
70
+ "}RIGHT-BRACE",
71
+ }
73
72
 
74
73
 
75
74
  def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
76
75
  _alt_re = re.compile(r"\([0-9]+\)")
77
- cmudict: List[Tuple[str, List[str]]] = list()
76
+ cmudict: List[Tuple[str, List[str]]] = []
78
77
  for line in lines:
79
78
  if not line or line.startswith(";;;"): # ignore comments
80
79
  continue
@@ -3,8 +3,8 @@ from typing import Dict, Tuple, Union
3
3
 
4
4
  import torchaudio
5
5
  from torch import Tensor
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
8
  from torchaudio.datasets.utils import _extract_zip
9
9
 
10
10
 
@@ -4,8 +4,8 @@ from typing import Optional, Tuple, Union
4
4
 
5
5
  import torchaudio
6
6
  from torch import Tensor
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_tar
10
10
 
11
11
  # The following lists prefixed with `filtered_` provide a filtered split
@@ -4,8 +4,8 @@ from typing import List, Tuple, Union
4
4
 
5
5
  import torchaudio
6
6
  from torch import Tensor
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.librispeech import _get_librispeech_metadata
10
10
  from torchaudio.datasets.utils import _extract_tar
11
11
 
@@ -3,8 +3,8 @@ from pathlib import Path
3
3
  from typing import Tuple, Union
4
4
 
5
5
  from torch import Tensor
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
8
  from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
9
 
10
10
  URL = "train-clean-100"
@@ -0,0 +1,189 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
+
10
+ URL = "train-clean-100"
11
+ FOLDER_IN_ARCHIVE = "LibriSpeech"
12
+ SAMPLE_RATE = 16000
13
+ _DATA_SUBSETS = [
14
+ "dev-clean",
15
+ "dev-other",
16
+ "test-clean",
17
+ "test-other",
18
+ "train-clean-100",
19
+ "train-clean-360",
20
+ "train-other-500",
21
+ ]
22
+ _CHECKSUMS = {
23
+ "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
24
+ "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
25
+ "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
26
+ "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
27
+ "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
28
+ "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
29
+ "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
30
+ }
31
+
32
+
33
+ def _download_librispeech(root, url):
34
+ base_url = "http://www.openslr.org/resources/12/"
35
+ ext_archive = ".tar.gz"
36
+
37
+ filename = url + ext_archive
38
+ archive = os.path.join(root, filename)
39
+ download_url = os.path.join(base_url, filename)
40
+ if not os.path.isfile(archive):
41
+ checksum = _CHECKSUMS.get(download_url, None)
42
+ download_url_to_file(download_url, archive, hash_prefix=checksum)
43
+ _extract_tar(archive)
44
+
45
+
46
+ def _get_librispeech_metadata(
47
+ fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str, blist: List[str]
48
+ ) -> Tuple[str, int, str, int, int, int]:
49
+ blist = blist or []
50
+ speaker_id, chapter_id, utterance_id = fileid.split("-")
51
+
52
+ # Get audio path and sample rate
53
+ fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
54
+ filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
55
+
56
+ # Load text
57
+ file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
58
+ file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
59
+ uttblist = []
60
+ with open(file_text) as ft:
61
+ for line in ft:
62
+ fileid_text, transcript = line.strip().split(" ", 1)
63
+ if fileid_audio == fileid_text:
64
+ # get utterance biasing list
65
+ for word in transcript.split():
66
+ if word in blist and word not in uttblist:
67
+ uttblist.append(word)
68
+ break
69
+ else:
70
+ # Translation not found
71
+ raise FileNotFoundError(f"Translation not found for {fileid_audio}")
72
+
73
+ return (
74
+ filepath,
75
+ SAMPLE_RATE,
76
+ transcript,
77
+ int(speaker_id),
78
+ int(chapter_id),
79
+ int(utterance_id),
80
+ uttblist,
81
+ )
82
+
83
+
84
+ class LibriSpeechBiasing(Dataset):
85
+ """*LibriSpeech* :cite:`7178964` dataset with prefix-tree construction and biasing support.
86
+
87
+ Args:
88
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
89
+ url (str, optional): The URL to download the dataset from,
90
+ or the type of the dataset to dowload.
91
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
92
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
93
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
94
+ folder_in_archive (str, optional):
95
+ The top-level directory of the dataset. (default: ``"LibriSpeech"``)
96
+ download (bool, optional):
97
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
98
+ blist (list, optional):
99
+ The list of biasing words (default: ``[]``).
100
+ """
101
+
102
+ _ext_txt = ".trans.txt"
103
+ _ext_audio = ".flac"
104
+
105
+ def __init__(
106
+ self,
107
+ root: Union[str, Path],
108
+ url: str = URL,
109
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
110
+ download: bool = False,
111
+ blist: List[str] = None,
112
+ ) -> None:
113
+ self._url = url
114
+ if url not in _DATA_SUBSETS:
115
+ raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
116
+
117
+ root = os.fspath(root)
118
+ self._archive = os.path.join(root, folder_in_archive)
119
+ self._path = os.path.join(root, folder_in_archive, url)
120
+
121
+ if not os.path.isdir(self._path):
122
+ if download:
123
+ _download_librispeech(root, url)
124
+ else:
125
+ raise RuntimeError(
126
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
127
+ )
128
+
129
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
130
+ self.blist = blist
131
+
132
+ def get_metadata(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
133
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
134
+ but otherwise returns the same fields as :py:func:`__getitem__`.
135
+
136
+ Args:
137
+ n (int): The index of the sample to be loaded
138
+
139
+ Returns:
140
+ Tuple of the following items;
141
+
142
+ str:
143
+ Path to audio
144
+ int:
145
+ Sample rate
146
+ str:
147
+ Transcript
148
+ int:
149
+ Speaker ID
150
+ int:
151
+ Chapter ID
152
+ int:
153
+ Utterance ID
154
+ list:
155
+ List of biasing words in the utterance
156
+ """
157
+ fileid = self._walker[n]
158
+ return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt, self.blist)
159
+
160
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
161
+ """Load the n-th sample from the dataset.
162
+
163
+ Args:
164
+ n (int): The index of the sample to be loaded
165
+
166
+ Returns:
167
+ Tuple of the following items;
168
+
169
+ Tensor:
170
+ Waveform
171
+ int:
172
+ Sample rate
173
+ str:
174
+ Transcript
175
+ int:
176
+ Speaker ID
177
+ int:
178
+ Chapter ID
179
+ int:
180
+ Utterance ID
181
+ list:
182
+ List of biasing words in the utterance
183
+ """
184
+ metadata = self.get_metadata(n)
185
+ waveform = _load_waveform(self._archive, metadata[0], metadata[1])
186
+ return (waveform,) + metadata[1:]
187
+
188
+ def __len__(self) -> int:
189
+ return len(self._walker)
@@ -4,8 +4,8 @@ from typing import Tuple, Union
4
4
 
5
5
  import torchaudio
6
6
  from torch import Tensor
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_tar
10
10
 
11
11
  URL = "train-clean-100"
@@ -5,8 +5,8 @@ from typing import Tuple, Union
5
5
 
6
6
  import torchaudio
7
7
  from torch import Tensor
8
- from torch.hub import download_url_to_file
9
8
  from torch.utils.data import Dataset
9
+ from torchaudio._internal import download_url_to_file
10
10
  from torchaudio.datasets.utils import _extract_tar
11
11
 
12
12
 
@@ -4,8 +4,8 @@ from typing import List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
6
  import torchaudio
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_zip
10
10
 
11
11
  _URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
@@ -4,8 +4,8 @@ from pathlib import Path
4
4
  from typing import Optional, Tuple, Union
5
5
 
6
6
  import torch
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_tar, _load_waveform
10
10
 
11
11
 
@@ -3,8 +3,8 @@ from pathlib import Path
3
3
  from typing import Optional, Tuple, Union
4
4
 
5
5
  from torch import Tensor
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
8
  from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
9
 
10
10
  FOLDER_IN_ARCHIVE = "SpeechCommands"
@@ -4,8 +4,8 @@ from typing import Tuple, Union
4
4
 
5
5
  import torchaudio
6
6
  from torch import Tensor
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_tar
10
10
 
11
11
 
@@ -3,8 +3,8 @@ from typing import Tuple
3
3
 
4
4
  import torchaudio
5
5
  from torch import Tensor
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
8
  from torchaudio.datasets.utils import _extract_zip
9
9
 
10
10
  URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
@@ -3,8 +3,8 @@ from pathlib import Path
3
3
  from typing import List, Tuple, Union
4
4
 
5
5
  from torch import Tensor
6
- from torch.hub import download_url_to_file
7
6
  from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
8
  from torchaudio.datasets.utils import _extract_zip, _load_waveform
9
9
 
10
10
 
@@ -4,8 +4,8 @@ from typing import List, Tuple, Union
4
4
 
5
5
  import torchaudio
6
6
  from torch import Tensor
7
- from torch.hub import download_url_to_file
8
7
  from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
9
  from torchaudio.datasets.utils import _extract_tar
10
10
 
11
11
 
@@ -1,3 +1,4 @@
1
+ from ._alignment import forced_align, merge_tokens, TokenSpan
1
2
  from .filtering import (
2
3
  allpass_biquad,
3
4
  band_biquad,
@@ -28,7 +29,6 @@ from .functional import (
28
29
  apply_beamforming,
29
30
  apply_codec,
30
31
  compute_deltas,
31
- compute_kaldi_pitch,
32
32
  convolve,
33
33
  create_dct,
34
34
  DB_to_amplitude,
@@ -36,6 +36,7 @@ from .functional import (
36
36
  detect_pitch_frequency,
37
37
  edit_distance,
38
38
  fftconvolve,
39
+ frechet_distance,
39
40
  griffinlim,
40
41
  inverse_spectrogram,
41
42
  linear_fbanks,
@@ -64,7 +65,6 @@ from .functional import (
64
65
  __all__ = [
65
66
  "amplitude_to_DB",
66
67
  "compute_deltas",
67
- "compute_kaldi_pitch",
68
68
  "create_dct",
69
69
  "melscale_fbanks",
70
70
  "linear_fbanks",
@@ -94,6 +94,9 @@ __all__ = [
94
94
  "equalizer_biquad",
95
95
  "filtfilt",
96
96
  "flanger",
97
+ "forced_align",
98
+ "merge_tokens",
99
+ "TokenSpan",
97
100
  "gain",
98
101
  "highpass_biquad",
99
102
  "lfilter",
@@ -120,4 +123,5 @@ __all__ = [
120
123
  "speed",
121
124
  "preemphasis",
122
125
  "deemphasis",
126
+ "frechet_distance",
123
127
  ]
@@ -0,0 +1,128 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torchaudio._extension import fail_if_no_align
7
+
8
+ __all__ = []
9
+
10
+
11
+ @fail_if_no_align
12
+ def forced_align(
13
+ log_probs: Tensor,
14
+ targets: Tensor,
15
+ input_lengths: Optional[Tensor] = None,
16
+ target_lengths: Optional[Tensor] = None,
17
+ blank: int = 0,
18
+ ) -> Tuple[Tensor, Tensor]:
19
+ r"""Align a CTC label sequence to an emission.
20
+
21
+ .. devices:: CPU CUDA
22
+
23
+ .. properties:: TorchScript
24
+
25
+ Args:
26
+ log_probs (Tensor): log probability of CTC emission output.
27
+ Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
28
+ `C` is the number of characters in alphabet including blank.
29
+ targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
30
+ where `L` is the target length.
31
+ input_lengths (Tensor or None, optional):
32
+ Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
33
+ target_lengths (Tensor or None, optional):
34
+ Lengths of the targets. 1-D Tensor of shape `(B,)`.
35
+ blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
36
+
37
+ Returns:
38
+ Tuple(Tensor, Tensor):
39
+ Tensor: Label for each time step in the alignment path computed using forced alignment.
40
+
41
+ Tensor: Log probability scores of the labels for each time step.
42
+
43
+ Note:
44
+ The sequence length of `log_probs` must satisfy:
45
+
46
+
47
+ .. math::
48
+ L_{\text{log\_probs}} \ge L_{\text{label}} + N_{\text{repeat}}
49
+
50
+ where :math:`N_{\text{repeat}}` is the number of consecutively repeated tokens.
51
+ For example, in str `"aabbc"`, the number of repeats are `2`.
52
+
53
+ Note:
54
+ The current version only supports ``batch_size==1``.
55
+ """
56
+ if blank in targets:
57
+ raise ValueError(f"targets Tensor shouldn't contain blank index. Found {targets}.")
58
+ if torch.max(targets) >= log_probs.shape[-1]:
59
+ raise ValueError("targets values must be less than the CTC dimension")
60
+
61
+ if input_lengths is None:
62
+ batch_size, length = log_probs.size(0), log_probs.size(1)
63
+ input_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=log_probs.device)
64
+ if target_lengths is None:
65
+ batch_size, length = targets.size(0), targets.size(1)
66
+ target_lengths = torch.full((batch_size,), length, dtype=torch.int64, device=targets.device)
67
+
68
+ # For TorchScript compatibility
69
+ assert input_lengths is not None
70
+ assert target_lengths is not None
71
+
72
+ paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
73
+ return paths, scores
74
+
75
+
76
+ @dataclass
77
+ class TokenSpan:
78
+ """TokenSpan()
79
+ Token with time stamps and score. Returned by :py:func:`merge_tokens`.
80
+ """
81
+
82
+ token: int
83
+ """The token"""
84
+ start: int
85
+ """The start time (inclusive) in emission time axis."""
86
+ end: int
87
+ """The end time (exclusive) in emission time axis."""
88
+ score: float
89
+ """The score of the this token."""
90
+
91
+ def __len__(self) -> int:
92
+ """Returns the time span"""
93
+ return self.end - self.start
94
+
95
+
96
+ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSpan]:
97
+ """Removes repeated tokens and blank tokens from the given CTC token sequence.
98
+
99
+ Args:
100
+ tokens (Tensor): Alignment tokens (unbatched) returned from :py:func:`forced_align`.
101
+ Shape: `(time, )`.
102
+ scores (Tensor): Alignment scores (unbatched) returned from :py:func:`forced_align`.
103
+ Shape: `(time, )`. When computing the token-size score, the given score is averaged
104
+ across the corresponding time span.
105
+
106
+ Returns:
107
+ list of TokenSpan
108
+
109
+ Example:
110
+ >>> aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths)
111
+ >>> token_spans = merge_tokens(aligned_tokens[0], scores[0])
112
+ """
113
+ if tokens.ndim != 1 or scores.ndim != 1:
114
+ raise ValueError("`tokens` and `scores` must be 1D Tensor.")
115
+ if len(tokens) != len(scores):
116
+ raise ValueError("`tokens` and `scores` must be the same length.")
117
+
118
+ diff = torch.diff(
119
+ tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
120
+ )
121
+ changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
122
+ tokens = tokens.tolist()
123
+ spans = [
124
+ TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
125
+ for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
126
+ if (token := tokens[start]) != blank
127
+ ]
128
+ return spans