phoonnx 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phoonnx/config.py +4 -1
- phoonnx/phonemizers/ar.py +36 -44
- phoonnx/phonemizers/base.py +27 -1
- phoonnx/phonemizers/he.py +6 -25
- phoonnx/phonemizers/mul.py +617 -4
- phoonnx/thirdparty/hangul2ipa.py +1 -0
- phoonnx/thirdparty/mantoq/__init__.py +1 -26
- phoonnx/thirdparty/phonikud/__init__.py +24 -0
- phoonnx/version.py +5 -1
- phoonnx/voice.py +4 -16
- {phoonnx-0.1.0a1.dist-info → phoonnx-0.1.0a3.dist-info}/METADATA +2 -1
- {phoonnx-0.1.0a1.dist-info → phoonnx-0.1.0a3.dist-info}/RECORD +17 -16
- phoonnx_train/export_onnx.py +307 -56
- phoonnx_train/preprocess.py +36 -9
- phoonnx_train/vits/dataset.py +4 -0
- {phoonnx-0.1.0a1.dist-info → phoonnx-0.1.0a3.dist-info}/WHEEL +0 -0
- {phoonnx-0.1.0a1.dist-info → phoonnx-0.1.0a3.dist-info}/top_level.txt +0 -0
phoonnx/thirdparty/hangul2ipa.py
CHANGED
@@ -3,14 +3,6 @@ from phoonnx.thirdparty.mantoq.buck.tokenization import (arabic_to_phonemes, pho
|
|
3
3
|
phonemes_to_tokens, simplify_phonemes)
|
4
4
|
from phoonnx.thirdparty.mantoq.buck.tokenization import tokens_to_ids as _tokens_to_id
|
5
5
|
from phoonnx.thirdparty.mantoq.num2words import num2words
|
6
|
-
import warnings
|
7
|
-
from phoonnx.thirdparty.tashkeel import TashkeelDiacritizer
|
8
|
-
try:
|
9
|
-
import onnxruntime
|
10
|
-
|
11
|
-
_TASHKEEL_AVAILABLE = True
|
12
|
-
except ImportError:
|
13
|
-
_TASHKEEL_AVAILABLE = False
|
14
6
|
|
15
7
|
_DIACRITIZER_INST = None
|
16
8
|
|
@@ -29,29 +21,12 @@ QUOTES_TABLE = str.maketrans(QUOTES, '"' * len(QUOTES))
|
|
29
21
|
BRACKETS_TABLE = str.maketrans("[]{}", "()()")
|
30
22
|
|
31
23
|
|
32
|
-
|
33
|
-
|
34
|
-
def tashkeel(text: str) -> str:
|
35
|
-
global _DIACRITIZER_INST
|
36
|
-
if not _TASHKEEL_AVAILABLE:
|
37
|
-
warnings.warn(
|
38
|
-
"Warning: The Tashkeel feature will not be available. Please re-install with the `libtashkeel` extra.",
|
39
|
-
UserWarning,
|
40
|
-
)
|
41
|
-
return text
|
42
|
-
if _DIACRITIZER_INST is None:
|
43
|
-
_DIACRITIZER_INST = TashkeelDiacritizer()
|
44
|
-
return _DIACRITIZER_INST.diacritize(text)
|
45
|
-
|
46
24
|
def g2p(
|
47
25
|
text: str,
|
48
|
-
add_tashkeel: bool = True,
|
49
26
|
process_numbers: bool = True,
|
50
27
|
append_eos: bool = False,
|
51
|
-
) -> list[str]:
|
28
|
+
) -> tuple[str, list[str]]:
|
52
29
|
text = text.translate(AR_SPECIAL_PUNCS_TABLE).translate(QUOTES_TABLE).translate(BRACKETS_TABLE)
|
53
|
-
if add_tashkeel:
|
54
|
-
text = tashkeel(text)
|
55
30
|
if process_numbers:
|
56
31
|
text = num2words(text)
|
57
32
|
normalized_text = text
|
@@ -0,0 +1,24 @@
|
|
1
|
+
import os
|
2
|
+
import requests
|
3
|
+
|
4
|
+
|
5
|
+
class PhonikudDiacritizer:
|
6
|
+
dl_url = "https://huggingface.co/thewh1teagle/phonikud-onnx/resolve/main/phonikud-1.0.int8.onnx"
|
7
|
+
|
8
|
+
def __init__(self):
|
9
|
+
|
10
|
+
base_path = os.path.expanduser("~/.local/share/phonikud")
|
11
|
+
fname = self.dl_url.split("/")[-1]
|
12
|
+
model = f"{base_path}/{fname}"
|
13
|
+
if not os.path.isfile(model):
|
14
|
+
os.makedirs(base_path, exist_ok=True)
|
15
|
+
# TODO - streaming download
|
16
|
+
data = requests.get(self.dl_url).content
|
17
|
+
with open(model, "wb") as f:
|
18
|
+
f.write(data)
|
19
|
+
|
20
|
+
from phonikud_onnx import Phonikud
|
21
|
+
self.phonikud = Phonikud(model)
|
22
|
+
|
23
|
+
def diacritize(self, text: str) -> str:
|
24
|
+
return self.phonikud.add_diacritics(text)
|
phoonnx/version.py
CHANGED
phoonnx/voice.py
CHANGED
@@ -14,7 +14,6 @@ from phoonnx.config import PhonemeType, VoiceConfig, SynthesisConfig, get_phonem
|
|
14
14
|
from phoonnx.phoneme_ids import phonemes_to_ids, BlankBetween
|
15
15
|
from phoonnx.phonemizers import Phonemizer
|
16
16
|
from phoonnx.phonemizers.base import PhonemizedChunks
|
17
|
-
from phoonnx.thirdparty.tashkeel import TashkeelDiacritizer
|
18
17
|
|
19
18
|
_PHONEME_BLOCK_PATTERN = re.compile(r"(\[\[.*?\]\])")
|
20
19
|
|
@@ -113,11 +112,6 @@ class TTSVoice:
|
|
113
112
|
|
114
113
|
phonemizer: Optional[Phonemizer] = None
|
115
114
|
|
116
|
-
# For Arabic text only
|
117
|
-
use_tashkeel: bool = True
|
118
|
-
tashkeel_diacritizier: Optional[TashkeelDiacritizer] = None # For Arabic text only
|
119
|
-
taskeen_threshold: Optional[float] = 0.8
|
120
|
-
|
121
115
|
def __post_init__(self):
|
122
116
|
try:
|
123
117
|
self.phonetic_spellings = PhoneticSpellings.from_lang(self.config.lang_code)
|
@@ -128,10 +122,6 @@ class TTSVoice:
|
|
128
122
|
self.config.alphabet,
|
129
123
|
self.config.phonemizer_model)
|
130
124
|
|
131
|
-
# compat with piper arabic models - TODO move to espeak phonemizer
|
132
|
-
if self.config.lang_code.split("-")[0] == "ar" and self.use_tashkeel and self.tashkeel_diacritizier is None:
|
133
|
-
self.tashkeel_diacritizier = TashkeelDiacritizer()
|
134
|
-
|
135
125
|
@staticmethod
|
136
126
|
def load(
|
137
127
|
model_path: Union[str, Path],
|
@@ -209,12 +199,6 @@ class TTSVoice:
|
|
209
199
|
|
210
200
|
continue
|
211
201
|
|
212
|
-
# Arabic diacritization
|
213
|
-
if self.config.lang_code.split("-")[0] == "ar" and self.use_tashkeel:
|
214
|
-
text_part = self.tashkeel_diacritizier(
|
215
|
-
text_part, taskeen_threshold=self.taskeen_threshold
|
216
|
-
)
|
217
|
-
|
218
202
|
# Phonemization
|
219
203
|
phonemes = self.phonemizer.phonemize(
|
220
204
|
text_part, self.config.lang_code
|
@@ -267,6 +251,10 @@ class TTSVoice:
|
|
267
251
|
if self.phonetic_spellings and syn_config.enable_phonetic_spellings:
|
268
252
|
text = self.phonetic_spellings.apply(text)
|
269
253
|
|
254
|
+
if syn_config.add_diacritics:
|
255
|
+
text = self.phonemizer.add_diacritics(text, self.config.lang_code)
|
256
|
+
LOG.debug("text+diacritics=%s", text)
|
257
|
+
|
270
258
|
# All phonemization goes through the unified self.phonemize method
|
271
259
|
sentence_phonemes = self.phonemize(text)
|
272
260
|
LOG.debug("phonemes=%s", sentence_phonemes)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: phoonnx
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.0a3
|
4
4
|
Home-page: https://github.com/TigreGotico/phoonnx
|
5
5
|
Author: JarbasAi
|
6
6
|
Author-email: jarbasai@mailfence.com
|
@@ -220,6 +220,7 @@ Requires-Dist: librosa<1,>=0.9.2; extra == "train"
|
|
220
220
|
Requires-Dist: numpy<2,>=1.19.0; extra == "train"
|
221
221
|
Requires-Dist: pytorch-lightning<2.0; extra == "train"
|
222
222
|
Requires-Dist: torch<2,>=1.11.0; extra == "train"
|
223
|
+
Requires-Dist: click; extra == "train"
|
223
224
|
Provides-Extra: uew
|
224
225
|
Requires-Dist: epitran; extra == "uew"
|
225
226
|
Provides-Extra: ug
|
@@ -1,29 +1,29 @@
|
|
1
1
|
phoonnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
phoonnx/config.py,sha256=
|
2
|
+
phoonnx/config.py,sha256=IYhC-kYjLgYmBroId6YeOE2Vp7SMNGtiGqIIe_09NJk,19531
|
3
3
|
phoonnx/phoneme_ids.py,sha256=FiNgZwV6naEsBh6XwFLh3_FyOgPiCsK9qo7S0v-CmI4,13667
|
4
4
|
phoonnx/util.py,sha256=XSjFEoqSFcujFTHxednacgC9GrSYyF-Il5L6Utmxmu4,25909
|
5
|
-
phoonnx/version.py,sha256=
|
6
|
-
phoonnx/voice.py,sha256=
|
5
|
+
phoonnx/version.py,sha256=WnY5J2wtSTore9QbKwfk04gQhBsYq4HVmV5CBjEhGnk,236
|
6
|
+
phoonnx/voice.py,sha256=JXjmbrhJd4mmTiLgz4O_Pa5_rKGUC9xzuBfqxYDw3Mg,19420
|
7
7
|
phoonnx/locale/ca/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
|
8
8
|
phoonnx/locale/en/phonetic_spellings.txt,sha256=xGQlWOABLzbttpQvopl9CU-NnwEJRqKx8iuylsdUoQA,27
|
9
9
|
phoonnx/locale/gl/phonetic_spellings.txt,sha256=igv3t7jxLSRE5GHsdn57HOpxiWNcEmECPql6m02wbO0,47
|
10
10
|
phoonnx/locale/pt/phonetic_spellings.txt,sha256=KntS8QMynEJ5A3Clvcjq4qlmL-ThSbhfD6v0nKSrlqs,49
|
11
11
|
phoonnx/phonemizers/__init__.py,sha256=QGBZk0QUgJdg2MwUWY9Kpk6ucwrEJYtHb07YcNvXCV4,1647
|
12
|
-
phoonnx/phonemizers/ar.py,sha256=
|
13
|
-
phoonnx/phonemizers/base.py,sha256=
|
12
|
+
phoonnx/phonemizers/ar.py,sha256=xxILq5iyH0kcI-NqFfRK4abGtpdUbykBjt_dZmPuO2w,3216
|
13
|
+
phoonnx/phonemizers/base.py,sha256=FHvAsvSjAl_oSa1GoeEi96CQ_JO_xkKXWq0ukuMxiuo,8660
|
14
14
|
phoonnx/phonemizers/en.py,sha256=N2SVoVhplQao7Ej5TXbxJU-YkAgkY0Fr9iYBFnsjFSE,9271
|
15
15
|
phoonnx/phonemizers/fa.py,sha256=d_DZM2wqomf4gcRH_rFcNA3VkQWKHru8vwBwaNG8Ll8,1452
|
16
16
|
phoonnx/phonemizers/gl.py,sha256=jEFKJJViHufZtB7lGNwWQCdWGiNKDCVZ_GRYXTaw_2c,6614
|
17
|
-
phoonnx/phonemizers/he.py,sha256=
|
17
|
+
phoonnx/phonemizers/he.py,sha256=49OFS34wSFvvR9B3z2bGSzSLmlIvnn2HtkHBOkHS9Ns,1383
|
18
18
|
phoonnx/phonemizers/ja.py,sha256=Xojsrt715ihnIiEk9K6giYqDo9Iykw-SHfIidrHtHSU,3834
|
19
19
|
phoonnx/phonemizers/ko.py,sha256=kwWoOFqanCB8kv2JRx17A0hP78P1wbXlX6e8VBn1ezQ,2989
|
20
|
-
phoonnx/phonemizers/mul.py,sha256
|
20
|
+
phoonnx/phonemizers/mul.py,sha256=-h6uN_laUD-unNRGThzjyiOZpN6pSl4uinCndg5-0TA,94184
|
21
21
|
phoonnx/phonemizers/vi.py,sha256=_XJc-Xeawr1Lxr7o8mE_hJao1aGcj4g01XYAOxC_Scg,1311
|
22
22
|
phoonnx/phonemizers/zh.py,sha256=88Ywq8h9LDanlyz8RHjRSCY_PRK_Dq808tBADyrgaP8,9657
|
23
23
|
phoonnx/thirdparty/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
24
24
|
phoonnx/thirdparty/arpa2ipa.py,sha256=Uj1G5NgP5oBBfSm26LGB8QoumdT-NqCLQTZHT165-_o,5850
|
25
25
|
phoonnx/thirdparty/bw2ipa.py,sha256=5FiWC4AP4KXkqtbclbinoXEsUnSYEjk4VWAPasMMcbg,2328
|
26
|
-
phoonnx/thirdparty/hangul2ipa.py,sha256=
|
26
|
+
phoonnx/thirdparty/hangul2ipa.py,sha256=Pj06lL-GkOH4ZkLuakwQAT045fEVsijGhwoY_EEEVKc,27572
|
27
27
|
phoonnx/thirdparty/zh_num.py,sha256=SESA6gvSJW3LZ0FLoybXn2SpbxqhQTi9Tg_U2IZ5JYY,7147
|
28
28
|
phoonnx/thirdparty/cotovia/cotovia_aarch64,sha256=BsAWZN452Lm9kDU4i6rQGHFSlmxP3GfHRKhbJMUQrfA,6764592
|
29
29
|
phoonnx/thirdparty/cotovia/cotovia_x86_64,sha256=-6BNx_cd49nnDreOAsGtVtePs_X76esrqcNAfmksN1o,1379832
|
@@ -37,7 +37,7 @@ phoonnx/thirdparty/ko_tables/tensification.csv,sha256=V4Xf3A1G1iMBzwZevBKQuk_lPa
|
|
37
37
|
phoonnx/thirdparty/ko_tables/yale.csv,sha256=UhtDbPXRAAyAKoQMXmwhVBwJ5pfZQ_Duk28qBtRUdsU,297
|
38
38
|
phoonnx/thirdparty/kog2p/__init__.py,sha256=yLizadg7RXM-3dQyftD4XSk8r2jb0QOlHQ6as9uUa4U,10267
|
39
39
|
phoonnx/thirdparty/kog2p/rulebook.txt,sha256=FQE3nej8wojl6ilVUBYo7f8bIk0Hjci-B7HPXhM-xNc,9303
|
40
|
-
phoonnx/thirdparty/mantoq/__init__.py,sha256=
|
40
|
+
phoonnx/thirdparty/mantoq/__init__.py,sha256=02FftO4Onmp_S-XdukbBQ3aRVvqEQyo1frCLWgcF9cY,1428
|
41
41
|
phoonnx/thirdparty/mantoq/num2words.py,sha256=9-ncMtxV1FusD9rNur1lu7l2DWhwUwI1mFiqiPSMH_Q,1264
|
42
42
|
phoonnx/thirdparty/mantoq/unicode_symbol2label.py,sha256=CeZNv7qWeQS4Ejvz-sKgK--5eNYdVVv04WHPaOeK4gk,259409
|
43
43
|
phoonnx/thirdparty/mantoq/buck/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -53,6 +53,7 @@ phoonnx/thirdparty/mantoq/pyarabic/number.py,sha256=NjFZPWRu-9dZDLgxfv9oDjmh-kWY
|
|
53
53
|
phoonnx/thirdparty/mantoq/pyarabic/number_const.py,sha256=vAvRVENxTrl9gWPllSXF-yqK9fAW6htuA2d041btC_A,42361
|
54
54
|
phoonnx/thirdparty/mantoq/pyarabic/stack.py,sha256=aJeSzQxVNdomDTWXuxIXWXVOc2BW_3iRWnwmBLkB8jM,1022
|
55
55
|
phoonnx/thirdparty/mantoq/pyarabic/trans.py,sha256=cusyHk9Y01iuvMLJXxgCnIiGyAORzEdSosDKX4cAhPc,13713
|
56
|
+
phoonnx/thirdparty/phonikud/__init__.py,sha256=g1dCelCZbwlKT0Ibaky6Ckp59wMw5g_1DDyDXauqFTg,760
|
56
57
|
phoonnx/thirdparty/tashkeel/LICENSE,sha256=mQjTJ6MGAXzmYkO7x4O2VuEeSwCMx7lncbc26TnrVjw,1067
|
57
58
|
phoonnx/thirdparty/tashkeel/SOURCE,sha256=SmnRz-Am5EXv-n2-RokJVEhnn8zeF1QZJVvMQDA_Qds,38
|
58
59
|
phoonnx/thirdparty/tashkeel/__init__.py,sha256=FRdGNCTQaai9X077vlNh4tFOvWgm1U2lIUgnQKO5q0s,7119
|
@@ -61,8 +62,8 @@ phoonnx/thirdparty/tashkeel/input_id_map.json,sha256=cnpJqjx-k53AbzKyfC4GxMS771l
|
|
61
62
|
phoonnx/thirdparty/tashkeel/model.onnx,sha256=UsQNQsoJT_n_B6CR0KHq_XuqXPI4jmCpzIm6zY5elV8,4788213
|
62
63
|
phoonnx/thirdparty/tashkeel/target_id_map.json,sha256=baNAJL_UwP9U91mLt01aAEBRRNdGr-csFB_O6roh7TA,181
|
63
64
|
phoonnx_train/__main__.py,sha256=FUAIsbQ-w2i_hoNiBuriQFk4uoryhL4ydyVY-hVjw1U,5086
|
64
|
-
phoonnx_train/export_onnx.py,sha256=
|
65
|
-
phoonnx_train/preprocess.py,sha256=
|
65
|
+
phoonnx_train/export_onnx.py,sha256=CPfgNEm0hnXPSlgme0R9jr-6jZ5fKFpG5DZJFMkC-h4,12820
|
66
|
+
phoonnx_train/preprocess.py,sha256=8_Opy5QVNjVmSVmh1_IF23bcNebVIEXuK2KcollIy28,15793
|
66
67
|
phoonnx_train/norm_audio/__init__.py,sha256=Al_YwqMnENXRWp0c79cDZqbdd7pFYARXKxCfBaedr1c,3030
|
67
68
|
phoonnx_train/norm_audio/trim.py,sha256=_ZsE3SYhahQSdEdBLeSwyFJGcvEbt-5E_lnWwTT4tcY,1698
|
68
69
|
phoonnx_train/norm_audio/vad.py,sha256=DXHfRD0qqFJ52FjPvrL5LlN6keJWuc9Nf6TNhxpwC_4,1600
|
@@ -70,7 +71,7 @@ phoonnx_train/vits/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuF
|
|
70
71
|
phoonnx_train/vits/attentions.py,sha256=yc_ViF8zR8z68DzphmVVVn27f9xK_5wi8S4ITLXVQL0,15134
|
71
72
|
phoonnx_train/vits/commons.py,sha256=JsD8CdZ3ZcYYubYhw8So5hICBziFlCrKLrv1lMDRCDM,4645
|
72
73
|
phoonnx_train/vits/config.py,sha256=oSuUIhw9Am7BQ5JwDgtCO-P1zRyN7nPgR-U1XuncJls,10789
|
73
|
-
phoonnx_train/vits/dataset.py,sha256=
|
74
|
+
phoonnx_train/vits/dataset.py,sha256=1V1tVh5dSLjFMBsuzrAsoGtYWSBT4iU64Jdqi8oG-y0,7016
|
74
75
|
phoonnx_train/vits/lightning.py,sha256=ZBuSIiJ7EUU1Za2V8Uh6-_HGGRW_qwpXLLs1cEDirHA,12301
|
75
76
|
phoonnx_train/vits/losses.py,sha256=j-uINhBcYxVXFvFutiewQpTuw-qF-J6M6hdJVeOKqNE,1401
|
76
77
|
phoonnx_train/vits/mel_processing.py,sha256=huIjbQgewSmM39hdzRZvZUCI7fTNSMmLcAv3f8zYb8k,3956
|
@@ -81,7 +82,7 @@ phoonnx_train/vits/utils.py,sha256=exiyrtPHbnnGvcHWSbaH9-gR6srH5ZPHlKiqV2IHUrQ,4
|
|
81
82
|
phoonnx_train/vits/wavfile.py,sha256=oQZiTIrdw0oLTbcVwKfGXye1WtKte6qK_52qVwiMvfc,26396
|
82
83
|
phoonnx_train/vits/monotonic_align/__init__.py,sha256=5IdAOD1Z7UloMb6d_9NRFsXoNIjEQ3h9mvOSh_AtO3k,636
|
83
84
|
phoonnx_train/vits/monotonic_align/setup.py,sha256=0K5iJJ2mKIklx6ncEfCQS34skm5hHPiz9vRlQEvevvY,266
|
84
|
-
phoonnx-0.1.
|
85
|
-
phoonnx-0.1.
|
86
|
-
phoonnx-0.1.
|
87
|
-
phoonnx-0.1.
|
85
|
+
phoonnx-0.1.0a3.dist-info/METADATA,sha256=3U1Ea0g2HxtWPsIs7NCxzPdo7ZTr4s_lRs9gIOC6MWY,8184
|
86
|
+
phoonnx-0.1.0a3.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
87
|
+
phoonnx-0.1.0a3.dist-info/top_level.txt,sha256=ZrnHXe-4HqbOSX6fbdY-JiP7YEu2Bok9T0ji351MrmM,22
|
88
|
+
phoonnx-0.1.0a3.dist-info/RECORD,,
|
phoonnx_train/export_onnx.py
CHANGED
@@ -1,109 +1,360 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
|
-
import
|
2
|
+
import click
|
3
3
|
import logging
|
4
|
+
import json
|
5
|
+
import os
|
4
6
|
from pathlib import Path
|
5
|
-
from typing import Optional
|
7
|
+
from typing import Optional, Dict, Any, Tuple
|
6
8
|
|
7
9
|
import torch
|
8
|
-
|
9
10
|
from phoonnx_train.vits.lightning import VitsModel
|
11
|
+
from phoonnx.version import VERSION_STR
|
10
12
|
|
11
|
-
|
13
|
+
# Basic logging configuration
|
14
|
+
logging.basicConfig(level=logging.DEBUG)
|
15
|
+
_LOGGER = logging.getLogger("phoonnx_train.export_onnx")
|
12
16
|
|
17
|
+
# ONNX opset version
|
13
18
|
OPSET_VERSION = 15
|
14
19
|
|
15
20
|
|
16
|
-
|
17
|
-
"""Main entry point"""
|
18
|
-
torch.manual_seed(1234)
|
21
|
+
# --- Utility Functions ---
|
19
22
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
+
def add_meta_data(filename: Path, meta_data: Dict[str, Any]) -> None:
|
24
|
+
"""
|
25
|
+
Add meta data to an ONNX model. The file is modified in-place.
|
23
26
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
27
|
+
Args:
|
28
|
+
filename:
|
29
|
+
Path to the ONNX model file to be changed.
|
30
|
+
meta_data:
|
31
|
+
Key-value pairs to be stored as metadata. Values will be converted to strings.
|
32
|
+
"""
|
33
|
+
try:
|
34
|
+
import onnx
|
35
|
+
|
36
|
+
# Load the ONNX model
|
37
|
+
model = onnx.load(str(filename))
|
38
|
+
|
39
|
+
# Clear existing metadata and add new properties
|
40
|
+
del model.metadata_props[:]
|
41
|
+
|
42
|
+
for key, value in meta_data.items():
|
43
|
+
meta = model.metadata_props.add()
|
44
|
+
meta.key = key
|
45
|
+
# Convert all values to string for ONNX metadata
|
46
|
+
meta.value = str(value)
|
47
|
+
|
48
|
+
onnx.save(model, str(filename))
|
49
|
+
_LOGGER.info(f"Added {len(meta_data)} metadata key/value pairs to ONNX model: {filename}")
|
50
|
+
|
51
|
+
except ImportError:
|
52
|
+
_LOGGER.error("The 'onnx' package is required to add metadata. Please install it with 'pip install onnx'.")
|
53
|
+
except Exception as e:
|
54
|
+
_LOGGER.error(f"Failed to add metadata to ONNX file {filename}: {e}")
|
55
|
+
|
56
|
+
|
57
|
+
def export_tokens(config_path: Path, output_path: Path = Path("tokens.txt")) -> None:
|
58
|
+
"""
|
59
|
+
Generates a tokens.txt file containing phoneme-to-id mapping from the model configuration.
|
60
|
+
|
61
|
+
The format is: `<phoneme> <id>` per line.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
config_path: Path to the model configuration JSON file.
|
65
|
+
output_path: Path to save the resulting tokens.txt file.
|
66
|
+
"""
|
67
|
+
try:
|
68
|
+
with open(config_path, "r", encoding="utf-8") as file:
|
69
|
+
config: Dict[str, Any] = json.load(file)
|
70
|
+
except Exception as e:
|
71
|
+
_LOGGER.error(f"Failed to load config file at {config_path}: {e}")
|
72
|
+
return
|
73
|
+
|
74
|
+
id_map: Optional[Dict[str, int]] = config.get("phoneme_id_map")
|
75
|
+
if not id_map:
|
76
|
+
_LOGGER.error("Could not find 'phoneme_id_map' in the config file.")
|
77
|
+
return
|
78
|
+
|
79
|
+
tokens_path = output_path
|
80
|
+
try:
|
81
|
+
with open(tokens_path, "w", encoding="utf-8") as f:
|
82
|
+
# Sort by ID to ensure a consistent output order
|
83
|
+
# The type hint for sorted_items is a list of tuples: List[Tuple[str, int]]
|
84
|
+
sorted_items: list[Tuple[str, int]] = sorted(id_map.items(), key=lambda item: item[1])
|
85
|
+
|
86
|
+
for s, i in sorted_items:
|
87
|
+
# Skip newlines or other invalid tokens if present in map
|
88
|
+
if s == "\n" or s == "":
|
89
|
+
continue
|
90
|
+
f.write(f"{s} {i}\n")
|
91
|
+
|
92
|
+
_LOGGER.info(f"Generated tokens file at {tokens_path}")
|
93
|
+
except Exception as e:
|
94
|
+
_LOGGER.error(f"Failed to write tokens file to {tokens_path}: {e}")
|
95
|
+
|
96
|
+
|
97
|
+
def convert_to_piper(config_path: Path, output_path: Path = Path("piper.json")) -> None:
|
98
|
+
"""
|
99
|
+
Generates a Piper compatible JSON configuration file from the VITS model configuration.
|
100
|
+
|
101
|
+
This function currently serves as a placeholder for full Piper conversion logic.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
config_path: Path to the VITS model configuration JSON file.
|
105
|
+
output_path: Path to save the resulting Piper JSON file.
|
106
|
+
"""
|
107
|
+
|
108
|
+
with open(config_path, "r", encoding="utf-8") as file:
|
109
|
+
config: Dict[str, Any] = json.load(file)
|
110
|
+
|
111
|
+
piper_config = {
|
112
|
+
"phoneme_type": "espeak" if config.get("phoneme_type", "") == "espeak" else "raw",
|
113
|
+
"phoneme_map": {},
|
114
|
+
"audio": config.get("audio", {}),
|
115
|
+
"inference": config.get("inference", {}),
|
116
|
+
"phoneme_id_map": {k: [v] for k, v in config.get("phoneme_id_map", {}).items()},
|
117
|
+
"espeak": {
|
118
|
+
"voice": config.get("lang_code", "")
|
119
|
+
},
|
120
|
+
"language": {
|
121
|
+
"code": config.get("lang_code", "")
|
122
|
+
},
|
123
|
+
"num_symbols": config.get("num_symbols", 256),
|
124
|
+
"num_speakers": config.get("num_speakers", 1),
|
125
|
+
"speaker_id_map": {},
|
126
|
+
"piper_version": f"phoonnx-" + config.get("phoonnx_version", "0.0.0")
|
127
|
+
}
|
128
|
+
|
129
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
130
|
+
json.dump(piper_config, f, indent=4, ensure_ascii=False)
|
28
131
|
|
29
|
-
if args.debug:
|
30
|
-
logging.basicConfig(level=logging.DEBUG)
|
31
|
-
else:
|
32
|
-
logging.basicConfig(level=logging.INFO)
|
33
132
|
|
34
|
-
|
133
|
+
# --- Main Logic using Click ---
|
134
|
+
@click.command(help="Export a VITS model checkpoint to ONNX format.")
|
135
|
+
@click.argument(
|
136
|
+
"checkpoint",
|
137
|
+
type=click.Path(exists=True, path_type=Path),
|
138
|
+
# help="Path to the PyTorch checkpoint file (*.ckpt)."
|
139
|
+
)
|
140
|
+
@click.option(
|
141
|
+
"-c",
|
142
|
+
"--config",
|
143
|
+
type=click.Path(exists=True, path_type=Path),
|
144
|
+
help="Path to the model configuration JSON file."
|
145
|
+
)
|
146
|
+
@click.option(
|
147
|
+
"-o",
|
148
|
+
"--output-dir",
|
149
|
+
type=click.Path(path_type=Path),
|
150
|
+
default=Path(os.getcwd()), # Set default to current working directory
|
151
|
+
help="Output directory for the ONNX model. (Default: current directory)"
|
152
|
+
)
|
153
|
+
@click.option(
|
154
|
+
"-t",
|
155
|
+
"--generate-tokens",
|
156
|
+
is_flag=True,
|
157
|
+
help="Generate tokens.txt alongside the ONNX model. Some inference engines need this (eg. sherpa)"
|
158
|
+
)
|
159
|
+
@click.option(
|
160
|
+
"-p",
|
161
|
+
"--piper",
|
162
|
+
is_flag=True,
|
163
|
+
help="Generate a piper compatible .json file alongside the ONNX model."
|
164
|
+
)
|
165
|
+
def cli(
|
166
|
+
checkpoint: Path,
|
167
|
+
config: Path,
|
168
|
+
output_dir: Path,
|
169
|
+
generate_tokens: bool,
|
170
|
+
piper: bool,
|
171
|
+
) -> None:
|
172
|
+
"""
|
173
|
+
Main entry point for exporting a VITS model checkpoint to ONNX format.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
checkpoint: Path to the PyTorch checkpoint file (*.ckpt).
|
177
|
+
config: Path to the model configuration JSON file.
|
178
|
+
output_dir: Output directory for the ONNX model and associated files.
|
179
|
+
generate_tokens: Flag to generate a tokens.txt file.
|
180
|
+
piper: Flag to generate a piper compatible .json file.
|
181
|
+
"""
|
182
|
+
torch.manual_seed(1234)
|
183
|
+
|
184
|
+
_LOGGER.debug(f"Arguments: {checkpoint=}, {config=}, {output_dir=}, {generate_tokens=}, {piper=}")
|
35
185
|
|
36
186
|
# -------------------------------------------------------------------------
|
187
|
+
# Paths and Setup
|
188
|
+
|
189
|
+
# Create output directory if it doesn't exist
|
190
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
191
|
+
_LOGGER.debug(f"Output directory ensured: {output_dir}")
|
192
|
+
|
193
|
+
# Load the phoonnx configuration
|
194
|
+
try:
|
195
|
+
with open(config, "r", encoding="utf-8") as f:
|
196
|
+
model_config: Dict[str, Any] = json.load(f)
|
197
|
+
_LOGGER.info(f"Loaded phoonnx config from {config}")
|
198
|
+
except Exception as e:
|
199
|
+
_LOGGER.error(f"Error loading config file {config}: {e}")
|
200
|
+
return
|
201
|
+
|
202
|
+
|
203
|
+
alphabet: str = model_config.get("alphabet", "")
|
204
|
+
phoneme_type: str = model_config.get("phoneme_type", "")
|
205
|
+
phonemizer_model: str = model_config.get("phonemizer_model", "") # depends on phonemizer (eg. byt5)
|
206
|
+
piper_compatible: bool = alphabet == "ipa" and phoneme_type == "espeak"
|
37
207
|
|
38
|
-
|
39
|
-
|
40
|
-
|
208
|
+
# Ensure mandatory keys exist before accessing
|
209
|
+
sample_rate: int = model_config.get("audio", {}).get("sample_rate", 22050)
|
210
|
+
phoneme_id_map: Dict[str, int] = model_config.get("phoneme_id_map", {})
|
41
211
|
|
42
|
-
|
43
|
-
|
212
|
+
if piper:
|
213
|
+
if not piper_compatible:
|
214
|
+
_LOGGER.warning("only models trained with ipa + espeak should be exported to piper. phonemization is not included in exported model.")
|
215
|
+
# Generate the piper.json file
|
216
|
+
piper_output_path = output_dir / f"{checkpoint.name}.piper.json"
|
217
|
+
convert_to_piper(config, piper_output_path)
|
44
218
|
|
45
|
-
|
46
|
-
|
219
|
+
if generate_tokens:
|
220
|
+
# Generate the tokens.txt file
|
221
|
+
tokens_output_path = output_dir / f"{checkpoint.name}.tokens.txt"
|
222
|
+
export_tokens(config, tokens_output_path)
|
47
223
|
|
48
|
-
#
|
224
|
+
# -------------------------------------------------------------------------
|
225
|
+
# Model Loading and Preparation
|
226
|
+
try:
|
227
|
+
model: VitsModel = VitsModel.load_from_checkpoint(
|
228
|
+
checkpoint,
|
229
|
+
dataset=None
|
230
|
+
)
|
231
|
+
except Exception as e:
|
232
|
+
_LOGGER.error(f"Error loading model checkpoint {checkpoint}: {e}")
|
233
|
+
return
|
234
|
+
|
235
|
+
model_g: torch.nn.Module = model.model_g
|
236
|
+
num_symbols: int = model_g.n_vocab
|
237
|
+
num_speakers: int = model_g.n_speakers
|
238
|
+
|
239
|
+
# Inference only setup
|
49
240
|
model_g.eval()
|
50
241
|
|
51
242
|
with torch.no_grad():
|
243
|
+
# Apply weight norm removal for inference mode
|
52
244
|
model_g.dec.remove_weight_norm()
|
245
|
+
_LOGGER.debug("Removed weight normalization from decoder.")
|
246
|
+
|
247
|
+
# -------------------------------------------------------------------------
|
248
|
+
# Define ONNX-compatible forward function
|
249
|
+
|
250
|
+
def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, scales: torch.Tensor, sid: Optional[torch.Tensor] = None) -> torch.Tensor:
|
251
|
+
"""
|
252
|
+
Custom forward pass for ONNX export, simplifying the input scales and
|
253
|
+
returning only the audio tensor with shape [B, 1, T].
|
53
254
|
|
54
|
-
|
255
|
+
Args:
|
256
|
+
text: Input phoneme sequence tensor, shape [B, T_in].
|
257
|
+
text_lengths: Tensor of sequence lengths, shape [B].
|
258
|
+
scales: Tensor containing [noise_scale, length_scale, noise_scale_w], shape [3].
|
259
|
+
sid: Optional speaker ID tensor, shape [B], for multi-speaker models.
|
55
260
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
261
|
+
Returns:
|
262
|
+
Generated audio tensor, shape [B, 1, T_out].
|
263
|
+
"""
|
264
|
+
noise_scale: float = scales[0]
|
265
|
+
length_scale: float = scales[1]
|
266
|
+
noise_scale_w: float = scales[2]
|
267
|
+
|
268
|
+
# model_g.infer returns a tuple: (audio, attn, ids_slice, x_mask, z, z_mask, g)
|
269
|
+
audio: torch.Tensor = model_g.infer(
|
61
270
|
text,
|
62
271
|
text_lengths,
|
63
272
|
noise_scale=noise_scale,
|
64
273
|
length_scale=length_scale,
|
65
274
|
noise_scale_w=noise_scale_w,
|
66
275
|
sid=sid,
|
67
|
-
)[0].unsqueeze(1)
|
276
|
+
)[0].unsqueeze(1) # [0] gets the audio tensor. unsqueeze(1) makes it [B, 1, T]
|
68
277
|
|
69
278
|
return audio
|
70
279
|
|
280
|
+
# Replace the default forward with the inference one for ONNX export
|
71
281
|
model_g.forward = infer_forward
|
72
282
|
|
73
|
-
|
74
|
-
|
283
|
+
# -------------------------------------------------------------------------
|
284
|
+
# Dummy Input Generation
|
285
|
+
|
286
|
+
dummy_input_length: int = 50
|
287
|
+
sequences: torch.Tensor = torch.randint(
|
75
288
|
low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
|
76
289
|
)
|
77
|
-
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
290
|
+
sequence_lengths: torch.Tensor = torch.LongTensor([sequences.size(1)])
|
78
291
|
|
79
292
|
sid: Optional[torch.LongTensor] = None
|
293
|
+
input_names: list[str] = ["input", "input_lengths", "scales"]
|
294
|
+
dynamic_axes_map: Dict[str, Dict[int, str]] = {
|
295
|
+
"input": {0: "batch_size", 1: "phonemes"},
|
296
|
+
"input_lengths": {0: "batch_size"},
|
297
|
+
"output": {0: "batch_size", 1: "time"},
|
298
|
+
}
|
299
|
+
|
80
300
|
if num_speakers > 1:
|
81
301
|
sid = torch.LongTensor([0])
|
302
|
+
input_names.append("sid")
|
303
|
+
dynamic_axes_map["sid"] = {0: "batch_size"}
|
304
|
+
_LOGGER.debug(f"Multi-speaker model detected (n_speakers={num_speakers}). 'sid' included.")
|
82
305
|
|
83
|
-
# noise,
|
84
|
-
scales = torch.FloatTensor([0.667, 1.0, 0.8])
|
85
|
-
dummy_input
|
306
|
+
# noise, length, noise_w scales (hardcoded defaults)
|
307
|
+
scales: torch.Tensor = torch.FloatTensor([0.667, 1.0, 0.8])
|
308
|
+
dummy_input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.LongTensor]] = (
|
309
|
+
sequences, sequence_lengths, scales, sid
|
310
|
+
)
|
86
311
|
|
312
|
+
# -------------------------------------------------------------------------
|
87
313
|
# Export
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
"output"
|
100
|
-
|
101
|
-
|
314
|
+
model_output: Path = output_dir / f"{checkpoint.name}.onnx"
|
315
|
+
_LOGGER.info(f"Starting ONNX export to {model_output} (opset={OPSET_VERSION})...")
|
316
|
+
|
317
|
+
try:
|
318
|
+
torch.onnx.export(
|
319
|
+
model=model_g,
|
320
|
+
args=dummy_input,
|
321
|
+
f=str(model_output),
|
322
|
+
verbose=False,
|
323
|
+
opset_version=OPSET_VERSION,
|
324
|
+
input_names=input_names,
|
325
|
+
output_names=["output"],
|
326
|
+
dynamic_axes=dynamic_axes_map,
|
327
|
+
)
|
328
|
+
_LOGGER.info(f"Successfully exported model to {model_output}")
|
329
|
+
except Exception as e:
|
330
|
+
_LOGGER.error(f"Failed during torch.onnx.export: {e}")
|
331
|
+
return
|
332
|
+
|
333
|
+
# -------------------------------------------------------------------------
|
334
|
+
# Add Metadata
|
335
|
+
metadata_dict: Dict[str, Any] = {
|
336
|
+
"model_type": "vits",
|
337
|
+
"n_speakers": num_speakers,
|
338
|
+
"n_vocab": num_symbols,
|
339
|
+
"sample_rate": sample_rate,
|
340
|
+
"alphabet": alphabet,
|
341
|
+
"phoneme_type": phoneme_type,
|
342
|
+
"phonemizer_model": phonemizer_model,
|
343
|
+
"phoneme_id_map": json.dumps(phoneme_id_map),
|
344
|
+
"has_espeak": phoneme_type == "espeak"
|
345
|
+
}
|
346
|
+
if piper_compatible:
|
347
|
+
metadata_dict["comment"] = "piper"
|
348
|
+
|
349
|
+
try:
|
350
|
+
add_meta_data(model_output, metadata_dict)
|
351
|
+
except Exception as e:
|
352
|
+
_LOGGER.error(f"Failed to add metadata to exported model {model_output}: {e}")
|
102
353
|
|
103
|
-
_LOGGER.info("
|
354
|
+
_LOGGER.info("Export complete.")
|
104
355
|
|
105
356
|
|
106
357
|
# -----------------------------------------------------------------------------
|
107
358
|
|
108
359
|
if __name__ == "__main__":
|
109
|
-
|
360
|
+
cli()
|