xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (132) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +54 -1
  4. xinference/client/restful/restful_client.py +82 -2
  5. xinference/constants.py +3 -0
  6. xinference/core/chat_interface.py +297 -83
  7. xinference/core/model.py +24 -3
  8. xinference/core/progress_tracker.py +16 -8
  9. xinference/core/supervisor.py +51 -1
  10. xinference/core/worker.py +315 -47
  11. xinference/deploy/cmdline.py +33 -1
  12. xinference/model/audio/core.py +11 -1
  13. xinference/model/audio/megatts.py +105 -0
  14. xinference/model/audio/model_spec.json +24 -1
  15. xinference/model/audio/model_spec_modelscope.json +26 -1
  16. xinference/model/core.py +14 -0
  17. xinference/model/embedding/core.py +6 -1
  18. xinference/model/flexible/core.py +6 -1
  19. xinference/model/image/core.py +6 -1
  20. xinference/model/image/model_spec.json +17 -1
  21. xinference/model/image/model_spec_modelscope.json +17 -1
  22. xinference/model/llm/__init__.py +4 -6
  23. xinference/model/llm/core.py +5 -0
  24. xinference/model/llm/llama_cpp/core.py +46 -17
  25. xinference/model/llm/llm_family.json +530 -85
  26. xinference/model/llm/llm_family.py +24 -1
  27. xinference/model/llm/llm_family_modelscope.json +572 -1
  28. xinference/model/llm/mlx/core.py +16 -2
  29. xinference/model/llm/reasoning_parser.py +3 -3
  30. xinference/model/llm/sglang/core.py +111 -13
  31. xinference/model/llm/transformers/__init__.py +14 -0
  32. xinference/model/llm/transformers/core.py +31 -6
  33. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  34. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  35. xinference/model/llm/transformers/gemma3.py +17 -2
  36. xinference/model/llm/transformers/intern_vl.py +28 -18
  37. xinference/model/llm/transformers/minicpmv26.py +21 -2
  38. xinference/model/llm/transformers/qwen-omni.py +308 -0
  39. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  40. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  41. xinference/model/llm/utils.py +37 -15
  42. xinference/model/llm/vllm/core.py +184 -8
  43. xinference/model/llm/vllm/distributed_executor.py +320 -0
  44. xinference/model/rerank/core.py +22 -12
  45. xinference/model/utils.py +118 -1
  46. xinference/model/video/core.py +6 -1
  47. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  48. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  49. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  50. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  51. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  52. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  53. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  54. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  55. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  56. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  57. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  58. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  59. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  60. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  61. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  62. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  63. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  64. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  65. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  66. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  67. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  68. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  69. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  70. xinference/thirdparty/megatts3/__init__.py +0 -0
  71. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  72. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  73. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  74. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  75. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  76. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  77. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  78. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  79. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  80. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  81. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  82. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  83. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  84. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  85. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  86. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  87. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  88. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  89. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  90. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  91. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  92. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  93. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  94. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  95. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  96. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  97. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  98. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  99. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  100. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  101. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  102. xinference/types.py +10 -0
  103. xinference/utils.py +54 -0
  104. xinference/web/ui/build/asset-manifest.json +6 -6
  105. xinference/web/ui/build/index.html +1 -1
  106. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  107. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  108. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  109. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  112. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  113. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  114. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  116. xinference/web/ui/src/locales/en.json +2 -1
  117. xinference/web/ui/src/locales/zh.json +2 -1
  118. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
  119. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
  120. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  121. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  122. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  123. xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
  124. xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  129. /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  130. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  131. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  132. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import json
