torchaudio 2.0.2__cp310-cp310-manylinux1_x86_64.whl → 2.1.1__cp310-cp310-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (92) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.so +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
  51. torchaudio/lib/_torchaudio_sox.so +0 -0
  52. torchaudio/lib/libctc_prefix_decoder.so +0 -0
  53. torchaudio/lib/libtorchaudio.so +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
  55. torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
  56. torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
  57. torchaudio/lib/libtorchaudio_sox.so +0 -0
  58. torchaudio/lib/pybind11_prefixctc.so +0 -0
  59. torchaudio/models/__init__.py +14 -0
  60. torchaudio/models/decoder/__init__.py +22 -7
  61. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  62. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  63. torchaudio/models/rnnt_decoder.py +10 -14
  64. torchaudio/models/squim/__init__.py +11 -0
  65. torchaudio/models/squim/objective.py +326 -0
  66. torchaudio/models/squim/subjective.py +150 -0
  67. torchaudio/models/wav2vec2/components.py +6 -10
  68. torchaudio/pipelines/__init__.py +9 -0
  69. torchaudio/pipelines/_squim_pipeline.py +176 -0
  70. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  71. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  72. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  73. torchaudio/sox_effects/sox_effects.py +7 -30
  74. torchaudio/transforms/__init__.py +2 -0
  75. torchaudio/transforms/_transforms.py +99 -54
  76. torchaudio/utils/download.py +2 -2
  77. torchaudio/utils/ffmpeg_utils.py +20 -15
  78. torchaudio/utils/sox_utils.py +8 -9
  79. torchaudio/version.py +2 -2
  80. torchaudio-2.1.1.dist-info/METADATA +113 -0
  81. torchaudio-2.1.1.dist-info/RECORD +119 -0
  82. torchaudio/io/_compat.py +0 -241
  83. torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
  84. torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
  85. torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
  86. torchaudio/lib/libflashlight-text.so +0 -0
  87. torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
  88. torchaudio-2.0.2.dist-info/METADATA +0 -26
  89. torchaudio-2.0.2.dist-info/RECORD +0 -100
  90. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  91. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +0 -0
  92. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
@@ -1390,11 +1390,11 @@ def _measure(
1390
1390
  cepstrum_end: int,
1391
1391
  noise_reduction_amount: float,
1392
1392
  measure_smooth_time_mult: float,
1393
- noise_up_time_mult: float,
1394
- noise_down_time_mult: float,
1395
- index_ns: int,
1393
+ noise_up_time_mult: Tensor,
1394
+ noise_down_time_mult: Tensor,
1396
1395
  boot_count: int,
1397
1396
  ) -> float:
1397
+ device = samples.device
1398
1398
 
1399
1399
  if spectrum.size(-1) != noise_spectrum.size(-1):
1400
1400
  raise ValueError(
@@ -1402,37 +1402,29 @@ def _measure(
1402
1402
  f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}"
1403
1403
  )
1404
1404
 
1405
- samplesLen_ns = samples.size()[-1]
1406
1405
  dft_len_ws = spectrum.size()[-1]
1407
1406
 
1408
- dftBuf = torch.zeros(dft_len_ws)
1407
+ dftBuf = torch.zeros(dft_len_ws, device=device)
1409
1408
 
1410
- _index_ns = torch.tensor([index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)])
1411
- dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws]
1412
-
1413
- # memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
1414
- dftBuf[measure_len_ws:dft_len_ws].zero_()
1409
+ dftBuf[:measure_len_ws] = samples * spectrum_window[:measure_len_ws]
1415
1410
 
1416
1411
  # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
1417
1412
  _dftBuf = torch.fft.rfft(dftBuf)
1418
1413
 
1419
- # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
1420
- _dftBuf[:spectrum_start].zero_()
1421
-
1422
1414
  mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
1423
1415
 
1424
1416
  _d = _dftBuf[spectrum_start:spectrum_end].abs()
1425
1417
  spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
1426
1418
  _d = spectrum[spectrum_start:spectrum_end] ** 2
1427
1419
 
