phoonnx 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ # taken from https://github.com/stannam/hangul_to_ipa
1
2
  import csv
2
3
  import math
3
4
  import os.path
@@ -3,14 +3,6 @@ from phoonnx.thirdparty.mantoq.buck.tokenization import (arabic_to_phonemes, pho
3
3
  phonemes_to_tokens, simplify_phonemes)
4
4
  from phoonnx.thirdparty.mantoq.buck.tokenization import tokens_to_ids as _tokens_to_id
5
5
  from phoonnx.thirdparty.mantoq.num2words import num2words
6
- import warnings
7
- from phoonnx.thirdparty.tashkeel import TashkeelDiacritizer
8
- try:
9
- import onnxruntime
10
-
11
- _TASHKEEL_AVAILABLE = True
12
- except ImportError:
13
- _TASHKEEL_AVAILABLE = False
14
6
 
15
7
  _DIACRITIZER_INST = None
16
8
 
@@ -29,29 +21,12 @@ QUOTES_TABLE = str.maketrans(QUOTES, '"' * len(QUOTES))
29
21
  BRACKETS_TABLE = str.maketrans("[]{}", "()()")
30
22
 
31
23
 
32
-
33
-
34
- def tashkeel(text: str) -> str:
35
- global _DIACRITIZER_INST
36
- if not _TASHKEEL_AVAILABLE:
37
- warnings.warn(
38
- "Warning: The Tashkeel feature will not be available. Please re-install with the `libtashkeel` extra.",
39
- UserWarning,
40
- )
41
- return text
42
- if _DIACRITIZER_INST is None:
43
- _DIACRITIZER_INST = TashkeelDiacritizer()
44
- return _DIACRITIZER_INST.diacritize(text)
45
-
46
24
  def g2p(
47
25
  text: str,
48
- add_tashkeel: bool = True,
49
26
  process_numbers: bool = True,
50
27
  append_eos: bool = False,
51
- ) -> list[str]:
28
+ ) -> tuple[str, list[str]]:
52
29
  text = text.translate(AR_SPECIAL_PUNCS_TABLE).translate(QUOTES_TABLE).translate(BRACKETS_TABLE)
53
- if add_tashkeel:
54
- text = tashkeel(text)
55
30
  if process_numbers:
56
31
  text = num2words(text)
57
32
  normalized_text = text
@@ -0,0 +1,24 @@
1
+ import os
2
+ import requests
3
+
4
+
5
+ class PhonikudDiacritizer:
6
+ dl_url = "https://huggingface.co/thewh1teagle/phonikud-onnx/resolve/main/phonikud-1.0.int8.onnx"
7
+
8
+ def __init__(self):
9
+
10
+ base_path = os.path.expanduser("~/.local/share/phonikud")
11
+ fname = self.dl_url.split("/")[-1]
12
+ model = f"{base_path}/{fname}"
13
+ if not os.path.isfile(model):
14
+ os.makedirs(base_path, exist_ok=True)
15
+ # TODO - streaming download
16
+ data = requests.get(self.dl_url).content
17
+ with open(model, "wb") as f:
18
+ f.write(data)
19
+
20
+ from phonikud_onnx import Phonikud
21
+ self.phonikud = Phonikud(model)
22
+
23
+ def diacritize(self, text: str) -> str:
24
+ return self.phonikud.add_diacritics(text)
phoonnx/version.py CHANGED
@@ -2,5 +2,9 @@
2
2
  VERSION_MAJOR = 0
3
3
  VERSION_MINOR = 1
4
4
  VERSION_BUILD = 0
5
- VERSION_ALPHA = 1
5
+ VERSION_ALPHA = 3
6
6
  # END_VERSION_BLOCK
7
+
8
+ VERSION_STR = f"{VERSION_MAJOR}.{VERSION_MINOR}.{VERSION_BUILD}"
9
+ if VERSION_ALPHA:
10
+ VERSION_STR += f"a{VERSION_ALPHA}"
phoonnx/voice.py CHANGED
@@ -14,7 +14,6 @@ from phoonnx.config import PhonemeType, VoiceConfig, SynthesisConfig, get_phonem
14
14
  from phoonnx.phoneme_ids import phonemes_to_ids, BlankBetween
15
15
  from phoonnx.phonemizers import Phonemizer
16
16
  from phoonnx.phonemizers.base import PhonemizedChunks
17
- from phoonnx.thirdparty.tashkeel import TashkeelDiacritizer
18
17
 
19
18
  _PHONEME_BLOCK_PATTERN = re.compile(r"(\[\[.*?\]\])")
20
19
 
@@ -113,11 +112,6 @@ class TTSVoice:
113
112
 
