xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +413 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +447 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
xinference/model/utils.py
CHANGED
|
@@ -11,17 +11,21 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
14
16
|
import json
|
|
15
17
|
import logging
|
|
16
18
|
import os
|
|
17
19
|
import random
|
|
20
|
+
import threading
|
|
18
21
|
from json import JSONDecodeError
|
|
19
22
|
from pathlib import Path
|
|
20
|
-
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
23
|
+
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
|
|
21
24
|
|
|
22
25
|
import huggingface_hub
|
|
23
26
|
import numpy as np
|
|
24
27
|
import torch
|
|
28
|
+
from tqdm.auto import tqdm
|
|
25
29
|
|
|
26
30
|
from ..constants import (
|
|
27
31
|
XINFERENCE_CACHE_DIR,
|
|
@@ -343,3 +347,116 @@ def set_all_random_seed(seed: int):
|
|
|
343
347
|
np.random.seed(seed)
|
|
344
348
|
torch.manual_seed(seed)
|
|
345
349
|
torch.cuda.manual_seed_all(seed)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class CancellableDownloader:
|
|
353
|
+
def __init__(
|
|
354
|
+
self,
|
|
355
|
+
cancel_error_cls: Type[BaseException] = asyncio.CancelledError,
|
|
356
|
+
cancelled_event: Optional[threading.Event] = None,
|
|
357
|
+
):
|
|
358
|
+
self._cancelled = cancelled_event
|
|
359
|
+
if self._cancelled is None:
|
|
360
|
+
self._cancelled = threading.Event()
|
|
361
|
+
self._done_event = threading.Event()
|
|
362
|
+
self._cancel_error_cls = cancel_error_cls
|
|
363
|
+
self._original_update = None
|
|
364
|
+
# progress for tqdm that is main
|
|
365
|
+
self._main_progresses: Set[tqdm] = set()
|
|
366
|
+
# progress for file downloader
|
|
367
|
+
# mainly when tqdm unit is set
|
|
368
|
+
self._download_progresses: Set[tqdm] = set()
|
|
369
|
+
# tqdm original update
|
|
370
|
+
self._original_tqdm_update = None
|
|
371
|
+
|
|
372
|
+
def reset(self):
|
|
373
|
+
self._main_progresses.clear()
|
|
374
|
+
self._download_progresses.clear()
|
|
375
|
+
|
|
376
|
+
def get_progress(self) -> float:
|
|
377
|
+
if self.cancelled or self.done:
|
|
378
|
+
# directly return 1.0 when cancelled or finished
|
|
379
|
+
return 1.0
|
|
380
|
+
|
|
381
|
+
tasks = finished_tasks = 0
|
|
382
|
+
for main_progress in self._main_progresses:
|
|
383
|
+
tasks += main_progress.total or 0
|
|
384
|
+
finished_tasks += main_progress.n
|
|
385
|
+
|
|
386
|
+
if tasks == 0:
|
|
387
|
+
# we assumed at least 1 task
|
|
388
|
+
tasks = 1
|
|
389
|
+
|
|
390
|
+
finished_ratio = finished_tasks / tasks
|
|
391
|
+
|
|
392
|
+
all_download_progress = finished_download_progress = 0
|
|
393
|
+
for download_progress in self._download_progresses:
|
|
394
|
+
# we skip finished download
|
|
395
|
+
if download_progress.n == download_progress.total:
|
|
396
|
+
continue
|
|
397
|
+
all_download_progress += download_progress.total or (
|
|
398
|
+
download_progress.n * 10
|
|
399
|
+
)
|
|
400
|
+
finished_download_progress += download_progress.n
|
|
401
|
+
|
|
402
|
+
if all_download_progress > 0:
|
|
403
|
+
rest_ratio = (
|
|
404
|
+
(tasks - finished_tasks)
|
|
405
|
+
/ tasks
|
|
406
|
+
* (finished_download_progress / all_download_progress)
|
|
407
|
+
)
|
|
408
|
+
return finished_ratio + rest_ratio
|
|
409
|
+
else:
|
|
410
|
+
return finished_ratio
|
|
411
|
+
|
|
412
|
+
def cancel(self):
|
|
413
|
+
self._cancelled.set()
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def cancelled(self):
|
|
417
|
+
return self._cancelled.is_set()
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def done(self):
|
|
421
|
+
return self._done_event.is_set()
|
|
422
|
+
|
|
423
|
+
def wait(self, timeout: float):
|
|
424
|
+
self._done_event.wait(timeout)
|
|
425
|
+
|
|
426
|
+
def raise_error(self, error_msg: str = "Download cancelled"):
|
|
427
|
+
raise self._cancel_error_cls(error_msg)
|
|
428
|
+
|
|
429
|
+
def patch_tqdm(self):
|
|
430
|
+
# patch tqdm
|
|
431
|
+
# raise error if cancelled
|
|
432
|
+
self._original_update = original_update = tqdm.update
|
|
433
|
+
downloader = self
|
|
434
|
+
|
|
435
|
+
def patched_update(self, n):
|
|
436
|
+
if downloader.cancelled:
|
|
437
|
+
downloader.raise_error()
|
|
438
|
+
if not self.disable:
|
|
439
|
+
progresses = (
|
|
440
|
+
downloader._main_progresses
|
|
441
|
+
if getattr(self, "unit", "it") == "it"
|
|
442
|
+
else downloader._download_progresses
|
|
443
|
+
)
|
|
444
|
+
progresses.add(self)
|
|
445
|
+
return original_update(self, n)
|
|
446
|
+
|
|
447
|
+
tqdm.update = patched_update
|
|
448
|
+
|
|
449
|
+
def unpatch_tqdm(self):
|
|
450
|
+
from tqdm.auto import tqdm
|
|
451
|
+
|
|
452
|
+
if self._original_update:
|
|
453
|
+
tqdm.update = self._original_update
|
|
454
|
+
|
|
455
|
+
def __enter__(self):
|
|
456
|
+
self.patch_tqdm()
|
|
457
|
+
return self
|
|
458
|
+
|
|
459
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
460
|
+
self.unpatch_tqdm()
|
|
461
|
+
self._done_event.set()
|
|
462
|
+
self.reset()
|
xinference/model/video/core.py
CHANGED
|
@@ -17,7 +17,7 @@ from collections import defaultdict
|
|
|
17
17
|
from typing import Any, Dict, List, Literal, Optional, Tuple
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
|
-
from ..core import CacheableModelSpec, ModelDescription
|
|
20
|
+
from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .diffusers import DiffUsersVideoModel
|
|
23
23
|
|
|
@@ -44,6 +44,7 @@ class VideoModelFamilyV1(CacheableModelSpec):
|
|
|
44
44
|
model_ability: Optional[List[str]]
|
|
45
45
|
default_model_config: Optional[Dict[str, Any]]
|
|
46
46
|
default_generate_config: Optional[Dict[str, Any]]
|
|
47
|
+
virtualenv: Optional[VirtualEnvSettings]
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
class VideoModelDescription(ModelDescription):
|
|
@@ -57,6 +58,10 @@ class VideoModelDescription(ModelDescription):
|
|
|
57
58
|
super().__init__(address, devices, model_path=model_path)
|
|
58
59
|
self._model_spec = model_spec
|
|
59
60
|
|
|
61
|
+
@property
|
|
62
|
+
def spec(self):
|
|
63
|
+
return self._model_spec
|
|
64
|
+
|
|
60
65
|
def to_dict(self):
|
|
61
66
|
return {
|
|
62
67
|
"model_type": "video",
|
|
File without changes
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
import whisper
|
|
18
|
+
import librosa
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
|
|
21
|
+
from tts.utils.audio_utils.align import mel2token_to_dur
|
|
22
|
+
|
|
23
|
+
''' Graphme to phoneme function '''
|
|
24
|
+
def g2p(self, text_inp):
|
|
25
|
+
# prepare inputs
|
|
26
|
+
txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
|
|
27
|
+
input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
|
|
28
|
+
|
|
29
|
+
# model forward
|
|
30
|
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
|
31
|
+
outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
|
|
32
|
+
|
|
33
|
+
# process outputs
|
|
34
|
+
ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
|
|
35
|
+
ph_pred, tone_pred = split_ph(ph_tokens[0])
|
|
36
|
+
ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
|
|
37
|
+
return ph_pred, tone_pred
|
|
38
|
+
|
|
39
|
+
''' Get phoneme2mel align of prompt speech '''
|
|
40
|
+
def align(self, wav):
|
|
41
|
+
with torch.inference_mode():
|
|
42
|
+
whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
|
|
43
|
+
mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
|
|
44
|
+
prompt_max_frame = mel.size(2) // self.fm * self.fm
|
|
45
|
+
mel = mel[:, :, :prompt_max_frame]
|
|
46
|
+
token = torch.LongTensor([[798]]).to(self.device)
|
|
47
|
+
audio_features = self.aligner_lm.embed_audio(mel)
|
|
48
|
+
for i in range(768):
|
|
49
|
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
|
50
|
+
logits = self.aligner_lm.logits(token, audio_features, None)
|
|
51
|
+
token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
|
|
52
|
+
token = torch.cat([token, token_pred], dim=1)
|
|
53
|
+
if token_pred[0] == 799:
|
|
54
|
+
break
|
|
55
|
+
alignment_tokens = token
|
|
56
|
+
|
|
57
|
+
ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
|
|
58
|
+
ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
|
|
59
|
+
tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
|
|
60
|
+
if dur_ref.sum() < prompt_max_frame:
|
|
61
|
+
dur_ref[-1] += prompt_max_frame - dur_ref.sum()
|
|
62
|
+
elif dur_ref.sum() > prompt_max_frame:
|
|
63
|
+
len_diff = dur_ref.sum() - prompt_max_frame
|
|
64
|
+
while True:
|
|
65
|
+
for i in range(len(dur_ref)):
|
|
66
|
+
dur_ref[i] -= 1
|
|
67
|
+
len_diff -= 1
|
|
68
|
+
if len_diff == 0:
|
|
69
|
+
break
|
|
70
|
+
if len_diff == 0:
|
|
71
|
+
break
|
|
72
|
+
mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
|
|
73
|
+
mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
|
|
74
|
+
return ph_ref, tone_ref, mel2ph_ref
|
|
75
|
+
|
|
76
|
+
''' Duration Prompting '''
|
|
77
|
+
def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
|
|
78
|
+
dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
|
|
79
|
+
max=self.hp_dur_model['dur_code_size'] - 1) + 1
|
|
80
|
+
|
|
81
|
+
ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
|
|
82
|
+
txt_tokens_flat_ = ph_ref.flatten(0, 1)
|
|
83
|
+
ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
|
|
84
|
+
|
|
85
|
+
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
|
|
86
|
+
dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
|
|
87
|
+
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
|
|
88
|
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
|
89
|
+
_, incremental_state_dur_prompt = self.dur_model.infer(
|
|
90
|
+
ph_ref, {'tone': tone_ref}, None, None, None,
|
|
91
|
+
ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
|
|
92
|
+
return incremental_state_dur_prompt, ctx_dur_tokens
|
|
93
|
+
|
|
94
|
+
''' Duration Prediction '''
|
|
95
|
+
def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
|
|
96
|
+
last_dur_token = ctx_dur_tokens[:, -1:]
|
|
97
|
+
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
|
|
98
|
+
incremental_state_dur = deepcopy(incremental_state_dur_prompt)
|
|
99
|
+
txt_len = ph_pred.shape[1]
|
|
100
|
+
dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
|
|
101
|
+
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
|
|
102
|
+
last_dur_pos_prompt = last_dur_pos_prompt + txt_len
|
|
103
|
+
|
|
104
|
+
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
|
|
105
|
+
dur_pred = self.dur_model.infer(
|
|
106
|
+
ph_pred, {'tone': tone_pred}, None, None, None,
|
|
107
|
+
incremental_state=incremental_state_dur,
|
|
108
|
+
first_decoder_inp=last_dur_token,
|
|
109
|
+
spk_pos_ids_flat=dur_spk_pos_ids_flat,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
dur_pred = dur_pred - 1
|
|
113
|
+
dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
|
|
114
|
+
# if is_final:
|
|
115
|
+
# dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
|
|
116
|
+
# else:
|
|
117
|
+
# dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
|
|
118
|
+
# if seg_i > 0:
|
|
119
|
+
# dur_pred[:, 0] = 0
|
|
120
|
+
# ['。', '!', '?', 'sil']
|
|
121
|
+
for sil_token in [148, 153, 166, 145]:
|
|
122
|
+
dur_pred[ph_pred==sil_token].clamp_min(32)
|
|
123
|
+
# [',', ';']
|
|
124
|
+
for sil_token in [163, 165]:
|
|
125
|
+
dur_pred[ph_pred==sil_token].clamp_min(16)
|
|
126
|
+
if not is_final:
|
|
127
|
+
# add 0.32ms for crossfade
|
|
128
|
+
dur_pred[:, -1] = dur_pred[:, -1] + 32
|
|
129
|
+
else:
|
|
130
|
+
dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
|
|
131
|
+
|
|
132
|
+
''' DiT target speech generation '''
|
|
133
|
+
dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
|
|
134
|
+
dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
|
|
135
|
+
dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
|
|
136
|
+
dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
|
|
137
|
+
dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
|
|
138
|
+
if is_first:
|
|
139
|
+
dur_pred[:, 0] = 8
|
|
140
|
+
|
|
141
|
+
dur_sum = dur_pred.sum()
|
|
142
|
+
npad = self.fm - dur_sum % self.fm
|
|
143
|
+
if npad < self.fm:
|
|
144
|
+
dur_pred[:, -1] += npad
|
|
145
|
+
mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
|
|
146
|
+
return mel2ph_pred
|
|
147
|
+
|
|
148
|
+
def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
|
|
149
|
+
# Prepare duration token
|
|
150
|
+
mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
|
|
151
|
+
mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
|
|
152
|
+
# Prepare phone and tone token
|
|
153
|
+
ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
|
|
154
|
+
tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
|
|
155
|
+
# Disable the English tone (set them to 3)"""
|
|
156
|
+
en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
|
|
157
|
+
tone_pred[en_tone_idx] = 3
|
|
158
|
+
|
|
159
|
+
# Prepare cfg inputs
|
|
160
|
+
ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
|
|
161
|
+
tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
|
|
162
|
+
target_size = mel2ph_pred.size(1)//self.vae_stride
|
|
163
|
+
vae_latent_ = vae_latent.repeat(3, 1, 1)
|
|
164
|
+
ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
|
|
165
|
+
vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
|
|
166
|
+
vae_latent_[1:] = 0.0
|
|
167
|
+
ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
|
|
168
|
+
|
|
169
|
+
return {
|
|
170
|
+
'phone': ph_seq,
|
|
171
|
+
'tone': tone_seq,
|
|
172
|
+
"lat_ctx": vae_latent_ * ctx_mask,
|
|
173
|
+
"ctx_mask": ctx_mask,
|
|
174
|
+
"dur": mel2ph_pred,
|
|
175
|
+
}
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import multiprocessing as mp
|
|
16
|
+
import torch
|
|
17
|
+
import os
|
|
18
|
+
from functools import partial
|
|
19
|
+
import gradio as gr
|
|
20
|
+
import traceback
|
|
21
|
+
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def model_worker(input_queue, output_queue, device_id):
|
|
25
|
+
device = None
|
|
26
|
+
if device_id is not None:
|
|
27
|
+
device = torch.device(f'cuda:{device_id}')
|
|
28
|
+
infer_pipe = MegaTTS3DiTInfer(device=device)
|
|
29
|
+
|
|
30
|
+
while True:
|
|
31
|
+
task = input_queue.get()
|
|
32
|
+
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
|
|
33
|
+
try:
|
|
34
|
+
convert_to_wav(inp_audio_path)
|
|
35
|
+
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
|
|
36
|
+
cut_wav(wav_path, max_len=28)
|
|
37
|
+
with open(wav_path, 'rb') as file:
|
|
38
|
+
file_content = file.read()
|
|
39
|
+
resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
|
|
40
|
+
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
|
|
41
|
+
output_queue.put(wav_bytes)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
traceback.print_exc()
|
|
44
|
+
print(task, str(e))
|
|
45
|
+
output_queue.put(None)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
|
|
49
|
+
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
|
|
50
|
+
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
|
|
51
|
+
res = output_queue.get()
|
|
52
|
+
if res is not None:
|
|
53
|
+
return res
|
|
54
|
+
else:
|
|
55
|
+
print("")
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == '__main__':
|
|
60
|
+
mp.set_start_method('spawn', force=True)
|
|
61
|
+
mp_manager = mp.Manager()
|
|
62
|
+
|
|
63
|
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
|
|
64
|
+
if devices != '':
|
|
65
|
+
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
|
|
66
|
+
else:
|
|
67
|
+
devices = None
|
|
68
|
+
|
|
69
|
+
num_workers = 1
|
|
70
|
+
input_queue = mp_manager.Queue()
|
|
71
|
+
output_queue = mp_manager.Queue()
|
|
72
|
+
processes = []
|
|
73
|
+
|
|
74
|
+
print("Start open workers")
|
|
75
|
+
for i in range(num_workers):
|
|
76
|
+
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
|
|
77
|
+
p.start()
|
|
78
|
+
processes.append(p)
|
|
79
|
+
|
|
80
|
+
api_interface = gr.Interface(fn=
|
|
81
|
+
partial(main, processes=processes, input_queue=input_queue,
|
|
82
|
+
output_queue=output_queue),
|
|
83
|
+
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
|
|
84
|
+
gr.Number(label="infer timestep", value=32),
|
|
85
|
+
gr.Number(label="Intelligibility Weight", value=1.4),
|
|
86
|
+
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
|
|
87
|
+
title="MegaTTS3",
|
|
88
|
+
description="Upload a speech clip as a reference for timbre, " +
|
|
89
|
+
"upload the pre-extracted latent file, "+
|
|
90
|
+
"input the target text, and receive the cloned voice.", concurrency_limit=1)
|
|
91
|
+
api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
|
|
92
|
+
for p in processes:
|
|
93
|
+
p.join()
|