torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (86) hide show
  1. torchaudio/.dylibs/libc++.1.0.dylib +0 -0
  2. torchaudio/__init__.py +204 -0
  3. torchaudio/_extension/__init__.py +61 -0
  4. torchaudio/_extension/utils.py +133 -0
  5. torchaudio/_internal/__init__.py +10 -0
  6. torchaudio/_internal/module_utils.py +171 -0
  7. torchaudio/_torchcodec.py +340 -0
  8. torchaudio/compliance/__init__.py +5 -0
  9. torchaudio/compliance/kaldi.py +813 -0
  10. torchaudio/datasets/__init__.py +47 -0
  11. torchaudio/datasets/cmuarctic.py +157 -0
  12. torchaudio/datasets/cmudict.py +186 -0
  13. torchaudio/datasets/commonvoice.py +86 -0
  14. torchaudio/datasets/dr_vctk.py +121 -0
  15. torchaudio/datasets/fluentcommands.py +108 -0
  16. torchaudio/datasets/gtzan.py +1118 -0
  17. torchaudio/datasets/iemocap.py +147 -0
  18. torchaudio/datasets/librilight_limited.py +111 -0
  19. torchaudio/datasets/librimix.py +133 -0
  20. torchaudio/datasets/librispeech.py +174 -0
  21. torchaudio/datasets/librispeech_biasing.py +189 -0
  22. torchaudio/datasets/libritts.py +168 -0
  23. torchaudio/datasets/ljspeech.py +107 -0
  24. torchaudio/datasets/musdb_hq.py +139 -0
  25. torchaudio/datasets/quesst14.py +136 -0
  26. torchaudio/datasets/snips.py +157 -0
  27. torchaudio/datasets/speechcommands.py +183 -0
  28. torchaudio/datasets/tedlium.py +218 -0
  29. torchaudio/datasets/utils.py +54 -0
  30. torchaudio/datasets/vctk.py +143 -0
  31. torchaudio/datasets/voxceleb1.py +309 -0
  32. torchaudio/datasets/yesno.py +89 -0
  33. torchaudio/functional/__init__.py +130 -0
  34. torchaudio/functional/_alignment.py +128 -0
  35. torchaudio/functional/filtering.py +1685 -0
  36. torchaudio/functional/functional.py +2505 -0
  37. torchaudio/lib/__init__.py +0 -0
  38. torchaudio/lib/_torchaudio.so +0 -0
  39. torchaudio/lib/libtorchaudio.so +0 -0
  40. torchaudio/models/__init__.py +85 -0
  41. torchaudio/models/_hdemucs.py +1008 -0
  42. torchaudio/models/conformer.py +293 -0
  43. torchaudio/models/conv_tasnet.py +330 -0
  44. torchaudio/models/decoder/__init__.py +64 -0
  45. torchaudio/models/decoder/_ctc_decoder.py +568 -0
  46. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  47. torchaudio/models/deepspeech.py +84 -0
  48. torchaudio/models/emformer.py +884 -0
  49. torchaudio/models/rnnt.py +816 -0
  50. torchaudio/models/rnnt_decoder.py +339 -0
  51. torchaudio/models/squim/__init__.py +11 -0
  52. torchaudio/models/squim/objective.py +326 -0
  53. torchaudio/models/squim/subjective.py +150 -0
  54. torchaudio/models/tacotron2.py +1046 -0
  55. torchaudio/models/wav2letter.py +72 -0
  56. torchaudio/models/wav2vec2/__init__.py +45 -0
  57. torchaudio/models/wav2vec2/components.py +1167 -0
  58. torchaudio/models/wav2vec2/model.py +1579 -0
  59. torchaudio/models/wav2vec2/utils/__init__.py +7 -0
  60. torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
  61. torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
  62. torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
  63. torchaudio/models/wavernn.py +409 -0
  64. torchaudio/pipelines/__init__.py +102 -0
  65. torchaudio/pipelines/_source_separation_pipeline.py +109 -0
  66. torchaudio/pipelines/_squim_pipeline.py +156 -0
  67. torchaudio/pipelines/_tts/__init__.py +16 -0
  68. torchaudio/pipelines/_tts/impl.py +385 -0
  69. torchaudio/pipelines/_tts/interface.py +255 -0
  70. torchaudio/pipelines/_tts/utils.py +230 -0
  71. torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
  72. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  73. torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
  74. torchaudio/pipelines/_wav2vec2/utils.py +346 -0
  75. torchaudio/pipelines/rnnt_pipeline.py +380 -0
  76. torchaudio/transforms/__init__.py +78 -0
  77. torchaudio/transforms/_multi_channel.py +467 -0
  78. torchaudio/transforms/_transforms.py +2138 -0
  79. torchaudio/utils/__init__.py +4 -0
  80. torchaudio/utils/download.py +89 -0
  81. torchaudio/version.py +2 -0
  82. torchaudio-2.9.0.dist-info/LICENSE +25 -0
  83. torchaudio-2.9.0.dist-info/METADATA +122 -0
  84. torchaudio-2.9.0.dist-info/RECORD +86 -0
  85. torchaudio-2.9.0.dist-info/WHEEL +5 -0
  86. torchaudio-2.9.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1118 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torchaudio
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchaudio._internal import download_url_to_file
9
+ from torchaudio.datasets.utils import _extract_tar
10
+
11
+ # The following lists prefixed with `filtered_` provide a filtered split
12
+ # that:
13
+ #
14
+ # a. Mitigate a known issue with GTZAN (duplication)
15
+ #
16
+ # b. Provide a standard split for testing it against other
17
+ # methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
18
+ #
19
+ # Those are used when GTZAN is initialised with the `filtered` keyword.
20
+ # The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
21
+
22
+ gtzan_genres = [
23
+ "blues",
24
+ "classical",
25
+ "country",
26
+ "disco",
27
+ "hiphop",
28
+ "jazz",
29
+ "metal",
30
+ "pop",
31
+ "reggae",
32
+ "rock",
33
+ ]
34
+
35
+ filtered_test = [
36
+ "blues.00012",
37
+ "blues.00013",
38
+ "blues.00014",
39
+ "blues.00015",
40
+ "blues.00016",
41
+ "blues.00017",
42
+ "blues.00018",
43
+ "blues.00019",
44
+ "blues.00020",
45
+ "blues.00021",
46
+ "blues.00022",
47
+ "blues.00023",
48
+ "blues.00024",
49
+ "blues.00025",
50
+ "blues.00026",
51
+ "blues.00027",
52
+ "blues.00028",
53
+ "blues.00061",
54
+ "blues.00062",
55
+ "blues.00063",
56
+ "blues.00064",
57
+ "blues.00065",
58
+ "blues.00066",
59
+ "blues.00067",
60
+ "blues.00068",
61
+ "blues.00069",
62
+ "blues.00070",
63
+ "blues.00071",
64
+ "blues.00072",
65
+ "blues.00098",
66
+ "blues.00099",
67
+ "classical.00011",
68
+ "classical.00012",
69
+ "classical.00013",
70
+ "classical.00014",
71
+ "classical.00015",
72
+ "classical.00016",
73
+ "classical.00017",
74
+ "classical.00018",
75
+ "classical.00019",
76
+ "classical.00020",
77
+ "classical.00021",
78
+ "classical.00022",
79
+ "classical.00023",
80
+ "classical.00024",
81
+ "classical.00025",
82
+ "classical.00026",
83
+ "classical.00027",
84
+ "classical.00028",
85
+ "classical.00029",
86
+ "classical.00034",
87
+ "classical.00035",
88
+ "classical.00036",
89
+ "classical.00037",
90
+ "classical.00038",
91
+ "classical.00039",
92
+ "classical.00040",
93
+ "classical.00041",
94
+ "classical.00049",
95
+ "classical.00077",
96
+ "classical.00078",
97
+ "classical.00079",
98
+ "country.00030",
99
+ "country.00031",
100
+ "country.00032",
101
+ "country.00033",
102
+ "country.00034",
103
+ "country.00035",
104
+ "country.00036",
105
+ "country.00037",
106
+ "country.00038",
107
+ "country.00039",
108
+ "country.00040",
109
+ "country.00043",
110
+ "country.00044",
111
+ "country.00046",
112
+ "country.00047",
113
+ "country.00048",
114
+ "country.00050",
115
+ "country.00051",
116
+ "country.00053",
117
+ "country.00054",
118
+ "country.00055",
119
+ "country.00056",
120
+ "country.00057",
121
+ "country.00058",
122
+ "country.00059",
123
+ "country.00060",
124
+ "country.00061",
125
+ "country.00062",
126
+ "country.00063",
127
+ "country.00064",
128
+ "disco.00001",
129
+ "disco.00021",
130
+ "disco.00058",
131
+ "disco.00062",
132
+ "disco.00063",
133
+ "disco.00064",
134
+ "disco.00065",
135
+ "disco.00066",
136
+ "disco.00069",
137
+ "disco.00076",
138
+ "disco.00077",
139
+ "disco.00078",
140
+ "disco.00079",
141
+ "disco.00080",
142
+ "disco.00081",
143
+ "disco.00082",
144
+ "disco.00083",
145
+ "disco.00084",
146
+ "disco.00085",
147
+ "disco.00086",
148
+ "disco.00087",
149
+ "disco.00088",
150
+ "disco.00091",
151
+ "disco.00092",
152
+ "disco.00093",
153
+ "disco.00094",
154
+ "disco.00096",
155
+ "disco.00097",
156
+ "disco.00099",
157
+ "hiphop.00000",
158
+ "hiphop.00026",
159
+ "hiphop.00027",
160
+ "hiphop.00030",
161
+ "hiphop.00040",
162
+ "hiphop.00043",
163
+ "hiphop.00044",
164
+ "hiphop.00045",
165
+ "hiphop.00051",
166
+ "hiphop.00052",
167
+ "hiphop.00053",
168
+ "hiphop.00054",
169
+ "hiphop.00062",
170
+ "hiphop.00063",
171
+ "hiphop.00064",
172
+ "hiphop.00065",
173
+ "hiphop.00066",
174
+ "hiphop.00067",
175
+ "hiphop.00068",
176
+ "hiphop.00069",
177
+ "hiphop.00070",
178
+ "hiphop.00071",
179
+ "hiphop.00072",
180
+ "hiphop.00073",
181
+ "hiphop.00074",
182
+ "hiphop.00075",
183
+ "hiphop.00099",
184
+ "jazz.00073",
185
+ "jazz.00074",
186
+ "jazz.00075",
187
+ "jazz.00076",
188
+ "jazz.00077",
189
+ "jazz.00078",
190
+ "jazz.00079",
191
+ "jazz.00080",
192
+ "jazz.00081",
193
+ "jazz.00082",
194
+ "jazz.00083",
195
+ "jazz.00084",
196
+ "jazz.00085",
197
+ "jazz.00086",
198
+ "jazz.00087",
199
+ "jazz.00088",
200
+ "jazz.00089",
201
+ "jazz.00090",
202
+ "jazz.00091",
203
+ "jazz.00092",
204
+ "jazz.00093",
205
+ "jazz.00094",
206
+ "jazz.00095",
207
+ "jazz.00096",
208
+ "jazz.00097",
209
+ "jazz.00098",
210
+ "jazz.00099",
211
+ "metal.00012",
212
+ "metal.00013",
213
+ "metal.00014",
214
+ "metal.00015",
215
+ "metal.00022",
216
+ "metal.00023",
217
+ "metal.00025",
218
+ "metal.00026",
219
+ "metal.00027",
220
+ "metal.00028",
221
+ "metal.00029",
222
+ "metal.00030",
223
+ "metal.00031",
224
+ "metal.00032",
225
+ "metal.00033",
226
+ "metal.00038",
227
+ "metal.00039",
228
+ "metal.00067",
229
+ "metal.00070",
230
+ "metal.00073",
231
+ "metal.00074",
232
+ "metal.00075",
233
+ "metal.00078",
234
+ "metal.00083",
235
+ "metal.00085",
236
+ "metal.00087",
237
+ "metal.00088",
238
+ "pop.00000",
239
+ "pop.00001",
240
+ "pop.00013",
241
+ "pop.00014",
242
+ "pop.00043",
243
+ "pop.00063",
244
+ "pop.00064",
245
+ "pop.00065",
246
+ "pop.00066",
247
+ "pop.00069",
248
+ "pop.00070",
249
+ "pop.00071",
250
+ "pop.00072",
251
+ "pop.00073",
252
+ "pop.00074",
253
+ "pop.00075",
254
+ "pop.00076",
255
+ "pop.00077",
256
+ "pop.00078",
257
+ "pop.00079",
258
+ "pop.00082",
259
+ "pop.00088",
260
+ "pop.00089",
261
+ "pop.00090",
262
+ "pop.00091",
263
+ "pop.00092",
264
+ "pop.00093",
265
+ "pop.00094",
266
+ "pop.00095",
267
+ "pop.00096",
268
+ "reggae.00034",
269
+ "reggae.00035",
270
+ "reggae.00036",
271
+ "reggae.00037",
272
+ "reggae.00038",
273
+ "reggae.00039",
274
+ "reggae.00040",
275
+ "reggae.00046",
276
+ "reggae.00047",
277
+ "reggae.00048",
278
+ "reggae.00052",
279
+ "reggae.00053",
280
+ "reggae.00064",
281
+ "reggae.00065",
282
+ "reggae.00066",
283
+ "reggae.00067",
284
+ "reggae.00068",
285
+ "reggae.00071",
286
+ "reggae.00079",
287
+ "reggae.00082",
288
+ "reggae.00083",
289
+ "reggae.00084",
290
+ "reggae.00087",
291
+ "reggae.00088",
292
+ "reggae.00089",
293
+ "reggae.00090",
294
+ "rock.00010",
295
+ "rock.00011",
296
+ "rock.00012",
297
+ "rock.00013",
298
+ "rock.00014",
299
+ "rock.00015",
300
+ "rock.00027",
301
+ "rock.00028",
302
+ "rock.00029",
303
+ "rock.00030",
304
+ "rock.00031",
305
+ "rock.00032",
306
+ "rock.00033",
307
+ "rock.00034",
308
+ "rock.00035",
309
+ "rock.00036",
310
+ "rock.00037",
311
+ "rock.00039",
312
+ "rock.00040",
313
+ "rock.00041",
314
+ "rock.00042",
315
+ "rock.00043",
316
+ "rock.00044",
317
+ "rock.00045",
318
+ "rock.00046",
319
+ "rock.00047",
320
+ "rock.00048",
321
+ "rock.00086",
322
+ "rock.00087",
323
+ "rock.00088",
324
+ "rock.00089",
325
+ "rock.00090",
326
+ ]
327
+
328
+ filtered_train = [
329
+ "blues.00029",
330
+ "blues.00030",
331
+ "blues.00031",
332
+ "blues.00032",
333
+ "blues.00033",
334
+ "blues.00034",
335
+ "blues.00035",
336
+ "blues.00036",
337
+ "blues.00037",
338
+ "blues.00038",
339
+ "blues.00039",
340
+ "blues.00040",
341
+ "blues.00041",
342
+ "blues.00042",
343
+ "blues.00043",
344
+ "blues.00044",
345
+ "blues.00045",
346
+ "blues.00046",
347
+ "blues.00047",
348
+ "blues.00048",
349
+ "blues.00049",
350
+ "blues.00073",
351
+ "blues.00074",
352
+ "blues.00075",
353
+ "blues.00076",
354
+ "blues.00077",
355
+ "blues.00078",
356
+ "blues.00079",
357
+ "blues.00080",
358
+ "blues.00081",
359
+ "blues.00082",
360
+ "blues.00083",
361
+ "blues.00084",
362
+ "blues.00085",
363
+ "blues.00086",
364
+ "blues.00087",
365
+ "blues.00088",
366
+ "blues.00089",
367
+ "blues.00090",
368
+ "blues.00091",
369
+ "blues.00092",
370
+ "blues.00093",
371
+ "blues.00094",
372
+ "blues.00095",
373
+ "blues.00096",
374
+ "blues.00097",
375
+ "classical.00030",
376
+ "classical.00031",
377
+ "classical.00032",
378
+ "classical.00033",
379
+ "classical.00043",
380
+ "classical.00044",
381
+ "classical.00045",
382
+ "classical.00046",
383
+ "classical.00047",
384
+ "classical.00048",
385
+ "classical.00050",
386
+ "classical.00051",
387
+ "classical.00052",
388
+ "classical.00053",
389
+ "classical.00054",
390
+ "classical.00055",
391
+ "classical.00056",
392
+ "classical.00057",
393
+ "classical.00058",
394
+ "classical.00059",
395
+ "classical.00060",
396
+ "classical.00061",
397
+ "classical.00062",
398
+ "classical.00063",
399
+ "classical.00064",
400
+ "classical.00065",
401
+ "classical.00066",
402
+ "classical.00067",
403
+ "classical.00080",
404
+ "classical.00081",
405
+ "classical.00082",
406
+ "classical.00083",
407
+ "classical.00084",
408
+ "classical.00085",
409
+ "classical.00086",
410
+ "classical.00087",
411
+ "classical.00088",
412
+ "classical.00089",
413
+ "classical.00090",
414
+ "classical.00091",
415
+ "classical.00092",
416
+ "classical.00093",
417
+ "classical.00094",
418
+ "classical.00095",
419
+ "classical.00096",
420
+ "classical.00097",
421
+ "classical.00098",
422
+ "classical.00099",
423
+ "country.00019",
424
+ "country.00020",
425
+ "country.00021",
426
+ "country.00022",
427
+ "country.00023",
428
+ "country.00024",
429
+ "country.00025",
430
+ "country.00026",
431
+ "country.00028",
432
+ "country.00029",
433
+ "country.00065",
434
+ "country.00066",
435
+ "country.00067",
436
+ "country.00068",
437
+ "country.00069",
438
+ "country.00070",
439
+ "country.00071",
440
+ "country.00072",
441
+ "country.00073",
442
+ "country.00074",
443
+ "country.00075",
444
+ "country.00076",
445
+ "country.00077",
446
+ "country.00078",
447
+ "country.00079",
448
+ "country.00080",
449
+ "country.00081",
450
+ "country.00082",
451
+ "country.00083",
452
+ "country.00084",
453
+ "country.00085",
454
+ "country.00086",
455
+ "country.00087",
456
+ "country.00088",
457
+ "country.00089",
458
+ "country.00090",
459
+ "country.00091",
460
+ "country.00092",
461
+ "country.00093",
462
+ "country.00094",
463
+ "country.00095",
464
+ "country.00096",
465
+ "country.00097",
466
+ "country.00098",
467
+ "country.00099",
468
+ "disco.00005",
469
+ "disco.00015",
470
+ "disco.00016",
471
+ "disco.00017",
472
+ "disco.00018",
473
+ "disco.00019",
474
+ "disco.00020",
475
+ "disco.00022",
476
+ "disco.00023",
477
+ "disco.00024",
478
+ "disco.00025",
479
+ "disco.00026",
480
+ "disco.00027",
481
+ "disco.00028",
482
+ "disco.00029",
483
+ "disco.00030",
484
+ "disco.00031",
485
+ "disco.00032",
486
+ "disco.00033",
487
+ "disco.00034",
488
+ "disco.00035",
489
+ "disco.00036",
490
+ "disco.00037",
491
+ "disco.00039",
492
+ "disco.00040",
493
+ "disco.00041",
494
+ "disco.00042",
495
+ "disco.00043",
496
+ "disco.00044",
497
+ "disco.00045",
498
+ "disco.00047",
499
+ "disco.00049",
500
+ "disco.00053",
501
+ "disco.00054",
502
+ "disco.00056",
503
+ "disco.00057",
504
+ "disco.00059",
505
+ "disco.00061",
506
+ "disco.00070",
507
+ "disco.00073",
508
+ "disco.00074",
509
+ "disco.00089",
510
+ "hiphop.00002",
511
+ "hiphop.00003",
512
+ "hiphop.00004",
513
+ "hiphop.00005",
514
+ "hiphop.00006",
515
+ "hiphop.00007",
516
+ "hiphop.00008",
517
+ "hiphop.00009",
518
+ "hiphop.00010",
519
+ "hiphop.00011",
520
+ "hiphop.00012",
521
+ "hiphop.00013",
522
+ "hiphop.00014",
523
+ "hiphop.00015",
524
+ "hiphop.00016",
525
+ "hiphop.00017",
526
+ "hiphop.00018",
527
+ "hiphop.00019",
528
+ "hiphop.00020",
529
+ "hiphop.00021",
530
+ "hiphop.00022",
531
+ "hiphop.00023",
532
+ "hiphop.00024",
533
+ "hiphop.00025",
534
+ "hiphop.00028",
535
+ "hiphop.00029",
536
+ "hiphop.00031",
537
+ "hiphop.00032",
538
+ "hiphop.00033",
539
+ "hiphop.00034",
540
+ "hiphop.00035",
541
+ "hiphop.00036",
542
+ "hiphop.00037",
543
+ "hiphop.00038",
544
+ "hiphop.00041",
545
+ "hiphop.00042",
546
+ "hiphop.00055",
547
+ "hiphop.00056",
548
+ "hiphop.00057",
549
+ "hiphop.00058",
550
+ "hiphop.00059",
551
+ "hiphop.00060",
552
+ "hiphop.00061",
553
+ "hiphop.00077",
554
+ "hiphop.00078",
555
+ "hiphop.00079",
556
+ "hiphop.00080",
557
+ "jazz.00000",
558
+ "jazz.00001",
559
+ "jazz.00011",
560
+ "jazz.00012",
561
+ "jazz.00013",
562
+ "jazz.00014",
563
+ "jazz.00015",
564
+ "jazz.00016",
565
+ "jazz.00017",
566
+ "jazz.00018",
567
+ "jazz.00019",
568
+ "jazz.00020",
569
+ "jazz.00021",
570
+ "jazz.00022",
571
+ "jazz.00023",
572
+ "jazz.00024",
573
+ "jazz.00041",
574
+ "jazz.00047",
575
+ "jazz.00048",
576
+ "jazz.00049",
577
+ "jazz.00050",
578
+ "jazz.00051",
579
+ "jazz.00052",
580
+ "jazz.00053",
581
+ "jazz.00054",
582
+ "jazz.00055",
583
+ "jazz.00056",
584
+ "jazz.00057",
585
+ "jazz.00058",
586
+ "jazz.00059",
587
+ "jazz.00060",
588
+ "jazz.00061",
589
+ "jazz.00062",
590
+ "jazz.00063",
591
+ "jazz.00064",
592
+ "jazz.00065",
593
+ "jazz.00066",
594
+ "jazz.00067",
595
+ "jazz.00068",
596
+ "jazz.00069",
597
+ "jazz.00070",
598
+ "jazz.00071",
599
+ "jazz.00072",
600
+ "metal.00002",
601
+ "metal.00003",
602
+ "metal.00005",
603
+ "metal.00021",
604
+ "metal.00024",
605
+ "metal.00035",
606
+ "metal.00046",
607
+ "metal.00047",
608
+ "metal.00048",
609
+ "metal.00049",
610
+ "metal.00050",
611
+ "metal.00051",
612
+ "metal.00052",
613
+ "metal.00053",
614
+ "metal.00054",
615
+ "metal.00055",
616
+ "metal.00056",
617
+ "metal.00057",
618
+ "metal.00059",
619
+ "metal.00060",
620
+ "metal.00061",
621
+ "metal.00062",
622
+ "metal.00063",
623
+ "metal.00064",
624
+ "metal.00065",
625
+ "metal.00066",
626
+ "metal.00069",
627
+ "metal.00071",
628
+ "metal.00072",
629
+ "metal.00079",
630
+ "metal.00080",
631
+ "metal.00084",
632
+ "metal.00086",
633
+ "metal.00089",
634
+ "metal.00090",
635
+ "metal.00091",
636
+ "metal.00092",
637
+ "metal.00093",
638
+ "metal.00094",
639
+ "metal.00095",
640
+ "metal.00096",
641
+ "metal.00097",
642
+ "metal.00098",
643
+ "metal.00099",
644
+ "pop.00002",
645
+ "pop.00003",
646
+ "pop.00004",
647
+ "pop.00005",
648
+ "pop.00006",
649
+ "pop.00007",
650
+ "pop.00008",
651
+ "pop.00009",
652
+ "pop.00011",
653
+ "pop.00012",
654
+ "pop.00016",
655
+ "pop.00017",
656
+ "pop.00018",
657
+ "pop.00019",
658
+ "pop.00020",
659
+ "pop.00023",
660
+ "pop.00024",
661
+ "pop.00025",
662
+ "pop.00026",
663
+ "pop.00027",
664
+ "pop.00028",
665
+ "pop.00029",
666
+ "pop.00031",
667
+ "pop.00032",
668
+ "pop.00033",
669
+ "pop.00034",
670
+ "pop.00035",
671
+ "pop.00036",
672
+ "pop.00038",
673
+ "pop.00039",
674
+ "pop.00040",
675
+ "pop.00041",
676
+ "pop.00042",
677
+ "pop.00044",
678
+ "pop.00046",
679
+ "pop.00049",
680
+ "pop.00050",
681
+ "pop.00080",
682
+ "pop.00097",
683
+ "pop.00098",
684
+ "pop.00099",
685
+ "reggae.00000",
686
+ "reggae.00001",
687
+ "reggae.00002",
688
+ "reggae.00004",
689
+ "reggae.00006",
690
+ "reggae.00009",
691
+ "reggae.00011",
692
+ "reggae.00012",
693
+ "reggae.00014",
694
+ "reggae.00015",
695
+ "reggae.00016",
696
+ "reggae.00017",
697
+ "reggae.00018",
698
+ "reggae.00019",
699
+ "reggae.00020",
700
+ "reggae.00021",
701
+ "reggae.00022",
702
+ "reggae.00023",
703
+ "reggae.00024",
704
+ "reggae.00025",
705
+ "reggae.00026",
706
+ "reggae.00027",
707
+ "reggae.00028",
708
+ "reggae.00029",
709
+ "reggae.00030",
710
+ "reggae.00031",
711
+ "reggae.00032",
712
+ "reggae.00042",
713
+ "reggae.00043",
714
+ "reggae.00044",
715
+ "reggae.00045",
716
+ "reggae.00049",
717
+ "reggae.00050",
718
+ "reggae.00051",
719
+ "reggae.00054",
720
+ "reggae.00055",
721
+ "reggae.00056",
722
+ "reggae.00057",
723
+ "reggae.00058",
724
+ "reggae.00059",
725
+ "reggae.00060",
726
+ "reggae.00063",
727
+ "reggae.00069",
728
+ "rock.00000",
729
+ "rock.00001",
730
+ "rock.00002",
731
+ "rock.00003",
732
+ "rock.00004",
733
+ "rock.00005",
734
+ "rock.00006",
735
+ "rock.00007",
736
+ "rock.00008",
737
+ "rock.00009",
738
+ "rock.00016",
739
+ "rock.00017",
740
+ "rock.00018",
741
+ "rock.00019",
742
+ "rock.00020",
743
+ "rock.00021",
744
+ "rock.00022",
745
+ "rock.00023",
746
+ "rock.00024",
747
+ "rock.00025",
748
+ "rock.00026",
749
+ "rock.00057",
750
+ "rock.00058",
751
+ "rock.00059",
752
+ "rock.00060",
753
+ "rock.00061",
754
+ "rock.00062",
755
+ "rock.00063",
756
+ "rock.00064",
757
+ "rock.00065",
758
+ "rock.00066",
759
+ "rock.00067",
760
+ "rock.00068",
761
+ "rock.00069",
762
+ "rock.00070",
763
+ "rock.00091",
764
+ "rock.00092",
765
+ "rock.00093",
766
+ "rock.00094",
767
+ "rock.00095",
768
+ "rock.00096",
769
+ "rock.00097",
770
+ "rock.00098",
771
+ "rock.00099",
772
+ ]
773
+
774
+ filtered_valid = [
775
+ "blues.00000",
776
+ "blues.00001",
777
+ "blues.00002",
778
+ "blues.00003",
779
+ "blues.00004",
780
+ "blues.00005",
781
+ "blues.00006",
782
+ "blues.00007",
783
+ "blues.00008",
784
+ "blues.00009",
785
+ "blues.00010",
786
+ "blues.00011",
787
+ "blues.00050",
788
+ "blues.00051",
789
+ "blues.00052",
790
+ "blues.00053",
791
+ "blues.00054",
792
+ "blues.00055",
793
+ "blues.00056",
794
+ "blues.00057",
795
+ "blues.00058",
796
+ "blues.00059",
797
+ "blues.00060",
798
+ "classical.00000",
799
+ "classical.00001",
800
+ "classical.00002",
801
+ "classical.00003",
802
+ "classical.00004",
803
+ "classical.00005",
804
+ "classical.00006",
805
+ "classical.00007",
806
+ "classical.00008",
807
+ "classical.00009",
808
+ "classical.00010",
809
+ "classical.00068",
810
+ "classical.00069",
811
+ "classical.00070",
812
+ "classical.00071",
813
+ "classical.00072",
814
+ "classical.00073",
815
+ "classical.00074",
816
+ "classical.00075",
817
+ "classical.00076",
818
+ "country.00000",
819
+ "country.00001",
820
+ "country.00002",
821
+ "country.00003",
822
+ "country.00004",
823
+ "country.00005",
824
+ "country.00006",
825
+ "country.00007",
826
+ "country.00009",
827
+ "country.00010",
828
+ "country.00011",
829
+ "country.00012",
830
+ "country.00013",
831
+ "country.00014",
832
+ "country.00015",
833
+ "country.00016",
834
+ "country.00017",
835
+ "country.00018",
836
+ "country.00027",
837
+ "country.00041",
838
+ "country.00042",
839
+ "country.00045",
840
+ "country.00049",
841
+ "disco.00000",
842
+ "disco.00002",
843
+ "disco.00003",
844
+ "disco.00004",
845
+ "disco.00006",
846
+ "disco.00007",
847
+ "disco.00008",
848
+ "disco.00009",
849
+ "disco.00010",
850
+ "disco.00011",
851
+ "disco.00012",
852
+ "disco.00013",
853
+ "disco.00014",
854
+ "disco.00046",
855
+ "disco.00048",
856
+ "disco.00052",
857
+ "disco.00067",
858
+ "disco.00068",
859
+ "disco.00072",
860
+ "disco.00075",
861
+ "disco.00090",
862
+ "disco.00095",
863
+ "hiphop.00081",
864
+ "hiphop.00082",
865
+ "hiphop.00083",
866
+ "hiphop.00084",
867
+ "hiphop.00085",
868
+ "hiphop.00086",
869
+ "hiphop.00087",
870
+ "hiphop.00088",
871
+ "hiphop.00089",
872
+ "hiphop.00090",
873
+ "hiphop.00091",
874
+ "hiphop.00092",
875
+ "hiphop.00093",
876
+ "hiphop.00094",
877
+ "hiphop.00095",
878
+ "hiphop.00096",
879
+ "hiphop.00097",
880
+ "hiphop.00098",
881
+ "jazz.00002",
882
+ "jazz.00003",
883
+ "jazz.00004",
884
+ "jazz.00005",
885
+ "jazz.00006",
886
+ "jazz.00007",
887
+ "jazz.00008",
888
+ "jazz.00009",
889
+ "jazz.00010",
890
+ "jazz.00025",
891
+ "jazz.00026",
892
+ "jazz.00027",
893
+ "jazz.00028",
894
+ "jazz.00029",
895
+ "jazz.00030",
896
+ "jazz.00031",
897
+ "jazz.00032",
898
+ "metal.00000",
899
+ "metal.00001",
900
+ "metal.00006",
901
+ "metal.00007",
902
+ "metal.00008",
903
+ "metal.00009",
904
+ "metal.00010",
905
+ "metal.00011",
906
+ "metal.00016",
907
+ "metal.00017",
908
+ "metal.00018",
909
+ "metal.00019",
910
+ "metal.00020",
911
+ "metal.00036",
912
+ "metal.00037",
913
+ "metal.00068",
914
+ "metal.00076",
915
+ "metal.00077",
916
+ "metal.00081",
917
+ "metal.00082",
918
+ "pop.00010",
919
+ "pop.00053",
920
+ "pop.00055",
921
+ "pop.00058",
922
+ "pop.00059",
923
+ "pop.00060",
924
+ "pop.00061",
925
+ "pop.00062",
926
+ "pop.00081",
927
+ "pop.00083",
928
+ "pop.00084",
929
+ "pop.00085",
930
+ "pop.00086",
931
+ "reggae.00061",
932
+ "reggae.00062",
933
+ "reggae.00070",
934
+ "reggae.00072",
935
+ "reggae.00074",
936
+ "reggae.00076",
937
+ "reggae.00077",
938
+ "reggae.00078",
939
+ "reggae.00085",
940
+ "reggae.00092",
941
+ "reggae.00093",
942
+ "reggae.00094",
943
+ "reggae.00095",
944
+ "reggae.00096",
945
+ "reggae.00097",
946
+ "reggae.00098",
947
+ "reggae.00099",
948
+ "rock.00038",
949
+ "rock.00049",
950
+ "rock.00050",
951
+ "rock.00051",
952
+ "rock.00052",
953
+ "rock.00053",
954
+ "rock.00054",
955
+ "rock.00055",
956
+ "rock.00056",
957
+ "rock.00071",
958
+ "rock.00072",
959
+ "rock.00073",
960
+ "rock.00074",
961
+ "rock.00075",
962
+ "rock.00076",
963
+ "rock.00077",
964
+ "rock.00078",
965
+ "rock.00079",
966
+ "rock.00080",
967
+ "rock.00081",
968
+ "rock.00082",
969
+ "rock.00083",
970
+ "rock.00084",
971
+ "rock.00085",
972
+ ]
973
+
974
+
975
+ URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
976
+ FOLDER_IN_ARCHIVE = "genres"
977
+ _CHECKSUMS = {
978
+ "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
979
+ }
980
+
981
+
982
+ def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
983
+ """
984
+ Loads a file from the dataset and returns the raw waveform
985
+ as a Torch Tensor, its sample rate as an integer, and its
986
+ genre as a string.
987
+ """
988
+ # Filenames are of the form label.id, e.g. blues.00078
989
+ label, _ = fileid.split(".")
990
+
991
+ # Read wav
992
+ file_audio = os.path.join(path, label, fileid + ext_audio)
993
+ waveform, sample_rate = torchaudio.load(file_audio)
994
+
995
+ return waveform, sample_rate, label
996
+
997
+
998
+ class GTZAN(Dataset):
999
+ """*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.
1000
+
1001
+ Note:
1002
+ Please see http://marsyas.info/downloads/datasets.html if you are planning to use
1003
+ this dataset to publish results.
1004
+
1005
+ Note:
1006
+ As of October 2022, the download link is not currently working. Setting ``download=True``
1007
+ in GTZAN dataset will result in a URL connection error.
1008
+
1009
+ Args:
1010
+ root (str or Path): Path to the directory where the dataset is found or downloaded.
1011
+ url (str, optional): The URL to download the dataset from.
1012
+ (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
1013
+ folder_in_archive (str, optional): The top-level directory of the dataset.
1014
+ download (bool, optional):
1015
+ Whether to download the dataset if it is not found at root path. (default: ``False``).
1016
+ subset (str or None, optional): Which subset of the dataset to use.
1017
+ One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
1018
+ If ``None``, the entire dataset is used. (default: ``None``).
1019
+ """
1020
+
1021
+ _ext_audio = ".wav"
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ root: Union[str, Path],
1026
+ url: str = URL,
1027
+ folder_in_archive: str = FOLDER_IN_ARCHIVE,
1028
+ download: bool = False,
1029
+ subset: Optional[str] = None,
1030
+ ) -> None:
1031
+
1032
+ # super(GTZAN, self).__init__()
1033
+
1034
+ # Get string representation of 'root' in case Path object is passed
1035
+ root = os.fspath(root)
1036
+
1037
+ self.root = root
1038
+ self.url = url
1039
+ self.folder_in_archive = folder_in_archive
1040
+ self.download = download
1041
+ self.subset = subset
1042
+
1043
+ if subset is not None and subset not in ["training", "validation", "testing"]:
1044
+ raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")
1045
+
1046
+ archive = os.path.basename(url)
1047
+ archive = os.path.join(root, archive)
1048
+ self._path = os.path.join(root, folder_in_archive)
1049
+
1050
+ if download:
1051
+ if not os.path.isdir(self._path):
1052
+ if not os.path.isfile(archive):
1053
+ checksum = _CHECKSUMS.get(url, None)
1054
+ download_url_to_file(url, archive, hash_prefix=checksum)
1055
+ _extract_tar(archive)
1056
+
1057
+ if not os.path.isdir(self._path):
1058
+ raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
1059
+
1060
+ if self.subset is None:
1061
+ # Check every subdirectory under dataset root
1062
+ # which has the same name as the genres in
1063
+ # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
1064
+ # This lets users remove or move around song files,
1065
+ # useful when e.g. they want to use only some of the files
1066
+ # in a genre or want to label other files with a different
1067
+ # genre.
1068
+ self._walker = []
1069
+
1070
+ root = os.path.expanduser(self._path)
1071
+
1072
+ for directory in gtzan_genres:
1073
+ fulldir = os.path.join(root, directory)
1074
+
1075
+ if not os.path.exists(fulldir):
1076
+ continue
1077
+
1078
+ songs_in_genre = os.listdir(fulldir)
1079
+ songs_in_genre.sort()
1080
+ for fname in songs_in_genre:
1081
+ name, ext = os.path.splitext(fname)
1082
+ if ext.lower() == ".wav" and "." in name:
1083
+ # Check whether the file is of the form
1084
+ # `gtzan_genre`.`5 digit number`.wav
1085
+ genre, num = name.split(".")
1086
+ if genre in gtzan_genres and len(num) == 5 and num.isdigit():
1087
+ self._walker.append(name)
1088
+ else:
1089
+ if self.subset == "training":
1090
+ self._walker = filtered_train
1091
+ elif self.subset == "validation":
1092
+ self._walker = filtered_valid
1093
+ elif self.subset == "testing":
1094
+ self._walker = filtered_test
1095
+
1096
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
1097
+ """Load the n-th sample from the dataset.
1098
+
1099
+ Args:
1100
+ n (int): The index of the sample to be loaded
1101
+
1102
+ Returns:
1103
+ Tuple of the following items;
1104
+
1105
+ Tensor:
1106
+ Waveform
1107
+ int:
1108
+ Sample rate
1109
+ str:
1110
+ Label
1111
+ """
1112
+ fileid = self._walker[n]
1113
+ item = load_gtzan_item(fileid, self._path, self._ext_audio)
1114
+ waveform, sample_rate, label = item
1115
+ return waveform, sample_rate, label
1116
+
1117
+ def __len__(self) -> int:
1118
+ return len(self._walker)