114
113
  phonemizer: Optional[Phonemizer] = None
115
114
 
116
- # For Arabic text only
117
- use_tashkeel: bool = True
118
- tashkeel_diacritizier: Optional[TashkeelDiacritizer] = None # For Arabic text only
119
- taskeen_threshold: Optional[float] = 0.8
120
-
121
115
  def __post_init__(self):
122
116
  try:
123
117
  self.phonetic_spellings = PhoneticSpellings.from_lang(self.config.lang_code)
@@ -128,10 +122,6 @@ class TTSVoice:
128
122
  self.config.alphabet,
129
123
  self.config.phonemizer_model)
130
124
 
131
- # compat with piper arabic models - TODO move to espeak phonemizer
132
- if self.config.lang_code.split("-")[0] == "ar" and self.use_tashkeel and self.tashkeel_diacritizier is None:
133
- self.tashkeel_diacritizier = TashkeelDiacritizer()
134
-
135
125
  @staticmethod
136
126
  def load(
137
127
  model_path: Union[str, Path],
@@ -209,12 +199,6 @@ class TTSVoice:
209
199
 
210
200
  continue
211
201
 
212
- # Arabic diacritization
213
- if self.config.lang_code.split("-")[0] == "ar" and self.use_tashkeel:
214
- text_part = self.tashkeel_diacritizier(
215
- text_part, taskeen_threshold=self.taskeen_threshold
216
- )
217
-
218
202
  # Phonemization
219
203
  phonemes = self.phonemizer.phonemize(
220
204
  text_part, self.config.lang_code
@@ -267,6 +251,10 @@ class TTSVoice:
267
251
  if self.phonetic_spellings and syn_config.enable_phonetic_spellings:
268
252
  text = self.phonetic_spellings.apply(text)
269
253
 
254
+ if syn_config.add_diacritics:
255
+ text = self.phonemizer.add_diacritics(text, self.config.lang_code)
256
+ LOG.debug("text+diacritics=%s", text)
257
+
270
258
  # All phonemization goes through the unified self.phonemize method
271
259
  sentence_phonemes = self.phonemize(text)
272
260
  LOG.debug("phonemes=%s", sentence_phonemes)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: phoonnx
3
- Version: 0.1.0a1
3
+ Version: 0.1.0a3
4
4
  Home-page: https://github.com/TigreGotico/phoonnx
5
5
  Author: JarbasAi
6
6
  Author-email: jarbasai@mailfence.com
@@ -220,6 +220,7 @@ Requires-Dist: librosa<1,>=0.9.2; extra == "train"
220
220
  Requires-Dist: numpy<2,>=1.19.0; extra == "train"
221
221
  Requires-Dist: pytorch-lightning<2.0; extra == "train"
222
222
  Requires-Dist: torch<2,>=1.11.0; extra == "train"
223
+ Requires-Dist: click; extra == "train"
223
224
  Provides-Extra: uew
224
225
  Requires-Dist: epitran; extra == "uew"
225
226
  Provides-Extra: ug
@@ -1,29 +1,29 @@
1
1
  phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- phoonnx/config.py,sha256=81H34oPG2BaiOA6UM1KapoT341n068LqRprKb5ER6mY,19451
2
+ phoonnx/config.py,sha256=IYhC-kYjLgYmBroId6YeOE2Vp7SMNGtiGqIIe_09NJk,19531
3
3
  phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
4
4
  phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
5
- phoonnx/version.py,sha256=95gLFCt-8xv9DgF7FIF6CljWmhm8SUhevumEBfo7Pl0,114
6
- phoonnx/voice.py,sha256=FR_LafK1vSi_anPERJjZBuH3Bb9vUIof0MAW6TnALlA,20024
5
+ phoonnx/version.py,sha256=WnY5J2wtSTore9QbKwfk04gQhBsYq4HVmV5CBjEhGnk,236
6
+ phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
7
7
  phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
8
8
  phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
9
9
  phoonnx/locale/gl/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
10
10
  phoonnx/locale/pt/phonetic_spellings.txt,sha256=KntS8QMynEJ5A3Clvcjq4qlmL-ThSbhfD6v0nKSrlqs,49
11
11
  phoonnx/phonemizers/__init__.py,sha256=QGBZk0QUgJdg2MwUWY9Kpk6ucwrEJYtHb07YcNvXCV4,1647
12
- phoonnx/phonemizers/ar.py,sha256=29bCfYhlhx0QX3PQyx3EkUghzh8YfkxNAnMAICXX6I8,4148
13
- phoonnx/phonemizers/base.py,sha256=yPg6-dvscYpl3rR3JEULG1PRF-i8DWC_C3HAZGLbxOo,7648
12
+ phoonnx/phonemizers/ar.py,sha256=xxILq5iyH0kcI-NqFfRK4abGtpdUbykBjt_dZmPuO2w,3216
13
+ phoonnx/phonemizers/base.py,sha256=FHvAsvSjAl_oSa1GoeEi96CQ_JO_xkKXWq0ukuMxiuo,8660
14
14
  phoonnx/phonemizers/en.py,sha256=N2SVoVhplQao7Ej5TXbxJU-YkAgkY0Fr9iYBFnsjFSE,9271
15
15
  phoonnx/phonemizers/fa.py,sha256=d_DZM2wqomf4gcRH_rFcNA3VkQWKHru8vwBwaNG8Ll8,1452
16
16
  phoonnx/phonemizers/gl.py,sha256=jEFKJJViHufZtB7lGNwWQCdWGiNKDCVZ_GRYXTaw_2c,6614
17
- phoonnx/phonemizers/he.py,sha256=KbRI3XRZa8UtJdNWmn_fd-t5lmFSIp4Mw8UgcO5l-Po,2211
17
+ phoonnx/phonemizers/he.py,sha256=49OFS34wSFvvR9B3z2bGSzSLmlIvnn2HtkHBOkHS9Ns,1383
18
18
  phoonnx/phonemizers/ja.py,sha256=Xojsrt715ihnIiEk9K6giYqDo9Iykw-SHfIidrHtHSU,3834
19
19
  phoonnx/phonemizers/ko.py,sha256=kwWoOFqanCB8kv2JRx17A0hP78P1wbXlX6e8VBn1ezQ,2989
20
- phoonnx/phonemizers/mul.py,sha256=37G_G58aGnVpdEm9vZEAOdGEHJ9TLBE17bU1HFvQ2rU,27291
20
+ phoonnx/phonemizers/mul.py,sha256=-h6uN_laUD-unNRGThzjyiOZpN6pSl4uinCndg5-0TA,94184
21
21
  phoonnx/phonemizers/vi.py,sha256=_XJc-Xeawr1Lxr7o8mE_hJao1aGcj4g01XYAOxC_Scg,1311
22
22
  phoonnx/phonemizers/zh.py,sha256=88Ywq8h9LDanlyz8RHjRSCY_PRK_Dq808tBADyrgaP8,9657
23
23
  phoonnx/thirdparty/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  phoonnx/thirdparty/arpa2ipa.py,sha256=Uj1G5NgP5oBBfSm26LGB8QoumdT-NqCLQTZHT165-_o,5850
25
25
  phoonnx/thirdparty/bw2ipa.py,sha256=5FiWC4AP4KXkqtbclbinoXEsUnSYEjk4VWAPasMMcbg,2328
26
- phoonnx/thirdparty/hangul2ipa.py,sha256=e2c0WOy5lFMcf6GS7pNqIbauMKBX07S84lCczZAZJGA,27518
26
+ phoonnx/thirdparty/hangul2ipa.py,sha256=Pj06lL-GkOH4ZkLuakwQAT045fEVsijGhwoY_EEEVKc,27572
27
27
  phoonnx/thirdparty/zh_num.py,sha256=SESA6gvSJW3LZ0FLoybXn2SpbxqhQTi9Tg_U2IZ5JYY,7147
28
28
  phoonnx/thirdparty/cotovia/cotovia_aarch64,sha256=BsAWZN452Lm9kDU4i6rQGHFSlmxP3GfHRKhbJMUQrfA,6764592
29
29
  phoonnx/thirdparty/cotovia/cotovia_x86_64,sha256=-6BNx_cd49nnDreOAsGtVtePs_X76esrqcNAfmksN1o,1379832
@@ -37,7 +37,7 @@ phoonnx/thirdparty/ko_tables/tensification.csv,sha256=V4Xf3A1G1iMBzwZevBKQuk_lPa
37
37
  phoonnx/thirdparty/ko_tables/yale.csv,sha256=UhtDbPXRAAyAKoQMXmwhVBwJ5pfZQ_Duk28qBtRUdsU,297
38
38
  phoonnx/thirdparty/kog2p/__init__.py,sha256=yLizadg7RXM-3dQyftD4XSk8r2jb0QOlHQ6as9uUa4U,10267
39
39
  phoonnx/thirdparty/kog2p/rulebook.txt,sha256=FQE3nej8wojl6ilVUBYo7f8bIk0Hjci-B7HPXhM-xNc,9303
40
- phoonnx/thirdparty/mantoq/__init__.py,sha256=4kZuZ3RA5ZhQwTOQGkHF9jQYSvetNTn9uWi5Dsx101k,2106
40
+ phoonnx/thirdparty/mantoq/__init__.py,sha256=02FftO4Onmp_S-XdukbBQ3aRVvqEQyo1frCLWgcF9cY,1428
41
41
  phoonnx/thirdparty/mantoq/num2words.py,sha256=9-ncMtxV1FusD9rNur1lu7l2DWhwUwI1mFiqiPSMH_Q,1264
42
42
  phoonnx/thirdparty/mantoq/unicode_symbol2label.py,sha256=CeZNv7qWeQS4Ejvz-sKgK--5eNYdVVv04WHPaOeK4gk,259409
43
43
  phoonnx/thirdparty/mantoq/buck/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -53,6 +53,7 @@ phoonnx/thirdparty/mantoq/pyarabic/number.py,sha256=NjFZPWRu-9dZDLgxfv9oDjmh-kWY
53
53
  phoonnx/thirdparty/mantoq/pyarabic/number_const.py,sha256=vAvRVENxTrl9gWPllSXF-yqK9fAW6htuA2d041btC_A,42361
54
54
  phoonnx/thirdparty/mantoq/pyarabic/stack.py,sha256=aJeSzQxVNdomDTWXuxIXWXVOc2BW_3iRWnwmBLkB8jM,1022
55
55
  phoonnx/thirdparty/mantoq/pyarabic/trans.py,sha256=cusyHk9Y01iuvMLJXxgCnIiGyAORzEdSosDKX4cAhPc,13713
56
+ phoonnx/thirdparty/phonikud/__init__.py,sha256=g1dCelCZbwlKT0Ibaky6Ckp59wMw5g_1DDyDXauqFTg,760
56
57
  phoonnx/thirdparty/tashkeel/LICENSE,sha256=mQjTJ6MGAXzmYkO7x4O2VuEeSwCMx7lncbc26TnrVjw,1067
57
58
  phoonnx/thirdparty/tashkeel/SOURCE,sha256=SmnRz-Am5EXv-n2-RokJVEhnn8zeF1QZJVvMQDA_Qds,38
58
59
  phoonnx/thirdparty/tashkeel/__init__.py,sha256=FRdGNCTQaai9X077vlNh4tFOvWgm1U2lIUgnQKO5q0s,7119
@@ -61,8 +62,8 @@ phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771l
61
62
  phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
62
63
  phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
63
64
  phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
64
- phoonnx_train/export_onnx.py,sha256=dcFJRZl4YvBk_Dj3j0aNAQVEqKfBHTzV22pzvQwSETQ,2909
65
- phoonnx_train/preprocess.py,sha256=0kto9Holywby6lnoQucBXq2wYEKDItRvdkvYbQnLJeo,14447
65
+ phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
66
+ phoonnx_train/preprocess.py,sha256=8_Opy5QVNjVmSVmh1_IF23bcNebVIEXuK2KcollIy28,15793
66
67
  phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
67
68
  phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
68
69
  phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
@@ -70,7 +71,7 @@ phoonnx_train/vits/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuF
70
71
  phoonnx_train/vits/attentions.py,sha256=yc_ViF8zR8z68DzphmVVVn27f9xK_5wi8S4ITLXVQL0,15134
71
72
  phoonnx_train/vits/commons.py,sha256=JsD8CdZ3ZcYYubYhw8So5hICBziFlCrKLrv1lMDRCDM,4645
72
73
  phoonnx_train/vits/config.py,sha256=oSuUIhw9Am7BQ5JwDgtCO-P1zRyN7nPgR-U1XuncJls,10789
73
- phoonnx_train/vits/dataset.py,sha256=DLLGSCkn3GF9uktoTprH1ERblZ18GO6-QsClQKWa98o,6804
74
+ phoonnx_train/vits/dataset.py,sha256=1V1tVh5dSLjFMBsuzrAsoGtYWSBT4iU64Jdqi8oG-y0,7016
74
75
  phoonnx_train/vits/lightning.py,sha256=ZBuSIiJ7EUU1Za2V8Uh6-_HGGRW_qwpXLLs1cEDirHA,12301
75
76
  phoonnx_train/vits/losses.py,sha256=j-uINhBcYxVXFvFutiewQpTuw-qF-J6M6hdJVeOKqNE,1401
76
77
  phoonnx_train/vits/mel_processing.py,sha256=huIjbQgewSmM39hdzRZvZUCI7fTNSMmLcAv3f8zYb8k,3956
@@ -81,7 +82,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
81
82
  phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
82
83
  phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
83
84
  phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
84
- phoonnx-0.1.0a1.dist-info/METADATA,sha256=9FZiRhA48da6ZbX1qCrKKVqsWMWQwfedz-bUXATd6Sk,8145
85
- phoonnx-0.1.0a1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
86
- phoonnx-0.1.0a1.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
87
- phoonnx-0.1.0a1.dist-info/RECORD,,
85
+ phoonnx-0.1.0a3.dist-info/METADATA,sha256=3U1Ea0g2HxtWPsIs7NCxzPdo7ZTr4s_lRs9gIOC6MWY,8184
86
+ phoonnx-0.1.0a3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
87
+ phoonnx-0.1.0a3.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
88
+ phoonnx-0.1.0a3.dist-info/RECORD,,
@@ -1,109 +1,360 @@
1
1
  #!/usr/bin/env python3
2
- import argparse
2
+ import click
3
3
  import logging
4
+ import json
5
+ import os
4
6
  from pathlib import Path
5
- from typing import Optional
7
+ from typing import Optional, Dict, Any, Tuple
6
8
 
7
9
  import torch
8
-
9
10
  from phoonnx_train.vits.lightning import VitsModel
11
+ from phoonnx.version import VERSION_STR
10
12
 
11
- _LOGGER = logging.getLogger("piper_train.export_onnx")
13
+ # Basic logging configuration
14
+ logging.basicConfig(level=logging.DEBUG)
15
+ _LOGGER = logging.getLogger("phoonnx_train.export_onnx")
12
16
 
17
+ # ONNX opset version
13
18
  OPSET_VERSION = 15
14
19
 
15
20
 
16
- def main() -> None:
17
- """Main entry point"""
18
- torch.manual_seed(1234)
21
+ # --- Utility Functions ---
19
22
 
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)")
22
- parser.add_argument("output", help="Path to output model (.onnx)")
23
+ def add_meta_data(filename: Path, meta_data: Dict[str, Any]) -> None:
24
+ """
25
+ Add meta data to an ONNX model. The file is modified in-place.
23
26
 
24
- parser.add_argument(
25
- "--debug", action="store_true", help="Print DEBUG messages to the console"
26
- )
27
- args = parser.parse_args()
27
+ Args:
28
+ filename:
29
+ Path to the ONNX model file to be changed.
30
+ meta_data:
31
+ Key-value pairs to be stored as metadata. Values will be converted to strings.
32
+ """
33
+ try:
34
+ import onnx
35
+
36
+ # Load the ONNX model
37
+ model = onnx.load(str(filename))
38
+
39
+ # Clear existing metadata and add new properties
40
+ del model.metadata_props[:]
41
+
42
+ for key, value in meta_data.items():
43
+ meta = model.metadata_props.add()
44
+ meta.key = key
45
+ # Convert all values to string for ONNX metadata
46
+ meta.value = str(value)
47
+
48
+ onnx.save(model, str(filename))
49
+ _LOGGER.info(f"Added {len(meta_data)} metadata key/value pairs to ONNX model: {filename}")
50
+
51
+ except ImportError:
52
+ _LOGGER.error("The 'onnx' package is required to add metadata. Please install it with 'pip install onnx'.")
53
+ except Exception as e:
54
+ _LOGGER.error(f"Failed to add metadata to ONNX file {filename}: {e}")
55
+
56
+
57
+ def export_tokens(config_path: Path, output_path: Path = Path("tokens.txt")) -> None:
58
+ """
59
+ Generates a tokens.txt file containing phoneme-to-id mapping from the model configuration.
60
+
61
+ The format is: `<phoneme> <id>` per line.
62
+
63
+ Args:
64
+ config_path: Path to the model configuration JSON file.
65
+ output_path: Path to save the resulting tokens.txt file.
66
+ """
67
+ try:
68
+ with open(config_path, "r", encoding="utf-8") as file:
69
+ config: Dict[str, Any] = json.load(file)
70
+ except Exception as e:
71
+ _LOGGER.error(f"Failed to load config file at {config_path}: {e}")
72
+ return
73
+
74
+ id_map: Optional[Dict[str, int]] = config.get("phoneme_id_map")
75
+ if not id_map:
76
+ _LOGGER.error("Could not find 'phoneme_id_map' in the config file.")
77
+ return
78
+
79
+ tokens_path = output_path
80
+ try:
81
+ with open(tokens_path, "w", encoding="utf-8") as f:
82
+ # Sort by ID to ensure a consistent output order
83
+ # The type hint for sorted_items is a list of tuples: List[Tuple[str, int]]
84
+ sorted_items: list[Tuple[str, int]] = sorted(id_map.items(), key=lambda item: item[1])
85
+
86
+ for s, i in sorted_items:
87
+ # Skip newlines or other invalid tokens if present in map
88
+ if s == "\n" or s == "":
89
+ continue
90
+ f.write(f"{s} {i}\n")
91
+
92
+ _LOGGER.info(f"Generated tokens file at {tokens_path}")
93
+ except Exception as e:
94
+ _LOGGER.error(f"Failed to write tokens file to {tokens_path}: {e}")
95
+
96
+
97
+ def convert_to_piper(config_path: Path, output_path: Path = Path("piper.json")) -> None:
98
+ """
99
+ Generates a Piper compatible JSON configuration file from the VITS model configuration.
100
+
101
+ This function currently serves as a placeholder for full Piper conversion logic.
102
+
103
+ Args:
104
+ config_path: Path to the VITS model configuration JSON file.
105
+ output_path: Path to save the resulting Piper JSON file.
106
+ """
107
+
108
+ with open(config_path, "r", encoding="utf-8") as file:
109
+ config: Dict[str, Any] = json.load(file)
110
+
111
+ piper_config = {
112
+ "phoneme_type": "espeak" if config.get("phoneme_type", "") == "espeak" else "raw",
113
+ "phoneme_map": {},
114
+ "audio": config.get("audio", {}),
115
+ "inference": config.get("inference", {}),
116
+ "phoneme_id_map": {k: [v] for k, v in config.get("phoneme_id_map", {}).items()},
117
+ "espeak": {
118
+ "voice": config.get("lang_code", "")
119
+ },
120
+ "language": {
121
+ "code": config.get("lang_code", "")
122
+ },
123
+ "num_symbols": config.get("num_symbols", 256),
124
+ "num_speakers": config.get("num_speakers", 1),
125
+ "speaker_id_map": {},
126
+ "piper_version": f"phoonnx-" + config.get("phoonnx_version", "0.0.0")
127
+ }
128
+
129
+ with open(output_path, "w", encoding="utf-8") as f:
130
+ json.dump(piper_config, f, indent=4, ensure_ascii=False)
28
131
 
29
- if args.debug:
30
- logging.basicConfig(level=logging.DEBUG)
31
- else:
32
- logging.basicConfig(level=logging.INFO)
33
132
 
34
- _LOGGER.debug(args)
133
+ # --- Main Logic using Click ---
134
+ @click.command(help="Export a VITS model checkpoint to ONNX format.")
135
+ @click.argument(
136
+ "checkpoint",
137
+ type=click.Path(exists=True, path_type=Path),
138
+ # help="Path to the PyTorch checkpoint file (*.ckpt)."
139
+ )
140
+ @click.option(
141
+ "-c",
142
+ "--config",
143
+ type=click.Path(exists=True, path_type=Path),
144
+ help="Path to the model configuration JSON file."
145
+ )
146
+ @click.option(
147
+ "-o",
148
+ "--output-dir",
149
+ type=click.Path(path_type=Path),
150
+ default=Path(os.getcwd()), # Set default to current working directory
151
+ help="Output directory for the ONNX model. (Default: current directory)"
152
+ )
153
+ @click.option(
154
+ "-t",
155
+ "--generate-tokens",
156
+ is_flag=True,
157
+ help="Generate tokens.txt alongside the ONNX model. Some inference engines need this (eg. sherpa)"
158
+ )
159
+ @click.option(
160
+ "-p",
161
+ "--piper",
162
+ is_flag=True,
163
+ help="Generate a piper compatible .json file alongside the ONNX model."
164
+ )
165
+ def cli(
166
+ checkpoint: Path,
167
+ config: Path,
168
+ output_dir: Path,
169
+ generate_tokens: bool,
170
+ piper: bool,
171
+ ) -> None:
172
+ """
173
+ Main entry point for exporting a VITS model checkpoint to ONNX format.
174
+
175
+ Args:
176
+ checkpoint: Path to the PyTorch checkpoint file (*.ckpt).
177
+ config: Path to the model configuration JSON file.
178
+ output_dir: Output directory for the ONNX model and associated files.
179
+ generate_tokens: Flag to generate a tokens.txt file.
180
+ piper: Flag to generate a piper compatible .json file.
181
+ """
182
+ torch.manual_seed(1234)
183
+
184
+ _LOGGER.debug(f"Arguments: {checkpoint=}, {config=}, {output_dir=}, {generate_tokens=}, {piper=}")
35
185
 
36
186
  # -------------------------------------------------------------------------
187
+ # Paths and Setup
188
+
189
+ # Create output directory if it doesn't exist
190
+ output_dir.mkdir(parents=True, exist_ok=True)
191
+ _LOGGER.debug(f"Output directory ensured: {output_dir}")
192
+
193
+ # Load the phoonnx configuration
194
+ try:
195
+ with open(config, "r", encoding="utf-8") as f:
196
+ model_config: Dict[str, Any] = json.load(f)
197
+ _LOGGER.info(f"Loaded phoonnx config from {config}")
198
+ except Exception as e:
199
+ _LOGGER.error(f"Error loading config file {config}: {e}")
200
+ return
201
+
202
+
203
+ alphabet: str = model_config.get("alphabet", "")
204
+ phoneme_type: str = model_config.get("phoneme_type", "")
205
+ phonemizer_model: str = model_config.get("phonemizer_model", "") # depends on phonemizer (eg. byt5)
206
+ piper_compatible: bool = alphabet == "ipa" and phoneme_type == "espeak"
37
207
 
38
- args.checkpoint = Path(args.checkpoint)
39
- args.output = Path(args.output)
40
- args.output.parent.mkdir(parents=True, exist_ok=True)
208
+ # Ensure mandatory keys exist before accessing
209
+ sample_rate: int = model_config.get("audio", {}).get("sample_rate", 22050)
210
+ phoneme_id_map: Dict[str, int] = model_config.get("phoneme_id_map", {})
41
211
 
42
- model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None)
43
- model_g = model.model_g
212
+ if piper:
213
+ if not piper_compatible:
214
+ _LOGGER.warning("only models trained with ipa + espeak should be exported to piper. phonemization is not included in exported model.")
215
+ # Generate the piper.json file
216
+ piper_output_path = output_dir / f"{checkpoint.name}.piper.json"
217
+ convert_to_piper(config, piper_output_path)
44
218
 
