xinference 1.5.1__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
23
23
  import torchaudio
24
24
  from hyperpyyaml import load_hyperpyyaml
25
25
  from tqdm import tqdm
26
- from cosyvoice.cli.model import CosyVoiceModel
26
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
27
27
  from cosyvoice.dataset.dataset import Dataset
28
28
 
29
29
 
@@ -33,6 +33,7 @@ def get_args():
33
33
  parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
34
  parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
35
  parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
36
37
  parser.add_argument('--llm_model', required=True, help='llm model file')
37
38
  parser.add_argument('--flow_model', required=True, help='flow model file')
38
39
  parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
@@ -59,16 +60,25 @@ def main():
59
60
  # Init cosyvoice models from configs
60
61
  use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
62
  device = torch.device('cuda' if use_cuda else 'cpu')
62
- with open(args.config, 'r') as f:
63
- configs = load_hyperpyyaml(f)
63
+ try:
64
+ with open(args.config, 'r') as f:
65
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
66
+ model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
67
+ except Exception:
68
+ try:
69
+ with open(args.config, 'r') as f:
70
+ configs = load_hyperpyyaml(f)
71
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
72
+ except Exception:
73
+ raise TypeError('no valid model_type!')
64
74
 
65
- model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
75
  model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
76
 
68
77
  test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
78
  tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
79
  test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
80
 
81
+ sample_rate = configs['sample_rate']
72
82
  del configs
73
83
  os.makedirs(args.result_dir, exist_ok=True)
74
84
  fn = os.path.join(args.result_dir, 'wav.scp')
@@ -104,7 +114,7 @@ def main():
104
114
  tts_speeches = torch.concat(tts_speeches, dim=1)
105
115
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
116
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
- torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
117
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
108
118
  f.write('{} {}\n'.format(tts_key, tts_fn))
109
119
  f.flush()
110
120
  f.close()
@@ -46,6 +46,7 @@ def get_args():
46
46
  parser.add_argument('--config', required=True, help='config file')
47
47
  parser.add_argument('--train_data', required=True, help='train data file')
48
48
  parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
49
50
  parser.add_argument('--checkpoint', help='checkpoint model')
50
51
  parser.add_argument('--model_dir', required=True, help='save model dir')
51
52
  parser.add_argument('--tensorboard_dir',
@@ -97,8 +98,12 @@ def main():
97
98
  override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
99
  if gan is True:
99
100
  override_dict.pop('hift')
100
- with open(args.config, 'r') as f:
101
- configs = load_hyperpyyaml(f, overrides=override_dict)
101
+ try:
102
+ with open(args.config, 'r') as f:
103
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
104
+ except Exception:
105
+ with open(args.config, 'r') as f:
106
+ configs = load_hyperpyyaml(f, overrides=override_dict)
102
107
  if gan is True:
103
108
  configs['train_conf'] = configs['train_conf_gan']
104
109
  configs['train_conf'].update(vars(args))
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  import os
15
15
  import time
16
+ from typing import Generator
16
17
  from tqdm import tqdm
17
18
  from hyperpyyaml import load_hyperpyyaml
18
19
  from modelscope import snapshot_download
@@ -20,45 +21,62 @@ import torch
20
21
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
21
22
  from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
22
23
  from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
23
25
 
24
26
 
25
27
  class CosyVoice:
26
28
 
27
- def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
28
- instruct = True if '-Instruct' in model_dir else False
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
30
+ self.instruct = True if '-Instruct' in model_dir else False
29
31
  self.model_dir = model_dir
32
+ self.fp16 = fp16
30
33
  if not os.path.exists(model_dir):
31
34
  model_dir = snapshot_download(model_dir)
32
- with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
35
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
36
+ if not os.path.exists(hyper_yaml_path):
37
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
38
+ with open(hyper_yaml_path, 'r') as f:
33
39
  configs = load_hyperpyyaml(f)
40
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
34
41
  self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
35
42
  configs['feat_extractor'],
36
43
  '{}/campplus.onnx'.format(model_dir),
37
44
  '{}/speech_tokenizer_v1.onnx'.format(model_dir),
38
45
  '{}/spk2info.pt'.format(model_dir),
39
- instruct,
40
46
  configs['allowed_special'])