21
+ from typing import Dict, List
22
+
23
+ import PIL.Image
24
+ import torch
25
+ from transformers import AutoModelForCausalLM
26
+
27
+
28
+ def load_pretrained_model(model_path: str):
29
+
30
+ from deepseek_vl2.models.processing_deepseek_vl_v2 import DeepseekVLV2Processor
31
+ from deepseek_vl2.models.modeling_deepseek_vl_v2 import DeepseekVLV2ForCausalLM
32
+
33
+ vl_chat_processor = DeepseekVLV2Processor.from_pretrained(model_path)
34
+ tokenizer = vl_chat_processor.tokenizer
35
+
36
+ vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
37
+ model_path, trust_remote_code=True
38
+ )
39
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
40
+
41
+ return tokenizer, vl_chat_processor, vl_gpt
42
+
43
+
44
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
45
+ """
46
+
47
+ Args:
48
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
49
+ [
50
+ {
51
+ "role": "User",
52
+ "content": "<image>\nExtract all information from this image and convert them into markdown format.",
53
+ "images": ["./examples/table_datasets.png"]
54
+ },
55
+ {"role": "Assistant", "content": ""},
56
+ ]
57
+
58
+ Returns:
59
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
60
+
61
+ """
62
+
63
+ pil_images = []
64
+
65
+ for message in conversations:
66
+ if "images" not in message:
67
+ continue
68
+
69
+ for image_path in message["images"]:
70
+ pil_img = PIL.Image.open(image_path)
71
+ pil_img = pil_img.convert("RGB")
72
+ pil_images.append(pil_img)
73
+
74
+ return pil_images
75
+
76
+
77
+ def load_json(filepath):
78
+ with open(filepath, "r") as f:
79
+ data = json.load(f)
80
+ return data
File without changes
@@ -0,0 +1,175 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import whisper
18
+ import librosa
19
+ from copy import deepcopy
20
+ from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
21
+ from tts.utils.audio_utils.align import mel2token_to_dur
22
+
23
+ ''' Graphme to phoneme function '''
24
+ def g2p(self, text_inp):
25
+ # prepare inputs
26
+ txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
27
+ input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
28
+
29
+ # model forward
30
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
31
+ outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
32
+
33
+ # process outputs
34
+ ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
35
+ ph_pred, tone_pred = split_ph(ph_tokens[0])
36
+ ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
37
+ return ph_pred, tone_pred
38
+
39
+ ''' Get phoneme2mel align of prompt speech '''
40
+ def align(self, wav):
41
+ with torch.inference_mode():
42
+ whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
43
+ mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
44
+ prompt_max_frame = mel.size(2) // self.fm * self.fm
45
+ mel = mel[:, :, :prompt_max_frame]
46
+ token = torch.LongTensor([[798]]).to(self.device)
47
+ audio_features = self.aligner_lm.embed_audio(mel)
48
+ for i in range(768):
49
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
50
+ logits = self.aligner_lm.logits(token, audio_features, None)
51
+ token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
52
+ token = torch.cat([token, token_pred], dim=1)
53
+ if token_pred[0] == 799:
54
+ break
55
+ alignment_tokens = token
56
+
57
+ ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
58
+ ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
59
+ tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
60
+ if dur_ref.sum() < prompt_max_frame:
61
+ dur_ref[-1] += prompt_max_frame - dur_ref.sum()
62
+ elif dur_ref.sum() > prompt_max_frame:
63
+ len_diff = dur_ref.sum() - prompt_max_frame
64
+ while True:
65
+ for i in range(len(dur_ref)):
66
+ dur_ref[i] -= 1
67
+ len_diff -= 1
68
+ if len_diff == 0:
69
+ break
70
+ if len_diff == 0:
71
+ break
72
+ mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
73
+ mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
74
+ return ph_ref, tone_ref, mel2ph_ref
75
+
76
+ ''' Duration Prompting '''
77
+ def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
78
+ dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
79
+ max=self.hp_dur_model['dur_code_size'] - 1) + 1
80
+
81
+ ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
82
+ txt_tokens_flat_ = ph_ref.flatten(0, 1)
83
+ ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
84
+
85
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
86
+ dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
87
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
88
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
89
+ _, incremental_state_dur_prompt = self.dur_model.infer(
90
+ ph_ref, {'tone': tone_ref}, None, None, None,
91
+ ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
92
+ return incremental_state_dur_prompt, ctx_dur_tokens
93
+
94
+ ''' Duration Prediction '''
95
+ def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
96
+ last_dur_token = ctx_dur_tokens[:, -1:]
97
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
98
+ incremental_state_dur = deepcopy(incremental_state_dur_prompt)
99
+ txt_len = ph_pred.shape[1]
100
+ dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
101
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
102
+ last_dur_pos_prompt = last_dur_pos_prompt + txt_len
103
+
104
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
105
+ dur_pred = self.dur_model.infer(
106
+ ph_pred, {'tone': tone_pred}, None, None, None,
107
+ incremental_state=incremental_state_dur,
108
+ first_decoder_inp=last_dur_token,
109
+ spk_pos_ids_flat=dur_spk_pos_ids_flat,
110
+ )
111
+
112
+ dur_pred = dur_pred - 1
113
+ dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
114
+ # if is_final:
115
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
116
+ # else:
117
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
118
+ # if seg_i > 0:
119
+ # dur_pred[:, 0] = 0
120
+ # ['。', '!', '?', 'sil']
121
+ for sil_token in [148, 153, 166, 145]:
122
+ dur_pred[ph_pred==sil_token].clamp_min(32)
123
+ # [',', ';']
124
+ for sil_token in [163, 165]:
125
+ dur_pred[ph_pred==sil_token].clamp_min(16)
126
+ if not is_final:
127
+ # add 0.32ms for crossfade
128
+ dur_pred[:, -1] = dur_pred[:, -1] + 32
129
+ else:
130
+ dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
131
+
132
+ ''' DiT target speech generation '''
133
+ dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
134
+ dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
135
+ dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
136
+ dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
137
+ dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
138
+ if is_first:
139
+ dur_pred[:, 0] = 8
140
+
141
+ dur_sum = dur_pred.sum()
142
+ npad = self.fm - dur_sum % self.fm
143
+ if npad < self.fm:
144
+ dur_pred[:, -1] += npad
145
+ mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
146
+ return mel2ph_pred
147
+
148
+ def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
149
+ # Prepare duration token
150
+ mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
151
+ mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
152
+ # Prepare phone and tone token
153
+ ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
154
+ tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
155
+ # Disable the English tone (set them to 3)"""
156
+ en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
157
+ tone_pred[en_tone_idx] = 3
158
+
159
+ # Prepare cfg inputs
160
+ ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
161
+ tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
162
+ target_size = mel2ph_pred.size(1)//self.vae_stride
163
+ vae_latent_ = vae_latent.repeat(3, 1, 1)
164
+ ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
165
+ vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
166
+ vae_latent_[1:] = 0.0
167
+ ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
168
+
169
+ return {
170
+ 'phone': ph_seq,
171
+ 'tone': tone_seq,
172
+ "lat_ctx": vae_latent_ * ctx_mask,
173
+ "ctx_mask": ctx_mask,
174
+ "dur": mel2ph_pred,
175
+ }
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing as mp
16
+ import torch
17
+ import os
18
+ from functools import partial
19
+ import gradio as gr
20
+ import traceback
21
+ from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
+
23
+
24
+ def model_worker(input_queue, output_queue, device_id):
25
+ device = None
26
+ if device_id is not None:
27
+ device = torch.device(f'cuda:{device_id}')
28
+ infer_pipe = MegaTTS3DiTInfer(device=device)
29
+
30
+ while True:
31
+ task = input_queue.get()
32
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
33
+ try:
34
+ convert_to_wav(inp_audio_path)
35
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
36
+ cut_wav(wav_path, max_len=28)
37
+ with open(wav_path, 'rb') as file:
38
+ file_content = file.read()
39
+ resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
40
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
41
+ output_queue.put(wav_bytes)
42
+ except Exception as e:
43
+ traceback.print_exc()
44
+ print(task, str(e))
45
+ output_queue.put(None)
46
+
47
+
48
+ def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
49
+ print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
50
+ input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
51
+ res = output_queue.get()
52
+ if res is not None:
53
+ return res
54
+ else:
55
+ print("")
56
+ return None
57
+
58
+
59
+ if __name__ == '__main__':
60
+ mp.set_start_method('spawn', force=True)
61
+ mp_manager = mp.Manager()
62
+
63
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
64
+ if devices != '':
65
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
66
+ else:
67
+ devices = None
68
+
69
+ num_workers = 1
70
+ input_queue = mp_manager.Queue()
71
+ output_queue = mp_manager.Queue()
72
+ processes = []
73
+
74
+ print("Start open workers")
75
+ for i in range(num_workers):
76
+ p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
77
+ p.start()
78
+ processes.append(p)
79
+
80
+ api_interface = gr.Interface(fn=
81
+ partial(main, processes=processes, input_queue=input_queue,
82
+ output_queue=output_queue),
83
+ inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
84
+ gr.Number(label="infer timestep", value=32),
85
+ gr.Number(label="Intelligibility Weight", value=1.4),
86
+ gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
87
+ title="MegaTTS3",
88
+ description="Upload a speech clip as a reference for timbre, " +
89
+ "upload the pre-extracted latent file, "+
90
+ "input the target text, and receive the cloned voice.", concurrency_limit=1)
91
+ api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
92
+ for p in processes:
93
+ p.join()
@@ -0,0 +1,277 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import argparse
18
+ import librosa
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
23
+ from tn.english.normalizer import Normalizer as EnNormalizer
24
+ from langdetect import detect as classify_language
25
+ from pydub import AudioSegment
26
+ import pyloudnorm as pyln
27
+
28
+ from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
29
+ from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
30
+ from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments
31
+ from tts.utils.commons.ckpt_utils import load_ckpt
32
+ from tts.utils.commons.hparams import set_hparams, hparams
33
+ from tts.utils.text_utils.text_encoder import TokenTextEncoder
34
+ from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english
35
+ from tts.utils.commons.hparams import hparams, set_hparams
36
+
37
+
38
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
+
41
+ def convert_to_wav(wav_path):
42
+ # Check if the file exists
43
+ if not os.path.exists(wav_path):
44
+ print(f"The file '{wav_path}' does not exist.")
45
+ return
46
+
47
+ # Check if the file already has a .wav extension
48
+ if not wav_path.endswith(".wav"):
49
+ # Define the output path with a .wav extension
50
+ out_path = os.path.splitext(wav_path)[0] + ".wav"
51
+
52
+ # Load the audio file using pydub and convert it to WAV
53
+ audio = AudioSegment.from_file(wav_path)
54
+ audio.export(out_path, format="wav")
55
+
56
+ print(f"Converted '{wav_path}' to '{out_path}'")
57
+
58
+
59
+ def cut_wav(wav_path, max_len=28):
60
+ audio = AudioSegment.from_file(wav_path)
61
+ audio = audio[:int(max_len * 1000)]
62
+ audio.export(wav_path, format="wav")
63
+
64
+ class MegaTTS3DiTInfer():
65
+ def __init__(
66
+ self,
67
+ device=None,
68
+ ckpt_root='./checkpoints',
69
+ dit_exp_name='diffusion_transformer',
70
+ frontend_exp_name='aligner_lm',
71
+ wavvae_exp_name='wavvae',
72
+ dur_ckpt_path='duration_lm',
73
+ g2p_exp_name='g2p',
74
+ precision=torch.float16,
75
+ **kwargs
76
+ ):
77
+ self.sr = 24000
78
+ self.fm = 8
79
+ if device is None:
80
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
81
+ self.device = device
82
+ self.precision = precision
83
+
84
+ # build models
85
+ self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
86
+ self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
87
+ self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
88
+ self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
89
+ self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
90
+ self.build_model(self.device)
91
+
92
+ # init text normalizer
93
+ self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
94
+ self.en_normalizer = EnNormalizer(overwrite_cache=False)
95
+ # loudness meter
96
+ self.loudness_meter = pyln.Meter(self.sr)
97
+
98
+ def build_model(self, device):
99
+ set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
100
+
101
+ ''' Load Dict '''
102
+ current_dir = os.path.dirname(os.path.abspath(__file__))
103
+ ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig'))
104
+ self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
105
+ self.token_encoder = token_encoder = self.ling_dict['phone']
106
+ ph_dict_size = len(token_encoder)
107
+
108
+ ''' Load Duration LM '''
109
+ from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
110
+ hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
111
+ hp_dur_model['frames_multiple'] = hparams['frames_multiple']
112
+ self.dur_model = ARDurPredictor(
113
+ hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
114
+ hp_dur_model['dur_model_layers'], ph_dict_size,
115
+ hp_dur_model['dur_code_size'],
116
+ use_rot_embed=hp_dur_model.get('use_rot_embed', False))
117
+ self.length_regulator = LengthRegulator()
118
+ load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
119
+ self.dur_model.eval()
120
+ self.dur_model.to(device)
121
+
122
+ ''' Load Diffusion Transformer '''
123
+ from tts.modules.llm_dit.dit import Diffusion
124
+ self.dit = Diffusion()
125
+ load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
126
+ self.dit.eval()
127
+ self.dit.to(device)
128
+ self.cfg_mask_token_phone = 302 - 1
129
+ self.cfg_mask_token_tone = 32 - 1
130
+
131
+ ''' Load Frontend LM '''
132
+ from tts.modules.aligner.whisper_small import Whisper
133
+ self.aligner_lm = Whisper()
134
+ load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
135
+ self.aligner_lm.eval()
136
+ self.aligner_lm.to(device)
137
+ self.kv_cache = None
138
+ self.hooks = None
139
+
140
+ ''' Load G2P LM'''
141
+ from transformers import AutoTokenizer, AutoModelForCausalLM
142
+ g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
143
+ g2p_tokenizer.padding_side = "right"
144
+ self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
145
+ self.g2p_tokenizer = g2p_tokenizer
146
+ self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]
147
+
148
+ ''' Wav VAE '''
149
+ self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
150
+ from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
151
+ self.wavvae = WavVAE_V3(hparams=hp_wavvae)
152
+ if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
153
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
154
+ self.has_vae_encoder = True
155
+ else:
156
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
157
+ self.has_vae_encoder = False
158
+ self.wavvae.eval()
159
+ self.wavvae.to(device)
160
+ self.vae_stride = hp_wavvae.get('vae_stride', 4)
161
+ self.hop_size = hp_wavvae.get('hop_size', 4)
162
+
163
+ def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
164
+ wav_bytes = convert_to_wav_bytes(audio_bytes)
165
+
166
+ ''' Load wav '''
167
+ wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
168
+ # Pad wav if necessary
169
+ ws = hparams['win_size']
170
+ if len(wav) % ws < ws - 1:
171
+ wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
172
+ wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
173
+ self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
174
+
175
+ ''' obtain alignments with aligner_lm '''
176
+ ph_ref, tone_ref, mel2ph_ref = align(self, wav)
177
+
178
+ with torch.inference_mode():
179
+ ''' Forward WaveVAE to obtain: prompt latent '''
180
+ if self.has_vae_encoder:
181
+ wav = torch.FloatTensor(wav)[None].to(self.device)
182
+ vae_latent = self.wavvae.encode_latent(wav)
183
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
184
+ else:
185
+ assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode"
186
+ vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
187
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
188
+
189
+ ''' Duration Prompting '''
190
+ self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
191
+ incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
192
+
193
+ return {
194
+ 'ph_ref': ph_ref,
195
+ 'tone_ref': tone_ref,
196
+ 'mel2ph_ref': mel2ph_ref,
197
+ 'vae_latent': vae_latent,
198
+ 'incremental_state_dur_prompt': incremental_state_dur_prompt,
199
+ 'ctx_dur_tokens': ctx_dur_tokens,
200
+ }
201
+
202
+ def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
203
+ device = self.device
204
+
205
+ ph_ref = resource_context['ph_ref'].to(device)
206
+ tone_ref = resource_context['tone_ref'].to(device)
207
+ mel2ph_ref = resource_context['mel2ph_ref'].to(device)
208
+ vae_latent = resource_context['vae_latent'].to(device)
209
+ ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device)
210
+ incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt']
211
+
212
+ with torch.inference_mode():
213
+ ''' Generating '''
214
+ wav_pred_ = []
215
+ language_type = classify_language(input_text)
216
+ if language_type == 'en':
217
+ input_text = self.en_normalizer.normalize(input_text)
218
+ text_segs = chunk_text_english(input_text, max_chars=130)
219
+ else:
220
+ input_text = self.zh_normalizer.normalize(input_text)
221
+ text_segs = chunk_text_chinese(input_text, limit=60)
222
+
223
+ for seg_i, text in enumerate(text_segs):
224
+ ''' G2P '''
225
+ ph_pred, tone_pred = g2p(self, text)
226
+
227
+ ''' Duration Prediction '''
228
+ mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
229
+
230
+ inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent)
231
+ # Speech dit inference
232
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
233
+ x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
234
+
235
+ # WavVAE decode
236
+ x[:, :vae_latent.size(1)] = vae_latent
237
+ wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32)
238
+
239
+ ''' Post-processing '''
240
+ # Trim prompt wav
241
+ wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
242
+ # Norm generated wav to prompt wav's level
243
+ meter = pyln.Meter(self.sr) # create BS.1770 meter
244
+ loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
245
+ wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
246
+ if np.abs(wav_pred).max() >= 1:
247
+ wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
248
+
249
+ # Apply hamming window
250
+ wav_pred_.append(wav_pred)
251
+
252
+ return combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
253
+
254
+
255
+ if __name__ == '__main__':
256
+ parser = argparse.ArgumentParser()
257
+ parser.add_argument('--input_wav', type=str)
258
+ parser.add_argument('--input_text', type=str)
259
+ parser.add_argument('--output_dir', type=str)
260
+ parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer')
261
+ parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight')
262
+ parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight')
263
+ args = parser.parse_args()
264
+ wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w
265
+
266
+ infer_ins = MegaTTS3DiTInfer()
267
+
268
+ with open(wav_path, 'rb') as file:
269
+ file_content = file.read()
270
+
271
+ print(f"| Start processing {wav_path}+{input_text}")
272
+ resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy'))
273
+ wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w)
274
+
275
+ print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav")
276
+ os.makedirs(out_path, exist_ok=True)
277
+ save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')