xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +4 -7
- xinference/client/handlers.py +3 -0
- xinference/core/chat_interface.py +6 -1
- xinference/core/model.py +2 -0
- xinference/core/scheduler.py +4 -7
- xinference/core/supervisor.py +114 -23
- xinference/core/worker.py +70 -4
- xinference/deploy/local.py +2 -1
- xinference/model/audio/core.py +11 -0
- xinference/model/audio/cosyvoice.py +16 -5
- xinference/model/audio/kokoro.py +139 -0
- xinference/model/audio/melotts.py +110 -0
- xinference/model/audio/model_spec.json +80 -0
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/whisper.py +35 -10
- xinference/model/llm/llama_cpp/core.py +21 -14
- xinference/model/llm/llm_family.json +527 -1
- xinference/model/llm/llm_family.py +4 -1
- xinference/model/llm/llm_family_modelscope.json +495 -3
- xinference/model/llm/memory.py +1 -1
- xinference/model/llm/mlx/core.py +24 -6
- xinference/model/llm/transformers/core.py +9 -1
- xinference/model/llm/transformers/qwen2_audio.py +3 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -3
- xinference/model/llm/transformers/utils.py +22 -11
- xinference/model/llm/utils.py +115 -1
- xinference/model/llm/vllm/core.py +14 -4
- xinference/model/llm/vllm/xavier/block.py +3 -4
- xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
- xinference/model/llm/vllm/xavier/collective.py +74 -0
- xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
- xinference/model/llm/vllm/xavier/executor.py +18 -16
- xinference/model/llm/vllm/xavier/scheduler.py +79 -63
- xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
- xinference/model/llm/vllm/xavier/transfer.py +53 -32
- xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
- xinference/thirdparty/melo/__init__.py +0 -0
- xinference/thirdparty/melo/api.py +135 -0
- xinference/thirdparty/melo/app.py +61 -0
- xinference/thirdparty/melo/attentions.py +459 -0
- xinference/thirdparty/melo/commons.py +160 -0
- xinference/thirdparty/melo/configs/config.json +94 -0
- xinference/thirdparty/melo/data/example/metadata.list +20 -0
- xinference/thirdparty/melo/data_utils.py +413 -0
- xinference/thirdparty/melo/download_utils.py +67 -0
- xinference/thirdparty/melo/infer.py +25 -0
- xinference/thirdparty/melo/init_downloads.py +14 -0
- xinference/thirdparty/melo/losses.py +58 -0
- xinference/thirdparty/melo/main.py +36 -0
- xinference/thirdparty/melo/mel_processing.py +174 -0
- xinference/thirdparty/melo/models.py +1030 -0
- xinference/thirdparty/melo/modules.py +598 -0
- xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
- xinference/thirdparty/melo/monotonic_align/core.py +46 -0
- xinference/thirdparty/melo/preprocess_text.py +135 -0
- xinference/thirdparty/melo/split_utils.py +174 -0
- xinference/thirdparty/melo/text/__init__.py +35 -0
- xinference/thirdparty/melo/text/chinese.py +199 -0
- xinference/thirdparty/melo/text/chinese_bert.py +107 -0
- xinference/thirdparty/melo/text/chinese_mix.py +253 -0
- xinference/thirdparty/melo/text/cleaner.py +36 -0
- xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
- xinference/thirdparty/melo/text/cmudict.rep +129530 -0
- xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
- xinference/thirdparty/melo/text/english.py +284 -0
- xinference/thirdparty/melo/text/english_bert.py +39 -0
- xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
- xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
- xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
- xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
- xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
- xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
- xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
- xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
- xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
- xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
- xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
- xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
- xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
- xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
- xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
- xinference/thirdparty/melo/text/french.py +94 -0
- xinference/thirdparty/melo/text/french_bert.py +39 -0
- xinference/thirdparty/melo/text/japanese.py +647 -0
- xinference/thirdparty/melo/text/japanese_bert.py +49 -0
- xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
- xinference/thirdparty/melo/text/korean.py +192 -0
- xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
- xinference/thirdparty/melo/text/spanish.py +122 -0
- xinference/thirdparty/melo/text/spanish_bert.py +39 -0
- xinference/thirdparty/melo/text/symbols.py +290 -0
- xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
- xinference/thirdparty/melo/train.py +635 -0
- xinference/thirdparty/melo/train.sh +19 -0
- xinference/thirdparty/melo/transforms.py +209 -0
- xinference/thirdparty/melo/utils.py +424 -0
- xinference/types.py +2 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
- xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
- /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
- {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,8 @@ from vllm.core.scheduler import Scheduler
|
|
|
23
23
|
from vllm.utils import TORCH_DTYPE_TO_NUMPY_DTYPE, Device
|
|
24
24
|
from vllm.worker.cache_engine import CacheEngine
|
|
25
25
|
|
|
26
|
+
from .collective import CollectiveRank
|
|
27
|
+
|
|
26
28
|
logger = logging.getLogger(__name__)
|
|
27
29
|
|
|
28
30
|
|
|
@@ -89,7 +91,7 @@ class BufferTransferMixin:
|
|
|
89
91
|
return TypeMappingGloo[TORCH_DTYPE_TO_NUMPY_DTYPE[input_dtype]]
|
|
90
92
|
|
|
91
93
|
|
|
92
|
-
class TransferActor(xo.StatelessActor, BufferTransferMixin):
|
|
94
|
+
class TransferActor(xo.StatelessActor, BufferTransferMixin, CollectiveRank):
|
|
93
95
|
@classmethod
|
|
94
96
|
def default_uid(cls):
|
|
95
97
|
return f"vllm-transfer-actor"
|
|
@@ -104,38 +106,21 @@ class TransferActor(xo.StatelessActor, BufferTransferMixin):
|
|
|
104
106
|
world_addresses: List[str],
|
|
105
107
|
):
|
|
106
108
|
super().__init__()
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
109
|
+
CollectiveRank.__init__(
|
|
110
|
+
self,
|
|
111
|
+
rank,
|
|
112
|
+
world_size,
|
|
113
|
+
rank_address,
|
|
114
|
+
store_address,
|
|
115
|
+
store_port,
|
|
116
|
+
world_addresses,
|
|
117
|
+
)
|
|
114
118
|
self._cache_engine: Optional[List[CacheEngine]] = None
|
|
115
119
|
self._scheduler: Optional[List[Scheduler]] = None
|
|
116
120
|
self._swap_stream = torch.cuda.Stream()
|
|
117
121
|
|
|
118
122
|
async def __post_create__(self):
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
context = xp.rendezvous.Context(self._rank, self._world_size)
|
|
122
|
-
|
|
123
|
-
attr = xp.transport.tcp.attr(self._rank_address.split(":")[0])
|
|
124
|
-
dev = xp.transport.tcp.CreateDevice(attr)
|
|
125
|
-
|
|
126
|
-
opt = xp.rendezvous.TCPStoreOptions()
|
|
127
|
-
opt.port = self._store_port
|
|
128
|
-
opt.numWorkers = self._world_size
|
|
129
|
-
opt.isServer = self._rank == 0
|
|
130
|
-
|
|
131
|
-
store = xp.rendezvous.TCPStore(self._store_address, opt)
|
|
132
|
-
store = xp.rendezvous.PrefixStore(str(self._world_size), store)
|
|
133
|
-
|
|
134
|
-
context.connectFullMesh(store, dev)
|
|
135
|
-
self._context = context
|
|
136
|
-
logger.debug(
|
|
137
|
-
f"Rank {self._rank} arrives successfully, world addresses: {self._world_addresses}"
|
|
138
|
-
)
|
|
123
|
+
self.init_rank()
|
|
139
124
|
|
|
140
125
|
def setup(
|
|
141
126
|
self,
|
|
@@ -153,6 +138,9 @@ class TransferActor(xo.StatelessActor, BufferTransferMixin):
|
|
|
153
138
|
num_buffer, buffer_shape, buffer_dtype, buffer_device, pin_memory
|
|
154
139
|
)
|
|
155
140
|
|
|
141
|
+
async def __pre_destroy__(self):
|
|
142
|
+
self._context.closeConnections()
|
|
143
|
+
|
|
156
144
|
def _get_cache_engine(self, virtual_engine: int) -> CacheEngine:
|
|
157
145
|
return self._cache_engine[virtual_engine] # type: ignore
|
|
158
146
|
|
|
@@ -281,18 +269,51 @@ class TransferActor(xo.StatelessActor, BufferTransferMixin):
|
|
|
281
269
|
self.free_buffer_index(cpu_buf_index)
|
|
282
270
|
|
|
283
271
|
async def recv(
|
|
284
|
-
self, virtual_engine: int,
|
|
272
|
+
self, virtual_engine: int, from_rank: int, src_to_dst: Dict[int, int]
|
|
285
273
|
):
|
|
286
274
|
"""
|
|
287
275
|
This is the external entry point for the call.
|
|
288
276
|
The transfer logic is as follows:
|
|
289
277
|
the receiver requests the sender to send the data directly to itself in a point-to-point manner.
|
|
290
278
|
"""
|
|
291
|
-
|
|
279
|
+
from_address = self._world_addresses[from_rank]
|
|
292
280
|
sender_ref = await xo.actor_ref(
|
|
293
|
-
address=from_address, uid=f"{TransferActor.default_uid()}-{
|
|
281
|
+
address=from_address, uid=f"{TransferActor.default_uid()}-{from_rank}"
|
|
294
282
|
)
|
|
295
283
|
await asyncio.gather(
|
|
296
284
|
sender_ref.do_send(virtual_engine, self._rank, src_to_dst),
|
|
297
|
-
self.do_recv(virtual_engine,
|
|
285
|
+
self.do_recv(virtual_engine, from_rank, src_to_dst),
|
|
298
286
|
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class Rank0TransferActor(xo.StatelessActor, CollectiveRank):
|
|
290
|
+
"""
|
|
291
|
+
The Rank 0 transfer actor is only used for constructing the collective communication world,
|
|
292
|
+
so it only needs to inherit the `CollectiveWorld` class.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
@classmethod
|
|
296
|
+
def default_uid(cls):
|
|
297
|
+
return f"vllm-transfer-actor"
|
|
298
|
+
|
|
299
|
+
def __init__(
|
|
300
|
+
self,
|
|
301
|
+
rank: int,
|
|
302
|
+
world_size: int,
|
|
303
|
+
rank_address: str,
|
|
304
|
+
store_address: str,
|
|
305
|
+
store_port: int,
|
|
306
|
+
world_addresses: List[str],
|
|
307
|
+
):
|
|
308
|
+
CollectiveRank.__init__(
|
|
309
|
+
self,
|
|
310
|
+
rank,
|
|
311
|
+
world_size,
|
|
312
|
+
rank_address,
|
|
313
|
+
store_address,
|
|
314
|
+
store_port,
|
|
315
|
+
world_addresses,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
async def __post_create__(self):
|
|
319
|
+
self.init_rank()
|
|
Binary file
|
|
File without changes
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import json
|
|
4
|
+
import torch
|
|
5
|
+
import librosa
|
|
6
|
+
import soundfile
|
|
7
|
+
import torchaudio
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from . import utils
|
|
14
|
+
from . import commons
|
|
15
|
+
from .models import SynthesizerTrn
|
|
16
|
+
from .split_utils import split_sentence
|
|
17
|
+
from .mel_processing import spectrogram_torch, spectrogram_torch_conv
|
|
18
|
+
from .download_utils import load_or_download_config, load_or_download_model
|
|
19
|
+
|
|
20
|
+
class TTS(nn.Module):
|
|
21
|
+
def __init__(self,
|
|
22
|
+
language,
|
|
23
|
+
device='auto',
|
|
24
|
+
use_hf=True,
|
|
25
|
+
config_path=None,
|
|
26
|
+
ckpt_path=None):
|
|
27
|
+
super().__init__()
|
|
28
|
+
if device == 'auto':
|
|
29
|
+
device = 'cpu'
|
|
30
|
+
if torch.cuda.is_available(): device = 'cuda'
|
|
31
|
+
if torch.backends.mps.is_available(): device = 'mps'
|
|
32
|
+
if 'cuda' in device:
|
|
33
|
+
assert torch.cuda.is_available()
|
|
34
|
+
|
|
35
|
+
# config_path =
|
|
36
|
+
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
|
|
37
|
+
|
|
38
|
+
num_languages = hps.num_languages
|
|
39
|
+
num_tones = hps.num_tones
|
|
40
|
+
symbols = hps.symbols
|
|
41
|
+
|
|
42
|
+
model = SynthesizerTrn(
|
|
43
|
+
len(symbols),
|
|
44
|
+
hps.data.filter_length // 2 + 1,
|
|
45
|
+
hps.train.segment_size // hps.data.hop_length,
|
|
46
|
+
n_speakers=hps.data.n_speakers,
|
|
47
|
+
num_tones=num_tones,
|
|
48
|
+
num_languages=num_languages,
|
|
49
|
+
**hps.model,
|
|
50
|
+
).to(device)
|
|
51
|
+
|
|
52
|
+
model.eval()
|
|
53
|
+
self.model = model
|
|
54
|
+
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
55
|
+
self.hps = hps
|
|
56
|
+
self.device = device
|
|
57
|
+
|
|
58
|
+
# load state_dict
|
|
59
|
+
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
|
|
60
|
+
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
|
|
61
|
+
|
|
62
|
+
language = language.split('_')[0]
|
|
63
|
+
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def audio_numpy_concat(segment_data_list, sr, speed=1.):
|
|
67
|
+
audio_segments = []
|
|
68
|
+
for segment_data in segment_data_list:
|
|
69
|
+
audio_segments += segment_data.reshape(-1).tolist()
|
|
70
|
+
audio_segments += [0] * int((sr * 0.05) / speed)
|
|
71
|
+
audio_segments = np.array(audio_segments).astype(np.float32)
|
|
72
|
+
return audio_segments
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def split_sentences_into_pieces(text, language, quiet=False):
|
|
76
|
+
texts = split_sentence(text, language_str=language)
|
|
77
|
+
if not quiet:
|
|
78
|
+
print(" > Text split to sentences.")
|
|
79
|
+
print('\n'.join(texts))
|
|
80
|
+
print(" > ===========================")
|
|
81
|
+
return texts
|
|
82
|
+
|
|
83
|
+
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
|
|
84
|
+
language = self.language
|
|
85
|
+
texts = self.split_sentences_into_pieces(text, language, quiet)
|
|
86
|
+
audio_list = []
|
|
87
|
+
if pbar:
|
|
88
|
+
tx = pbar(texts)
|
|
89
|
+
else:
|
|
90
|
+
if position:
|
|
91
|
+
tx = tqdm(texts, position=position)
|
|
92
|
+
elif quiet:
|
|
93
|
+
tx = texts
|
|
94
|
+
else:
|
|
95
|
+
tx = tqdm(texts)
|
|
96
|
+
for t in tx:
|
|
97
|
+
if language in ['EN', 'ZH_MIX_EN']:
|
|
98
|
+
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
|
|
99
|
+
device = self.device
|
|
100
|
+
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
|
|
101
|
+
with torch.no_grad():
|
|
102
|
+
x_tst = phones.to(device).unsqueeze(0)
|
|
103
|
+
tones = tones.to(device).unsqueeze(0)
|
|
104
|
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
|
105
|
+
bert = bert.to(device).unsqueeze(0)
|
|
106
|
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
|
107
|
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
|
108
|
+
del phones
|
|
109
|
+
speakers = torch.LongTensor([speaker_id]).to(device)
|
|
110
|
+
audio = self.model.infer(
|
|
111
|
+
x_tst,
|
|
112
|
+
x_tst_lengths,
|
|
113
|
+
speakers,
|
|
114
|
+
tones,
|
|
115
|
+
lang_ids,
|
|
116
|
+
bert,
|
|
117
|
+
ja_bert,
|
|
118
|
+
sdp_ratio=sdp_ratio,
|
|
119
|
+
noise_scale=noise_scale,
|
|
120
|
+
noise_scale_w=noise_scale_w,
|
|
121
|
+
length_scale=1. / speed,
|
|
122
|
+
)[0][0, 0].data.cpu().float().numpy()
|
|
123
|
+
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
|
|
124
|
+
#
|
|
125
|
+
audio_list.append(audio)
|
|
126
|
+
torch.cuda.empty_cache()
|
|
127
|
+
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
|
128
|
+
|
|
129
|
+
if output_path is None:
|
|
130
|
+
return audio
|
|
131
|
+
else:
|
|
132
|
+
if format:
|
|
133
|
+
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
|
|
134
|
+
else:
|
|
135
|
+
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# WebUI by mrfakename <X @realmrfakename / HF @mrfakename>
|
|
2
|
+
# Demo also available on HF Spaces: https://huggingface.co/spaces/mrfakename/MeloTTS
|
|
3
|
+
import gradio as gr
|
|
4
|
+
import os, torch, io
|
|
5
|
+
# os.system('python -m unidic download')
|
|
6
|
+
print("Make sure you've downloaded unidic (python -m unidic download) for this WebUI to work.")
|
|
7
|
+
from melo.api import TTS
|
|
8
|
+
speed = 1.0
|
|
9
|
+
import tempfile
|
|
10
|
+
import click
|
|
11
|
+
device = 'auto'
|
|
12
|
+
models = {
|
|
13
|
+
'EN': TTS(language='EN', device=device),
|
|
14
|
+
'ES': TTS(language='ES', device=device),
|
|
15
|
+
'FR': TTS(language='FR', device=device),
|
|
16
|
+
'ZH': TTS(language='ZH', device=device),
|
|
17
|
+
'JP': TTS(language='JP', device=device),
|
|
18
|
+
'KR': TTS(language='KR', device=device),
|
|
19
|
+
}
|
|
20
|
+
speaker_ids = models['EN'].hps.data.spk2id
|
|
21
|
+
|
|
22
|
+
default_text_dict = {
|
|
23
|
+
'EN': 'The field of text-to-speech has seen rapid development recently.',
|
|
24
|
+
'ES': 'El campo de la conversión de texto a voz ha experimentado un rápido desarrollo recientemente.',
|
|
25
|
+
'FR': 'Le domaine de la synthèse vocale a connu un développement rapide récemment',
|
|
26
|
+
'ZH': 'text-to-speech 领域近年来发展迅速',
|
|
27
|
+
'JP': 'テキスト読み上げの分野は最近急速な発展を遂げています',
|
|
28
|
+
'KR': '최근 텍스트 음성 변환 분야가 급속도로 발전하고 있습니다.',
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def synthesize(speaker, text, speed, language, progress=gr.Progress()):
|
|
32
|
+
bio = io.BytesIO()
|
|
33
|
+
models[language].tts_to_file(text, models[language].hps.data.spk2id[speaker], bio, speed=speed, pbar=progress.tqdm, format='wav')
|
|
34
|
+
return bio.getvalue()
|
|
35
|
+
def load_speakers(language, text):
|
|
36
|
+
if text in list(default_text_dict.values()):
|
|
37
|
+
newtext = default_text_dict[language]
|
|
38
|
+
else:
|
|
39
|
+
newtext = text
|
|
40
|
+
return gr.update(value=list(models[language].hps.data.spk2id.keys())[0], choices=list(models[language].hps.data.spk2id.keys())), newtext
|
|
41
|
+
with gr.Blocks() as demo:
|
|
42
|
+
gr.Markdown('# MeloTTS WebUI\n\nA WebUI for MeloTTS.')
|
|
43
|
+
with gr.Group():
|
|
44
|
+
speaker = gr.Dropdown(speaker_ids.keys(), interactive=True, value='EN-US', label='Speaker')
|
|
45
|
+
language = gr.Radio(['EN', 'ES', 'FR', 'ZH', 'JP', 'KR'], label='Language', value='EN')
|
|
46
|
+
speed = gr.Slider(label='Speed', minimum=0.1, maximum=10.0, value=1.0, interactive=True, step=0.1)
|
|
47
|
+
text = gr.Textbox(label="Text to speak", value=default_text_dict['EN'])
|
|
48
|
+
language.input(load_speakers, inputs=[language, text], outputs=[speaker, text])
|
|
49
|
+
btn = gr.Button('Synthesize', variant='primary')
|
|
50
|
+
aud = gr.Audio(interactive=False)
|
|
51
|
+
btn.click(synthesize, inputs=[speaker, text, speed, language], outputs=[aud])
|
|
52
|
+
gr.Markdown('WebUI by [mrfakename](https://twitter.com/realmrfakename).')
|
|
53
|
+
@click.command()
|
|
54
|
+
@click.option('--share', '-s', is_flag=True, show_default=True, default=False, help="Expose a publicly-accessible shared Gradio link usable by anyone with the link. Only share the link with people you trust.")
|
|
55
|
+
@click.option('--host', '-h', default=None)
|
|
56
|
+
@click.option('--port', '-p', type=int, default=None)
|
|
57
|
+
def main(share, host, port):
|
|
58
|
+
demo.queue(api_open=False).launch(show_api=False, share=share, server_name=host, server_port=port)
|
|
59
|
+
|
|
60
|
+
if __name__ == "__main__":
|
|
61
|
+
main()
|