phoonnx 0.2.0a2__py3-none-any.whl → 0.2.2a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
phoonnx/version.py CHANGED
@@ -1,8 +1,8 @@
1
1
  # START_VERSION_BLOCK
2
2
  VERSION_MAJOR = 0
3
3
  VERSION_MINOR = 2
4
- VERSION_BUILD = 0
5
- VERSION_ALPHA = 2
4
+ VERSION_BUILD = 2
5
+ VERSION_ALPHA = 1
6
6
  # END_VERSION_BLOCK
7
7
 
8
8
  VERSION_STR = f"{VERSION_MAJOR}.{VERSION_MINOR}.{VERSION_BUILD}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: phoonnx
3
- Version: 0.2.0a2
3
+ Version: 0.2.2a1
4
4
  Home-page: https://github.com/TigreGotico/phoonnx
5
5
  Author: JarbasAi
6
6
  Author-email: jarbasai@mailfence.com
@@ -2,7 +2,7 @@ phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  phoonnx/config.py,sha256=DKgsU03g8jrAuMcVqbu-w3MWPXOUihFtRnavg6WGQ1Y,19983
3
3
  phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
4
4
  phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
5
- phoonnx/version.py,sha256=vWIFSqAEXucauJWaMPz6YHjKZz41mk4S80fAg4MMQIA,237
5
+ phoonnx/version.py,sha256=WiLyUm-i8r69uJR8Yj-q7I3TuC27Ha0HVgB22_lpeR4,237
6
6
  phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
7
7
  phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
8
8
  phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
@@ -62,9 +62,9 @@ phoonnx/thirdparty/tashkeel/hint_id_map.json,sha256=gJMdtTsfEDFgmmbyO2Shw315rkqK
62
62
  phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771ltzkv1EnYmHKc2w8M,628
63
63
  phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
64
64
  phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
65
- phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
66
65
  phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
67
- phoonnx_train/preprocess.py,sha256=8_Opy5QVNjVmSVmh1_IF23bcNebVIEXuK2KcollIy28,15793
66
+ phoonnx_train/preprocess.py,sha256=4FJFi7KL-ZUmrbN2NyhxBNpEjDlPRLSDJo2JoyvpR14,21700
67
+ phoonnx_train/train.py,sha256=EUePlnNdBuo9IFIxHsxZ4CZ27IwOCIc1ySsJiIo-dkI,6015
68
68
  phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
69
69
  phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
70
70
  phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
@@ -83,7 +83,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
83
83
  phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
84
84
  phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
85
85
  phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
86
- phoonnx-0.2.0a2.dist-info/METADATA,sha256=APrlFEorw1xKK9RXDfQyGLMtibn69yu7_3kTV0YitaQ,8250
87
- phoonnx-0.2.0a2.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
88
- phoonnx-0.2.0a2.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
89
- phoonnx-0.2.0a2.dist-info/RECORD,,
86
+ phoonnx-0.2.2a1.dist-info/METADATA,sha256=j6Tw5hEBAq4iuLqVeAswSzEERwVI-_uci2w1oYJPyfM,8250
87
+ phoonnx-0.2.2a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
88
+ phoonnx-0.2.2a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
89
+ phoonnx-0.2.2a1.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python3
2
- import argparse
3
2
  import csv
4
3
  import dataclasses
5
4
  import itertools
@@ -10,13 +9,16 @@ from collections import Counter
10
9
  from dataclasses import dataclass
11
10
  from multiprocessing import JoinableQueue, Process, Queue
12
11
  from pathlib import Path
13
- from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union
12
+ from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union, Callable
14
13
 
14
+ import click
15
15
  from phoonnx.util import normalize
16
16
  from phoonnx.config import PhonemeType, get_phonemizer, Alphabet
17
17
  from phoonnx.phonemizers import Phonemizer
18
- from phoonnx.phoneme_ids import (phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
19
- DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN)
18
+ from phoonnx.phoneme_ids import (
19
+ phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
20
+ DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN
21
+ )
20
22
  from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
21
23
  from tqdm import tqdm
22
24
  from phoonnx.version import VERSION_STR
@@ -46,7 +48,7 @@ class Utterance:
46
48
  audio_spec_path: Optional[Path] = None
47
49
 
48
50
  def asdict(self) -> Dict[str, Any]:
49
- """Custom asdict to handle Path objects."""
51
+ """Custom asdict to handle Path objects for JSON serialization."""
50
52
  data = dataclasses.asdict(self)
51
53
  for key, value in data.items():
52
54
  if isinstance(value, Path):
@@ -57,14 +59,31 @@ class Utterance:
57
59
  class PathEncoder(json.JSONEncoder):
58
60
  """JSON encoder for Path objects."""
59
61
 
60
- def default(self, o):
62
+ def default(self, o: Any) -> Union[str, Any]:
63
+ """
64
+ Converts Path objects to strings for serialization.
65
+
66
+ Args:
67
+ o: The object to serialize.
68
+
69
+ Returns:
70
+ The serialized string representation or the default JSON serialization.
71
+ """
61
72
  if isinstance(o, Path):
62
73
  return str(o)
63
74
  return super().default(o)
64
75
 
65
76
 
66
- def get_text_casing(casing: str):
67
- """Returns a function to apply text casing based on a string."""
77
+ def get_text_casing(casing: str) -> Callable[[str], str]:
78
+ """
79
+ Returns a function to apply text casing based on a string name.
80
+
81
+ Args:
82
+ casing: The name of the casing function ('lower', 'upper', 'casefold', or 'ignore').
83
+
84
+ Returns:
85
+ A callable function (str) -> str.
86
+ """
68
87
  if casing == "lower":