45
- num_symbols = model_g.n_vocab
46
- num_speakers = model_g.n_speakers
219
+ if generate_tokens:
220
+ # Generate the tokens.txt file
221
+ tokens_output_path = output_dir / f"{checkpoint.name}.tokens.txt"
222
+ export_tokens(config, tokens_output_path)
47
223
 
48
- # Inference only
224
+ # -------------------------------------------------------------------------
225
+ # Model Loading and Preparation
226
+ try:
227
+ model: VitsModel = VitsModel.load_from_checkpoint(
228
+ checkpoint,
229
+ dataset=None
230
+ )
231
+ except Exception as e:
232
+ _LOGGER.error(f"Error loading model checkpoint {checkpoint}: {e}")
233
+ return
234
+
235
+ model_g: torch.nn.Module = model.model_g
236
+ num_symbols: int = model_g.n_vocab
237
+ num_speakers: int = model_g.n_speakers
238
+
239
+ # Inference only setup
49
240
  model_g.eval()
50
241
 
51
242
  with torch.no_grad():
243
+ # Apply weight norm removal for inference mode
52
244
  model_g.dec.remove_weight_norm()
245
+ _LOGGER.debug("Removed weight normalization from decoder.")
246
+
247
+ # -------------------------------------------------------------------------
248
+ # Define ONNX-compatible forward function
249
+
250
+ def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, scales: torch.Tensor, sid: Optional[torch.Tensor] = None) -> torch.Tensor:
251
+ """
252
+ Custom forward pass for ONNX export, simplifying the input scales and
253
+ returning only the audio tensor with shape [B, 1, T].
53
254
 
54
- # old_forward = model_g.infer
255
+ Args:
256
+ text: Input phoneme sequence tensor, shape [B, T_in].
257
+ text_lengths: Tensor of sequence lengths, shape [B].
258
+ scales: Tensor containing [noise_scale, length_scale, noise_scale_w], shape [3].
259
+ sid: Optional speaker ID tensor, shape [B], for multi-speaker models.
55
260
 
56
- def infer_forward(text, text_lengths, scales, sid=None):
57
- noise_scale = scales[0]
58
- length_scale = scales[1]
59
- noise_scale_w = scales[2]
60
- audio = model_g.infer(
261
+ Returns:
262
+ Generated audio tensor, shape [B, 1, T_out].
263
+ """
264
+ noise_scale: float = scales[0]
265
+ length_scale: float = scales[1]
266
+ noise_scale_w: float = scales[2]
267
+
268
+ # model_g.infer returns a tuple: (audio, attn, ids_slice, x_mask, z, z_mask, g)
269
+ audio: torch.Tensor = model_g.infer(
61
270
  text,
62
271
  text_lengths,
63
272
  noise_scale=noise_scale,
64
273
  length_scale=length_scale,
65
274
  noise_scale_w=noise_scale_w,
66
275
  sid=sid,
67
- )[0].unsqueeze(1)
276
+ )[0].unsqueeze(1) # [0] gets the audio tensor. unsqueeze(1) makes it [B, 1, T]
68
277
 