41
47
  self.sample_rate = configs['sample_rate']
42
- if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
43
- load_jit = False
44
- fp16 = False
45
- logging.warning('cpu do not support fp16 and jit, force set to False')
48
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
49
+ load_jit, load_trt, fp16 = False, False, False
50
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
46
51
  self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
47
52
  self.model.load('{}/llm.pt'.format(model_dir),
48
53
  '{}/flow.pt'.format(model_dir),
49
54
  '{}/hift.pt'.format(model_dir))
50
55
  if load_jit:
51
- self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
52
- '{}/llm.llm.fp16.zip'.format(model_dir),
53
- '{}/flow.encoder.fp32.zip'.format(model_dir))
54
- if load_onnx:
55
- self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
56
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
58
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
59
+ if load_trt:
60
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
61
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
62
+ self.fp16)
56
63
  del configs
57
64
 
58
- def list_avaliable_spks(self):
65
+ def list_available_spks(self):
59
66
  spks = list(self.frontend.spk2info.keys())
60
67
  return spks
61
68
 
69
+ def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
70
+ assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
71
+ model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
72
+ del model_input['text']
73
+ del model_input['text_len']
74
+ self.frontend.spk2info[zero_shot_spk_id] = model_input
75
+ return True
76
+
77
+ def save_spkinfo(self):
78
+ torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
79
+
62
80
  def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
63
81
  for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
64
82
  model_input = self.frontend.frontend_sft(i, spk_id)
@@ -70,12 +88,12 @@ class CosyVoice:
70
88
  yield model_output
71
89
  start_time = time.time()
72
90
 
73
- def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
91
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
74
92
  prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
75
93
  for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
76
- if len(i) < 0.5 * len(prompt_text):
94
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
77
95
  logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
78
- model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
96
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
79
97
  start_time = time.time()
80
98
  logging.info('synthesis text {}'.format(i))
81
99
  for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -84,11 +102,9 @@ class CosyVoice:
84
102
  yield model_output
85
103
  start_time = time.time()
86
104
 
87
- def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
88
- if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
89
- raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
105
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
90
106
  for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
91
- model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
107
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
92
108
  start_time = time.time()
93
109
  logging.info('synthesis text {}'.format(i))
94
110
  for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -98,8 +114,8 @@ class CosyVoice:
98
114
  start_time = time.time()
99
115
 
100
116
  def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
101
- assert isinstance(self.model, CosyVoiceModel)
102
- if self.frontend.instruct is False:
117
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
118
+ if self.instruct is False:
103
119
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
104
120
  instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
105
121
  for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
@@ -112,22 +128,10 @@ class CosyVoice:
112
128
  yield model_output
113
129
  start_time = time.time()
114
130
 
115
- def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
116
- assert isinstance(self.model, CosyVoice2Model)
117
- for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
118
- model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
119
- start_time = time.time()
120
- logging.info('synthesis text {}'.format(i))
121
- for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
122
- speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
123
- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
124
- yield model_output
125
- start_time = time.time()
126
-
127
131
  def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
128
132
  model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
129
133
  start_time = time.time()
130
- for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
134
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
131
135
  speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
132
136
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
133
137
  yield model_output
@@ -136,35 +140,51 @@ class CosyVoice:
136
140
 
137
141
  class CosyVoice2(CosyVoice):
138
142
 
139
- def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
140
- instruct = True if '-Instruct' in model_dir else False
143
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
144
+ self.instruct = True if '-Instruct' in model_dir else False
141
145
  self.model_dir = model_dir
146
+ self.fp16 = fp16
142
147
  if not os.path.exists(model_dir):
143
148
  model_dir = snapshot_download(model_dir)
144
- with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
149
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
150
+ if not os.path.exists(hyper_yaml_path):
151
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
152
+ with open(hyper_yaml_path, 'r') as f:
145
153
  configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
