xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,92 @@
1
+ # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import argparse
18
+ import glob
19
+
20
+ import yaml
21
+ import torch
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser(description='average model')
26
+ parser.add_argument('--dst_model', required=True, help='averaged model')
27
+ parser.add_argument('--src_path',
28
+ required=True,
29
+ help='src model path for average')
30
+ parser.add_argument('--val_best',
31
+ action="store_true",
32
+ help='averaged model')
33
+ parser.add_argument('--num',
34
+ default=5,
35
+ type=int,
36
+ help='nums for averaged model')
37
+
38
+ args = parser.parse_args()
39
+ print(args)
40
+ return args
41
+
42
+
43
+ def main():
44
+ args = get_args()
45
+ val_scores = []
46
+ if args.val_best:
47
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48
+ yamls = [
49
+ f for f in yamls
50
+ if not (os.path.basename(f).startswith('train')
51
+ or os.path.basename(f).startswith('init'))
52
+ ]
53
+ for y in yamls:
54
+ with open(y, 'r') as f:
55
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56
+ loss = float(dic_yaml['loss_dict']['loss'])
57
+ epoch = int(dic_yaml['epoch'])
58
+ step = int(dic_yaml['step'])
59
+ tag = dic_yaml['tag']
60
+ val_scores += [[epoch, step, loss, tag]]
61
+ sorted_val_scores = sorted(val_scores,
62
+ key=lambda x: x[2],
63
+ reverse=False)
64
+ print("best val (epoch, step, loss, tag) = " +
65
+ str(sorted_val_scores[:args.num]))
66
+ path_list = [
67
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68
+ for score in sorted_val_scores[:args.num]
69
+ ]
70
+ print(path_list)
71
+ avg = {}
72
+ num = args.num
73
+ assert num == len(path_list)
74
+ for path in path_list:
75
+ print('Processing {}'.format(path))
76
+ states = torch.load(path, map_location=torch.device('cpu'))
77
+ for k in states.keys():
78
+ if k not in avg.keys():
79
+ avg[k] = states[k].clone()
80
+ else:
81
+ avg[k] += states[k]
82
+ # average
83
+ for k in avg.keys():
84
+ if avg[k] is not None:
85
+ # pytorch 1.6 use true_divide instead of /=
86
+ avg[k] = torch.true_divide(avg[k], num)
87
+ print('Saving to {}'.format(args.dst_model))
88
+ torch.save(avg, args.dst_model)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
@@ -19,12 +19,13 @@ import logging
19
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
20
  import os
21
21
  import sys
22
+ import torch
22
23
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
23
24
  sys.path.append('{}/../..'.format(ROOT_DIR))
24
25
  sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
25
- import torch
26
26
  from cosyvoice.cli.cosyvoice import CosyVoice
27
27
 
28
+
28
29
  def get_args():
29
30
  parser = argparse.ArgumentParser(description='export your model for deployment')