69
88
  return str.lower
70
89
  if casing == "upper":
@@ -74,18 +93,46 @@ def get_text_casing(casing: str):
74
93
  return lambda s: s
75
94
 
76
95
 
77
- def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
96
+ @dataclass
97
+ class PreprocessorConfig:
98
+ """Dataclass to hold all runtime configuration, mimicking argparse.Namespace."""
99
+ input_dir: Path
100
+ output_dir: Path
101
+ language: str
102
+ sample_rate: int
103
+ cache_dir: Path
104
+ max_workers: int
105
+ single_speaker: bool
106
+ speaker_id: Optional[int]
107
+ phoneme_type: PhonemeType
108
+ alphabet: Alphabet
109
+ phonemizer_model: str
110
+ text_casing: str
111
+ dataset_name: Optional[str]
112
+ audio_quality: Optional[str]
113
+ skip_audio: bool
114
+ debug: bool
115
+ add_diacritics: bool
116
+
117
+
118
+ def ljspeech_dataset(config: PreprocessorConfig) -> Iterable[Utterance]:
78
119
  """
79
120
  Generator for LJSpeech-style dataset.
80
121
  Loads metadata and resolves audio file paths.
122
+
123
+ Args:
124
+ config: The configuration object containing dataset parameters.
125
+
126
+ Yields:
127
+ Utterance: A fully populated Utterance object.
81
128
  """
82
- dataset_dir = args.input_dir
129
+ dataset_dir = config.input_dir
83
130
  metadata_path = dataset_dir / "metadata.csv"
84
131
  if not metadata_path.exists():
85
132
  _LOGGER.error(f"Missing metadata file: {metadata_path}")
86
133
  return
87
134
 
88
- wav_dirs = [dataset_dir / "wav", dataset_dir / "wavs"]
135
+ wav_dirs: List[Path] = [dataset_dir / "wav", dataset_dir / "wavs"]
89
136
 
90
137
  with open(metadata_path, "r", encoding="utf-8") as csv_file:
91
138
  reader = csv.reader(csv_file, delimiter="|")
@@ -98,16 +145,18 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
98
145
  text: str = row[-1]
99
146
  speaker: Optional[str] = None
100
147
 
101
- if not args.single_speaker and len(row) > 2:
148
+ if not config.single_speaker and len(row) > 2:
102
149
  speaker = row[1]
103
150
  else:
104
151
  speaker = None
105
152
 
106
- wav_path = None
153
+ wav_path: Optional[Path] = None
107
154
  for wav_dir in wav_dirs:
108
- potential_paths = [wav_dir / filename,
109
- wav_dir / f"{filename}.wav",
110
- wav_dir / f"{filename.lstrip('0')}.wav"]
155
+ potential_paths: List[Path] = [
156
+ wav_dir / filename,
157
+ wav_dir / f"{filename}.wav",
158
+ wav_dir / f"{filename.lstrip('0')}.wav"
159
+ ]
111
160
  for path in potential_paths:
112
161
  if path.exists():
113
162
  wav_path = path
@@ -115,34 +164,40 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
115
164
  if wav_path:
116
165
  break
117
166
 
118
- if not args.skip_audio and not wav_path:
167
+ if not config.skip_audio and not wav_path:
119
168
  _LOGGER.warning("Missing audio file for filename: %s", filename)
120
169
  continue
121
170
 
122
- if not args.skip_audio and wav_path and wav_path.stat().st_size == 0:
171
+ if not config.skip_audio and wav_path and wav_path.stat().st_size == 0:
123
172
  _LOGGER.warning("Empty audio file: %s", wav_path)
124
173
  continue
125
174
 
175
+ # Ensure wav_path is Path or None, and is never accessed if skip_audio is true
126
176
  yield Utterance(
127
177
  text=text,
128
- audio_path=wav_path,
178
+ audio_path=wav_path or Path(""), # Use empty path if skipping audio, should not be used
129
179
  speaker=speaker,
130
- speaker_id=args.speaker_id,
180
+ speaker_id=config.speaker_id,
131
181
  )
132
182
 
133
183
 
134
184
  def phonemize_worker(
135
- args: argparse.Namespace,
185
+ config: PreprocessorConfig,
136
186
  task_queue: JoinableQueue,
137
187
  result_queue: Queue,
138
188
  phonemizer: Phonemizer,
139
- ):
189
+ ) -> None:
140
190
  """
141
191
  Worker process for phonemization and audio processing.
142
- Returns the utterance and the unique phonemes found in its batch.
192
+
193
+ Args:
194
+ config: The configuration object containing runtime parameters.
195
+ task_queue: Queue for receiving batches of Utterance objects.
196
+ result_queue: Queue for sending processed results (Utterance, set of phonemes).
197
+ phonemizer: The initialized Phonemizer instance.
143
198
  """
144
199
  try:
145
- casing = get_text_casing(args.text_casing)
200
+ casing: Callable[[str], str] = get_text_casing(config.text_casing)
146
201
  silence_detector = make_silence_detector()
147
202
 
148
203
  while True:
