phoonnx 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. phoonnx/__init__.py +0 -0
  2. phoonnx/config.py +490 -0
  3. phoonnx/locale/ca/phonetic_spellings.txt +2 -0
  4. phoonnx/locale/en/phonetic_spellings.txt +1 -0
  5. phoonnx/locale/gl/phonetic_spellings.txt +2 -0
  6. phoonnx/locale/pt/phonetic_spellings.txt +2 -0
  7. phoonnx/phoneme_ids.py +453 -0
  8. phoonnx/phonemizers/__init__.py +45 -0
  9. phoonnx/phonemizers/ar.py +42 -0
  10. phoonnx/phonemizers/base.py +216 -0
  11. phoonnx/phonemizers/en.py +250 -0
  12. phoonnx/phonemizers/fa.py +46 -0
  13. phoonnx/phonemizers/gl.py +142 -0
  14. phoonnx/phonemizers/he.py +67 -0
  15. phoonnx/phonemizers/ja.py +119 -0
  16. phoonnx/phonemizers/ko.py +97 -0
  17. phoonnx/phonemizers/mul.py +606 -0
  18. phoonnx/phonemizers/vi.py +44 -0
  19. phoonnx/phonemizers/zh.py +308 -0
  20. phoonnx/thirdparty/__init__.py +0 -0
  21. phoonnx/thirdparty/arpa2ipa.py +249 -0
  22. phoonnx/thirdparty/cotovia/cotovia_aarch64 +0 -0
  23. phoonnx/thirdparty/cotovia/cotovia_x86_64 +0 -0
  24. phoonnx/thirdparty/hangul2ipa.py +783 -0
  25. phoonnx/thirdparty/ko_tables/aspiration.csv +20 -0
  26. phoonnx/thirdparty/ko_tables/assimilation.csv +31 -0
  27. phoonnx/thirdparty/ko_tables/double_coda.csv +17 -0
  28. phoonnx/thirdparty/ko_tables/hanja.tsv +8525 -0
  29. phoonnx/thirdparty/ko_tables/ipa.csv +22 -0
  30. phoonnx/thirdparty/ko_tables/neutralization.csv +11 -0
  31. phoonnx/thirdparty/ko_tables/tensification.csv +56 -0
  32. phoonnx/thirdparty/ko_tables/yale.csv +22 -0
  33. phoonnx/thirdparty/kog2p/__init__.py +385 -0
  34. phoonnx/thirdparty/kog2p/rulebook.txt +212 -0
  35. phoonnx/thirdparty/mantoq/__init__.py +67 -0
  36. phoonnx/thirdparty/mantoq/buck/__init__.py +0 -0
  37. phoonnx/thirdparty/mantoq/buck/phonetise_buckwalter.py +569 -0
  38. phoonnx/thirdparty/mantoq/buck/symbols.py +64 -0
  39. phoonnx/thirdparty/mantoq/buck/tokenization.py +105 -0
  40. phoonnx/thirdparty/mantoq/num2words.py +37 -0
  41. phoonnx/thirdparty/mantoq/pyarabic/__init__.py +12 -0
  42. phoonnx/thirdparty/mantoq/pyarabic/arabrepr.py +64 -0
  43. phoonnx/thirdparty/mantoq/pyarabic/araby.py +1647 -0
  44. phoonnx/thirdparty/mantoq/pyarabic/named_const.py +227 -0
  45. phoonnx/thirdparty/mantoq/pyarabic/normalize.py +161 -0
  46. phoonnx/thirdparty/mantoq/pyarabic/number.py +826 -0
  47. phoonnx/thirdparty/mantoq/pyarabic/number_const.py +1704 -0
  48. phoonnx/thirdparty/mantoq/pyarabic/stack.py +52 -0
  49. phoonnx/thirdparty/mantoq/pyarabic/trans.py +517 -0
  50. phoonnx/thirdparty/mantoq/unicode_symbol2label.py +4173 -0
  51. phoonnx/thirdparty/tashkeel/LICENSE +22 -0
  52. phoonnx/thirdparty/tashkeel/SOURCE +1 -0
  53. phoonnx/thirdparty/tashkeel/__init__.py +212 -0
  54. phoonnx/thirdparty/tashkeel/hint_id_map.json +18 -0
  55. phoonnx/thirdparty/tashkeel/input_id_map.json +56 -0
  56. phoonnx/thirdparty/tashkeel/model.onnx +0 -0
  57. phoonnx/thirdparty/tashkeel/target_id_map.json +17 -0
  58. phoonnx/thirdparty/zh_num.py +238 -0
  59. phoonnx/util.py +705 -0
  60. phoonnx/version.py +6 -0
  61. phoonnx/voice.py +521 -0
  62. phoonnx-0.0.0.dist-info/METADATA +255 -0
  63. phoonnx-0.0.0.dist-info/RECORD +86 -0
  64. phoonnx-0.0.0.dist-info/WHEEL +5 -0
  65. phoonnx-0.0.0.dist-info/top_level.txt +2 -0
  66. phoonnx_train/__main__.py +151 -0
  67. phoonnx_train/export_onnx.py +109 -0
  68. phoonnx_train/norm_audio/__init__.py +92 -0
  69. phoonnx_train/norm_audio/trim.py +54 -0
  70. phoonnx_train/norm_audio/vad.py +54 -0
  71. phoonnx_train/preprocess.py +420 -0
  72. phoonnx_train/vits/__init__.py +0 -0
  73. phoonnx_train/vits/attentions.py +427 -0
  74. phoonnx_train/vits/commons.py +147 -0
  75. phoonnx_train/vits/config.py +330 -0
  76. phoonnx_train/vits/dataset.py +214 -0
  77. phoonnx_train/vits/lightning.py +352 -0
  78. phoonnx_train/vits/losses.py +58 -0
  79. phoonnx_train/vits/mel_processing.py +139 -0
  80. phoonnx_train/vits/models.py +732 -0
  81. phoonnx_train/vits/modules.py +527 -0
  82. phoonnx_train/vits/monotonic_align/__init__.py +20 -0
  83. phoonnx_train/vits/monotonic_align/setup.py +13 -0
  84. phoonnx_train/vits/transforms.py +212 -0
  85. phoonnx_train/vits/utils.py +16 -0
  86. phoonnx_train/vits/wavfile.py +860 -0
