xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import glob
17
+ import os
18
+ import re
19
+ import subprocess
20
+ import traceback
21
+
22
+ import torch
23
+ from torch.nn.parallel import DistributedDataParallel
24
+ import torch.distributed as dist
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def dist_load(path):
29
+ if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
30
+ yield path
31
+ else:
32
+ from tts.utils.commons.hparams import hparams
33
+ from tts.utils.commons.trainer import LOCAL_RANK
34
+ tmpdir = '/dev/shm'
35
+ assert len(os.path.basename(path)) > 0
36
+ shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
37
+ if LOCAL_RANK == 0:
38
+ subprocess.check_call(
39
+ f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
40
+ f'cp -Lr {path} {shm_ckpt_path}', shell=True)
41
+ dist.barrier()
42
+ yield shm_ckpt_path
43
+ dist.barrier()
44
+ if LOCAL_RANK == 0:
45
+ subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
46
+
47
+
48
+ def torch_load_dist(path, map_location='cpu'):
49
+ with dist_load(path) as tmp_path:
50
+ checkpoint = torch.load(tmp_path, map_location=map_location)
51
+ return checkpoint
52
+
53
+
54
+ def get_last_checkpoint(work_dir, steps=None):
55
+ checkpoint = None
56
+ last_ckpt_path = None
57
+ ckpt_paths = get_all_ckpts(work_dir, steps)
58
+ if len(ckpt_paths) > 0:
59
+ last_ckpt_path = ckpt_paths[0]
60
+ checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
61
+ return checkpoint, last_ckpt_path
62
+
63
+
64
+ def get_all_ckpts(work_dir, steps=None):
65
+ if steps is None or steps == 0:
66
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
67
+ else:
68
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
69
+ return sorted(glob.glob(ckpt_path_pattern),
70
+ key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
71
+
72
+
73
+ def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
74
+ silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
75
+ if checkpoint is None:
76
+ if os.path.isfile(ckpt_base_dir):
77
+ base_dir = os.path.dirname(ckpt_base_dir)
78
+ ckpt_path = ckpt_base_dir
79
+ checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
80
+ else:
81
+ base_dir = ckpt_base_dir
82
+ if load_opt:
83
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
84
+ else:
85
+ ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
86
+ if os.path.exists(ckpt_path):
87
+ checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
88
+ else:
89
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
90
+ if checkpoint is not None:
91
+ state_dict_all = {
92
+ k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
93
+ if not isinstance(cur_model, list):
94
+ cur_models = [cur_model]
95
+ model_names = [model_name]
96
+ else:
97
+ cur_models = cur_model
98
+ model_names = model_name
99
+ for model_name, cur_model in zip(model_names, cur_models):
100
+ if isinstance(cur_model, DistributedDataParallel):
101
+ cur_model = cur_model.module
102
+ device = next(cur_model.parameters()).device
103
+ if '.' not in model_name:
104
+ state_dict = state_dict_all[model_name]
105
+ else:
106
+ base_model_name = model_name.split('.')[0]
107
+ rest_model_name = model_name[len(base_model_name) + 1:]
108
+ state_dict = {
109
+ k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
110
+ if k.startswith(f'{rest_model_name}.')}
111
+ state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
112
+ if not strict and delete_unmatch:
113
+ try:
114
+ cur_model.load_state_dict(state_dict, strict=True)
115
+ if not silent:
116
+ print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
117
+ except:
118
+ cur_model_state_dict = cur_model.state_dict()
119
+ cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
120
+ cur_model_state_dict.items()}
121
+ unmatched_keys = []
122
+ for key, param in state_dict.items():
123
+ if key in cur_model_state_dict:
124
+ new_param = cur_model_state_dict[key]
125
+ if new_param.shape != param.shape:
126
+ unmatched_keys.append(key)
127
+ print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
128
+ "ckpt model: ", param.shape)
129
+ for key in unmatched_keys:
130
+ del state_dict[key]
131
+ load_results = cur_model.load_state_dict(state_dict, strict=strict)
132
+ cur_model.to(device)
133
+ if not silent:
134
+ print(f"| loaded '{model_name}' from '{ckpt_path}'.")
135
+ missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
136
+ print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
137
+ if load_opt:
138
+ optimizer_states = checkpoint['optimizer_states']
139
+ assert len(opts) == len(optimizer_states)
140
+ for optimizer, opt_state in zip(opts, optimizer_states):
141
+ opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
142
+ if optimizer is None:
143
+ return
144
+ try:
145
+ optimizer.load_state_dict(opt_state)
146
+ for i, state in enumerate(optimizer.state.values()):
147
+ for k, v in state.items():
148
+ if isinstance(v, torch.Tensor):
149
+ state[k] = v.to(device)
150
+ except ValueError:
151
+ print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
152
+ return checkpoint.get('global_step', 0)
153
+ else:
154
+ e_msg = f"| ckpt not found in {base_dir}."
155
+ if force:
156
+ assert False, e_msg
157
+ else:
158
+ print(e_msg)
159
+
160
+
161
+ def load_with_size_mismatch(model, state_dict, prefix=""):
162
+ current_model_dict = model.state_dict()
163
+ cm_keys = current_model_dict.keys()
164
+ mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
165
+ new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
166
+ missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
167
+ print(f"| mismatch keys: ", mismatch_keys)
168
+ if len(missing_keys) > 0:
169
+ print(f"| missing_keys in dit: {missing_keys}")
170
+ if len(unexpected_keys) > 0:
171
+ print(f"| unexpected_keys in dit: {unexpected_keys}")
@@ -0,0 +1,215 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import re
19
+
20
+ import yaml
21
+
22
+ global_print_hparams = True
23
+ hparams = {}
24
+
25
+
26
+ class Args:
27
+ def __init__(self, **kwargs):
28
+ for k, v in kwargs.items():
29
+ self.__setattr__(k, v)
30
+
31
+
32
+ def override_config(old_config: dict, new_config: dict):
33
+ if new_config.get('__replace', False):
34
+ old_config.clear()
35
+ for k, v in new_config.items():
36
+ if isinstance(v, dict) and k in old_config:
37
+ override_config(old_config[k], new_config[k])
38
+ else:
39
+ old_config[k] = v
40
+
41
+
42
+ def traverse_dict(d, func, ctx):
43
+ for k in list(d.keys()):
44
+ v = d[k]
45
+ if isinstance(v, dict):
46
+ traverse_dict(v, func, ctx)
47
+ else:
48
+ d[k] = func(v, ctx)
49
+
50
+
51
+ def parse_config(v, context=None):
52
+ if context is None:
53
+ context = {}
54
+
55
+ if isinstance(v, str):
56
+ if v.startswith('^'):
57
+ return load_config(v[1:], [], set())
58
+
59
+ match = re.match(r"\${(.*)}", v)
60
+ if match:
61
+ expression = match.group(1)
62
+ return eval(expression, {}, context)
63
+ return v
64
+
65
+
66
+ def remove_meta_key(d):
67
+ for k in list(d.keys()):
68
+ v = d[k]
69
+ if isinstance(v, dict):
70
+ remove_meta_key(v)
71
+ else:
72
+ if k[:2] == '__':
73
+ del d[k]
74
+
75
+
76
+ def load_config(config_fn, config_chains, loaded_configs):
77
+ # deep first inheritance and avoid the second visit of one node
78
+ if not os.path.exists(config_fn):
79
+ print(f"| WARN: {config_fn} not exist.", )
80
+ return {}
81
+ with open(config_fn) as f:
82
+ hparams_ = yaml.safe_load(f)
83
+ loaded_configs.add(config_fn)
84
+
85
+ if 'base_config' in hparams_:
86
+ ret_hparams = {}
87
+ if not isinstance(hparams_['base_config'], list):
88
+ hparams_['base_config'] = [hparams_['base_config']]
89
+ for c in hparams_['base_config']:
90
+ if c.startswith('.'):
91
+ c = f'{os.path.dirname(config_fn)}/{c}'
92
+ c = os.path.normpath(c)
93
+ if c not in loaded_configs:
94
+ override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
95
+ override_config(ret_hparams, hparams_)
96
+ else:
97
+ ret_hparams = hparams_
98
+
99
+ config_chains.append(config_fn)
100
+ return ret_hparams
101
+
102
+
103
+ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
104
+ if config == '' and exp_name == '':
105
+ parser = argparse.ArgumentParser(description='')
106
+ parser.add_argument('--config', type=str, default='',
107
+ help='location of the data corpus')
108
+ parser.add_argument('--exp_name', type=str, default='', help='exp_name')
109
+ parser.add_argument('-hp', '--hparams', type=str, default='',
110
+ help='location of the data corpus')
111
+ parser.add_argument('--infer', action='store_true', help='infer')
112
+ parser.add_argument('--validate', action='store_true', help='validate')
113
+ parser.add_argument('--reset', action='store_true', help='reset hparams')
114
+ parser.add_argument('--remove', action='store_true', help='remove old ckpt')
115
+ parser.add_argument('--debug', action='store_true', help='debug')
116
+ parser.add_argument('--start_rank', type=int, default=-1,
117
+ help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
118
+ parser.add_argument('--world_size', type=int, default=-1,
119
+ help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
120
+ parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
121
+ parser.add_argument('--master_addr', type=str, default='', help='')
122
+ parser.add_argument('--ddp_dir', type=str, default='', help='')
123
+
124
+ args, unknown = parser.parse_known_args()
125
+ if print_hparams:
126
+ print("| set_hparams Unknow hparams: ", unknown)
127
+ else:
128
+ args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
129
+ infer=False, validate=False, reset=False, debug=False, remove=False,
130
+ start_rank=-1, world_size=-1, init_method='tcp', ddp_dir='', master_addr='')
131
+ global hparams
132
+ assert args.config != '' or args.exp_name != ''
133
+ if args.config != '':
134
+ assert os.path.exists(args.config), f"{args.config} not exists"
135
+
136
+ saved_hparams = {}
137
+ args_work_dir = ''
138
+ if args.exp_name != '':
139
+ args_work_dir = f'{args.exp_name}'
140
+ ckpt_config_path = f'{args_work_dir}/config.yaml'
141
+ if os.path.exists(ckpt_config_path):
142
+ with open(ckpt_config_path) as f:
143
+ saved_hparams_ = yaml.safe_load(f)
144
+ if saved_hparams_ is not None:
145
+ saved_hparams.update(saved_hparams_)
146
+ hparams_ = {}
147
+ config_chains = []
148
+ if args.config != '':
149
+ hparams_.update(load_config(args.config, config_chains, set()))
150
+ if len(config_chains) > 1 and print_hparams:
151
+ print('| Hparams chains: ', config_chains)
152
+ if not args.reset:
153
+ hparams_.update(saved_hparams)
154
+ traverse_dict(hparams_, parse_config, hparams_)
155
+ hparams_['work_dir'] = args_work_dir
156
+
157
+ # Support config overriding in command line. Support list type config overriding.
158
+ # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
159
+ if args.hparams != "":
160
+ for new_hparam in args.hparams.split(","):
161
+ k, v = new_hparam.split("=")
162
+ v = v.strip("\'\" ")
163
+ config_node = hparams_
164
+ for k_ in k.split(".")[:-1]:
165
+ config_node = config_node[k_]
166
+ k = k.split(".")[-1]
167
+ if k in config_node:
168
+ if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
169
+ if type(config_node[k]) == list:
170
+ v = v.replace(" ", ",").replace('^', "\"")
171
+ if '|' in v:
172
+ tp = type(config_node[k][0]) if len(config_node[k]) else str
173
+ config_node[k] = [tp(x) for x in v.split("|") if x != '']
174
+ continue
175
+ config_node[k] = eval(v)
176
+ else:
177
+ config_node[k] = type(config_node[k])(v)
178
+ else:
179
+ config_node[k] = v
180
+ try:
181
+ config_node[k] = float(v)
182
+ except:
183
+ pass
184
+ try:
185
+ config_node[k] = int(v)
186
+ except:
187
+ pass
188
+ if v.lower() in ['false', 'true']:
189
+ config_node[k] = v.lower() == 'true'
190
+
191
+ if args_work_dir != '' and not args.infer:
192
+ os.makedirs(hparams_['work_dir'], exist_ok=True)
193
+
194
+ hparams_['infer'] = args.infer
195
+ hparams_['debug'] = args.debug
196
+ hparams_['validate'] = args.validate
197
+ hparams_['exp_name'] = args.exp_name
198
+
199
+ hparams_['start_rank'] = args.start_rank # useful for multi-machine training
200
+ hparams_['world_size'] = args.world_size
201
+ hparams_['init_method'] = args.init_method
202
+ hparams_['ddp_dir'] = args.ddp_dir
203
+ hparams_['master_addr'] = args.master_addr
204
+
205
+ remove_meta_key(hparams_)
206
+ global global_print_hparams
207
+ if global_hparams:
208
+ hparams.clear()
209
+ hparams.update(hparams_)
210
+ if print_hparams and global_print_hparams and global_hparams:
211
+ print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
212
+ # for i, (k, v) in enumerate(sorted(hparams_.items())):
213
+ # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
214
+ global_print_hparams = False
215
+ return hparams_
@@ -0,0 +1 @@
1
+ {"phone": ["C0a", "C0ai", "C0air", "C0an", "C0ang", "C0angr", "C0anr", "C0ao", "C0aor", "C0ar", "C0b", "C0c", "C0ch", "C0d", "C0e", "C0ei", "C0eir", "C0en", "C0eng", "C0engr", "C0enr", "C0er", "C0f", "C0g", "C0h", "C0i", "C0ia", "C0ian", "C0iang", "C0iangr", "C0ianr", "C0iao", "C0iaor", "C0iar", "C0ie", "C0ier", "C0ii", "C0iii", "C0iiir", "C0iir", "C0in", "C0ing", "C0ingr", "C0inr", "C0io", "C0iong", "C0iongr", "C0iou", "C0iour", "C0ir", "C0j", "C0k", "C0l", "C0m", "C0n", "C0ng", "C0o", "C0ong", "C0ongr", "C0or", "C0ou", "C0our", "C0p", "C0q", "C0r", "C0s", "C0sh", "C0t", "C0u", "C0ua", "C0uai", "C0uair", "C0uan", "C0uang", "C0uangr", "C0uanr", "C0uar", "C0uei", "C0ueir", "C0uen", "C0ueng", "C0uengr", "C0uenr", "C0uo", "C0uor", "C0ur", "C0v", "C0van", "C0vanr", "C0ve", "C0ver", "C0vn", "C0vnr", "C0vr", "C0x", "C0z", "C0zh", "C0_", "E0aa", "E0ae", "E0ah", "E0ao", "E0aw", "E0ax", "E0ay", "E0b", "E0ch", "E0d", "E0dh", "E0eh", "E0ehr", "E0er", "E0ey", "E0f", "E0g", "E0hh", "E0ih", "E0iy", "E0iyr", "E0jh", "E0k", "E0l", "E0m", "E0n", "E0ng", "E0oh", "E0ow", "E0oy", "E0p", "E0r", "E0s", "E0sh", "E0t", "E0th", "E0uh", "E0uw", "E0uwr", "E0v", "E0w", "E0y", "E0z", "E0zh", "sil", "…", "、", "。", "《", "》", "【", "】", "!", """, "#", "$", "%", "'", "''", "(", ")", "*", ",", ":", ";", "?", "\", "^", "_", "`", "{", "}", "~"], "tone": ["0", "1", "10", "11", "12", "13", "15", "17", "2", "3", "4", "5", "6", "7", "8", "9"], "wordCategory": ["0", "B", "E", "M", "S"], "prosody": ["0", "1", "2", "3", "4"], "focus": ["0", "1"], "intonation": ["0", "1", "2"], "phraseAccent": ["0", "H-", "L-"], "boundaryTone": ["0", "H%", "L%"], "accentType": ["!H*", "0", "H*", "L*", "L*+H", "L+H*"]}
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ def map_phone_to_tokendict(item, pad_bos_eos=True):
19
+ # Merge Chinese phone and tone (Original dict ends at 173, i.e., ph_dict_size=173). 146~173 is punctuations.
20
+ phone = item['txt_token'].clone()
21
+ merged_phone = item['txt_token'].clone()
22
+ tone_tmp = item['tone'].clone()
23
+ # In tone_dict, tone_1 is 4, tone_2 is 11, tone_3 is 12, tone_4 is 13, tone_5 is 14, tone_6 is 15
24
+ tone_tmp[tone_tmp==4] = 1
25
+ tone_tmp[tone_tmp==11] = 2
26
+ tone_tmp[tone_tmp==12] = 3
27
+ tone_tmp[tone_tmp==13] = 4
28
+ tone_tmp[tone_tmp==14] = 5
29
+ tone_tmp[tone_tmp==15] = 6
30
+ # Chinese phones lie in 3~100 in the phone_dict, we map them to 200~788
31
+ ch_phone_idx = (phone >= 3) & (phone <= 100)
32
+ merged_phone[ch_phone_idx] = (merged_phone[ch_phone_idx] - 3) * 6 + 200 + tone_tmp[ch_phone_idx]
33
+
34
+ if pad_bos_eos:
35
+ merged_phone = F.pad(merged_phone, (1, 0), mode='constant', value=798)
36
+ merged_phone = F.pad(merged_phone, (0, 1), mode='constant', value=799)
37
+ return merged_phone
38
+
39
+ def split_ph_timestamp(ph_timestamp):
40
+ ''' Input: ph_timestamp, shape [T] '''
41
+
42
+ # Map the timestamp of each phone back to its original frame-level lengths
43
+ ph_timestamp[ph_timestamp >= 800] -= 800
44
+
45
+ ph_list = []
46
+ tone_list = []
47
+ dur_list = []
48
+ cur_timestamp = 0
49
+ for idx, item in enumerate(ph_timestamp):
50
+ if idx % 2 == 0:
51
+ # Map Chinese phones back to its original phone_dict
52
+ if (200 <= item <= 788):
53
+ ph = (item - 200 - 1) // 6 + 3
54
+ tone = (item - 200 - 1) % 6 + 1
55
+ if tone == 1:
56
+ tone = 4
57
+ else:
58
+ tone = tone + 9
59
+ # Set English tone to '3'
60
+ else:
61
+ ph = item
62
+ tone = 3
63
+ ph_list.append(ph)
64
+ tone_list.append(tone)
65
+ else:
66
+ dur_list.append((item - cur_timestamp))
67
+ cur_timestamp = item
68
+ assert len(ph_list) == len(dur_list), f"{len(ph_list)}, {len(dur_list)}"
69
+ ph_seq, tone_seq, dur_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list), torch.LongTensor(dur_list)
70
+ return ph_seq, tone_seq, dur_seq, ph_timestamp[-1]
71
+
72
+ def split_ph(ph_seq):
73
+ ''' Input: ph_timestamp, shape [T] '''
74
+ ph_list = []
75
+ tone_list = []
76
+ for idx, item in enumerate(ph_seq):
77
+ # Map Chinese phones back to its original phone_dict
78
+ if (200 <= item <= 788):
79
+ ph = (item - 200 - 1) // 6 + 3
80
+ tone = (item - 200 - 1) % 6 + 1
81
+ if tone == 1:
82
+ tone = 4
83
+ else:
84
+ tone = tone + 9
85
+ # Set English tone to '3'
86
+ else:
87
+ ph = item
88
+ tone = 3
89
+ ph_list.append(ph)
90
+ tone_list.append(tone)
91
+
92
+ assert len(ph_list) == len(tone_list)
93
+ ph_seq, tone_seq = torch.LongTensor(ph_list), torch.LongTensor(tone_list)
94
+ return ph_seq, tone_seq
@@ -0,0 +1,90 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ def chunk_text_chinese(text, limit=60):
18
+ # 中文字符匹配
19
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
20
+ # 标点符号匹配
21
+ punctuation = ",。!?;:,\.!?;"
22
+
23
+ result = [] # 存储断句结果
24
+ current_chunk = [] # 当前片段
25
+ chinese_count = 0 # 中文字符计数
26
+
27
+ i = 0
28
+ while i < len(text):
29
+ char = text[i]
30
+ current_chunk.append(char)
31
+ if chinese_pattern.match(char):
32
+ chinese_count += 1
33
+
34
+ if chinese_count >= limit: # 达到限制字符数
35
+ # 从当前位置往前找最近的标点符号
36
+ for j in range(len(current_chunk) - 1, -1, -1):
37
+ if current_chunk[j] in punctuation:
38
+ result.append(''.join(current_chunk[:j + 1]))
39
+ current_chunk = current_chunk[j + 1:]
40
+ chinese_count = sum(1 for c in current_chunk if chinese_pattern.match(c))
41
+ break
42
+ else:
43
+ # 如果前面没有标点符号,则继续找后面的标点符号
44
+ for k in range(i + 1, len(text)):
45
+ if text[k] in punctuation:
46
+ result.append(''.join(current_chunk)+text[i+1:k+1])
47
+ current_chunk = []
48
+ chinese_count = 0
49
+ i = k
50
+ break
51
+ i+=1
52
+
53
+ # 添加最后剩余的部分
54
+ if current_chunk:
55
+ result.append(''.join(current_chunk))
56
+
57
+ return result
58
+
59
+ def chunk_text_english(text, max_chars=130):
60
+ """
61
+ Splits the input text into chunks, each with a maximum number of characters.
62
+
63
+ Args:
64
+ text (str): The text to be split.
65
+ max_chars (int): The maximum number of characters per chunk.
66
+
67
+ Returns:
68
+ List[str]: A list of text chunks.
69
+ """
70
+ chunks = []
71
+ current_chunk = ""
72
+ # Split the text into sentences based on punctuation followed by whitespace
73
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
74
+
75
+ for sentence in sentences:
76
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
77
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
78
+ else:
79
+ if current_chunk:
80
+ chunks.append(current_chunk.strip())
81
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
82
+
83
+ if current_chunk:
84
+ chunks.append(current_chunk.strip())
85
+
86
+ return chunks
87
+
88
+ if __name__ == '__main__':
89
+ print(chunk_text_chinese("哇塞!家人们,你们太好运了。我居然发现了一个宝藏零食大礼包,简直适合所有人的口味!有香辣的,让你舌尖跳舞;有盐焗的,咸香可口;还有五香的,香气四溢。就连怀孕的姐妹都吃得津津有味!整整三十包啊!什么手撕蟹柳、辣子鸡、嫩豆干、手撕素肉、鹌鹑蛋、小肉枣肠、猪肉腐、魔芋、魔芋丝等等,应有尽有。香辣土豆爽辣过瘾,各种素肉嚼劲十足,鹌鹑蛋营养美味,真的太多太多啦,...家人们,现在价格太划算了,赶紧下单。"))
90
+ print(chunk_text_english("Washington CNN When President Donald Trump declared in the House Chamber this week that executives at the nation’s top automakers were “so excited” about their prospects amid his new tariff regime, it did not entirely reflect the conversation he’d held with them earlier that day."))