phoonnx 0.2.0a1__py3-none-any.whl → 0.2.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  from phoonnx.phonemizers.base import BasePhonemizer, Alphabet
2
- from mwl_phonemizer import CRFOrthoCorrector
3
2
 
4
3
 
5
4
  class MirandesePhonemizer(BasePhonemizer):
@@ -7,6 +6,7 @@ class MirandesePhonemizer(BasePhonemizer):
7
6
 
8
7
  def __init__(self):
9
8
  super().__init__(Alphabet.IPA)
9
+ from mwl_phonemizer import CRFOrthoCorrector
10
10
  self.pho = CRFOrthoCorrector()
11
11
 
12
12
  @classmethod
phoonnx/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # START_VERSION_BLOCK
2
2
  VERSION_MAJOR = 0
3
3
  VERSION_MINOR = 2
4
- VERSION_BUILD = 0
4
+ VERSION_BUILD = 1
5
5
  VERSION_ALPHA = 1
6
6
  # END_VERSION_BLOCK
7
7
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: phoonnx
3
- Version: 0.2.0a1
3
+ Version: 0.2.1a1
4
4
  Home-page: https://github.com/TigreGotico/phoonnx
5
5
  Author: JarbasAi
6
6
  Author-email: jarbasai@mailfence.com
@@ -2,7 +2,7 @@ phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  phoonnx/config.py,sha256=DKgsU03g8jrAuMcVqbu-w3MWPXOUihFtRnavg6WGQ1Y,19983
3
3
  phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
4
4
  phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
5
- phoonnx/version.py,sha256=pjMhhxCQpOnjMJmb_1XE4wD6sGrE3QOU71jk4mAIZTQ,237
5
+ phoonnx/version.py,sha256=hzUBKPGD2cJBEmHoNugbPPHTO3ITl0OwgO9W0oQfjho,237
6
6
  phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
7
7
  phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
8
8
  phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
@@ -18,7 +18,7 @@ phoonnx/phonemizers/he.py,sha256=49OFS34wSFvvR9B3z2bGSzSLmlIvnn2HtkHBOkHS9Ns,138
18
18
  phoonnx/phonemizers/ja.py,sha256=Xojsrt715ihnIiEk9K6giYqDo9Iykw-SHfIidrHtHSU,3834
19
19
  phoonnx/phonemizers/ko.py,sha256=kwWoOFqanCB8kv2JRx17A0hP78P1wbXlX6e8VBn1ezQ,2989
20
20
  phoonnx/phonemizers/mul.py,sha256=Y_M5BUY4Yka6Ba62Eea1HvgC6FTrrigaulo4KNRi1vE,99580
21
- phoonnx/phonemizers/mwl.py,sha256=9bwKmKQ-fXQQKK04fmKbT9QiraD0r3rKdNFZkWZP-eI,999
21
+ phoonnx/phonemizers/mwl.py,sha256=xAOB1Bz_uVO14WbYlSFgvPxsezxzUKFwy6GT2mDgP2w,1007
22
22
  phoonnx/phonemizers/vi.py,sha256=_XJc-Xeawr1Lxr7o8mE_hJao1aGcj4g01XYAOxC_Scg,1311
23
23
  phoonnx/phonemizers/zh.py,sha256=88Ywq8h9LDanlyz8RHjRSCY_PRK_Dq808tBADyrgaP8,9657
24
24
  phoonnx/thirdparty/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -62,9 +62,9 @@ phoonnx/thirdparty/tashkeel/hint_id_map.json,sha256=gJMdtTsfEDFgmmbyO2Shw315rkqK
62
62
  phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771ltzkv1EnYmHKc2w8M,628
63
63
  phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
64
64
  phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
65
- phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
66
65
  phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
67
- phoonnx_train/preprocess.py,sha256=8_Opy5QVNjVmSVmh1_IF23bcNebVIEXuK2KcollIy28,15793
66
+ phoonnx_train/preprocess.py,sha256=nSkcjThEGfKuwmOurhoXwXfnksqtb-hcdCbx1byDGRI,19890
67
+ phoonnx_train/train.py,sha256=6ydKJb1sqZy6wJpfSAlkXZ8Y6W2GKsg7c8KPodj3j-Q,5989
68
68
  phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
69
69
  phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
70
70
  phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
@@ -83,7 +83,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
83
83
  phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
84
84
  phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
85
85
  phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
86
- phoonnx-0.2.0a1.dist-info/METADATA,sha256=YzTNDisiyAKoRj_Ig13nUXwtcL0mV4AaABva5c7OYOo,8250
87
- phoonnx-0.2.0a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
88
- phoonnx-0.2.0a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
89
- phoonnx-0.2.0a1.dist-info/RECORD,,
86
+ phoonnx-0.2.1a1.dist-info/METADATA,sha256=0nAZevi9ypEWjg_QQpByxdQX3KGDKkfNlyb7rhLHmww,8250
87
+ phoonnx-0.2.1a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
88
+ phoonnx-0.2.1a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
89
+ phoonnx-0.2.1a1.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python3
2
- import argparse
3
2
  import csv
4
3
  import dataclasses
5
4
  import itertools
@@ -10,13 +9,16 @@ from collections import Counter
10
9
  from dataclasses import dataclass
11
10
  from multiprocessing import JoinableQueue, Process, Queue
12
11
  from pathlib import Path
13
- from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union
12
+ from typing import Dict, Iterable, List, Optional, Tuple, Any, Set, Union, Callable
14
13
 
14
+ import click
15
15
  from phoonnx.util import normalize
16
16
  from phoonnx.config import PhonemeType, get_phonemizer, Alphabet