@@ -0,0 +1,330 @@
1
+ """Configuration classes"""
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ @dataclass
7
+ class MelAudioConfig:
8
+ filter_length: int = 1024
9
+ hop_length: int = 256
10
+ win_length: int = 1024
11
+ mel_channels: int = 80
12
+ sample_rate: int = 22050
13
+ sample_bytes: int = 2
14
+ channels: int = 1
15
+ mel_fmin: float = 0.0
16
+ mel_fmax: Optional[float] = None
17
+
18
+
19
+ @dataclass
20
+ class ModelAudioConfig:
21
+ resblock: str
22
+ resblock_kernel_sizes: Tuple[int, ...]
23
+ resblock_dilation_sizes: Tuple[Tuple[int, ...], ...]
24
+ upsample_rates: Tuple[int, ...]
25
+ upsample_initial_channel: int
26
+ upsample_kernel_sizes: Tuple[int, ...]
27
+
28
+ @staticmethod
29
+ def low_quality() -> "ModelAudioConfig":
30
+ return ModelAudioConfig(
31
+ resblock="2",
32
+ resblock_kernel_sizes=(3, 5, 7),
33
+ resblock_dilation_sizes=(
34
+ (1, 2),
35
+ (2, 6),
36
+ (3, 12),
37
+ ),
38
+ upsample_rates=(8, 8, 4),
39
+ upsample_initial_channel=256,
40
+ upsample_kernel_sizes=(16, 16, 8),
41
+ )
42
+
43
+ @staticmethod
44
+ def high_quality() -> "ModelAudioConfig":
45
+ return ModelAudioConfig(
46
+ resblock="1",
47
+ resblock_kernel_sizes=(3, 7, 11),
48
+ resblock_dilation_sizes=(
49
+ (1, 3, 5),
50
+ (1, 3, 5),
51
+ (1, 3, 5),
52
+ ),
53
+ upsample_rates=(8, 8, 2, 2),
54
+ upsample_initial_channel=512,
55
+ upsample_kernel_sizes=(16, 16, 4, 4),
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class ModelConfig:
61
+ num_symbols: int
62
+ n_speakers: int
63
+ audio: ModelAudioConfig
64
+ mel: MelAudioConfig = field(default_factory=MelAudioConfig)
65
+
66
+ inter_channels: int = 192
67
+ hidden_channels: int = 192
68
+ filter_channels: int = 768
69
+ n_heads: int = 2
70
+ n_layers: int = 6
71
+ kernel_size: int = 3
72
+ p_dropout: float = 0.1
73
+ n_layers_q: int = 3
74
+ use_spectral_norm: bool = False
75
+ gin_channels: int = 0 # single speaker
76
+ use_sdp: bool = True # StochasticDurationPredictor
77
+ segment_size: int = 8192
78
+
79
+ @property
80
+ def is_multispeaker(self) -> bool:
81
+ return self.n_speakers > 1
82
+
83
+ @property
84
+ def resblock(self) -> str:
85
+ return self.audio.resblock
86
+
87
+ @property
88
+ def resblock_kernel_sizes(self) -> Tuple[int, ...]:
89
+ return self.audio.resblock_kernel_sizes
90
+
91
+ @property
92
+ def resblock_dilation_sizes(self) -> Tuple[Tuple[int, ...], ...]:
93
+ return self.audio.resblock_dilation_sizes
94
+
95
+ @property
96
+ def upsample_rates(self) -> Tuple[int, ...]:
97
+ return self.audio.upsample_rates
98
+
99
+ @property
100
+ def upsample_initial_channel(self) -> int:
101
+ return self.audio.upsample_initial_channel
102
+
103
+ @property
104
+ def upsample_kernel_sizes(self) -> Tuple[int, ...]:
105
+ return self.audio.upsample_kernel_sizes
106
+
107
+ def __post_init__(self):
108
+ if self.is_multispeaker and (self.gin_channels == 0):
109
+ self.gin_channels = 512
110
+
111
+
112
+ @dataclass
113
+ class TrainingConfig:
114
+ learning_rate: float = 2e-4
115
+ betas: Tuple[float, float] = field(default=(0.8, 0.99))
116
+ eps: float = 1e-9
117
+ # batch_size: int = 32
118
+ fp16_run: bool = False
119
+ lr_decay: float = 0.999875
120
+ init_lr_ratio: float = 1.0
121
+ warmup_epochs: int = 0
122
+ c_mel: int = 45
123
+ c_kl: float = 1.0
124
+ grad_clip: Optional[float] = None
125
+
126
+
127
+ # @dataclass
128
+ # class PhonemesConfig(DataClassJsonMixin):
129
+ # phoneme_separator: str = " "
130
+ # """Separator between individual phonemes in CSV input"""
131
+
132
+ # word_separator: str = "#"
133
+ # """Separator between word phonemes in CSV input (must not match phoneme_separator)"""
134
+
135
+ # phoneme_to_id: typing.Optional[typing.Dict[str, int]] = None
136
+ # pad: typing.Optional[str] = "_"
137
+ # bos: typing.Optional[str] = None
138
+ # eos: typing.Optional[str] = None
139
+ # blank: typing.Optional[str] = "#"
140
+ # blank_word: typing.Optional[str] = None
141
+ # blank_between: typing.Union[str, BlankBetween] = BlankBetween.WORDS
142
+ # blank_at_start: bool = True
143
+ # blank_at_end: bool = True
144
+ # simple_punctuation: bool = True
145
+ # punctuation_map: typing.Optional[typing.Dict[str, str]] = None
146
+ # separate: typing.Optional[typing.List[str]] = None
147
+ # separate_graphemes: bool = False
148
+ # separate_tones: bool = False
149
+ # tone_before: bool = False
150
+ # phoneme_map: typing.Optional[typing.Dict[str, typing.List[str]]] = None
151
+ # auto_bos_eos: bool = False
152
+ # minor_break: typing.Optional[str] = IPA.BREAK_MINOR.value
153
+ # major_break: typing.Optional[str] = IPA.BREAK_MAJOR.value
154
+ # break_phonemes_into_graphemes: bool = False
155
+ # break_phonemes_into_codepoints: bool = False
156
+ # drop_stress: bool = False
157
+ # symbols: typing.Optional[typing.List[str]] = None
158
+
159
+ # def split_word_phonemes(self, phonemes_str: str) -> typing.List[typing.List[str]]:
160
+ # """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
161
+ # return [
162
+ # word_phonemes_str.split(self.phoneme_separator)
163
+ # if self.phoneme_separator
164
+ # else list(word_phonemes_str)
165
+ # for word_phonemes_str in phonemes_str.split(self.word_separator)
166
+ # ]
167
+
168
+ # def join_word_phonemes(self, word_phonemes: typing.List[typing.List[str]]) -> str:
169
+ # """Split phonemes string into a list of lists (outer is words, inner is individual phonemes in each word)"""
170
+ # return self.word_separator.join(
171
+ # self.phoneme_separator.join(wp) for wp in word_phonemes
172
+ # )
173
+
174
+
175
+ # class Phonemizer(str, Enum):
176
+ # SYMBOLS = "symbols"
177
+ # GRUUT = "gruut"
178
+ # ESPEAK = "espeak"
179
+ # EPITRAN = "epitran"
180
+
181
+
182
+ # class Aligner(str, Enum):
183
+ # KALDI_ALIGN = "kaldi_align"
184
+
185
+
186
+ # class TextCasing(str, Enum):
187
+ # LOWER = "lower"
188
+ # UPPER = "upper"
189
+
190
+
191
+ # class MetadataFormat(str, Enum):
192
+ # TEXT = "text"
193
+ # PHONEMES = "phonemes"
194
+ # PHONEME_IDS = "ids"
195
+
196
+
197
+ # @dataclass
198
+ # class DatasetConfig:
199
+ # name: str
200
+ # metadata_format: MetadataFormat = MetadataFormat.TEXT
201
+ # multispeaker: bool = False
202
+ # text_language: typing.Optional[str] = None
203
+ # audio_dir: typing.Optional[typing.Union[str, Path]] = None
204
+ # cache_dir: typing.Optional[typing.Union[str, Path]] = None
205
+
206
+ # def get_cache_dir(self, output_dir: typing.Union[str, Path]) -> Path:
207
+ # if self.cache_dir is not None:
208
+ # cache_dir = Path(self.cache_dir)
209
+ # else:
210
+ # cache_dir = Path("cache") / self.name
211
+
212
+ # if not cache_dir.is_absolute():
213
+ # cache_dir = Path(output_dir) / str(cache_dir)
214
+
215
+ # return cache_dir
216
+
217
+
218
+ # @dataclass
219
+ # class AlignerConfig:
220
+ # aligner: typing.Optional[Aligner] = None
221
+ # casing: typing.Optional[TextCasing] = None
222
+
223
+
224
+ # @dataclass
225
+ # class InferenceConfig:
226
+ # length_scale: float = 1.0
227
+ # noise_scale: float = 0.667
228
+ # noise_w: float = 0.8
229
+
230
+
231
+ # @dataclass
232
+ # class TrainingConfig(DataClassJsonMixin):
233
+ # seed: int = 1234
234
+ # epochs: int = 10000
235
+ # learning_rate: float = 2e-4
236
+ # betas: typing.Tuple[float, float] = field(default=(0.8, 0.99))
237
+ # eps: float = 1e-9
238
+ # batch_size: int = 32
239
+ # fp16_run: bool = False
240
+ # lr_decay: float = 0.999875
241
+ # segment_size: int = 8192
242
+ # init_lr_ratio: float = 1.0
243
+ # warmup_epochs: int = 0
244
+ # c_mel: int = 45
245
+ # c_kl: float = 1.0
246
+ # grad_clip: typing.Optional[float] = None
247
+
248
+ # min_seq_length: typing.Optional[int] = None
249
+ # max_seq_length: typing.Optional[int] = None
250
+
251
+ # min_spec_length: typing.Optional[int] = None
252
+ # max_spec_length: typing.Optional[int] = None
253
+
254
+ # min_speaker_utterances: typing.Optional[int] = None
255
+
256
+ # last_epoch: int = 1
257
+ # global_step: int = 1
258
+ # best_loss: typing.Optional[float] = None
259
+ # audio: AudioConfig = field(default_factory=AudioConfig)
260
+ # model: ModelConfig = field(default_factory=ModelConfig)
261
+ # phonemes: PhonemesConfig = field(default_factory=PhonemesConfig)
262
+ # text_aligner: AlignerConfig = field(default_factory=AlignerConfig)
263
+ # text_language: typing.Optional[str] = None
264
+ # phonemizer: typing.Optional[Phonemizer] = None
265
+ # datasets: typing.List[DatasetConfig] = field(default_factory=list)
266
+ # inference: InferenceConfig = field(default_factory=InferenceConfig)
267
+
268
+ # version: int = 1
269
+ # git_commit: str = ""
270
+
271
+ # @property
272
+ # def is_multispeaker(self):
273
+ # return self.model.is_multispeaker or any(d.multispeaker for d in self.datasets)
274
+
275
+ # def save(self, config_file: typing.TextIO):
276
+ # """Save config as JSON to a file"""
277
+ # json.dump(self.to_dict(), config_file, indent=4)
278
+
279
+ # def get_speaker_id(self, dataset_name: str, speaker_name: str) -> int:
280
+ # if self.speaker_id_map is None:
281
+ # self.speaker_id_map = {}
282
+
283
+ # full_speaker_name = f"{dataset_name}_{speaker_name}"
284
+ # speaker_id = self.speaker_id_map.get(full_speaker_name)
285
+ # if speaker_id is None:
286
+ # speaker_id = len(self.speaker_id_map)
287
+ # self.speaker_id_map[full_speaker_name] = speaker_id
288
+
289
+ # return speaker_id
290
+
291
+ # @staticmethod
292
+ # def load(config_file: typing.TextIO) -> "TrainingConfig":
293
+ # """Load config from a JSON file"""
294
+ # return TrainingConfig.from_json(config_file.read())
295
+
296
+ # @staticmethod
297
+ # def load_and_merge(
298
+ # config: "TrainingConfig",
299
+ # config_files: typing.Iterable[typing.Union[str, Path, typing.TextIO]],
300
+ # ) -> "TrainingConfig":
301
+ # """Loads one or more JSON configuration files and overlays them on top of an existing config"""
302
+ # base_dict = config.to_dict()
303
+ # for maybe_config_file in config_files:
304
+ # if isinstance(maybe_config_file, (str, Path)):
305
+ # # File path
306
+ # config_file = open(maybe_config_file, "r", encoding="utf-8")
307
+ # else:
308
+ # # File object
309
+ # config_file = maybe_config_file
310
+
311
+ # with config_file:
312
+ # # Load new config and overlay on existing config
313
+ # new_dict = json.load(config_file)
314
+ # TrainingConfig.recursive_update(base_dict, new_dict)
315
+
316
+ # return TrainingConfig.from_dict(base_dict)
317
+
318
+ # @staticmethod
319
+ # def recursive_update(
320
+ # base_dict: typing.Dict[typing.Any, typing.Any],
321
+ # new_dict: typing.Mapping[typing.Any, typing.Any],
322
+ # ) -> None:
323
+ # """Recursively overwrites values in base dictionary with values from new dictionary"""
324
+ # for key, value in new_dict.items():
325
+ # if isinstance(value, collections.Mapping) and (
326
+ # base_dict.get(key) is not None
327
+ # ):
328
+ # TrainingConfig.recursive_update(base_dict[key], value)
329
+ # else:
330
+ # base_dict[key] = value
@@ -0,0 +1,214 @@
1
+ import json
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Iterable, List, Optional, Sequence, Union
6
+
7
+ import torch
8
+ from torch import FloatTensor, LongTensor
9
+ from torch.utils.data import Dataset
10
+
11
+ _LOGGER = logging.getLogger("vits.dataset")
12
+
13
+
14
+ @dataclass
15
+ class Utterance:
16
+ phoneme_ids: List[int]
17
+ audio_norm_path: Path
18
+ audio_spec_path: Path
19
+ speaker_id: Optional[int] = None
20
+ text: Optional[str] = None
21
+
22
+
23
+ @dataclass
24
+ class UtteranceTensors:
25
+ phoneme_ids: LongTensor
26
+ spectrogram: FloatTensor
27
+ audio_norm: FloatTensor
28
+ speaker_id: Optional[LongTensor] = None
29
+ text: Optional[str] = None
30
+
31
+ @property
32
+ def spec_length(self) -> int:
33
+ return self.spectrogram.size(1)
34
+
35
+
36
+ @dataclass
37
+ class Batch:
38
+ phoneme_ids: LongTensor
39
+ phoneme_lengths: LongTensor
40
+ spectrograms: FloatTensor
41
+ spectrogram_lengths: LongTensor
42
+ audios: FloatTensor
43
+ audio_lengths: LongTensor
44
+ speaker_ids: Optional[LongTensor] = None
45
+
46
+
47
+ class PiperDataset(Dataset):
48
+ """
49
+ Dataset format:
50
+
51
+ * phoneme_ids (required)
52
+ * audio_norm_path (required)
53
+ * audio_spec_path (required)
54
+ * text (optional)
55
+ * phonemes (optional)
56
+ * audio_path (optional)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ dataset_paths: List[Union[str, Path]],
62
+ max_phoneme_ids: Optional[int] = None,
63
+ ):
64
+ self.utterances: List[Utterance] = []
65
+
66
+ for dataset_path in dataset_paths:
67
+ dataset_path = Path(dataset_path)
68
+ _LOGGER.debug("Loading dataset: %s", dataset_path)
69
+ self.utterances.extend(
70
+ PiperDataset.load_dataset(dataset_path, max_phoneme_ids=max_phoneme_ids)
71
+ )
72
+
73
+ def __len__(self):
74
+ return len(self.utterances)
75
+
76
+ def __getitem__(self, idx) -> UtteranceTensors:
77
+ utt = self.utterances[idx]
78
+ return UtteranceTensors(
79
+ phoneme_ids=LongTensor(utt.phoneme_ids),
80
+ audio_norm=torch.load(utt.audio_norm_path),
81
+ spectrogram=torch.load(utt.audio_spec_path),
82
+ speaker_id=LongTensor([utt.speaker_id])
83
+ if utt.speaker_id is not None
84
+ else None,
85
+ text=utt.text,
86
+ )
87
+
88
+ @staticmethod
89
+ def load_dataset(
90
+ dataset_path: Path,
91
+ max_phoneme_ids: Optional[int] = None,
92
+ ) -> Iterable[Utterance]:
93
+ num_skipped = 0
94
+
95
+ with open(dataset_path, "r", encoding="utf-8") as dataset_file:
96
+ for line_idx, line in enumerate(dataset_file):
97
+ line = line.strip()
98
+ if not line:
99
+ continue
100
+
101
+ try:
102
+ utt = PiperDataset.load_utterance(line)
103
+ if (max_phoneme_ids is None) or (
104
+ len(utt.phoneme_ids) <= max_phoneme_ids
105
+ ):
106
+ yield utt
107
+ else:
108
+ num_skipped += 1
109
+ except Exception:
110
+ _LOGGER.exception(
111
+ "Error on line %s of %s: %s",
112
+ line_idx + 1,
113
+ dataset_path,
114
+ line,
115
+ )
116
+
117
+ if num_skipped > 0:
118
+ _LOGGER.warning("Skipped %s utterance(s)", num_skipped)
119
+
120
+ @staticmethod
121
+ def load_utterance(line: str) -> Utterance:
122
+ utt_dict = json.loads(line)
123
+ return Utterance(
124
+ phoneme_ids=utt_dict["phoneme_ids"],
125
+ audio_norm_path=Path(utt_dict["audio_norm_path"]),
126
+ audio_spec_path=Path(utt_dict["audio_spec_path"]),
127
+ speaker_id=utt_dict.get("speaker_id"),
128
+ text=utt_dict.get("text"),
129
+ )
130
+
131
+
132
+ class UtteranceCollate:
133
+ def __init__(self, is_multispeaker: bool, segment_size: int):
134
+ self.is_multispeaker = is_multispeaker
135
+ self.segment_size = segment_size
136
+
137
+ def __call__(self, utterances: Sequence[UtteranceTensors]) -> Batch:
138
+ num_utterances = len(utterances)
139
+ assert num_utterances > 0, "No utterances"
140
+
141
+ max_phonemes_length = 0
142
+ max_spec_length = 0
143
+ max_audio_length = 0
144
+
145
+ num_mels = 0
146
+
147
+ # Determine lengths
148
+ for utt_idx, utt in enumerate(utterances):
149
+ assert utt.spectrogram is not None
150
+ assert utt.audio_norm is not None
151
+
152
+ phoneme_length = utt.phoneme_ids.size(0)
153
+ spec_length = utt.spectrogram.size(1)
154
+ audio_length = utt.audio_norm.size(1)
155
+
156
+ max_phonemes_length = max(max_phonemes_length, phoneme_length)
157
+ max_spec_length = max(max_spec_length, spec_length)
158
+ max_audio_length = max(max_audio_length, audio_length)
159
+
160
+ num_mels = utt.spectrogram.size(0)
161
+ if self.is_multispeaker:
162
+ assert utt.speaker_id is not None, "Missing speaker id"
163
+
164
+ # Audio cannot be smaller than segment size (8192)
165
+ max_audio_length = max(max_audio_length, self.segment_size)
166
+
167
+ # Create padded tensors
168
+ phonemes_padded = LongTensor(num_utterances, max_phonemes_length)
169
+ spec_padded = FloatTensor(num_utterances, num_mels, max_spec_length)
170
+ audio_padded = FloatTensor(num_utterances, 1, max_audio_length)
171
+
172
+ phonemes_padded.zero_()
173
+ spec_padded.zero_()
174
+ audio_padded.zero_()
175
+
176
+ phoneme_lengths = LongTensor(num_utterances)
177
+ spec_lengths = LongTensor(num_utterances)
178
+ audio_lengths = LongTensor(num_utterances)
179
+
180
+ speaker_ids: Optional[LongTensor] = None
181
+ if self.is_multispeaker:
182
+ speaker_ids = LongTensor(num_utterances)
183
+
184
+ # Sort by decreasing spectrogram length
185
+ sorted_utterances = sorted(
186
+ utterances, key=lambda u: u.spectrogram.size(1), reverse=True
187
+ )
188
+ for utt_idx, utt in enumerate(sorted_utterances):
189
+ phoneme_length = utt.phoneme_ids.size(0)
190
+ spec_length = utt.spectrogram.size(1)
191
+ audio_length = utt.audio_norm.size(1)
192
+
193
+ phonemes_padded[utt_idx, :phoneme_length] = utt.phoneme_ids
194
+ phoneme_lengths[utt_idx] = phoneme_length
195
+
196
+ spec_padded[utt_idx, :, :spec_length] = utt.spectrogram
197
+ spec_lengths[utt_idx] = spec_length
198
+
199
+ audio_padded[utt_idx, :, :audio_length] = utt.audio_norm
200
+ audio_lengths[utt_idx] = audio_length
201
+
202
+ if utt.speaker_id is not None:
203
+ assert speaker_ids is not None
204
+ speaker_ids[utt_idx] = utt.speaker_id
205
+
206
+ return Batch(
207
+ phoneme_ids=phonemes_padded,
208
+ phoneme_lengths=phoneme_lengths,
209
+ spectrograms=spec_padded,
210
+ spectrogram_lengths=spec_lengths,
211
+ audios=audio_padded,
212
+ audio_lengths=audio_lengths,
213
+ speaker_ids=speaker_ids,
214
+ )