@@ -155,28 +210,30 @@ def phonemize_worker(
155
210
 
156
211
  for utt in utterance_batch:
157
212
  try:
158
- # normalize text (case, numbers....)
159
- utterance = casing(normalize( utt.text, args.language))
213
+ # Normalize text (case, numbers, etc.)
214
+ utterance: str = casing(normalize(utt.text, config.language))
160
215
 
161
- # add diacritics
162
- if args.add_diacritics:
163
- utterance = phonemizer.add_diacritics(utterance, args.language)
216
+ # Add diacritics
217
+ if config.add_diacritics:
218
+ utterance = phonemizer.add_diacritics(utterance, config.language)
164
219
 
165
220
  # Phonemize the text
166
- utt.phonemes = phonemizer.phonemize_to_list(utterance, args.language)
221
+ utt.phonemes = [p for p in phonemizer.phonemize_to_list(utterance, config.language)
222
+ if p != "\n"] # HACK: not sure where this is coming from
167
223
  if not utt.phonemes:
168
224
  raise RuntimeError(f"Phonemes not found for '{utterance}'")
169
225
 
170
226
  # Process audio if not skipping
171
- if not args.skip_audio:
227
+ if not config.skip_audio:
172
228
  utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
173
229
  utt.audio_path,
174
- args.cache_dir,
230
+ config.cache_dir,
175
231
  silence_detector,
176
- args.sample_rate,
232
+ config.sample_rate,
177
233
  )
178
234
 
179
235
  # Put the processed utterance and its phonemes into the result queue
236
+ # The result is a tuple of (Utterance, set of unique phonemes in that utterance)
180
237
  result_queue.put((utt, set(utt.phonemes)))
181
238
  except Exception:
182
239
  _LOGGER.exception("Failed to process utterance: %s", utt.audio_path)