69
278
  return audio
70
279
 
280
+ # Replace the default forward with the inference one for ONNX export
71
281
  model_g.forward = infer_forward
72
282
 
73
- dummy_input_length = 50
74
- sequences = torch.randint(
283
+ # -------------------------------------------------------------------------
284
+ # Dummy Input Generation
285
+
286
+ dummy_input_length: int = 50
287
+ sequences: torch.Tensor = torch.randint(
75
288
  low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
76
289
  )
77
- sequence_lengths = torch.LongTensor([sequences.size(1)])
290
+ sequence_lengths: torch.Tensor = torch.LongTensor([sequences.size(1)])
78
291
 
79
292
  sid: Optional[torch.LongTensor] = None
293
+ input_names: list[str] = ["input", "input_lengths", "scales"]
294
+ dynamic_axes_map: Dict[str, Dict[int, str]] = {
295
+ "input": {0: "batch_size", 1: "phonemes"},
296
+ "input_lengths": {0: "batch_size"},
297
+ "output": {0: "batch_size", 1: "time"},
298
+ }
299
+
80
300
  if num_speakers > 1:
81
301
  sid = torch.LongTensor([0])
302
+ input_names.append("sid")
303
+ dynamic_axes_map["sid"] = {0: "batch_size"}
304
+ _LOGGER.debug(f"Multi-speaker model detected (n_speakers={num_speakers}). 'sid' included.")
82
305
 
83
- # noise, noise_w, length
84
- scales = torch.FloatTensor([0.667, 1.0, 0.8])
85
- dummy_input = (sequences, sequence_lengths, scales, sid)
306
+ # noise, length, noise_w scales (hardcoded defaults)
307
+ scales: torch.Tensor = torch.FloatTensor([0.667, 1.0, 0.8])
308
+ dummy_input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.LongTensor]] = (
309
+ sequences, sequence_lengths, scales, sid
310
+ )
86
311
 