154
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
146
155
  self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
147
156
  configs['feat_extractor'],
148
157
  '{}/campplus.onnx'.format(model_dir),
149
158
  '{}/speech_tokenizer_v2.onnx'.format(model_dir),
150
159
  '{}/spk2info.pt'.format(model_dir),
151
- instruct,
152
160
  configs['allowed_special'])
153
161
  self.sample_rate = configs['sample_rate']
154
- if torch.cuda.is_available() is False and load_jit is True:
155
- load_jit = False
156
- logging.warning('cpu do not support jit, force set to False')
157
- self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
162
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
163
+ load_jit, load_trt, fp16 = False, False, False
164
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
165
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
158
166
  self.model.load('{}/llm.pt'.format(model_dir),
159
- '{}/flow.pt'.format(model_dir),
167
+ '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
160
168
  '{}/hift.pt'.format(model_dir))
161
169
  if load_jit:
162
- self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
163
- if load_trt is True and load_onnx is True:
164
- load_onnx = False
165
- logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
166
- if load_onnx:
167
- self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
170
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
168
171
  if load_trt:
169
- self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
172
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
173
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
174
+ self.fp16)
170
175
  del configs
176
+
177
+ def inference_instruct(self, *args, **kwargs):
178
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
179
+
180
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
181
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
182
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
183
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
184
+ start_time = time.time()
185
+ logging.info('synthesis text {}'.format(i))
186
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
187
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
188
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
189
+ yield model_output
190
+ start_time = time.time()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  from functools import partial
15
+ from typing import Generator
15
16
  import json
16
17
  import onnxruntime
17
18
  import torch
@@ -31,7 +32,8 @@ except ImportError:
31
32
  from tn.chinese.normalizer import Normalizer as ZhNormalizer
32
33
  from tn.english.normalizer import Normalizer as EnNormalizer
33
34
  use_ttsfrd = False
34
- from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
35
+ from cosyvoice.utils.file_utils import logging
36
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
35
37
 
36
38
 
37
39
  class CosyVoiceFrontEnd:
@@ -42,7 +44,6 @@ class CosyVoiceFrontEnd:
42
44
  campplus_model: str,
43
45
  speech_tokenizer_model: str,
44
46
  spk2info: str = '',
45
- instruct: bool = False,
46
47
  allowed_special: str = 'all'):
47
48
  self.tokenizer = get_tokenizer()
48
49
  self.feat_extractor = feat_extractor
@@ -58,9 +59,7 @@ class CosyVoiceFrontEnd:
58
59
  self.spk2info = torch.load(spk2info, map_location=self.device)
59
60
  else:
60
61
  self.spk2info = {}
61
- self.instruct = instruct
62
62
  self.allowed_special = allowed_special
63
- self.inflect_parser = inflect.engine()
64
63
  self.use_ttsfrd = use_ttsfrd
65
64
  if self.use_ttsfrd:
66
65
  self.frd = ttsfrd.TtsFrontendEngine()
@@ -69,14 +68,26 @@ class CosyVoiceFrontEnd:
69
68
  'failed to initialize ttsfrd resource'
70
69
  self.frd.set_lang_type('pinyinvg')
71
70
  else:
72
- self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
73
72
  self.en_tn_model = EnNormalizer()
73
+ self.inflect_parser = inflect.engine()
74
74
 
75
75
  def _extract_text_token(self, text):
76
- text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
77
- text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
78
- text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
79
- return text_token, text_token_len
76
+ if isinstance(text, Generator):
77
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
78
+ # NOTE add a dummy text_token_len for compatibility
79
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
80
+ else:
81
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
82
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
83
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
84
+ return text_token, text_token_len
85
+
86
+ def _extract_text_token_generator(self, text_generator):
87
+ for text in text_generator:
88
+ text_token, _ = self._extract_text_token(text)
89
+ for i in range(text_token.shape[1]):
90
+ yield text_token[:, i: i + 1]
80
91
 
81
92
  def _extract_speech_token(self, speech):
