phoonnx 0.2.0a2__py3-none-any.whl → 0.2.2a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phoonnx/version.py +2 -2
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.2a1.dist-info}/METADATA +1 -1
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.2a1.dist-info}/RECORD +7 -7
- phoonnx_train/preprocess.py +356 -162
- phoonnx_train/train.py +151 -0
- phoonnx_train/__main__.py +0 -151
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.2a1.dist-info}/WHEEL +0 -0
- {phoonnx-0.2.0a2.dist-info → phoonnx-0.2.2a1.dist-info}/top_level.txt +0 -0
phoonnx/version.py
CHANGED
@@ -2,7 +2,7 @@ phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
phoonnx/config.py,sha256=DKgsU03g8jrAuMcVqbu-w3MWPXOUihFtRnavg6WGQ1Y,19983
|
3
3
|
phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
|
4
4
|
phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
|
5
|
-
phoonnx/version.py,sha256=
|
5
|
+
phoonnx/version.py,sha256=WiLyUm-i8r69uJR8Yj-q7I3TuC27Ha0HVgB22_lpeR4,237
|
6
6
|
phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
|
7
7
|
phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
|
8
8
|
phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
|
@@ -62,9 +62,9 @@ phoonnx/thirdparty/tashkeel/hint_id_map.json,sha256=gJMdtTsfEDFgmmbyO2Shw315rkqK
|
|
62
62
|
phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771ltzkv1EnYmHKc2w8M,628
|
63
63
|
phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
|
64
64
|
phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
|
65
|
-
phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
|
66
65
|
phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
|
67
|
-
phoonnx_train/preprocess.py,sha256=
|
66
|
+
phoonnx_train/preprocess.py,sha256=4FJFi7KL-ZUmrbN2NyhxBNpEjDlPRLSDJo2JoyvpR14,21700
|
67
|
+
phoonnx_train/train.py,sha256=EUePlnNdBuo9IFIxHsxZ4CZ27IwOCIc1ySsJiIo-dkI,6015
|
68
68
|
phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
|
69
69
|
phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
|
70
70
|
phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
|
@@ -83,7 +83,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
|
|
83
83
|
phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
|
84
84
|
phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
|
85
85
|
phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
|
86
|
-
phoonnx-0.2.
|
87
|
-
phoonnx-0.2.
|
88
|
-
phoonnx-0.2.
|
89
|
-
phoonnx-0.2.
|
86
|
+
phoonnx-0.2.2a1.dist-info/METADATA,sha256=j6Tw5hEBAq4iuLqVeAswSzEERwVI-_uci2w1oYJPyfM,8250
|
87
|
+
phoonnx-0.2.2a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
88
|
+
phoonnx-0.2.2a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
|
89
|
+
phoonnx-0.2.2a1.dist-info/RECORD,,
|
phoonnx_train/preprocess.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
import argparse
|
3
2
|
import csv
|
4
3
|
import dataclasses
|
5
4
|
import itertools
|
@@ -10,13 +9,16 @@ from collections import Counter
|
|
10
9
|
from dataclasses import dataclass
|
11
10
|
from multiprocessing import JoinableQueue, Process, Queue
|
12
11
|
from pathlib import Path
|
13
|
-
from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union
|
12
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union, Callable
|
14
13
|
|
14
|
+
import click
|
15
15
|
from phoonnx.util import normalize
|
16
16
|
from phoonnx.config import PhonemeType, get_phonemizer, Alphabet
|
17
17
|
from phoonnx.phonemizers import Phonemizer
|
18
|
-
from phoonnx.phoneme_ids import (
|
19
|
-
|
18
|
+
from phoonnx.phoneme_ids import (
|
19
|
+
phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
|
20
|
+
DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN
|
21
|
+
)
|
20
22
|
from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
|
21
23
|
from tqdm import tqdm
|
22
24
|
from phoonnx.version import VERSION_STR
|
@@ -46,7 +48,7 @@ class Utterance:
|
|
46
48
|
audio_spec_path: Optional[Path] = None
|
47
49
|
|
48
50
|
def asdict(self) -> Dict[str, Any]:
|
49
|
-
"""Custom asdict to handle Path objects."""
|
51
|
+
"""Custom asdict to handle Path objects for JSON serialization."""
|
50
52
|
data = dataclasses.asdict(self)
|
51
53
|
for key, value in data.items():
|
52
54
|
if isinstance(value, Path):
|
@@ -57,14 +59,31 @@ class Utterance:
|
|
57
59
|
class PathEncoder(json.JSONEncoder):
|
58
60
|
"""JSON encoder for Path objects."""
|
59
61
|
|
60
|
-
def default(self, o):
|
62
|
+
def default(self, o: Any) -> Union[str, Any]:
|
63
|
+
"""
|
64
|
+
Converts Path objects to strings for serialization.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
o: The object to serialize.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The serialized string representation or the default JSON serialization.
|
71
|
+
"""
|
61
72
|
if isinstance(o, Path):
|
62
73
|
return str(o)
|
63
74
|
return super().default(o)
|
64
75
|
|
65
76
|
|
66
|
-
def get_text_casing(casing: str):
|
67
|
-
"""
|
77
|
+
def get_text_casing(casing: str) -> Callable[[str], str]:
|
78
|
+
"""
|
79
|
+
Returns a function to apply text casing based on a string name.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
casing: The name of the casing function ('lower', 'upper', 'casefold', or 'ignore').
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A callable function (str) -> str.
|
86
|
+
"""
|
68
87
|
if casing == "lower":
|
69
88
|
return str.lower
|
70
89
|
if casing == "upper":
|
@@ -74,18 +93,46 @@ def get_text_casing(casing: str):
|
|
74
93
|
return lambda s: s
|
75
94
|
|
76
95
|
|
77
|
-
|
96
|
+
@dataclass
|
97
|
+
class PreprocessorConfig:
|
98
|
+
"""Dataclass to hold all runtime configuration, mimicking argparse.Namespace."""
|
99
|
+
input_dir: Path
|
100
|
+
output_dir: Path
|
101
|
+
language: str
|
102
|
+
sample_rate: int
|
103
|
+
cache_dir: Path
|
104
|
+
max_workers: int
|
105
|
+
single_speaker: bool
|
106
|
+
speaker_id: Optional[int]
|
107
|
+
phoneme_type: PhonemeType
|
108
|
+
alphabet: Alphabet
|
109
|
+
phonemizer_model: str
|
110
|
+
text_casing: str
|
111
|
+
dataset_name: Optional[str]
|
112
|
+
audio_quality: Optional[str]
|
113
|
+
skip_audio: bool
|
114
|
+
debug: bool
|
115
|
+
add_diacritics: bool
|
116
|
+
|
117
|
+
|
118
|
+
def ljspeech_dataset(config: PreprocessorConfig) -> Iterable[Utterance]:
|
78
119
|
"""
|
79
120
|
Generator for LJSpeech-style dataset.
|
80
121
|
Loads metadata and resolves audio file paths.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
config: The configuration object containing dataset parameters.
|
125
|
+
|
126
|
+
Yields:
|
127
|
+
Utterance: A fully populated Utterance object.
|
81
128
|
"""
|
82
|
-
dataset_dir =
|
129
|
+
dataset_dir = config.input_dir
|
83
130
|
metadata_path = dataset_dir / "metadata.csv"
|
84
131
|
if not metadata_path.exists():
|
85
132
|
_LOGGER.error(f"Missing metadata file: {metadata_path}")
|
86
133
|
return
|
87
134
|
|
88
|
-
wav_dirs = [dataset_dir / "wav", dataset_dir / "wavs"]
|
135
|
+
wav_dirs: List[Path] = [dataset_dir / "wav", dataset_dir / "wavs"]
|
89
136
|
|
90
137
|
with open(metadata_path, "r", encoding="utf-8") as csv_file:
|
91
138
|
reader = csv.reader(csv_file, delimiter="|")
|
@@ -98,16 +145,18 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
|
|
98
145
|
text: str = row[-1]
|
99
146
|
speaker: Optional[str] = None
|
100
147
|
|
101
|
-
if not
|
148
|
+
if not config.single_speaker and len(row) > 2:
|
102
149
|
speaker = row[1]
|
103
150
|
else:
|
104
151
|
speaker = None
|
105
152
|
|
106
|
-
wav_path = None
|
153
|
+
wav_path: Optional[Path] = None
|
107
154
|
for wav_dir in wav_dirs:
|
108
|
-
potential_paths = [
|
109
|
-
|
110
|
-
|
155
|
+
potential_paths: List[Path] = [
|
156
|
+
wav_dir / filename,
|
157
|
+
wav_dir / f"{filename}.wav",
|
158
|
+
wav_dir / f"{filename.lstrip('0')}.wav"
|
159
|
+
]
|
111
160
|
for path in potential_paths:
|
112
161
|
if path.exists():
|
113
162
|
wav_path = path
|
@@ -115,34 +164,40 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
|
|
115
164
|
if wav_path:
|
116
165
|
break
|
117
166
|
|
118
|
-
if not
|
167
|
+
if not config.skip_audio and not wav_path:
|
119
168
|
_LOGGER.warning("Missing audio file for filename: %s", filename)
|
120
169
|
continue
|
121
170
|
|
122
|
-
if not
|
171
|
+
if not config.skip_audio and wav_path and wav_path.stat().st_size == 0:
|
123
172
|
_LOGGER.warning("Empty audio file: %s", wav_path)
|
124
173
|
continue
|
125
174
|
|
175
|
+
# Ensure wav_path is Path or None, and is never accessed if skip_audio is true
|
126
176
|
yield Utterance(
|
127
177
|
text=text,
|
128
|
-
audio_path=wav_path,
|
178
|
+
audio_path=wav_path or Path(""), # Use empty path if skipping audio, should not be used
|
129
179
|
speaker=speaker,
|
130
|
-
speaker_id=
|
180
|
+
speaker_id=config.speaker_id,
|
131
181
|
)
|
132
182
|
|
133
183
|
|
134
184
|
def phonemize_worker(
|
135
|
-
|
185
|
+
config: PreprocessorConfig,
|
136
186
|
task_queue: JoinableQueue,
|
137
187
|
result_queue: Queue,
|
138
188
|
phonemizer: Phonemizer,
|
139
|
-
):
|
189
|
+
) -> None:
|
140
190
|
"""
|
141
191
|
Worker process for phonemization and audio processing.
|
142
|
-
|
192
|
+
|
193
|
+
Args:
|
194
|
+
config: The configuration object containing runtime parameters.
|
195
|
+
task_queue: Queue for receiving batches of Utterance objects.
|
196
|
+
result_queue: Queue for sending processed results (Utterance, set of phonemes).
|
197
|
+
phonemizer: The initialized Phonemizer instance.
|
143
198
|
"""
|
144
199
|
try:
|
145
|
-
casing = get_text_casing(
|
200
|
+
casing: Callable[[str], str] = get_text_casing(config.text_casing)
|
146
201
|
silence_detector = make_silence_detector()
|
147
202
|
|
148
203
|
while True:
|
@@ -155,28 +210,30 @@ def phonemize_worker(
|
|
155
210
|
|
156
211
|
for utt in utterance_batch:
|
157
212
|
try:
|
158
|
-
#
|
159
|
-
utterance = casing(normalize(
|
213
|
+
# Normalize text (case, numbers, etc.)
|
214
|
+
utterance: str = casing(normalize(utt.text, config.language))
|
160
215
|
|
161
|
-
#
|
162
|
-
if
|
163
|
-
utterance = phonemizer.add_diacritics(utterance,
|
216
|
+
# Add diacritics
|
217
|
+
if config.add_diacritics:
|
218
|
+
utterance = phonemizer.add_diacritics(utterance, config.language)
|
164
219
|
|
165
220
|
# Phonemize the text
|
166
|
-
utt.phonemes = phonemizer.phonemize_to_list(utterance,
|
221
|
+
utt.phonemes = [p for p in phonemizer.phonemize_to_list(utterance, config.language)
|
222
|
+
if p != "\n"] # HACK: not sure where this is coming from
|
167
223
|
if not utt.phonemes:
|
168
224
|
raise RuntimeError(f"Phonemes not found for '{utterance}'")
|
169
225
|
|
170
226
|
# Process audio if not skipping
|
171
|
-
if not
|
227
|
+
if not config.skip_audio:
|
172
228
|
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
|
173
229
|
utt.audio_path,
|
174
|
-
|
230
|
+
config.cache_dir,
|
175
231
|
silence_detector,
|
176
|
-
|
232
|
+
config.sample_rate,
|
177
233
|
)
|
178
234
|
|
179
235
|
# Put the processed utterance and its phonemes into the result queue
|
236
|
+
# The result is a tuple of (Utterance, set of unique phonemes in that utterance)
|
180
237
|
result_queue.put((utt, set(utt.phonemes)))
|
181
238
|
except Exception:
|
182
239
|
_LOGGER.exception("Failed to process utterance: %s", utt.audio_path)
|
@@ -188,109 +245,212 @@ def phonemize_worker(
|
|
188
245
|
_LOGGER.exception("Worker process failed")
|
189
246
|
|
190
247
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
248
|
+
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
249
|
+
@click.option(
|
250
|
+
"-i",
|
251
|
+
"--input-dir",
|
252
|
+
"input_dir",
|
253
|
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
254
|
+
required=True,
|
255
|
+
help="Directory with audio dataset (e.g., containing metadata.csv and wavs/)",
|
256
|
+
)
|
257
|
+
@click.option(
|
258
|
+
"-o",
|
259
|
+
"--output-dir",
|
260
|
+
"output_dir",
|
261
|
+
type=click.Path(file_okay=False, path_type=Path),
|
262
|
+
required=True,
|
263
|
+
help="Directory to write output files for training (config.json, dataset.jsonl)",
|
264
|
+
)
|
265
|
+
@click.option(
|
266
|
+
"-l",
|
267
|
+
"--language",
|
268
|
+
"language",
|
269
|
+
required=True,
|
270
|
+
help="phonemizer language code (e.g., 'en', 'es', 'fr')",
|
271
|
+
)
|
272
|
+
@click.option(
|
273
|
+
"-c",
|
274
|
+
"--prev-config",
|
275
|
+
"prev_config",
|
276
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
277
|
+
default=None,
|
278
|
+
help="Optional path to a previous config.json from which to reuse phoneme_id_map. (for fine-tuning only)",
|
279
|
+
)
|
280
|
+
@click.option(
|
281
|
+
"--drop-extra-phonemes",
|
282
|
+
"drop_extra_phonemes",
|
283
|
+
type=bool,
|
284
|
+
default=True,
|
285
|
+
help="If training data has more symbols than base model, discard new symbols. (for fine-tuning only)",
|
286
|
+
)
|
287
|
+
@click.option(
|
288
|
+
"-r",
|
289
|
+
"--sample-rate",
|
290
|
+
"sample_rate",
|
291
|
+
type=int,
|
292
|
+
default=22050,
|
293
|
+
help="Target sample rate for voice (hertz, Default: 22050)",
|
294
|
+
)
|
295
|
+
@click.option(
|
296
|
+
"--cache-dir",
|
297
|
+
"cache_dir",
|
298
|
+
type=click.Path(file_okay=False, path_type=Path),
|
299
|
+
default=None,
|
300
|
+
help="Directory to cache processed audio files. Defaults to <output-dir>/cache/<sample-rate>.",
|
301
|
+
)
|
302
|
+
@click.option(
|
303
|
+
"-w",
|
304
|
+
"--max-workers",
|
305
|
+
"max_workers",
|
306
|
+
type=click.IntRange(min=1),
|
307
|
+
default=os.cpu_count() or 1,
|
308
|
+
help="Maximum number of worker processes to use for parallel processing. Defaults to CPU count.",
|
309
|
+
)
|
310
|
+
@click.option(
|
311
|
+
"--single-speaker",
|
312
|
+
"single_speaker",
|
313
|
+
is_flag=True,
|
314
|
+
help="Force treating the dataset as single speaker, ignoring metadata speaker columns.",
|
315
|
+
)
|
316
|
+
@click.option(
|
317
|
+
"--speaker-id",
|
318
|
+
"speaker_id",
|
319
|
+
type=int,
|
320
|
+
default=None,
|
321
|
+
help="Specify a fixed speaker ID (0, 1, etc.) for a single speaker dataset.",
|
322
|
+
)
|
323
|
+
@click.option(
|
324
|
+
"--phoneme-type",
|
325
|
+
"phoneme_type",
|
326
|
+
type=click.Choice([p.value for p in PhonemeType]),
|
327
|
+
default=PhonemeType.ESPEAK.value,
|
328
|
+
help="Type of phonemes to use.",
|
329
|
+
)
|
330
|
+
@click.option(
|
331
|
+
"--alphabet",
|
332
|
+
"alphabet",
|
333
|
+
type=click.Choice([a.value for a in Alphabet]),
|
334
|
+
default=Alphabet.IPA.value,
|
335
|
+
help="Phoneme alphabet to use (e.g., IPA).",
|
336
|
+
)
|
337
|
+
@click.option(
|
338
|
+
"--phonemizer-model",
|
339
|
+
"phonemizer_model",
|
340
|
+
default="",
|
341
|
+
help="Path or name of a custom phonemizer model, if applicable.",
|
342
|
+
)
|
343
|
+
@click.option(
|
344
|
+
"--text-casing",
|
345
|
+
"text_casing",
|
346
|
+
type=click.Choice(("ignore", "lower", "upper", "casefold")),
|
347
|
+
default="ignore",
|
348
|
+
help="Casing applied to utterance text before phonemization.",
|
349
|
+
)
|
350
|
+
@click.option(
|
351
|
+
"--dataset-name",
|
352
|
+
"dataset_name",
|
353
|
+
default=None,
|
354
|
+
help="Name of dataset to put in config (default: name of <output_dir>/../).",
|
355
|
+
)
|
356
|
+
@click.option(
|
357
|
+
"--audio-quality",
|
358
|
+
"audio_quality",
|
359
|
+
default=None,
|
360
|
+
help="Audio quality description to put in config (default: name of <output_dir>).",
|
361
|
+
)
|
362
|
+
@click.option(
|
363
|
+
"--skip-audio",
|
364
|
+
"skip_audio",
|
365
|
+
is_flag=True,
|
366
|
+
help="Do not preprocess or cache audio files.",
|
367
|
+
)
|
368
|
+
@click.option(
|
369
|
+
"--debug",
|
370
|
+
"debug",
|
371
|
+
is_flag=True,
|
372
|
+
help="Print DEBUG messages to the console.",
|
373
|
+
)
|
374
|
+
@click.option(
|
375
|
+
"--add-diacritics",
|
376
|
+
"add_diacritics",
|
377
|
+
is_flag=True,
|
378
|
+
help="Add diacritics to text (phonemizer specific, e.g., to denote stress).",
|
379
|
+
)
|
380
|
+
def cli(
|
381
|
+
input_dir: Path,
|
382
|
+
output_dir: Path,
|
383
|
+
language: str,
|
384
|
+
prev_config: Path,
|
385
|
+
drop_extra_phonemes: bool,
|
386
|
+
sample_rate: int,
|
387
|
+
cache_dir: Optional[Path],
|
388
|
+
max_workers: Optional[int],
|
389
|
+
single_speaker: bool,
|
390
|
+
speaker_id: Optional[int],
|
391
|
+
phoneme_type: str,
|
392
|
+
alphabet: str,
|
393
|
+
phonemizer_model: str,
|
394
|
+
text_casing: str,
|
395
|
+
dataset_name: Optional[str],
|
396
|
+
audio_quality: Optional[str],
|
397
|
+
skip_audio: bool,
|
398
|
+
debug: bool,
|
399
|
+
add_diacritics: bool,
|
400
|
+
) -> None:
|
401
|
+
"""
|
402
|
+
Preprocess a TTS dataset (e.g., LJSpeech format) for training a VITS-style model.
|
403
|
+
This script handles text normalization, phonemization, and optional audio caching.
|
404
|
+
"""
|
405
|
+
# Create a config object from click arguments for easier passing
|
406
|
+
config = PreprocessorConfig(
|
407
|
+
input_dir=input_dir,
|
408
|
+
output_dir=output_dir,
|
409
|
+
language=language,
|
410
|
+
sample_rate=sample_rate,
|
411
|
+
cache_dir=cache_dir or output_dir / "cache" / str(sample_rate),
|
412
|
+
max_workers=max_workers or os.cpu_count() or 1,
|
413
|
+
single_speaker=single_speaker,
|
414
|
+
speaker_id=speaker_id,
|
415
|
+
phoneme_type=PhonemeType(phoneme_type),
|
416
|
+
alphabet=Alphabet(alphabet),
|
417
|
+
phonemizer_model=phonemizer_model,
|
418
|
+
text_casing=text_casing,
|
419
|
+
dataset_name=dataset_name,
|
420
|
+
audio_quality=audio_quality,
|
421
|
+
skip_audio=skip_audio,
|
422
|
+
debug=debug,
|
423
|
+
add_diacritics=add_diacritics,
|
257
424
|
)
|
258
|
-
args = parser.parse_args()
|
259
425
|
|
260
|
-
# Setup
|
261
|
-
level = logging.DEBUG if
|
426
|
+
# Setup logging
|
427
|
+
level = logging.DEBUG if config.debug else logging.INFO
|
262
428
|
logging.basicConfig(level=level)
|
263
429
|
logging.getLogger().setLevel(level)
|
264
430
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
265
431
|
|
266
|
-
|
432
|
+
# Validation
|
433
|
+
if config.single_speaker and (config.speaker_id is not None):
|
267
434
|
_LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
|
268
|
-
|
435
|
+
raise click.Abort()
|
269
436
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
args.cache_dir = (
|
274
|
-
Path(args.cache_dir)
|
275
|
-
if args.cache_dir
|
276
|
-
else args.output_dir / "cache" / str(args.sample_rate)
|
277
|
-
)
|
278
|
-
args.cache_dir.mkdir(parents=True, exist_ok=True)
|
279
|
-
args.phoneme_type = PhonemeType(args.phoneme_type)
|
437
|
+
# Create directories
|
438
|
+
config.output_dir.mkdir(parents=True, exist_ok=True)
|
439
|
+
config.cache_dir.mkdir(parents=True, exist_ok=True)
|
280
440
|
|
281
441
|
# Load all utterances from the dataset
|
282
442
|
_LOGGER.info("Loading utterances from dataset...")
|
283
|
-
utterances = list(ljspeech_dataset(
|
443
|
+
utterances: List[Utterance] = list(ljspeech_dataset(config))
|
284
444
|
if not utterances:
|
285
445
|
_LOGGER.error("No valid utterances found in dataset.")
|
286
446
|
return
|
287
447
|
|
288
|
-
num_utterances = len(utterances)
|
448
|
+
num_utterances: int = len(utterances)
|
289
449
|
_LOGGER.info("Found %d utterances.", num_utterances)
|
290
450
|
|
291
|
-
# Count speakers
|
451
|
+
# Count speakers and assign IDs
|
292
452
|
speaker_counts: Counter[str] = Counter(u.speaker for u in utterances if u.speaker)
|
293
|
-
is_multispeaker = len(speaker_counts) > 1
|
453
|
+
is_multispeaker: bool = len(speaker_counts) > 1
|
294
454
|
speaker_ids: Dict[str, int] = {}
|
295
455
|
if is_multispeaker:
|
296
456
|
_LOGGER.info("%s speakers detected", len(speaker_counts))
|
@@ -301,48 +461,47 @@ def main() -> None:
|
|
301
461
|
_LOGGER.info("Single speaker dataset")
|
302
462
|
|
303
463
|
# --- Single Pass: Process audio/phonemes and collect results ---
|
304
|
-
|
305
|
-
args.max_workers = args.max_workers if args.max_workers is not None and args.max_workers > 0 else os.cpu_count()
|
306
|
-
_LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
|
464
|
+
_LOGGER.info("Starting single pass processing with %d workers...", config.max_workers)
|
307
465
|
|
308
466
|
# Initialize the phonemizer only once in the main process
|
309
|
-
phonemizer = get_phonemizer(
|
310
|
-
|
311
|
-
|
467
|
+
phonemizer: Phonemizer = get_phonemizer(config.phoneme_type,
|
468
|
+
config.alphabet,
|
469
|
+
config.phonemizer_model)
|
312
470
|
|
313
|
-
batch_size = max(1, int(num_utterances / (
|
471
|
+
batch_size: int = max(1, int(num_utterances / (config.max_workers * 2)))
|
314
472
|
|
315
|
-
task_queue: "
|
473
|
+
task_queue: "JoinableQueue[Optional[List[Utterance]]]" = JoinableQueue()
|
316
474
|
# The result queue will hold tuples of (Utterance, set(phonemes))
|
317
|
-
result_queue: "Queue[Optional[
|
475
|
+
result_queue: "Queue[Tuple[Optional[Utterance], Set[str]]]" = Queue()
|
318
476
|
|
319
477
|
# Start workers
|
320
|
-
processes = [
|
478
|
+
processes: List[Process] = [
|
321
479
|
Process(
|
322
480
|
target=phonemize_worker,
|
323
|
-
args=(
|
481
|
+
args=(config, task_queue, result_queue, phonemizer)
|
324
482
|
)
|
325
|
-
for _ in range(
|
483
|
+
for _ in range(config.max_workers)
|
326
484
|
]
|
327
485
|
|
328
486
|
for proc in processes:
|
329
487
|
proc.start()
|
330
488
|
|
331
489
|
# Populate the task queue with batches
|
332
|
-
task_count = 0
|
490
|
+
task_count: int = 0
|
333
491
|
for utt_batch in batched(utterances, batch_size):
|
334
492
|
task_queue.put(utt_batch)
|
335
493
|
task_count += len(utt_batch)
|
336
494
|
|
337
495
|
# Signal workers to stop
|
338
|
-
for _ in range(
|
496
|
+
for _ in range(config.max_workers):
|
339
497
|
task_queue.put(None)
|
340
498
|
|
341
499
|
# Collect results from the queue with a progress bar
|
342
500
|
processed_utterances: List[Utterance] = []
|
343
501
|
all_phonemes: Set[str] = set()
|
344
502
|
for _ in tqdm(range(task_count), desc="Processing utterances"):
|
345
|
-
|
503
|
+
result: Tuple[Optional[Utterance], Set[str]] = result_queue.get()
|
504
|
+
utt, unique_phonemes = result
|
346
505
|
if utt is not None:
|
347
506
|
processed_utterances.append(utt)
|
348
507
|
all_phonemes.update(unique_phonemes)
|
@@ -352,43 +511,69 @@ def main() -> None:
|
|
352
511
|
for proc in processes:
|
353
512
|
proc.join()
|
354
513
|
|
355
|
-
# --- Build the final phoneme map from the collected phonemes ---
|
356
|
-
_LOGGER.info("Building a complete phoneme map from collected phonemes...")
|
357
|
-
|
358
|
-
final_phoneme_id_map = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
|
359
|
-
if phonemizer.alphabet == Alphabet.IPA:
|
360
|
-
all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
|
361
|
-
|
362
|
-
# Filter out special tokens that are already in the map
|
363
|
-
existing_keys = set(final_phoneme_id_map.keys())
|
364
|
-
new_phonemes = sorted([p for p in all_phonemes if p not in existing_keys])
|
365
514
|
|
366
|
-
|
515
|
+
# --- Build the final phoneme map from the collected phonemes ---
|
516
|
+
_LOGGER.info("Building a phoneme map from collected dataset phonemes...")
|
517
|
+
|
518
|
+
if prev_config:
|
519
|
+
with open(prev_config) as f:
|
520
|
+
prev_phoneme_id_map = json.load(f)["phoneme_id_map"]
|
521
|
+
_LOGGER.info(f"Loaded phoneme map from previous config: '{prev_config}'")
|
522
|
+
all_phonemes.update(prev_phoneme_id_map.keys())
|
523
|
+
final_phoneme_id_map = prev_phoneme_id_map
|
524
|
+
_LOGGER.info("previous phoneme map contains %d symbols.", len(final_phoneme_id_map))
|
525
|
+
else:
|
526
|
+
final_phoneme_id_map: Dict[str, int] = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
|
527
|
+
if phonemizer.alphabet == Alphabet.IPA:
|
528
|
+
all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
|
529
|
+
|
530
|
+
# Filter out tokens that are already in the map
|
531
|
+
existing_keys: Set[str] = set(final_phoneme_id_map.keys())
|
532
|
+
new_phonemes: List[str] = sorted([p for p in all_phonemes
|
533
|
+
if p not in existing_keys]
|
534
|
+
)
|
535
|
+
|
536
|
+
_LOGGER.info("Collected %d new symbols.", len(new_phonemes))
|
537
|
+
|
538
|
+
finetune_error = prev_config and len(new_phonemes)
|
539
|
+
if finetune_error:
|
540
|
+
if not drop_extra_phonemes:
|
541
|
+
raise ValueError("training data contains different phonemes than previous phoneme map! Can not finetune model")
|
542
|
+
else:
|
543
|
+
_LOGGER.error("training data contains different phonemes than previous phoneme map! "
|
544
|
+
"Discarding new phonemes to still allow model finetuning")
|
545
|
+
|
546
|
+
current_id: int = len(final_phoneme_id_map)
|
367
547
|
for pho in new_phonemes:
|
368
|
-
|
369
|
-
|
548
|
+
if finetune_error:
|
549
|
+
_LOGGER.info(f"Discarded phoneme: {pho}")
|
550
|
+
else:
|
551
|
+
final_phoneme_id_map[pho] = current_id
|
552
|
+
current_id += 1
|
553
|
+
_LOGGER.debug(f"New phoneme: {pho}")
|
370
554
|
|
371
|
-
|
555
|
+
if new_phonemes:
|
556
|
+
_LOGGER.info("Final phoneme map contains %d symbols.", len(final_phoneme_id_map))
|
372
557
|
|
373
558
|
# --- Write the final config.json ---
|
374
559
|
_LOGGER.info("Writing dataset config...")
|
375
|
-
audio_quality =
|
376
|
-
dataset_name =
|
560
|
+
audio_quality = config.audio_quality or config.output_dir.name
|
561
|
+
dataset_name = config.dataset_name or config.output_dir.parent.name
|
377
562
|
|
378
|
-
|
563
|
+
config_data: Dict[str, Any] = {
|
379
564
|
"dataset": dataset_name,
|
380
565
|
"audio": {
|
381
|
-
"sample_rate":
|
566
|
+
"sample_rate": config.sample_rate,
|
382
567
|
"quality": audio_quality,
|
383
568
|
},
|
384
|
-
"lang_code":
|
569
|
+
"lang_code": config.language,
|
385
570
|
"inference": {"noise_scale": 0.667,
|
386
571
|
"length_scale": 1,
|
387
572
|
"noise_w": 0.8,
|
388
|
-
"add_diacritics":
|
573
|
+
"add_diacritics": config.add_diacritics},
|
389
574
|
"alphabet": phonemizer.alphabet.value,
|
390
|
-
"phoneme_type":
|
391
|
-
"phonemizer_model":
|
575
|
+
"phoneme_type": config.phoneme_type.value,
|
576
|
+
"phonemizer_model": config.phonemizer_model,
|
392
577
|
"phoneme_id_map": final_phoneme_id_map,
|
393
578
|
"num_symbols": len(final_phoneme_id_map),
|
394
579
|
"num_speakers": len(speaker_counts) if is_multispeaker else 1,
|
@@ -396,13 +581,13 @@ def main() -> None:
|
|
396
581
|
"phoonnx_version": VERSION_STR,
|
397
582
|
}
|
398
583
|
|
399
|
-
with open(
|
400
|
-
json.dump(
|
584
|
+
with open(config.output_dir / "config.json", "w", encoding="utf-8") as config_file:
|
585
|
+
json.dump(config_data, config_file, ensure_ascii=False, indent=2)
|
401
586
|
|
402
587
|
# --- Apply final phoneme IDs and write dataset.jsonl ---
|
403
588
|
_LOGGER.info("Writing dataset.jsonl...")
|
404
|
-
valid_utterances_count = 0
|
405
|
-
with open(
|
589
|
+
valid_utterances_count: int = 0
|
590
|
+
with open(config.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
|
406
591
|
for utt in processed_utterances:
|
407
592
|
if is_multispeaker and utt.speaker is not None:
|
408
593
|
if utt.speaker not in speaker_ids:
|
@@ -432,8 +617,17 @@ def main() -> None:
|
|
432
617
|
|
433
618
|
# -----------------------------------------------------------------------------
|
434
619
|
|
435
|
-
def batched(iterable, n):
|
436
|
-
"
|
620
|
+
def batched(iterable: Iterable[Any], n: int) -> Iterable[List[Any]]:
|
621
|
+
"""
|
622
|
+
Batch data from an iterable into lists of length n. The last batch may be shorter.
|
623
|
+
|
624
|
+
Args:
|
625
|
+
iterable: The input iterable to be batched.
|
626
|
+
n: The desired size of each batch.
|
627
|
+
|
628
|
+
Yields:
|
629
|
+
List[Any]: A list representing a batch of items.
|
630
|
+
"""
|
437
631
|
if n < 1:
|
438
632
|
raise ValueError("n must be at least one")
|
439
633
|
it = iter(iterable)
|
@@ -444,4 +638,4 @@ def batched(iterable, n):
|
|
444
638
|
|
445
639
|
|
446
640
|
if __name__ == "__main__":
|
447
|
-
|
641
|
+
cli()
|
phoonnx_train/train.py
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import click
|
7
|
+
from pytorch_lightning import Trainer
|
8
|
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9
|
+
|
10
|
+
from phoonnx_train.vits.lightning import VitsModel
|
11
|
+
|
12
|
+
_LOGGER = logging.getLogger(__package__)
|
13
|
+
|
14
|
+
|
15
|
+
def load_state_dict(model, saved_state_dict):
|
16
|
+
state_dict = model.state_dict()
|
17
|
+
new_state_dict = {}
|
18
|
+
|
19
|
+
for k, v in state_dict.items():
|
20
|
+
if k in saved_state_dict:
|
21
|
+
new_state_dict[k] = saved_state_dict[k]
|
22
|
+
else:
|
23
|
+
_LOGGER.debug("%s is not in the checkpoint", k)
|
24
|
+
new_state_dict[k] = v
|
25
|
+
|
26
|
+
model.load_state_dict(new_state_dict)
|
27
|
+
|
28
|
+
|
29
|
+
@click.command(context_settings=dict(ignore_unknown_options=True))
|
30
|
+
@click.option('--dataset-dir', required=True, type=click.Path(exists=True, file_okay=False), help='Path to pre-processed dataset directory')
|
31
|
+
@click.option('--checkpoint-epochs', default=1, type=int, help='Save checkpoint every N epochs (default: 1)')
|
32
|
+
@click.option('--quality', default='medium', type=click.Choice(['x-low', 'medium', 'high']), help='Quality/size of model (default: medium)')
|
33
|
+
@click.option('--resume-from-checkpoint', default=None, help='Load an existing checkpoint and resume training')
|
34
|
+
@click.option('--resume-from-single-speaker-checkpoint', help='For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training')
|
35
|
+
@click.option('--seed', type=int, default=1234, help='Random seed (default: 1234)')
|
36
|
+
# Common Trainer options
|
37
|
+
@click.option('--max-epochs', type=int, default=1000, help='Stop training once this number of epochs is reached (default: 1000)')
|
38
|
+
@click.option('--devices', default=1, help='Number of devices or list of device IDs to train on (default: 1)')
|
39
|
+
@click.option('--accelerator', default='auto', help='Hardware accelerator to use (cpu, gpu, tpu, mps, etc.) (default: "auto")')
|
40
|
+
@click.option('--default-root-dir', type=click.Path(file_okay=False), default=None, help='Default root directory for logs and checkpoints (default: None)')
|
41
|
+
@click.option('--precision', default=32, help='Precision used in training (e.g. 16, 32, bf16) (default: 32)')
|
42
|
+
# Model-specific arguments
|
43
|
+
@click.option('--learning-rate', type=float, default=2e-4, help='Learning rate for optimizer (default: 2e-4)')
|
44
|
+
@click.option('--batch-size', type=int, default=16, help='Training batch size (default: 16)')
|
45
|
+
@click.option('--num-workers', type=click.IntRange(min=1), default=os.cpu_count() or 1, help='Number of data loader workers (default: CPU count)')
|
46
|
+
@click.option('--validation-split', type=float, default=0.05, help='Proportion of data used for validation (default: 0.05)')
|
47
|
+
def main(
|
48
|
+
dataset_dir,
|
49
|
+
checkpoint_epochs,
|
50
|
+
quality,
|
51
|
+
resume_from_checkpoint,
|
52
|
+
resume_from_single_speaker_checkpoint,
|
53
|
+
seed,
|
54
|
+
max_epochs,
|
55
|
+
devices,
|
56
|
+
accelerator,
|
57
|
+
default_root_dir,
|
58
|
+
precision,
|
59
|
+
learning_rate,
|
60
|
+
batch_size,
|
61
|
+
num_workers,
|
62
|
+
validation_split,
|
63
|
+
):
|
64
|
+
logging.basicConfig(level=logging.DEBUG)
|
65
|
+
|
66
|
+
dataset_dir = Path(dataset_dir)
|
67
|
+
if default_root_dir is None:
|
68
|
+
default_root_dir = dataset_dir
|
69
|
+
|
70
|
+
torch.backends.cudnn.benchmark = True
|
71
|
+
torch.manual_seed(seed)
|
72
|
+
|
73
|
+
config_path = dataset_dir / 'config.json'
|
74
|
+
dataset_path = dataset_dir / 'dataset.jsonl'
|
75
|
+
|
76
|
+
print(f"INFO - config_path: '{config_path}'")
|
77
|
+
print(f"INFO - dataset_path: '{dataset_path}'")
|
78
|
+
|
79
|
+
with open(config_path, 'r', encoding='utf-8') as config_file:
|
80
|
+
config = json.load(config_file)
|
81
|
+
num_symbols = int(config['num_symbols'])
|
82
|
+
num_speakers = int(config['num_speakers'])
|
83
|
+
sample_rate = int(config['audio']['sample_rate'])
|
84
|
+
|
85
|
+
trainer = Trainer(
|
86
|
+
max_epochs=max_epochs,
|
87
|
+
devices=devices,
|
88
|
+
accelerator=accelerator,
|
89
|
+
default_root_dir=default_root_dir,
|
90
|
+
precision=precision,
|
91
|
+
resume_from_checkpoint=resume_from_checkpoint
|
92
|
+
)
|
93
|
+
|
94
|
+
if checkpoint_epochs is not None:
|
95
|
+
trainer.callbacks = [ModelCheckpoint(every_n_epochs=checkpoint_epochs)]
|
96
|
+
_LOGGER.info('Checkpoints will be saved every %s epoch(s)', checkpoint_epochs)
|
97
|
+
|
98
|
+
dict_args = dict(
|
99
|
+
seed=seed,
|
100
|
+
learning_rate=learning_rate,
|
101
|
+
batch_size=batch_size,
|
102
|
+
num_workers=num_workers,
|
103
|
+
validation_split=validation_split,
|
104
|
+
)
|
105
|
+
|
106
|
+
if quality == 'x-low':
|
107
|
+
dict_args.update({
|
108
|
+
'hidden_channels': 96,
|
109
|
+
'inter_channels': 96,
|
110
|
+
'filter_channels': 384,
|
111
|
+
})
|
112
|
+
elif quality == 'high':
|
113
|
+
dict_args.update({
|
114
|
+
'resblock': '1',
|
115
|
+
'resblock_kernel_sizes': (3, 7, 11),
|
116
|
+
'resblock_dilation_sizes': ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
117
|
+
'upsample_rates': (8, 8, 2, 2),
|
118
|
+
'upsample_initial_channel': 512,
|
119
|
+
'upsample_kernel_sizes': (16, 16, 4, 4),
|
120
|
+
})
|
121
|
+
|
122
|
+
print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
|
123
|
+
model = VitsModel(
|
124
|
+
num_symbols=num_symbols,
|
125
|
+
num_speakers=num_speakers,
|
126
|
+
sample_rate=sample_rate,
|
127
|
+
dataset=[dataset_path],
|
128
|
+
**dict_args,
|
129
|
+
)
|
130
|
+
|
131
|
+
if resume_from_single_speaker_checkpoint:
|
132
|
+
assert num_speakers > 1, "--resume-from-single-speaker-checkpoint is only for multi-speaker models."
|
133
|
+
_LOGGER.info('Resuming from single-speaker checkpoint: %s', resume_from_single_speaker_checkpoint)
|
134
|
+
|
135
|
+
model_single = VitsModel.load_from_checkpoint(resume_from_single_speaker_checkpoint, dataset=None)
|
136
|
+
g_dict = model_single.model_g.state_dict()
|
137
|
+
|
138
|
+
for key in list(g_dict.keys()):
|
139
|
+
if key.startswith('dec.cond') or key.startswith('dp.cond') or ('enc.cond_layer' in key):
|
140
|
+
g_dict.pop(key, None)
|
141
|
+
|
142
|
+
load_state_dict(model.model_g, g_dict)
|
143
|
+
load_state_dict(model.model_d, model_single.model_d.state_dict())
|
144
|
+
_LOGGER.info('Successfully converted single-speaker checkpoint to multi-speaker')
|
145
|
+
|
146
|
+
print('training started!!')
|
147
|
+
trainer.fit(model)
|
148
|
+
|
149
|
+
|
150
|
+
if __name__ == '__main__':
|
151
|
+
main()
|
phoonnx_train/__main__.py
DELETED
@@ -1,151 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import torch
|
7
|
-
from pytorch_lightning import Trainer
|
8
|
-
from pytorch_lightning.callbacks import ModelCheckpoint
|
9
|
-
|
10
|
-
from phoonnx_train.vits.lightning import VitsModel
|
11
|
-
|
12
|
-
_LOGGER = logging.getLogger(__package__)
|
13
|
-
|
14
|
-
|
15
|
-
def main():
|
16
|
-
logging.basicConfig(level=logging.DEBUG)
|
17
|
-
|
18
|
-
parser = argparse.ArgumentParser()
|
19
|
-
parser.add_argument(
|
20
|
-
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
|
21
|
-
)
|
22
|
-
parser.add_argument(
|
23
|
-
"--checkpoint-epochs",
|
24
|
-
type=int,
|
25
|
-
help="Save checkpoint every N epochs (default: 1)",
|
26
|
-
)
|
27
|
-
parser.add_argument(
|
28
|
-
"--quality",
|
29
|
-
default="medium",
|
30
|
-
choices=("x-low", "medium", "high"),
|
31
|
-
help="Quality/size of model (default: medium)",
|
32
|
-
)
|
33
|
-
parser.add_argument(
|
34
|
-
"--resume_from_single_speaker_checkpoint",
|
35
|
-
help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
|
36
|
-
)
|
37
|
-
Trainer.add_argparse_args(parser)
|
38
|
-
VitsModel.add_model_specific_args(parser)
|
39
|
-
parser.add_argument("--seed", type=int, default=1234)
|
40
|
-
args = parser.parse_args()
|
41
|
-
_LOGGER.debug(args)
|
42
|
-
|
43
|
-
args.dataset_dir = Path(args.dataset_dir)
|
44
|
-
if not args.default_root_dir:
|
45
|
-
args.default_root_dir = args.dataset_dir
|
46
|
-
|
47
|
-
torch.backends.cudnn.benchmark = True
|
48
|
-
torch.manual_seed(args.seed)
|
49
|
-
|
50
|
-
config_path = args.dataset_dir / "config.json"
|
51
|
-
dataset_path = args.dataset_dir / "dataset.jsonl"
|
52
|
-
|
53
|
-
print(f"INFO - config_path: '{config_path}'")
|
54
|
-
print(f"INFO - dataset_path: '{dataset_path}'")
|
55
|
-
|
56
|
-
with open(config_path, "r", encoding="utf-8") as config_file:
|
57
|
-
# See preprocess.py for format
|
58
|
-
config = json.load(config_file)
|
59
|
-
num_symbols = int(config["num_symbols"])
|
60
|
-
num_speakers = int(config["num_speakers"])
|
61
|
-
sample_rate = int(config["audio"]["sample_rate"])
|
62
|
-
|
63
|
-
trainer = Trainer.from_argparse_args(args)
|
64
|
-
if args.checkpoint_epochs is not None:
|
65
|
-
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
|
66
|
-
_LOGGER.info(
|
67
|
-
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
|
68
|
-
)
|
69
|
-
|
70
|
-
dict_args = vars(args)
|
71
|
-
if args.quality == "x-low":
|
72
|
-
dict_args["hidden_channels"] = 96
|
73
|
-
dict_args["inter_channels"] = 96
|
74
|
-
dict_args["filter_channels"] = 384
|
75
|
-
elif args.quality == "high":
|
76
|
-
dict_args["resblock"] = "1"
|
77
|
-
dict_args["resblock_kernel_sizes"] = (3, 7, 11)
|
78
|
-
dict_args["resblock_dilation_sizes"] = (
|
79
|
-
(1, 3, 5),
|
80
|
-
(1, 3, 5),
|
81
|
-
(1, 3, 5),
|
82
|
-
)
|
83
|
-
dict_args["upsample_rates"] = (8, 8, 2, 2)
|
84
|
-
dict_args["upsample_initial_channel"] = 512
|
85
|
-
dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)
|
86
|
-
|
87
|
-
print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
|
88
|
-
model = VitsModel(
|
89
|
-
num_symbols=num_symbols,
|
90
|
-
num_speakers=num_speakers,
|
91
|
-
sample_rate=sample_rate,
|
92
|
-
dataset=[dataset_path],
|
93
|
-
**dict_args,
|
94
|
-
)
|
95
|
-
|
96
|
-
if args.resume_from_single_speaker_checkpoint:
|
97
|
-
assert (
|
98
|
-
num_speakers > 1
|
99
|
-
), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
|
100
|
-
|
101
|
-
# Load single-speaker checkpoint
|
102
|
-
_LOGGER.info(
|
103
|
-
"Resuming from single-speaker checkpoint: %s",
|
104
|
-
args.resume_from_single_speaker_checkpoint,
|
105
|
-
)
|
106
|
-
model_single = VitsModel.load_from_checkpoint(
|
107
|
-
args.resume_from_single_speaker_checkpoint,
|
108
|
-
dataset=None,
|
109
|
-
)
|
110
|
-
g_dict = model_single.model_g.state_dict()
|
111
|
-
for key in list(g_dict.keys()):
|
112
|
-
# Remove keys that can't be copied over due to missing speaker embedding
|
113
|
-
if (
|
114
|
-
key.startswith("dec.cond")
|
115
|
-
or key.startswith("dp.cond")
|
116
|
-
or ("enc.cond_layer" in key)
|
117
|
-
):
|
118
|
-
g_dict.pop(key, None)
|
119
|
-
|
120
|
-
# Copy over the multi-speaker model, excluding keys related to the
|
121
|
-
# speaker embedding (which is missing from the single-speaker model).
|
122
|
-
load_state_dict(model.model_g, g_dict)
|
123
|
-
load_state_dict(model.model_d, model_single.model_d.state_dict())
|
124
|
-
_LOGGER.info(
|
125
|
-
"Successfully converted single-speaker checkpoint to multi-speaker"
|
126
|
-
)
|
127
|
-
print("training started!!")
|
128
|
-
trainer.fit(model)
|
129
|
-
|
130
|
-
|
131
|
-
def load_state_dict(model, saved_state_dict):
|
132
|
-
state_dict = model.state_dict()
|
133
|
-
new_state_dict = {}
|
134
|
-
|
135
|
-
for k, v in state_dict.items():
|
136
|
-
if k in saved_state_dict:
|
137
|
-
# Use saved value
|
138
|
-
new_state_dict[k] = saved_state_dict[k]
|
139
|
-
else:
|
140
|
-
# Use initialized value
|
141
|
-
_LOGGER.debug("%s is not in the checkpoint", k)
|
142
|
-
new_state_dict[k] = v
|
143
|
-
|
144
|
-
model.load_state_dict(new_state_dict)
|
145
|
-
|
146
|
-
|
147
|
-
# -----------------------------------------------------------------------------
|
148
|
-
|
149
|
-
|
150
|
-
if __name__ == "__main__":
|
151
|
-
main()
|
File without changes
|
File without changes
|