30
31
  parser.add_argument('--model_dir',
@@ -35,6 +36,7 @@ def get_args():
35
36
  print(args)
36
37
  return args
37
38
 
39
+
38
40
  def main():
39
41
  args = get_args()
40
42
  logging.basicConfig(level=logging.DEBUG,
@@ -44,7 +46,7 @@ def main():
44
46
  torch._C._jit_set_profiling_mode(False)
45
47
  torch._C._jit_set_profiling_executor(False)
46
48
 
47
- cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
49
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
48
50
 
49
51
  # 1. export llm text_encoder
50
52
  llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
@@ -60,5 +62,13 @@ def main():
60
62
  script = torch.jit.optimize_for_inference(script)
61
63
  script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
62
64
 
65
+ # 3. export flow encoder
66
+ flow_encoder = cosyvoice.model.flow.encoder
67
+ script = torch.jit.script(flow_encoder)
68
+ script = torch.jit.freeze(script)
69
+ script = torch.jit.optimize_for_inference(script)
70
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71
+
72
+
63
73
  if __name__ == '__main__':
64
74
  main()
@@ -0,0 +1,112 @@
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60
+
61
+ # 1. export flow decoder estimator
62
+ estimator = cosyvoice.model.flow.decoder.estimator
63
+
64
+ device = cosyvoice.model.device
65
+ batch_size, seq_len = 1, 256
66
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68
+ torch.onnx.export(
69
+ estimator,
70
+ (x, mask, mu, t, spks, cond),
71
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72
+ export_params=True,
73
+ opset_version=18,
74
+ do_constant_folding=True,
75
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76
+ output_names=['estimator_out'],
77
+ dynamic_axes={
78
+ 'x': {0: 'batch_size', 2: 'seq_len'},
79
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
80
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
81
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
82
+ 't': {0: 'batch_size'},
83
+ 'spks': {0: 'batch_size'},
84
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
85
+ }
86
+ )
87
+
88
+ # 2. test computation consistency
89
+ option = onnxruntime.SessionOptions()
90
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
+ option.intra_op_num_threads = 1
92
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94
+ sess_options=option, providers=providers)
95
+
96
+ for _ in tqdm(range(10)):
97
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
98
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
99
+ ort_inputs = {
100
+ 'x': x.cpu().numpy(),
101
+ 'mask': mask.cpu().numpy(),
102
+ 'mu': mu.cpu().numpy(),
103
+ 't': t.cpu().numpy(),
104
+ 'spks': spks.cpu().numpy(),
105
+ 'cond': cond.cpu().numpy()
106
+ }
107
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
@@ -0,0 +1,9 @@
1
+ #!/bin/bash
2
+ # Copyright 2024 Alibaba Inc. All Rights Reserved.
3
+ # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
4
+ # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
5
+ TRT_DIR=<YOUR_TRT_DIR>
6
+ MODEL_DIR=<COSYVOICE2_MODEL_DIR>
7
+
8
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
9
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
@@ -18,16 +18,15 @@ import argparse
18
18
  import logging
19
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
20
  import os
21
-
22
21
  import torch
23
22
  from torch.utils.data import DataLoader
24
23
  import torchaudio
25
24
  from hyperpyyaml import load_hyperpyyaml
26
25
  from tqdm import tqdm
27
26
  from cosyvoice.cli.model import CosyVoiceModel
28
-
29
27
  from cosyvoice.dataset.dataset import Dataset
30
28
 
29
+
31
30
  def get_args():
32
31
  parser = argparse.ArgumentParser(description='inference with your model')
33
32
  parser.add_argument('--config', required=True, help='config file')
@@ -66,7 +65,8 @@ def main():
66
65
  model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
66
  model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
67
 
69
- test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
70
  test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
71
 
72
72
  del configs
@@ -74,13 +74,11 @@ def main():
74
74
  fn = os.path.join(args.result_dir, 'wav.scp')
75
75
  f = open(fn, 'w')
76
76
  with torch.no_grad():
77
- for batch_idx, batch in tqdm(enumerate(test_data_loader)):
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
78
  utts = batch["utts"]
79
79
  assert len(utts) == 1, "inference mode only support batchsize 1"
80
- text = batch["text"]
81
80
  text_token = batch["text_token"].to(device)
82
81
  text_token_len = batch["text_token_len"].to(device)
83
- tts_text = batch["tts_text"]
84
82
  tts_index = batch["tts_index"]
85
83
  tts_text_token = batch["tts_text_token"].to(device)
86
84
  tts_text_token_len = batch["tts_text_token_len"].to(device)
@@ -101,7 +99,7 @@ def main():
101
99
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
100
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
101
  tts_speeches = []
104
- for model_output in model.inference(**model_input):
102
+ for model_output in model.tts(**model_input):
105
103
  tts_speeches.append(model_output['tts_speech'])
106
104
  tts_speeches = torch.concat(tts_speeches, dim=1)
107
105
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
@@ -18,6 +18,7 @@ import datetime
18
18
  import logging
19
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
20
  from copy import deepcopy
21
+ import os
21
22
  import torch
22
23
  import torch.distributed as dist
23
24
  import deepspeed
@@ -67,13 +68,17 @@ def get_args():
67
68
  action='store_true',
68
69
  default=False,
69
70
  help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
70
75
  parser.add_argument('--deepspeed.save_states',
71
76
  dest='save_states',
72
77
  default='model_only',
73
78
  choices=['model_only', 'model+optimizer'],
74
79
  help='save model/optimizer states')
75
80
  parser.add_argument('--timeout',
76
- default=30,
81
+ default=60,
77
82
  type=int,
78
83
  help='timeout (in seconds) of cosyvoice_join.')
79
84
  parser = deepspeed.add_config_arguments(parser)
@@ -86,10 +91,16 @@ def main():
86
91
  args = get_args()
87
92
  logging.basicConfig(level=logging.DEBUG,
88
93
  format='%(asctime)s %(levelname)s %(message)s')
94
+ # gan train has some special initialization logic
95
+ gan = True if args.model == 'hifigan' else False
89
96
 
90
- override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
97
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
+ if gan is True:
99
+ override_dict.pop('hift')
91
100
  with open(args.config, 'r') as f:
92
101
  configs = load_hyperpyyaml(f, overrides=override_dict)
102
+ if gan is True:
103
+ configs['train_conf'] = configs['train_conf_gan']
93
104
  configs['train_conf'].update(vars(args))
94
105
 
95
106
  # Init env for ddp
@@ -97,7 +108,7 @@ def main():
97
108
 
98
109
  # Get dataset & dataloader
99
110
  train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100
- init_dataset_and_dataloader(args, configs)
111
+ init_dataset_and_dataloader(args, configs, gan)
101
112
 
102
113
  # Do some sanity checks and save config to arsg.model_dir
103
114
  configs = check_modify_and_save_config(args, configs)
@@ -107,30 +118,53 @@ def main():
107
118
 
108
119
  # load checkpoint
109
120
  model = configs[args.model]
121
+ start_step, start_epoch = 0, -1
110
122
  if args.checkpoint is not None:
111
- model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
123
+ if os.path.exists(args.checkpoint):
124
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
125
+ model.load_state_dict(state_dict, strict=False)
126
+ if 'step' in state_dict:
127
+ start_step = state_dict['step']
128
+ if 'epoch' in state_dict:
129
+ start_epoch = state_dict['epoch']
130
+ else:
131
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
112
132
 
113
133
  # Dispatch model from cpu to gpu
114
134
  model = wrap_cuda_model(args, model)
115
135
 
116
136
  # Get optimizer & scheduler
117
- model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
137
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
138
+ scheduler.set_step(start_step)
139
+ if scheduler_d is not None:
140
+ scheduler_d.set_step(start_step)
118
141
 
119
142
  # Save init checkpoints
120
143
  info_dict = deepcopy(configs['train_conf'])
144
+ info_dict['step'] = start_step
145
+ info_dict['epoch'] = start_epoch
121
146
  save_model(model, 'init', info_dict)
122
147
 
123
148
  # Get executor
124
- executor = Executor()
149
+ executor = Executor(gan=gan)
150
+ executor.step = start_step
125
151
 
152
+ # Init scaler, used for pytorch amp mixed precision training
153
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
154
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
126
155
  # Start training loop
127
- for epoch in range(info_dict['max_epoch']):
156
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
128
157
  executor.epoch = epoch
129
158
  train_dataset.set_epoch(epoch)
130
159
  dist.barrier()
131
160
  group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
132
- executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
161
+ if gan is True:
162
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
163
+ writer, info_dict, scaler, group_join)
164
+ else:
165
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
133
166
  dist.destroy_process_group(group_join)
