phoonnx 0.1.0a1__py3-none-any.whl → 0.1.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,109 +1,360 @@
1
1
  #!/usr/bin/env python3
2
- import argparse
2
+ import click
3
3
  import logging
4
+ import json
5
+ import os
4
6
  from pathlib import Path
5
- from typing import Optional
7
+ from typing import Optional, Dict, Any, Tuple
6
8
 
7
9
  import torch
8
-
9
10
  from phoonnx_train.vits.lightning import VitsModel
11
+ from phoonnx.version import VERSION_STR
10
12
 
11
- _LOGGER = logging.getLogger("piper_train.export_onnx")
13
+ # Basic logging configuration
14
+ logging.basicConfig(level=logging.DEBUG)
15
+ _LOGGER = logging.getLogger("phoonnx_train.export_onnx")
12
16
 
17
+ # ONNX opset version
13
18
  OPSET_VERSION = 15
14
19
 
15
20
 
16
- def main() -> None:
17
- """Main entry point"""
18
- torch.manual_seed(1234)
21
+ # --- Utility Functions ---
19
22
 
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)")
22
- parser.add_argument("output", help="Path to output model (.onnx)")
23
+ def add_meta_data(filename: Path, meta_data: Dict[str, Any]) -> None:
24
+ """
25
+ Add meta data to an ONNX model. The file is modified in-place.
23
26
 
24
- parser.add_argument(
25
- "--debug", action="store_true", help="Print DEBUG messages to the console"
26
- )
27
- args = parser.parse_args()
27
+ Args:
28
+ filename:
29
+ Path to the ONNX model file to be changed.
30
+ meta_data:
31
+ Key-value pairs to be stored as metadata. Values will be converted to strings.
32
+ """
33
+ try:
34
+ import onnx
35
+
36
+ # Load the ONNX model
37
+ model = onnx.load(str(filename))
38
+
39
+ # Clear existing metadata and add new properties
40
+ del model.metadata_props[:]
41
+
42
+ for key, value in meta_data.items():
43
+ meta = model.metadata_props.add()
44
+ meta.key = key
45
+ # Convert all values to string for ONNX metadata
46
+ meta.value = str(value)
47
+
48
+ onnx.save(model, str(filename))
49
+ _LOGGER.info(f"Added {len(meta_data)} metadata key/value pairs to ONNX model: {filename}")
50
+
51
+ except ImportError:
52
+ _LOGGER.error("The 'onnx' package is required to add metadata. Please install it with 'pip install onnx'.")
53
+ except Exception as e:
54
+ _LOGGER.error(f"Failed to add metadata to ONNX file {filename}: {e}")
55
+
56
+
57
+ def export_tokens(config_path: Path, output_path: Path = Path("tokens.txt")) -> None:
58
+ """
59
+ Generates a tokens.txt file containing phoneme-to-id mapping from the model configuration.
60
+
61
+ The format is: `<phoneme> <id>` per line.
62
+
63
+ Args:
64
+ config_path: Path to the model configuration JSON file.
65
+ output_path: Path to save the resulting tokens.txt file.
66
+ """
67
+ try:
68
+ with open(config_path, "r", encoding="utf-8") as file:
69
+ config: Dict[str, Any] = json.load(file)
70
+ except Exception as e:
71
+ _LOGGER.error(f"Failed to load config file at {config_path}: {e}")
72
+ return
73
+
74
+ id_map: Optional[Dict[str, int]] = config.get("phoneme_id_map")
75
+ if not id_map:
76
+ _LOGGER.error("Could not find 'phoneme_id_map' in the config file.")
77
+ return
78
+
79
+ tokens_path = output_path
80
+ try:
81
+ with open(tokens_path, "w", encoding="utf-8") as f:
82
+ # Sort by ID to ensure a consistent output order
83
+ # The type hint for sorted_items is a list of tuples: List[Tuple[str, int]]
84
+ sorted_items: list[Tuple[str, int]] = sorted(id_map.items(), key=lambda item: item[1])
85
+
86
+ for s, i in sorted_items:
87
+ # Skip newlines or other invalid tokens if present in map
88
+ if s == "\n" or s == "":
89
+ continue
90
+ f.write(f"{s} {i}\n")
91
+
92
+ _LOGGER.info(f"Generated tokens file at {tokens_path}")
93
+ except Exception as e:
94
+ _LOGGER.error(f"Failed to write tokens file to {tokens_path}: {e}")
95
+
96
+
97
+ def convert_to_piper(config_path: Path, output_path: Path = Path("piper.json")) -> None:
98
+ """
99
+ Generates a Piper compatible JSON configuration file from the VITS model configuration.
100
+
101
+ This function currently serves as a placeholder for full Piper conversion logic.
102
+
103
+ Args:
104
+ config_path: Path to the VITS model configuration JSON file.
105
+ output_path: Path to save the resulting Piper JSON file.
106
+ """
107
+
108
+ with open(config_path, "r", encoding="utf-8") as file:
109
+ config: Dict[str, Any] = json.load(file)
110
+
111
+ piper_config = {
112
+ "phoneme_type": "espeak" if config.get("phoneme_type", "") == "espeak" else "raw",
113
+ "phoneme_map": {},
114
+ "audio": config.get("audio", {}),
115
+ "inference": config.get("inference", {}),
116
+ "phoneme_id_map": {k: [v] for k, v in config.get("phoneme_id_map", {}).items()},
117
+ "espeak": {
118
+ "voice": config.get("lang_code", "")
119
+ },
120
+ "language": {
121
+ "code": config.get("lang_code", "")
122
+ },
123
+ "num_symbols": config.get("num_symbols", 256),
124
+ "num_speakers": config.get("num_speakers", 1),
125
+ "speaker_id_map": {},
126
+ "piper_version": f"phoonnx-" + config.get("phoonnx_version", "0.0.0")
127
+ }
128
+
129
+ with open(output_path, "w", encoding="utf-8") as f:
130
+ json.dump(piper_config, f, indent=4, ensure_ascii=False)
28
131
 