312
+ # -------------------------------------------------------------------------
87
313
  # Export
88
- torch.onnx.export(
89
- model=model_g,
90
- args=dummy_input,
91
- f=str(args.output),
92
- verbose=False,
93
- opset_version=OPSET_VERSION,
94
- input_names=["input", "input_lengths", "scales", "sid"],
95
- output_names=["output"],
96
- dynamic_axes={
97
- "input": {0: "batch_size", 1: "phonemes"},
98
- "input_lengths": {0: "batch_size"},
99
- "output": {0: "batch_size", 1: "time"},
100
- },
101
- )
314
+ model_output: Path = output_dir / f"{checkpoint.name}.onnx"
315
+ _LOGGER.info(f"Starting ONNX export to {model_output} (opset={OPSET_VERSION})...")
316
+
317
+ try:
318
+ torch.onnx.export(
319
+ model=model_g,
320
+ args=dummy_input,
321
+ f=str(model_output),
322
+ verbose=False,
323
+ opset_version=OPSET_VERSION,
324
+ input_names=input_names,
325
+ output_names=["output"],
326
+ dynamic_axes=dynamic_axes_map,
327
+ )
328
+ _LOGGER.info(f"Successfully exported model to {model_output}")
329
+ except Exception as e:
330
+ _LOGGER.error(f"Failed during torch.onnx.export: {e}")
331
+ return
332
+
333
+ # -------------------------------------------------------------------------
334
+ # Add Metadata
335
+ metadata_dict: Dict[str, Any] = {
336
+ "model_type": "vits",
337
+ "n_speakers": num_speakers,
338
+ "n_vocab": num_symbols,
339
+ "sample_rate": sample_rate,
340
+ "alphabet": alphabet,
341
+ "phoneme_type": phoneme_type,
342
+ "phonemizer_model": phonemizer_model,
343
+ "phoneme_id_map": json.dumps(phoneme_id_map),
344
+ "has_espeak": phoneme_type == "espeak"
345
+ }
346
+ if piper_compatible:
347
+ metadata_dict["comment"] = "piper"
348
+
349
+ try:
350
+ add_meta_data(model_output, metadata_dict)
351
+ except Exception as e:
352
+ _LOGGER.error(f"Failed to add metadata to exported model {model_output}: {e}")
102
353
 
103
- _LOGGER.info("Exported model to %s", args.output)
354
+ _LOGGER.info("Export complete.")
104
355
 
105
356
 
106
357
  # -----------------------------------------------------------------------------
107
358
 
108
359
  if __name__ == "__main__":
109
- main()
360
+ cli()