phoonnx 0.2.0a2__py3-none-any.whl → 0.2.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phoonnx/version.py +2 -2
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.1a1.dist-info}/METADATA +1 -1
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.1a1.dist-info}/RECORD +7 -7
- phoonnx_train/preprocess.py +302 -152
- phoonnx_train/train.py +151 -0
- phoonnx_train/__main__.py +0 -151
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.1a1.dist-info}/WHEEL +0 -0
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.1a1.dist-info}/top_level.txt +0 -0
phoonnx/version.py
CHANGED
@@ -2,7 +2,7 @@ phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
phoonnx/config.py,sha256=DKgsU03g8jrAuMcVqbu-w3MWPXOUihFtRnavg6WGQ1Y,19983
|
3
3
|
phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
|
4
4
|
phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
|
5
|
-
phoonnx/version.py,sha256=
|
5
|
+
phoonnx/version.py,sha256=hzUBKPGD2cJBEmHoNugbPPHTO3ITl0OwgO9W0oQfjho,237
|
6
6
|
phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
|
7
7
|
phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
|
8
8
|
phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
|
@@ -62,9 +62,9 @@ phoonnx/thirdparty/tashkeel/hint_id_map.json,sha256=gJMdtTsfEDFgmmbyO2Shw315rkqK
|
|
62
62
|
phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771ltzkv1EnYmHKc2w8M,628
|
63
63
|
phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
|
64
64
|
phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
|
65
|
-
phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
|
66
65
|
phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
|
67
|
-
phoonnx_train/preprocess.py,sha256=
|
66
|
+
phoonnx_train/preprocess.py,sha256=nSkcjThEGfKuwmOurhoXwXfnksqtb-hcdCbx1byDGRI,19890
|
67
|
+
phoonnx_train/train.py,sha256=6ydKJb1sqZy6wJpfSAlkXZ8Y6W2GKsg7c8KPodj3j-Q,5989
|
68
68
|
phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
|
69
69
|
phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
|
70
70
|
phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
|
@@ -83,7 +83,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
|
|
83
83
|
phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
|
84
84
|
phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
|
85
85
|
phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
|
86
|
-
phoonnx-0.2.
|
87
|
-
phoonnx-0.2.
|
88
|
-
phoonnx-0.2.
|
89
|
-
phoonnx-0.2.
|
86
|
+
phoonnx-0.2.1a1.dist-info/METADATA,sha256=0nAZevi9ypEWjg_QQpByxdQX3KGDKkfNlyb7rhLHmww,8250
|
87
|
+
phoonnx-0.2.1a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
88
|
+
phoonnx-0.2.1a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
|
89
|
+
phoonnx-0.2.1a1.dist-info/RECORD,,
|
phoonnx_train/preprocess.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
import argparse
|
3
2
|
import csv
|
4
3
|
import dataclasses
|
5
4
|
import itertools
|
@@ -10,13 +9,16 @@ from collections import Counter
|
|
10
9
|
from dataclasses import dataclass
|
11
10
|
from multiprocessing import JoinableQueue, Process, Queue
|
12
11
|
from pathlib import Path
|
13
|
-
from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union
|
12
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union, Callable
|
14
13
|
|
14
|
+
import click
|
15
15
|
from phoonnx.util import normalize
|
16
16
|
from phoonnx.config import PhonemeType, get_phonemizer, Alphabet
|
17
17
|
from phoonnx.phonemizers import Phonemizer
|
18
|
-
from phoonnx.phoneme_ids import (
|
19
|
-
|
18
|
+
from phoonnx.phoneme_ids import (
|
19
|
+
phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
|
20
|
+
DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN
|
21
|
+
)
|
20
22
|
from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
|
21
23
|
from tqdm import tqdm
|
22
24
|
from phoonnx.version import VERSION_STR
|
@@ -46,7 +48,7 @@ class Utterance:
|
|
46
48
|
audio_spec_path: Optional[Path] = None
|
47
49
|
|
48
50
|
def asdict(self) -> Dict[str, Any]:
|
49
|
-
"""Custom asdict to handle Path objects."""
|
51
|
+
"""Custom asdict to handle Path objects for JSON serialization."""
|
50
52
|
data = dataclasses.asdict(self)
|
51
53
|
for key, value in data.items():
|
52
54
|
if isinstance(value, Path):
|
@@ -57,14 +59,31 @@ class Utterance:
|
|
57
59
|
class PathEncoder(json.JSONEncoder):
|
58
60
|
"""JSON encoder for Path objects."""
|
59
61
|
|
60
|
-
def default(self, o):
|
62
|
+
def default(self, o: Any) -> Union[str, Any]:
|
63
|
+
"""
|
64
|
+
Converts Path objects to strings for serialization.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
o: The object to serialize.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The serialized string representation or the default JSON serialization.
|
71
|
+
"""
|
61
72
|
if isinstance(o, Path):
|
62
73
|
return str(o)
|
63
74
|
return super().default(o)
|
64
75
|
|
65
76
|
|
66
|
-
def get_text_casing(casing: str):
|
67
|
-
"""
|
77
|
+
def get_text_casing(casing: str) -> Callable[[str], str]:
|
78
|
+
"""
|
79
|
+
Returns a function to apply text casing based on a string name.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
casing: The name of the casing function ('lower', 'upper', 'casefold', or 'ignore').
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A callable function (str) -> str.
|
86
|
+
"""
|
68
87
|
if casing == "lower":
|
69
88
|
return str.lower
|
70
89
|
if casing == "upper":
|
@@ -74,18 +93,46 @@ def get_text_casing(casing: str):
|
|
74
93
|
return lambda s: s
|
75
94
|
|
76
95
|
|
77
|
-
|
96
|
+
@dataclass
|
97
|
+
class PreprocessorConfig:
|
98
|
+
"""Dataclass to hold all runtime configuration, mimicking argparse.Namespace."""
|
99
|
+
input_dir: Path
|
100
|
+
output_dir: Path
|
101
|
+
language: str
|
102
|
+
sample_rate: int
|
103
|
+
cache_dir: Path
|
104
|
+
max_workers: int
|
105
|
+
single_speaker: bool
|
106
|
+
speaker_id: Optional[int]
|
107
|
+
phoneme_type: PhonemeType
|
108
|
+
alphabet: Alphabet
|
109
|
+
phonemizer_model: str
|
110
|
+
text_casing: str
|
111
|
+
dataset_name: Optional[str]
|
112
|
+
audio_quality: Optional[str]
|
113
|
+
skip_audio: bool
|
114
|
+
debug: bool
|
115
|
+
add_diacritics: bool
|
116
|
+
|
117
|
+
|
118
|
+
def ljspeech_dataset(config: PreprocessorConfig) -> Iterable[Utterance]:
|
78
119
|
"""
|
79
120
|
Generator for LJSpeech-style dataset.
|
80
121
|
Loads metadata and resolves audio file paths.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
config: The configuration object containing dataset parameters.
|
125
|
+
|
126
|
+
Yields:
|
127
|
+
Utterance: A fully populated Utterance object.
|
81
128
|
"""
|
82
|
-
dataset_dir =
|
129
|
+
dataset_dir = config.input_dir
|
83
130
|
metadata_path = dataset_dir / "metadata.csv"
|
84
131
|
if not metadata_path.exists():
|
85
132
|
_LOGGER.error(f"Missing metadata file: {metadata_path}")
|
86
133
|
return
|
87
134
|
|
88
|
-
wav_dirs = [dataset_dir / "wav", dataset_dir / "wavs"]
|
135
|
+
wav_dirs: List[Path] = [dataset_dir / "wav", dataset_dir / "wavs"]
|
89
136
|
|
90
137
|
with open(metadata_path, "r", encoding="utf-8") as csv_file:
|
91
138
|
reader = csv.reader(csv_file, delimiter="|")
|
@@ -98,16 +145,18 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
|
|
98
145
|
text: str = row[-1]
|
99
146
|
speaker: Optional[str] = None
|
100
147
|
|
101
|
-
if not
|
148
|
+
if not config.single_speaker and len(row) > 2:
|
102
149
|
speaker = row[1]
|
103
150
|
else:
|
104
151
|
speaker = None
|
105
152
|
|
106
|
-
wav_path = None
|
153
|
+
wav_path: Optional[Path] = None
|
107
154
|
for wav_dir in wav_dirs:
|
108
|
-
potential_paths = [
|
109
|
-
|
110
|
-
|
155
|
+
potential_paths: List[Path] = [
|
156
|
+
wav_dir / filename,
|
157
|
+
wav_dir / f"{filename}.wav",
|
158
|
+
wav_dir / f"{filename.lstrip('0')}.wav"
|
159
|
+
]
|
111
160
|
for path in potential_paths:
|
112
161
|
if path.exists():
|
113
162
|
wav_path = path
|
@@ -115,34 +164,40 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
|
|
115
164
|
if wav_path:
|
116
165
|
break
|
117
166
|
|
118
|
-
if not
|
167
|
+
if not config.skip_audio and not wav_path:
|
119
168
|
_LOGGER.warning("Missing audio file for filename: %s", filename)
|
120
169
|
continue
|
121
170
|
|
122
|
-
if not
|
171
|
+
if not config.skip_audio and wav_path and wav_path.stat().st_size == 0:
|
123
172
|
_LOGGER.warning("Empty audio file: %s", wav_path)
|
124
173
|
continue
|
125
174
|
|
175
|
+
# Ensure wav_path is Path or None, and is never accessed if skip_audio is true
|
126
176
|
yield Utterance(
|
127
177
|
text=text,
|
128
|
-
audio_path=wav_path,
|
178
|
+
audio_path=wav_path or Path(""), # Use empty path if skipping audio, should not be used
|
129
179
|
speaker=speaker,
|
130
|
-
speaker_id=
|
180
|
+
speaker_id=config.speaker_id,
|
131
181
|
)
|
132
182
|
|
133
183
|
|
134
184
|
def phonemize_worker(
|
135
|
-
|
185
|
+
config: PreprocessorConfig,
|
136
186
|
task_queue: JoinableQueue,
|
137
187
|
result_queue: Queue,
|
138
188
|
phonemizer: Phonemizer,
|
139
|
-
):
|
189
|
+
) -> None:
|
140
190
|
"""
|
141
191
|
Worker process for phonemization and audio processing.
|
142
|
-
|
192
|
+
|
193
|
+
Args:
|
194
|
+
config: The configuration object containing runtime parameters.
|
195
|
+
task_queue: Queue for receiving batches of Utterance objects.
|
196
|
+
result_queue: Queue for sending processed results (Utterance, set of phonemes).
|
197
|
+
phonemizer: The initialized Phonemizer instance.
|
143
198
|
"""
|
144
199
|
try:
|
145
|
-
casing = get_text_casing(
|
200
|
+
casing: Callable[[str], str] = get_text_casing(config.text_casing)
|
146
201
|
silence_detector = make_silence_detector()
|
147
202
|
|
148
203
|
while True:
|
@@ -155,28 +210,29 @@ def phonemize_worker(
|
|
155
210
|
|
156
211
|
for utt in utterance_batch:
|
157
212
|
try:
|
158
|
-
#
|
159
|
-
utterance = casing(normalize(
|
213
|
+
# Normalize text (case, numbers, etc.)
|
214
|
+
utterance: str = casing(normalize(utt.text, config.language))
|
160
215
|
|
161
|
-
#
|
162
|
-
if
|
163
|
-
utterance = phonemizer.add_diacritics(utterance,
|
216
|
+
# Add diacritics
|
217
|
+
if config.add_diacritics:
|
218
|
+
utterance = phonemizer.add_diacritics(utterance, config.language)
|
164
219
|
|
165
220
|
# Phonemize the text
|
166
|
-
utt.phonemes = phonemizer.phonemize_to_list(utterance,
|
221
|
+
utt.phonemes = phonemizer.phonemize_to_list(utterance, config.language)
|
167
222
|
if not utt.phonemes:
|
168
223
|
raise RuntimeError(f"Phonemes not found for '{utterance}'")
|
169
224
|
|
170
225
|
# Process audio if not skipping
|
171
|
-
if not
|
226
|
+
if not config.skip_audio:
|
172
227
|
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
|
173
228
|
utt.audio_path,
|
174
|
-
|
229
|
+
config.cache_dir,
|
175
230
|
silence_detector,
|
176
|
-
|
231
|
+
config.sample_rate,
|
177
232
|
)
|
178
233
|
|
179
234
|
# Put the processed utterance and its phonemes into the result queue
|
235
|
+
# The result is a tuple of (Utterance, set of unique phonemes in that utterance)
|
180
236
|
result_queue.put((utt, set(utt.phonemes)))
|
181
237
|
except Exception:
|
182
238
|
_LOGGER.exception("Failed to process utterance: %s", utt.audio_path)
|
@@ -188,109 +244,195 @@ def phonemize_worker(
|
|
188
244
|
_LOGGER.exception("Worker process failed")
|
189
245
|
|
190
246
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
)
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
)
|
255
|
-
|
256
|
-
|
247
|
+
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
248
|
+
@click.option(
|
249
|
+
"-i",
|
250
|
+
"--input-dir",
|
251
|
+
"input_dir",
|
252
|
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
253
|
+
required=True,
|
254
|
+
help="Directory with audio dataset (e.g., containing metadata.csv and wavs/)",
|
255
|
+
)
|
256
|
+
@click.option(
|
257
|
+
"-o",
|
258
|
+
"--output-dir",
|
259
|
+
"output_dir",
|
260
|
+
type=click.Path(file_okay=False, path_type=Path),
|
261
|
+
required=True,
|
262
|
+
help="Directory to write output files for training (config.json, dataset.jsonl)",
|
263
|
+
)
|
264
|
+
@click.option(
|
265
|
+
"-l",
|
266
|
+
"--language",
|
267
|
+
"language",
|
268
|
+
required=True,
|
269
|
+
help="phonemizer language code (e.g., 'en', 'es', 'fr')",
|
270
|
+
)
|
271
|
+
@click.option(
|
272
|
+
"-r",
|
273
|
+
"--sample-rate",
|
274
|
+
"sample_rate",
|
275
|
+
type=int,
|
276
|
+
required=True,
|
277
|
+
help="Target sample rate for voice (hertz, e.g., 22050)",
|
278
|
+
)
|
279
|
+
@click.option(
|
280
|
+
"--cache-dir",
|
281
|
+
"cache_dir",
|
282
|
+
type=click.Path(file_okay=False, path_type=Path),
|
283
|
+
default=None,
|
284
|
+
help="Directory to cache processed audio files. Defaults to <output-dir>/cache/<sample-rate>.",
|
285
|
+
)
|
286
|
+
@click.option(
|
287
|
+
"-w",
|
288
|
+
"--max-workers",
|
289
|
+
"max_workers",
|
290
|
+
type=click.IntRange(min=1),
|
291
|
+
default=os.cpu_count() or 1,
|
292
|
+
help="Maximum number of worker processes to use for parallel processing. Defaults to CPU count.",
|
293
|
+
)
|
294
|
+
@click.option(
|
295
|
+
"--single-speaker",
|
296
|
+
"single_speaker",
|
297
|
+
is_flag=True,
|
298
|
+
help="Force treating the dataset as single speaker, ignoring metadata speaker columns.",
|
299
|
+
)
|
300
|
+
@click.option(
|
301
|
+
"--speaker-id",
|
302
|
+
"speaker_id",
|
303
|
+
type=int,
|
304
|
+
default=None,
|
305
|
+
help="Specify a fixed speaker ID (0, 1, etc.) for a single speaker dataset.",
|
306
|
+
)
|
307
|
+
@click.option(
|
308
|
+
"--phoneme-type",
|
309
|
+
"phoneme_type",
|
310
|
+
type=click.Choice([p.value for p in PhonemeType]),
|
311
|
+
default=PhonemeType.ESPEAK.value,
|
312
|
+
help="Type of phonemes to use.",
|
313
|
+
)
|
314
|
+
@click.option(
|
315
|
+
"--alphabet",
|
316
|
+
"alphabet",
|
317
|
+
type=click.Choice([a.value for a in Alphabet]),
|
318
|
+
default=Alphabet.IPA.value,
|
319
|
+
help="Phoneme alphabet to use (e.g., IPA).",
|
320
|
+
)
|
321
|
+
@click.option(
|
322
|
+
"--phonemizer-model",
|
323
|
+
"phonemizer_model",
|
324
|
+
default="",
|
325
|
+
help="Path or name of a custom phonemizer model, if applicable.",
|
326
|
+
)
|
327
|
+
@click.option(
|
328
|
+
"--text-casing",
|
329
|
+
"text_casing",
|
330
|
+
type=click.Choice(("ignore", "lower", "upper", "casefold")),
|
331
|
+
default="ignore",
|
332
|
+
help="Casing applied to utterance text before phonemization.",
|
333
|
+
)
|
334
|
+
@click.option(
|
335
|
+
"--dataset-name",
|
336
|
+
"dataset_name",
|
337
|
+
default=None,
|
338
|
+
help="Name of dataset to put in config (default: name of <output_dir>/../).",
|
339
|
+
)
|
340
|
+
@click.option(
|
341
|
+
"--audio-quality",
|
342
|
+
"audio_quality",
|
343
|
+
default=None,
|
344
|
+
help="Audio quality description to put in config (default: name of <output_dir>).",
|
345
|
+
)
|
346
|
+
@click.option(
|
347
|
+
"--skip-audio",
|
348
|
+
"skip_audio",
|
349
|
+
is_flag=True,
|
350
|
+
help="Do not preprocess or cache audio files.",
|
351
|
+
)
|
352
|
+
@click.option(
|
353
|
+
"--debug",
|
354
|
+
"debug",
|
355
|
+
is_flag=True,
|
356
|
+
help="Print DEBUG messages to the console.",
|
357
|
+
)
|
358
|
+
@click.option(
|
359
|
+
"--add-diacritics",
|
360
|
+
"add_diacritics",
|
361
|
+
is_flag=True,
|
362
|
+
help="Add diacritics to text (phonemizer specific, e.g., to denote stress).",
|
363
|
+
)
|
364
|
+
def cli(
|
365
|
+
input_dir: Path,
|
366
|
+
output_dir: Path,
|
367
|
+
language: str,
|
368
|
+
sample_rate: int,
|
369
|
+
cache_dir: Optional[Path],
|
370
|
+
max_workers: Optional[int],
|
371
|
+
single_speaker: bool,
|
372
|
+
speaker_id: Optional[int],
|
373
|
+
phoneme_type: str,
|
374
|
+
alphabet: str,
|
375
|
+
phonemizer_model: str,
|
376
|
+
text_casing: str,
|
377
|
+
dataset_name: Optional[str],
|
378
|
+
audio_quality: Optional[str],
|
379
|
+
skip_audio: bool,
|
380
|
+
debug: bool,
|
381
|
+
add_diacritics: bool,
|
382
|
+
) -> None:
|
383
|
+
"""
|
384
|
+
Preprocess a TTS dataset (e.g., LJSpeech format) for training a VITS-style model.
|
385
|
+
This script handles text normalization, phonemization, and optional audio caching.
|
386
|
+
"""
|
387
|
+
# Create a config object from click arguments for easier passing
|
388
|
+
config = PreprocessorConfig(
|
389
|
+
input_dir=input_dir,
|
390
|
+
output_dir=output_dir,
|
391
|
+
language=language,
|
392
|
+
sample_rate=sample_rate,
|
393
|
+
cache_dir=cache_dir or output_dir / "cache" / str(sample_rate),
|
394
|
+
max_workers=max_workers or os.cpu_count() or 1,
|
395
|
+
single_speaker=single_speaker,
|
396
|
+
speaker_id=speaker_id,
|
397
|
+
phoneme_type=PhonemeType(phoneme_type),
|
398
|
+
alphabet=Alphabet(alphabet),
|
399
|
+
phonemizer_model=phonemizer_model,
|
400
|
+
text_casing=text_casing,
|
401
|
+
dataset_name=dataset_name,
|
402
|
+
audio_quality=audio_quality,
|
403
|
+
skip_audio=skip_audio,
|
404
|
+
debug=debug,
|
405
|
+
add_diacritics=add_diacritics,
|
257
406
|
)
|
258
|
-
args = parser.parse_args()
|
259
407
|
|
260
|
-
# Setup
|
261
|
-
level = logging.DEBUG if
|
408
|
+
# Setup logging
|
409
|
+
level = logging.DEBUG if config.debug else logging.INFO
|
262
410
|
logging.basicConfig(level=level)
|
263
411
|
logging.getLogger().setLevel(level)
|
264
412
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
265
413
|
|
266
|
-
|
414
|
+
# Validation
|
415
|
+
if config.single_speaker and (config.speaker_id is not None):
|
267
416
|
_LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
|
268
|
-
|
417
|
+
raise click.Abort()
|
269
418
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
args.cache_dir = (
|
274
|
-
Path(args.cache_dir)
|
275
|
-
if args.cache_dir
|
276
|
-
else args.output_dir / "cache" / str(args.sample_rate)
|
277
|
-
)
|
278
|
-
args.cache_dir.mkdir(parents=True, exist_ok=True)
|
279
|
-
args.phoneme_type = PhonemeType(args.phoneme_type)
|
419
|
+
# Create directories
|
420
|
+
config.output_dir.mkdir(parents=True, exist_ok=True)
|
421
|
+
config.cache_dir.mkdir(parents=True, exist_ok=True)
|
280
422
|
|
281
423
|
# Load all utterances from the dataset
|
282
424
|
_LOGGER.info("Loading utterances from dataset...")
|
283
|
-
utterances = list(ljspeech_dataset(
|
425
|
+
utterances: List[Utterance] = list(ljspeech_dataset(config))
|
284
426
|
if not utterances:
|
285
427
|
_LOGGER.error("No valid utterances found in dataset.")
|
286
428
|
return
|
287
429
|
|
288
|
-
num_utterances = len(utterances)
|
430
|
+
num_utterances: int = len(utterances)
|
289
431
|
_LOGGER.info("Found %d utterances.", num_utterances)
|
290
432
|
|
291
|
-
# Count speakers
|
433
|
+
# Count speakers and assign IDs
|
292
434
|
speaker_counts: Counter[str] = Counter(u.speaker for u in utterances if u.speaker)
|
293
|
-
is_multispeaker = len(speaker_counts) > 1
|
435
|
+
is_multispeaker: bool = len(speaker_counts) > 1
|
294
436
|
speaker_ids: Dict[str, int] = {}
|
295
437
|
if is_multispeaker:
|
296
438
|
_LOGGER.info("%s speakers detected", len(speaker_counts))
|
@@ -301,48 +443,47 @@ def main() -> None:
|
|
301
443
|
_LOGGER.info("Single speaker dataset")
|
302
444
|
|
303
445
|
# --- Single Pass: Process audio/phonemes and collect results ---
|
304
|
-
|
305
|
-
args.max_workers = args.max_workers if args.max_workers is not None and args.max_workers > 0 else os.cpu_count()
|
306
|
-
_LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
|
446
|
+
_LOGGER.info("Starting single pass processing with %d workers...", config.max_workers)
|
307
447
|
|
308
448
|
# Initialize the phonemizer only once in the main process
|
309
|
-
phonemizer = get_phonemizer(
|
310
|
-
|
311
|
-
|
449
|
+
phonemizer: Phonemizer = get_phonemizer(config.phoneme_type,
|
450
|
+
config.alphabet,
|
451
|
+
config.phonemizer_model)
|
312
452
|
|
313
|
-
batch_size = max(1, int(num_utterances / (
|
453
|
+
batch_size: int = max(1, int(num_utterances / (config.max_workers * 2)))
|
314
454
|
|
315
|
-
task_queue: "
|
455
|
+
task_queue: "JoinableQueue[Optional[List[Utterance]]]" = JoinableQueue()
|
316
456
|
# The result queue will hold tuples of (Utterance, set(phonemes))
|
317
|
-
result_queue: "Queue[Optional[
|
457
|
+
result_queue: "Queue[Tuple[Optional[Utterance], Set[str]]]" = Queue()
|
318
458
|
|
319
459
|
# Start workers
|
320
|
-
processes = [
|
460
|
+
processes: List[Process] = [
|
321
461
|
Process(
|
322
462
|
target=phonemize_worker,
|
323
|
-
args=(
|
463
|
+
args=(config, task_queue, result_queue, phonemizer)
|
324
464
|
)
|
325
|
-
for _ in range(
|
465
|
+
for _ in range(config.max_workers)
|
326
466
|
]
|
327
467
|
|
328
468
|
for proc in processes:
|
329
469
|
proc.start()
|
330
470
|
|
331
471
|
# Populate the task queue with batches
|
332
|
-
task_count = 0
|
472
|
+
task_count: int = 0
|
333
473
|
for utt_batch in batched(utterances, batch_size):
|
334
474
|
task_queue.put(utt_batch)
|
335
475
|
task_count += len(utt_batch)
|
336
476
|
|
337
477
|
# Signal workers to stop
|
338
|
-
for _ in range(
|
478
|
+
for _ in range(config.max_workers):
|
339
479
|
task_queue.put(None)
|
340
480
|
|
341
481
|
# Collect results from the queue with a progress bar
|
342
482
|
processed_utterances: List[Utterance] = []
|
343
483
|
all_phonemes: Set[str] = set()
|
344
484
|
for _ in tqdm(range(task_count), desc="Processing utterances"):
|
345
|
-
|
485
|
+
result: Tuple[Optional[Utterance], Set[str]] = result_queue.get()
|
486
|
+
utt, unique_phonemes = result
|
346
487
|
if utt is not None:
|
347
488
|
processed_utterances.append(utt)
|
348
489
|
all_phonemes.update(unique_phonemes)
|
@@ -355,15 +496,15 @@ def main() -> None:
|
|
355
496
|
# --- Build the final phoneme map from the collected phonemes ---
|
356
497
|
_LOGGER.info("Building a complete phoneme map from collected phonemes...")
|
357
498
|
|
358
|
-
final_phoneme_id_map = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
|
499
|
+
final_phoneme_id_map: Dict[str, int] = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
|
359
500
|
if phonemizer.alphabet == Alphabet.IPA:
|
360
501
|
all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
|
361
502
|
|
362
503
|
# Filter out special tokens that are already in the map
|
363
|
-
existing_keys = set(final_phoneme_id_map.keys())
|
364
|
-
new_phonemes = sorted([p for p in all_phonemes if p not in existing_keys])
|
504
|
+
existing_keys: Set[str] = set(final_phoneme_id_map.keys())
|
505
|
+
new_phonemes: List[str] = sorted([p for p in all_phonemes if p not in existing_keys])
|
365
506
|
|
366
|
-
current_id = len(final_phoneme_id_map)
|
507
|
+
current_id: int = len(final_phoneme_id_map)
|
367
508
|
for pho in new_phonemes:
|
368
509
|
final_phoneme_id_map[pho] = current_id
|
369
510
|
current_id += 1
|
@@ -372,23 +513,23 @@ def main() -> None:
|
|
372
513
|
|
373
514
|
# --- Write the final config.json ---
|
374
515
|
_LOGGER.info("Writing dataset config...")
|
375
|
-
audio_quality =
|
376
|
-
dataset_name =
|
516
|
+
audio_quality = config.audio_quality or config.output_dir.name
|
517
|
+
dataset_name = config.dataset_name or config.output_dir.parent.name
|
377
518
|
|
378
|
-
|
519
|
+
config_data: Dict[str, Any] = {
|
379
520
|
"dataset": dataset_name,
|
380
521
|
"audio": {
|
381
|
-
"sample_rate":
|
522
|
+
"sample_rate": config.sample_rate,
|
382
523
|
"quality": audio_quality,
|
383
524
|
},
|
384
|
-
"lang_code":
|
525
|
+
"lang_code": config.language,
|
385
526
|
"inference": {"noise_scale": 0.667,
|
386
527
|
"length_scale": 1,
|
387
528
|
"noise_w": 0.8,
|
388
|
-
"add_diacritics":
|
529
|
+
"add_diacritics": config.add_diacritics},
|
389
530
|
"alphabet": phonemizer.alphabet.value,
|
390
|
-
"phoneme_type":
|
391
|
-
"phonemizer_model":
|
531
|
+
"phoneme_type": config.phoneme_type.value,
|
532
|
+
"phonemizer_model": config.phonemizer_model,
|
392
533
|
"phoneme_id_map": final_phoneme_id_map,
|
393
534
|
"num_symbols": len(final_phoneme_id_map),
|
394
535
|
"num_speakers": len(speaker_counts) if is_multispeaker else 1,
|
@@ -396,13 +537,13 @@ def main() -> None:
|
|
396
537
|
"phoonnx_version": VERSION_STR,
|
397
538
|
}
|
398
539
|
|
399
|
-
with open(
|
400
|
-
json.dump(
|
540
|
+
with open(config.output_dir / "config.json", "w", encoding="utf-8") as config_file:
|
541
|
+
json.dump(config_data, config_file, ensure_ascii=False, indent=2)
|
401
542
|
|
402
543
|
# --- Apply final phoneme IDs and write dataset.jsonl ---
|
403
544
|
_LOGGER.info("Writing dataset.jsonl...")
|
404
|
-
valid_utterances_count = 0
|
405
|
-
with open(
|
545
|
+
valid_utterances_count: int = 0
|
546
|
+
with open(config.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
|
406
547
|
for utt in processed_utterances:
|
407
548
|
if is_multispeaker and utt.speaker is not None:
|
408
549
|
if utt.speaker not in speaker_ids:
|
@@ -432,8 +573,17 @@ def main() -> None:
|
|
432
573
|
|
433
574
|
# -----------------------------------------------------------------------------
|
434
575
|
|
435
|
-
def batched(iterable, n):
|
436
|
-
"
|
576
|
+
def batched(iterable: Iterable[Any], n: int) -> Iterable[List[Any]]:
|
577
|
+
"""
|
578
|
+
Batch data from an iterable into lists of length n. The last batch may be shorter.
|
579
|
+
|
580
|
+
Args:
|
581
|
+
iterable: The input iterable to be batched.
|
582
|
+
n: The desired size of each batch.
|
583
|
+
|
584
|
+
Yields:
|
585
|
+
List[Any]: A list representing a batch of items.
|
586
|
+
"""
|
437
587
|
if n < 1:
|
438
588
|
raise ValueError("n must be at least one")
|
439
589
|
it = iter(iterable)
|
@@ -444,4 +594,4 @@ def batched(iterable, n):
|
|
444
594
|
|
445
595
|
|
446
596
|
if __name__ == "__main__":
|
447
|
-
|
597
|
+
cli()
|
phoonnx_train/train.py
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import click
|
7
|
+
from pytorch_lightning import Trainer
|
8
|
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9
|
+
|
10
|
+
from phoonnx_train.vits.lightning import VitsModel
|
11
|
+
|
12
|
+
_LOGGER = logging.getLogger(__package__)
|
13
|
+
|
14
|
+
|
15
|
+
def load_state_dict(model, saved_state_dict):
|
16
|
+
state_dict = model.state_dict()
|
17
|
+
new_state_dict = {}
|
18
|
+
|
19
|
+
for k, v in state_dict.items():
|
20
|
+
if k in saved_state_dict:
|
21
|
+
new_state_dict[k] = saved_state_dict[k]
|
22
|
+
else:
|
23
|
+
_LOGGER.debug("%s is not in the checkpoint", k)
|
24
|
+
new_state_dict[k] = v
|
25
|
+
|
26
|
+
model.load_state_dict(new_state_dict)
|
27
|
+
|
28
|
+
|
29
|
+
@click.command(context_settings=dict(ignore_unknown_options=True))
|
30
|
+
@click.option('--dataset-dir', required=True, type=click.Path(exists=True, file_okay=False), help='Path to pre-processed dataset directory')
|
31
|
+
@click.option('--checkpoint-epochs', default=1, type=int, help='Save checkpoint every N epochs (default: 1)')
|
32
|
+
@click.option('--quality', default='medium', type=click.Choice(['x-low', 'medium', 'high']), help='Quality/size of model (default: medium)')
|
33
|
+
@click.option('--resume-from-checkpoint', default=None, help='Load an existing checkpoint and resume training')
|
34
|
+
@click.option('--resume-from-single-speaker-checkpoint', help='For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training')
|
35
|
+
@click.option('--seed', type=int, default=1234, help='Random seed (default: 1234)')
|
36
|
+
# Common Trainer options
|
37
|
+
@click.option('--max-epochs', type=int, default=1000, help='Stop training once this number of epochs is reached (default: 1000)')
|
38
|
+
@click.option('--devices', default=1, help='Number of devices or list of device IDs to train on (default: 1)')
|
39
|
+
@click.option('--accelerator', default='auto', help='Hardware accelerator to use (cpu, gpu, tpu, mps, etc.) (default: "auto")')
|
40
|
+
@click.option('--default-root-dir', type=click.Path(file_okay=False), default=None, help='Default root directory for logs and checkpoints (default: None)')
|
41
|
+
@click.option('--precision', default=32, help='Precision used in training (e.g. 16, 32, bf16) (default: 32)')
|
42
|
+
# Model-specific arguments
|
43
|
+
@click.option('--learning-rate', type=float, default=2e-4, help='Learning rate for optimizer (default: 2e-4)')
|
44
|
+
@click.option('--batch-size', type=int, default=16, help='Training batch size (default: 16)')
|
45
|
+
@click.option('--num-workers', type=click.IntRange(min=1), default=1, help='Number of data loader workers (default: 1)')
|
46
|
+
@click.option('--validation-split', type=float, default=0.05, help='Proportion of data used for validation (default: 0.05)')
|
47
|
+
def main(
|
48
|
+
dataset_dir,
|
49
|
+
checkpoint_epochs,
|
50
|
+
quality,
|
51
|
+
resume_from_checkpoint,
|
52
|
+
resume_from_single_speaker_checkpoint,
|
53
|
+
seed,
|
54
|
+
max_epochs,
|
55
|
+
devices,
|
56
|
+
accelerator,
|
57
|
+
default_root_dir,
|
58
|
+
precision,
|
59
|
+
learning_rate,
|
60
|
+
batch_size,
|
61
|
+
num_workers,
|
62
|
+
validation_split,
|
63
|
+
):
|
64
|
+
logging.basicConfig(level=logging.DEBUG)
|
65
|
+
|
66
|
+
dataset_dir = Path(dataset_dir)
|
67
|
+
if default_root_dir is None:
|
68
|
+
default_root_dir = dataset_dir
|
69
|
+
|
70
|
+
torch.backends.cudnn.benchmark = True
|
71
|
+
torch.manual_seed(seed)
|
72
|
+
|
73
|
+
config_path = dataset_dir / 'config.json'
|
74
|
+
dataset_path = dataset_dir / 'dataset.jsonl'
|
75
|
+
|
76
|
+
print(f"INFO - config_path: '{config_path}'")
|
77
|
+
print(f"INFO - dataset_path: '{dataset_path}'")
|
78
|
+
|
79
|
+
with open(config_path, 'r', encoding='utf-8') as config_file:
|
80
|
+
config = json.load(config_file)
|
81
|
+
num_symbols = int(config['num_symbols'])
|
82
|
+
num_speakers = int(config['num_speakers'])
|
83
|
+
sample_rate = int(config['audio']['sample_rate'])
|
84
|
+
|
85
|
+
trainer = Trainer(
|
86
|
+
max_epochs=max_epochs,
|
87
|
+
devices=devices,
|
88
|
+
accelerator=accelerator,
|
89
|
+
default_root_dir=default_root_dir,
|
90
|
+
precision=precision,
|
91
|
+
resume_from_checkpoint=resume_from_checkpoint
|
92
|
+
)
|
93
|
+
|
94
|
+
if checkpoint_epochs is not None:
|
95
|
+
trainer.callbacks = [ModelCheckpoint(every_n_epochs=checkpoint_epochs)]
|
96
|
+
_LOGGER.info('Checkpoints will be saved every %s epoch(s)', checkpoint_epochs)
|
97
|
+
|
98
|
+
dict_args = dict(
|
99
|
+
seed=seed,
|
100
|
+
learning_rate=learning_rate,
|
101
|
+
batch_size=batch_size,
|
102
|
+
num_workers=num_workers,
|
103
|
+
validation_split=validation_split,
|
104
|
+
)
|
105
|
+
|
106
|
+
if quality == 'x-low':
|
107
|
+
dict_args.update({
|
108
|
+
'hidden_channels': 96,
|
109
|
+
'inter_channels': 96,
|
110
|
+
'filter_channels': 384,
|
111
|
+
})
|
112
|
+
elif quality == 'high':
|
113
|
+
dict_args.update({
|
114
|
+
'resblock': '1',
|
115
|
+
'resblock_kernel_sizes': (3, 7, 11),
|
116
|
+
'resblock_dilation_sizes': ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
117
|
+
'upsample_rates': (8, 8, 2, 2),
|
118
|
+
'upsample_initial_channel': 512,
|
119
|
+
'upsample_kernel_sizes': (16, 16, 4, 4),
|
120
|
+
})
|
121
|
+
|
122
|
+
print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
|
123
|
+
model = VitsModel(
|
124
|
+
num_symbols=num_symbols,
|
125
|
+
num_speakers=num_speakers,
|
126
|
+
sample_rate=sample_rate,
|
127
|
+
dataset=[dataset_path],
|
128
|
+
**dict_args,
|
129
|
+
)
|
130
|
+
|
131
|
+
if resume_from_single_speaker_checkpoint:
|
132
|
+
assert num_speakers > 1, "--resume-from-single-speaker-checkpoint is only for multi-speaker models."
|
133
|
+
_LOGGER.info('Resuming from single-speaker checkpoint: %s', resume_from_single_speaker_checkpoint)
|
134
|
+
|
135
|
+
model_single = VitsModel.load_from_checkpoint(resume_from_single_speaker_checkpoint, dataset=None)
|
136
|
+
g_dict = model_single.model_g.state_dict()
|
137
|
+
|
138
|
+
for key in list(g_dict.keys()):
|
139
|
+
if key.startswith('dec.cond') or key.startswith('dp.cond') or ('enc.cond_layer' in key):
|
140
|
+
g_dict.pop(key, None)
|
141
|
+
|
142
|
+
load_state_dict(model.model_g, g_dict)
|
143
|
+
load_state_dict(model.model_d, model_single.model_d.state_dict())
|
144
|
+
_LOGGER.info('Successfully converted single-speaker checkpoint to multi-speaker')
|
145
|
+
|
146
|
+
print('training started!!')
|
147
|
+
trainer.fit(model)
|
148
|
+
|
149
|
+
|
150
|
+
if __name__ == '__main__':
|
151
|
+
main()
|
phoonnx_train/__main__.py
DELETED
@@ -1,151 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import torch
|
7
|
-
from pytorch_lightning import Trainer
|
8
|
-
from pytorch_lightning.callbacks import ModelCheckpoint
|
9
|
-
|
10
|
-
from phoonnx_train.vits.lightning import VitsModel
|
11
|
-
|
12
|
-
_LOGGER = logging.getLogger(__package__)
|
13
|
-
|
14
|
-
|
15
|
-
def main():
|
16
|
-
logging.basicConfig(level=logging.DEBUG)
|
17
|
-
|
18
|
-
parser = argparse.ArgumentParser()
|
19
|
-
parser.add_argument(
|
20
|
-
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
|
21
|
-
)
|
22
|
-
parser.add_argument(
|
23
|
-
"--checkpoint-epochs",
|
24
|
-
type=int,
|
25
|
-
help="Save checkpoint every N epochs (default: 1)",
|
26
|
-
)
|
27
|
-
parser.add_argument(
|
28
|
-
"--quality",
|
29
|
-
default="medium",
|
30
|
-
choices=("x-low", "medium", "high"),
|
31
|
-
help="Quality/size of model (default: medium)",
|
32
|
-
)
|
33
|
-
parser.add_argument(
|
34
|
-
"--resume_from_single_speaker_checkpoint",
|
35
|
-
help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
|
36
|
-
)
|
37
|
-
Trainer.add_argparse_args(parser)
|
38
|
-
VitsModel.add_model_specific_args(parser)
|
39
|
-
parser.add_argument("--seed", type=int, default=1234)
|
40
|
-
args = parser.parse_args()
|
41
|
-
_LOGGER.debug(args)
|
42
|
-
|
43
|
-
args.dataset_dir = Path(args.dataset_dir)
|
44
|
-
if not args.default_root_dir:
|
45
|
-
args.default_root_dir = args.dataset_dir
|
46
|
-
|
47
|
-
torch.backends.cudnn.benchmark = True
|
48
|
-
torch.manual_seed(args.seed)
|
49
|
-
|
50
|
-
config_path = args.dataset_dir / "config.json"
|
51
|
-
dataset_path = args.dataset_dir / "dataset.jsonl"
|
52
|
-
|
53
|
-
print(f"INFO - config_path: '{config_path}'")
|
54
|
-
print(f"INFO - dataset_path: '{dataset_path}'")
|
55
|
-
|
56
|
-
with open(config_path, "r", encoding="utf-8") as config_file:
|
57
|
-
# See preprocess.py for format
|
58
|
-
config = json.load(config_file)
|
59
|
-
num_symbols = int(config["num_symbols"])
|
60
|
-
num_speakers = int(config["num_speakers"])
|
61
|
-
sample_rate = int(config["audio"]["sample_rate"])
|
62
|
-
|
63
|
-
trainer = Trainer.from_argparse_args(args)
|
64
|
-
if args.checkpoint_epochs is not None:
|
65
|
-
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
|
66
|
-
_LOGGER.info(
|
67
|
-
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
|
68
|
-
)
|
69
|
-
|
70
|
-
dict_args = vars(args)
|
71
|
-
if args.quality == "x-low":
|
72
|
-
dict_args["hidden_channels"] = 96
|
73
|
-
dict_args["inter_channels"] = 96
|
74
|
-
dict_args["filter_channels"] = 384
|
75
|
-
elif args.quality == "high":
|
76
|
-
dict_args["resblock"] = "1"
|
77
|
-
dict_args["resblock_kernel_sizes"] = (3, 7, 11)
|
78
|
-
dict_args["resblock_dilation_sizes"] = (
|
79
|
-
(1, 3, 5),
|
80
|
-
(1, 3, 5),
|
81
|
-
(1, 3, 5),
|
82
|
-
)
|
83
|
-
dict_args["upsample_rates"] = (8, 8, 2, 2)
|
84
|
-
dict_args["upsample_initial_channel"] = 512
|
85
|
-
dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)
|
86
|
-
|
87
|
-
print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
|
88
|
-
model = VitsModel(
|
89
|
-
num_symbols=num_symbols,
|
90
|
-
num_speakers=num_speakers,
|
91
|
-
sample_rate=sample_rate,
|
92
|
-
dataset=[dataset_path],
|
93
|
-
**dict_args,
|
94
|
-
)
|
95
|
-
|
96
|
-
if args.resume_from_single_speaker_checkpoint:
|
97
|
-
assert (
|
98
|
-
num_speakers > 1
|
99
|
-
), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
|
100
|
-
|
101
|
-
# Load single-speaker checkpoint
|
102
|
-
_LOGGER.info(
|
103
|
-
"Resuming from single-speaker checkpoint: %s",
|
104
|
-
args.resume_from_single_speaker_checkpoint,
|
105
|
-
)
|
106
|
-
model_single = VitsModel.load_from_checkpoint(
|
107
|
-
args.resume_from_single_speaker_checkpoint,
|
108
|
-
dataset=None,
|
109
|
-
)
|
110
|
-
g_dict = model_single.model_g.state_dict()
|
111
|
-
for key in list(g_dict.keys()):
|
112
|
-
# Remove keys that can't be copied over due to missing speaker embedding
|
113
|
-
if (
|
114
|
-
key.startswith("dec.cond")
|
115
|
-
or key.startswith("dp.cond")
|
116
|
-
or ("enc.cond_layer" in key)
|
117
|
-
):
|
118
|
-
g_dict.pop(key, None)
|
119
|
-
|
120
|
-
# Copy over the multi-speaker model, excluding keys related to the
|
121
|
-
# speaker embedding (which is missing from the single-speaker model).
|
122
|
-
load_state_dict(model.model_g, g_dict)
|
123
|
-
load_state_dict(model.model_d, model_single.model_d.state_dict())
|
124
|
-
_LOGGER.info(
|
125
|
-
"Successfully converted single-speaker checkpoint to multi-speaker"
|
126
|
-
)
|
127
|
-
print("training started!!")
|
128
|
-
trainer.fit(model)
|
129
|
-
|
130
|
-
|
131
|
-
def load_state_dict(model, saved_state_dict):
|
132
|
-
state_dict = model.state_dict()
|
133
|
-
new_state_dict = {}
|
134
|
-
|
135
|
-
for k, v in state_dict.items():
|
136
|
-
if k in saved_state_dict:
|
137
|
-
# Use saved value
|
138
|
-
new_state_dict[k] = saved_state_dict[k]
|
139
|
-
else:
|
140
|
-
# Use initialized value
|
141
|
-
_LOGGER.debug("%s is not in the checkpoint", k)
|
142
|
-
new_state_dict[k] = v
|
143
|
-
|
144
|
-
model.load_state_dict(new_state_dict)
|
145
|
-
|
146
|
-
|
147
|
-
# -----------------------------------------------------------------------------
|
148
|
-
|
149
|
-
|
150
|
-
if __name__ == "__main__":
|
151
|
-
main()
|
File without changes
|
File without changes
|