torchaudio 2.9.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchaudio might be problematic. Click here for more details.
- torchaudio/.dylibs/libc++.1.0.dylib +0 -0
- torchaudio/__init__.py +204 -0
- torchaudio/_extension/__init__.py +61 -0
- torchaudio/_extension/utils.py +133 -0
- torchaudio/_internal/__init__.py +10 -0
- torchaudio/_internal/module_utils.py +171 -0
- torchaudio/_torchcodec.py +340 -0
- torchaudio/compliance/__init__.py +5 -0
- torchaudio/compliance/kaldi.py +813 -0
- torchaudio/datasets/__init__.py +47 -0
- torchaudio/datasets/cmuarctic.py +157 -0
- torchaudio/datasets/cmudict.py +186 -0
- torchaudio/datasets/commonvoice.py +86 -0
- torchaudio/datasets/dr_vctk.py +121 -0
- torchaudio/datasets/fluentcommands.py +108 -0
- torchaudio/datasets/gtzan.py +1118 -0
- torchaudio/datasets/iemocap.py +147 -0
- torchaudio/datasets/librilight_limited.py +111 -0
- torchaudio/datasets/librimix.py +133 -0
- torchaudio/datasets/librispeech.py +174 -0
- torchaudio/datasets/librispeech_biasing.py +189 -0
- torchaudio/datasets/libritts.py +168 -0
- torchaudio/datasets/ljspeech.py +107 -0
- torchaudio/datasets/musdb_hq.py +139 -0
- torchaudio/datasets/quesst14.py +136 -0
- torchaudio/datasets/snips.py +157 -0
- torchaudio/datasets/speechcommands.py +183 -0
- torchaudio/datasets/tedlium.py +218 -0
- torchaudio/datasets/utils.py +54 -0
- torchaudio/datasets/vctk.py +143 -0
- torchaudio/datasets/voxceleb1.py +309 -0
- torchaudio/datasets/yesno.py +89 -0
- torchaudio/functional/__init__.py +130 -0
- torchaudio/functional/_alignment.py +128 -0
- torchaudio/functional/filtering.py +1685 -0
- torchaudio/functional/functional.py +2505 -0
- torchaudio/lib/__init__.py +0 -0
- torchaudio/lib/_torchaudio.so +0 -0
- torchaudio/lib/libtorchaudio.so +0 -0
- torchaudio/models/__init__.py +85 -0
- torchaudio/models/_hdemucs.py +1008 -0
- torchaudio/models/conformer.py +293 -0
- torchaudio/models/conv_tasnet.py +330 -0
- torchaudio/models/decoder/__init__.py +64 -0
- torchaudio/models/decoder/_ctc_decoder.py +568 -0
- torchaudio/models/decoder/_cuda_ctc_decoder.py +187 -0
- torchaudio/models/deepspeech.py +84 -0
- torchaudio/models/emformer.py +884 -0
- torchaudio/models/rnnt.py +816 -0
- torchaudio/models/rnnt_decoder.py +339 -0
- torchaudio/models/squim/__init__.py +11 -0
- torchaudio/models/squim/objective.py +326 -0
- torchaudio/models/squim/subjective.py +150 -0
- torchaudio/models/tacotron2.py +1046 -0
- torchaudio/models/wav2letter.py +72 -0
- torchaudio/models/wav2vec2/__init__.py +45 -0
- torchaudio/models/wav2vec2/components.py +1167 -0
- torchaudio/models/wav2vec2/model.py +1579 -0
- torchaudio/models/wav2vec2/utils/__init__.py +7 -0
- torchaudio/models/wav2vec2/utils/import_fairseq.py +213 -0
- torchaudio/models/wav2vec2/utils/import_huggingface.py +134 -0
- torchaudio/models/wav2vec2/wavlm_attention.py +214 -0
- torchaudio/models/wavernn.py +409 -0
- torchaudio/pipelines/__init__.py +102 -0
- torchaudio/pipelines/_source_separation_pipeline.py +109 -0
- torchaudio/pipelines/_squim_pipeline.py +156 -0
- torchaudio/pipelines/_tts/__init__.py +16 -0
- torchaudio/pipelines/_tts/impl.py +385 -0
- torchaudio/pipelines/_tts/interface.py +255 -0
- torchaudio/pipelines/_tts/utils.py +230 -0
- torchaudio/pipelines/_wav2vec2/__init__.py +0 -0
- torchaudio/pipelines/_wav2vec2/aligner.py +87 -0
- torchaudio/pipelines/_wav2vec2/impl.py +1699 -0
- torchaudio/pipelines/_wav2vec2/utils.py +346 -0
- torchaudio/pipelines/rnnt_pipeline.py +380 -0
- torchaudio/transforms/__init__.py +78 -0
- torchaudio/transforms/_multi_channel.py +467 -0
- torchaudio/transforms/_transforms.py +2138 -0
- torchaudio/utils/__init__.py +4 -0
- torchaudio/utils/download.py +89 -0
- torchaudio/version.py +2 -0
- torchaudio-2.9.0.dist-info/LICENSE +25 -0
- torchaudio-2.9.0.dist-info/METADATA +122 -0
- torchaudio-2.9.0.dist-info/RECORD +86 -0
- torchaudio-2.9.0.dist-info/WHEEL +5 -0
- torchaudio-2.9.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchaudio
|
|
5
|
+
|
|
6
|
+
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class SquimObjectiveBundle:
|
|
11
|
+
"""Data class that bundles associated information to use pretrained
|
|
12
|
+
:py:class:`~torchaudio.models.SquimObjective` model.
|
|
13
|
+
|
|
14
|
+
This class provides interfaces for instantiating the pretrained model along with
|
|
15
|
+
the information necessary to retrieve pretrained weights and additional data
|
|
16
|
+
to be used with the model.
|
|
17
|
+
|
|
18
|
+
Torchaudio library instantiates objects of this class, each of which represents
|
|
19
|
+
a different pretrained model. Client code should access pretrained models via these
|
|
20
|
+
instances.
|
|
21
|
+
|
|
22
|
+
This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.
|
|
23
|
+
A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.
|
|
24
|
+
|
|
25
|
+
Example: Estimate the objective metric scores for the input waveform.
|
|
26
|
+
>>> import torch
|
|
27
|
+
>>> import torchaudio
|
|
28
|
+
>>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Load the SquimObjective bundle
|
|
31
|
+
>>> model = bundle.get_model()
|
|
32
|
+
Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"
|
|
33
|
+
100%|████████████| 28.2M/28.2M [00:03<00:00, 9.24MB/s]
|
|
34
|
+
>>>
|
|
35
|
+
>>> # Resample audio to the expected sampling rate
|
|
36
|
+
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
|
|
37
|
+
>>>
|
|
38
|
+
>>> # Estimate objective metric scores
|
|
39
|
+
>>> scores = model(waveform)
|
|
40
|
+
>>> print(f"STOI: {scores[0].item()}, PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")
|
|
41
|
+
""" # noqa: E501
|
|
42
|
+
|
|
43
|
+
_path: str
|
|
44
|
+
_sample_rate: float
|
|
45
|
+
|
|
46
|
+
def get_model(self) -> SquimObjective:
|
|
47
|
+
"""Construct the SquimObjective model, and load the pretrained weight.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
|
51
|
+
"""
|
|
52
|
+
model = squim_objective_base()
|
|
53
|
+
path = torchaudio.utils._download_asset(f"models/{self._path}")
|
|
54
|
+
state_dict = torch.load(path, weights_only=True)
|
|
55
|
+
model.load_state_dict(state_dict)
|
|
56
|
+
model.eval()
|
|
57
|
+
return model
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def sample_rate(self):
|
|
61
|
+
"""Sample rate of the audio that the model is trained on.
|
|
62
|
+
|
|
63
|
+
:type: float
|
|
64
|
+
"""
|
|
65
|
+
return self._sample_rate
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
|
69
|
+
"squim_objective_dns2020.pth",
|
|
70
|
+
_sample_rate=16000,
|
|
71
|
+
)
|
|
72
|
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
|
73
|
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
|
74
|
+
|
|
75
|
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
|
76
|
+
The weights are under `Creative Commons Attribution 4.0 International License
|
|
77
|
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
|
78
|
+
|
|
79
|
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class SquimSubjectiveBundle:
|
|
85
|
+
"""Data class that bundles associated information to use pretrained
|
|
86
|
+
:py:class:`~torchaudio.models.SquimSubjective` model.
|
|
87
|
+
|
|
88
|
+
This class provides interfaces for instantiating the pretrained model along with
|
|
89
|
+
the information necessary to retrieve pretrained weights and additional data
|
|
90
|
+
to be used with the model.
|
|
91
|
+
|
|
92
|
+
Torchaudio library instantiates objects of this class, each of which represents
|
|
93
|
+
a different pretrained model. Client code should access pretrained models via these
|
|
94
|
+
instances.
|
|
95
|
+
|
|
96
|
+
This bundle can estimate subjective metric scores for speech enhancement, such as MOS.
|
|
97
|
+
A typical use case would be a flow like `waveform -> score`. Please see below for the code example.
|
|
98
|
+
|
|
99
|
+
Example: Estimate the subjective metric scores for the input waveform.
|
|
100
|
+
>>> import torch
|
|
101
|
+
>>> import torchaudio
|
|
102
|
+
>>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle
|
|
103
|
+
>>>
|
|
104
|
+
>>> # Load the SquimSubjective bundle
|
|
105
|
+
>>> model = bundle.get_model()
|
|
106
|
+
Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"
|
|
107
|
+
100%|████████████| 360M/360M [00:09<00:00, 41.1MB/s]
|
|
108
|
+
>>>
|
|
109
|
+
>>> # Resample audio to the expected sampling rate
|
|
110
|
+
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
|
|
111
|
+
>>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input
|
|
112
|
+
>>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)
|
|
113
|
+
>>>
|
|
114
|
+
>>> # Estimate subjective metric scores
|
|
115
|
+
>>> score = model(waveform, reference)
|
|
116
|
+
>>> print(f"MOS: {score}.")
|
|
117
|
+
""" # noqa: E501
|
|
118
|
+
|
|
119
|
+
_path: str
|
|
120
|
+
_sample_rate: float
|
|
121
|
+
|
|
122
|
+
def get_model(self) -> SquimSubjective:
|
|
123
|
+
"""Construct the SquimSubjective model, and load the pretrained weight.
|
|
124
|
+
Returns:
|
|
125
|
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
|
126
|
+
"""
|
|
127
|
+
model = squim_subjective_base()
|
|
128
|
+
path = torchaudio.utils._download_asset(f"models/{self._path}")
|
|
129
|
+
state_dict = torch.load(path, weights_only=True)
|
|
130
|
+
model.load_state_dict(state_dict)
|
|
131
|
+
model.eval()
|
|
132
|
+
return model
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def sample_rate(self):
|
|
136
|
+
"""Sample rate of the audio that the model is trained on.
|
|
137
|
+
|
|
138
|
+
:type: float
|
|
139
|
+
"""
|
|
140
|
+
return self._sample_rate
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
|
|
144
|
+
"squim_subjective_bvcc_daps.pth",
|
|
145
|
+
_sample_rate=16000,
|
|
146
|
+
)
|
|
147
|
+
SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained
|
|
148
|
+
as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`
|
|
149
|
+
on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.
|
|
150
|
+
|
|
151
|
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.
|
|
152
|
+
The weights are under `Creative Commons Attribution Non Commercial 4.0 International
|
|
153
|
+
<https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.
|
|
154
|
+
|
|
155
|
+
Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.
|
|
156
|
+
"""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .impl import (
|
|
2
|
+
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
|
|
3
|
+
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
|
|
4
|
+
TACOTRON2_WAVERNN_CHAR_LJSPEECH,
|
|
5
|
+
TACOTRON2_WAVERNN_PHONE_LJSPEECH,
|
|
6
|
+
)
|
|
7
|
+
from .interface import Tacotron2TTSBundle
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Tacotron2TTSBundle",
|
|
12
|
+
"TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH",
|
|
13
|
+
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
|
|
14
|
+
"TACOTRON2_WAVERNN_CHAR_LJSPEECH",
|
|
15
|
+
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
|
|
16
|
+
]
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torchaudio._internal import load_state_dict_from_url
|
|
8
|
+
from torchaudio.functional import mu_law_decoding
|
|
9
|
+
from torchaudio.models import Tacotron2, WaveRNN
|
|
10
|
+
from torchaudio.transforms import GriffinLim, InverseMelScale
|
|
11
|
+
|
|
12
|
+
from . import utils
|
|
13
|
+
from .interface import Tacotron2TTSBundle
|
|
14
|
+
|
|
15
|
+
__all__ = []
|
|
16
|
+
|
|
17
|
+
_BASE_URL = "https://download.pytorch.org/torchaudio/models"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
################################################################################
|
|
21
|
+
# Pipeline implementation - Text Processor
|
|
22
|
+
################################################################################
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class _EnglishCharProcessor(Tacotron2TTSBundle.TextProcessor):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self._tokens = utils._get_chars()
|
|
29
|
+
self._mapping = {s: i for i, s in enumerate(self._tokens)}
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def tokens(self):
|
|
33
|
+
return self._tokens
|
|
34
|
+
|
|
35
|
+
def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
|
|
36
|
+
if isinstance(texts, str):
|
|
37
|
+
texts = [texts]
|
|
38
|
+
indices = [[self._mapping[c] for c in t.lower() if c in self._mapping] for t in texts]
|
|
39
|
+
return utils._to_tensor(indices)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor):
|
|
43
|
+
def __init__(self, *, dl_kwargs=None):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self._tokens = utils._get_phones()
|
|
46
|
+
self._mapping = {p: i for i, p in enumerate(self._tokens)}
|
|
47
|
+
self._phonemizer = utils._load_phonemizer("en_us_cmudict_forward.pt", dl_kwargs=dl_kwargs)
|
|
48
|
+
self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])"
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def tokens(self):
|
|
52
|
+
return self._tokens
|
|
53
|
+
|
|
54
|
+
def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]:
|
|
55
|
+
if isinstance(texts, str):
|
|
56
|
+
texts = [texts]
|
|
57
|
+
|
|
58
|
+
indices = []
|
|
59
|
+
for phones in self._phonemizer(texts, lang="en_us"):
|
|
60
|
+
# '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!']
|
|
61
|
+
ret = [re.sub(r"[\[\]]", "", r) for r in re.findall(self._pattern, phones)]
|
|
62
|
+
indices.append([self._mapping[p] for p in ret])
|
|
63
|
+
return utils._to_tensor(indices)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
################################################################################
|
|
67
|
+
# Pipeline implementation - Vocoder
|
|
68
|
+
################################################################################
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
|
|
72
|
+
def __init__(self, model: WaveRNN, min_level_db: Optional[float] = -100):
|
|
73
|
+
super().__init__()
|
|
74
|
+
self._sample_rate = 22050
|
|
75
|
+
self._model = model
|
|
76
|
+
self._min_level_db = min_level_db
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def sample_rate(self):
|
|
80
|
+
return self._sample_rate
|
|
81
|
+
|
|
82
|
+
def forward(self, mel_spec, lengths=None):
|
|
83
|
+
mel_spec = torch.exp(mel_spec)
|
|
84
|
+
mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
|
|
85
|
+
if self._min_level_db is not None:
|
|
86
|
+
mel_spec = (self._min_level_db - mel_spec) / self._min_level_db
|
|
87
|
+
mel_spec = torch.clamp(mel_spec, min=0, max=1)
|
|
88
|
+
waveform, lengths = self._model.infer(mel_spec, lengths)
|
|
89
|
+
waveform = utils._unnormalize_waveform(waveform, self._model.n_bits)
|
|
90
|
+
waveform = mu_law_decoding(waveform, self._model.n_classes)
|
|
91
|
+
waveform = waveform.squeeze(1)
|
|
92
|
+
return waveform, lengths
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
|
|
96
|
+
def __init__(self):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self._sample_rate = 22050
|
|
99
|
+
self._inv_mel = InverseMelScale(
|
|
100
|
+
n_stft=(1024 // 2 + 1),
|
|
101
|
+
n_mels=80,
|
|
102
|
+
sample_rate=self.sample_rate,
|
|
103
|
+
f_min=0.0,
|
|
104
|
+
f_max=8000.0,
|
|
105
|
+
mel_scale="slaney",
|
|
106
|
+
norm="slaney",
|
|
107
|
+
)
|
|
108
|
+
self._griffin_lim = GriffinLim(
|
|
109
|
+
n_fft=1024,
|
|
110
|
+
power=1,
|
|
111
|
+
hop_length=256,
|
|
112
|
+
win_length=1024,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def sample_rate(self):
|
|
117
|
+
return self._sample_rate
|
|
118
|
+
|
|
119
|
+
def forward(self, mel_spec, lengths=None):
|
|
120
|
+
mel_spec = torch.exp(mel_spec)
|
|
121
|
+
mel_spec = mel_spec.clone().detach().requires_grad_(True)
|
|
122
|
+
spec = self._inv_mel(mel_spec)
|
|
123
|
+
spec = spec.detach().requires_grad_(False)
|
|
124
|
+
waveforms = self._griffin_lim(spec)
|
|
125
|
+
return waveforms, lengths
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
################################################################################
|
|
129
|
+
# Bundle classes mixins
|
|
130
|
+
################################################################################
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class _CharMixin:
|
|
134
|
+
def get_text_processor(self) -> Tacotron2TTSBundle.TextProcessor:
|
|
135
|
+
return _EnglishCharProcessor()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class _PhoneMixin:
|
|
139
|
+
def get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor:
|
|
140
|
+
return _EnglishPhoneProcessor(dl_kwargs=dl_kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass
|
|
144
|
+
class _Tacotron2Mixin:
|
|
145
|
+
_tacotron2_path: str
|
|
146
|
+
_tacotron2_params: Dict[str, Any]
|
|
147
|
+
|
|
148
|
+
def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2:
|
|
149
|
+
model = Tacotron2(**self._tacotron2_params)
|
|
150
|
+
url = f"{_BASE_URL}/{self._tacotron2_path}"
|
|
151
|
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
|
152
|
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
153
|
+
model.load_state_dict(state_dict)
|
|
154
|
+
model.eval()
|
|
155
|
+
return model
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@dataclass
|
|
159
|
+
class _WaveRNNMixin:
|
|
160
|
+
_wavernn_path: Optional[str]
|
|
161
|
+
_wavernn_params: Optional[Dict[str, Any]]
|
|
162
|
+
|
|
163
|
+
def get_vocoder(self, *, dl_kwargs=None):
|
|
164
|
+
wavernn = self._get_wavernn(dl_kwargs=dl_kwargs)
|
|
165
|
+
return _WaveRNNVocoder(wavernn)
|
|
166
|
+
|
|
167
|
+
def _get_wavernn(self, *, dl_kwargs=None):
|
|
168
|
+
model = WaveRNN(**self._wavernn_params)
|
|
169
|
+
url = f"{_BASE_URL}/{self._wavernn_path}"
|
|
170
|
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
|
171
|
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
|
172
|
+
model.load_state_dict(state_dict)
|
|
173
|
+
model.eval()
|
|
174
|
+
return model
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class _GriffinLimMixin:
|
|
178
|
+
def get_vocoder(self, **_):
|
|
179
|
+
return _GriffinLimVocoder()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
################################################################################
|
|
183
|
+
# Bundle classes
|
|
184
|
+
################################################################################
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class _Tacotron2WaveRNNCharBundle(_WaveRNNMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@dataclass
|
|
193
|
+
class _Tacotron2WaveRNNPhoneBundle(_WaveRNNMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass
|
|
198
|
+
class _Tacotron2GriffinLimCharBundle(_GriffinLimMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle):
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle):
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
################################################################################
|
|
208
|
+
# Instantiate bundle objects
|
|
209
|
+
################################################################################
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle(
|
|
213
|
+
_tacotron2_path="tacotron2_english_characters_1500_epochs_ljspeech.pth",
|
|
214
|
+
_tacotron2_params=utils._get_taco_params(n_symbols=38),
|
|
215
|
+
)
|
|
216
|
+
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
|
|
217
|
+
:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
|
|
218
|
+
|
|
219
|
+
The text processor encodes the input texts character-by-character.
|
|
220
|
+
|
|
221
|
+
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
|
|
222
|
+
The default parameters were used.
|
|
223
|
+
|
|
224
|
+
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
|
|
225
|
+
|
|
226
|
+
Example - "Hello world! T T S stands for Text to Speech!"
|
|
227
|
+
|
|
228
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.png
|
|
229
|
+
:alt: Spectrogram generated by Tacotron2
|
|
230
|
+
|
|
231
|
+
.. raw:: html
|
|
232
|
+
|
|
233
|
+
<audio controls="controls">
|
|
234
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.wav" type="audio/wav">
|
|
235
|
+
Your browser does not support the <code>audio</code> element.
|
|
236
|
+
</audio>
|
|
237
|
+
|
|
238
|
+
Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
|
|
239
|
+
|
|
240
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.png
|
|
241
|
+
:alt: Spectrogram generated by Tacotron2
|
|
242
|
+
|
|
243
|
+
.. raw:: html
|
|
244
|
+
|
|
245
|
+
<audio controls="controls">
|
|
246
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH_v2.wav" type="audio/wav">
|
|
247
|
+
Your browser does not support the <code>audio</code> element.
|
|
248
|
+
</audio>
|
|
249
|
+
""" # noqa: E501
|
|
250
|
+
|
|
251
|
+
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle(
|
|
252
|
+
_tacotron2_path="tacotron2_english_phonemes_1500_epochs_ljspeech.pth",
|
|
253
|
+
_tacotron2_params=utils._get_taco_params(n_symbols=96),
|
|
254
|
+
)
|
|
255
|
+
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and
|
|
256
|
+
:py:class:`~torchaudio.transforms.GriffinLim` as vocoder.
|
|
257
|
+
|
|
258
|
+
The text processor encodes the input texts based on phoneme.
|
|
259
|
+
It uses `DeepPhonemizer <https://github.com/as-ideas/DeepPhonemizer>`__ to convert
|
|
260
|
+
graphemes to phonemes.
|
|
261
|
+
The model (*en_us_cmudict_forward*) was trained on
|
|
262
|
+
`CMUDict <http://www.speech.cs.cmu.edu/cgi-bin/cmudict>`__.
|
|
263
|
+
|
|
264
|
+
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
|
|
265
|
+
The text processor is set to the *"english_phonemes"*.
|
|
266
|
+
|
|
267
|
+
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
|
|
268
|
+
|
|
269
|
+
Example - "Hello world! T T S stands for Text to Speech!"
|
|
270
|
+
|
|
271
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.png
|
|
272
|
+
:alt: Spectrogram generated by Tacotron2
|
|
273
|
+
|
|
274
|
+
.. raw:: html
|
|
275
|
+
|
|
276
|
+
<audio controls="controls">
|
|
277
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.wav" type="audio/wav">
|
|
278
|
+
Your browser does not support the <code>audio</code> element.
|
|
279
|
+
</audio>
|
|
280
|
+
|
|
281
|
+
Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
|
|
282
|
+
|
|
283
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH_v2.png
|
|
284
|
+
:alt: Spectrogram generated by Tacotron2
|
|
285
|
+
|
|
286
|
+
.. raw:: html
|
|
287
|
+
|
|
288
|
+
<audio controls="controls">
|
|
289
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH_v2.wav" type="audio/wav">
|
|
290
|
+
Your browser does not support the <code>audio</code> element.
|
|
291
|
+
</audio>
|
|
292
|
+
|
|
293
|
+
""" # noqa: E501
|
|
294
|
+
|
|
295
|
+
TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle(
|
|
296
|
+
_tacotron2_path="tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth",
|
|
297
|
+
_tacotron2_params=utils._get_taco_params(n_symbols=38),
|
|
298
|
+
_wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
|
|
299
|
+
_wavernn_params=utils._get_wrnn_params(),
|
|
300
|
+
)
|
|
301
|
+
TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = """Character-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs and :py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
|
|
302
|
+
|
|
303
|
+
The text processor encodes the input texts character-by-character.
|
|
304
|
+
|
|
305
|
+
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
|
|
306
|
+
The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
|
|
307
|
+
``mel_fmin=40``, and ``mel_fmax=11025``.
|
|
308
|
+
|
|
309
|
+
You can find the training script `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_wavernn>`__.
|
|
310
|
+
|
|
311
|
+
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
|
|
312
|
+
|
|
313
|
+
Example - "Hello world! T T S stands for Text to Speech!"
|
|
314
|
+
|
|
315
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.png
|
|
316
|
+
:alt: Spectrogram generated by Tacotron2
|
|
317
|
+
|
|
318
|
+
.. raw:: html
|
|
319
|
+
|
|
320
|
+
<audio controls="controls">
|
|
321
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.wav" type="audio/wav">
|
|
322
|
+
Your browser does not support the <code>audio</code> element.
|
|
323
|
+
</audio>
|
|
324
|
+
|
|
325
|
+
Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
|
|
326
|
+
|
|
327
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.png
|
|
328
|
+
:alt: Spectrogram generated by Tacotron2
|
|
329
|
+
|
|
330
|
+
.. raw:: html
|
|
331
|
+
|
|
332
|
+
<audio controls="controls">
|
|
333
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH_v2.wav" type="audio/wav">
|
|
334
|
+
Your browser does not support the <code>audio</code> element.
|
|
335
|
+
</audio>
|
|
336
|
+
""" # noqa: E501
|
|
337
|
+
|
|
338
|
+
TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle(
|
|
339
|
+
_tacotron2_path="tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth",
|
|
340
|
+
_tacotron2_params=utils._get_taco_params(n_symbols=96),
|
|
341
|
+
_wavernn_path="wavernn_10k_epochs_8bits_ljspeech.pth",
|
|
342
|
+
_wavernn_params=utils._get_wrnn_params(),
|
|
343
|
+
)
|
|
344
|
+
TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = """Phoneme-based TTS pipeline with :py:class:`~torchaudio.models.Tacotron2` trained on *LJSpeech* :cite:`ljspeech17` for 1,500 epochs, and
|
|
345
|
+
:py:class:`~torchaudio.models.WaveRNN` vocoder trained on 8 bits depth waveform of *LJSpeech* :cite:`ljspeech17` for 10,000 epochs.
|
|
346
|
+
|
|
347
|
+
The text processor encodes the input texts based on phoneme.
|
|
348
|
+
It uses `DeepPhonemizer <https://github.com/as-ideas/DeepPhonemizer>`__ to convert
|
|
349
|
+
graphemes to phonemes.
|
|
350
|
+
The model (*en_us_cmudict_forward*) was trained on
|
|
351
|
+
`CMUDict <http://www.speech.cs.cmu.edu/cgi-bin/cmudict>`__.
|
|
352
|
+
|
|
353
|
+
You can find the training script for Tacotron2 `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_tacotron2>`__.
|
|
354
|
+
The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``,
|
|
355
|
+
``mel_fmin=40``, and ``mel_fmax=11025``.
|
|
356
|
+
|
|
357
|
+
You can find the training script for WaveRNN `here <https://github.com/pytorch/audio/tree/main/examples/pipeline_wavernn>`__.
|
|
358
|
+
|
|
359
|
+
Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage.
|
|
360
|
+
|
|
361
|
+
Example - "Hello world! T T S stands for Text to Speech!"
|
|
362
|
+
|
|
363
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.png
|
|
364
|
+
:alt: Spectrogram generated by Tacotron2
|
|
365
|
+
|
|
366
|
+
.. raw:: html
|
|
367
|
+
|
|
368
|
+
<audio controls="controls">
|
|
369
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.wav" type="audio/wav">
|
|
370
|
+
Your browser does not support the <code>audio</code> element.
|
|
371
|
+
</audio>
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
Example - "The examination and testimony of the experts enabled the Commission to conclude that five shots may have been fired,"
|
|
375
|
+
|
|
376
|
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.png
|
|
377
|
+
:alt: Spectrogram generated by Tacotron2
|
|
378
|
+
|
|
379
|
+
.. raw:: html
|
|
380
|
+
|
|
381
|
+
<audio controls="controls">
|
|
382
|
+
<source src="https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH_v2.wav" type="audio/wav">
|
|
383
|
+
Your browser does not support the <code>audio</code> element.
|
|
384
|
+
</audio>
|
|
385
|
+
""" # noqa: E501
|