29
- if args.debug:
30
- logging.basicConfig(level=logging.DEBUG)
31
- else:
32
- logging.basicConfig(level=logging.INFO)
33
132
 
34
- _LOGGER.debug(args)
133
+ # --- Main Logic using Click ---
134
+ @click.command(help="Export a VITS model checkpoint to ONNX format.")
135
+ @click.argument(
136
+ "checkpoint",
137
+ type=click.Path(exists=True, path_type=Path),
138
+ # help="Path to the PyTorch checkpoint file (*.ckpt)."
139
+ )
140
+ @click.option(
141
+ "-c",
142
+ "--config",
143
+ type=click.Path(exists=True, path_type=Path),
144
+ help="Path to the model configuration JSON file."
145
+ )
146
+ @click.option(
147
+ "-o",
148
+ "--output-dir",
149
+ type=click.Path(path_type=Path),
150
+ default=Path(os.getcwd()), # Set default to current working directory
151
+ help="Output directory for the ONNX model. (Default: current directory)"
152
+ )
153
+ @click.option(
154
+ "-t",
155
+ "--generate-tokens",
156
+ is_flag=True,
157
+ help="Generate tokens.txt alongside the ONNX model. Some inference engines need this (eg. sherpa)"
158
+ )
159
+ @click.option(
160
+ "-p",
161
+ "--piper",
162
+ is_flag=True,
163
+ help="Generate a piper compatible .json file alongside the ONNX model."
164
+ )
165
+ def cli(
166
+ checkpoint: Path,
167
+ config: Path,
168
+ output_dir: Path,
169
+ generate_tokens: bool,
170
+ piper: bool,
171
+ ) -> None:
172
+ """
173
+ Main entry point for exporting a VITS model checkpoint to ONNX format.
174
+
175
+ Args:
176
+ checkpoint: Path to the PyTorch checkpoint file (*.ckpt).
177
+ config: Path to the model configuration JSON file.
178
+ output_dir: Output directory for the ONNX model and associated files.
179
+ generate_tokens: Flag to generate a tokens.txt file.
180
+ piper: Flag to generate a piper compatible .json file.
181
+ """
182
+ torch.manual_seed(1234)
183
+
184
+ _LOGGER.debug(f"Arguments: {checkpoint=}, {config=}, {output_dir=}, {generate_tokens=}, {piper=}")
35
185
 