82
93
  assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -108,14 +119,17 @@ class CosyVoiceFrontEnd:
108
119
  return speech_feat, speech_feat_len
109
120
 
110
121
  def text_normalize(self, text, split=True, text_frontend=True):
111
- if text_frontend is False:
122
+ if isinstance(text, Generator):
123
+ logging.info('get tts_text generator, will skip text_normalize!')
124
+ return [text]
125
+ if text_frontend is False or text == '':
112
126
  return [text] if split is True else text
113
127
  text = text.strip()
114
- if contains_chinese(text):
115
- if self.use_ttsfrd:
116
- texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
117
- text = ''.join(texts)
118
- else:
128
+ if self.use_ttsfrd:
129
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
130
+ text = ''.join(texts)
131
+ else:
132
+ if contains_chinese(text):
119
133
  text = self.zh_tn_model.normalize(text)
120
134
  text = text.replace("\n", "")
121
135
  text = replace_blank(text)
@@ -126,18 +140,13 @@ class CosyVoiceFrontEnd:
126
140
  text = re.sub(r'[,,、]+$', '。', text)
127
141
  texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
128
142
  token_min_n=60, merge_len=20, comma_split=False))
129
- else:
130
- if self.use_ttsfrd:
131
- texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
132
- text = ''.join(texts)
133
143
  else:
134
144
  text = self.en_tn_model.normalize(text)
135
145
  text = spell_out_number(text, self.inflect_parser)
136
146
  texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
137
147
  token_min_n=60, merge_len=20, comma_split=False))
138
- if split is False:
139
- return text
140
- return texts
148
+ texts = [i for i in texts if not is_only_punctuation(i)]
149
+ return texts if split is True else text
141
150
 
142
151
  def frontend_sft(self, tts_text, spk_id):
143
152
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
@@ -145,28 +154,32 @@ class CosyVoiceFrontEnd:
145
154
  model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
146
155
  return model_input
147
156
 
148
- def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
157
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
149
158
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
150
- prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
151
- prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
152
- speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
153
- speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
154
- if resample_rate == 24000:
155
- # cosyvoice2, force speech_feat % speech_token = 2
156
- token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
157
- speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
158
- speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
159
- embedding = self._extract_spk_embedding(prompt_speech_16k)
160
- model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
161
- 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
162
- 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
163
- 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
164
- 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
165
- 'llm_embedding': embedding, 'flow_embedding': embedding}
159
+ if zero_shot_spk_id == '':
160
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
161
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
162
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
163
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
164
+ if resample_rate == 24000:
165
+ # cosyvoice2, force speech_feat % speech_token = 2
166
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
167
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
168
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
169
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
170
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
171
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
172
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
173
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
174
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
175
+ else:
176
+ model_input = self.spk2info[zero_shot_spk_id]
177
+ model_input['text'] = tts_text_token
178
+ model_input['text_len'] = tts_text_token_len
166
179
  return model_input
167
180
 
168
- def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
169
- model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
181
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
182
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
170
183
  # in cross lingual mode, we remove prompt in llm
171
184
  del model_input['prompt_text']
172
185
  del model_input['prompt_text_len']
@@ -183,23 +196,10 @@ class CosyVoiceFrontEnd:
183
196
  model_input['prompt_text_len'] = instruct_text_token_len
184
197
  return model_input
185
198
 
186
- def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
187
- tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
188
- prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
189
- prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
190
- speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
191
- speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
192
- if resample_rate == 24000:
193
- # cosyvoice2, force speech_feat % speech_token = 2
194
- token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
195
- speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
196
- speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
197
- embedding = self._extract_spk_embedding(prompt_speech_16k)
198
- model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
199
- 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
200
- 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
201
- 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
202
- 'llm_embedding': embedding, 'flow_embedding': embedding}
199
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
200
+ model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
201
+ del model_input['llm_prompt_speech_token']
202
+ del model_input['llm_prompt_speech_token_len']
203
203
  return model_input
204
204
 
205
205
  def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):