1428
- _zeros = torch.zeros(spectrum_end - spectrum_start)
1420
+ _zeros = torch.zeros(spectrum_end - spectrum_start, device=device)
1429
1421
  _mult = (
1430
1422
  _zeros
1431
1423
  if boot_count >= 0
1432
1424
  else torch.where(
1433
1425
  _d > noise_spectrum[spectrum_start:spectrum_end],
1434
- torch.tensor(noise_up_time_mult), # if
1435
- torch.tensor(noise_down_time_mult), # else
1426
+ noise_up_time_mult, # if
1427
+ noise_down_time_mult, # else,
1436
1428
  )
1437
1429
  )
1438
1430
 
@@ -1441,10 +1433,10 @@ def _measure(
1441
1433
  torch.max(
1442
1434
  _zeros,
1443
1435
  _d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end],
1444
- )
1436
+ ),
1445
1437
  )
1446
1438
 
1447
- _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
1439
+ _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1, device=device)
1448
1440
  _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
1449
1441
  _cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
1450
1442
 
@@ -1539,6 +1531,7 @@ def vad(
1539
1531
  Reference:
1540
1532
  - http://sox.sourceforge.net/sox.html
1541
1533
  """
1534
+ device = waveform.device
1542
1535
 
1543
1536
  if waveform.ndim > 2:
1544
1537
  warnings.warn(
@@ -1566,23 +1559,23 @@ def vad(
1566
1559
  fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
1567
1560
  samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
1568
1561
 
1569
- spectrum_window = torch.zeros(measure_len_ws)
1562
+ spectrum_window = torch.zeros(measure_len_ws, device=device)
1570
1563
  for i in range(measure_len_ws):
1571
1564
  # sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
1572
1565
  spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws))
1573
1566
  # lsx_apply_hann(spectrum_window, (int)measure_len_ws);
1574
- spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float)
1567
+ spectrum_window *= torch.hann_window(measure_len_ws, device=device, dtype=torch.float)
1575
1568
 
1576
1569
  spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5)
1577
1570
  spectrum_start: int = max(spectrum_start, 1)
1578
1571
  spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5)
1579
1572
  spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
1580
1573
 
1581
- cepstrum_window = torch.zeros(spectrum_end - spectrum_start)
1574
+ cepstrum_window = torch.zeros(spectrum_end - spectrum_start, device=device)
1582
1575
  for i in range(spectrum_end - spectrum_start):
1583
1576
  cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
1584
1577
  # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
1585
- cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
1578
+ cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, device=device, dtype=torch.float)
1586
1579
 
1587
1580
  cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
1588
1581
  cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
@@ -1594,14 +1587,13 @@ def vad(
1594
1587
  f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}."
1595
1588
  )
1596
1589
 
1597
- noise_up_time_mult = math.exp(-1.0 / (noise_up_time * measure_freq))
1598
- noise_down_time_mult = math.exp(-1.0 / (noise_down_time * measure_freq))
1590
+ noise_up_time_mult = torch.tensor(math.exp(-1.0 / (noise_up_time * measure_freq)), device=device)
1591
+ noise_down_time_mult = torch.tensor(math.exp(-1.0 / (noise_down_time * measure_freq)), device=device)
1599
1592
  measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq))
1600
1593
  trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq))
1601
1594
 
1602
1595
  boot_count_max = int(boot_time * measure_freq - 0.5)
1603
- measure_timer_ns = measure_len_ns
1604
- boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0
1596
+ boot_count = measures_index = flushedLen_ns = 0
1605
1597
 
1606
1598
  # pack batch
1607
1599
  shape = waveform.size()
@@ -1609,80 +1601,65 @@ def vad(
1609
1601
 
1610
1602
  n_channels, ilen = waveform.size()
1611
1603
 
1612
- mean_meas = torch.zeros(n_channels)
1613
- samples = torch.zeros(n_channels, samplesLen_ns)
1614
- spectrum = torch.zeros(n_channels, dft_len_ws)
1615
- noise_spectrum = torch.zeros(n_channels, dft_len_ws)
1616
- measures = torch.zeros(n_channels, measures_len)
1604
+ mean_meas = torch.zeros(n_channels, device=device)
1605
+ spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
1606
+ noise_spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
1607
+ measures = torch.zeros(n_channels, measures_len, device=device)
1617
1608
 
1618
1609
  has_triggered: bool = False
1619
1610
  num_measures_to_flush: int = 0
1620
- pos: int = 0
1621
1611
 
1622
- while pos < ilen and not has_triggered:
1623
- measure_timer_ns -= 1
1612
+ pos = 0
1613
+ for pos in range(measure_len_ns, ilen, measure_period_ns):
1624
1614
  for i in range(n_channels):
1625
- samples[i, samplesIndex_ns] = waveform[i, pos]
1626
- # if (!p->measure_timer_ns) {
1627
- if measure_timer_ns == 0:
1628
- index_ns: int = (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
1629
- meas: float = _measure(
1630
- measure_len_ws=measure_len_ws,
1631
- samples=samples[i],
1632
- spectrum=spectrum[i],
1633
- noise_spectrum=noise_spectrum[i],
1634
- spectrum_window=spectrum_window,
1635
- spectrum_start=spectrum_start,
1636
- spectrum_end=spectrum_end,
1637
- cepstrum_window=cepstrum_window,
1638
- cepstrum_start=cepstrum_start,
1639
- cepstrum_end=cepstrum_end,
1640
- noise_reduction_amount=noise_reduction_amount,
1641
- measure_smooth_time_mult=measure_smooth_time_mult,
1642
- noise_up_time_mult=noise_up_time_mult,
1643
- noise_down_time_mult=noise_down_time_mult,
1644
- index_ns=index_ns,
1645
- boot_count=boot_count,
1646
- )
1647
- measures[i, measures_index] = meas
1648
- mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
1649
-
1650
- has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
1651
- if has_triggered:
1652
- n: int = measures_len
1653
- k: int = measures_index
1654
- jTrigger: int = n
1655
- jZero: int = n
1656
- j: int = 0
1657
-
1658
- for j in range(n):
1659
- if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
1660
- jZero = jTrigger = j
1661
- elif (measures[i, k] == 0) and (jTrigger >= jZero):
1662
- jZero = j
1663
- k = (k + n - 1) % n
1664
- j = min(j, jZero)
1665
- # num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
1666
- num_measures_to_flush = min(max(num_measures_to_flush, j), n)
1667
- # end if has_triggered
1668
- # end if (measure_timer_ns == 0):
1669
- # end for
1670
- samplesIndex_ns += 1
1671
- pos += 1
1672
- # end while
1673
- if samplesIndex_ns == samplesLen_ns:
1674
- samplesIndex_ns = 0
1675
- if measure_timer_ns == 0:
1676
- measure_timer_ns = measure_period_ns
1677
- measures_index += 1
1678
- measures_index = measures_index % measures_len
1679
- if boot_count >= 0:
1680
- boot_count = -1 if boot_count == boot_count_max else boot_count + 1
1615
+ meas: float = _measure(
1616
+ measure_len_ws=measure_len_ws,
1617
+ samples=waveform[i, pos - measure_len_ws : pos],
1618
+ spectrum=spectrum[i],
1619
+ noise_spectrum=noise_spectrum[i],
1620
+ spectrum_window=spectrum_window,
1621
+ spectrum_start=spectrum_start,
1622
+ spectrum_end=spectrum_end,
1623
+ cepstrum_window=cepstrum_window,
1624
+ cepstrum_start=cepstrum_start,
1625
+ cepstrum_end=cepstrum_end,
1626
+ noise_reduction_amount=noise_reduction_amount,
1627
+ measure_smooth_time_mult=measure_smooth_time_mult,
1628
+ noise_up_time_mult=noise_up_time_mult,
1629
+ noise_down_time_mult=noise_down_time_mult,
1630
+ boot_count=boot_count,
1631
+ )
1632
+ measures[i, measures_index] = meas
1633
+ mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
1634
+
1635
+ has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
1636
+ if has_triggered:
1637
+ n: int = measures_len
1638
+ k: int = measures_index
1639
+ jTrigger: int = n
1640
+ jZero: int = n
1641
+ j: int = 0
1642
+
1643
+ for j in range(n):
1644
+ if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
1645
+ jZero = jTrigger = j
1646
+ elif (measures[i, k] == 0) and (jTrigger >= jZero):
1647
+ jZero = j
1648
+ k = (k + n - 1) % n
1649
+ j = min(j, jZero)
1650
+ # num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
1651
+ num_measures_to_flush = min(max(num_measures_to_flush, j), n)
1652
+ # end if has_triggered
1653
+ # end for channel
1654
+ measures_index += 1
1655
+ measures_index = measures_index % measures_len
1656
+ if boot_count >= 0:
1657
+ boot_count = -1 if boot_count == boot_count_max else boot_count + 1
1681
1658
 
1682
1659
  if has_triggered:
1683
1660
  flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
1684
- samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
1685
-
1661
+ break
1662
+ # end for window
1686
1663
  res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
1687
1664
  # unpack batch
1688
1665
  return res.view(shape[:-1] + res.shape[-1:])
@@ -1,7 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- import io
4
3
  import math
4
+ import tempfile
5
5
  import warnings
6
6
  from collections.abc import Sequence
7
7
  from typing import List, Optional, Tuple, Union
@@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union
9
9
  import torch
10
10
  import torchaudio
11
11
  from torch import Tensor
12
+ from torchaudio._internal.module_utils import deprecated
12
13
 
13
14
  from .filtering import highpass_biquad, treble_biquad
14
15
 
@@ -19,7 +20,6 @@ __all__ = [
19
20
  "amplitude_to_DB",
20
21
  "DB_to_amplitude",
21
22
  "compute_deltas",
22
- "compute_kaldi_pitch",
23
23
  "melscale_fbanks",
24
24
  "linear_fbanks",
25
25
  "create_dct",
@@ -83,7 +83,7 @@ def spectrogram(
83
83
  hop_length (int): Length of hop between STFT windows
84
84
  win_length (int): Window size
85
85
  power (float or None): Exponent for the magnitude spectrogram,
86
- (must be > 0) e.g., 1 for energy, 2 for power, etc.
86
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc.
87
87
  If None, then the complex spectrum is returned instead.
88
88
  normalized (bool or str): Whether to normalize by magnitude after stft. If input is str, choices are
89
89
  ``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
@@ -286,7 +286,7 @@ def griffinlim(
286
286
  Default: ``win_length // 2``)
287
287
  win_length (int): Window size. (Default: ``n_fft``)
288
288
  power (float): Exponent for the magnitude spectrogram,
289
- (must be > 0) e.g., 1 for energy, 2 for power, etc.
289
+ (must be > 0) e.g., 1 for magnitude, 2 for power, etc.
290
290
  n_iter (int): Number of iteration for phase recovery process.
291
291
  momentum (float): The momentum parameter for fast Griffin-Lim.
292
292
  Setting this to 0 recovers the original Griffin-Lim method.
@@ -370,9 +370,17 @@ def amplitude_to_DB(
370
370
 
371
371
  Args:
372
372
 
373
- x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
374
- the form `(..., freq, time)`. Batched inputs should include a channel dimension and
375
- have the form `(batch, channel, freq, time)`.
373
+ x (Tensor): Input spectrogram(s) before being converted to decibel scale.
374
+ The expected shapes are ``(freq, time)``, ``(channel, freq, time)`` or
375
+ ``(..., batch, channel, freq, time)``.
376
+
377
+ .. note::
378
+
379
+ When ``top_db`` is specified, cut-off values are computed for each audio
380
+ in the batch. Therefore if the input shape is 4D (or larger), different
381
+ cut-off values are used for audio data in the batch.
382
+ If the input shape is 2D or 3D, a single cutoff value is used.
383
+
376
384
  multiplier (float): Use 10. for power and 20. for amplitude
377
385
  amin (float): Number to clamp ``x``
378
386
  db_multiplier (float): Log10(max(reference value and amin))
@@ -547,7 +555,7 @@ def melscale_fbanks(
547
555
  meaning number of frequencies to highlight/apply to x the number of filterbanks.
548
556
  Each column is a filterbank so that assuming there is a matrix A of
549
557
  size (..., ``n_freqs``), the applied result would be
550
- ``A * melscale_fbanks(A.size(-1), ...)``.
558
+ ``A @ melscale_fbanks(A.size(-1), ...)``.
551
559
 
552
560
  """
553
561
 
@@ -825,18 +833,25 @@ def mask_along_axis_iid(
825
833
  ``max_v = min(mask_param, floor(specgrams.size(axis) * p))`` otherwise.
826
834
 
827
835
  Args:
828
- specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
836
+ specgrams (Tensor): Real spectrograms `(..., freq, time)`, with at least 3 dimensions.
829
837
  mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
830
838
  mask_value (float): Value to assign to the masked columns
831
- axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
839
+ axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
832
840
  p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
833
841
 
834
842
  Returns:
835
- Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
843
+ Tensor: Masked spectrograms with the same dimensions as input specgrams Tensor`
836
844
  """
837
845
 
838
- if axis not in [2, 3]:
839
- raise ValueError("Only Frequency and Time masking are supported")
846
+ dim = specgrams.dim()
847
+
848
+ if dim < 3:
849
+ raise ValueError(f"Spectrogram must have at least three dimensions ({dim} given).")
850
+
851
+ if axis not in [dim - 2, dim - 1]:
852
+ raise ValueError(
853
+ f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
854
+ )
840
855
 
841
856
  if not 0.0 <= p <= 1.0:
842
857
  raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
@@ -848,8 +863,8 @@ def mask_along_axis_iid(
848
863
  device = specgrams.device
849
864
  dtype = specgrams.dtype
850
865
 
851
- value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
852
- min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
866
+ value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * mask_param
867
+ min_value = torch.rand(specgrams.shape[: (dim - 2)], device=device, dtype=dtype) * (specgrams.size(axis) - value)
853
868
 
854
869
  # Create broadcastable mask
855
870
  mask_start = min_value.long()[..., None, None]
@@ -879,24 +894,31 @@ def mask_along_axis(
879
894
 
880
895
  Mask will be applied from indices ``[v_0, v_0 + v)``,
881
896
  where ``v`` is sampled from ``uniform(0, max_v)`` and
882
- ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
897
+ ``v_0`` from ``uniform(0, specgram.size(axis) - v)``, with
883
898
  ``max_v = mask_param`` when ``p = 1.0`` and
884
- ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
899
+ ``max_v = min(mask_param, floor(specgram.size(axis) * p))``
885
900
  otherwise.
886
901
  All examples will have the same mask interval.
887
902
 
888
903
  Args:
889
- specgram (Tensor): Real spectrogram `(channel, freq, time)`
904
+ specgram (Tensor): Real spectrograms `(..., freq, time)`, with at least 2 dimensions.
890
905
  mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
891
906
  mask_value (float): Value to assign to the masked columns
892
- axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
907
+ axis (int): Axis to apply masking on, which should be the one of the last two dimensions.
893
908
  p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
894
909
 
895
910
  Returns:
896
- Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
911
+ Tensor: Masked spectrograms with the same dimensions as input specgram Tensor
897
912
  """
898
- if axis not in [1, 2]:
899
- raise ValueError("Only Frequency and Time masking are supported")
913
+ dim = specgram.dim()
914
+
915
+ if dim < 2:
916
+ raise ValueError(f"Spectrogram must have at least two dimensions (time and frequency) ({dim} given).")
917
+
918
+ if axis not in [dim - 2, dim - 1]:
919
+ raise ValueError(
920
+ f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
921
+ )
900
922
 
901
923
  if not 0.0 <= p <= 1.0:
902
924
  raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
@@ -908,14 +930,17 @@ def mask_along_axis(
908
930
  # pack batch
909
931
  shape = specgram.size()
910
932
  specgram = specgram.reshape([-1] + list(shape[-2:]))
933
+ # After packing, specgram is a 3D tensor, and the axis corresponding to the to-be-masked dimension
934
+ # is now (axis - dim + 3), e.g. a tensor of shape (10, 2, 50, 10, 2) becomes a tensor of shape (1000, 10, 2).
911
935
  value = torch.rand(1) * mask_param
912
- min_value = torch.rand(1) * (specgram.size(axis) - value)
936
+ min_value = torch.rand(1) * (specgram.size(axis - dim + 3) - value)
913
937
 
914
938
  mask_start = (min_value.long()).squeeze()
915
939
  mask_end = (min_value.long() + value.long()).squeeze()
916
- mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
940
+ mask = torch.arange(0, specgram.shape[axis - dim + 3], device=specgram.device, dtype=specgram.dtype)
917
941
  mask = (mask >= mask_start) & (mask < mask_end)
918
- if axis == 1:
942
+ # unsqueeze the mask if the axis is frequency
943
+ if axis == dim - 2:
919
944
  mask = mask.unsqueeze(-1)
920
945
 
921
946
  if mask_end - mask_start >= mask_param:
@@ -1019,8 +1044,8 @@ def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_lo
1019
1044
 
1020
1045
  output_frames = (
1021
1046
  (s1 * s2).sum(-1)
1022
- / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
1023
- / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
1047
+ / (EPSILON + torch.linalg.vector_norm(s1, ord=2, dim=-1)).pow(2)
1048
+ / (EPSILON + torch.linalg.vector_norm(s2, ord=2, dim=-1)).pow(2)
1024
1049
  )
1025
1050
 
1026
1051
  output_lag.append(output_frames.unsqueeze(-1))
@@ -1271,6 +1296,7 @@ def spectral_centroid(
1271
1296
 
1272
1297
 
1273
1298
  @torchaudio._extension.fail_if_no_sox
1299
+ @deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False)
1274
1300
  def apply_codec(
1275
1301
  waveform: Tensor,
1276
1302
  sample_rate: int,
@@ -1303,129 +1329,17 @@ def apply_codec(
1303
1329
  Tensor: Resulting Tensor.
1304
1330
  If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
1305
1331
  """
1306
- bytes = io.BytesIO()
1307
- torchaudio.backend.sox_io_backend.save(
1308
- bytes, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
1309
- )
1310
- bytes.seek(0)
1311
- augmented, sr = torchaudio.backend.sox_io_backend.load(bytes, channels_first=channels_first, format=format)
1332
+ with tempfile.NamedTemporaryFile() as f:
1333
+ torchaudio.backend.sox_io_backend.save(
1334
+ f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
1335
+ )
1336
+ augmented, sr = torchaudio.backend.sox_io_backend.load(f.name, channels_first=channels_first, format=format)
1312
1337
  if sr != sample_rate:
1313
1338
  augmented = resample(augmented, sr, sample_rate)
1314
1339
  return augmented
1315
1340
 
1316
1341
 
1317
- @torchaudio._extension.fail_if_no_kaldi
1318
- def compute_kaldi_pitch(
1319
- waveform: torch.Tensor,
1320
- sample_rate: float,
1321
- frame_length: float = 25.0,
1322
- frame_shift: float = 10.0,
1323
- min_f0: float = 50,
1324
- max_f0: float = 400,
1325
- soft_min_f0: float = 10.0,
1326
- penalty_factor: float = 0.1,
1327
- lowpass_cutoff: float = 1000,
1328
- resample_frequency: float = 4000,
1329
- delta_pitch: float = 0.005,
1330
- nccf_ballast: float = 7000,
1331
- lowpass_filter_width: int = 1,
1332
- upsample_filter_width: int = 5,
1333
- max_frames_latency: int = 0,
1334
- frames_per_chunk: int = 0,
1335
- simulate_first_pass_online: bool = False,
1336
- recompute_frame: int = 500,
1337
- snip_edges: bool = True,
1338
- ) -> torch.Tensor:
1339
- """Extract pitch based on method described in *A pitch extraction algorithm tuned
1340
- for automatic speech recognition* :cite:`6854049`.
1341
-
1342
- .. devices:: CPU
1343
-
1344
- .. properties:: TorchScript
1345
-
1346
- This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.
1347
-
1348
- Args:
1349
- waveform (Tensor):
1350
- The input waveform of shape `(..., time)`.
1351
- sample_rate (float):
1352
- Sample rate of `waveform`.
1353
- frame_length (float, optional):
1354
- Frame length in milliseconds. (default: 25.0)
1355
- frame_shift (float, optional):
1356
- Frame shift in milliseconds. (default: 10.0)
1357
- min_f0 (float, optional):
1358
- Minimum F0 to search for (Hz) (default: 50.0)
1359
- max_f0 (float, optional):
1360
- Maximum F0 to search for (Hz) (default: 400.0)
1361
- soft_min_f0 (float, optional):
1362
- Minimum f0, applied in soft way, must not exceed min-f0 (default: 10.0)
1363
- penalty_factor (float, optional):
1364
- Cost factor for FO change. (default: 0.1)
1365
- lowpass_cutoff (float, optional):
1366
- Cutoff frequency for LowPass filter (Hz) (default: 1000)
1367
- resample_frequency (float, optional):
1368
- Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
1369
- (default: 4000)
1370
- delta_pitch( float, optional):
1371
- Smallest relative change in pitch that our algorithm measures. (default: 0.005)
1372
- nccf_ballast (float, optional):
1373
- Increasing this factor reduces NCCF for quiet frames (default: 7000)
1374
- lowpass_filter_width (int, optional):
1375
- Integer that determines filter width of lowpass filter, more gives sharper filter.
1376
- (default: 1)
1377
- upsample_filter_width (int, optional):
1378
- Integer that determines filter width when upsampling NCCF. (default: 5)
1379
- max_frames_latency (int, optional):
1380
- Maximum number of frames of latency that we allow pitch tracking to introduce into
1381
- the feature processing (affects output only if ``frames_per_chunk > 0`` and
1382
- ``simulate_first_pass_online=True``) (default: 0)
1383
- frames_per_chunk (int, optional):
1384
- The number of frames used for energy normalization. (default: 0)
1385
- simulate_first_pass_online (bool, optional):
1386
- If true, the function will output features that correspond to what an online decoder
1387
- would see in the first pass of decoding -- not the final version of the features,
1388
- which is the default. (default: False)
1389
- Relevant if ``frames_per_chunk > 0``.
1390
- recompute_frame (int, optional):
1391
- Only relevant for compatibility with online pitch extraction.
1392
- A non-critical parameter; the frame at which we recompute some of the forward pointers,
1393
- after revising our estimate of the signal energy.
1394
- Relevant if ``frames_per_chunk > 0``. (default: 500)
1395
- snip_edges (bool, optional):
1396
- If this is set to false, the incomplete frames near the ending edge won't be snipped,
1397
- so that the number of frames is the file size divided by the frame-shift.
1398
- This makes different types of features give the same number of frames. (default: True)
1399
-
1400
- Returns:
1401
- Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
1402
- corresponds to pitch and NCCF.
1403
- """
1404
- shape = waveform.shape
1405
- waveform = waveform.reshape(-1, shape[-1])
1406
- result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
1407
- waveform,
1408
- sample_rate,
1409
- frame_length,
1410
- frame_shift,
1411
- min_f0,
1412
- max_f0,
1413
- soft_min_f0,
1414
- penalty_factor,
1415
- lowpass_cutoff,
1416
- resample_frequency,
1417
- delta_pitch,
1418
- nccf_ballast,
1419
- lowpass_filter_width,
1420
- upsample_filter_width,
1421
- max_frames_latency,
1422
- frames_per_chunk,
1423
- simulate_first_pass_online,
1424
- recompute_frame,
1425
- snip_edges,
1426
- )
1427
- result = result.reshape(shape[:-1] + result.shape[-2:])
1428
- return result
1342
+ _CPU = torch.device("cpu")
1429
1343
 
1430
1344
 
1431
1345
  def _get_sinc_resample_kernel(
@@ -1436,10 +1350,9 @@ def _get_sinc_resample_kernel(
1436
1350
  rolloff: float = 0.99,
1437
1351
  resampling_method: str = "sinc_interp_hann",
1438
1352
  beta: Optional[float] = None,
1439
- device: torch.device = torch.device("cpu"),
1353
+ device: torch.device = _CPU,
1440
1354
  dtype: Optional[torch.dtype] = None,
1441
1355
  ):
1442
-
1443
1356
  if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
1444
1357
  raise Exception(
1445
1358
  "Frequencies must be of integer type to ensure quality resampling computation. "
@@ -1550,7 +1463,7 @@ def _apply_sinc_resample_kernel(
1550
1463
  waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
1551
1464
  resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
1552
1465
  resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
1553
- target_length = int(math.ceil(new_freq * length / orig_freq))
1466
+ target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
1554
1467
  resampled = resampled[..., :target_length]
1555
1468
 
1556
1469
  # unpack batch
@@ -2580,3 +2493,41 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
2580
2493
  a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
2581
2494
  b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
2582
2495
  return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
2496
+
2497
+
2498
+ def frechet_distance(mu_x, sigma_x, mu_y, sigma_y):
2499
+ r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.
2500
+
2501
+ Concretely, for multivariate Gaussians :math:`X(\mu_X, \Sigma_X)`
2502
+ and :math:`Y(\mu_Y, \Sigma_Y)`, the function computes and returns :math:`F` as
2503
+
2504
+ .. math::
2505
+ F(X, Y) = || \mu_X - \mu_Y ||_2^2
2506
+ + \text{Tr}\left( \Sigma_X + \Sigma_Y - 2 \sqrt{\Sigma_X \Sigma_Y} \right)
2507
+
2508
+ Args:
2509
+ mu_x (torch.Tensor): mean :math:`\mu_X` of multivariate Gaussian :math:`X`, with shape `(N,)`.
2510
+ sigma_x (torch.Tensor): covariance matrix :math:`\Sigma_X` of :math:`X`, with shape `(N, N)`.
2511
+ mu_y (torch.Tensor): mean :math:`\mu_Y` of multivariate Gaussian :math:`Y`, with shape `(N,)`.
2512
+ sigma_y (torch.Tensor): covariance matrix :math:`\Sigma_Y` of :math:`Y`, with shape `(N, N)`.
2513
+
2514
+ Returns:
2515
+ torch.Tensor: the Fréchet distance between :math:`X` and :math:`Y`.
2516
+ """
2517
+ if len(mu_x.size()) != 1:
2518
+ raise ValueError(f"Input mu_x must be one-dimensional; got dimension {len(mu_x.size())}.")
2519
+ if len(sigma_x.size()) != 2:
2520
+ raise ValueError(f"Input sigma_x must be two-dimensional; got dimension {len(sigma_x.size())}.")
2521
+ if sigma_x.size(0) != sigma_x.size(1) != mu_x.size(0):
2522
+ raise ValueError("Each of sigma_x's dimensions must match mu_x's size.")
2523
+ if mu_x.size() != mu_y.size():
2524
+ raise ValueError(f"Inputs mu_x and mu_y must have the same shape; got {mu_x.size()} and {mu_y.size()}.")
2525
+ if sigma_x.size() != sigma_y.size():
2526
+ raise ValueError(
2527
+ f"Inputs sigma_x and sigma_y must have the same shape; got {sigma_x.size()} and {sigma_y.size()}."
2528
+ )
2529
+
2530
+ a = (mu_x - mu_y).square().sum()
2531
+ b = sigma_x.trace() + sigma_y.trace()
2532
+ c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum()
2533
+ return a + b - 2 * c
torchaudio/io/__init__.py CHANGED
@@ -1,10 +1,13 @@
1
+ from ._effector import AudioEffector
1
2
  from ._playback import play_audio
2
3
  from ._stream_reader import StreamReader
3
- from ._stream_writer import StreamWriter
4
+ from ._stream_writer import CodecConfig, StreamWriter
4
5
 
5
6
 
6
7
  __all__ = [
8
+ "AudioEffector",
7
9
  "StreamReader",
8
10
  "StreamWriter",
11
+ "CodecConfig",
9
12
  "play_audio",
10
13
  ]