36
186
  # -------------------------------------------------------------------------
187
+ # Paths and Setup
188
+
189
+ # Create output directory if it doesn't exist
190
+ output_dir.mkdir(parents=True, exist_ok=True)
191
+ _LOGGER.debug(f"Output directory ensured: {output_dir}")
192
+
193
+ # Load the phoonnx configuration
194
+ try:
195
+ with open(config, "r", encoding="utf-8") as f:
196
+ model_config: Dict[str, Any] = json.load(f)
197
+ _LOGGER.info(f"Loaded phoonnx config from {config}")
198
+ except Exception as e:
199
+ _LOGGER.error(f"Error loading config file {config}: {e}")
200
+ return
201
+
202
+
203
+ alphabet: str = model_config.get("alphabet", "")
204
+ phoneme_type: str = model_config.get("phoneme_type", "")
205
+ phonemizer_model: str = model_config.get("phonemizer_model", "") # depends on phonemizer (eg. byt5)
206
+ piper_compatible: bool = alphabet == "ipa" and phoneme_type == "espeak"
37
207
 
38
- args.checkpoint = Path(args.checkpoint)
39
- args.output = Path(args.output)
40
- args.output.parent.mkdir(parents=True, exist_ok=True)
208
+ # Ensure mandatory keys exist before accessing
209
+ sample_rate: int = model_config.get("audio", {}).get("sample_rate", 22050)
210
+ phoneme_id_map: Dict[str, int] = model_config.get("phoneme_id_map", {})
41
211
 
42
- model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None)
43
- model_g = model.model_g
212
+ if piper:
213
+ if not piper_compatible:
214
+ _LOGGER.warning("only models trained with ipa + espeak should be exported to piper. phonemization is not included in exported model.")
215
+ # Generate the piper.json file
216
+ piper_output_path = output_dir / f"{checkpoint.name}.piper.json"
217
+ convert_to_piper(config, piper_output_path)
44
218
 
45
- num_symbols = model_g.n_vocab
46
- num_speakers = model_g.n_speakers
219
+ if generate_tokens:
220
+ # Generate the tokens.txt file
221
+ tokens_output_path = output_dir / f"{checkpoint.name}.tokens.txt"
222
+ export_tokens(config, tokens_output_path)
47
223
 
48
- # Inference only
224
+ # -------------------------------------------------------------------------
225
+ # Model Loading and Preparation
226
+ try:
227
+ model: VitsModel = VitsModel.load_from_checkpoint(
228
+ checkpoint,
229
+ dataset=None
230
+ )
231
+ except Exception as e:
232
+ _LOGGER.error(f"Error loading model checkpoint {checkpoint}: {e}")
233
+ return
234
+
235
+ model_g: torch.nn.Module = model.model_g
236
+ num_symbols: int = model_g.n_vocab
237
+ num_speakers: int = model_g.n_speakers
238
+
239
+ # Inference only setup
49
240
  model_g.eval()
50
241
 
51
242
  with torch.no_grad():
243
+ # Apply weight norm removal for inference mode
52
244
  model_g.dec.remove_weight_norm()
