xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +107 -11
- xinference/client/restful/restful_client.py +51 -11
- xinference/constants.py +5 -1
- xinference/core/media_interface.py +758 -0
- xinference/core/model.py +49 -9
- xinference/core/supervisor.py +1 -1
- xinference/core/utils.py +1 -1
- xinference/core/worker.py +33 -39
- xinference/deploy/cmdline.py +17 -0
- xinference/deploy/utils.py +0 -3
- xinference/model/audio/__init__.py +16 -27
- xinference/model/audio/core.py +2 -1
- xinference/model/audio/cosyvoice.py +4 -2
- xinference/model/audio/model_spec.json +63 -46
- xinference/model/audio/model_spec_modelscope.json +31 -14
- xinference/model/embedding/__init__.py +16 -24
- xinference/model/image/__init__.py +15 -25
- xinference/model/llm/__init__.py +40 -115
- xinference/model/llm/core.py +29 -6
- xinference/model/llm/llama_cpp/core.py +30 -347
- xinference/model/llm/llm_family.json +1674 -2203
- xinference/model/llm/llm_family.py +71 -7
- xinference/model/llm/llm_family_csghub.json +0 -32
- xinference/model/llm/llm_family_modelscope.json +1838 -2016
- xinference/model/llm/llm_family_openmind_hub.json +19 -325
- xinference/model/llm/lmdeploy/core.py +7 -2
- xinference/model/llm/mlx/core.py +23 -7
- xinference/model/llm/reasoning_parser.py +281 -5
- xinference/model/llm/sglang/core.py +39 -11
- xinference/model/llm/transformers/chatglm.py +9 -2
- xinference/model/llm/transformers/cogagent.py +10 -12
- xinference/model/llm/transformers/cogvlm2.py +6 -3
- xinference/model/llm/transformers/cogvlm2_video.py +3 -6
- xinference/model/llm/transformers/core.py +58 -60
- xinference/model/llm/transformers/deepseek_v2.py +4 -2
- xinference/model/llm/transformers/deepseek_vl.py +10 -4
- xinference/model/llm/transformers/deepseek_vl2.py +9 -4
- xinference/model/llm/transformers/gemma3.py +4 -5
- xinference/model/llm/transformers/glm4v.py +3 -21
- xinference/model/llm/transformers/glm_edge_v.py +3 -20
- xinference/model/llm/transformers/intern_vl.py +3 -6
- xinference/model/llm/transformers/internlm2.py +1 -1
- xinference/model/llm/transformers/minicpmv25.py +4 -2
- xinference/model/llm/transformers/minicpmv26.py +5 -3
- xinference/model/llm/transformers/omnilmm.py +1 -1
- xinference/model/llm/transformers/opt.py +1 -1
- xinference/model/llm/transformers/ovis2.py +302 -0
- xinference/model/llm/transformers/qwen-omni.py +8 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +5 -1
- xinference/model/llm/transformers/qwen_vl.py +5 -2
- xinference/model/llm/utils.py +96 -45
- xinference/model/llm/vllm/core.py +108 -24
- xinference/model/llm/vllm/distributed_executor.py +8 -7
- xinference/model/llm/vllm/xavier/allocator.py +1 -1
- xinference/model/llm/vllm/xavier/block_manager.py +1 -1
- xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
- xinference/model/llm/vllm/xavier/executor.py +1 -1
- xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
- xinference/model/rerank/__init__.py +13 -24
- xinference/model/video/__init__.py +15 -25
- xinference/model/video/core.py +3 -3
- xinference/model/video/diffusers.py +157 -13
- xinference/model/video/model_spec.json +100 -0
- xinference/model/video/model_spec_modelscope.json +104 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
- xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
- xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
- xinference/thirdparty/cosyvoice/bin/train.py +7 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
- xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
- xinference/thirdparty/cosyvoice/cli/model.py +140 -155
- xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
- xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
- xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
- xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
- xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
- xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
- xinference/thirdparty/cosyvoice/utils/common.py +1 -1
- xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
- xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
- xinference/types.py +2 -71
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
- xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
- xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
- xinference/web/ui/src/locales/en.json +7 -4
- xinference/web/ui/src/locales/zh.json +7 -4
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
- xinference/core/image_interface.py +0 -377
- xinference/model/llm/transformers/compression.py +0 -258
- xinference/model/llm/transformers/yi_vl.py +0 -239
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
- xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
- xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
- xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
- /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -23,7 +23,8 @@ import torch
|
|
|
23
23
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
24
24
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
25
25
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
26
|
-
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
26
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
27
|
+
from cosyvoice.utils.file_utils import logging
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def get_args():
|
|
@@ -37,6 +38,16 @@ def get_args():
|
|
|
37
38
|
return args
|
|
38
39
|
|
|
39
40
|
|
|
41
|
+
def get_optimized_script(model, preserved_attrs=[]):
|
|
42
|
+
script = torch.jit.script(model)
|
|
43
|
+
if preserved_attrs != []:
|
|
44
|
+
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
|
45
|
+
else:
|
|
46
|
+
script = torch.jit.freeze(script)
|
|
47
|
+
script = torch.jit.optimize_for_inference(script)
|
|
48
|
+
return script
|
|
49
|
+
|
|
50
|
+
|
|
40
51
|
def main():
|
|
41
52
|
args = get_args()
|
|
42
53
|
logging.basicConfig(level=logging.DEBUG,
|
|
@@ -46,28 +57,47 @@ def main():
|
|
|
46
57
|
torch._C._jit_set_profiling_mode(False)
|
|
47
58
|
torch._C._jit_set_profiling_executor(False)
|
|
48
59
|
|
|
49
|
-
|
|
60
|
+
try:
|
|
61
|
+
model = CosyVoice(args.model_dir)
|
|
62
|
+
except Exception:
|
|
63
|
+
try:
|
|
64
|
+
# NOTE set use_flow_cache=True when export jit for cache inference
|
|
65
|
+
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
66
|
+
except Exception:
|
|
67
|
+
raise TypeError('no valid model_type!')
|
|
50
68
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
if not isinstance(model, CosyVoice2):
|
|
70
|
+
# 1. export llm text_encoder
|
|
71
|
+
llm_text_encoder = model.model.llm.text_encoder
|
|
72
|
+
script = get_optimized_script(llm_text_encoder)
|
|
73
|
+
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
|
74
|
+
script = get_optimized_script(llm_text_encoder.half())
|
|
75
|
+
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
|
76
|
+
logging.info('successfully export llm_text_encoder')
|
|
57
77
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
78
|
+
# 2. export llm llm
|
|
79
|
+
llm_llm = model.model.llm.llm
|
|
80
|
+
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
|
81
|
+
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
|
82
|
+
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
|
83
|
+
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
|
84
|
+
logging.info('successfully export llm_llm')
|
|
64
85
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
86
|
+
# 3. export flow encoder
|
|
87
|
+
flow_encoder = model.model.flow.encoder
|
|
88
|
+
script = get_optimized_script(flow_encoder)
|
|
89
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
90
|
+
script = get_optimized_script(flow_encoder.half())
|
|
91
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
92
|
+
logging.info('successfully export flow_encoder')
|
|
93
|
+
else:
|
|
94
|
+
# 3. export flow encoder
|
|
95
|
+
flow_encoder = model.model.flow.encoder
|
|
96
|
+
script = get_optimized_script(flow_encoder, ['forward_chunk'])
|
|
97
|
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
98
|
+
script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
|
|
99
|
+
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
100
|
+
logging.info('successfully export flow_encoder')
|
|
71
101
|
|
|
72
102
|
|
|
73
103
|
if __name__ == '__main__':
|
|
@@ -27,7 +27,8 @@ from tqdm import tqdm
|
|
|
27
27
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
28
28
|
sys.path.append('{}/../..'.format(ROOT_DIR))
|
|
29
29
|
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
30
|
-
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
30
|
+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
31
|
+
from cosyvoice.utils.file_utils import logging
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
|
@@ -51,61 +52,145 @@ def get_args():
|
|
|
51
52
|
return args
|
|
52
53
|
|
|
53
54
|
|
|
55
|
+
@torch.no_grad()
|
|
54
56
|
def main():
|
|
55
57
|
args = get_args()
|
|
56
58
|
logging.basicConfig(level=logging.DEBUG,
|
|
57
59
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
58
60
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
estimator
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
'
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
'
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
61
|
+
try:
|
|
62
|
+
model = CosyVoice(args.model_dir)
|
|
63
|
+
except Exception:
|
|
64
|
+
try:
|
|
65
|
+
# NOTE set use_flow_cache=True when export jit for cache inference
|
|
66
|
+
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
67
|
+
except Exception:
|
|
68
|
+
raise TypeError('no valid model_type!')
|
|
69
|
+
|
|
70
|
+
if not isinstance(model, CosyVoice2):
|
|
71
|
+
# 1. export flow decoder estimator
|
|
72
|
+
estimator = model.model.flow.decoder.estimator
|
|
73
|
+
estimator.eval()
|
|
74
|
+
|
|
75
|
+
device = model.model.device
|
|
76
|
+
batch_size, seq_len = 2, 256
|
|
77
|
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
78
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
79
|
+
torch.onnx.export(
|
|
80
|
+
estimator,
|
|
81
|
+
(x, mask, mu, t, spks, cond),
|
|
82
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
83
|
+
export_params=True,
|
|
84
|
+
opset_version=18,
|
|
85
|
+
do_constant_folding=True,
|
|
86
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
87
|
+
output_names=['estimator_out'],
|
|
88
|
+
dynamic_axes={
|
|
89
|
+
'x': {2: 'seq_len'},
|
|
90
|
+
'mask': {2: 'seq_len'},
|
|
91
|
+
'mu': {2: 'seq_len'},
|
|
92
|
+
'cond': {2: 'seq_len'},
|
|
93
|
+
'estimator_out': {2: 'seq_len'},
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# 2. test computation consistency
|
|
98
|
+
option = onnxruntime.SessionOptions()
|
|
99
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
100
|
+
option.intra_op_num_threads = 1
|
|
101
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
102
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
103
|
+
sess_options=option, providers=providers)
|
|
104
|
+
|
|
105
|
+
for _ in tqdm(range(10)):
|
|
106
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
107
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
|
108
|
+
ort_inputs = {
|
|
109
|
+
'x': x.cpu().numpy(),
|
|
110
|
+
'mask': mask.cpu().numpy(),
|
|
111
|
+
'mu': mu.cpu().numpy(),
|
|
112
|
+
't': t.cpu().numpy(),
|
|
113
|
+
'spks': spks.cpu().numpy(),
|
|
114
|
+
'cond': cond.cpu().numpy()
|
|
115
|
+
}
|
|
116
|
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
|
117
|
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
|
118
|
+
logging.info('successfully export estimator')
|
|
119
|
+
else:
|
|
120
|
+
# 1. export flow decoder estimator
|
|
121
|
+
estimator = model.model.flow.decoder.estimator
|
|
122
|
+
estimator.forward = estimator.forward_chunk
|
|
123
|
+
estimator.eval()
|
|
124
|
+
|
|
125
|
+
device = model.model.device
|
|
126
|
+
batch_size, seq_len = 2, 256
|
|
127
|
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
128
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
129
|
+
cache = model.model.init_flow_cache()['decoder_cache']
|
|
130
|
+
cache.pop('offset')
|
|
131
|
+
cache = {k: v[0] for k, v in cache.items()}
|
|
132
|
+
torch.onnx.export(
|
|
133
|
+
estimator,
|
|
134
|
+
(x, mask, mu, t, spks, cond,
|
|
135
|
+
cache['down_blocks_conv_cache'],
|
|
136
|
+
cache['down_blocks_kv_cache'],
|
|
137
|
+
cache['mid_blocks_conv_cache'],
|
|
138
|
+
cache['mid_blocks_kv_cache'],
|
|
139
|
+
cache['up_blocks_conv_cache'],
|
|
140
|
+
cache['up_blocks_kv_cache'],
|
|
141
|
+
cache['final_blocks_conv_cache']),
|
|
142
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
143
|
+
export_params=True,
|
|
144
|
+
opset_version=18,
|
|
145
|
+
do_constant_folding=True,
|
|
146
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
|
|
147
|
+
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
|
|
148
|
+
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
|
|
149
|
+
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
|
|
150
|
+
dynamic_axes={
|
|
151
|
+
'x': {2: 'seq_len'},
|
|
152
|
+
'mask': {2: 'seq_len'},
|
|
153
|
+
'mu': {2: 'seq_len'},
|
|
154
|
+
'cond': {2: 'seq_len'},
|
|
155
|
+
'down_blocks_kv_cache': {3: 'cache_in_len'},
|
|
156
|
+
'mid_blocks_kv_cache': {3: 'cache_in_len'},
|
|
157
|
+
'up_blocks_kv_cache': {3: 'cache_in_len'},
|
|
158
|
+
'estimator_out': {2: 'seq_len'},
|
|
159
|
+
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
160
|
+
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
161
|
+
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
162
|
+
}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# 2. test computation consistency
|
|
166
|
+
option = onnxruntime.SessionOptions()
|
|
167
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
168
|
+
option.intra_op_num_threads = 1
|
|
169
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
170
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
171
|
+
sess_options=option, providers=providers)
|
|
172
|
+
|
|
173
|
+
for iter in tqdm(range(10)):
|
|
174
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
175
|
+
cache = model.model.init_flow_cache()['decoder_cache']
|
|
176
|
+
cache.pop('offset')
|
|
177
|
+
cache = {k: v[0] for k, v in cache.items()}
|
|
178
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
|
|
179
|
+
ort_inputs = {
|
|
180
|
+
'x': x.cpu().numpy(),
|
|
181
|
+
'mask': mask.cpu().numpy(),
|
|
182
|
+
'mu': mu.cpu().numpy(),
|
|
183
|
+
't': t.cpu().numpy(),
|
|
184
|
+
'spks': spks.cpu().numpy(),
|
|
185
|
+
'cond': cond.cpu().numpy(),
|
|
186
|
+
}
|
|
187
|
+
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
|
|
188
|
+
if iter == 0:
|
|
189
|
+
# NOTE why can not pass first iteration check?
|
|
190
|
+
continue
|
|
191
|
+
for i, j in zip(output_pytorch, output_onnx):
|
|
192
|
+
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
|
|
193
|
+
logging.info('successfully export estimator')
|
|
109
194
|
|
|
110
195
|
|
|
111
196
|
if __name__ == "__main__":
|
|
@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
|
|
|
23
23
|
import torchaudio
|
|
24
24
|
from hyperpyyaml import load_hyperpyyaml
|
|
25
25
|
from tqdm import tqdm
|
|
26
|
-
from cosyvoice.cli.model import CosyVoiceModel
|
|
26
|
+
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
27
27
|
from cosyvoice.dataset.dataset import Dataset
|
|
28
28
|
|
|
29
29
|
|
|
@@ -33,6 +33,7 @@ def get_args():
|
|
|
33
33
|
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
|
34
34
|
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
|
35
35
|
parser.add_argument('--tts_text', required=True, help='tts input file')
|
|
36
|
+
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
|
36
37
|
parser.add_argument('--llm_model', required=True, help='llm model file')
|
|
37
38
|
parser.add_argument('--flow_model', required=True, help='flow model file')
|
|
38
39
|
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
|
@@ -59,16 +60,25 @@ def main():
|
|
|
59
60
|
# Init cosyvoice models from configs
|
|
60
61
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
61
62
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
62
|
-
|
|
63
|
-
|
|
63
|
+
try:
|
|
64
|
+
with open(args.config, 'r') as f:
|
|
65
|
+
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
|
|
66
|
+
model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
|
67
|
+
except Exception:
|
|
68
|
+
try:
|
|
69
|
+
with open(args.config, 'r') as f:
|
|
70
|
+
configs = load_hyperpyyaml(f)
|
|
71
|
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
72
|
+
except Exception:
|
|
73
|
+
raise TypeError('no valid model_type!')
|
|
64
74
|
|
|
65
|
-
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
66
75
|
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
|
67
76
|
|
|
68
77
|
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
|
69
78
|
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
|
70
79
|
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
|
71
80
|
|
|
81
|
+
sample_rate = configs['sample_rate']
|
|
72
82
|
del configs
|
|
73
83
|
os.makedirs(args.result_dir, exist_ok=True)
|
|
74
84
|
fn = os.path.join(args.result_dir, 'wav.scp')
|
|
@@ -104,7 +114,7 @@ def main():
|
|
|
104
114
|
tts_speeches = torch.concat(tts_speeches, dim=1)
|
|
105
115
|
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
|
106
116
|
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
|
107
|
-
torchaudio.save(tts_fn, tts_speeches, sample_rate=
|
|
117
|
+
torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
|
|
108
118
|
f.write('{} {}\n'.format(tts_key, tts_fn))
|
|
109
119
|
f.flush()
|
|
110
120
|
f.close()
|
|
@@ -46,6 +46,7 @@ def get_args():
|
|
|
46
46
|
parser.add_argument('--config', required=True, help='config file')
|
|
47
47
|
parser.add_argument('--train_data', required=True, help='train data file')
|
|
48
48
|
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
49
|
+
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
|
49
50
|
parser.add_argument('--checkpoint', help='checkpoint model')
|
|
50
51
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
|
51
52
|
parser.add_argument('--tensorboard_dir',
|
|
@@ -97,8 +98,12 @@ def main():
|
|
|
97
98
|
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
|
98
99
|
if gan is True:
|
|
99
100
|
override_dict.pop('hift')
|
|
100
|
-
|
|
101
|
-
|
|
101
|
+
try:
|
|
102
|
+
with open(args.config, 'r') as f:
|
|
103
|
+
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
|
|
104
|
+
except Exception:
|
|
105
|
+
with open(args.config, 'r') as f:
|
|
106
|
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
|
102
107
|
if gan is True:
|
|
103
108
|
configs['train_conf'] = configs['train_conf_gan']
|
|
104
109
|
configs['train_conf'].update(vars(args))
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import os
|
|
15
15
|
import time
|
|
16
|
+
from typing import Generator
|
|
16
17
|
from tqdm import tqdm
|
|
17
18
|
from hyperpyyaml import load_hyperpyyaml
|
|
18
19
|
from modelscope import snapshot_download
|
|
@@ -20,45 +21,62 @@ import torch
|
|
|
20
21
|
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
21
22
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
22
23
|
from cosyvoice.utils.file_utils import logging
|
|
24
|
+
from cosyvoice.utils.class_utils import get_model_type
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
class CosyVoice:
|
|
26
28
|
|
|
27
|
-
def __init__(self, model_dir, load_jit=
|
|
28
|
-
instruct = True if '-Instruct' in model_dir else False
|
|
29
|
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
|
30
|
+
self.instruct = True if '-Instruct' in model_dir else False
|
|
29
31
|
self.model_dir = model_dir
|
|
32
|
+
self.fp16 = fp16
|
|
30
33
|
if not os.path.exists(model_dir):
|
|
31
34
|
model_dir = snapshot_download(model_dir)
|
|
32
|
-
|
|
35
|
+
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
|
36
|
+
if not os.path.exists(hyper_yaml_path):
|
|
37
|
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
|
38
|
+
with open(hyper_yaml_path, 'r') as f:
|
|
33
39
|
configs = load_hyperpyyaml(f)
|
|
40
|
+
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
|
34
41
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
35
42
|
configs['feat_extractor'],
|
|
36
43
|
'{}/campplus.onnx'.format(model_dir),
|
|
37
44
|
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
|
38
45
|
'{}/spk2info.pt'.format(model_dir),
|
|
39
|
-
instruct,
|
|
40
46
|
configs['allowed_special'])
|
|
41
47
|
self.sample_rate = configs['sample_rate']
|
|
42
|
-
if torch.cuda.is_available() is False and (
|
|
43
|
-
load_jit = False
|
|
44
|
-
fp16
|
|
45
|
-
logging.warning('cpu do not support fp16 and jit, force set to False')
|
|
48
|
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
49
|
+
load_jit, load_trt, fp16 = False, False, False
|
|
50
|
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
46
51
|
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
|
47
52
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
48
53
|
'{}/flow.pt'.format(model_dir),
|
|
49
54
|
'{}/hift.pt'.format(model_dir))
|
|
50
55
|
if load_jit:
|
|
51
|
-
self.model.load_jit('{}/llm.text_encoder.
|
|
52
|
-
'{}/llm.llm.
|
|
53
|
-
'{}/flow.encoder.
|
|
54
|
-
if
|
|
55
|
-
self.model.
|
|
56
|
+
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
57
|
+
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
58
|
+
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
|
59
|
+
if load_trt:
|
|
60
|
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
61
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
62
|
+
self.fp16)
|
|
56
63
|
del configs
|
|
57
64
|
|
|
58
|
-
def
|
|
65
|
+
def list_available_spks(self):
|
|
59
66
|
spks = list(self.frontend.spk2info.keys())
|
|
60
67
|
return spks
|
|
61
68
|
|
|
69
|
+
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
|
|
70
|
+
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
|
|
71
|
+
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
|
|
72
|
+
del model_input['text']
|
|
73
|
+
del model_input['text_len']
|
|
74
|
+
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def save_spkinfo(self):
|
|
78
|
+
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
|
|
79
|
+
|
|
62
80
|
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
|
63
81
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
64
82
|
model_input = self.frontend.frontend_sft(i, spk_id)
|
|
@@ -70,12 +88,12 @@ class CosyVoice:
|
|
|
70
88
|
yield model_output
|
|
71
89
|
start_time = time.time()
|
|
72
90
|
|
|
73
|
-
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
91
|
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
74
92
|
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
|
75
93
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
76
|
-
if len(i) < 0.5 * len(prompt_text):
|
|
94
|
+
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
|
77
95
|
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
|
78
|
-
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
|
|
96
|
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
79
97
|
start_time = time.time()
|
|
80
98
|
logging.info('synthesis text {}'.format(i))
|
|
81
99
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
@@ -84,11 +102,9 @@ class CosyVoice:
|
|
|
84
102
|
yield model_output
|
|
85
103
|
start_time = time.time()
|
|
86
104
|
|
|
87
|
-
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
88
|
-
if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
|
|
89
|
-
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
|
105
|
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
90
106
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
91
|
-
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
|
|
107
|
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
92
108
|
start_time = time.time()
|
|
93
109
|
logging.info('synthesis text {}'.format(i))
|
|
94
110
|
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
@@ -98,8 +114,8 @@ class CosyVoice:
|
|
|
98
114
|
start_time = time.time()
|
|
99
115
|
|
|
100
116
|
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
|
101
|
-
assert isinstance(self.model, CosyVoiceModel)
|
|
102
|
-
if self.
|
|
117
|
+
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
|
|
118
|
+
if self.instruct is False:
|
|
103
119
|
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
|
104
120
|
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
|
105
121
|
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
@@ -112,22 +128,10 @@ class CosyVoice:
|
|
|
112
128
|
yield model_output
|
|
113
129
|
start_time = time.time()
|
|
114
130
|
|
|
115
|
-
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
|
|
116
|
-
assert isinstance(self.model, CosyVoice2Model)
|
|
117
|
-
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
118
|
-
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
|
|
119
|
-
start_time = time.time()
|
|
120
|
-
logging.info('synthesis text {}'.format(i))
|
|
121
|
-
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
122
|
-
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
123
|
-
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
124
|
-
yield model_output
|
|
125
|
-
start_time = time.time()
|
|
126
|
-
|
|
127
131
|
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
|
128
132
|
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
|
|
129
133
|
start_time = time.time()
|
|
130
|
-
for model_output in self.model.
|
|
134
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
131
135
|
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
132
136
|
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
133
137
|
yield model_output
|
|
@@ -136,35 +140,51 @@ class CosyVoice:
|
|
|
136
140
|
|
|
137
141
|
class CosyVoice2(CosyVoice):
|
|
138
142
|
|
|
139
|
-
def __init__(self, model_dir, load_jit=False,
|
|
140
|
-
instruct = True if '-Instruct' in model_dir else False
|
|
143
|
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
|
|
144
|
+
self.instruct = True if '-Instruct' in model_dir else False
|
|
141
145
|
self.model_dir = model_dir
|
|
146
|
+
self.fp16 = fp16
|
|
142
147
|
if not os.path.exists(model_dir):
|
|
143
148
|
model_dir = snapshot_download(model_dir)
|
|
144
|
-
|
|
149
|
+
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
|
150
|
+
if not os.path.exists(hyper_yaml_path):
|
|
151
|
+
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
|
152
|
+
with open(hyper_yaml_path, 'r') as f:
|
|
145
153
|
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
|
154
|
+
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
|
146
155
|
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
|
147
156
|
configs['feat_extractor'],
|
|
148
157
|
'{}/campplus.onnx'.format(model_dir),
|
|
149
158
|
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
|
150
159
|
'{}/spk2info.pt'.format(model_dir),
|
|
151
|
-
instruct,
|
|
152
160
|
configs['allowed_special'])
|
|
153
161
|
self.sample_rate = configs['sample_rate']
|
|
154
|
-
if torch.cuda.is_available() is False and load_jit is True:
|
|
155
|
-
load_jit = False
|
|
156
|
-
logging.warning('
|
|
157
|
-
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
|
|
162
|
+
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
163
|
+
load_jit, load_trt, fp16 = False, False, False
|
|
164
|
+
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
165
|
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
|
|
158
166
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
159
|
-
'{}/flow.pt'.format(model_dir),
|
|
167
|
+
'{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
|
|
160
168
|
'{}/hift.pt'.format(model_dir))
|
|
161
169
|
if load_jit:
|
|
162
|
-
self.model.load_jit('{}/flow.encoder.
|
|
163
|
-
if load_trt is True and load_onnx is True:
|
|
164
|
-
load_onnx = False
|
|
165
|
-
logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
|
|
166
|
-
if load_onnx:
|
|
167
|
-
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
|
170
|
+
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
|
168
171
|
if load_trt:
|
|
169
|
-
self.model.load_trt('{}/flow.decoder.estimator.
|
|
172
|
+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
173
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
174
|
+
self.fp16)
|
|
170
175
|
del configs
|
|
176
|
+
|
|
177
|
+
def inference_instruct(self, *args, **kwargs):
|
|
178
|
+
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
|
|
179
|
+
|
|
180
|
+
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
|
181
|
+
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
|
|
182
|
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
|
183
|
+
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
|
|
184
|
+
start_time = time.time()
|
|
185
|
+
logging.info('synthesis text {}'.format(i))
|
|
186
|
+
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
|
187
|
+
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
|
188
|
+
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
|
189
|
+
yield model_output
|
|
190
|
+
start_time = time.time()
|