17
17
  from phoonnx.phonemizers import Phonemizer
18
- from phoonnx.phoneme_ids import (phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
19
- DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN)
18
+ from phoonnx.phoneme_ids import (
19
+ phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DEFAULT_PAD_TOKEN,
20
+ DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN
21
+ )
20
22
  from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
21
23
  from tqdm import tqdm
22
24
  from phoonnx.version import VERSION_STR
@@ -46,7 +48,7 @@ class Utterance:
46
48
  audio_spec_path: Optional[Path] = None
47
49
 
48
50
  def asdict(self) -> Dict[str, Any]:
49
- """Custom asdict to handle Path objects."""
51
+ """Custom asdict to handle Path objects for JSON serialization."""
50
52
  data = dataclasses.asdict(self)
51
53
  for key, value in data.items():
52
54
  if isinstance(value, Path):
@@ -57,14 +59,31 @@ class Utterance:
57
59
  class PathEncoder(json.JSONEncoder):
58
60
  """JSON encoder for Path objects."""
59
61
 
60
- def default(self, o):
62
+ def default(self, o: Any) -> Union[str, Any]:
63
+ """
64
+ Converts Path objects to strings for serialization.
65
+
66
+ Args:
67
+ o: The object to serialize.
68
+
69
+ Returns:
70
+ The serialized string representation or the default JSON serialization.
71
+ """
61
72
  if isinstance(o, Path):
62
73
  return str(o)
63
74
  return super().default(o)
64
75
 
65
76
 
66
- def get_text_casing(casing: str):
67
- """Returns a function to apply text casing based on a string."""
77
+ def get_text_casing(casing: str) -> Callable[[str], str]:
78
+ """
79
+ Returns a function to apply text casing based on a string name.
80
+
81
+ Args:
82
+ casing: The name of the casing function ('lower', 'upper', 'casefold', or 'ignore').
83
+
84
+ Returns:
85
+ A callable function (str) -> str.
86
+ """
68
87
  if casing == "lower":
69
88
  return str.lower
70
89
  if casing == "upper":
@@ -74,18 +93,46 @@ def get_text_casing(casing: str):
74
93
  return lambda s: s
75
94
 
76
95
 
77
- def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
96
+ @dataclass
97
+ class PreprocessorConfig:
98
+ """Dataclass to hold all runtime configuration, mimicking argparse.Namespace."""
99
+ input_dir: Path
100
+ output_dir: Path
101
+ language: str
102
+ sample_rate: int
103
+ cache_dir: Path
104
+ max_workers: int
105
+ single_speaker: bool
106
+ speaker_id: Optional[int]
107
+ phoneme_type: PhonemeType
108
+ alphabet: Alphabet
109
+ phonemizer_model: str
110
+ text_casing: str
111
+ dataset_name: Optional[str]
112
+ audio_quality: Optional[str]
113
+ skip_audio: bool
114
+ debug: bool
115
+ add_diacritics: bool
116
+
117
+
118
+ def ljspeech_dataset(config: PreprocessorConfig) -> Iterable[Utterance]:
78
119
  """
79
120
  Generator for LJSpeech-style dataset.
80
121
  Loads metadata and resolves audio file paths.
122
+
123
+ Args:
124
+ config: The configuration object containing dataset parameters.
125
+
126
+ Yields:
127
+ Utterance: A fully populated Utterance object.
81
128
  """
82
- dataset_dir = args.input_dir
129
+ dataset_dir = config.input_dir
83
130
  metadata_path = dataset_dir / "metadata.csv"
84
131
  if not metadata_path.exists():
85
132
  _LOGGER.error(f"Missing metadata file: {metadata_path}")
86
133
  return
87
134
 
88
- wav_dirs = [dataset_dir / "wav", dataset_dir / "wavs"]
135
+ wav_dirs: List[Path] = [dataset_dir / "wav", dataset_dir / "wavs"]
89
136
 
90
137
  with open(metadata_path, "r", encoding="utf-8") as csv_file:
91
138
  reader = csv.reader(csv_file, delimiter="|")
@@ -98,16 +145,18 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
98
145
  text: str = row[-1]
99
146
  speaker: Optional[str] = None
100
147
 
101
- if not args.single_speaker and len(row) > 2:
148
+ if not config.single_speaker and len(row) > 2:
102
149
  speaker = row[1]
103
150
  else:
104
151
  speaker = None
105
152
 
106
- wav_path = None
153
+ wav_path: Optional[Path] = None
107
154
  for wav_dir in wav_dirs:
108
- potential_paths = [wav_dir / filename,
109
- wav_dir / f"{filename}.wav",
110
- wav_dir / f"{filename.lstrip('0')}.wav"]
155
+ potential_paths: List[Path] = [
156
+ wav_dir / filename,
157
+ wav_dir / f"{filename}.wav",
158
+ wav_dir / f"{filename.lstrip('0')}.wav"
159
+ ]
111
160
  for path in potential_paths:
112
161
  if path.exists():
113
162
  wav_path = path
@@ -115,34 +164,40 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
115
164
  if wav_path:
116
165
  break
117
166
 
118
- if not args.skip_audio and not wav_path:
167
+ if not config.skip_audio and not wav_path:
119
168
  _LOGGER.warning("Missing audio file for filename: %s", filename)
120
169
  continue
121
170
 
122
- if not args.skip_audio and wav_path and wav_path.stat().st_size == 0:
171
+ if not config.skip_audio and wav_path and wav_path.stat().st_size == 0:
123
172
  _LOGGER.warning("Empty audio file: %s", wav_path)
124
173
  continue
125
174
 