245
+ _LOGGER.debug("Removed weight normalization from decoder.")
246
+
247
+ # -------------------------------------------------------------------------
248
+ # Define ONNX-compatible forward function
249
+
250
+ def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, scales: torch.Tensor, sid: Optional[torch.Tensor] = None) -> torch.Tensor:
251
+ """
252
+ Custom forward pass for ONNX export, simplifying the input scales and
253
+ returning only the audio tensor with shape [B, 1, T].
53
254
 
54
- # old_forward = model_g.infer
255
+ Args:
256
+ text: Input phoneme sequence tensor, shape [B, T_in].
257
+ text_lengths: Tensor of sequence lengths, shape [B].
258
+ scales: Tensor containing [noise_scale, length_scale, noise_scale_w], shape [3].
259
+ sid: Optional speaker ID tensor, shape [B], for multi-speaker models.
55
260
 
56
- def infer_forward(text, text_lengths, scales, sid=None):
57
- noise_scale = scales[0]
58
- length_scale = scales[1]
59
- noise_scale_w = scales[2]
60
- audio = model_g.infer(
261
+ Returns:
262
+ Generated audio tensor, shape [B, 1, T_out].
263
+ """
264
+ noise_scale: float = scales[0]
265
+ length_scale: float = scales[1]
266
+ noise_scale_w: float = scales[2]
267
+
268
+ # model_g.infer returns a tuple: (audio, attn, ids_slice, x_mask, z, z_mask, g)
269
+ audio: torch.Tensor = model_g.infer(
61
270
  text,
62
271
  text_lengths,
63
272
  noise_scale=noise_scale,
64
273
  length_scale=length_scale,
65
274
  noise_scale_w=noise_scale_w,
66
275
  sid=sid,
67
- )[0].unsqueeze(1)
276
+ )[0].unsqueeze(1) # [0] gets the audio tensor. unsqueeze(1) makes it [B, 1, T]
68
277
 
69
278
  return audio
70
279
 
280
+ # Replace the default forward with the inference one for ONNX export
71
281
  model_g.forward = infer_forward
72
282
 
73
- dummy_input_length = 50
74
- sequences = torch.randint(
283
+ # -------------------------------------------------------------------------
284
+ # Dummy Input Generation
285
+
286
+ dummy_input_length: int = 50
287
+ sequences: torch.Tensor = torch.randint(
75
288
  low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
76
289
  )
77
- sequence_lengths = torch.LongTensor([sequences.size(1)])
290
+ sequence_lengths: torch.Tensor = torch.LongTensor([sequences.size(1)])
78
291
 
79
292
  sid: Optional[torch.LongTensor] = None
293
+ input_names: list[str] = ["input", "input_lengths", "scales"]
294
+ dynamic_axes_map: Dict[str, Dict[int, str]] = {
295
+ "input": {0: "batch_size", 1: "phonemes"},
296
+ "input_lengths": {0: "batch_size"},
297
+ "output": {0: "batch_size", 1: "time"},
298
+ }
299
+
80
300
  if num_speakers > 1:
81
301
  sid = torch.LongTensor([0])
302
+ input_names.append("sid")
303
+ dynamic_axes_map["sid"] = {0: "batch_size"}
304
+ _LOGGER.debug(f"Multi-speaker model detected (n_speakers={num_speakers}). 'sid' included.")
82
305
 
83
- # noise, noise_w, length
84
- scales = torch.FloatTensor([0.667, 1.0, 0.8])
85
- dummy_input = (sequences, sequence_lengths, scales, sid)
306
+ # noise, length, noise_w scales (hardcoded defaults)
307
+ scales: torch.Tensor = torch.FloatTensor([0.667, 1.0, 0.8])
308
+ dummy_input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.LongTensor]] = (
309
+ sequences, sequence_lengths, scales, sid
310
+ )
86
311
 
312
+ # -------------------------------------------------------------------------
87
313
  # Export
88
- torch.onnx.export(
89
- model=model_g,
90
- args=dummy_input,
91
- f=str(args.output),
92
- verbose=False,
93
- opset_version=OPSET_VERSION,
94
- input_names=["input", "input_lengths", "scales", "sid"],
95
- output_names=["output"],
96
- dynamic_axes={
97
- "input": {0: "batch_size", 1: "phonemes"},
98
- "input_lengths": {0: "batch_size"},
99
- "output": {0: "batch_size", 1: "time"},
100
- },
101
- )
314
+ model_output: Path = output_dir / f"{checkpoint.name}.onnx"
315
+ _LOGGER.info(f"Starting ONNX export to {model_output} (opset={OPSET_VERSION})...")
316
+
317
+ try:
318
+ torch.onnx.export(
319
+ model=model_g,
320
+ args=dummy_input,
321
+ f=str(model_output),
322
+ verbose=False,
323
+ opset_version=OPSET_VERSION,
324
+ input_names=input_names,
325
+ output_names=["output"],
326
+ dynamic_axes=dynamic_axes_map,
327
+ )
328
+ _LOGGER.info(f"Successfully exported model to {model_output}")
329
+ except Exception as e:
330
+ _LOGGER.error(f"Failed during torch.onnx.export: {e}")
331
+ return
332
+
333
+ # -------------------------------------------------------------------------
334
+ # Add Metadata
335
+ metadata_dict: Dict[str, Any] = {
336
+ "model_type": "vits",
337
+ "n_speakers": num_speakers,
338
+ "n_vocab": num_symbols,
339
+ "sample_rate": sample_rate,
340
+ "alphabet": alphabet,
341
+ "phoneme_type": phoneme_type,
342
+ "phonemizer_model": phonemizer_model,
343
+ "phoneme_id_map": json.dumps(phoneme_id_map),
344
+ "has_espeak": phoneme_type == "espeak"
345
+ }
346
+ if piper_compatible:
347
+ metadata_dict["comment"] = "piper"
348
+
349
+ try:
350
+ add_meta_data(model_output, metadata_dict)
351
+ except Exception as e:
352
+ _LOGGER.error(f"Failed to add metadata to exported model {model_output}: {e}")
102
353
 
