torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,189 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ from torch import Tensor
6
+ from torch.utils.data import Dataset
7
+ from torchaudio._internal import download_url_to_file
8
+ from torchaudio.datasets.utils import _extract_tar, _load_waveform
9
+
10
+ URL = "train-clean-100"
11
+ FOLDER_IN_ARCHIVE = "LibriSpeech"
12
+ SAMPLE_RATE = 16000
13
+ _DATA_SUBSETS = [
14
+ "dev-clean",
15
+ "dev-other",
16
+ "test-clean",
17
+ "test-other",
18
+ "train-clean-100",
19
+ "train-clean-360",
20
+ "train-other-500",
21
+ ]
22
+ _CHECKSUMS = {
23
+ "http://www.openslr.org/resources/12/dev-clean.tar.gz": "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", # noqa: E501
24
+ "http://www.openslr.org/resources/12/dev-other.tar.gz": "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", # noqa: E501
25
+ "http://www.openslr.org/resources/12/test-clean.tar.gz": "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", # noqa: E501
26
+ "http://www.openslr.org/resources/12/test-other.tar.gz": "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", # noqa: E501
27
+ "http://www.openslr.org/resources/12/train-clean-100.tar.gz": "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", # noqa: E501
28
+ "http://www.openslr.org/resources/12/train-clean-360.tar.gz": "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", # noqa: E501
29
+ "http://www.openslr.org/resources/12/train-other-500.tar.gz": "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2", # noqa: E501
30
+ }
31
+
32
+
33
+ def _download_librispeech(root, url):
34
+ base_url = "http://www.openslr.org/resources/12/"
35
+ ext_archive = ".tar.gz"
36
+
37
+ filename = url + ext_archive
38
+ archive = os.path.join(root, filename)
39
+ download_url = os.path.join(base_url, filename)
40
+ if not os.path.isfile(archive):
41
+ checksum = _CHECKSUMS.get(download_url, None)
42
+ download_url_to_file(download_url, archive, hash_prefix=checksum)
43
+ _extract_tar(archive)
44
+
45
+
46
+ def _get_librispeech_metadata(
47
+ fileid: str, root: str, folder: str, ext_audio: str, ext_txt: str, blist: List[str]
48
+ ) -> Tuple[str, int, str, int, int, int]:
49
+ blist = blist or []
50
+ speaker_id, chapter_id, utterance_id = fileid.split("-")
51
+
52
+ # Get audio path and sample rate
53
+ fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
54
+ filepath = os.path.join(folder, speaker_id, chapter_id, f"{fileid_audio}{ext_audio}")
55
+
56
+ # Load text
57
+ file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
58
+ file_text = os.path.join(root, folder, speaker_id, chapter_id, file_text)
59
+ uttblist = []
60
+ with open(file_text) as ft:
61
+ for line in ft:
62
+ fileid_text, transcript = line.strip().split(" ", 1)
63
+ if fileid_audio == fileid_text:
64
+ # get utterance biasing list
65
+ for word in transcript.split():
66
+ if word in blist and word not in uttblist:
67
+ uttblist.append(word)
68
+ break
69
+ else:
70
+ # Translation not found
71
+ raise FileNotFoundError(f"Translation not found for {fileid_audio}")
72
+
73
+ return (
74
+ filepath,
75
+ SAMPLE_RATE,
76
+ transcript,
77
+ int(speaker_id),
78
+ int(chapter_id),
79
+ int(utterance_id),
80
+ uttblist,
81
+ )
82
+
83
+
84
+ class LibriSpeechBiasing(Dataset):
85
+ """*LibriSpeech* :cite:`7178964` dataset with prefix-tree construction and biasing support.
86
+
87
+ Args:
88
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
89
+ url (str, optional): The URL to download the dataset from,
90
+ or the type of the dataset to dowload.
91
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
92
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
93
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
94
+ folder_in_archive (str, optional):
95
+ The top-level directory of the dataset. (default: ``"LibriSpeech"``)
96
+ download (bool, optional):
97
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
98
+ blist (list, optional):
99
+ The list of biasing words (default: ``[]``).
100
+ """
101
+
102
+ _ext_txt = ".trans.txt"
103
+ _ext_audio = ".flac"
104
+
105
+ def __init__(
106
+ self,
107
+ root: Union[str, Path],
108
+ url: str = URL,
109
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
110
+ download: bool = False,
111
+ blist: List[str] = None,
112
+ ) -> None:
113
+ self._url = url
114
+ if url not in _DATA_SUBSETS:
115
+ raise ValueError(f"Invalid url '{url}' given; please provide one of {_DATA_SUBSETS}.")
116
+
117
+ root = os.fspath(root)
118
+ self._archive = os.path.join(root, folder_in_archive)
119
+ self._path = os.path.join(root, folder_in_archive, url)
120
+
121
+ if not os.path.isdir(self._path):
122
+ if download:
123
+ _download_librispeech(root, url)
124
+ else:
125
+ raise RuntimeError(
126
+ f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
127
+ )
128
+
129
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
130
+ self.blist = blist
131
+
132
+ def get_metadata(self, n: int) -> Tuple[str, int, str, int, int, int]:
133
+ """Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
134
+ but otherwise returns the same fields as :py:func:`__getitem__`.
135
+
136
+ Args:
137
+ n (int): The index of the sample to be loaded
138
+
139
+ Returns:
140
+ Tuple of the following items;
141
+
142
+ str:
143
+ Path to audio
144
+ int:
145
+ Sample rate
146
+ str:
147
+ Transcript
148
+ int:
149
+ Speaker ID
150
+ int:
151
+ Chapter ID
152
+ int:
153
+ Utterance ID
154
+ list:
155
+ List of biasing words in the utterance
156
+ """
157
+ fileid = self._walker[n]
158
+ return _get_librispeech_metadata(fileid, self._archive, self._url, self._ext_audio, self._ext_txt, self.blist)
159
+
160
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
161
+ """Load the n-th sample from the dataset.
162
+
163
+ Args:
164
+ n (int): The index of the sample to be loaded
165
+
166
+ Returns:
167
+ Tuple of the following items;
168
+
169
+ Tensor:
170
+ Waveform
171
+ int:
172
+ Sample rate
173
+ str:
174
+ Transcript
175
+ int:
176
+ Speaker ID
177
+ int:
178
+ Chapter ID
179
+ int:
180
+ Utterance ID
181
+ list:
182
+ List of biasing words in the utterance
183
+ """
184
+ metadata = self.get_metadata(n)
185
+ waveform = _load_waveform(self._archive, metadata[0], metadata[1])
186
+ return (waveform,) + metadata[1:]
187
+
188
+ def __len__(self) -> int:
189
+ return len(self._walker)
@@ -0,0 +1,168 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+ URL = "train-clean-100"
12
+ FOLDER_IN_ARCHIVE = "LibriTTS"
13
+ _CHECKSUMS = {
14
+ "http://www.openslr.org/resources/60/dev-clean.tar.gz": "da0864e1bd26debed35da8a869dd5c04dfc27682921936de7cff9c8a254dbe1a", # noqa: E501
15
+ "http://www.openslr.org/resources/60/dev-other.tar.gz": "d413eda26f3a152ac7c9cf3658ef85504dfb1b625296e5fa83727f5186cca79c", # noqa: E501
16
+ "http://www.openslr.org/resources/60/test-clean.tar.gz": "234ea5b25859102a87024a4b9b86641f5b5aaaf1197335c95090cde04fe9a4f5", # noqa: E501
17
+ "http://www.openslr.org/resources/60/test-other.tar.gz": "33a5342094f3bba7ccc2e0500b9e72d558f72eb99328ac8debe1d9080402f10d", # noqa: E501
18
+ "http://www.openslr.org/resources/60/train-clean-100.tar.gz": "c5608bf1ef74bb621935382b8399c5cdd51cd3ee47cec51f00f885a64c6c7f6b", # noqa: E501
19
+ "http://www.openslr.org/resources/60/train-clean-360.tar.gz": "ce7cff44dcac46009d18379f37ef36551123a1dc4e5c8e4eb73ae57260de4886", # noqa: E501
20
+ "http://www.openslr.org/resources/60/train-other-500.tar.gz": "e35f7e34deeb2e2bdfe4403d88c8fdd5fbf64865cae41f027a185a6965f0a5df", # noqa: E501
21
+ }
22
+
23
+
24
+ def load_libritts_item(
25
+ fileid: str,
26
+ path: str,
27
+ ext_audio: str,
28
+ ext_original_txt: str,
29
+ ext_normalized_txt: str,
30
+ ) -> Tuple[Tensor, int, str, str, int, int, str]:
31
+ speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_")
32
+ utterance_id = fileid
33
+
34
+ normalized_text = utterance_id + ext_normalized_txt
35
+ normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text)
36
+
37
+ original_text = utterance_id + ext_original_txt
38
+ original_text = os.path.join(path, speaker_id, chapter_id, original_text)
39
+
40
+ file_audio = utterance_id + ext_audio
41
+ file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
42
+
43
+ # Load audio
44
+ waveform, sample_rate = torchaudio.load(file_audio)
45
+
46
+ # Load original text
47
+ with open(original_text) as ft:
48
+ original_text = ft.readline()
49
+
50
+ # Load normalized text
51
+ with open(normalized_text, "r") as ft:
52
+ normalized_text = ft.readline()
53
+
54
+ return (
55
+ waveform,
56
+ sample_rate,
57
+ original_text,
58
+ normalized_text,
59
+ int(speaker_id),
60
+ int(chapter_id),
61
+ utterance_id,
62
+ )
63
+
64
+
65
+ class LIBRITTS(Dataset):
66
+ """*LibriTTS* :cite:`Zen2019LibriTTSAC` dataset.
67
+
68
+ Args:
69
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
70
+ url (str, optional): The URL to download the dataset from,
71
+ or the type of the dataset to dowload.
72
+ Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
73
+ ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
74
+ ``"train-other-500"``. (default: ``"train-clean-100"``)
75
+ folder_in_archive (str, optional):
76
+ The top-level directory of the dataset. (default: ``"LibriTTS"``)
77
+ download (bool, optional):
78
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
79
+ """
80
+
81
+ _ext_original_txt = ".original.txt"
82
+ _ext_normalized_txt = ".normalized.txt"
83
+ _ext_audio = ".wav"
84
+
85
+ def __init__(
86
+ self,
87
+ root: Union[str, Path],
88
+ url: str = URL,
89
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
90
+ download: bool = False,
91
+ ) -> None:
92
+
93
+ if url in [
94
+ "dev-clean",
95
+ "dev-other",
96
+ "test-clean",
97
+ "test-other",
98
+ "train-clean-100",
99
+ "train-clean-360",
100
+ "train-other-500",
101
+ ]:
102
+
103
+ ext_archive = ".tar.gz"
104
+ base_url = "http://www.openslr.org/resources/60/"
105
+
106
+ url = os.path.join(base_url, url + ext_archive)
107
+
108
+ # Get string representation of 'root' in case Path object is passed
109
+ root = os.fspath(root)
110
+
111
+ basename = os.path.basename(url)
112
+ archive = os.path.join(root, basename)
113
+
114
+ basename = basename.split(".")[0]
115
+ folder_in_archive = os.path.join(folder_in_archive, basename)
116
+
117
+ self._path = os.path.join(root, folder_in_archive)
118
+
119
+ if download:
120
+ if not os.path.isdir(self._path):
121
+ if not os.path.isfile(archive):
122
+ checksum = _CHECKSUMS.get(url, None)
123
+ download_url_to_file(url, archive, hash_prefix=checksum)
124
+ _extract_tar(archive)
125
+ else:
126
+ if not os.path.exists(self._path):
127
+ raise RuntimeError(
128
+ f"The path {self._path} doesn't exist. "
129
+ "Please check the ``root`` path or set `download=True` to download it"
130
+ )
131
+
132
+ self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*/*/*" + self._ext_audio))
133
+
134
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
135
+ """Load the n-th sample from the dataset.
136
+
137
+ Args:
138
+ n (int): The index of the sample to be loaded
139
+
140
+ Returns:
141
+ Tuple of the following items;
142
+
143
+ Tensor:
144
+ Waveform
145
+ int:
146
+ Sample rate
147
+ str:
148
+ Original text
149
+ str:
150
+ Normalized text
151
+ int:
152
+ Speaker ID
153
+ int:
154
+ Chapter ID
155
+ str:
156
+ Utterance ID
157
+ """
158
+ fileid = self._walker[n]
159
+ return load_libritts_item(
160
+ fileid,
161
+ self._path,
162
+ self._ext_audio,
163
+ self._ext_original_txt,
164
+ self._ext_normalized_txt,
165
+ )
166
+
167
+ def __len__(self) -> int:
168
+ return len(self._walker)
@@ -0,0 +1,107 @@
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Tuple, Union
5
+
6
+ import torchaudio
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+ from torchaudio._internal import download_url_to_file
10
+ from torchaudio.datasets.utils import _extract_tar
11
+
12
+
13
+ _RELEASE_CONFIGS = {
14
+ "release1": {
15
+ "folder_in_archive": "wavs",
16
+ "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
17
+ "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
18
+ }
19
+ }
20
+
21
+
22
+ class LJSPEECH(Dataset):
23
+ """*LJSpeech-1.1* :cite:`ljspeech17` dataset.
24
+
25
+ Args:
26
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
27
+ url (str, optional): The URL to download the dataset from.
28
+ (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
29
+ folder_in_archive (str, optional):
30
+ The top-level directory of the dataset. (default: ``"wavs"``)
31
+ download (bool, optional):
32
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ root: Union[str, Path],
38
+ url: str = _RELEASE_CONFIGS["release1"]["url"],
39
+ folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
40
+ download: bool = False,
41
+ ) -> None:
42
+
43
+ self._parse_filesystem(root, url, folder_in_archive, download)
44
+
45
+ def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
46
+ root = Path(root)
47
+
48
+ basename = os.path.basename(url)
49
+ archive = root / basename
50
+
51
+ basename = Path(basename.split(".tar.bz2")[0])
52
+ folder_in_archive = basename / folder_in_archive
53
+
54
+ self._path = root / folder_in_archive
55
+ self._metadata_path = root / basename / "metadata.csv"
56
+
57
+ if download:
58
+ if not os.path.isdir(self._path):
59
+ if not os.path.isfile(archive):
60
+ checksum = _RELEASE_CONFIGS["release1"]["checksum"]
61
+ download_url_to_file(url, archive, hash_prefix=checksum)
62
+ _extract_tar(archive)
63
+ else:
64
+ if not os.path.exists(self._path):
65
+ raise RuntimeError(
66
+ f"The path {self._path} doesn't exist. "
67
+ "Please check the ``root`` path or set `download=True` to download it"
68
+ )
69
+
70
+ with open(self._metadata_path, "r", newline="") as metadata:
71
+ flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
72
+ self._flist = list(flist)
73
+
74
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
75
+ """Load the n-th sample from the dataset.
76
+
77
+ Args:
78
+ n (int): The index of the sample to be loaded
79
+
80
+ Returns:
81
+ Tuple of the following items;
82
+
83
+ Tensor:
84
+ Waveform
85
+ int:
86
+ Sample rate
87
+ str:
88
+ Transcript
89
+ str:
90
+ Normalized Transcript
91
+ """
92
+ line = self._flist[n]
93
+ fileid, transcript, normalized_transcript = line
94
+ fileid_audio = self._path / (fileid + ".wav")
95
+
96
+ # Load audio
97
+ waveform, sample_rate = torchaudio.load(fileid_audio)
98
+
99
+ return (
100
+ waveform,
101
+ sample_rate,
102
+ transcript,
103
+ normalized_transcript,
104
+ )
105
+
106
+ def __len__(self) -> int:
107
+ return len(self._flist)
@@ -0,0 +1,139 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torchaudio
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_zip
10
+
11
+ _URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
12
+ _CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
13
+ _EXT = ".wav"
14
+ _SAMPLE_RATE = 44100
15
+ _VALIDATION_SET = [
16
+ "Actions - One Minute Smile",
17
+ "Clara Berry And Wooldog - Waltz For My Victims",
18
+ "Johnny Lokke - Promises & Lies",
19
+ "Patrick Talbot - A Reason To Leave",
20
+ "Triviul - Angelsaint",
21
+ "Alexander Ross - Goodbye Bolero",
22
+ "Fergessen - Nos Palpitants",
23
+ "Leaf - Summerghost",
24
+ "Skelpolu - Human Mistakes",
25
+ "Young Griffo - Pennies",
26
+ "ANiMAL - Rockshow",
27
+ "James May - On The Line",
28
+ "Meaxic - Take A Step",
29
+ "Traffic Experiment - Sirens",
30
+ ]
31
+
32
+
33
+ class MUSDB_HQ(Dataset):
34
+ """*MUSDB_HQ* :cite:`MUSDB18HQ` dataset.
35
+
36
+ Args:
37
+ root (str or Path): Root directory where the dataset's top level directory is found
38
+ subset (str): Subset of the dataset to use. Options: [``"train"``, ``"test"``].
39
+ sources (List[str] or None, optional): Sources extract data from.
40
+ List can contain the following options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
41
+ If ``None``, dataset consists of tracks except mixture.
42
+ (default: ``None``)
43
+ split (str or None, optional): Whether to split training set into train and validation set.
44
+ If ``None``, no splitting occurs. If ``train`` or ``validation``, returns respective set.
45
+ (default: ``None``)
46
+ download (bool, optional): Whether to download the dataset if it is not found at root path.
47
+ (default: ``False``)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ root: Union[str, Path],
53
+ subset: str,
54
+ sources: Optional[List[str]] = None,
55
+ split: Optional[str] = None,
56
+ download: bool = False,
57
+ ) -> None:
58
+ self.sources = ["bass", "drums", "other", "vocals"] if not sources else sources
59
+ self.split = split
60
+
61
+ basename = os.path.basename(_URL)
62
+ archive = os.path.join(root, basename)
63
+ basename = basename.rsplit(".", 2)[0]
64
+
65
+ if subset not in ["test", "train"]:
66
+ raise ValueError("`subset` must be one of ['test', 'train']")
67
+ if self.split is not None and self.split not in ["train", "validation"]:
68
+ raise ValueError("`split` must be one of ['train', 'validation']")
69
+ base_path = os.path.join(root, basename)
70
+ self._path = os.path.join(base_path, subset)
71
+ if not os.path.isdir(self._path):
72
+ if not os.path.isfile(archive):
73
+ if not download:
74
+ raise RuntimeError("Dataset not found. Please use `download=True` to download")
75
+ download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
76
+ os.makedirs(base_path, exist_ok=True)
77
+ _extract_zip(archive, base_path)
78
+
79
+ self.names = self._collect_songs()
80
+
81
+ def _get_track(self, name, source):
82
+ return Path(self._path) / name / f"{source}{_EXT}"
83
+
84
+ def _load_sample(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
85
+ name = self.names[n]
86
+ wavs = []
87
+ num_frames = None
88
+ for source in self.sources:
89
+ track = self._get_track(name, source)
90
+ wav, sr = torchaudio.load(str(track))
91
+ if sr != _SAMPLE_RATE:
92
+ raise ValueError(f"expected sample rate {_SAMPLE_RATE}, but got {sr}")
93
+ if num_frames is None:
94
+ num_frames = wav.shape[-1]
95
+ else:
96
+ if wav.shape[-1] != num_frames:
97
+ raise ValueError("num_frames do not match across sources")
98
+ wavs.append(wav)
99
+
100
+ stacked = torch.stack(wavs)
101
+
102
+ return stacked, _SAMPLE_RATE, num_frames, name
103
+
104
+ def _collect_songs(self):
105
+ if self.split == "validation":
106
+ return _VALIDATION_SET
107
+ path = Path(self._path)
108
+ names = []
109
+ for root, folders, _ in os.walk(path, followlinks=True):
110
+ root = Path(root)
111
+ if root.name.startswith(".") or folders or root == path:
112
+ continue
113
+ name = str(root.relative_to(path))
114
+ if self.split and name in _VALIDATION_SET:
115
+ continue
116
+ names.append(name)
117
+ return sorted(names)
118
+
119
+ def __getitem__(self, n: int) -> Tuple[torch.Tensor, int, int, str]:
120
+ """Load the n-th sample from the dataset.
121
+
122
+ Args:
123
+ n (int): The index of the sample to be loaded
124
+ Returns:
125
+ Tuple of the following items;
126
+
127
+ Tensor:
128
+ Waveform
129
+ int:
130
+ Sample rate
131
+ int:
132
+ Num frames
133
+ str:
134
+ Track name
135
+ """
136
+ return self._load_sample(n)
137
+
138
+ def __len__(self) -> int:
139
+ return len(self.names)