torchaudio 2.0.2__cp311-cp311-manylinux2014_aarch64.whl → 2.1.1__cp311-cp311-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchaudio might be problematic. Click here for more details.

Files changed (90) hide show
  1. torchaudio/__init__.py +22 -3
  2. torchaudio/_backend/__init__.py +55 -4
  3. torchaudio/_backend/backend.py +53 -0
  4. torchaudio/_backend/common.py +52 -0
  5. torchaudio/_backend/ffmpeg.py +373 -0
  6. torchaudio/_backend/soundfile.py +54 -0
  7. torchaudio/_backend/soundfile_backend.py +457 -0
  8. torchaudio/_backend/sox.py +91 -0
  9. torchaudio/_backend/utils.py +81 -323
  10. torchaudio/_extension/__init__.py +55 -36
  11. torchaudio/_extension/utils.py +109 -17
  12. torchaudio/_internal/__init__.py +4 -1
  13. torchaudio/_internal/module_utils.py +37 -6
  14. torchaudio/backend/__init__.py +7 -11
  15. torchaudio/backend/_no_backend.py +24 -0
  16. torchaudio/backend/_sox_io_backend.py +297 -0
  17. torchaudio/backend/common.py +12 -52
  18. torchaudio/backend/no_backend.py +11 -21
  19. torchaudio/backend/soundfile_backend.py +11 -448
  20. torchaudio/backend/sox_io_backend.py +11 -435
  21. torchaudio/backend/utils.py +9 -18
  22. torchaudio/datasets/__init__.py +2 -0
  23. torchaudio/datasets/cmuarctic.py +1 -1
  24. torchaudio/datasets/cmudict.py +61 -62
  25. torchaudio/datasets/dr_vctk.py +1 -1
  26. torchaudio/datasets/gtzan.py +1 -1
  27. torchaudio/datasets/librilight_limited.py +1 -1
  28. torchaudio/datasets/librispeech.py +1 -1
  29. torchaudio/datasets/librispeech_biasing.py +189 -0
  30. torchaudio/datasets/libritts.py +1 -1
  31. torchaudio/datasets/ljspeech.py +1 -1
  32. torchaudio/datasets/musdb_hq.py +1 -1
  33. torchaudio/datasets/quesst14.py +1 -1
  34. torchaudio/datasets/speechcommands.py +1 -1
  35. torchaudio/datasets/tedlium.py +1 -1
  36. torchaudio/datasets/vctk.py +1 -1
  37. torchaudio/datasets/voxceleb1.py +1 -1
  38. torchaudio/datasets/yesno.py +1 -1
  39. torchaudio/functional/__init__.py +6 -2
  40. torchaudio/functional/_alignment.py +128 -0
  41. torchaudio/functional/filtering.py +69 -92
  42. torchaudio/functional/functional.py +99 -148
  43. torchaudio/io/__init__.py +4 -1
  44. torchaudio/io/_effector.py +347 -0
  45. torchaudio/io/_stream_reader.py +158 -90
  46. torchaudio/io/_stream_writer.py +196 -10
  47. torchaudio/lib/_torchaudio.so +0 -0
  48. torchaudio/lib/_torchaudio_ffmpeg4.so +0 -0
  49. torchaudio/lib/_torchaudio_ffmpeg5.so +0 -0
  50. torchaudio/lib/_torchaudio_ffmpeg6.so +0 -0
  51. torchaudio/lib/_torchaudio_sox.so +0 -0
  52. torchaudio/lib/libtorchaudio.so +0 -0
  53. torchaudio/lib/libtorchaudio_ffmpeg4.so +0 -0
  54. torchaudio/lib/libtorchaudio_ffmpeg5.so +0 -0
  55. torchaudio/lib/libtorchaudio_ffmpeg6.so +0 -0
  56. torchaudio/lib/libtorchaudio_sox.so +0 -0
  57. torchaudio/models/__init__.py +14 -0
  58. torchaudio/models/decoder/__init__.py +22 -7
  59. torchaudio/models/decoder/_ctc_decoder.py +123 -69
  60. torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
  61. torchaudio/models/rnnt_decoder.py +10 -14
  62. torchaudio/models/squim/__init__.py +11 -0
  63. torchaudio/models/squim/objective.py +326 -0
  64. torchaudio/models/squim/subjective.py +150 -0
  65. torchaudio/models/wav2vec2/components.py +6 -10
  66. torchaudio/pipelines/__init__.py +9 -0
  67. torchaudio/pipelines/_squim_pipeline.py +176 -0
  68. torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
  69. torchaudio/pipelines/_wav2vec2/impl.py +198 -68
  70. torchaudio/pipelines/_wav2vec2/utils.py +120 -0
  71. torchaudio/sox_effects/sox_effects.py +7 -30
  72. torchaudio/transforms/__init__.py +2 -0
  73. torchaudio/transforms/_transforms.py +99 -54
  74. torchaudio/utils/download.py +2 -2
  75. torchaudio/utils/ffmpeg_utils.py +20 -15
  76. torchaudio/utils/sox_utils.py +8 -9
  77. torchaudio/version.py +2 -2
  78. torchaudio-2.1.1.dist-info/METADATA +113 -0
  79. torchaudio-2.1.1.dist-info/RECORD +117 -0
  80. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/WHEEL +1 -1
  81. torchaudio/io/_compat.py +0 -241
  82. torchaudio/lib/_torchaudio_ffmpeg.so +0 -0
  83. torchaudio/lib/flashlight_lib_text_decoder.so +0 -0
  84. torchaudio/lib/flashlight_lib_text_dictionary.so +0 -0
  85. torchaudio/lib/libflashlight-text.so +0 -0
  86. torchaudio/lib/libtorchaudio_ffmpeg.so +0 -0
  87. torchaudio-2.0.2.dist-info/METADATA +0 -30
  88. torchaudio-2.0.2.dist-info/RECORD +0 -100
  89. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/LICENSE +0 -0
  90. {torchaudio-2.0.2.dist-info → torchaudio-2.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchaudio
6
+
7
+
8
+ class AttPool(nn.Module):
9
+ """Attention-Pooling module that estimates the attention score.
10
+
11
+ Args:
12
+ input_dim (int): Input feature dimension.
13
+ att_dim (int): Attention Tensor dimension.
14
+ """
15
+
16
+ def __init__(self, input_dim: int, att_dim: int):
17
+ super(AttPool, self).__init__()
18
+
19
+ self.linear1 = nn.Linear(input_dim, 1)
20
+ self.linear2 = nn.Linear(input_dim, att_dim)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ """Apply attention and pooling.
24
+
25
+ Args:
26
+ x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
27
+
28
+ Returns:
29
+ (torch.Tensor): Attention score with dimensions `(batch, att_dim)`.
30
+ """
31
+
32
+ att = self.linear1(x) # (batch, time, 1)
33
+ att = att.transpose(2, 1) # (batch, 1, time)
34
+ att = nn.functional.softmax(att, dim=2)
35
+ x = torch.matmul(att, x).squeeze(1) # (batch, input_dim)
36
+ x = self.linear2(x) # (batch, att_dim)
37
+ return x
38
+
39
+
40
+ class Predictor(nn.Module):
41
+ """Prediction module that apply pooling and attention, then predict subjective metric scores.
42
+
43
+ Args:
44
+ input_dim (int): Input feature dimension.
45
+ att_dim (int): Attention Tensor dimension.
46
+ """
47
+
48
+ def __init__(self, input_dim: int, att_dim: int):
49
+ super(Predictor, self).__init__()
50
+ self.att_pool_layer = AttPool(input_dim, att_dim)
51
+ self.att_dim = att_dim
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ """Predict subjective evaluation metric score.
55
+
56
+ Args:
57
+ x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.
58
+
59
+ Returns:
60
+ (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
61
+ """
62
+ x = self.att_pool_layer(x)
63
+ x = nn.functional.softmax(x, dim=1)
64
+ B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
65
+ x = (x * B).sum(dim=1)
66
+ return x
67
+
68
+
69
+ class SquimSubjective(nn.Module):
70
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores
71
+ for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*
72
+ :cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.
73
+
74
+ Args:
75
+ ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.
76
+ projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.
77
+ predictor (torch.nn.Module): Predict the subjective scores.
78
+ """
79
+
80
+ def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
81
+ super(SquimSubjective, self).__init__()
82
+ self.ssl_model = ssl_model
83
+ self.projector = projector
84
+ self.predictor = predictor
85
+
86
+ def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
87
+ """Cut or pad the reference Tensor to make it aligned with waveform Tensor.
88
+
89
+ Args:
90
+ waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
91
+ reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
92
+
93
+ Returns:
94
+ (torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors
95
+ with same dimensions `(batch, time)`.
96
+ """
97
+ T_waveform = waveform.shape[-1]
98
+ T_reference = reference.shape[-1]
99
+ if T_reference < T_waveform:
100
+ num_padding = T_waveform // T_reference + 1
101
+ reference = torch.cat([reference for _ in range(num_padding)], dim=1)
102
+ return waveform, reference[:, :T_waveform]
103
+
104
+ def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
105
+ """Predict subjective evaluation metric score.
106
+
107
+ Args:
108
+ waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.
109
+ reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.
110
+
111
+ Returns:
112
+ (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.
113
+ """
114
+ waveform, reference = self._align_shapes(waveform, reference)
115
+ waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
116
+ reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
117
+ concat = torch.cat((reference, waveform), dim=2)
118
+ score_diff = self.predictor(concat) # Score difference compared to the reference
119
+ return 5 - score_diff
120
+
121
+
122
+ def squim_subjective_model(
123
+ ssl_type: str,
124
+ feat_dim: int,
125
+ proj_dim: int,
126
+ att_dim: int,
127
+ ) -> SquimSubjective:
128
+ """Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.
129
+
130
+ Args:
131
+ ssl_type (str): Type of self-supervised learning (SSL) models.
132
+ Must be one of ["wav2vec2_base", "wav2vec2_large"].
133
+ feat_dim (int): Feature dimension of the SSL feature representation.
134
+ proj_dim (int): Output dimension of projection layer.
135
+ att_dim (int): Dimension of attention scores.
136
+ """
137
+ ssl_model = getattr(torchaudio.models, ssl_type)()
138
+ projector = nn.Linear(feat_dim, proj_dim)
139
+ predictor = Predictor(proj_dim * 2, att_dim)
140
+ return SquimSubjective(ssl_model, projector, predictor)
141
+
142
+
143
+ def squim_subjective_base() -> SquimSubjective:
144
+ """Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
145
+ return squim_subjective_model(
146
+ ssl_type="wav2vec2_base",
147
+ feat_dim=768,
148
+ proj_dim=32,
149
+ att_dim=5,
150
+ )
@@ -208,18 +208,13 @@ class ConvolutionalPositionalEmbedding(Module):
208
208
  groups=groups,
209
209
  )
210
210
 
211
- self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
211
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
212
212
  self.num_remove: int = 1 if kernel_size % 2 == 0 else 0
213
213
 
214
214
  def __prepare_scriptable__(self):
215
- for hook in self.conv._forward_pre_hooks.values():
216
- # The hook we want to remove is an instance of WeightNorm class, so
217
- # normally we would do `if isinstance(...)` but this class is not accessible
218
- # because of shadowing, so we check the module name directly.
219
- # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3
220
- if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm":
221
- _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
222
- torch.nn.utils.remove_weight_norm(self.conv)
215
+ if self.conv.__class__.__name__ == "ParametrizedConv1d":
216
+ _LG.warning("Removing weight_norm from %s", self.__class__.__name__)
217
+ torch.nn.utils.parametrize.remove_parametrizations(self.conv, "weight")
223
218
  return self
224
219
 
225
220
  def forward(self, x):
@@ -458,9 +453,10 @@ class Transformer(Module):
458
453
  raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]")
459
454
 
460
455
  ret: List[Tensor] = []
456
+ position_bias = None
461
457
  x = self._preprocess(x)
462
458
  for layer in self.layers:
463
- x, _ = layer(x, attention_mask) # Ignore position_bias
459
+ x, position_bias = layer(x, attention_mask, position_bias=position_bias)
464
460
  ret.append(x)
465
461
  if num_layers is not None and len(ret) >= num_layers:
466
462
  return ret
@@ -4,6 +4,7 @@ from ._source_separation_pipeline import (
4
4
  HDEMUCS_HIGH_MUSDB_PLUS,
5
5
  SourceSeparationBundle,
6
6
  )
7
+ from ._squim_pipeline import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE, SquimObjectiveBundle, SquimSubjectiveBundle
7
8
  from ._tts import (
8
9
  TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
9
10
  TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
@@ -17,6 +18,7 @@ from ._wav2vec2.impl import (
17
18
  HUBERT_BASE,
18
19
  HUBERT_LARGE,
19
20
  HUBERT_XLARGE,
21
+ MMS_FA,
20
22
  VOXPOPULI_ASR_BASE_10K_DE,
21
23
  VOXPOPULI_ASR_BASE_10K_EN,
22
24
  VOXPOPULI_ASR_BASE_10K_ES,
@@ -40,6 +42,7 @@ from ._wav2vec2.impl import (
40
42
  WAV2VEC2_XLSR_300M,
41
43
  Wav2Vec2ASRBundle,
42
44
  Wav2Vec2Bundle,
45
+ Wav2Vec2FABundle,
43
46
  WAVLM_BASE,
44
47
  WAVLM_BASE_PLUS,
45
48
  WAVLM_LARGE,
@@ -50,6 +53,7 @@ from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
50
53
  __all__ = [
51
54
  "Wav2Vec2Bundle",
52
55
  "Wav2Vec2ASRBundle",
56
+ "Wav2Vec2FABundle",
53
57
  "WAV2VEC2_BASE",
54
58
  "WAV2VEC2_LARGE",
55
59
  "WAV2VEC2_LARGE_LV60K",
@@ -76,6 +80,7 @@ __all__ = [
76
80
  "HUBERT_XLARGE",
77
81
  "HUBERT_ASR_LARGE",
78
82
  "HUBERT_ASR_XLARGE",
83
+ "MMS_FA",
79
84
  "WAVLM_BASE",
80
85
  "WAVLM_BASE_PLUS",
81
86
  "WAVLM_LARGE",
@@ -90,4 +95,8 @@ __all__ = [
90
95
  "CONVTASNET_BASE_LIBRI2MIX",
91
96
  "HDEMUCS_HIGH_MUSDB_PLUS",
92
97
  "HDEMUCS_HIGH_MUSDB",
98
+ "SQUIM_OBJECTIVE",
99
+ "SQUIM_SUBJECTIVE",
100
+ "SquimObjectiveBundle",
101
+ "SquimSubjectiveBundle",
93
102
  ]
@@ -0,0 +1,176 @@
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
6
+
7
+
8
+ @dataclass
9
+ class SquimObjectiveBundle:
10
+ """Data class that bundles associated information to use pretrained
11
+ :py:class:`~torchaudio.models.SquimObjective` model.
12
+
13
+ This class provides interfaces for instantiating the pretrained model along with
14
+ the information necessary to retrieve pretrained weights and additional data
15
+ to be used with the model.
16
+
17
+ Torchaudio library instantiates objects of this class, each of which represents
18
+ a different pretrained model. Client code should access pretrained models via these
19
+ instances.
20
+
21
+ This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
22
+ A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
23
+
24
+ Example: Estimate the objective metric scores for the input waveform.
25
+ >>> import torch
26
+ >>> import torchaudio
27
+ >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
28
+ >>>
29
+ >>> # Load the SquimObjective bundle
30
+ >>> model = bundle.get_model()
31
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
32
+ 100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
33
+ >>>
34
+ >>> # Resample audio to the expected sampling rate
35
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
36
+ >>>
37
+ >>> # Estimate objective metric scores
38
+ >>> scores = model(waveform)
39
+ >>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
40
+ """ # noqa: E501
41
+
42
+ _path: str
43
+ _sample_rate: float
44
+
45
+ def _get_state_dict(self, dl_kwargs):
46
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
47
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
48
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
49
+ return state_dict
50
+
51
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
52
+ """Construct the SquimObjective model, and load the pretrained weight.
53
+
54
+ The weight file is downloaded from the internet and cached with
55
+ :func:`torch.hub.load_state_dict_from_url`
56
+
57
+ Args:
58
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
59
+
60
+ Returns:
61
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
62
+ """
63
+ model = squim_objective_base()
64
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
65
+ model.eval()
66
+ return model
67
+
68
+ @property
69
+ def sample_rate(self):
70
+ """Sample rate of the audio that the model is trained on.
71
+
72
+ :type: float
73
+ """
74
+ return self._sample_rate
75
+
76
+
77
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
78
+ "squim_objective_dns2020.pth",
79
+ _sample_rate=16000,
80
+ )
81
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
82
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
83
+
84
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
85
+ The weights are under `Creative Commons Attribution 4.0 International License
86
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
87
+
88
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
89
+ """
90
+
91
+
92
+ @dataclass
93
+ class SquimSubjectiveBundle:
94
+ """Data class that bundles associated information to use pretrained
95
+ :py:class:`~torchaudio.models.SquimSubjective` model.
96
+
97
+ This class provides interfaces for instantiating the pretrained model along with
98
+ the information necessary to retrieve pretrained weights and additional data
99
+ to be used with the model.
100
+
101
+ Torchaudio library instantiates objects of this class, each of which represents
102
+ a different pretrained model. Client code should access pretrained models via these
103
+ instances.
104
+
105
+ This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
106
+ A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
107
+
108
+ Example: Estimate the subjective metric scores for the input waveform.
109
+ >>> import torch
110
+ >>> import torchaudio
111
+ >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
112
+ >>>
113
+ >>> # Load the SquimSubjective bundle
114
+ >>> model = bundle.get_model()
115
+ Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
116
+ 100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
117
+ >>>
118
+ >>> # Resample audio to the expected sampling rate
119
+ >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
120
+ >>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
121
+ >>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
122
+ >>>
123
+ >>> # Estimate subjective metric scores
124
+ >>> score = model(waveform, reference)
125
+ >>> print(f"MOS: {score}.")
126
+ """ # noqa: E501
127
+
128
+ _path: str
129
+ _sample_rate: float
130
+
131
+ def _get_state_dict(self, dl_kwargs):
132
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
133
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
134
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
135
+ return state_dict
136
+
137
+ def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
138
+ """Construct the SquimSubjective model, and load the pretrained weight.
139
+
140
+ The weight file is downloaded from the internet and cached with
141
+ :func:`torch.hub.load_state_dict_from_url`
142
+
143
+ Args:
144
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
145
+
146
+ Returns:
147
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
148
+ """
149
+ model = squim_subjective_base()
150
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
151
+ model.eval()
152
+ return model
153
+
154
+ @property
155
+ def sample_rate(self):
156
+ """Sample rate of the audio that the model is trained on.
157
+
158
+ :type: float
159
+ """
160
+ return self._sample_rate
161
+
162
+
163
+ SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
164
+ "squim_subjective_bvcc_daps.pth",
165
+ _sample_rate=16000,
166
+ )
167
+ SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
168
+ as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
169
+ on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
170
+
171
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
172
+ The weights are under `Creative Commons Attribution Non Commercial 4.0 International
173
+ <https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
174
+
175
+ Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
176
+ """
@@ -0,0 +1,87 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torchaudio.functional as F
6
+ from torch import Tensor
7
+ from torchaudio.functional import TokenSpan
8
+
9
+
10
+ class ITokenizer(ABC):
11
+ @abstractmethod
12
+ def __call__(self, transcript: List[str]) -> List[List[str]]:
13
+ """Tokenize the given transcript (list of word)
14
+
15
+ .. note::
16
+
17
+ The toranscript must be normalized.
18
+
19
+ Args:
20
+ transcript (list of str): Transcript (list of word).
21
+
22
+ Returns:
23
+ (list of int): List of token sequences
24
+ """
25
+
26
+
27
+ class Tokenizer(ITokenizer):
28
+ def __init__(self, dictionary: Dict[str, int]):
29
+ self.dictionary = dictionary
30
+
31
+ def __call__(self, transcript: List[str]) -> List[List[int]]:
32
+ return [[self.dictionary[c] for c in word] for word in transcript]
33
+
34
+
35
+ def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0):
36
+ device = emission.device
37
+ emission = emission.unsqueeze(0)
38
+ targets = torch.tensor([tokens], dtype=torch.int32, device=device)
39
+
40
+ aligned_tokens, scores = F.forced_align(emission, targets, blank=blank)
41
+
42
+ scores = scores.exp() # convert back to probability
43
+ aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension
44
+ return aligned_tokens, scores
45
+
46
+
47
+ class IAligner(ABC):
48
+ @abstractmethod
49
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
50
+ """Generate list of time-stamped token sequences
51
+
52
+ Args:
53
+ emission (Tensor): Sequence of token probability distributions in log-domain.
54
+ Shape: `(time, tokens)`.
55
+ tokens (list of integer sequence): Tokenized transcript.
56
+ Output from :py:class:`torchaudio.pipelines.Wav2Vec2FABundle.Tokenizer`.
57
+
58
+ Returns:
59
+ (list of TokenSpan sequence): Tokens with time stamps and scores.
60
+ """
61
+
62
+
63
+ def _unflatten(list_, lengths):
64
+ assert len(list_) == sum(lengths)
65
+ i = 0
66
+ ret = []
67
+ for l in lengths:
68
+ ret.append(list_[i : i + l])
69
+ i += l
70
+ return ret
71
+
72
+
73
+ def _flatten(nested_list):
74
+ return [item for list_ in nested_list for item in list_]
75
+
76
+
77
+ class Aligner(IAligner):
78
+ def __init__(self, blank):
79
+ self.blank = blank
80
+
81
+ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]:
82
+ if emission.ndim != 2:
83
+ raise ValueError(f"The input emission must be 2D. Found: {emission.shape}")
84
+
85
+ aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank)
86
+ spans = F.merge_tokens(aligned_tokens, scores)
87
+ return _unflatten(spans, [len(ts) for ts in tokens])