103
- _LOGGER.info("Exported model to %s", args.output)
354
+ _LOGGER.info("Export complete.")
104
355
 
105
356
 
106
357
  # -----------------------------------------------------------------------------
107
358
 
108
359
  if __name__ == "__main__":
109
- main()
360
+ cli()
@@ -19,8 +19,8 @@ from phoonnx.phoneme_ids import (phonemes_to_ids, DEFAULT_IPA_PHONEME_ID_MAP, DE
19
19
  DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BLANK_WORD_TOKEN)
20
20
  from phoonnx_train.norm_audio import cache_norm_audio, make_silence_detector
21
21
  from tqdm import tqdm
22
+ from phoonnx.version import VERSION_STR
22
23
 
23
- _VERSION = "0.0.0"
24
24
  _LOGGER = logging.getLogger("preprocess")
25
25
 
26
26
  # Base phoneme map
@@ -105,7 +105,9 @@ def ljspeech_dataset(args: argparse.Namespace) -> Iterable[Utterance]:
105
105
 
106
106
  wav_path = None
107
107
  for wav_dir in wav_dirs:
108
- potential_paths = [wav_dir / filename, wav_dir / f"{filename}.wav"]
108
+ potential_paths = [wav_dir / filename,
109
+ wav_dir / f"{filename}.wav",
110
+ wav_dir / f"{filename.lstrip('0')}.wav"]
109
111
  for path in potential_paths:
110
112
  if path.exists():
111
113
  wav_path = path
@@ -153,9 +155,17 @@ def phonemize_worker(
153
155
 
154
156
  for utt in utterance_batch:
155
157
  try:
158
+ # normalize text (case, numbers....)
159
+ utterance = casing(normalize( utt.text, args.language))
160
+
161
+ # add diacritics
162
+ if args.add_diacritics:
163
+ utterance = phonemizer.add_diacritics(utterance, args.language)
164
+
156
165
  # Phonemize the text
157
- norm_utt = casing(normalize(utt.text, args.language))
158
- utt.phonemes = phonemizer.phonemize_to_list(norm_utt, args.language)
166
+ utt.phonemes = phonemizer.phonemize_to_list(utterance, args.language)
167
+ if not utt.phonemes:
168
+ raise RuntimeError(f"Phonemes not found for '{utterance}'")
159
169
 
160
170
  # Process audio if not skipping
161
171
  if not args.skip_audio:
@@ -242,6 +252,9 @@ def main() -> None:
242
252
  parser.add_argument(
243
253
  "--debug", action="store_true", help="Print DEBUG messages to the console"
244
254
  )
255
+ parser.add_argument(
256
+ "--add-diacritics", action="store_true", help="Add diacritics to text (phonemizer specific)"
257
+ )
245
258
  args = parser.parse_args()
246
259
 
247
260
  # Setup
@@ -293,7 +306,9 @@ def main() -> None:
293
306
  _LOGGER.info("Starting single pass processing with %d workers...", args.max_workers)
294
307
 
295
308
  # Initialize the phonemizer only once in the main process
296
- phonemizer = get_phonemizer(args.phoneme_type, args.alphabet, args.phonemizer_model)
309
+ phonemizer = get_phonemizer(args.phoneme_type,
310
+ args.alphabet,
311
+ args.phonemizer_model)
297
312
 
298
313
  batch_size = max(1, int(num_utterances / (args.max_workers * 2)))
299
314
 
@@ -367,7 +382,10 @@ def main() -> None:
367
382
  "quality": audio_quality,
368
383
  },
