xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  from functools import partial
15
+ import json
15
16
  import onnxruntime
16
17
  import torch
17
18
  import numpy as np
@@ -50,9 +51,13 @@ class CosyVoiceFrontEnd:
50
51
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
52
  option.intra_op_num_threads = 1
52
53
  self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
- self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
54
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
55
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
56
+ "CPUExecutionProvider"])
54
57
  if os.path.exists(spk2info):
55
58
  self.spk2info = torch.load(spk2info, map_location=self.device)
59
+ else:
60
+ self.spk2info = {}
56
61
  self.instruct = instruct
57
62
  self.allowed_special = allowed_special
58
63
  self.inflect_parser = inflect.engine()
@@ -60,10 +65,9 @@ class CosyVoiceFrontEnd:
60
65
  if self.use_ttsfrd:
61
66
  self.frd = ttsfrd.TtsFrontendEngine()
62
67
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
- assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
64
- self.frd.set_lang_type('pinyin')
65
- self.frd.enable_pinyin_mix(True)
66
- self.frd.set_breakmodel_index(1)
68
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
69
+ 'failed to initialize ttsfrd resource'
70
+ self.frd.set_lang_type('pinyinvg')
67
71
  else:
68
72
  self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
69
73
  self.en_tn_model = EnNormalizer()
@@ -75,9 +79,13 @@ class CosyVoiceFrontEnd:
75
79
  return text_token, text_token_len
76
80
 
77
81
  def _extract_speech_token(self, speech):
82
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
78
83
  feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
- speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
- self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
84
+ speech_token = self.speech_tokenizer_session.run(None,
85
+ {self.speech_tokenizer_session.get_inputs()[0].name:
86
+ feat.detach().cpu().numpy(),
87
+ self.speech_tokenizer_session.get_inputs()[1].name:
88
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
81
89
  speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
90
  speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
91
  return speech_token, speech_token_len
@@ -88,7 +96,8 @@ class CosyVoiceFrontEnd:
88
96
  dither=0,
89
97
  sample_frequency=16000)
90
98
  feat = feat - feat.mean(dim=0, keepdim=True)
91
- embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
99
+ embedding = self.campplus_session.run(None,
100
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
92
101
  embedding = torch.tensor([embedding]).to(self.device)
93
102
  return embedding
94
103
 
@@ -98,32 +107,34 @@ class CosyVoiceFrontEnd:
98
107
  speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
99
108
  return speech_feat, speech_feat_len
100
109
 
101
- def text_normalize(self, text, split=True):
110
+ def text_normalize(self, text, split=True, text_frontend=True):
111
+ if text_frontend is False:
112
+ return [text] if split is True else text
102
113
  text = text.strip()
103
114
  if contains_chinese(text):
104
115
  if self.use_ttsfrd:
105
- text = self.frd.get_frd_extra_info(text, 'input')
116
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
117
+ text = ''.join(texts)
106
118
  else:
107
119
  text = self.zh_tn_model.normalize(text)
108
- text = text.replace("\n", "")
109
- text = replace_blank(text)
110
- text = replace_corner_mark(text)
111
- text = text.replace(".", "")
112
- text = text.replace(" - ", ",")
113
- text = remove_bracket(text)
114
- text = re.sub(r'[,,]+$', '。', text)
115
- texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
- token_min_n=60, merge_len=20,
117
- comma_split=False)]
120
+ text = text.replace("\n", "")
121
+ text = replace_blank(text)
122
+ text = replace_corner_mark(text)
123
+ text = text.replace(".", "")
124
+ text = text.replace(" - ", ",")
125
+ text = remove_bracket(text)
126
+ text = re.sub(r'[,,、]+$', '。', text)
127
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
128
+ token_min_n=60, merge_len=20, comma_split=False))
118
129
  else:
119
130
  if self.use_ttsfrd:
120
- text = self.frd.get_frd_extra_info(text, 'input')
131
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
132
+ text = ''.join(texts)
121
133
  else:
122
134
  text = self.en_tn_model.normalize(text)
123
- text = spell_out_number(text, self.inflect_parser)
124
- texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
- token_min_n=60, merge_len=20,
126
- comma_split=False)]
135
+ text = spell_out_number(text, self.inflect_parser)
136
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
137
+ token_min_n=60, merge_len=20, comma_split=False))
127
138
  if split is False:
128
139
  return text
129
140
  return texts
@@ -134,12 +145,17 @@ class CosyVoiceFrontEnd:
134
145
  model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
135
146
  return model_input
136
147
 
137
- def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
148
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
138
149
  tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
139
150
  prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
140
- prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
141
- speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
151
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
152
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
142
153
  speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
154
+ if resample_rate == 24000:
155
+ # cosyvoice2, force speech_feat % speech_token = 2
156
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
157
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
158
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
143
159
  embedding = self._extract_spk_embedding(prompt_speech_16k)
144
160
  model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
145
161
  'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
@@ -149,8 +165,8 @@ class CosyVoiceFrontEnd:
149
165
  'llm_embedding': embedding, 'flow_embedding': embedding}
150
166
  return model_input
151
167
 
152
- def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
153
- model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
168
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
169
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
154
170
  # in cross lingual mode, we remove prompt in llm
155
171
  del model_input['prompt_text']
156
172
  del model_input['prompt_text_len']
@@ -166,3 +182,34 @@ class CosyVoiceFrontEnd:
166
182
  model_input['prompt_text'] = instruct_text_token
167
183
  model_input['prompt_text_len'] = instruct_text_token_len
168
184
  return model_input
185
+
186
+ def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
187
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
188
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
189
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
190
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
191
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
192
+ if resample_rate == 24000:
193
+ # cosyvoice2, force speech_feat % speech_token = 2
194
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
195
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
196
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
197
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
198
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
199
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
200
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
201
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
202
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
203
+ return model_input
204
+
205
+ def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
206
+ prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
207
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
208
+ prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
209
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
210
+ source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
211
+ model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
212
+ 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
213
+ 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
214
+ 'flow_embedding': embedding}
215
+ return model_input