175
+ # Ensure wav_path is Path or None, and is never accessed if skip_audio is true
126
176
  yield Utterance(
127
177
  text=text,
128
- audio_path=wav_path,
178
+ audio_path=wav_path or Path(""), # Use empty path if skipping audio, should not be used
129
179
  speaker=speaker,
130
- speaker_id=args.speaker_id,
180
+ speaker_id=config.speaker_id,
131
181
  )
132
182
 
133
183
 
134
184
  def phonemize_worker(
135
- args: argparse.Namespace,
185
+ config: PreprocessorConfig,
136
186
  task_queue: JoinableQueue,
137
187
  result_queue: Queue,
138
188
  phonemizer: Phonemizer,
139
- ):
189
+ ) -> None:
140
190
  """
141
191
  Worker process for phonemization and audio processing.
142
- Returns the utterance and the unique phonemes found in its batch.
192
+
193
+ Args:
194
+ config: The configuration object containing runtime parameters.
195
+ task_queue: Queue for receiving batches of Utterance objects.
196
+ result_queue: Queue for sending processed results (Utterance, set of phonemes).
197
+ phonemizer: The initialized Phonemizer instance.
143
198
  """
144
199
  try:
145
- casing = get_text_casing(args.text_casing)
200
+ casing: Callable[[str], str] = get_text_casing(config.text_casing)
146
201
  silence_detector = make_silence_detector()
147
202
 
148
203
  while True:
@@ -155,28 +210,29 @@ def phonemize_worker(
155
210
 
156
211
  for utt in utterance_batch:
157
212
  try:
158
- # normalize text (case, numbers....)
159
- utterance = casing(normalize( utt.text, args.language))
213
+ # Normalize text (case, numbers, etc.)
214
+ utterance: str = casing(normalize(utt.text, config.language))
160
215
 
161
- # add diacritics
162
- if args.add_diacritics:
163
- utterance = phonemizer.add_diacritics(utterance, args.language)
216
+ # Add diacritics
217
+ if config.add_diacritics:
218
+ utterance = phonemizer.add_diacritics(utterance, config.language)
164
219
 
165
220
  # Phonemize the text
166
- utt.phonemes = phonemizer.phonemize_to_list(utterance, args.language)
221
+ utt.phonemes = phonemizer.phonemize_to_list(utterance, config.language)
167
222
  if not utt.phonemes:
168
223
  raise RuntimeError(f"Phonemes not found for '{utterance}'")
169
224
 
170
225
  # Process audio if not skipping
171
- if not args.skip_audio:
226
+ if not config.skip_audio:
172
227
  utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
173
228
  utt.audio_path,
174
- args.cache_dir,
229
+ config.cache_dir,
175
230
  silence_detector,
176
- args.sample_rate,
231
+ config.sample_rate,
177
232
  )
178
233
 
179
234
  # Put the processed utterance and its phonemes into the result queue
235
+ # The result is a tuple of (Utterance, set of unique phonemes in that utterance)
180
236
  result_queue.put((utt, set(utt.phonemes)))
181
237
  except Exception:
182
238
  _LOGGER.exception("Failed to process utterance: %s", utt.audio_path)
