xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -7
- xinference/client/handlers.py +3 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +2 -0
- xinference/core/scheduler.py +4 -7
- xinference/core/supervisor.py +114 -23
- xinference/core/worker.py +70 -4
- xinference/deploy/local.py +2 -1
- xinference/model/audio/core.py +11 -0
- xinference/model/audio/cosyvoice.py +16 -5
- xinference/model/audio/kokoro.py +139 -0
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +80 -0
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/llm/llama_cpp/core.py +21 -14
- xinference/model/llm/llm_family.json +527 -1
- xinference/model/llm/llm_family.py +4 -1
- xinference/model/llm/llm_family_modelscope.json +495 -3
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +24 -6
- xinference/model/llm/transformers/core.py +9 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -3
- xinference/model/llm/transformers/utils.py +22 -11
- xinference/model/llm/utils.py +115 -1
- xinference/model/llm/vllm/core.py +14 -4
- xinference/model/llm/vllm/xavier/block.py +3 -4
- xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/executor.py +18 -16
- xinference/model/llm/vllm/xavier/scheduler.py +79 -63
- xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
- xinference/model/llm/vllm/xavier/transfer.py +53 -32
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/melo/__init__.py +0 -0
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
- /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from random import shuffle
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import click
|
|
8
|
+
from text.cleaner import clean_text_bert
|
|
9
|
+
import os
|
|
10
|
+
import torch
|
|
11
|
+
from text.symbols import symbols, num_languages, num_tones
|
|
12
|
+
|
|
13
|
+
@click.command()
|
|
14
|
+
@click.option(
|
|
15
|
+
"--metadata",
|
|
16
|
+
default="data/example/metadata.list",
|
|
17
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
|
18
|
+
)
|
|
19
|
+
@click.option("--cleaned-path", default=None)
|
|
20
|
+
@click.option("--train-path", default=None)
|
|
21
|
+
@click.option("--val-path", default=None)
|
|
22
|
+
@click.option(
|
|
23
|
+
"--config_path",
|
|
24
|
+
default="configs/config.json",
|
|
25
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
|
26
|
+
)
|
|
27
|
+
@click.option("--val-per-spk", default=4)
|
|
28
|
+
@click.option("--max-val-total", default=8)
|
|
29
|
+
@click.option("--clean/--no-clean", default=True)
|
|
30
|
+
def main(
|
|
31
|
+
metadata: str,
|
|
32
|
+
cleaned_path: Optional[str],
|
|
33
|
+
train_path: str,
|
|
34
|
+
val_path: str,
|
|
35
|
+
config_path: str,
|
|
36
|
+
val_per_spk: int,
|
|
37
|
+
max_val_total: int,
|
|
38
|
+
clean: bool,
|
|
39
|
+
):
|
|
40
|
+
if train_path is None:
|
|
41
|
+
train_path = os.path.join(os.path.dirname(metadata), 'train.list')
|
|
42
|
+
if val_path is None:
|
|
43
|
+
val_path = os.path.join(os.path.dirname(metadata), 'val.list')
|
|
44
|
+
out_config_path = os.path.join(os.path.dirname(metadata), 'config.json')
|
|
45
|
+
|
|
46
|
+
if cleaned_path is None:
|
|
47
|
+
cleaned_path = metadata + ".cleaned"
|
|
48
|
+
|
|
49
|
+
if clean:
|
|
50
|
+
out_file = open(cleaned_path, "w", encoding="utf-8")
|
|
51
|
+
new_symbols = []
|
|
52
|
+
for line in tqdm(open(metadata, encoding="utf-8").readlines()):
|
|
53
|
+
try:
|
|
54
|
+
utt, spk, language, text = line.strip().split("|")
|
|
55
|
+
norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0')
|
|
56
|
+
for ph in phones:
|
|
57
|
+
if ph not in symbols and ph not in new_symbols:
|
|
58
|
+
new_symbols.append(ph)
|
|
59
|
+
print('update!, now symbols:')
|
|
60
|
+
print(new_symbols)
|
|
61
|
+
with open(f'{language}_symbol.txt', 'w') as f:
|
|
62
|
+
f.write(f'{new_symbols}')
|
|
63
|
+
|
|
64
|
+
assert len(phones) == len(tones)
|
|
65
|
+
assert len(phones) == sum(word2ph)
|
|
66
|
+
out_file.write(
|
|
67
|
+
"{}|{}|{}|{}|{}|{}|{}\n".format(
|
|
68
|
+
utt,
|
|
69
|
+
spk,
|
|
70
|
+
language,
|
|
71
|
+
norm_text,
|
|
72
|
+
" ".join(phones),
|
|
73
|
+
" ".join([str(i) for i in tones]),
|
|
74
|
+
" ".join([str(i) for i in word2ph]),
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
bert_path = utt.replace(".wav", ".bert.pt")
|
|
78
|
+
os.makedirs(os.path.dirname(bert_path), exist_ok=True)
|
|
79
|
+
torch.save(bert.cpu(), bert_path)
|
|
80
|
+
except Exception as error:
|
|
81
|
+
print("err!", line, error)
|
|
82
|
+
|
|
83
|
+
out_file.close()
|
|
84
|
+
|
|
85
|
+
metadata = cleaned_path
|
|
86
|
+
|
|
87
|
+
spk_utt_map = defaultdict(list)
|
|
88
|
+
spk_id_map = {}
|
|
89
|
+
current_sid = 0
|
|
90
|
+
|
|
91
|
+
with open(metadata, encoding="utf-8") as f:
|
|
92
|
+
for line in f.readlines():
|
|
93
|
+
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
|
|
94
|
+
spk_utt_map[spk].append(line)
|
|
95
|
+
|
|
96
|
+
if spk not in spk_id_map.keys():
|
|
97
|
+
spk_id_map[spk] = current_sid
|
|
98
|
+
current_sid += 1
|
|
99
|
+
|
|
100
|
+
train_list = []
|
|
101
|
+
val_list = []
|
|
102
|
+
|
|
103
|
+
for spk, utts in spk_utt_map.items():
|
|
104
|
+
shuffle(utts)
|
|
105
|
+
val_list += utts[:val_per_spk]
|
|
106
|
+
train_list += utts[val_per_spk:]
|
|
107
|
+
|
|
108
|
+
if len(val_list) > max_val_total:
|
|
109
|
+
train_list += val_list[max_val_total:]
|
|
110
|
+
val_list = val_list[:max_val_total]
|
|
111
|
+
|
|
112
|
+
with open(train_path, "w", encoding="utf-8") as f:
|
|
113
|
+
for line in train_list:
|
|
114
|
+
f.write(line)
|
|
115
|
+
|
|
116
|
+
with open(val_path, "w", encoding="utf-8") as f:
|
|
117
|
+
for line in val_list:
|
|
118
|
+
f.write(line)
|
|
119
|
+
|
|
120
|
+
config = json.load(open(config_path, encoding="utf-8"))
|
|
121
|
+
config["data"]["spk2id"] = spk_id_map
|
|
122
|
+
|
|
123
|
+
config["data"]["training_files"] = train_path
|
|
124
|
+
config["data"]["validation_files"] = val_path
|
|
125
|
+
config["data"]["n_speakers"] = len(spk_id_map)
|
|
126
|
+
config["num_languages"] = num_languages
|
|
127
|
+
config["num_tones"] = num_tones
|
|
128
|
+
config["symbols"] = symbols
|
|
129
|
+
|
|
130
|
+
with open(out_config_path, "w", encoding="utf-8") as f:
|
|
131
|
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
main()
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
import glob
|
|
4
|
+
import numpy as np
|
|
5
|
+
import soundfile as sf
|
|
6
|
+
import torchaudio
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
def split_sentence(text, min_len=10, language_str='EN'):
|
|
10
|
+
if language_str in ['EN', 'FR', 'ES', 'SP']:
|
|
11
|
+
sentences = split_sentences_latin(text, min_len=min_len)
|
|
12
|
+
else:
|
|
13
|
+
sentences = split_sentences_zh(text, min_len=min_len)
|
|
14
|
+
return sentences
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def split_sentences_latin(text, min_len=10):
|
|
18
|
+
text = re.sub('[。!?;]', '.', text)
|
|
19
|
+
text = re.sub('[,]', ',', text)
|
|
20
|
+
text = re.sub('[“”]', '"', text)
|
|
21
|
+
text = re.sub('[‘’]', "'", text)
|
|
22
|
+
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
|
|
23
|
+
return [item.strip() for item in txtsplit(text, 256, 512) if item.strip()]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def split_sentences_zh(text, min_len=10):
|
|
27
|
+
text = re.sub('[。!?;]', '.', text)
|
|
28
|
+
text = re.sub('[,]', ',', text)
|
|
29
|
+
# 将文本中的换行符、空格和制表符替换为空格
|
|
30
|
+
text = re.sub('[\n\t ]+', ' ', text)
|
|
31
|
+
# 在标点符号后添加一个空格
|
|
32
|
+
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
|
33
|
+
# 分隔句子并去除前后空格
|
|
34
|
+
# sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
|
|
35
|
+
sentences = [s.strip() for s in text.split('$#!')]
|
|
36
|
+
if len(sentences[-1]) == 0: del sentences[-1]
|
|
37
|
+
|
|
38
|
+
new_sentences = []
|
|
39
|
+
new_sent = []
|
|
40
|
+
count_len = 0
|
|
41
|
+
for ind, sent in enumerate(sentences):
|
|
42
|
+
new_sent.append(sent)
|
|
43
|
+
count_len += len(sent)
|
|
44
|
+
if count_len > min_len or ind == len(sentences) - 1:
|
|
45
|
+
count_len = 0
|
|
46
|
+
new_sentences.append(' '.join(new_sent))
|
|
47
|
+
new_sent = []
|
|
48
|
+
return merge_short_sentences_zh(new_sentences)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def merge_short_sentences_en(sens):
|
|
52
|
+
"""Avoid short sentences by merging them with the following sentence.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
List[str]: list of input sentences.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
List[str]: list of output sentences.
|
|
59
|
+
"""
|
|
60
|
+
sens_out = []
|
|
61
|
+
for s in sens:
|
|
62
|
+
# If the previous sentense is too short, merge them with
|
|
63
|
+
# the current sentence.
|
|
64
|
+
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
|
|
65
|
+
sens_out[-1] = sens_out[-1] + " " + s
|
|
66
|
+
else:
|
|
67
|
+
sens_out.append(s)
|
|
68
|
+
try:
|
|
69
|
+
if len(sens_out[-1].split(" ")) <= 2:
|
|
70
|
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
|
71
|
+
sens_out.pop(-1)
|
|
72
|
+
except:
|
|
73
|
+
pass
|
|
74
|
+
return sens_out
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def merge_short_sentences_zh(sens):
|
|
78
|
+
# return sens
|
|
79
|
+
"""Avoid short sentences by merging them with the following sentence.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
List[str]: list of input sentences.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
List[str]: list of output sentences.
|
|
86
|
+
"""
|
|
87
|
+
sens_out = []
|
|
88
|
+
for s in sens:
|
|
89
|
+
# If the previous sentense is too short, merge them with
|
|
90
|
+
# the current sentence.
|
|
91
|
+
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
|
|
92
|
+
sens_out[-1] = sens_out[-1] + " " + s
|
|
93
|
+
else:
|
|
94
|
+
sens_out.append(s)
|
|
95
|
+
try:
|
|
96
|
+
if len(sens_out[-1]) <= 2:
|
|
97
|
+
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
|
|
98
|
+
sens_out.pop(-1)
|
|
99
|
+
except:
|
|
100
|
+
pass
|
|
101
|
+
return sens_out
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def txtsplit(text, desired_length=100, max_length=200):
|
|
106
|
+
"""Split text it into chunks of a desired length trying to keep sentences intact."""
|
|
107
|
+
text = re.sub(r'\n\n+', '\n', text)
|
|
108
|
+
text = re.sub(r'\s+', ' ', text)
|
|
109
|
+
text = re.sub(r'[""]', '"', text)
|
|
110
|
+
text = re.sub(r'([,.?!])', r'\1 ', text)
|
|
111
|
+
text = re.sub(r'\s+', ' ', text)
|
|
112
|
+
|
|
113
|
+
rv = []
|
|
114
|
+
in_quote = False
|
|
115
|
+
current = ""
|
|
116
|
+
split_pos = []
|
|
117
|
+
pos = -1
|
|
118
|
+
end_pos = len(text) - 1
|
|
119
|
+
def seek(delta):
|
|
120
|
+
nonlocal pos, in_quote, current
|
|
121
|
+
is_neg = delta < 0
|
|
122
|
+
for _ in range(abs(delta)):
|
|
123
|
+
if is_neg:
|
|
124
|
+
pos -= 1
|
|
125
|
+
current = current[:-1]
|
|
126
|
+
else:
|
|
127
|
+
pos += 1
|
|
128
|
+
current += text[pos]
|
|
129
|
+
if text[pos] == '"':
|
|
130
|
+
in_quote = not in_quote
|
|
131
|
+
return text[pos]
|
|
132
|
+
def peek(delta):
|
|
133
|
+
p = pos + delta
|
|
134
|
+
return text[p] if p < end_pos and p >= 0 else ""
|
|
135
|
+
def commit():
|
|
136
|
+
nonlocal rv, current, split_pos
|
|
137
|
+
rv.append(current)
|
|
138
|
+
current = ""
|
|
139
|
+
split_pos = []
|
|
140
|
+
while pos < end_pos:
|
|
141
|
+
c = seek(1)
|
|
142
|
+
if len(current) >= max_length:
|
|
143
|
+
if len(split_pos) > 0 and len(current) > (desired_length / 2):
|
|
144
|
+
d = pos - split_pos[-1]
|
|
145
|
+
seek(-d)
|
|
146
|
+
else:
|
|
147
|
+
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
|
|
148
|
+
c = seek(-1)
|
|
149
|
+
commit()
|
|
150
|
+
elif not in_quote and (c in '!?\n' or (c in '.,' and peek(1) in '\n ')):
|
|
151
|
+
while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
|
|
152
|
+
c = seek(1)
|
|
153
|
+
split_pos.append(pos)
|
|
154
|
+
if len(current) >= desired_length:
|
|
155
|
+
commit()
|
|
156
|
+
elif in_quote and peek(1) == '"' and peek(2) in '\n ':
|
|
157
|
+
seek(2)
|
|
158
|
+
split_pos.append(pos)
|
|
159
|
+
rv.append(current)
|
|
160
|
+
rv = [s.strip() for s in rv]
|
|
161
|
+
rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
|
|
162
|
+
return rv
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
if __name__ == '__main__':
|
|
166
|
+
zh_text = "好的,我来给你讲一个故事吧。从前有一个小姑娘,她叫做小红。小红非常喜欢在森林里玩耍,她经常会和她的小伙伴们一起去探险。有一天,小红和她的小伙伴们走到了森林深处,突然遇到了一只凶猛的野兽。小红的小伙伴们都吓得不敢动弹,但是小红并没有被吓倒,她勇敢地走向野兽,用她的智慧和勇气成功地制服了野兽,保护了她的小伙伴们。从那以后,小红变得更加勇敢和自信,成为了她小伙伴们心中的英雄。"
|
|
167
|
+
en_text = "I didn’t know what to do. I said please kill her because it would be better than being kidnapped,” Ben, whose surname CNN is not using for security concerns, said on Wednesday. “It’s a nightmare. I said ‘please kill her, don’t take her there.’"
|
|
168
|
+
sp_text = "¡Claro! ¿En qué tema te gustaría que te hable en español? Puedo proporcionarte información o conversar contigo sobre una amplia variedad de temas, desde cultura y comida hasta viajes y tecnología. ¿Tienes alguna preferencia en particular?"
|
|
169
|
+
fr_text = "Bien sûr ! En quelle matière voudriez-vous que je vous parle en français ? Je peux vous fournir des informations ou discuter avec vous sur une grande variété de sujets, que ce soit la culture, la nourriture, les voyages ou la technologie. Avez-vous une préférence particulière ?"
|
|
170
|
+
|
|
171
|
+
print(split_sentence(zh_text, language_str='ZH'))
|
|
172
|
+
print(split_sentence(en_text, language_str='EN'))
|
|
173
|
+
print(split_sentence(sp_text, language_str='SP'))
|
|
174
|
+
print(split_sentence(fr_text, language_str='FR'))
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from .symbols import *
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
|
|
8
|
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
|
9
|
+
Args:
|
|
10
|
+
text: string to convert to a sequence
|
|
11
|
+
Returns:
|
|
12
|
+
List of integers corresponding to the symbols in the text
|
|
13
|
+
"""
|
|
14
|
+
symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
|
|
15
|
+
phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
|
|
16
|
+
tone_start = language_tone_start_map[language]
|
|
17
|
+
tones = [i + tone_start for i in tones]
|
|
18
|
+
lang_id = language_id_map[language]
|
|
19
|
+
lang_ids = [lang_id for i in phones]
|
|
20
|
+
return phones, tones, lang_ids
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_bert(norm_text, word2ph, language, device):
|
|
24
|
+
from .chinese_bert import get_bert_feature as zh_bert
|
|
25
|
+
from .english_bert import get_bert_feature as en_bert
|
|
26
|
+
from .japanese_bert import get_bert_feature as jp_bert
|
|
27
|
+
from .chinese_mix import get_bert_feature as zh_mix_en_bert
|
|
28
|
+
from .spanish_bert import get_bert_feature as sp_bert
|
|
29
|
+
from .french_bert import get_bert_feature as fr_bert
|
|
30
|
+
from .korean import get_bert_feature as kr_bert
|
|
31
|
+
|
|
32
|
+
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
|
|
33
|
+
'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
|
|
34
|
+
bert = lang_bert_func_map[language](norm_text, word2ph, device)
|
|
35
|
+
return bert
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
import cn2an
|
|
5
|
+
from pypinyin import lazy_pinyin, Style
|
|
6
|
+
|
|
7
|
+
from .symbols import punctuation
|
|
8
|
+
from .tone_sandhi import ToneSandhi
|
|
9
|
+
|
|
10
|
+
current_file_path = os.path.dirname(__file__)
|
|
11
|
+
pinyin_to_symbol_map = {
|
|
12
|
+
line.split("\t")[0]: line.strip().split("\t")[1]
|
|
13
|
+
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
import jieba.posseg as psg
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
rep_map = {
|
|
20
|
+
":": ",",
|
|
21
|
+
";": ",",
|
|
22
|
+
",": ",",
|
|
23
|
+
"。": ".",
|
|
24
|
+
"!": "!",
|
|
25
|
+
"?": "?",
|
|
26
|
+
"\n": ".",
|
|
27
|
+
"·": ",",
|
|
28
|
+
"、": ",",
|
|
29
|
+
"...": "…",
|
|
30
|
+
"$": ".",
|
|
31
|
+
"“": "'",
|
|
32
|
+
"”": "'",
|
|
33
|
+
"‘": "'",
|
|
34
|
+
"’": "'",
|
|
35
|
+
"(": "'",
|
|
36
|
+
")": "'",
|
|
37
|
+
"(": "'",
|
|
38
|
+
")": "'",
|
|
39
|
+
"《": "'",
|
|
40
|
+
"》": "'",
|
|
41
|
+
"【": "'",
|
|
42
|
+
"】": "'",
|
|
43
|
+
"[": "'",
|
|
44
|
+
"]": "'",
|
|
45
|
+
"—": "-",
|
|
46
|
+
"~": "-",
|
|
47
|
+
"~": "-",
|
|
48
|
+
"「": "'",
|
|
49
|
+
"」": "'",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
tone_modifier = ToneSandhi()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def replace_punctuation(text):
|
|
56
|
+
text = text.replace("嗯", "恩").replace("呣", "母")
|
|
57
|
+
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
|
58
|
+
|
|
59
|
+
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
|
60
|
+
|
|
61
|
+
replaced_text = re.sub(
|
|
62
|
+
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return replaced_text
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def g2p(text):
|
|
69
|
+
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
|
|
70
|
+
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
|
|
71
|
+
phones, tones, word2ph = _g2p(sentences)
|
|
72
|
+
assert sum(word2ph) == len(phones)
|
|
73
|
+
assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
|
|
74
|
+
phones = ["_"] + phones + ["_"]
|
|
75
|
+
tones = [0] + tones + [0]
|
|
76
|
+
word2ph = [1] + word2ph + [1]
|
|
77
|
+
return phones, tones, word2ph
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _get_initials_finals(word):
|
|
81
|
+
initials = []
|
|
82
|
+
finals = []
|
|
83
|
+
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
|
84
|
+
orig_finals = lazy_pinyin(
|
|
85
|
+
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
|
86
|
+
)
|
|
87
|
+
for c, v in zip(orig_initials, orig_finals):
|
|
88
|
+
initials.append(c)
|
|
89
|
+
finals.append(v)
|
|
90
|
+
return initials, finals
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _g2p(segments):
|
|
94
|
+
phones_list = []
|
|
95
|
+
tones_list = []
|
|
96
|
+
word2ph = []
|
|
97
|
+
for seg in segments:
|
|
98
|
+
# Replace all English words in the sentence
|
|
99
|
+
seg = re.sub("[a-zA-Z]+", "", seg)
|
|
100
|
+
seg_cut = psg.lcut(seg)
|
|
101
|
+
initials = []
|
|
102
|
+
finals = []
|
|
103
|
+
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
|
104
|
+
for word, pos in seg_cut:
|
|
105
|
+
if pos == "eng":
|
|
106
|
+
import pdb; pdb.set_trace()
|
|
107
|
+
continue
|
|
108
|
+
sub_initials, sub_finals = _get_initials_finals(word)
|
|
109
|
+
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
|
110
|
+
initials.append(sub_initials)
|
|
111
|
+
finals.append(sub_finals)
|
|
112
|
+
|
|
113
|
+
# assert len(sub_initials) == len(sub_finals) == len(word)
|
|
114
|
+
initials = sum(initials, [])
|
|
115
|
+
finals = sum(finals, [])
|
|
116
|
+
#
|
|
117
|
+
for c, v in zip(initials, finals):
|
|
118
|
+
raw_pinyin = c + v
|
|
119
|
+
# NOTE: post process for pypinyin outputs
|
|
120
|
+
# we discriminate i, ii and iii
|
|
121
|
+
if c == v:
|
|
122
|
+
assert c in punctuation
|
|
123
|
+
phone = [c]
|
|
124
|
+
tone = "0"
|
|
125
|
+
word2ph.append(1)
|
|
126
|
+
else:
|
|
127
|
+
v_without_tone = v[:-1]
|
|
128
|
+
tone = v[-1]
|
|
129
|
+
|
|
130
|
+
pinyin = c + v_without_tone
|
|
131
|
+
assert tone in "12345"
|
|
132
|
+
|
|
133
|
+
if c:
|
|
134
|
+
# 多音节
|
|
135
|
+
v_rep_map = {
|
|
136
|
+
"uei": "ui",
|
|
137
|
+
"iou": "iu",
|
|
138
|
+
"uen": "un",
|
|
139
|
+
}
|
|
140
|
+
if v_without_tone in v_rep_map.keys():
|
|
141
|
+
pinyin = c + v_rep_map[v_without_tone]
|
|
142
|
+
else:
|
|
143
|
+
# 单音节
|
|
144
|
+
pinyin_rep_map = {
|
|
145
|
+
"ing": "ying",
|
|
146
|
+
"i": "yi",
|
|
147
|
+
"in": "yin",
|
|
148
|
+
"u": "wu",
|
|
149
|
+
}
|
|
150
|
+
if pinyin in pinyin_rep_map.keys():
|
|
151
|
+
pinyin = pinyin_rep_map[pinyin]
|
|
152
|
+
else:
|
|
153
|
+
single_rep_map = {
|
|
154
|
+
"v": "yu",
|
|
155
|
+
"e": "e",
|
|
156
|
+
"i": "y",
|
|
157
|
+
"u": "w",
|
|
158
|
+
}
|
|
159
|
+
if pinyin[0] in single_rep_map.keys():
|
|
160
|
+
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
|
161
|
+
|
|
162
|
+
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
|
|
163
|
+
phone = pinyin_to_symbol_map[pinyin].split(" ")
|
|
164
|
+
word2ph.append(len(phone))
|
|
165
|
+
|
|
166
|
+
phones_list += phone
|
|
167
|
+
tones_list += [int(tone)] * len(phone)
|
|
168
|
+
return phones_list, tones_list, word2ph
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def text_normalize(text):
|
|
172
|
+
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
|
|
173
|
+
for number in numbers:
|
|
174
|
+
text = text.replace(number, cn2an.an2cn(number), 1)
|
|
175
|
+
text = replace_punctuation(text)
|
|
176
|
+
return text
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def get_bert_feature(text, word2ph, device=None):
|
|
180
|
+
from text import chinese_bert
|
|
181
|
+
|
|
182
|
+
return chinese_bert.get_bert_feature(text, word2ph, device=device)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
if __name__ == "__main__":
|
|
186
|
+
from text.chinese_bert import get_bert_feature
|
|
187
|
+
|
|
188
|
+
text = "啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
|
|
189
|
+
text = text_normalize(text)
|
|
190
|
+
print(text)
|
|
191
|
+
phones, tones, word2ph = g2p(text)
|
|
192
|
+
bert = get_bert_feature(text, word2ph)
|
|
193
|
+
|
|
194
|
+
print(phones, tones, word2ph, bert.shape)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# # 示例用法
|
|
198
|
+
# text = "这是一个示例文本:,你好!这是一个测试...."
|
|
199
|
+
# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import sys
|
|
3
|
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# model_id = 'hfl/chinese-roberta-wwm-ext-large'
|
|
7
|
+
local_path = "./bert/chinese-roberta-wwm-ext-large"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
tokenizers = {}
|
|
11
|
+
models = {}
|
|
12
|
+
|
|
13
|
+
def get_bert_feature(text, word2ph, device=None, model_id='hfl/chinese-roberta-wwm-ext-large'):
|
|
14
|
+
if model_id not in models:
|
|
15
|
+
models[model_id] = AutoModelForMaskedLM.from_pretrained(
|
|
16
|
+
model_id
|
|
17
|
+
).to(device)
|
|
18
|
+
tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
|
|
19
|
+
model = models[model_id]
|
|
20
|
+
tokenizer = tokenizers[model_id]
|
|
21
|
+
|
|
22
|
+
if (
|
|
23
|
+
sys.platform == "darwin"
|
|
24
|
+
and torch.backends.mps.is_available()
|
|
25
|
+
and device == "cpu"
|
|
26
|
+
):
|
|
27
|
+
device = "mps"
|
|
28
|
+
if not device:
|
|
29
|
+
device = "cuda"
|
|
30
|
+
|
|
31
|
+
with torch.no_grad():
|
|
32
|
+
inputs = tokenizer(text, return_tensors="pt")
|
|
33
|
+
for i in inputs:
|
|
34
|
+
inputs[i] = inputs[i].to(device)
|
|
35
|
+
res = model(**inputs, output_hidden_states=True)
|
|
36
|
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
37
|
+
# import pdb; pdb.set_trace()
|
|
38
|
+
# assert len(word2ph) == len(text) + 2
|
|
39
|
+
word2phone = word2ph
|
|
40
|
+
phone_level_feature = []
|
|
41
|
+
for i in range(len(word2phone)):
|
|
42
|
+
repeat_feature = res[i].repeat(word2phone[i], 1)
|
|
43
|
+
phone_level_feature.append(repeat_feature)
|
|
44
|
+
|
|
45
|
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
46
|
+
return phone_level_feature.T
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if __name__ == "__main__":
|
|
50
|
+
import torch
|
|
51
|
+
|
|
52
|
+
word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
|
|
53
|
+
word2phone = [
|
|
54
|
+
1,
|
|
55
|
+
2,
|
|
56
|
+
1,
|
|
57
|
+
2,
|
|
58
|
+
2,
|
|
59
|
+
1,
|
|
60
|
+
2,
|
|
61
|
+
2,
|
|
62
|
+
1,
|
|
63
|
+
2,
|
|
64
|
+
2,
|
|
65
|
+
1,
|
|
66
|
+
2,
|
|
67
|
+
2,
|
|
68
|
+
2,
|
|
69
|
+
2,
|
|
70
|
+
2,
|
|
71
|
+
1,
|
|
72
|
+
1,
|
|
73
|
+
2,
|
|
74
|
+
2,
|
|
75
|
+
1,
|
|
76
|
+
2,
|
|
77
|
+
2,
|
|
78
|
+
2,
|
|
79
|
+
2,
|
|
80
|
+
1,
|
|
81
|
+
2,
|
|
82
|
+
2,
|
|
83
|
+
2,
|
|
84
|
+
2,
|
|
85
|
+
2,
|
|
86
|
+
1,
|
|
87
|
+
2,
|
|
88
|
+
2,
|
|
89
|
+
2,
|
|
90
|
+
2,
|
|
91
|
+
1,
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
# 计算总帧数
|
|
95
|
+
total_frames = sum(word2phone)
|
|
96
|
+
print(word_level_feature.shape)
|
|
97
|
+
print(word2phone)
|
|
98
|
+
phone_level_feature = []
|
|
99
|
+
for i in range(len(word2phone)):
|
|
100
|
+
print(word_level_feature[i].shape)
|
|
101
|
+
|
|
102
|
+
# 对每个词重复word2phone[i]次
|
|
103
|
+
repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
|
|
104
|
+
phone_level_feature.append(repeat_feature)
|
|
105
|
+
|
|
106
|
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
107
|
+
print(phone_level_feature.shape) # torch.Size([36, 1024])
|