xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +415 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +449 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
xinference/model/utils.py CHANGED
@@ -11,17 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import asyncio
14
16
  import json
15
17
  import logging
16
18
  import os
17
19
  import random
20
+ import threading
18
21
  from json import JSONDecodeError
19
22
  from pathlib import Path
20
- from typing import Any, Callable, Dict, Optional, Tuple, Union
23
+ from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
21
24
 
22
25
  import huggingface_hub
23
26
  import numpy as np
24
27
  import torch
28
+ from tqdm.auto import tqdm
25
29
 
26
30
  from ..constants import (
27
31
  XINFERENCE_CACHE_DIR,
@@ -343,3 +347,116 @@ def set_all_random_seed(seed: int):
343
347
  np.random.seed(seed)
344
348
  torch.manual_seed(seed)
345
349
  torch.cuda.manual_seed_all(seed)
350
+
351
+
352
+ class CancellableDownloader:
353
+ def __init__(
354
+ self,
355
+ cancel_error_cls: Type[BaseException] = asyncio.CancelledError,
356
+ cancelled_event: Optional[threading.Event] = None,
357
+ ):
358
+ self._cancelled = cancelled_event
359
+ if self._cancelled is None:
360
+ self._cancelled = threading.Event()
361
+ self._done_event = threading.Event()
362
+ self._cancel_error_cls = cancel_error_cls
363
+ self._original_update = None
364
+ # progress for tqdm that is main
365
+ self._main_progresses: Set[tqdm] = set()
366
+ # progress for file downloader
367
+ # mainly when tqdm unit is set
368
+ self._download_progresses: Set[tqdm] = set()
369
+ # tqdm original update
370
+ self._original_tqdm_update = None
371
+
372
+ def reset(self):
373
+ self._main_progresses.clear()
374
+ self._download_progresses.clear()
375
+
376
+ def get_progress(self) -> float:
377
+ if self.cancelled or self.done:
378
+ # directly return 1.0 when cancelled or finished
379
+ return 1.0
380
+
381
+ tasks = finished_tasks = 0
382
+ for main_progress in self._main_progresses:
383
+ tasks += main_progress.total or 0
384
+ finished_tasks += main_progress.n
385
+
386
+ if tasks == 0:
387
+ # we assumed at least 1 task
388
+ tasks = 1
389
+
390
+ finished_ratio = finished_tasks / tasks
391
+
392
+ all_download_progress = finished_download_progress = 0
393
+ for download_progress in self._download_progresses:
394
+ # we skip finished download
395
+ if download_progress.n == download_progress.total:
396
+ continue
397
+ all_download_progress += download_progress.total or (
398
+ download_progress.n * 10
399
+ )
400
+ finished_download_progress += download_progress.n
401
+
402
+ if all_download_progress > 0:
403
+ rest_ratio = (
404
+ (tasks - finished_tasks)
405
+ / tasks
406
+ * (finished_download_progress / all_download_progress)
407
+ )
408
+ return finished_ratio + rest_ratio
409
+ else:
410
+ return finished_ratio
411
+
412
+ def cancel(self):
413
+ self._cancelled.set()
414
+
415
+ @property
416
+ def cancelled(self):
417
+ return self._cancelled.is_set()
418
+
419
+ @property
420
+ def done(self):
421
+ return self._done_event.is_set()
422
+
423
+ def wait(self, timeout: float):
424
+ self._done_event.wait(timeout)
425
+
426
+ def raise_error(self, error_msg: str = "Download cancelled"):
427
+ raise self._cancel_error_cls(error_msg)
428
+
429
+ def patch_tqdm(self):
430
+ # patch tqdm
431
+ # raise error if cancelled
432
+ self._original_update = original_update = tqdm.update
433
+ downloader = self
434
+
435
+ def patched_update(self, n):
436
+ if downloader.cancelled:
437
+ downloader.raise_error()
438
+ if not self.disable:
439
+ progresses = (
440
+ downloader._main_progresses
441
+ if getattr(self, "unit", "it") == "it"
442
+ else downloader._download_progresses
443
+ )
444
+ progresses.add(self)
445
+ return original_update(self, n)
446
+
447
+ tqdm.update = patched_update
448
+
449
+ def unpatch_tqdm(self):
450
+ from tqdm.auto import tqdm
451
+
452
+ if self._original_update:
453
+ tqdm.update = self._original_update
454
+
455
+ def __enter__(self):
456
+ self.patch_tqdm()
457
+ return self
458
+
459
+ def __exit__(self, exc_type, exc_val, exc_tb):
460
+ self.unpatch_tqdm()
461
+ self._done_event.set()
462
+ self.reset()
@@ -17,7 +17,7 @@ from collections import defaultdict
17
17
  from typing import Any, Dict, List, Literal, Optional, Tuple
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
- from ..core import CacheableModelSpec, ModelDescription
20
+ from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
21
21
  from ..utils import valid_model_revision
22
22
  from .diffusers import DiffUsersVideoModel
23
23
 
@@ -44,6 +44,7 @@ class VideoModelFamilyV1(CacheableModelSpec):
44
44
  model_ability: Optional[List[str]]
45
45
  default_model_config: Optional[Dict[str, Any]]
46
46
  default_generate_config: Optional[Dict[str, Any]]
47
+ virtualenv: Optional[VirtualEnvSettings]
47
48
 
48
49
 
49
50
  class VideoModelDescription(ModelDescription):
@@ -57,6 +58,10 @@ class VideoModelDescription(ModelDescription):
57
58
  super().__init__(address, devices, model_path=model_path)
58
59
  self._model_spec = model_spec
59
60
 
61
+ @property
62
+ def spec(self):
63
+ return self._model_spec
64
+
60
65
  def to_dict(self):
61
66
  return {
62
67
  "model_type": "video",
File without changes
@@ -0,0 +1,175 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import whisper
18
+ import librosa
19
+ from copy import deepcopy
20
+ from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
21
+ from tts.utils.audio_utils.align import mel2token_to_dur
22
+
23
+ ''' Graphme to phoneme function '''
24
+ def g2p(self, text_inp):
25
+ # prepare inputs
26
+ txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
27
+ input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
28
+
29
+ # model forward
30
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
31
+ outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
32
+
33
+ # process outputs
34
+ ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
35
+ ph_pred, tone_pred = split_ph(ph_tokens[0])
36
+ ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
37
+ return ph_pred, tone_pred
38
+
39
+ ''' Get phoneme2mel align of prompt speech '''
40
+ def align(self, wav):
41
+ with torch.inference_mode():
42
+ whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
43
+ mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
44
+ prompt_max_frame = mel.size(2) // self.fm * self.fm
45
+ mel = mel[:, :, :prompt_max_frame]
46
+ token = torch.LongTensor([[798]]).to(self.device)
47
+ audio_features = self.aligner_lm.embed_audio(mel)
48
+ for i in range(768):
49
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
50
+ logits = self.aligner_lm.logits(token, audio_features, None)
51
+ token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
52
+ token = torch.cat([token, token_pred], dim=1)
53
+ if token_pred[0] == 799:
54
+ break
55
+ alignment_tokens = token
56
+
57
+ ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
58
+ ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
59
+ tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
60
+ if dur_ref.sum() < prompt_max_frame:
61
+ dur_ref[-1] += prompt_max_frame - dur_ref.sum()
62
+ elif dur_ref.sum() > prompt_max_frame:
63
+ len_diff = dur_ref.sum() - prompt_max_frame
64
+ while True:
65
+ for i in range(len(dur_ref)):
66
+ dur_ref[i] -= 1
67
+ len_diff -= 1
68
+ if len_diff == 0:
69
+ break
70
+ if len_diff == 0:
71
+ break
72
+ mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
73
+ mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
74
+ return ph_ref, tone_ref, mel2ph_ref
75
+
76
+ ''' Duration Prompting '''
77
+ def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
78
+ dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
79
+ max=self.hp_dur_model['dur_code_size'] - 1) + 1
80
+
81
+ ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
82
+ txt_tokens_flat_ = ph_ref.flatten(0, 1)
83
+ ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
84
+
85
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
86
+ dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
87
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
88
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
89
+ _, incremental_state_dur_prompt = self.dur_model.infer(
90
+ ph_ref, {'tone': tone_ref}, None, None, None,
91
+ ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
92
+ return incremental_state_dur_prompt, ctx_dur_tokens
93
+
94
+ ''' Duration Prediction '''
95
+ def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
96
+ last_dur_token = ctx_dur_tokens[:, -1:]
97
+ last_dur_pos_prompt = ctx_dur_tokens.shape[1]
98
+ incremental_state_dur = deepcopy(incremental_state_dur_prompt)
99
+ txt_len = ph_pred.shape[1]
100
+ dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
101
+ dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
102
+ last_dur_pos_prompt = last_dur_pos_prompt + txt_len
103
+
104
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
105
+ dur_pred = self.dur_model.infer(
106
+ ph_pred, {'tone': tone_pred}, None, None, None,
107
+ incremental_state=incremental_state_dur,
108
+ first_decoder_inp=last_dur_token,
109
+ spk_pos_ids_flat=dur_spk_pos_ids_flat,
110
+ )
111
+
112
+ dur_pred = dur_pred - 1
113
+ dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
114
+ # if is_final:
115
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
116
+ # else:
117
+ # dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
118
+ # if seg_i > 0:
119
+ # dur_pred[:, 0] = 0
120
+ # ['。', '!', '?', 'sil']
121
+ for sil_token in [148, 153, 166, 145]:
122
+ dur_pred[ph_pred==sil_token].clamp_min(32)
123
+ # [',', ';']
124
+ for sil_token in [163, 165]:
125
+ dur_pred[ph_pred==sil_token].clamp_min(16)
126
+ if not is_final:
127
+ # add 0.32ms for crossfade
128
+ dur_pred[:, -1] = dur_pred[:, -1] + 32
129
+ else:
130
+ dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
131
+
132
+ ''' DiT target speech generation '''
133
+ dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
134
+ dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
135
+ dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
136
+ dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
137
+ dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
138
+ if is_first:
139
+ dur_pred[:, 0] = 8
140
+
141
+ dur_sum = dur_pred.sum()
142
+ npad = self.fm - dur_sum % self.fm
143
+ if npad < self.fm:
144
+ dur_pred[:, -1] += npad
145
+ mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
146
+ return mel2ph_pred
147
+
148
+ def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
149
+ # Prepare duration token
150
+ mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
151
+ mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
152
+ # Prepare phone and tone token
153
+ ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
154
+ tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
155
+ # Disable the English tone (set them to 3)"""
156
+ en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
157
+ tone_pred[en_tone_idx] = 3
158
+
159
+ # Prepare cfg inputs
160
+ ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
161
+ tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
162
+ target_size = mel2ph_pred.size(1)//self.vae_stride
163
+ vae_latent_ = vae_latent.repeat(3, 1, 1)
164
+ ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
165
+ vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
166
+ vae_latent_[1:] = 0.0
167
+ ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
168
+
169
+ return {
170
+ 'phone': ph_seq,
171
+ 'tone': tone_seq,
172
+ "lat_ctx": vae_latent_ * ctx_mask,
173
+ "ctx_mask": ctx_mask,
174
+ "dur": mel2ph_pred,
175
+ }
@@ -0,0 +1,93 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing as mp
16
+ import torch
17
+ import os
18
+ from functools import partial
19
+ import gradio as gr
20
+ import traceback
21
+ from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
22
+
23
+
24
+ def model_worker(input_queue, output_queue, device_id):
25
+ device = None
26
+ if device_id is not None:
27
+ device = torch.device(f'cuda:{device_id}')
28
+ infer_pipe = MegaTTS3DiTInfer(device=device)
29
+
30
+ while True:
31
+ task = input_queue.get()
32
+ inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
33
+ try:
34
+ convert_to_wav(inp_audio_path)
35
+ wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
36
+ cut_wav(wav_path, max_len=28)
37
+ with open(wav_path, 'rb') as file:
38
+ file_content = file.read()
39
+ resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
40
+ wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
41
+ output_queue.put(wav_bytes)
42
+ except Exception as e:
43
+ traceback.print_exc()
44
+ print(task, str(e))
45
+ output_queue.put(None)
46
+
47
+
48
+ def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
49
+ print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
50
+ input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
51
+ res = output_queue.get()
52
+ if res is not None:
53
+ return res
54
+ else:
55
+ print("")
56
+ return None
57
+
58
+
59
+ if __name__ == '__main__':
60
+ mp.set_start_method('spawn', force=True)
61
+ mp_manager = mp.Manager()
62
+
63
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
64
+ if devices != '':
65
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
66
+ else:
67
+ devices = None
68
+
69
+ num_workers = 1
70
+ input_queue = mp_manager.Queue()
71
+ output_queue = mp_manager.Queue()
72
+ processes = []
73
+
74
+ print("Start open workers")
75
+ for i in range(num_workers):
76
+ p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
77
+ p.start()
78
+ processes.append(p)
79
+
80
+ api_interface = gr.Interface(fn=
81
+ partial(main, processes=processes, input_queue=input_queue,
82
+ output_queue=output_queue),
83
+ inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
84
+ gr.Number(label="infer timestep", value=32),
85
+ gr.Number(label="Intelligibility Weight", value=1.4),
86
+ gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
87
+ title="MegaTTS3",
88
+ description="Upload a speech clip as a reference for timbre, " +
89
+ "upload the pre-extracted latent file, "+
90
+ "input the target text, and receive the cloned voice.", concurrency_limit=1)
91
+ api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
92
+ for p in processes:
93
+ p.join()