@@ -188,109 +244,195 @@ def phonemize_worker(
188
244
  _LOGGER.exception("Worker process failed")
189
245
 
190
246
 
191
- def main() -> None:
192
- parser = argparse.ArgumentParser(
193
- description="Preprocess a TTS dataset for training a VITS-style model."
194
- )
195
- parser.add_argument(
196
- "--input-dir", required=True, help="Directory with audio dataset"
197
- )
198
- parser.add_argument(
199
- "--output-dir",
200
- required=True,
201
- help="Directory to write output files for training",
202
- )
203
- parser.add_argument("--language", required=True, help="eSpeak-ng voice")
204
- parser.add_argument(
205
- "--sample-rate",
206
- type=int,
207
- required=True,
208
- help="Target sample rate for voice (hertz)",
209
- )
210
- parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
211
- parser.add_argument("--max-workers", type=int)
212
- parser.add_argument(
213
- "--single-speaker", action="store_true", help="Force single speaker dataset"
214
- )
215
- parser.add_argument(
216
- "--speaker-id", type=int, help="Add speaker id to single speaker dataset"
217
- )
218
- parser.add_argument(
219
- "--phoneme-type",
220
- choices=list(PhonemeType),
221
- default=PhonemeType.ESPEAK,
222
- help="Type of phonemes to use (default: espeak)",
223
- )
224
- parser.add_argument(
225
- "--alphabet",
226
- choices=list(Alphabet),
227
- default=Alphabet.IPA,
228
- help="Casing applied to utterance text",
229
- )
230
- parser.add_argument(
231
- "--phonemizer-model",
232
- default="",
233
- help="phonemizer model, if applicable",
234
- )
235
- parser.add_argument(
236
- "--text-casing",
237
- choices=("ignore", "lower", "upper", "casefold"),
238
- default="ignore",
239
- help="Casing applied to utterance text",
240
- )
241
- parser.add_argument(
242
- "--dataset-name",
243
- help="Name of dataset to put in config (default: name of <ouput_dir>/../)",
244
- )
245
- parser.add_argument(
246
- "--audio-quality",
247
- help="Audio quality to put in config (default: name of <output_dir>)",
248
- )
249
- parser.add_argument(
250
- "--skip-audio", action="store_true", help="Don't preprocess audio"
251
- )
252
- parser.add_argument(
253
- "--debug", action="store_true", help="Print DEBUG messages to the console"
254
- )
255
- parser.add_argument(
256
- "--add-diacritics", action="store_true", help="Add diacritics to text (phonemizer specific)"
247
+ @click.command(context_settings={"help_option_names": ["-h", "--help"]})
248
+ @click.option(
249
+ "-i",
250
+ "--input-dir",
251
+ "input_dir",
252
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
253
+ required=True,
254
+ help="Directory with audio dataset (e.g., containing metadata.csv and wavs/)",
255
+ )
256
+ @click.option(
257
+ "-o",
258
+ "--output-dir",
259
+ "output_dir",
260
+ type=click.Path(file_okay=False, path_type=Path),
261
+ required=True,
262
+ help="Directory to write output files for training (config.json, dataset.jsonl)",
263
+ )
264
+ @click.option(
265
+ "-l",
266
+ "--language",
267
+ "language",
268
+ required=True,
269
+ help="phonemizer language code (e.g., 'en', 'es', 'fr')",
270
+ )
271
+ @click.option(
272
+ "-r",
273
+ "--sample-rate",
274
+ "sample_rate",
275
+ type=int,
276
+ required=True,
277
+ help="Target sample rate for voice (hertz, e.g., 22050)",
278
+ )
279
+ @click.option(
280
+ "--cache-dir",
281
+ "cache_dir",
282
+ type=click.Path(file_okay=False, path_type=Path),
283
+ default=None,
284
+ help="Directory to cache processed audio files. Defaults to <output-dir>/cache/<sample-rate>.",
285
+ )
286
+ @click.option(
287
+ "-w",
288
+ "--max-workers",
289
+ "max_workers",
290
+ type=click.IntRange(min=1),
291
+ default=os.cpu_count() or 1,
292
+ help="Maximum number of worker processes to use for parallel processing. Defaults to CPU count.",
293
+ )
294
+ @click.option(
295
+ "--single-speaker",
296
+ "single_speaker",
297
+ is_flag=True,
298
+ help="Force treating the dataset as single speaker, ignoring metadata speaker columns.",
299
+ )
300
+ @click.option(
301
+ "--speaker-id",
302
+ "speaker_id",
303
+ type=int,
304
+ default=None,
305
+ help="Specify a fixed speaker ID (0, 1, etc.) for a single speaker dataset.",
306
+ )
307
+ @click.option(
308
+ "--phoneme-type",
309
+ "phoneme_type",
310
+ type=click.Choice([p.value for p in PhonemeType]),
311
+ default=PhonemeType.ESPEAK.value,
312
+ help="Type of phonemes to use.",
313
+ )
314
+ @click.option(
315
+ "--alphabet",
316
+ "alphabet",
317
+ type=click.Choice([a.value for a in Alphabet]),
318
+ default=Alphabet.IPA.value,
319
+ help="Phoneme alphabet to use (e.g., IPA).",
320
+ )
321
+ @click.option(
322
+ "--phonemizer-model",
323
+ "phonemizer_model",
324
+ default="",
325
+ help="Path or name of a custom phonemizer model, if applicable.",
326
+ )
327
+ @click.option(
328
+ "--text-casing",
329
+ "text_casing",
330
+ type=click.Choice(("ignore", "lower", "upper", "casefold")),
331
+ default="ignore",
332
+ help="Casing applied to utterance text before phonemization.",
333
+ )
334
+ @click.option(
335
+ "--dataset-name",
336
+ "dataset_name",
337
+ default=None,
338
+ help="Name of dataset to put in config (default: name of <output_dir>/../).",
339
+ )
340
+ @click.option(
341
+ "--audio-quality",
342
+ "audio_quality",
343
+ default=None,
344
+ help="Audio quality description to put in config (default: name of <output_dir>).",
345
+ )
346
+ @click.option(
347
+ "--skip-audio",
348
+ "skip_audio",
349
+ is_flag=True,
350
+ help="Do not preprocess or cache audio files.",
351
+ )
352
+ @click.option(
353
+ "--debug",
354
+ "debug",
355
+ is_flag=True,
356
+ help="Print DEBUG messages to the console.",
357
+ )
358
+ @click.option(
359
+ "--add-diacritics",
360
+ "add_diacritics",
361
+ is_flag=True,
362
+ help="Add diacritics to text (phonemizer specific, e.g., to denote stress).",
363
+ )
364
+ def cli(
365
+ input_dir: Path,
366
+ output_dir: Path,
367
+ language: str,
368
+ sample_rate: int,
369
+ cache_dir: Optional[Path],
370
+ max_workers: Optional[int],
371
+ single_speaker: bool,
372
+ speaker_id: Optional[int],
373
+ phoneme_type: str,
374
+ alphabet: str,
375
+ phonemizer_model: str,
376
+ text_casing: str,
377
+ dataset_name: Optional[str],
378
+ audio_quality: Optional[str],
379
+ skip_audio: bool,
380
+ debug: bool,
381
+ add_diacritics: bool,
382
+ ) -> None:
383
+ """
384
+ Preprocess a TTS dataset (e.g., LJSpeech format) for training a VITS-style model.
385
+ This script handles text normalization, phonemization, and optional audio caching.
386
+ """
387
+ # Create a config object from click arguments for easier passing
388
+ config = PreprocessorConfig(
389
+ input_dir=input_dir,
390
+ output_dir=output_dir,
391
+ language=language,
392
+ sample_rate=sample_rate,
393
+ cache_dir=cache_dir or output_dir / "cache" / str(sample_rate),
394
+ max_workers=max_workers or os.cpu_count() or 1,
395
+ single_speaker=single_speaker,
396
+ speaker_id=speaker_id,
397
+ phoneme_type=PhonemeType(phoneme_type),
398
+ alphabet=Alphabet(alphabet),
399
+ phonemizer_model=phonemizer_model,
400
+ text_casing=text_casing,
401
+ dataset_name=dataset_name,
402
+ audio_quality=audio_quality,
403
+ skip_audio=skip_audio,
404
+ debug=debug,
405
+ add_diacritics=add_diacritics,
257
406
  )
258
- args = parser.parse_args()
259
407
 
260
- # Setup
261
- level = logging.DEBUG if args.debug else logging.INFO
408
+ # Setup logging
409
+ level = logging.DEBUG if config.debug else logging.INFO
262
410
  logging.basicConfig(level=level)
263
411
  logging.getLogger().setLevel(level)
264
412
  logging.getLogger("numba").setLevel(logging.WARNING)
265
413
 
266
- if args.single_speaker and (args.speaker_id is not None):
414
+ # Validation
415
+ if config.single_speaker and (config.speaker_id is not None):
267
416
  _LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided")
268
- return
417
+ raise click.Abort()
269
418
 
270
- args.input_dir = Path(args.input_dir)
271
- args.output_dir = Path(args.output_dir)
272
- args.output_dir.mkdir(parents=True, exist_ok=True)
273
- args.cache_dir = (
274
- Path(args.cache_dir)
275
- if args.cache_dir
276
- else args.output_dir / "cache" / str(args.sample_rate)
277
- )
278
- args.cache_dir.mkdir(parents=True, exist_ok=True)
279
- args.phoneme_type = PhonemeType(args.phoneme_type)
419
+ # Create directories
420
+ config.output_dir.mkdir(parents=True, exist_ok=True)
421
+ config.cache_dir.mkdir(parents=True, exist_ok=True)
280
422
 
281
423
  # Load all utterances from the dataset
282
424
  _LOGGER.info("Loading utterances from dataset...")
283
- utterances = list(ljspeech_dataset(args))
425
+ utterances: List[Utterance] = list(ljspeech_dataset(config))
284
426
  if not utterances:
285
427
  _LOGGER.error("No valid utterances found in dataset.")
286
428
  return
287
429
 
288
- num_utterances = len(utterances)
430
+ num_utterances: int = len(utterances)
289
431
  _LOGGER.info("Found %d utterances.", num_utterances)
290
432
 
291
- # Count speakers
433
+ # Count speakers and assign IDs
292
434
  speaker_counts: Counter[str] = Counter(u.speaker for u in utterances if u.speaker)
293
- is_multispeaker = len(speaker_counts) > 1
435
+ is_multispeaker: bool = len(speaker_counts) > 1
294
436
  speaker_ids: Dict[str, int] = {}
295
437
  if is_multispeaker:
296
438
  _LOGGER.info("%s speakers detected", len(speaker_counts))
@@ -301,48 +443,47 @@ def main() -> None:
301
443
  _LOGGER.info("Single speaker dataset")
302
444
 
303
445
  # --- Single Pass: Process audio/phonemes and collect results ---
304
- # Set up multiprocessing
305
- args.max_workers = args.max_workers if args.max_workers is not None and args.max_workers > 0 else os.cpu_count()
306
- _LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
446
+ _LOGGER.info("Starting single pass processing with %d workers...", config.max_workers)
307
447
 
308
448
  # Initialize the phonemizer only once in the main process
309
- phonemizer = get_phonemizer(args.phoneme_type,
310
- args.alphabet,
311
- args.phonemizer_model)
449
+ phonemizer: Phonemizer = get_phonemizer(config.phoneme_type,
450
+ config.alphabet,
451
+ config.phonemizer_model)
312
452
 
313
- batch_size = max(1, int(num_utterances / (args.max_workers * 2)))
453
+ batch_size: int = max(1, int(num_utterances / (config.max_workers * 2)))
314
454
 
315
- task_queue: "Queue[Optional[List[Utterance]]]" = JoinableQueue()
455
+ task_queue: "JoinableQueue[Optional[List[Utterance]]]" = JoinableQueue()
316
456
  # The result queue will hold tuples of (Utterance, set(phonemes))
317
- result_queue: "Queue[Optional[Tuple[Utterance, Set[str]]]]" = Queue()
457
+ result_queue: "Queue[Tuple[Optional[Utterance], Set[str]]]" = Queue()
318
458
 
319
459
  # Start workers
320
- processes = [
460
+ processes: List[Process] = [
321
461
  Process(
322
462
  target=phonemize_worker,
323
- args=(args, task_queue, result_queue, phonemizer)
463
+ args=(config, task_queue, result_queue, phonemizer)
324
464
  )
325
- for _ in range(args.max_workers)
465
+ for _ in range(config.max_workers)
326
466
  ]
327
467
 
328
468
  for proc in processes:
329
469
  proc.start()
330
470
 
331
471
  # Populate the task queue with batches
332
- task_count = 0
472
+ task_count: int = 0
333
473
  for utt_batch in batched(utterances, batch_size):
334
474
  task_queue.put(utt_batch)
335
475
  task_count += len(utt_batch)
336
476
 
337
477
  # Signal workers to stop
338
- for _ in range(args.max_workers):
478
+ for _ in range(config.max_workers):
339
479
  task_queue.put(None)
340
480
 
341
481
  # Collect results from the queue with a progress bar
342
482
  processed_utterances: List[Utterance] = []
343
483
  all_phonemes: Set[str] = set()
344
484
  for _ in tqdm(range(task_count), desc="Processing utterances"):
345
- utt, unique_phonemes = result_queue.get()
485
+ result: Tuple[Optional[Utterance], Set[str]] = result_queue.get()
486
+ utt, unique_phonemes = result
346
487
  if utt is not None:
347
488
  processed_utterances.append(utt)
348
489
  all_phonemes.update(unique_phonemes)
@@ -355,15 +496,15 @@ def main() -> None:
355
496
  # --- Build the final phoneme map from the collected phonemes ---
356
497
  _LOGGER.info("Building a complete phoneme map from collected phonemes...")
357
498
 
358
- final_phoneme_id_map = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
499
+ final_phoneme_id_map: Dict[str, int] = DEFAULT_SPECIAL_PHONEME_ID_MAP.copy()
359
500
  if phonemizer.alphabet == Alphabet.IPA:
360
501
  all_phonemes.update(DEFAULT_IPA_PHONEME_ID_MAP.keys())
361
502
 
362
503
  # Filter out special tokens that are already in the map
363
- existing_keys = set(final_phoneme_id_map.keys())
364
- new_phonemes = sorted([p for p in all_phonemes if p not in existing_keys])
504
+ existing_keys: Set[str] = set(final_phoneme_id_map.keys())
505
+ new_phonemes: List[str] = sorted([p for p in all_phonemes if p not in existing_keys])
365
506
 
366
- current_id = len(final_phoneme_id_map)
507
+ current_id: int = len(final_phoneme_id_map)
367
508
  for pho in new_phonemes:
368
509
  final_phoneme_id_map[pho] = current_id
369
510
  current_id += 1
@@ -372,23 +513,23 @@ def main() -> None:
372
513
 
373
514
  # --- Write the final config.json ---
374
515
  _LOGGER.info("Writing dataset config...")
375
- audio_quality = args.audio_quality or args.output_dir.name
376
- dataset_name = args.dataset_name or args.output_dir.parent.name
516
+ audio_quality = config.audio_quality or config.output_dir.name
517
+ dataset_name = config.dataset_name or config.output_dir.parent.name
377
518
 
378
- config = {
519
+ config_data: Dict[str, Any] = {
379
520
  "dataset": dataset_name,
380
521
  "audio": {
381
- "sample_rate": args.sample_rate,
522
+ "sample_rate": config.sample_rate,
382
523
  "quality": audio_quality,
383
524
  },
384
- "lang_code": args.language,
525
+ "lang_code": config.language,
385
526
  "inference": {"noise_scale": 0.667,
386
527
  "length_scale": 1,
387
528
  "noise_w": 0.8,
388
- "add_diacritics": args.add_diacritics},
529
+ "add_diacritics": config.add_diacritics},
389
530
  "alphabet": phonemizer.alphabet.value,
390
- "phoneme_type": args.phoneme_type.value,
391
- "phonemizer_model": args.phonemizer_model,
531
+ "phoneme_type": config.phoneme_type.value,
532
+ "phonemizer_model": config.phonemizer_model,
392
533
  "phoneme_id_map": final_phoneme_id_map,
393
534
  "num_symbols": len(final_phoneme_id_map),
394
535
  "num_speakers": len(speaker_counts) if is_multispeaker else 1,
@@ -396,13 +537,13 @@ def main() -> None:
396
537
  "phoonnx_version": VERSION_STR,
397
538
  }
398
539
 
399
- with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
400
- json.dump(config, config_file, ensure_ascii=False, indent=2)
540
+ with open(config.output_dir / "config.json", "w", encoding="utf-8") as config_file:
541
+ json.dump(config_data, config_file, ensure_ascii=False, indent=2)
401
542
 
402
543
  # --- Apply final phoneme IDs and write dataset.jsonl ---
403
544
  _LOGGER.info("Writing dataset.jsonl...")
404
- valid_utterances_count = 0
405
- with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
545
+ valid_utterances_count: int = 0
546
+ with open(config.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
406
547
  for utt in processed_utterances:
407
548
  if is_multispeaker and utt.speaker is not None:
408
549
  if utt.speaker not in speaker_ids:
@@ -432,8 +573,17 @@ def main() -> None:
432
573
 
433
574
  # -----------------------------------------------------------------------------
434
575
 
435
- def batched(iterable, n):
436
- "Batch data into lists of length n. The last batch may be shorter."
576
+ def batched(iterable: Iterable[Any], n: int) -> Iterable[List[Any]]:
577
+ """
578
+ Batch data from an iterable into lists of length n. The last batch may be shorter.
579
+
580
+ Args:
581
+ iterable: The input iterable to be batched.
582
+ n: The desired size of each batch.
583
+
584
+ Yields:
585
+ List[Any]: A list representing a batch of items.
586
+ """
437
587
  if n < 1:
438
588
  raise ValueError("n must be at least one")
439
589
  it = iter(iterable)
@@ -444,4 +594,4 @@ def batched(iterable, n):
444
594
 
445
595
 
446
596
  if __name__ == "__main__":
447
- main()
597
+ cli()
phoonnx_train/train.py ADDED
@@ -0,0 +1,151 @@
1
+ import json
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import click
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+
10
+ from phoonnx_train.vits.lightning import VitsModel
11
+
12
+ _LOGGER = logging.getLogger(__package__)
13
+
14
+
15
+ def load_state_dict(model, saved_state_dict):
16
+ state_dict = model.state_dict()
17
+ new_state_dict = {}
18
+
19
+ for k, v in state_dict.items():
20
+ if k in saved_state_dict:
21
+ new_state_dict[k] = saved_state_dict[k]
22
+ else:
23
+ _LOGGER.debug("%s is not in the checkpoint", k)
24
+ new_state_dict[k] = v
25
+
26
+ model.load_state_dict(new_state_dict)
27
+
28
+
29
+ @click.command(context_settings=dict(ignore_unknown_options=True))
30
+ @click.option('--dataset-dir', required=True, type=click.Path(exists=True, file_okay=False), help='Path to pre-processed dataset directory')
31
+ @click.option('--checkpoint-epochs', default=1, type=int, help='Save checkpoint every N epochs (default: 1)')
32
+ @click.option('--quality', default='medium', type=click.Choice(['x-low', 'medium', 'high']), help='Quality/size of model (default: medium)')
33
+ @click.option('--resume-from-checkpoint', default=None, help='Load an existing checkpoint and resume training')
34
+ @click.option('--resume-from-single-speaker-checkpoint', help='For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training')
35
+ @click.option('--seed', type=int, default=1234, help='Random seed (default: 1234)')
36
+ # Common Trainer options
37
+ @click.option('--max-epochs', type=int, default=1000, help='Stop training once this number of epochs is reached (default: 1000)')
38
+ @click.option('--devices', default=1, help='Number of devices or list of device IDs to train on (default: 1)')
39
+ @click.option('--accelerator', default='auto', help='Hardware accelerator to use (cpu, gpu, tpu, mps, etc.) (default: "auto")')
40
+ @click.option('--default-root-dir', type=click.Path(file_okay=False), default=None, help='Default root directory for logs and checkpoints (default: None)')
41
+ @click.option('--precision', default=32, help='Precision used in training (e.g. 16, 32, bf16) (default: 32)')
42
+ # Model-specific arguments
43
+ @click.option('--learning-rate', type=float, default=2e-4, help='Learning rate for optimizer (default: 2e-4)')
44
+ @click.option('--batch-size', type=int, default=16, help='Training batch size (default: 16)')
45
+ @click.option('--num-workers', type=click.IntRange(min=1), default=1, help='Number of data loader workers (default: 1)')
46
+ @click.option('--validation-split', type=float, default=0.05, help='Proportion of data used for validation (default: 0.05)')
47
+ def main(
48
+ dataset_dir,
49
+ checkpoint_epochs,
50
+ quality,
51
+ resume_from_checkpoint,
52
+ resume_from_single_speaker_checkpoint,
53
+ seed,
54
+ max_epochs,
55
+ devices,
56
+ accelerator,
57
+ default_root_dir,
58
+ precision,
59
+ learning_rate,
60
+ batch_size,
61
+ num_workers,
62
+ validation_split,
63
+ ):
64
+ logging.basicConfig(level=logging.DEBUG)
65
+
66
+ dataset_dir = Path(dataset_dir)
67
+ if default_root_dir is None:
68
+ default_root_dir = dataset_dir
69
+
70
+ torch.backends.cudnn.benchmark = True
71
+ torch.manual_seed(seed)
72
+
73
+ config_path = dataset_dir / 'config.json'
74
+ dataset_path = dataset_dir / 'dataset.jsonl'
75
+
76
+ print(f"INFO - config_path: '{config_path}'")
77
+ print(f"INFO - dataset_path: '{dataset_path}'")
78
+
79
+ with open(config_path, 'r', encoding='utf-8') as config_file:
80
+ config = json.load(config_file)
81
+ num_symbols = int(config['num_symbols'])
82
+ num_speakers = int(config['num_speakers'])
83
+ sample_rate = int(config['audio']['sample_rate'])
84
+
85
+ trainer = Trainer(
86
+ max_epochs=max_epochs,
87
+ devices=devices,
88
+ accelerator=accelerator,
89
+ default_root_dir=default_root_dir,
90
+ precision=precision,
91
+ resume_from_checkpoint=resume_from_checkpoint
92
+ )
93
+
94
+ if checkpoint_epochs is not None:
95
+ trainer.callbacks = [ModelCheckpoint(every_n_epochs=checkpoint_epochs)]
96
+ _LOGGER.info('Checkpoints will be saved every %s epoch(s)', checkpoint_epochs)
97
+
98
+ dict_args = dict(
99
+ seed=seed,
100
+ learning_rate=learning_rate,
101
+ batch_size=batch_size,
102
+ num_workers=num_workers,
103
+ validation_split=validation_split,
104
+ )
105
+
106
+ if quality == 'x-low':
107
+ dict_args.update({
108
+ 'hidden_channels': 96,
109
+ 'inter_channels': 96,
110
+ 'filter_channels': 384,
111
+ })
112
+ elif quality == 'high':
113
+ dict_args.update({
114
+ 'resblock': '1',
115
+ 'resblock_kernel_sizes': (3, 7, 11),
116
+ 'resblock_dilation_sizes': ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
117
+ 'upsample_rates': (8, 8, 2, 2),
118
+ 'upsample_initial_channel': 512,
119
+ 'upsample_kernel_sizes': (16, 16, 4, 4),
120
+ })
121
+
122
+ print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
123
+ model = VitsModel(
124
+ num_symbols=num_symbols,
125
+ num_speakers=num_speakers,
126
+ sample_rate=sample_rate,
127
+ dataset=[dataset_path],
128
+ **dict_args,
129
+ )
130
+
131
+ if resume_from_single_speaker_checkpoint:
132
+ assert num_speakers > 1, "--resume-from-single-speaker-checkpoint is only for multi-speaker models."
133
+ _LOGGER.info('Resuming from single-speaker checkpoint: %s', resume_from_single_speaker_checkpoint)
134
+
135
+ model_single = VitsModel.load_from_checkpoint(resume_from_single_speaker_checkpoint, dataset=None)
136
+ g_dict = model_single.model_g.state_dict()
137
+
138
+ for key in list(g_dict.keys()):
139
+ if key.startswith('dec.cond') or key.startswith('dp.cond') or ('enc.cond_layer' in key):
140
+ g_dict.pop(key, None)
141
+
142
+ load_state_dict(model.model_g, g_dict)
143
+ load_state_dict(model.model_d, model_single.model_d.state_dict())
144
+ _LOGGER.info('Successfully converted single-speaker checkpoint to multi-speaker')
145
+
146
+ print('training started!!')
147
+ trainer.fit(model)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
phoonnx_train/__main__.py DELETED
@@ -1,151 +0,0 @@
1
- import argparse
2
- import json
3
- import logging
4
- from pathlib import Path
5
-
6
- import torch
7
- from pytorch_lightning import Trainer
8
- from pytorch_lightning.callbacks import ModelCheckpoint
9
-
10
- from phoonnx_train.vits.lightning import VitsModel
11
-
12
- _LOGGER = logging.getLogger(__package__)
13
-
14
-
15
- def main():
16
- logging.basicConfig(level=logging.DEBUG)
17
-
18
- parser = argparse.ArgumentParser()
19
- parser.add_argument(
20
- "--dataset-dir", required=True, help="Path to pre-processed dataset directory"
21
- )
22
- parser.add_argument(
23
- "--checkpoint-epochs",
24
- type=int,
25
- help="Save checkpoint every N epochs (default: 1)",
26
- )
27
- parser.add_argument(
28
- "--quality",
29
- default="medium",
30
- choices=("x-low", "medium", "high"),
31
- help="Quality/size of model (default: medium)",
32
- )
33
- parser.add_argument(
34
- "--resume_from_single_speaker_checkpoint",
35
- help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
36
- )
37
- Trainer.add_argparse_args(parser)
38
- VitsModel.add_model_specific_args(parser)
39
- parser.add_argument("--seed", type=int, default=1234)
40
- args = parser.parse_args()
41
- _LOGGER.debug(args)
42
-
43
- args.dataset_dir = Path(args.dataset_dir)
44
- if not args.default_root_dir:
45
- args.default_root_dir = args.dataset_dir
46
-
47
- torch.backends.cudnn.benchmark = True
48
- torch.manual_seed(args.seed)
49
-
50
- config_path = args.dataset_dir / "config.json"
51
- dataset_path = args.dataset_dir / "dataset.jsonl"
52
-
53
- print(f"INFO - config_path: '{config_path}'")
54
- print(f"INFO - dataset_path: '{dataset_path}'")
55
-
56
- with open(config_path, "r", encoding="utf-8") as config_file:
57
- # See preprocess.py for format
58
- config = json.load(config_file)
59
- num_symbols = int(config["num_symbols"])
60
- num_speakers = int(config["num_speakers"])
61
- sample_rate = int(config["audio"]["sample_rate"])
62
-
63
- trainer = Trainer.from_argparse_args(args)
64
- if args.checkpoint_epochs is not None:
65
- trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
66
- _LOGGER.info(
67
- "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
68
- )
69
-
70
- dict_args = vars(args)
71
- if args.quality == "x-low":
72
- dict_args["hidden_channels"] = 96
73
- dict_args["inter_channels"] = 96
74
- dict_args["filter_channels"] = 384
75
- elif args.quality == "high":
76
- dict_args["resblock"] = "1"
77
- dict_args["resblock_kernel_sizes"] = (3, 7, 11)
78
- dict_args["resblock_dilation_sizes"] = (
79
- (1, 3, 5),
80
- (1, 3, 5),
81
- (1, 3, 5),
82
- )
83
- dict_args["upsample_rates"] = (8, 8, 2, 2)
84
- dict_args["upsample_initial_channel"] = 512
85
- dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)
86
-
87
- print(f"VitsModel params: num_symbols={num_symbols} num_speakers={num_speakers} sample_rate={sample_rate}")
88
- model = VitsModel(
89
- num_symbols=num_symbols,
90
- num_speakers=num_speakers,
91
- sample_rate=sample_rate,
92
- dataset=[dataset_path],
93
- **dict_args,
94
- )
95
-
96
- if args.resume_from_single_speaker_checkpoint:
97
- assert (
98
- num_speakers > 1
99
- ), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
100
-
101
- # Load single-speaker checkpoint
102
- _LOGGER.info(
103
- "Resuming from single-speaker checkpoint: %s",
104
- args.resume_from_single_speaker_checkpoint,
105
- )
106
- model_single = VitsModel.load_from_checkpoint(
107
- args.resume_from_single_speaker_checkpoint,
108
- dataset=None,
109
- )
110
- g_dict = model_single.model_g.state_dict()
111
- for key in list(g_dict.keys()):
112
- # Remove keys that can't be copied over due to missing speaker embedding
113
- if (
114
- key.startswith("dec.cond")
115
- or key.startswith("dp.cond")
116
- or ("enc.cond_layer" in key)
117
- ):
118
- g_dict.pop(key, None)
119
-
120
- # Copy over the multi-speaker model, excluding keys related to the
121
- # speaker embedding (which is missing from the single-speaker model).
122
- load_state_dict(model.model_g, g_dict)
123
- load_state_dict(model.model_d, model_single.model_d.state_dict())
124
- _LOGGER.info(
125
- "Successfully converted single-speaker checkpoint to multi-speaker"
126
- )
127
- print("training started!!")
128
- trainer.fit(model)
129
-
130
-
131
- def load_state_dict(model, saved_state_dict):
132
- state_dict = model.state_dict()
133
- new_state_dict = {}
134
-
135
- for k, v in state_dict.items():
136
- if k in saved_state_dict:
137
- # Use saved value
138
- new_state_dict[k] = saved_state_dict[k]
139
- else:
140
- # Use initialized value
141
- _LOGGER.debug("%s is not in the checkpoint", k)
142
- new_state_dict[k] = v
143
-
144
- model.load_state_dict(new_state_dict)
145
-
146
-
147
- # -----------------------------------------------------------------------------
148
-
149
-
150
- if __name__ == "__main__":
151
- main()