134
167
 
168
+
135
169
  if __name__ == '__main__':
136
170
  main()
@@ -13,15 +13,18 @@
13
13
  # limitations under the License.
14
14
  import os
15
15
  import time
16
+ from tqdm import tqdm
16
17
  from hyperpyyaml import load_hyperpyyaml
17
18
  from modelscope import snapshot_download
19
+ import torch
18
20
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
- from cosyvoice.cli.model import CosyVoiceModel
21
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
20
22
  from cosyvoice.utils.file_utils import logging
21
23
 
24
+
22
25
  class CosyVoice:
23
26
 
24
- def __init__(self, model_dir, load_jit=True):
27
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
25
28
  instruct = True if '-Instruct' in model_dir else False
26
29
  self.model_dir = model_dir
27
30
  if not os.path.exists(model_dir):
@@ -35,65 +38,133 @@ class CosyVoice:
35
38
  '{}/spk2info.pt'.format(model_dir),
36
39
  instruct,
37
40
  configs['allowed_special'])
38
- self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
41
+ self.sample_rate = configs['sample_rate']
42
+ if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
43
+ load_jit = False
44
+ fp16 = False
45
+ logging.warning('cpu do not support fp16 and jit, force set to False')
46
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
39
47
  self.model.load('{}/llm.pt'.format(model_dir),
40
48
  '{}/flow.pt'.format(model_dir),
41
49
  '{}/hift.pt'.format(model_dir))
42
50
  if load_jit:
43
51
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
- '{}/llm.llm.fp16.zip'.format(model_dir))
52
+ '{}/llm.llm.fp16.zip'.format(model_dir),
53
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
54
+ if load_onnx:
55
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
45
56
  del configs
46
57
 
47
58
  def list_avaliable_spks(self):
48
59
  spks = list(self.frontend.spk2info.keys())
49
60
  return spks
50
61
 
51
- def inference_sft(self, tts_text, spk_id, stream=False):
52
- for i in self.frontend.text_normalize(tts_text, split=True):
62
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
63
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
53
64
  model_input = self.frontend.frontend_sft(i, spk_id)
54
65
  start_time = time.time()
55
66
  logging.info('synthesis text {}'.format(i))
56
- for model_output in self.model.inference(**model_input, stream=stream):
57
- speech_len = model_output['tts_speech'].shape[1] / 22050
67
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
68
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
58
69
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
59
70
  yield model_output
60
71
  start_time = time.time()
61
72
 
62
- def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
63
- prompt_text = self.frontend.text_normalize(prompt_text, split=False)
64
- for i in self.frontend.text_normalize(tts_text, split=True):
65
- model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
73
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
74
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
75
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
76
+ if len(i) < 0.5 * len(prompt_text):
77
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
78
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
66
79
  start_time = time.time()
67
80
  logging.info('synthesis text {}'.format(i))
68
- for model_output in self.model.inference(**model_input, stream=stream):
69
- speech_len = model_output['tts_speech'].shape[1] / 22050
81
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
82
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
70
83
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
71
84
  yield model_output
72
85
  start_time = time.time()
73
86
 
74
- def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
75
- if self.frontend.instruct is True:
87
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
88
+ if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
76
89
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
77
- for i in self.frontend.text_normalize(tts_text, split=True):
78
- model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
90
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
91
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
79
92
  start_time = time.time()
80
93
  logging.info('synthesis text {}'.format(i))
81
- for model_output in self.model.inference(**model_input, stream=stream):
82
- speech_len = model_output['tts_speech'].shape[1] / 22050
94
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
95
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
83
96
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
84
97
  yield model_output
85
98
  start_time = time.time()
86
99
 
87
- def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
100
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
101
+ assert isinstance(self.model, CosyVoiceModel)
88
102
  if self.frontend.instruct is False:
89
103
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
90
- instruct_text = self.frontend.text_normalize(instruct_text, split=False)
91
- for i in self.frontend.text_normalize(tts_text, split=True):
104
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
105
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
92
106
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
93
107
  start_time = time.time()
94
108
  logging.info('synthesis text {}'.format(i))
95
- for model_output in self.model.inference(**model_input, stream=stream):
96
- speech_len = model_output['tts_speech'].shape[1] / 22050
109
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
110
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
111
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
112
+ yield model_output
113
+ start_time = time.time()
114
+
115
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
116
+ assert isinstance(self.model, CosyVoice2Model)
117
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
118
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
119
+ start_time = time.time()
120
+ logging.info('synthesis text {}'.format(i))
121
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
122
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
97
123
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
98
124
  yield model_output
99
125
  start_time = time.time()
126
+
127
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
128
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
129
+ start_time = time.time()
130
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
131
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
132
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
133
+ yield model_output
134
+ start_time = time.time()
135
+
136
+
137
+ class CosyVoice2(CosyVoice):
138
+
139
+ def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
140
+ instruct = True if '-Instruct' in model_dir else False
141
+ self.model_dir = model_dir
142
+ if not os.path.exists(model_dir):
143
+ model_dir = snapshot_download(model_dir)
144
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
145
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
146
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
147
+ configs['feat_extractor'],
148
+ '{}/campplus.onnx'.format(model_dir),
149
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
150
+ '{}/spk2info.pt'.format(model_dir),
151
+ instruct,
152
+ configs['allowed_special'])
153
+ self.sample_rate = configs['sample_rate']
154
+ if torch.cuda.is_available() is False and load_jit is True:
155
+ load_jit = False
156
+ logging.warning('cpu do not support jit, force set to False')
157
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
158
+ self.model.load('{}/llm.pt'.format(model_dir),
159
+ '{}/flow.pt'.format(model_dir),
160
+ '{}/hift.pt'.format(model_dir))
161
+ if load_jit:
162
+ self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
163
+ if load_trt is True and load_onnx is True:
164
+ load_onnx = False
165
+ logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
166
+ if load_onnx:
167
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
168
+ if load_trt:
169
+ self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
170
+ del configs