xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +54 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +24 -3
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +51 -1
- xinference/core/worker.py +315 -47
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +4 -6
- xinference/model/llm/core.py +5 -0
- xinference/model/llm/llama_cpp/core.py +46 -17
- xinference/model/llm/llm_family.json +530 -85
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +572 -1
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/reasoning_parser.py +3 -3
- xinference/model/llm/sglang/core.py +111 -13
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +31 -6
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +37 -15
- xinference/model/llm/vllm/core.py +184 -8
- xinference/model/llm/vllm/distributed_executor.py +320 -0
- xinference/model/rerank/core.py +22 -12
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
- xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import contextlib
|
|
16
|
+
import glob
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import subprocess
|
|
20
|
+
import traceback
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch.nn.parallel import DistributedDataParallel
|
|
24
|
+
import torch.distributed as dist
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@contextlib.contextmanager
|
|
28
|
+
def dist_load(path):
|
|
29
|
+
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
|
|
30
|
+
yield path
|
|
31
|
+
else:
|
|
32
|
+
from tts.utils.commons.hparams import hparams
|
|
33
|
+
from tts.utils.commons.trainer import LOCAL_RANK
|
|
34
|
+
tmpdir = '/dev/shm'
|
|
35
|
+
assert len(os.path.basename(path)) > 0
|
|
36
|
+
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
|
|
37
|
+
if LOCAL_RANK == 0:
|
|
38
|
+
subprocess.check_call(
|
|
39
|
+
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
|
|
40
|
+
f'cp -Lr {path} {shm_ckpt_path}', shell=True)
|
|
41
|
+
dist.barrier()
|
|
42
|
+
yield shm_ckpt_path
|
|
43
|
+
dist.barrier()
|
|
44
|
+
if LOCAL_RANK == 0:
|
|
45
|
+
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def torch_load_dist(path, map_location='cpu'):
|
|
49
|
+
with dist_load(path) as tmp_path:
|
|
50
|
+
checkpoint = torch.load(tmp_path, map_location=map_location)
|
|
51
|
+
return checkpoint
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_last_checkpoint(work_dir, steps=None):
|
|
55
|
+
checkpoint = None
|
|
56
|
+
last_ckpt_path = None
|
|
57
|
+
ckpt_paths = get_all_ckpts(work_dir, steps)
|
|
58
|
+
if len(ckpt_paths) > 0:
|
|
59
|
+
last_ckpt_path = ckpt_paths[0]
|
|
60
|
+
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
|
|
61
|
+
return checkpoint, last_ckpt_path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_all_ckpts(work_dir, steps=None):
|
|
65
|
+
if steps is None or steps == 0:
|
|
66
|
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
|
|
67
|
+
else:
|
|
68
|
+
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
|
|
69
|
+
return sorted(glob.glob(ckpt_path_pattern),
|
|
70
|
+
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
|
|
74
|
+
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
|
|
75
|
+
if checkpoint is None:
|
|
76
|
+
if os.path.isfile(ckpt_base_dir):
|
|
77
|
+
base_dir = os.path.dirname(ckpt_base_dir)
|
|
78
|
+
ckpt_path = ckpt_base_dir
|
|
79
|
+
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
|
|
80
|
+
else:
|
|
81
|
+
base_dir = ckpt_base_dir
|
|
82
|
+
if load_opt:
|
|
83
|
+
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
|
|
84
|
+
else:
|
|
85
|
+
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
|
|
86
|
+
if os.path.exists(ckpt_path):
|
|
87
|
+
checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
|
|
88
|
+
else:
|
|
89
|
+
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
|
|
90
|
+
if checkpoint is not None:
|
|
91
|
+
state_dict_all = {
|
|
92
|
+
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
|
|
93
|
+
if not isinstance(cur_model, list):
|
|
94
|
+
cur_models = [cur_model]
|
|
95
|
+
model_names = [model_name]
|
|
96
|
+
else:
|
|
97
|
+
cur_models = cur_model
|
|
98
|
+
model_names = model_name
|
|
99
|
+
for model_name, cur_model in zip(model_names, cur_models):
|
|
100
|
+
if isinstance(cur_model, DistributedDataParallel):
|
|
101
|
+
cur_model = cur_model.module
|
|
102
|
+
device = next(cur_model.parameters()).device
|
|
103
|
+
if '.' not in model_name:
|
|
104
|
+
state_dict = state_dict_all[model_name]
|
|
105
|
+
else:
|
|
106
|
+
base_model_name = model_name.split('.')[0]
|
|
107
|
+
rest_model_name = model_name[len(base_model_name) + 1:]
|
|
108
|
+
state_dict = {
|
|
109
|
+
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
|
|
110
|
+
if k.startswith(f'{rest_model_name}.')}
|
|
111
|
+
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
|
|
112
|
+
if not strict and delete_unmatch:
|
|
113
|
+
try:
|
|
114
|
+
cur_model.load_state_dict(state_dict, strict=True)
|
|
115
|
+
if not silent:
|
|
116
|
+
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
|
|
117
|
+
except:
|
|
118
|
+
cur_model_state_dict = cur_model.state_dict()
|
|
119
|
+
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
|
|
120
|
+
cur_model_state_dict.items()}
|
|
121
|
+
unmatched_keys = []
|
|
122
|
+
for key, param in state_dict.items():
|
|
123
|
+
if key in cur_model_state_dict:
|
|
124
|
+
new_param = cur_model_state_dict[key]
|
|
125
|
+
if new_param.shape != param.shape:
|
|
126
|
+
unmatched_keys.append(key)
|
|
127
|
+
print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
|
|
128
|
+
"ckpt model: ", param.shape)
|
|
129
|
+
for key in unmatched_keys:
|
|
130
|
+
del state_dict[key]
|
|
131
|
+
load_results = cur_model.load_state_dict(state_dict, strict=strict)
|
|
132
|
+
cur_model.to(device)
|
|
133
|
+
if not silent:
|
|
134
|
+
print(f"| loaded '{model_name}' from '{ckpt_path}'.")
|
|
135
|
+
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
|
|
136
|
+
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
|
|
137
|
+
if load_opt:
|
|
138
|
+
optimizer_states = checkpoint['optimizer_states']
|
|
139
|
+
assert len(opts) == len(optimizer_states)
|
|
140
|
+
for optimizer, opt_state in zip(opts, optimizer_states):
|
|
141
|
+
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
|
|
142
|
+
if optimizer is None:
|
|
143
|
+
return
|
|
144
|
+
try:
|
|
145
|
+
optimizer.load_state_dict(opt_state)
|
|
146
|
+
for i, state in enumerate(optimizer.state.values()):
|
|
147
|
+
for k, v in state.items():
|
|
148
|
+
if isinstance(v, torch.Tensor):
|
|
149
|
+
state[k] = v.to(device)
|
|
150
|
+
except ValueError:
|
|
151
|
+
print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
|
|
152
|
+
return checkpoint.get('global_step', 0)
|
|
153
|
+
else:
|
|
154
|
+
e_msg = f"| ckpt not found in {base_dir}."
|
|
155
|
+
if force:
|
|
156
|
+
assert False, e_msg
|
|
157
|
+
else:
|
|
158
|
+
print(e_msg)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def load_with_size_mismatch(model, state_dict, prefix=""):
|
|
162
|
+
current_model_dict = model.state_dict()
|
|
163
|
+
cm_keys = current_model_dict.keys()
|
|
164
|
+
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
|
|
165
|
+
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
|
|
166
|
+
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
|
167
|
+
print(f"| mismatch keys: ", mismatch_keys)
|
|
168
|
+
if len(missing_keys) > 0:
|
|
169
|
+
print(f"| missing_keys in dit: {missing_keys}")
|
|
170
|
+
if len(unexpected_keys) > 0:
|
|
171
|
+
print(f"| unexpected_keys in dit: {unexpected_keys}")
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
|
|
20
|
+
import yaml
|
|
21
|
+
|
|
22
|
+
global_print_hparams = True
|
|
23
|
+
hparams = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Args:
|
|
27
|
+
def __init__(self, **kwargs):
|
|
28
|
+
for k, v in kwargs.items():
|
|
29
|
+
self.__setattr__(k, v)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def override_config(old_config: dict, new_config: dict):
|
|
33
|
+
if new_config.get('__replace', False):
|
|
34
|
+
old_config.clear()
|
|
35
|
+
for k, v in new_config.items():
|
|
36
|
+
if isinstance(v, dict) and k in old_config:
|
|
37
|
+
override_config(old_config[k], new_config[k])
|
|
38
|
+
else:
|
|
39
|
+
old_config[k] = v
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def traverse_dict(d, func, ctx):
|
|
43
|
+
for k in list(d.keys()):
|
|
44
|
+
v = d[k]
|
|
45
|
+
if isinstance(v, dict):
|
|
46
|
+
traverse_dict(v, func, ctx)
|
|
47
|
+
else:
|
|
48
|
+
d[k] = func(v, ctx)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def parse_config(v, context=None):
|
|
52
|
+
if context is None:
|
|
53
|
+
context = {}
|
|
54
|
+
|
|
55
|
+
if isinstance(v, str):
|
|
56
|
+
if v.startswith('^'):
|
|
57
|
+
return load_config(v[1:], [], set())
|
|
58
|
+
|
|
59
|
+
match = re.match(r"\${(.*)}", v)
|
|
60
|
+
if match:
|
|
61
|
+
expression = match.group(1)
|
|
62
|
+
return eval(expression, {}, context)
|
|
63
|
+
return v
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def remove_meta_key(d):
|
|
67
|
+
for k in list(d.keys()):
|
|
68
|
+
v = d[k]
|
|
69
|
+
if isinstance(v, dict):
|
|
70
|
+
remove_meta_key(v)
|
|
71
|
+
else:
|
|
72
|
+
if k[:2] == '__':
|
|
73
|
+
del d[k]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_config(config_fn, config_chains, loaded_configs):
|
|
77
|
+
# deep first inheritance and avoid the second visit of one node
|
|
78
|
+
if not os.path.exists(config_fn):
|
|
79
|
+
print(f"| WARN: {config_fn} not exist.", )
|
|
80
|
+
return {}
|
|
81
|
+
with open(config_fn) as f:
|
|
82
|
+
hparams_ = yaml.safe_load(f)
|
|
83
|
+
loaded_configs.add(config_fn)
|
|
84
|
+
|
|
85
|
+
if 'base_config' in hparams_:
|
|
86
|
+
ret_hparams = {}
|
|
87
|
+
if not isinstance(hparams_['base_config'], list):
|
|
88
|
+
hparams_['base_config'] = [hparams_['base_config']]
|
|
89
|
+
for c in hparams_['base_config']:
|
|
90
|
+
if c.startswith('.'):
|
|
91
|
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
|
92
|
+
c = os.path.normpath(c)
|
|
93
|
+
if c not in loaded_configs:
|
|
94
|
+
override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
|
|
95
|
+
override_config(ret_hparams, hparams_)
|
|
96
|
+
else:
|
|
97
|
+
ret_hparams = hparams_
|
|
98
|
+
|
|
99
|
+
config_chains.append(config_fn)
|
|
100
|
+
return ret_hparams
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
|
104
|
+
if config == '' and exp_name == '':
|
|
105
|
+
parser = argparse.ArgumentParser(description='')
|
|
106
|
+
parser.add_argument('--config', type=str, default='',
|
|
107
|
+
help='location of the data corpus')
|
|
108
|
+
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
|
109
|
+
parser.add_argument('-hp', '--hparams', type=str, default='',
|
|
110
|
+
help='location of the data corpus')
|
|
111
|
+
parser.add_argument('--infer', action='store_true', help='infer')
|
|
112
|
+
parser.add_argument('--validate', action='store_true', help='validate')
|
|
113
|
+
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
|
114
|
+
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
|
115
|
+
parser.add_argument('--debug', action='store_true', help='debug')
|
|
116
|
+
parser.add_argument('--start_rank', type=int, default=-1,
|
|
117
|
+
help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
|
|
118
|
+
parser.add_argument('--world_size', type=int, default=-1,
|
|
119
|
+
help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
|
|
120
|
+
parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
|
|
121
|
+
parser.add_argument('--master_addr', type=str, default='', help='')
|
|
122
|
+
parser.add_argument('--ddp_dir', type=str, default='', help='')
|
|
123
|
+
|
|
124
|
+
args, unknown = parser.parse_known_args()
|
|
125
|
+
if print_hparams:
|
|
126
|
+
print("| set_hparams Unknow hparams: ", unknown)
|
|
127
|
+
else:
|
|
128
|
+
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
|
129
|
+
infer=False, validate=False, reset=False, debug=False, remove=False,
|
|
130
|
+
start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
|
|
131
|
+
global hparams
|
|
132
|
+
assert args.config != '' or args.exp_name != ''
|
|
133
|
+
if args.config != '':
|
|
134
|
+
assert os.path.exists(args.config), f"{args.config} not exists"
|
|
135
|
+
|
|
136
|
+
saved_hparams = {}
|
|
137
|
+
args_work_dir = ''
|
|
138
|
+
if args.exp_name != '':
|
|
139
|
+
args_work_dir = f'{args.exp_name}'
|
|
140
|
+
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
|
141
|
+
if os.path.exists(ckpt_config_path):
|
|
142
|
+
with open(ckpt_config_path) as f:
|
|
143
|
+
saved_hparams_ = yaml.safe_load(f)
|
|
144
|
+
if saved_hparams_ is not None:
|
|
145
|
+
saved_hparams.update(saved_hparams_)
|
|
146
|
+
hparams_ = {}
|
|
147
|
+
config_chains = []
|
|
148
|
+
if args.config != '':
|
|
149
|
+
hparams_.update(load_config(args.config, config_chains, set()))
|
|
150
|
+
if len(config_chains) > 1 and print_hparams:
|
|
151
|
+
print('| Hparams chains: ', config_chains)
|
|
152
|
+
if not args.reset:
|
|
153
|
+
hparams_.update(saved_hparams)
|
|
154
|
+
traverse_dict(hparams_, parse_config, hparams_)
|
|
155
|
+
hparams_['work_dir'] = args_work_dir
|
|
156
|
+
|
|
157
|
+
# Support config overriding in command line. Support list type config overriding.
|
|
158
|
+
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
|
159
|
+
if args.hparams != "":
|
|
160
|
+
for new_hparam in args.hparams.split(","):
|
|
161
|
+
k, v = new_hparam.split("=")
|
|
162
|
+
v = v.strip("\'\" ")
|
|
163
|
+
config_node = hparams_
|
|
164
|
+
for k_ in k.split(".")[:-1]:
|
|
165
|
+
config_node = config_node[k_]
|
|
166
|
+
k = k.split(".")[-1]
|
|
167
|
+
if k in config_node:
|
|
168
|
+
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
|
169
|
+
if type(config_node[k]) == list:
|
|
170
|
+
v = v.replace(" ", ",").replace('^', "\"")
|
|
171
|
+
if '|' in v:
|
|
172
|
+
tp = type(config_node[k][0]) if len(config_node[k]) else str
|
|
173
|
+
config_node[k] = [tp(x) for x in v.split("|") if x != '']
|
|
174
|
+
continue
|
|
175
|
+
config_node[k] = eval(v)
|
|
176
|
+
else:
|
|
177
|
+
config_node[k] = type(config_node[k])(v)
|
|
178
|
+
else:
|
|
179
|
+
config_node[k] = v
|
|
180
|
+
try:
|
|
181
|
+
config_node[k] = float(v)
|
|
182
|
+
except:
|
|
183
|
+
pass
|
|
184
|
+
try:
|
|
185
|
+
config_node[k] = int(v)
|
|
186
|
+
except:
|
|
187
|
+
pass
|
|
188
|
+
if v.lower() in ['false', 'true']:
|
|
189
|
+
config_node[k] = v.lower() == 'true'
|
|
190
|
+
|
|
191
|
+
if args_work_dir != '' and not args.infer:
|
|
192
|
+
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
|
193
|
+
|
|
194
|
+
hparams_['infer'] = args.infer
|
|
195
|
+
hparams_['debug'] = args.debug
|
|
196
|
+
hparams_['validate'] = args.validate
|
|
197
|
+
hparams_['exp_name'] = args.exp_name
|
|
198
|
+
|
|
199
|
+
hparams_['start_rank'] = args.start_rank # useful for multi-machine training
|
|
200
|
+
hparams_['world_size'] = args.world_size
|
|
201
|
+
hparams_['init_method'] = args.init_method
|
|
202
|
+
hparams_['ddp_dir'] = args.ddp_dir
|
|
203
|
+
hparams_['master_addr'] = args.master_addr
|
|
204
|
+
|
|
205
|
+
remove_meta_key(hparams_)
|
|
206
|
+
global global_print_hparams
|
|
207
|
+
if global_hparams:
|
|
208
|
+
hparams.clear()
|
|
209
|
+
hparams.update(hparams_)
|
|
210
|
+
if print_hparams and global_print_hparams and global_hparams:
|
|
211
|
+
print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
|
|
212
|
+
# for i, (k, v) in enumerate(sorted(hparams_.items())):
|
|
213
|
+
# print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
|
214
|
+
global_print_hparams = False
|
|
215
|
+
return hparams_
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
|
|
18
|
+
def map_phone_to_tokendict(item, pad_bos_eos=True):
|
|
19
|
+
# Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
|
|
20
|
+
phone = item['txt_token'].clone()
|
|
21
|
+
merged_phone = item['txt_token'].clone()
|
|
22
|
+
tone_tmp = item['tone'].clone()
|
|
23
|
+
# In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
|
|
24
|
+
tone_tmp[tone_tmp==4] = 1
|
|
25
|
+
tone_tmp[tone_tmp==11] = 2
|
|
26
|
+
tone_tmp[tone_tmp==12] = 3
|
|
27
|
+
tone_tmp[tone_tmp==13] = 4
|
|
28
|
+
tone_tmp[tone_tmp==14] = 5
|
|
29
|
+
tone_tmp[tone_tmp==15] = 6
|
|
30
|
+
# Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
|
|
31
|
+
ch_phone_idx = (phone >= 3) & (phone <= 100)
|
|
32
|
+
merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
|
|
33
|
+
|
|
34
|
+
if pad_bos_eos:
|
|
35
|
+
merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
|
|
36
|
+
merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
|
|
37
|
+
return merged_phone
|
|
38
|
+
|
|
39
|
+
def split_ph_timestamp(ph_timestamp):
|
|
40
|
+
''' Input: ph_timestamp, shape [T] '''
|
|
41
|
+
|
|
42
|
+
# Map the timestamp of each phone back to its original frame-level lengths
|
|
43
|
+
ph_timestamp[ph_timestamp >= 800] -= 800
|
|
44
|
+
|
|
45
|
+
ph_list = []
|
|
46
|
+
tone_list = []
|
|
47
|
+
dur_list = []
|
|
48
|
+
cur_timestamp = 0
|
|
49
|
+
for idx, item in enumerate(ph_timestamp):
|
|
50
|
+
if idx % 2 == 0:
|
|
51
|
+
# Map Chinese phones back to its original phone_dict
|
|
52
|
+
if (200 <= item <= 788):
|
|
53
|
+
ph = (item - 200 - 1) // 6 + 3
|
|
54
|
+
tone = (item - 200 - 1) % 6 + 1
|
|
55
|
+
if tone == 1:
|
|
56
|
+
tone = 4
|
|
57
|
+
else:
|
|
58
|
+
tone = tone + 9
|
|
59
|
+
# Set English tone to '3'
|
|
60
|
+
else:
|
|
61
|
+
ph = item
|
|
62
|
+
tone = 3
|
|
63
|
+
ph_list.append(ph)
|
|
64
|
+
tone_list.append(tone)
|
|
65
|
+
else:
|
|
66
|
+
dur_list.append((item - cur_timestamp))
|
|
67
|
+
cur_timestamp = item
|
|
68
|
+
assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
|
|
69
|
+
ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
|
|
70
|
+
return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
|
|
71
|
+
|
|
72
|
+
def split_ph(ph_seq):
|
|
73
|
+
''' Input: ph_timestamp, shape [T] '''
|
|
74
|
+
ph_list = []
|
|
75
|
+
tone_list = []
|
|
76
|
+
for idx, item in enumerate(ph_seq):
|
|
77
|
+
# Map Chinese phones back to its original phone_dict
|
|
78
|
+
if (200 <= item <= 788):
|
|
79
|
+
ph = (item - 200 - 1) // 6 + 3
|
|
80
|
+
tone = (item - 200 - 1) % 6 + 1
|
|
81
|
+
if tone == 1:
|
|
82
|
+
tone = 4
|
|
83
|
+
else:
|
|
84
|
+
tone = tone + 9
|
|
85
|
+
# Set English tone to '3'
|
|
86
|
+
else:
|
|
87
|
+
ph = item
|
|
88
|
+
tone = 3
|
|
89
|
+
ph_list.append(ph)
|
|
90
|
+
tone_list.append(tone)
|
|
91
|
+
|
|
92
|
+
assert len(ph_list) == len(tone_list)
|
|
93
|
+
ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
|
|
94
|
+
return ph_seq, tone_seq
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
def chunk_text_chinese(text, limit=60):
|
|
18
|
+
# 中文字符匹配
|
|
19
|
+
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
|
|
20
|
+
# 标点符号匹配
|
|
21
|
+
punctuation = ",。!?;:,\.!?;"
|
|
22
|
+
|
|
23
|
+
result = [] # 存储断句结果
|
|
24
|
+
current_chunk = [] # 当前片段
|
|
25
|
+
chinese_count = 0 # 中文字符计数
|
|
26
|
+
|
|
27
|
+
i = 0
|
|
28
|
+
while i < len(text):
|
|
29
|
+
char = text[i]
|
|
30
|
+
current_chunk.append(char)
|
|
31
|
+
if chinese_pattern.match(char):
|
|
32
|
+
chinese_count += 1
|
|
33
|
+
|
|
34
|
+
if chinese_count >= limit: # 达到限制字符数
|
|
35
|
+
# 从当前位置往前找最近的标点符号
|
|
36
|
+
for j in range(len(current_chunk) - 1, -1, -1):
|
|
37
|
+
if current_chunk[j] in punctuation:
|
|
38
|
+
result.append(''.join(current_chunk[:j + 1]))
|
|
39
|
+
current_chunk = current_chunk[j + 1:]
|
|
40
|
+
chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
|
|
41
|
+
break
|
|
42
|
+
else:
|
|
43
|
+
# 如果前面没有标点符号,则继续找后面的标点符号
|
|
44
|
+
for k in range(i + 1, len(text)):
|
|
45
|
+
if text[k] in punctuation:
|
|
46
|
+
result.append(''.join(current_chunk)+text[i+1:k+1])
|
|
47
|
+
current_chunk = []
|
|
48
|
+
chinese_count = 0
|
|
49
|
+
i = k
|
|
50
|
+
break
|
|
51
|
+
i+=1
|
|
52
|
+
|
|
53
|
+
# 添加最后剩余的部分
|
|
54
|
+
if current_chunk:
|
|
55
|
+
result.append(''.join(current_chunk))
|
|
56
|
+
|
|
57
|
+
return result
|
|
58
|
+
|
|
59
|
+
def chunk_text_english(text, max_chars=130):
|
|
60
|
+
"""
|
|
61
|
+
Splits the input text into chunks, each with a maximum number of characters.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
text (str): The text to be split.
|
|
65
|
+
max_chars (int): The maximum number of characters per chunk.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
List[str]: A list of text chunks.
|
|
69
|
+
"""
|
|
70
|
+
chunks = []
|
|
71
|
+
current_chunk = ""
|
|
72
|
+
# Split the text into sentences based on punctuation followed by whitespace
|
|
73
|
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
|
74
|
+
|
|
75
|
+
for sentence in sentences:
|
|
76
|
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
|
77
|
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
78
|
+
else:
|
|
79
|
+
if current_chunk:
|
|
80
|
+
chunks.append(current_chunk.strip())
|
|
81
|
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
82
|
+
|
|
83
|
+
if current_chunk:
|
|
84
|
+
chunks.append(current_chunk.strip())
|
|
85
|
+
|
|
86
|
+
return chunks
|
|
87
|
+
|
|
88
|
+
if __name__ == '__main__':
|
|
89
|
+
print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
|
|
90
|
+
print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))
|