sigilyph 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sigilyph/__init__.py +18 -0
- sigilyph/core/__init__.py +0 -0
- sigilyph/core/bert_align.py +163 -0
- sigilyph/core/g2p_func.py +47 -0
- sigilyph/core/norm_func_bk.py +98 -0
- sigilyph/core/predict.py +64 -0
- sigilyph/core/preprocess.py +16 -0
- sigilyph/core/py2phone.dict +2165 -0
- sigilyph/core/sigilyph_class.py +246 -0
- sigilyph/core/special_dict.json +26 -0
- sigilyph/core/symbols.py +445 -0
- sigilyph/core/text_process.py +328 -0
- sigilyph/fst_tool/__init__.py +0 -0
- sigilyph/fst_tool/infer_normalizer.py +49 -0
- sigilyph/fst_tool/processor.py +122 -0
- sigilyph/fst_tool/token_parser.py +159 -0
- sigilyph/text_norm/__init__.py +0 -0
- sigilyph/text_norm/norm_func.py +155 -0
- sigilyph/text_norm/norm_func_new.py +89 -0
- sigilyph/text_norm/sigilyph_norm.py +179 -0
- sigilyph-0.5.2.dist-info/METADATA +24 -0
- sigilyph-0.5.2.dist-info/RECORD +24 -0
- sigilyph-0.5.2.dist-info/WHEEL +5 -0
- sigilyph-0.5.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
'''
|
|
2
|
+
FilePath: /python-Sigilyph/sigilyph/core/text_process.py
|
|
3
|
+
Descripttion:
|
|
4
|
+
Author: Yixiang Chen
|
|
5
|
+
version:
|
|
6
|
+
Date: 2025-03-31 16:31:26
|
|
7
|
+
LastEditors: Yixiang Chen
|
|
8
|
+
LastEditTime: 2026-01-16 17:36:55
|
|
9
|
+
'''
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import langid
|
|
13
|
+
import re
|
|
14
|
+
|
|
15
|
+
import jieba
|
|
16
|
+
import jieba.posseg
|
|
17
|
+
|
|
18
|
+
from sigilyph.core.g2p_func import g2p_en, g2p_cn
|
|
19
|
+
#from sigilyph.core.norm_func import preprocess_first, text_norm_en, text_norm_cn
|
|
20
|
+
from sigilyph.text_norm.norm_func import preprocess_first, text_norm_en, text_norm_cn
|
|
21
|
+
|
|
22
|
+
from sigilyph.core.symbols import base_phone_set, cn_phone_set, en_phone_set, punctuation, special_phrase
|
|
23
|
+
|
|
24
|
+
#all_phone_set = [] + sorted(set(base_phone_set + cn_phone_set + en_phone_set))
|
|
25
|
+
#all_phone_set = [] + list(set(base_phone_set)) + list(set(cn_phone_set + en_phone_set))
|
|
26
|
+
all_phone_set = [] + sorted(set(base_phone_set)) + sorted(set(cn_phone_set)) + sorted(set(en_phone_set))
|
|
27
|
+
all_phone_dict = {xx:idx for idx, xx in enumerate(all_phone_set)}
|
|
28
|
+
|
|
29
|
+
norm_func_dict = {
|
|
30
|
+
'en': text_norm_en,
|
|
31
|
+
'zh': text_norm_cn
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
g2p_func_dict = {
|
|
35
|
+
'en': g2p_en,
|
|
36
|
+
'zh': g2p_cn
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
sil1symbol='-'
|
|
40
|
+
|
|
41
|
+
def text_split_lang_old(text, lang):
|
|
42
|
+
if lang == 'ZH' or lang == 'zh':
|
|
43
|
+
multi_lang_text_list = [{'lang':'zh', 'text_split': text}]
|
|
44
|
+
elif lang == 'en':
|
|
45
|
+
multi_lang_text_list = [{'lang':'en', 'text_split': text}]
|
|
46
|
+
else:
|
|
47
|
+
pattern = r'([a-zA-Z ,.\!\?\"]+|[\u4e00-\u9fa5 ,。\!\?\“\”]+)'
|
|
48
|
+
text_split = re.findall(pattern, text)
|
|
49
|
+
multi_lang_text_list = []
|
|
50
|
+
for idx in range(len(text_split)):
|
|
51
|
+
tmpts = text_split[idx]
|
|
52
|
+
tmp_lang = langid.classify(tmpts)[0]
|
|
53
|
+
multi_lang_text_list.append({'lang':tmp_lang, 'text_split': tmpts})
|
|
54
|
+
return multi_lang_text_list
|
|
55
|
+
|
|
56
|
+
def text_split_lang_bk0724(text, lang):
|
|
57
|
+
if lang == 'ZH' or lang == 'zh':
|
|
58
|
+
multi_lang_text_list = [{'lang':'zh', 'text_split': text}]
|
|
59
|
+
elif lang == 'en':
|
|
60
|
+
multi_lang_text_list = [{'lang':'en', 'text_split': text}]
|
|
61
|
+
else:
|
|
62
|
+
pretext_split = re.split("(\[.*?\])", text, re.I|re.M)
|
|
63
|
+
multi_lang_text_list = []
|
|
64
|
+
pretext_split = list(filter(None, pretext_split))
|
|
65
|
+
for utext in pretext_split:
|
|
66
|
+
if utext[0] != '[':
|
|
67
|
+
pattern = r'([a-zA-Z ,.\!\?\"]+|[\u4e00-\u9fa5 ,。,.\t \!\?]+)'
|
|
68
|
+
text_split = re.findall(pattern, utext)
|
|
69
|
+
for idx in range(len(text_split)):
|
|
70
|
+
tmpts = text_split[idx]
|
|
71
|
+
tmp_lang = langid.classify(tmpts)[0]
|
|
72
|
+
if tmp_lang in ['zh', 'jp', 'ja']:
|
|
73
|
+
tmp_lang = 'zh'
|
|
74
|
+
else:
|
|
75
|
+
tmp_lang = 'en'
|
|
76
|
+
if not tmpts.isspace():
|
|
77
|
+
multi_lang_text_list.append({'lang':tmp_lang, 'text_split': tmpts})
|
|
78
|
+
else:
|
|
79
|
+
phones = utext[1:-1]
|
|
80
|
+
multi_lang_text_list.append({'lang':'phone', 'text_split': phones})
|
|
81
|
+
return multi_lang_text_list
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def search_ele_mid(flaglist, tf = 'v'):
|
|
85
|
+
nowidx = -1
|
|
86
|
+
halflen = (len(flaglist))//2
|
|
87
|
+
for gap in range(len(flaglist)-halflen):
|
|
88
|
+
nowidx = halflen - gap
|
|
89
|
+
if flaglist[nowidx]==tf:
|
|
90
|
+
return nowidx
|
|
91
|
+
nowidx = halflen + gap
|
|
92
|
+
if flaglist[nowidx]==tf:
|
|
93
|
+
return nowidx
|
|
94
|
+
return nowidx
|
|
95
|
+
|
|
96
|
+
def add_pause(text, tf='v'):
|
|
97
|
+
segment = jieba.posseg.cut(text.strip())
|
|
98
|
+
wlist = []
|
|
99
|
+
flist = []
|
|
100
|
+
for x in segment:
|
|
101
|
+
wlist.append(x.word)
|
|
102
|
+
flist.append(x.flag)
|
|
103
|
+
idx = search_ele_mid(flist, tf)
|
|
104
|
+
if idx != len(flist)-1:
|
|
105
|
+
wlist.insert(idx, sil1symbol)
|
|
106
|
+
outtext = ''.join(wlist)
|
|
107
|
+
return outtext
|
|
108
|
+
|
|
109
|
+
def has_punc(text):
|
|
110
|
+
for char in text:
|
|
111
|
+
if char in [',', '.', '!', '?', ',','。','?','!', sil1symbol]:
|
|
112
|
+
return True
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def text_split_lang(text, lang):
|
|
116
|
+
if lang == 'ZH' or lang == 'zh':
|
|
117
|
+
multi_lang_text_list = [{'lang':'zh', 'text_split': text}]
|
|
118
|
+
elif lang == 'en':
|
|
119
|
+
multi_lang_text_list = [{'lang':'en', 'text_split': text}]
|
|
120
|
+
else:
|
|
121
|
+
pretext_split = re.split("(\[.*?\])", text, re.I|re.M)
|
|
122
|
+
multi_lang_text_list = []
|
|
123
|
+
pretext_split = list(filter(None, pretext_split))
|
|
124
|
+
for utext in pretext_split:
|
|
125
|
+
if utext[0] != '[':
|
|
126
|
+
pattern = r'([a-zA-Z ,.\!\?]+|[\u4e00-\u9fa5 ,。,.\t \"\!\?\“\”\、]+)'
|
|
127
|
+
text_split = re.findall(pattern, utext)
|
|
128
|
+
print(text_split)
|
|
129
|
+
for idx in range(len(text_split)):
|
|
130
|
+
tmpts = text_split[idx]
|
|
131
|
+
tmp_lang = langid.classify(tmpts)[0]
|
|
132
|
+
if len(tmpts)>20:
|
|
133
|
+
if not has_punc(tmpts[:-1]):
|
|
134
|
+
tmpts = add_pause(tmpts, 'p')
|
|
135
|
+
if not has_punc(tmpts[:-1]):
|
|
136
|
+
tmpts = add_pause(tmpts, 'v')
|
|
137
|
+
if tmpts in special_phrase:
|
|
138
|
+
tmpts = tmpts+sil1symbol
|
|
139
|
+
if tmp_lang in ['zh', 'jp', 'ja']:
|
|
140
|
+
tmp_lang = 'zh'
|
|
141
|
+
tmpts = tmpts.replace(' ', sil1symbol)
|
|
142
|
+
else:
|
|
143
|
+
tmp_lang = 'en'
|
|
144
|
+
if not tmpts.isspace():
|
|
145
|
+
multi_lang_text_list.append({'lang':tmp_lang, 'text_split': tmpts})
|
|
146
|
+
else:
|
|
147
|
+
phones = utext[1:-1]
|
|
148
|
+
multi_lang_text_list.append({'lang':'phone', 'text_split': phones})
|
|
149
|
+
return multi_lang_text_list
|
|
150
|
+
|
|
151
|
+
def text_norm(text, lang):
|
|
152
|
+
outtext = norm_func_dict[lang](text)
|
|
153
|
+
return outtext
|
|
154
|
+
|
|
155
|
+
def g2p(text, lang):
|
|
156
|
+
phoneme_list = g2p_func_dict[lang](text)
|
|
157
|
+
return phoneme_list
|
|
158
|
+
|
|
159
|
+
def tokenizer(phoneme_list):
|
|
160
|
+
#token_list = [all_phone_dict[pho] for pho in phoneme_list]
|
|
161
|
+
token_list = [all_phone_dict[pho] if pho in all_phone_dict.keys() else all_phone_dict['sil'] for pho in phoneme_list]
|
|
162
|
+
return token_list
|
|
163
|
+
|
|
164
|
+
def postprocess(phonelist):
|
|
165
|
+
outlist = [xx if xx not in punctuation else 'sil' for xx in phonelist]
|
|
166
|
+
return outlist
|
|
167
|
+
|
|
168
|
+
def postprocess_tts(phonelist):
|
|
169
|
+
#outlist = ['sil', '<sp>']
|
|
170
|
+
outlist = []
|
|
171
|
+
print(phonelist)
|
|
172
|
+
for idx in range(len(phonelist)):
|
|
173
|
+
pm = phonelist[idx]
|
|
174
|
+
if pm not in punctuation:
|
|
175
|
+
outlist.append(pm)
|
|
176
|
+
elif pm == sil1symbol:
|
|
177
|
+
outlist.append('sil_1')
|
|
178
|
+
else:
|
|
179
|
+
#outlist.append('sil')
|
|
180
|
+
outlist.append('sil_punc')
|
|
181
|
+
#outlist.append('<sp>')
|
|
182
|
+
#if outlist[-1] == 'sil':
|
|
183
|
+
# outlist.append('<sp>')
|
|
184
|
+
#elif outlist[-2] != 'sil':
|
|
185
|
+
# outlist.append('sil')
|
|
186
|
+
# outlist.append('<sp>')
|
|
187
|
+
if phonelist[-2] not in punctuation and outlist[-1].split('_')[0] != 'sil':
|
|
188
|
+
#outlist.append('sil')
|
|
189
|
+
outlist.append('sil_end')
|
|
190
|
+
outlist.append('<sp>')
|
|
191
|
+
return outlist
|
|
192
|
+
|
|
193
|
+
def text_process_old(text, lang, spflag=True):
|
|
194
|
+
multi_lang_text_list = text_split_lang(text, lang)
|
|
195
|
+
|
|
196
|
+
all_phone = []
|
|
197
|
+
for text_split_dict in multi_lang_text_list:
|
|
198
|
+
use_lang = text_split_dict['lang']
|
|
199
|
+
if use_lang not in norm_func_dict.keys():
|
|
200
|
+
use_lang = 'zh'
|
|
201
|
+
use_text = text_split_dict['text_split']
|
|
202
|
+
use_text = text_norm(use_text, use_lang)
|
|
203
|
+
phone_list = g2p(use_text, use_lang)
|
|
204
|
+
#all_phone.append('sil')
|
|
205
|
+
all_phone.append('sil_lang')
|
|
206
|
+
all_phone.append('<sp>')
|
|
207
|
+
all_phone.extend(phone_list)
|
|
208
|
+
#all_phone = postprocess(all_phone)
|
|
209
|
+
all_phone = postprocess_tts(all_phone)
|
|
210
|
+
if not spflag:
|
|
211
|
+
while '<sp>' in all_phone:
|
|
212
|
+
all_phone.remove('<sp>')
|
|
213
|
+
return all_phone
|
|
214
|
+
|
|
215
|
+
def text_process(text, lang, spflag=True, use_lang='zh'):
|
|
216
|
+
text = preprocess_first(text, use_lang=use_lang)
|
|
217
|
+
|
|
218
|
+
multi_lang_text_list = text_split_lang(text, lang)
|
|
219
|
+
|
|
220
|
+
all_phone = []
|
|
221
|
+
for text_split_dict in multi_lang_text_list:
|
|
222
|
+
use_lang = text_split_dict['lang']
|
|
223
|
+
use_text = text_split_dict['text_split']
|
|
224
|
+
if use_lang == 'phone':
|
|
225
|
+
phonelist = use_text.split()
|
|
226
|
+
all_phone.extend(phonelist)
|
|
227
|
+
else:
|
|
228
|
+
if use_lang not in norm_func_dict.keys():
|
|
229
|
+
use_lang = 'zh'
|
|
230
|
+
use_text = text_norm(use_text, use_lang)
|
|
231
|
+
phone_list = g2p(use_text, use_lang)
|
|
232
|
+
#all_phone.append('sil')
|
|
233
|
+
all_phone.append('sil_lang')
|
|
234
|
+
all_phone.append('<sp>')
|
|
235
|
+
all_phone.extend(phone_list)
|
|
236
|
+
#all_phone = postprocess(all_phone)
|
|
237
|
+
all_phone = postprocess_tts(all_phone)
|
|
238
|
+
if not spflag:
|
|
239
|
+
while '<sp>' in all_phone:
|
|
240
|
+
all_phone.remove('<sp>')
|
|
241
|
+
return all_phone
|
|
242
|
+
|
|
243
|
+
def replace_sil2label_old(phones):
|
|
244
|
+
phones = ['sil_1' if xx == 'sil_lang' else xx for xx in phones]
|
|
245
|
+
phones = ['sil_2' if xx == 'sil_punc' else xx for xx in phones]
|
|
246
|
+
phones = ['sil_2' if xx == 'sil_end' else xx for xx in phones]
|
|
247
|
+
phones = ['sil_1' if xx == 'sil' else xx for xx in phones]
|
|
248
|
+
outphones = []
|
|
249
|
+
for ele in phones:
|
|
250
|
+
if outphones == []:
|
|
251
|
+
outphones.append(ele)
|
|
252
|
+
else:
|
|
253
|
+
if ele.split('_')[0] == 'sil' and outphones[-1].split('_')[0] == 'sil':
|
|
254
|
+
#outphones[-1] = 'sil_2'
|
|
255
|
+
outphones[-1] = 'sil_1'
|
|
256
|
+
else:
|
|
257
|
+
outphones.append(ele)
|
|
258
|
+
if outphones[-1].split('_')[0] == 'sil':
|
|
259
|
+
outphones = outphones[:-1]
|
|
260
|
+
return outphones
|
|
261
|
+
|
|
262
|
+
def replace_sil2label_0808(phones):
|
|
263
|
+
#phones = ['sil_1' if xx == 'sil_lang' else xx for xx in phones]
|
|
264
|
+
phones = ['' if xx == 'sil_lang' else xx for xx in phones]
|
|
265
|
+
phones = ['sil_2' if xx == 'sil_punc' else xx for xx in phones]
|
|
266
|
+
phones = ['sil_2' if xx == 'sil_end' else xx for xx in phones]
|
|
267
|
+
phones = ['sil_1' if xx == 'sil' else xx for xx in phones]
|
|
268
|
+
phones = list(filter(None, phones))
|
|
269
|
+
#outphones = []
|
|
270
|
+
outphones = ['sil_1']
|
|
271
|
+
for ele in phones:
|
|
272
|
+
if outphones == []:
|
|
273
|
+
outphones.append(ele)
|
|
274
|
+
else:
|
|
275
|
+
if ele.split('_')[0] == 'sil' and outphones[-1].split('_')[0] == 'sil':
|
|
276
|
+
#outphones[-1] = 'sil_2'
|
|
277
|
+
outphones[-1] = 'sil_1'
|
|
278
|
+
else:
|
|
279
|
+
outphones.append(ele)
|
|
280
|
+
if outphones[-1].split('_')[0] == 'sil':
|
|
281
|
+
outphones = outphones[:-1]
|
|
282
|
+
return outphones
|
|
283
|
+
|
|
284
|
+
def replace_sil2label(phones):
|
|
285
|
+
#phones = ['sil_1' if xx == 'sil_lang' else xx for xx in phones]
|
|
286
|
+
phones = ['' if xx == 'sil_lang' else xx for xx in phones]
|
|
287
|
+
phones = ['sil_2' if xx == 'sil_punc' else xx for xx in phones]
|
|
288
|
+
phones = ['sil_2' if xx == 'sil_end' else xx for xx in phones]
|
|
289
|
+
phones = ['sil_1' if xx == 'sil' else xx for xx in phones]
|
|
290
|
+
phones = list(filter(None, phones))
|
|
291
|
+
#outphones = []
|
|
292
|
+
outphones = ['sil_1']
|
|
293
|
+
for ele in phones:
|
|
294
|
+
if outphones == []:
|
|
295
|
+
outphones.append(ele)
|
|
296
|
+
else:
|
|
297
|
+
if ele.split('_')[0] == 'sil' and outphones[-1].split('_')[0] == 'sil':
|
|
298
|
+
outphones[-1] = 'sil_2'
|
|
299
|
+
#outphones[-1] = 'sil_1'
|
|
300
|
+
else:
|
|
301
|
+
outphones.append(ele)
|
|
302
|
+
#if outphones[-1].split('_')[0] == 'sil':
|
|
303
|
+
# outphones = outphones[:-1]
|
|
304
|
+
return outphones
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def text_process_asr(text, lang):
|
|
308
|
+
multi_lang_text_list = text_split_lang(text, lang)
|
|
309
|
+
|
|
310
|
+
all_phone = []
|
|
311
|
+
for text_split_dict in multi_lang_text_list:
|
|
312
|
+
use_lang = text_split_dict['lang']
|
|
313
|
+
use_text = text_split_dict['text_split']
|
|
314
|
+
use_text = text_norm(use_text, use_lang)
|
|
315
|
+
phone_list = g2p(use_text, use_lang)
|
|
316
|
+
phone_list_new = []
|
|
317
|
+
for idx in range(len(phone_list)):
|
|
318
|
+
tmpp = phone_list[idx]
|
|
319
|
+
if tmpp != '<sp>':
|
|
320
|
+
phone_list_new.append(tmpp)
|
|
321
|
+
all_phone.extend(phone_list_new)
|
|
322
|
+
all_phone = postprocess(all_phone)
|
|
323
|
+
if all_phone[0] != 'sil':
|
|
324
|
+
all_phone = ['sil'] + all_phone
|
|
325
|
+
if all_phone[-1] != 'sil':
|
|
326
|
+
all_phone = all_phone + ['sil']
|
|
327
|
+
|
|
328
|
+
return all_phone
|
|
File without changes
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from sigilyph.fst_tool.processor import Processor
|
|
16
|
+
from importlib_resources import files
|
|
17
|
+
|
|
18
|
+
class ZhNormalizer(Processor):
|
|
19
|
+
|
|
20
|
+
def __init__(self,
|
|
21
|
+
version_id='v1',
|
|
22
|
+
cache_dir=None,
|
|
23
|
+
overwrite_cache=False,
|
|
24
|
+
remove_interjections=True,
|
|
25
|
+
remove_erhua=True,
|
|
26
|
+
traditional_to_simple=True,
|
|
27
|
+
remove_puncts=False,
|
|
28
|
+
full_to_half=True,
|
|
29
|
+
tag_oov=False):
|
|
30
|
+
super().__init__(name='zh_normalizer')
|
|
31
|
+
self.remove_interjections = remove_interjections
|
|
32
|
+
self.remove_erhua = remove_erhua
|
|
33
|
+
self.traditional_to_simple = traditional_to_simple
|
|
34
|
+
self.remove_puncts = remove_puncts
|
|
35
|
+
self.full_to_half = full_to_half
|
|
36
|
+
self.tag_oov = tag_oov
|
|
37
|
+
if cache_dir is None:
|
|
38
|
+
cache_dir = files("no_fst")
|
|
39
|
+
#self.build_fst('zh_tn', cache_dir, overwrite_cache)
|
|
40
|
+
self.build_fst('textnorm_zh_' + version_id, cache_dir, overwrite_cache)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class EnNormalizer(Processor):
|
|
44
|
+
def __init__(self, version_id='v1', cache_dir=None, overwrite_cache=False):
|
|
45
|
+
super().__init__(name='en_normalizer', ordertype="en_tn")
|
|
46
|
+
if cache_dir is None:
|
|
47
|
+
cache_dir = files("no_fst")
|
|
48
|
+
#self.build_fst('en_tn', cache_dir, overwrite_cache)
|
|
49
|
+
self.build_fst('textnorm_en_' + version_id, cache_dir, overwrite_cache)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import string
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from sigilyph.fst_tool.token_parser import TokenParser
|
|
20
|
+
|
|
21
|
+
from pynini import (cdrewrite, cross, difference, escape, Fst, shortestpath,
|
|
22
|
+
union, closure, invert)
|
|
23
|
+
from pynini.lib import byte, utf8
|
|
24
|
+
from pynini.lib.pynutil import delete, insert
|
|
25
|
+
|
|
26
|
+
class Processor:
|
|
27
|
+
|
|
28
|
+
def __init__(self, name, ordertype="tn"):
|
|
29
|
+
self.ALPHA = byte.ALPHA
|
|
30
|
+
self.DIGIT = byte.DIGIT
|
|
31
|
+
self.PUNCT = byte.PUNCT
|
|
32
|
+
self.SPACE = byte.SPACE | u'\u00A0'
|
|
33
|
+
self.VCHAR = utf8.VALID_UTF8_CHAR
|
|
34
|
+
self.VSIGMA = self.VCHAR.star
|
|
35
|
+
self.LOWER = byte.LOWER
|
|
36
|
+
self.UPPER = byte.UPPER
|
|
37
|
+
|
|
38
|
+
CHAR = difference(self.VCHAR, union('\\', '"'))
|
|
39
|
+
self.CHAR = (CHAR | cross('\\', '\\\\\\') | cross('"', '\\"'))
|
|
40
|
+
self.SIGMA = (CHAR | cross('\\\\\\', '\\') | cross('\\"', '"')).star
|
|
41
|
+
self.NOT_QUOTE = difference(self.VCHAR, r'"').optimize()
|
|
42
|
+
self.NOT_SPACE = difference(self.VCHAR, self.SPACE).optimize()
|
|
43
|
+
self.INSERT_SPACE = insert(" ")
|
|
44
|
+
self.DELETE_SPACE = delete(self.SPACE).star
|
|
45
|
+
self.DELETE_EXTRA_SPACE = cross(closure(self.SPACE, 1), " ")
|
|
46
|
+
self.DELETE_ZERO_OR_ONE_SPACE = delete(closure(self.SPACE, 0, 1))
|
|
47
|
+
self.MIN_NEG_WEIGHT = -0.0001
|
|
48
|
+
self.TO_LOWER = union(*[
|
|
49
|
+
cross(x, y)
|
|
50
|
+
for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)
|
|
51
|
+
])
|
|
52
|
+
self.TO_UPPER = invert(self.TO_LOWER)
|
|
53
|
+
|
|
54
|
+
self.name = name
|
|
55
|
+
self.ordertype = ordertype
|
|
56
|
+
self.tagger = None
|
|
57
|
+
self.verbalizer = None
|
|
58
|
+
|
|
59
|
+
def build_rule(self, fst, l='', r=''):
|
|
60
|
+
rule = cdrewrite(fst, l, r, self.VSIGMA)
|
|
61
|
+
return rule
|
|
62
|
+
|
|
63
|
+
def add_tokens(self, tagger):
|
|
64
|
+
tagger = insert(f"{self.name} {{ ") + tagger + insert(' } ')
|
|
65
|
+
return tagger.optimize()
|
|
66
|
+
|
|
67
|
+
def delete_tokens(self, verbalizer):
|
|
68
|
+
verbalizer = (delete(f"{self.name}") + delete(' { ') + verbalizer +
|
|
69
|
+
delete(' }') + delete(' ').ques)
|
|
70
|
+
return verbalizer.optimize()
|
|
71
|
+
|
|
72
|
+
def build_verbalizer(self):
|
|
73
|
+
verbalizer = delete('value: "') + self.SIGMA + delete('"')
|
|
74
|
+
self.verbalizer = self.delete_tokens(verbalizer)
|
|
75
|
+
|
|
76
|
+
def build_fst(self, prefix, cache_dir, overwrite_cache):
|
|
77
|
+
logger = logging.getLogger('textnorm-{}'.format(self.name))
|
|
78
|
+
logger.setLevel(logging.INFO)
|
|
79
|
+
handler = logging.StreamHandler()
|
|
80
|
+
#fmt = logging.Formatter('%(asctime)s WETEXT %(levelname)s %(message)s')
|
|
81
|
+
fmt = logging.Formatter('TextNorm %(levelname)s %(message)s')
|
|
82
|
+
handler.setFormatter(fmt)
|
|
83
|
+
logger.addHandler(handler)
|
|
84
|
+
|
|
85
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
86
|
+
tagger_name = '{}_tagger.fst'.format(prefix)
|
|
87
|
+
verbalizer_name = '{}_verbalizer.fst'.format(prefix)
|
|
88
|
+
|
|
89
|
+
tagger_path = os.path.join(cache_dir, tagger_name)
|
|
90
|
+
verbalizer_path = os.path.join(cache_dir, verbalizer_name)
|
|
91
|
+
|
|
92
|
+
exists = os.path.exists(tagger_path) and os.path.exists(
|
|
93
|
+
verbalizer_path)
|
|
94
|
+
if exists and not overwrite_cache:
|
|
95
|
+
logger.info("found existing fst: {}".format(tagger_path))
|
|
96
|
+
logger.info(" {}".format(verbalizer_path))
|
|
97
|
+
#logger.info("skip building fst for {} ...".format(self.name))
|
|
98
|
+
self.tagger = Fst.read(tagger_path).optimize()
|
|
99
|
+
self.verbalizer = Fst.read(verbalizer_path).optimize()
|
|
100
|
+
else:
|
|
101
|
+
logger.info("NO fst")
|
|
102
|
+
|
|
103
|
+
def tag(self, input):
|
|
104
|
+
if len(input) == 0:
|
|
105
|
+
return ''
|
|
106
|
+
input = escape(input)
|
|
107
|
+
lattice = input @ self.tagger
|
|
108
|
+
return shortestpath(lattice, nshortest=1, unique=True).string()
|
|
109
|
+
|
|
110
|
+
def verbalize(self, input):
|
|
111
|
+
# Only words from the blacklist are contained.
|
|
112
|
+
if len(input) == 0:
|
|
113
|
+
return ''
|
|
114
|
+
output = TokenParser(self.ordertype).reorder(input)
|
|
115
|
+
# We need escape for pynini to build the fst from string.
|
|
116
|
+
lattice = escape(output) @ self.verbalizer
|
|
117
|
+
return shortestpath(lattice, nshortest=1, unique=True).string()
|
|
118
|
+
|
|
119
|
+
def normalize(self, input):
|
|
120
|
+
return self.verbalize(self.tag(input))
|
|
121
|
+
|
|
122
|
+
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import string
|
|
16
|
+
|
|
17
|
+
EOS = '<EOS>'
|
|
18
|
+
TN_ORDERS = {
|
|
19
|
+
'date': ['year', 'month', 'day'],
|
|
20
|
+
'fraction': ['denominator', 'numerator'],
|
|
21
|
+
'measure': ['denominator', 'numerator', 'value'],
|
|
22
|
+
'money': ['value', 'currency'],
|
|
23
|
+
'time': ['noon', 'hour', 'minute', 'second']
|
|
24
|
+
}
|
|
25
|
+
EN_TN_ORDERS = {
|
|
26
|
+
'date': ['preserve_order', 'text', 'day', 'month', 'year'],
|
|
27
|
+
'money': ['integer_part', 'fractional_part', 'quantity', 'currency_maj'],
|
|
28
|
+
}
|
|
29
|
+
ITN_ORDERS = {
|
|
30
|
+
'date': ['year', 'month', 'day'],
|
|
31
|
+
'fraction': ['sign', 'numerator', 'denominator'],
|
|
32
|
+
'measure': ['numerator', 'denominator', 'value'],
|
|
33
|
+
'money': ['currency', 'value', 'decimal'],
|
|
34
|
+
'time': ['hour', 'minute', 'second', 'noon']
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
class Token:
|
|
38
|
+
|
|
39
|
+
def __init__(self, name):
|
|
40
|
+
self.name = name
|
|
41
|
+
self.order = []
|
|
42
|
+
self.members = {}
|
|
43
|
+
|
|
44
|
+
def append(self, key, value):
|
|
45
|
+
self.order.append(key)
|
|
46
|
+
self.members[key] = value
|
|
47
|
+
|
|
48
|
+
def string(self, orders):
|
|
49
|
+
output = self.name + ' {'
|
|
50
|
+
if self.name in orders.keys():
|
|
51
|
+
if "preserve_order" not in self.members.keys() or \
|
|
52
|
+
self.members["preserve_order"] != "true":
|
|
53
|
+
self.order = orders[self.name]
|
|
54
|
+
|
|
55
|
+
for key in self.order:
|
|
56
|
+
if key not in self.members.keys():
|
|
57
|
+
continue
|
|
58
|
+
output += ' {}: "{}"'.format(key, self.members[key])
|
|
59
|
+
return output + ' }'
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TokenParser:
|
|
63
|
+
|
|
64
|
+
def __init__(self, ordertype="tn"):
|
|
65
|
+
if ordertype == "tn":
|
|
66
|
+
self.orders = TN_ORDERS
|
|
67
|
+
elif ordertype == "itn":
|
|
68
|
+
self.orders = ITN_ORDERS
|
|
69
|
+
elif ordertype == "en_tn":
|
|
70
|
+
self.orders = EN_TN_ORDERS
|
|
71
|
+
else:
|
|
72
|
+
raise NotImplementedError()
|
|
73
|
+
|
|
74
|
+
def load(self, input):
|
|
75
|
+
assert len(input) > 0
|
|
76
|
+
self.index = 0
|
|
77
|
+
self.text = input
|
|
78
|
+
self.char = input[0]
|
|
79
|
+
self.tokens = []
|
|
80
|
+
|
|
81
|
+
def read(self):
|
|
82
|
+
if self.index < len(self.text) - 1:
|
|
83
|
+
self.index += 1
|
|
84
|
+
self.char = self.text[self.index]
|
|
85
|
+
return True
|
|
86
|
+
self.char = EOS
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
def parse_ws(self):
|
|
90
|
+
not_eos = self.char != EOS
|
|
91
|
+
while not_eos and self.char == ' ':
|
|
92
|
+
not_eos = self.read()
|
|
93
|
+
return not_eos
|
|
94
|
+
|
|
95
|
+
def parse_char(self, exp):
|
|
96
|
+
if self.char == exp:
|
|
97
|
+
self.read()
|
|
98
|
+
return True
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
def parse_chars(self, exp):
|
|
102
|
+
ok = False
|
|
103
|
+
for x in exp:
|
|
104
|
+
ok |= self.parse_char(x)
|
|
105
|
+
return ok
|
|
106
|
+
|
|
107
|
+
def parse_key(self):
|
|
108
|
+
assert self.char != EOS
|
|
109
|
+
assert self.char not in string.whitespace
|
|
110
|
+
|
|
111
|
+
key = ''
|
|
112
|
+
while self.char in string.ascii_letters + '_':
|
|
113
|
+
key += self.char
|
|
114
|
+
self.read()
|
|
115
|
+
return key
|
|
116
|
+
|
|
117
|
+
def parse_value(self):
|
|
118
|
+
assert self.char != EOS
|
|
119
|
+
escape = False
|
|
120
|
+
|
|
121
|
+
value = ''
|
|
122
|
+
while self.char != '"':
|
|
123
|
+
value += self.char
|
|
124
|
+
escape = self.char == '\\'
|
|
125
|
+
self.read()
|
|
126
|
+
if escape:
|
|
127
|
+
escape = False
|
|
128
|
+
value += self.char
|
|
129
|
+
self.read()
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
def parse(self, input):
|
|
133
|
+
self.load(input)
|
|
134
|
+
while self.parse_ws():
|
|
135
|
+
name = self.parse_key()
|
|
136
|
+
self.parse_chars(' { ')
|
|
137
|
+
|
|
138
|
+
token = Token(name)
|
|
139
|
+
while self.parse_ws():
|
|
140
|
+
if self.char == '}':
|
|
141
|
+
self.parse_char('}')
|
|
142
|
+
break
|
|
143
|
+
key = self.parse_key()
|
|
144
|
+
self.parse_chars(': "')
|
|
145
|
+
value = self.parse_value()
|
|
146
|
+
self.parse_char('"')
|
|
147
|
+
token.append(key, value)
|
|
148
|
+
self.tokens.append(token)
|
|
149
|
+
|
|
150
|
+
def reorder(self, input):
|
|
151
|
+
self.parse(input)
|
|
152
|
+
output = ''
|
|
153
|
+
for token in self.tokens:
|
|
154
|
+
output += token.string(self.orders) + ' '
|
|
155
|
+
return output.strip()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
File without changes
|