@@ -188,109 +245,212 @@ def phonemize_worker(
188
245
  _LOGGER.exception("Worker process failed")
189
246
 
190
247
 
191
- def main() -> None:
192
- parser = argparse.ArgumentParser(
193
- description="Preprocess a TTS dataset for training a VITS-style model."
194
- )
195
- parser.add_argument(
196
- "--input-dir", required=True, help="Directory with audio dataset"
197
- )
198
- parser.add_argument(
199
- "--output-dir",
200
- required=True,
201
- help="Directory to write output files for training",
202
- )
203
- parser.add_argument("--language", required=True, help="eSpeak-ng voice")
204
- parser.add_argument(
205
- "--sample-rate",
206
- type=int,
207
- required=True,
208
- help="Target sample rate for voice (hertz)",
209
- )
210
- parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
211
- parser.add_argument("--max-workers", type=int)
212
- parser.add_argument(
213
- "--single-speaker", action="store_true", help="Force single speaker dataset"
214
- )
215
- parser.add_argument(
216
- "--speaker-id", type=int, help="Add speaker id to single speaker dataset"
217
- )
218
- parser.add_argument(
219
- "--phoneme-type",
220
- choices=list(PhonemeType),
221
- default=PhonemeType.ESPEAK,
222
- help="Type of phonemes to use (default: espeak)",
223
- )
224
- parser.add_argument(
225
- "--alphabet",
226
- choices=list(Alphabet),
227
- default=Alphabet.IPA,
228
- help="Casing applied to utterance text",
229
- )
230
- parser.add_argument(
231
- "--phonemizer-model",
232
- default="",
233
- help="phonemizer model, if applicable",
234
- )
235
- parser.add_argument(
236
- "--text-casing",
237
- choices=("ignore", "lower", "upper", "casefold"),
238
- default="ignore",
239
- help="Casing applied to utterance text",
240
- )
241
- parser.add_argument(
242
- "--dataset-name",
243
- help="Name of dataset to put in config (default: name of <ouput_dir>/../)",
244
- )
245
- parser.add_argument(
246
- "--audio-quality",
247
- help="Audio quality to put in config (default: name of <output_dir>)",
248
- )
249
- parser.add_argument(
250
- "--skip-audio", action="store_true", help="Don't preprocess audio"
251
- )
252
- parser.add_argument(
253
- "--debug", action="store_true", help="Print DEBUG messages to the console"
254
- )
255
- parser.add_argument(
256
- "--add-diacritics", action="store_true", help="Add diacritics to text (phonemizer specific)"
248
+ @click.command(context_settings={"help_option_names": ["-h", "--help"]})
249
+ @click.option(
250
+ "-i",
251
+ "--input-dir",
252
+ "input_dir",
253
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
254
+ required=True,
255
+ help="Directory with audio dataset (e.g., containing metadata.csv and wavs/)",
256
+ )
257
+ @click.option(
258
+ "-o",
259
+ "--output-dir",
260
+ "output_dir",
261
+ type=click.Path(file_okay=False, path_type=Path),
262
+ required=True,
263
+ help="Directory to write output files for training (config.json, dataset.jsonl)",
264
+ )
265
+ @click.option(
266
+ "-l",
267
+ "--language",
268
+ "language",
269
+ required=True,
270
+ help="phonemizer language code (e.g., 'en', 'es', 'fr')",
271
+ )
272
+ @click.option(
273
+ "-c",
274
+ "--prev-config",
275
+ "prev_config",
276
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
277
+ default=None,
278
+ help="Optional path to a previous config.json from which to reuse phoneme_id_map. (for fine-tuning only)",
279
+ )
280
+ @click.option(
281
+ "--drop-extra-phonemes",
282
+ "drop_extra_phonemes",
283
+ type=bool,
284
+ default=True,
285
+ help="If training data has more symbols than base model, discard new symbols. (for fine-tuning only)",
286
+ )
287
+ @click.option(
288
+ "-r",
289
+ "--sample-rate",
290
+ "sample_rate",
291
+ type=int,
292
+ default=22050,
293
+ help="Target sample rate for voice (hertz, Default: 22050)",
294
+ )
295
+ @click.option(
296
+ "--cache-dir",
297
+ "cache_dir",
298
+ type=click.Path(file_okay=False, path_type=Path),
299
+ default=None,
300
+ help="Directory to cache processed audio files. Defaults to <output-dir>/cache/<sample-rate>.",
301
+ )
302
+ @click.option(
303
+ "-w",
304
+ "--max-workers",
305
+ "max_workers",
306
+ type=click.IntRange(min=1),
307
+ default=os.cpu_count() or 1,
308
+ help="Maximum number of worker processes to use for parallel processing. Defaults to CPU count.",
309
+ )
310
+ @click.option(
311
+ "--single-speaker",
312
+ "single_speaker",
313
+ is_flag=True,
314
+ help="Force treating the dataset as single speaker, ignoring metadata speaker columns.",
315
+ )
316
+ @click.option(
317
+ "--speaker-id",
318
+ "speaker_id",
319
+ type=int,
320
+ default=None,
321
+ help="Specify a fixed speaker ID (0, 1, etc.) for a single speaker dataset.",
322
+ )
323
+ @click.option(
324
+ "--phoneme-type",
325
+ "phoneme_type",
326
+ type=click.Choice([p.value for p in PhonemeType]),
327
+ default=PhonemeType.ESPEAK.value,
328
+ help="Type of phonemes to use.",
329
+ )
330
+ @click.option(
331
+ "--alphabet",
332
+ "alphabet",
333
+ type=click.Choice([a.value for a in Alphabet]),
334
+ default=Alphabet.IPA.value,
335
+ help="Phoneme alphabet to use (e.g., IPA).",
336
+ )
337
+ @click.option(
338
+ "--phonemizer-model",
339
+ "phonemizer_model",
340
+ default="",
341
+ help="Path or name of a custom phonemizer model, if applicable.",
342
+ )
343
+ @click.option(
344
+ "--text-casing",
345
+ "text_casing",
346
+ type=click.Choice(("ignore", "lower", "upper", "casefold")),
347
+ default="ignore",
348
+ help="Casing applied to utterance text before phonemization.",
349
+ )
350
+ @click.option(
351
+ "--dataset-name",
352
+ "dataset_name",
353
+ default=None,
354
+ help="Name of dataset to put in config (default: name of <output_dir>/../).",
355
+ )
356
+ @click.option(
357
+ "--audio-quality",
358
+ "audio_quality",
359
+ default=None,
360
+ help="Audio quality description to put in config (default: name of <output_dir>).",
361
+ )
362
+ @click.option(
363
+ "--skip-audio",
364
+ "skip_audio",
365
+ is_flag=True,
366
+ help="Do not preprocess or cache audio files.",
367
+ )
368
+ @click.option(
369
+ "--debug",
370
+ "debug",
371
+ is_flag=True,
372
+ help="Print DEBUG messages to the console.",
373
+ )
374
+ @click.option(
375
+ "--add-diacritics",
376
+ "add_diacritics",
377
+ is_flag=True,
378
+ help="Add diacritics to text (phonemizer specific, e.g., to denote stress).",
379
+ )
380
+ def cli(
381
+ input_dir: Path,
382
+ output_dir: Path,
383
+ language: str,
384
+ prev_config: Path,
385
+ drop_extra_phonemes: bool,
386
+ sample_rate: int,
387
+ cache_dir: Optional[Path],
388
+ max_workers: Optional[int],
389
+ single_speaker: bool,
390
+ speaker_id: Optional[int],
391
+ phoneme_type: str,
392
+ alphabet: str,
393
+ phonemizer_model: str,
394
+ text_casing: str,
395
+ dataset_name: Optional[str],
396
+ audio_quality: Optional[str],
397
+ skip_audio: bool,
398
+ debug: bool,
399
+ add_diacritics: bool,
400
+ ) -> None:
401
+ """
402
+ Preprocess a TTS dataset (e.g., LJSpeech format) for training a VITS-style model.
403
+ This script handles text normalization, phonemization, and optional audio caching.
404
+ """
405
+ # Create a config object from click arguments for easier passing
406
+ config = PreprocessorConfig(
407
+ input_dir=input_dir,
408
+ output_dir=output_dir,
409
+ language=language,
410
+ sample_rate=sample_rate,
411
+ cache_dir=cache_dir or output_dir / "cache" / str(sample_rate),
412
+ max_workers=max_workers or os.cpu_count() or 1,
413
+ single_speaker=single_speaker,
414
+ speaker_id=speaker_id,
415
+ phoneme_type=PhonemeType(phoneme_type),
416
+ alphabet=Alphabet(alphabet),
417
+ phonemizer_model=phonemizer_model,
418
+ text_casing=text_casing,
419
+ dataset_name=dataset_name,
420
+ audio_quality=audio_quality,
421
+ skip_audio=skip_audio,
422
+ debug=debug,
423
+ add_diacritics=add_diacritics,
257
424
  )
258
- args = parser.parse_args()
259
425
 
260
- # Setup
261
- level = logging.DEBUG if args.debug else logging.INFO
426
+ # Setup logging
427
+ level = logging.DEBUG if config.debug else logging.INFO
262
428
  logging.basicConfig(level=level)
263
429
  logging.getLogger().setLevel(level)
264
430
  logging.getLogger("numba").setLevel(logging.WARNING)
265
431
 
266
- if args.single_speaker and (args.speaker_id is not None):
432
+ # Validation
433
+ if config.single_speaker and (config.speaker_id is not None):
267
434
  _LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
268
- return
435
+ raise click.Abort()
269
436
 
270
- args.input_dir = Path(args.input_dir)
271
- args.output_dir = Path(args.output_dir)
272
- args.output_dir.mkdir(parents=True, exist_ok=True)
273
- args.cache_dir = (
274
- Path(args.cache_dir)
275
- if args.cache_dir
276
- else args.output_dir / "cache" / str(args.sample_rate)
277
- )
278
- args.cache_dir.mkdir(parents=True, exist_ok=True)
279
- args.phoneme_type = PhonemeType(args.phoneme_type)
437
+ # Create directories
438
+ config.output_dir.mkdir(parents=True, exist_ok=True)
439
+ config.cache_dir.mkdir(parents=True, exist_ok=True)
280
440
 
281
441
  # Load all utterances from the dataset
282
442
  _LOGGER.info("Loading utterances from dataset...")
283
- utterances = list(ljspeech_dataset(args))
443
+ utterances: List[Utterance] = list(ljspeech_dataset(config))
284
444
  if not utterances:
285
445
  _LOGGER.error("No valid utterances found in dataset.")
286
446
  return
287
447
 
288
- num_utterances = len(utterances)
448
+ num_utterances: int = len(utterances)
289
449
  _LOGGER.info("Found %d utterances.", num_utterances)
290
450
 
291
- # Count speakers
451
+ # Count speakers and assign IDs
292
452
  speaker_counts: Counter[str] = Counter(u.speaker for u in utterances if u.speaker)
293
- is_multispeaker = len(speaker_counts) > 1
453
+ is_multispeaker: bool = len(speaker_counts) > 1
294
454
  speaker_ids: Dict[str, int] = {}
295
455
  if is_multispeaker:
296
456
  _LOGGER.info("%s speakers detected", len(speaker_counts))
@@ -301,48 +461,47 @@ def main() -> None:
301
461
  _LOGGER.info("Single speaker dataset")
302
462
 
303
463
  # --- Single Pass: Process audio/phonemes and collect results ---
304
- # Set up multiprocessing
305
- args.max_workers = args.max_workers if args.max_workers is not None and args.max_workers > 0 else os.cpu_count()
306
- _LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
464
+ _LOGGER.info("Starting single pass processing with %d workers...", config.max_workers)
307
465
 
308
466
  # Initialize the phonemizer only once in the main process
309
- phonemizer = get_phonemizer(args.phoneme_type,
310
- args.alphabet,
311
- args.phonemizer_model)
467
+ phonemizer: Phonemizer = get_phonemizer(config.phoneme_type,
468
+ config.alphabet,
469
+ config.phonemizer_model)
312
470
 
313
- batch_size = max(1, int(num_utterances / (args.max_workers * 2)))
471
+ batch_size: int = max(1, int(num_utterances / (config.max_workers * 2)))
314
472
 
315
- task_queue: "Queue[Optional[List[Utterance]]]" = JoinableQueue()
473
+ task_queue: "JoinableQueue[Optional[List[Utterance]]]" = JoinableQueue()
316
474
  # The result queue will hold tuples of (Utterance, set(phonemes))
317
- result_queue: "Queue[Optional[Tuple[Utterance, Set[str]]]]" = Queue()
475
+ result_queue: "Queue[Tuple[Optional[Utterance], Set[str]]]" = Queue()
318
476
 
319
477
  # Start workers
320
- processes = [
478
+ processes: List[Process] = [
321
479
  Process(
322
480
  target=phonemize_worker,
323
- args=(args, task_queue, result_queue, phonemizer)
481
+ args=(config, task_queue, result_queue, phonemizer)
324
482
  )
325
- for _ in range(args.max_workers)
483
+ for _ in range(config.max_workers)
326
484
  ]
327
485
 
328
486
  for proc in processes:
329
487
  proc.start()
330
488
 
331
489
  # Populate the task queue with batches
332
- task_count = 0
490
+ task_count: int = 0
333
491
  for utt_batch in batched(utterances, batch_size):
334
492
  task_queue.put(utt_batch)
335
493
  task_count += len(utt_batch)
336
494
 
337
495
  # Signal workers to stop
338
- for _ in range(args.max_workers):
496
+ for _ in range(config.max_workers):
339
497
  task_queue.put(None)
340
498
 
341
499
  # Collect results from the queue with a progress bar
342
500
  processed_utterances: List[Utterance] = []
343
501
  all_phonemes: Set[str] = set()
344
502
  for _ in tqdm(range(task_count), desc="Processing utterances"):
345
- utt, unique_phonemes = result_queue.get()
503
+ result: Tuple[Optional[Utterance], Set[str]] = result_queue.get()
504
+ utt, unique_phonemes = result
346
505
  if utt is not None:
347
506
  processed_utterances.append(utt)
348
507
  all_phonemes.update(unique_phonemes)
@@ -352,43 +511,69 @@ def main() -> None:
352
511
  for proc in processes:
353
512
  proc.join()
354
513
 
355
- # --- Build the final phoneme map from the collected phonemes ---
356
- _LOGGER.info("Building a complete phoneme map from collected phonemes...")
357
-
358
- final_phoneme_id_map = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
359
- if phonemizer.alphabet == Alphabet.IPA:
360
- all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
361
-
362
- # Filter out special tokens that are already in the map
363
- existing_keys = set(final_phoneme_id_map.keys())
364
- new_phonemes = sorted([p for p in all_phonemes if p not in existing_keys])
365
514
 
366
- current_id = len(final_phoneme_id_map)
515
+ # --- Build the final phoneme map from the collected phonemes ---
516
+ _LOGGER.info("Building a phoneme map from collected dataset phonemes...")
517
+
518
+ if prev_config:
519
+ with open(prev_config) as f:
520
+ prev_phoneme_id_map = json.load(f)["phoneme_id_map"]
521
+ _LOGGER.info(f"Loaded phoneme map from previous config: '{prev_config}'")
522
+ all_phonemes.update(prev_phoneme_id_map.keys())
523
+ final_phoneme_id_map = prev_phoneme_id_map
524
+ _LOGGER.info("previous phoneme map contains %d symbols.", len(final_phoneme_id_map))
525
+ else:
526
+ final_phoneme_id_map: Dict[str, int] = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
527
+ if phonemizer.alphabet == Alphabet.IPA:
528
+ all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
529
+
530
+ # Filter out tokens that are already in the map
531
+ existing_keys: Set[str] = set(final_phoneme_id_map.keys())
532
+ new_phonemes: List[str] = sorted([p for p in all_phonemes
533
+ if p not in existing_keys]
534
+ )
535
+
536
+ _LOGGER.info("Collected %d new symbols.", len(new_phonemes))
537
+
538
+ finetune_error = prev_config and len(new_phonemes)
539
+ if finetune_error:
540
+ if not drop_extra_phonemes:
541
+ raise ValueError("training data contains different phonemes than previous phoneme map! Can not finetune model")
542
+ else:
543
+ _LOGGER.error("training data contains different phonemes than previous phoneme map! "
544
+ "Discarding new phonemes to still allow model finetuning")
545
+
546
+ current_id: int = len(final_phoneme_id_map)
367
547
  for pho in new_phonemes:
368
- final_phoneme_id_map[pho] = current_id
369
- current_id += 1
548
+ if finetune_error:
549
+ _LOGGER.info(f"Discarded phoneme: {pho}")
550
+ else:
551
+ final_phoneme_id_map[pho] = current_id
552
+ current_id += 1
553
+ _LOGGER.debug(f"New phoneme: {pho}")
370
554
 
371
- _LOGGER.info("Final phoneme map contains %d symbols.", len(final_phoneme_id_map))
555
+ if new_phonemes:
556
+ _LOGGER.info("Final phoneme map contains %d symbols.", len(final_phoneme_id_map))
372
557
 
373
558
  # --- Write the final config.json ---
374
559
  _LOGGER.info("Writing dataset config...")
375
- audio_quality = args.audio_quality or args.output_dir.name
376
- dataset_name = args.dataset_name or args.output_dir.parent.name
560
+ audio_quality = config.audio_quality or config.output_dir.name
561
+ dataset_name = config.dataset_name or config.output_dir.parent.name
377
562
 
378
- config = {
563
+ config_data: Dict[str, Any] = {
379
564
  "dataset": dataset_name,
380
565
  "audio": {
381
- "sample_rate": args.sample_rate,
566
+ "sample_rate": config.sample_rate,
382
567
  "quality": audio_quality,
383
568
  },
384
- "lang_code": args.language,
569
+ "lang_code": config.language,
385
570
  "inference": {"noise_scale": 0.667,
386
571
  "length_scale": 1,
387
572
  "noise_w": 0.8,
388
- "add_diacritics": args.add_diacritics},
573
+ "add_diacritics": config.add_diacritics},
389
574
  "alphabet": phonemizer.alphabet.value,
390
- "phoneme_type": args.phoneme_type.value,
391
- "phonemizer_model": args.phonemizer_model,
575
+ "phoneme_type": config.phoneme_type.value,
576
+ "phonemizer_model": config.phonemizer_model,
392
577
  "phoneme_id_map": final_phoneme_id_map,
393
578
  "num_symbols": len(final_phoneme_id_map),
394
579
  "num_speakers": len(speaker_counts) if is_multispeaker else 1,
@@ -396,13 +581,13 @@ def main() -> None:
396
581
  "phoonnx_version": VERSION_STR,
397
582
  }
398
583
 
399
- with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
400
- json.dump(config, config_file, ensure_ascii=False, indent=2)
584
+ with open(config.output_dir / "config.json", "w", encoding="utf-8") as config_file:
585
+ json.dump(config_data, config_file, ensure_ascii=False, indent=2)
401
586
 
402
587
  # --- Apply final phoneme IDs and write dataset.jsonl ---
403
588
  _LOGGER.info("Writing dataset.jsonl...")
404
- valid_utterances_count = 0
405
- with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
589
+ valid_utterances_count: int = 0
590
+ with open(config.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
406
591
  for utt in processed_utterances:
407
592
  if is_multispeaker and utt.speaker is not None:
408
593
  if utt.speaker not in speaker_ids:
@@ -432,8 +617,17 @@ def main() -> None:
432
617
 
433
618
  # -----------------------------------------------------------------------------
434
619
 
435
- def batched(iterable, n):
436
- "Batch data into lists of length n. The last batch may be shorter."
620
+ def batched(iterable: Iterable[Any], n: int) -> Iterable[List[Any]]:
621
+ """
622
+ Batch data from an iterable into lists of length n. The last batch may be shorter.
623
+
624
+ Args:
625
+ iterable: The input iterable to be batched.
626
+ n: The desired size of each batch.
627
+
628
+ Yields:
629
+ List[Any]: A list representing a batch of items.
630
+ """
437
631
  if n < 1:
438
632
  raise ValueError("n must be at least one")
439
633
  it = iter(iterable)
@@ -444,4 +638,4 @@ def batched(iterable, n):
444
638
 
445
639
 
446
640
  if __name__ == "__main__":
447
- main()
641
+ cli()
phoonnx_train/train.py ADDED
@@ -0,0 +1,151 @@
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import click
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+
10
+ from phoonnx_train.vits.lightning import VitsModel
11
+
12
+ _LOGGER = logging.getLogger(__package__)
13
+
14
+
15
+ def load_state_dict(model, saved_state_dict):
16
+ state_dict = model.state_dict()
17
+ new_state_dict = {}
18
+
19
+ for k, v in state_dict.items():
20
+ if k in saved_state_dict:
21
+ new_state_dict[k] = saved_state_dict[k]
22
+ else:
23
+ _LOGGER.debug("%s is not in the checkpoint", k)
24
+ new_state_dict[k] = v
25
+
26
+ model.load_state_dict(new_state_dict)
27
+
28
+
29
+ @click.command(context_settings=dict(ignore_unknown_options=True))
30
+ @click.option('--dataset-dir', required=True, type=click.Path(exists=True, file_okay=False), help='Path to pre-processed dataset directory')
31
+ @click.option('--checkpoint-epochs', default=1, type=int, help='Save checkpoint every N epochs (default: 1)')
32
+ @click.option('--quality', default='medium', type=click.Choice(['x-low', 'medium', 'high']), help='Quality/size of model (default: medium)')
33
+ @click.option('--resume-from-checkpoint', default=None, help='Load an existing checkpoint and resume training')
34
+ @click.option('--resume-from-single-speaker-checkpoint', help='For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training')
35
+ @click.option('--seed', type=int, default=1234, help='Random seed (default: 1234)')
36
+ # Common Trainer options
37
+ @click.option('--max-epochs', type=int, default=1000, help='Stop training once this number of epochs is reached (default: 1000)')
38
+ @click.option('--devices', default=1, help='Number of devices or list of device IDs to train on (default: 1)')
39
+ @click.option('--accelerator', default='auto', help='Hardware accelerator to use (cpu, gpu, tpu, mps, etc.) (default: "auto")')
40
+ @click.option('--default-root-dir', type=click.Path(file_okay=False), default=None, help='Default root directory for logs and checkpoints (default: None)')
41
+ @click.option('--precision', default=32, help='Precision used in training (e.g. 16, 32, bf16) (default: 32)')
42
+ # Model-specific arguments
43
+ @click.option('--learning-rate', type=float, default=2e-4, help='Learning rate for optimizer (default: 2e-4)')
44
+ @click.option('--batch-size', type=int, default=16, help='Training batch size (default: 16)')
45
+ @click.option('--num-workers', type=click.IntRange(min=1), default=os.cpu_count() or 1, help='Number of data loader workers (default: CPU count)')
46
+ @click.option('--validation-split', type=float, default=0.05, help='Proportion of data used for validation (default: 0.05)')
47
+ def main(
48
+ dataset_dir,
49
+ checkpoint_epochs,
50
+ quality,
51
+ resume_from_checkpoint,
52
+ resume_from_single_speaker_checkpoint,
53
+ seed,
54
+ max_epochs,
55
+ devices,
56
+ accelerator,
57
+ default_root_dir,
58
+ precision,
59
+ learning_rate,
60
+ batch_size,
61
+ num_workers,
62
+ validation_split,
63
+ ):
64
+ logging.basicConfig(level=logging.DEBUG)
65
+
66
+ dataset_dir = Path(dataset_dir)
67
+ if default_root_dir is None:
68
+ default_root_dir = dataset_dir
69
+
70
+ torch.backends.cudnn.benchmark = True
71
+ torch.manual_seed(seed)
72
+
73
+ config_path = dataset_dir / 'config.json'
74
+ dataset_path = dataset_dir / 'dataset.jsonl'
75
+
76
+ print(f"INFO - config_path: '{config_path}'")
77
+ print(f"INFO - dataset_path: '{dataset_path}'")
78
+
79
+ with open(config_path, 'r', encoding='utf-8') as config_file:
80
+ config = json.load(config_file)
81
+ num_symbols = int(config['num_symbols'])
82
+ num_speakers = int(config['num_speakers'])
83
+ sample_rate = int(config['audio']['sample_rate'])
84
+
85
+ trainer = Trainer(
86
+ max_epochs=max_epochs,
87
+ devices=devices,
88
+ accelerator=accelerator,
89
+ default_root_dir=default_root_dir,
90
+ precision=precision,
91
+ resume_from_checkpoint=resume_from_checkpoint
92
+ )
93
+
94
+ if checkpoint_epochs is not None:
95
+ trainer.callbacks = [ModelCheckpoint(every_n_epochs=checkpoint_epochs)]
96
+ _LOGGER.info('Checkpoints will be saved every %s epoch(s)', checkpoint_epochs)
97
+
98
+ dict_args = dict(
99
+ seed=seed,
100
+ learning_rate=learning_rate,
101
+ batch_size=batch_size,
102
+ num_workers=num_workers,
103
+ validation_split=validation_split,
104
+ )
105
+
106
+ if quality == 'x-low':
107
+ dict_args.update({
108
+ 'hidden_channels': 96,
109
+ 'inter_channels': 96,
110
+ 'filter_channels': 384,
111
+ })
112
+ elif quality == 'high':
113
+ dict_args.update({
114
+ 'resblock': '1',
115
+ 'resblock_kernel_sizes': (3, 7, 11),
116
+ 'resblock_dilation_sizes': ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
117
+ 'upsample_rates': (8, 8, 2, 2),
118
+ 'upsample_initial_channel': 512,
119
+ 'upsample_kernel_sizes': (16, 16, 4, 4),
120
+ })
121
+
122
+ print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
123
+ model = VitsModel(
124
+ num_symbols=num_symbols,
125
+ num_speakers=num_speakers,
126
+ sample_rate=sample_rate,
127
+ dataset=[dataset_path],
128
+ **dict_args,
129
+ )
130
+
131
+ if resume_from_single_speaker_checkpoint:
132
+ assert num_speakers > 1, "--resume-from-single-speaker-checkpoint is only for multi-speaker models."
133
+ _LOGGER.info('Resuming from single-speaker checkpoint: %s', resume_from_single_speaker_checkpoint)
134
+
135
+ model_single = VitsModel.load_from_checkpoint(resume_from_single_speaker_checkpoint, dataset=None)
136
+ g_dict = model_single.model_g.state_dict()
137
+
138
+ for key in list(g_dict.keys()):
139
+ if key.startswith('dec.cond') or key.startswith('dp.cond') or ('enc.cond_layer' in key):
140
+ g_dict.pop(key, None)
141
+
142
+ load_state_dict(model.model_g, g_dict)
143
+ load_state_dict(model.model_d, model_single.model_d.state_dict())
144
+ _LOGGER.info('Successfully converted single-speaker checkpoint to multi-speaker')
145
+
146
+ print('training started!!')
147
+ trainer.fit(model)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
phoonnx_train/__main__.py DELETED
@@ -1,151 +0,0 @@
1
- import argparse
2
- import json
3
- import logging
4
- from pathlib import Path
5
-
6
- import torch
7
- from pytorch_lightning import Trainer
8
- from pytorch_lightning.callbacks import ModelCheckpoint
9
-
10
- from phoonnx_train.vits.lightning import VitsModel
11
-
12
- _LOGGER = logging.getLogger(__package__)
13
-
14
-
15
- def main():
16
- logging.basicConfig(level=logging.DEBUG)
17
-
18
- parser = argparse.ArgumentParser()
19
- parser.add_argument(
20
- "--dataset-dir", required=True, help="Path to pre-processed dataset directory"
21
- )
22
- parser.add_argument(
23
- "--checkpoint-epochs",
24
- type=int,
25
- help="Save checkpoint every N epochs (default: 1)",
26
- )
27
- parser.add_argument(
28
- "--quality",
29
- default="medium",
30
- choices=("x-low", "medium", "high"),
31
- help="Quality/size of model (default: medium)",
32
- )
33
- parser.add_argument(
34
- "--resume_from_single_speaker_checkpoint",
35
- help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
36
- )
37
- Trainer.add_argparse_args(parser)
38
- VitsModel.add_model_specific_args(parser)
39
- parser.add_argument("--seed", type=int, default=1234)
40
- args = parser.parse_args()
41
- _LOGGER.debug(args)
42
-
43
- args.dataset_dir = Path(args.dataset_dir)
44
- if not args.default_root_dir:
45
- args.default_root_dir = args.dataset_dir
46
-
47
- torch.backends.cudnn.benchmark = True
48
- torch.manual_seed(args.seed)
49
-
50
- config_path = args.dataset_dir / "config.json"
51
- dataset_path = args.dataset_dir / "dataset.jsonl"
52
-
53
- print(f"INFO - config_path: '{config_path}'")
54
- print(f"INFO - dataset_path: '{dataset_path}'")
55
-
56
- with open(config_path, "r", encoding="utf-8") as config_file:
57
- # See preprocess.py for format
58
- config = json.load(config_file)
59
- num_symbols = int(config["num_symbols"])
60
- num_speakers = int(config["num_speakers"])
61
- sample_rate = int(config["audio"]["sample_rate"])
62
-
63
- trainer = Trainer.from_argparse_args(args)
64
- if args.checkpoint_epochs is not None:
65
- trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
66
- _LOGGER.info(
67
- "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
68
- )
69
-
70
- dict_args = vars(args)
71
- if args.quality == "x-low":
72
- dict_args["hidden_channels"] = 96
73
- dict_args["inter_channels"] = 96
74
- dict_args["filter_channels"] = 384
75
- elif args.quality == "high":
76
- dict_args["resblock"] = "1"
77
- dict_args["resblock_kernel_sizes"] = (3, 7, 11)
78
- dict_args["resblock_dilation_sizes"] = (
79
- (1, 3, 5),
80
- (1, 3, 5),
81
- (1, 3, 5),
82
- )
83
- dict_args["upsample_rates"] = (8, 8, 2, 2)
84
- dict_args["upsample_initial_channel"] = 512
85
- dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)
86
-
87
- print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
88
- model = VitsModel(
89
- num_symbols=num_symbols,
90
- num_speakers=num_speakers,
91
- sample_rate=sample_rate,
92
- dataset=[dataset_path],
93
- **dict_args,
94
- )
95
-
96
- if args.resume_from_single_speaker_checkpoint:
97
- assert (
98
- num_speakers > 1
99
- ), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
100
-
101
- # Load single-speaker checkpoint
102
- _LOGGER.info(
103
- "Resuming from single-speaker checkpoint: %s",
104
- args.resume_from_single_speaker_checkpoint,
105
- )
106
- model_single = VitsModel.load_from_checkpoint(
107
- args.resume_from_single_speaker_checkpoint,
108
- dataset=None,
109
- )
110
- g_dict = model_single.model_g.state_dict()
111
- for key in list(g_dict.keys()):
112
- # Remove keys that can't be copied over due to missing speaker embedding
113
- if (
114
- key.startswith("dec.cond")
115
- or key.startswith("dp.cond")
116
- or ("enc.cond_layer" in key)
117
- ):
118
- g_dict.pop(key, None)
119
-
120
- # Copy over the multi-speaker model, excluding keys related to the
121
- # speaker embedding (which is missing from the single-speaker model).
122
- load_state_dict(model.model_g, g_dict)
123
- load_state_dict(model.model_d, model_single.model_d.state_dict())
124
- _LOGGER.info(
125
- "Successfully converted single-speaker checkpoint to multi-speaker"
126
- )
127
- print("training started!!")
128
- trainer.fit(model)
129
-
130
-
131
- def load_state_dict(model, saved_state_dict):
132
- state_dict = model.state_dict()
133
- new_state_dict = {}
134
-
135
- for k, v in state_dict.items():
136
- if k in saved_state_dict:
137
- # Use saved value
138
- new_state_dict[k] = saved_state_dict[k]
139
- else:
140
- # Use initialized value
141
- _LOGGER.debug("%s is not in the checkpoint", k)
142
- new_state_dict[k] = v
143
-
144
- model.load_state_dict(new_state_dict)
145
-
146
-
147
- # -----------------------------------------------------------------------------
148
-
149
-
150
- if __name__ == "__main__":
151
- main()