369
384
  "lang_code": args.language,
370
- "inference": {"noise_scale": 0.667, "length_scale": 1, "noise_w": 0.8},
385
+ "inference": {"noise_scale": 0.667,
386
+ "length_scale": 1,
387
+ "noise_w": 0.8,
388
+ "add_diacritics": args.add_diacritics},
371
389
  "alphabet": phonemizer.alphabet.value,
372
390
  "phoneme_type": args.phoneme_type.value,
373
391
  "phonemizer_model": args.phonemizer_model,
@@ -375,7 +393,7 @@ def main() -> None:
375
393
  "num_symbols": len(final_phoneme_id_map),
376
394
  "num_speakers": len(speaker_counts) if is_multispeaker else 1,
377
395
  "speaker_id_map": speaker_ids,
378
- "phoonnx_version": _VERSION,
396
+ "phoonnx_version": VERSION_STR,
379
397
  }
380
398
 
381
399
  with open(args.output_dir / "config.json", "w", encoding="utf-8") as config_file:
@@ -383,15 +401,23 @@ def main() -> None:
383
401
 
384
402
  # --- Apply final phoneme IDs and write dataset.jsonl ---
385
403
  _LOGGER.info("Writing dataset.jsonl...")
404
+ valid_utterances_count = 0
386
405
  with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
387
406
  for utt in processed_utterances:
388
- if utt.speaker is not None:
407
+ if is_multispeaker and utt.speaker is not None:
408
+ if utt.speaker not in speaker_ids:
409
+ _LOGGER.error("Speaker '%s' not in speaker_id_map. This indicates an issue with your metadata.csv file.", utt.speaker)
410
+ continue
389
411
  utt.speaker_id = speaker_ids[utt.speaker]
390
412
 
391
413
  # Apply the final phoneme ID map to each utterance
392
414
  if utt.phonemes:
393
415
  utt.phoneme_ids = phonemes_to_ids(utt.phonemes, id_map=final_phoneme_id_map)
394
416
 
417
+ if not utt.phoneme_ids:
418
+ _LOGGER.warning("Skipping utterance with invalid phoneme_ids before writing: %s", utt.audio_path)
419
+ continue
420
+
395
421
  json.dump(
396
422
  utt.asdict(),
397
423
  dataset_file,
@@ -399,8 +425,9 @@ def main() -> None:
399
425
  cls=PathEncoder,
400
426
  )
401
427
  print("", file=dataset_file)
428
+ valid_utterances_count += 1
402
429
 
403
- _LOGGER.info("Preprocessing complete.")
430
+ _LOGGER.info("Preprocessing complete. Wrote %d valid utterances to dataset.jsonl.", valid_utterances_count)
404
431
 
405
432
 
406
433
  # -----------------------------------------------------------------------------
@@ -69,6 +69,8 @@ class PiperDataset(Dataset):
69
69
  self.utterances.extend(
70
70
  PiperDataset.load_dataset(dataset_path, max_phoneme_ids=max_phoneme_ids)
71
71
  )
72
+ if not self.utterances:
73
+ raise ValueError("No utterances loaded")
72
74
 
73
75
  def __len__(self):
74
76
  return len(self.utterances)
@@ -120,6 +122,8 @@ class PiperDataset(Dataset):
120
122
  @staticmethod
121
123
  def load_utterance(line: str) -> Utterance:
122
124
  utt_dict = json.loads(line)
125
+ if not utt_dict["phoneme_ids"]:
126
+ raise ValueError(f"invalid utterance line - phoneme_ids not set ({line})")
123
127
  return Utterance(
124
128
  phoneme_ids=utt_dict["phoneme_ids"],
125
129
  audio_norm_path=Path(utt_dict["audio_norm_path"]),