xinference 1.4.1__py3-none-any.whl → 1.5.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +415 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +449 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/METADATA +129 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.post1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py
CHANGED
|
@@ -22,9 +22,20 @@ import signal
|
|
|
22
22
|
import threading
|
|
23
23
|
import time
|
|
24
24
|
from collections import defaultdict
|
|
25
|
-
from dataclasses import dataclass
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
26
|
from logging import getLogger
|
|
27
|
-
from typing import
|
|
27
|
+
from typing import (
|
|
28
|
+
TYPE_CHECKING,
|
|
29
|
+
Any,
|
|
30
|
+
Dict,
|
|
31
|
+
List,
|
|
32
|
+
Literal,
|
|
33
|
+
Optional,
|
|
34
|
+
Set,
|
|
35
|
+
Tuple,
|
|
36
|
+
Union,
|
|
37
|
+
no_type_check,
|
|
38
|
+
)
|
|
28
39
|
|
|
29
40
|
import xoscar as xo
|
|
30
41
|
from async_timeout import timeout
|
|
@@ -34,13 +45,17 @@ from ..constants import (
|
|
|
34
45
|
XINFERENCE_CACHE_DIR,
|
|
35
46
|
XINFERENCE_DISABLE_HEALTH_CHECK,
|
|
36
47
|
XINFERENCE_DISABLE_METRICS,
|
|
48
|
+
XINFERENCE_ENABLE_VIRTUAL_ENV,
|
|
37
49
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
50
|
+
XINFERENCE_VIRTUAL_ENV_DIR,
|
|
38
51
|
)
|
|
39
52
|
from ..core.model import ModelActor
|
|
40
53
|
from ..core.status_guard import LaunchStatus
|
|
41
54
|
from ..device_utils import get_available_device_env_name, gpu_count
|
|
42
|
-
from ..model.core import ModelDescription, create_model_instance
|
|
55
|
+
from ..model.core import ModelDescription, VirtualEnvSettings, create_model_instance
|
|
56
|
+
from ..model.utils import CancellableDownloader
|
|
43
57
|
from ..types import PeftModelConfig
|
|
58
|
+
from ..utils import get_pip_config_args, get_real_path
|
|
44
59
|
from .cache_tracker import CacheTrackerActor
|
|
45
60
|
from .event import Event, EventCollectorActor, EventType
|
|
46
61
|
from .metrics import launch_metrics_export_server, record_metrics
|
|
@@ -48,6 +63,14 @@ from .resource import gather_node_info
|
|
|
48
63
|
from .status_guard import StatusGuardActor
|
|
49
64
|
from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
|
|
50
65
|
|
|
66
|
+
try:
|
|
67
|
+
from xoscar.virtualenv import VirtualEnvManager
|
|
68
|
+
except ImportError:
|
|
69
|
+
VirtualEnvManager = None
|
|
70
|
+
|
|
71
|
+
if TYPE_CHECKING:
|
|
72
|
+
from .progress_tracker import Progressor
|
|
73
|
+
|
|
51
74
|
logger = getLogger(__name__)
|
|
52
75
|
|
|
53
76
|
|
|
@@ -64,6 +87,17 @@ class ModelStatus:
|
|
|
64
87
|
last_error: str = ""
|
|
65
88
|
|
|
66
89
|
|
|
90
|
+
@dataclass
|
|
91
|
+
class LaunchInfo:
|
|
92
|
+
cancel_event: threading.Event = field(default_factory=threading.Event)
|
|
93
|
+
# virtualenv manager
|
|
94
|
+
virtual_env_manager: Optional["VirtualEnvManager"] = None
|
|
95
|
+
# downloader, report progress or cancel entire download
|
|
96
|
+
downloader: Optional[CancellableDownloader] = None
|
|
97
|
+
# sub pools created for the model
|
|
98
|
+
sub_pools: Optional[List[str]] = None
|
|
99
|
+
|
|
100
|
+
|
|
67
101
|
class WorkerActor(xo.StatelessActor):
|
|
68
102
|
def __init__(
|
|
69
103
|
self,
|
|
@@ -92,7 +126,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
92
126
|
|
|
93
127
|
# internal states.
|
|
94
128
|
# temporary placeholder during model launch process:
|
|
95
|
-
self._model_uid_launching_guard: Dict[str,
|
|
129
|
+
self._model_uid_launching_guard: Dict[str, LaunchInfo] = {}
|
|
96
130
|
# attributes maintained after model launched:
|
|
97
131
|
self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
|
|
98
132
|
self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
|
|
@@ -352,6 +386,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
352
386
|
self._cache_tracker_ref = await xo.actor_ref(
|
|
353
387
|
address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
|
|
354
388
|
)
|
|
389
|
+
self._progress_tracker_ref = None
|
|
355
390
|
# cache_tracker is on supervisor
|
|
356
391
|
from ..model.audio import get_audio_model_descriptions
|
|
357
392
|
from ..model.embedding import get_embedding_model_descriptions
|
|
@@ -548,8 +583,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
548
583
|
model_type: Optional[str] = None,
|
|
549
584
|
n_gpu: Optional[Union[int, str]] = "auto",
|
|
550
585
|
gpu_idx: Optional[List[int]] = None,
|
|
586
|
+
env: Optional[Dict[str, str]] = None,
|
|
551
587
|
) -> Tuple[str, List[str]]:
|
|
552
|
-
env = {}
|
|
588
|
+
env = {} if env is None else env
|
|
553
589
|
devices = []
|
|
554
590
|
env_name = get_available_device_env_name() or "CUDA_VISIBLE_DEVICES"
|
|
555
591
|
if gpu_idx is None:
|
|
@@ -778,6 +814,96 @@ class WorkerActor(xo.StatelessActor):
|
|
|
778
814
|
version_info["model_file_location"],
|
|
779
815
|
)
|
|
780
816
|
|
|
817
|
+
@classmethod
|
|
818
|
+
def _create_virtual_env_manager(
|
|
819
|
+
cls,
|
|
820
|
+
enable_virtual_env: Optional[bool],
|
|
821
|
+
virtual_env_name: Optional[str],
|
|
822
|
+
env_path: str,
|
|
823
|
+
) -> Optional[VirtualEnvManager]:
|
|
824
|
+
if enable_virtual_env is None:
|
|
825
|
+
enable_virtual_env = XINFERENCE_ENABLE_VIRTUAL_ENV
|
|
826
|
+
|
|
827
|
+
if not enable_virtual_env:
|
|
828
|
+
# skip preparing virtualenv
|
|
829
|
+
return None
|
|
830
|
+
|
|
831
|
+
from xoscar.virtualenv import get_virtual_env_manager
|
|
832
|
+
|
|
833
|
+
virtual_env_manager: VirtualEnvManager = get_virtual_env_manager(
|
|
834
|
+
virtual_env_name or "uv", env_path
|
|
835
|
+
)
|
|
836
|
+
return virtual_env_manager
|
|
837
|
+
|
|
838
|
+
@classmethod
|
|
839
|
+
def _prepare_virtual_env(
|
|
840
|
+
cls,
|
|
841
|
+
virtual_env_manager: "VirtualEnvManager",
|
|
842
|
+
settings: Optional[VirtualEnvSettings],
|
|
843
|
+
):
|
|
844
|
+
if not settings or not settings.packages:
|
|
845
|
+
# no settings or no packages
|
|
846
|
+
return
|
|
847
|
+
|
|
848
|
+
# create env
|
|
849
|
+
virtual_env_manager.create_env()
|
|
850
|
+
|
|
851
|
+
if settings.inherit_pip_config:
|
|
852
|
+
# inherit pip config
|
|
853
|
+
pip_config = get_pip_config_args()
|
|
854
|
+
for k, v in pip_config.items():
|
|
855
|
+
if hasattr(settings, k) and not getattr(settings, k):
|
|
856
|
+
setattr(settings, k, v)
|
|
857
|
+
|
|
858
|
+
packages = settings.packages
|
|
859
|
+
index_url = settings.index_url
|
|
860
|
+
extra_index_url = settings.extra_index_url
|
|
861
|
+
find_links = settings.find_links
|
|
862
|
+
trusted_host = settings.trusted_host
|
|
863
|
+
|
|
864
|
+
logger.info(
|
|
865
|
+
"Installing packages %s in virtual env %s, with settings(index_url=%s)",
|
|
866
|
+
packages,
|
|
867
|
+
virtual_env_manager.env_path,
|
|
868
|
+
index_url,
|
|
869
|
+
)
|
|
870
|
+
virtual_env_manager.install_packages(
|
|
871
|
+
packages,
|
|
872
|
+
index_url=index_url,
|
|
873
|
+
extra_index_url=extra_index_url,
|
|
874
|
+
find_links=find_links,
|
|
875
|
+
trusted_host=trusted_host,
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
async def _get_progressor(self, request_id: str):
|
|
879
|
+
from .progress_tracker import Progressor, ProgressTrackerActor
|
|
880
|
+
|
|
881
|
+
progress_tracker_ref = self._progress_tracker_ref
|
|
882
|
+
if progress_tracker_ref is None:
|
|
883
|
+
progress_tracker_ref = self._progress_tracker_ref = await xo.actor_ref(
|
|
884
|
+
address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
progressor = Progressor(
|
|
888
|
+
request_id,
|
|
889
|
+
progress_tracker_ref,
|
|
890
|
+
asyncio.get_running_loop(),
|
|
891
|
+
)
|
|
892
|
+
await progressor.start()
|
|
893
|
+
progressor.set_progress(0.0, "start to launch model")
|
|
894
|
+
return progressor
|
|
895
|
+
|
|
896
|
+
@classmethod
|
|
897
|
+
def _upload_download_progress(
|
|
898
|
+
cls, progressor: "Progressor", downloader: CancellableDownloader
|
|
899
|
+
):
|
|
900
|
+
while not downloader.done:
|
|
901
|
+
progress = downloader.get_progress()
|
|
902
|
+
progressor.set_progress(progress)
|
|
903
|
+
downloader.wait(1)
|
|
904
|
+
|
|
905
|
+
progressor.set_progress(1.0, "Start to load model")
|
|
906
|
+
|
|
781
907
|
@log_async(logger=logger, level=logging.INFO)
|
|
782
908
|
async def launch_builtin_model(
|
|
783
909
|
self,
|
|
@@ -870,9 +996,27 @@ class WorkerActor(xo.StatelessActor):
|
|
|
870
996
|
raise ValueError(f"{model_uid} is running")
|
|
871
997
|
|
|
872
998
|
try:
|
|
873
|
-
self._model_uid_launching_guard[model_uid] =
|
|
999
|
+
self._model_uid_launching_guard[model_uid] = launch_info = LaunchInfo()
|
|
1000
|
+
|
|
1001
|
+
# virtualenv
|
|
1002
|
+
enable_virtual_env = kwargs.pop("enable_virtual_env", None)
|
|
1003
|
+
virtual_env_name = kwargs.pop("virtual_env_name", None)
|
|
1004
|
+
virtual_env_path = os.path.join(XINFERENCE_VIRTUAL_ENV_DIR, model_name)
|
|
1005
|
+
virtual_env_manager = await asyncio.to_thread(
|
|
1006
|
+
self._create_virtual_env_manager,
|
|
1007
|
+
enable_virtual_env,
|
|
1008
|
+
virtual_env_name,
|
|
1009
|
+
virtual_env_path,
|
|
1010
|
+
)
|
|
1011
|
+
# setting os.environ if virtualenv created
|
|
1012
|
+
env = (
|
|
1013
|
+
{"PYTHONPATH": virtual_env_manager.get_lib_path()}
|
|
1014
|
+
if virtual_env_manager
|
|
1015
|
+
else None
|
|
1016
|
+
)
|
|
1017
|
+
|
|
874
1018
|
subpool_address, devices = await self._create_subpool(
|
|
875
|
-
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
|
|
1019
|
+
model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx, env=env
|
|
876
1020
|
)
|
|
877
1021
|
all_subpool_addresses = [subpool_address]
|
|
878
1022
|
try:
|
|
@@ -891,23 +1035,62 @@ class WorkerActor(xo.StatelessActor):
|
|
|
891
1035
|
driver_info=driver_info,
|
|
892
1036
|
)
|
|
893
1037
|
)
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
1038
|
+
|
|
1039
|
+
with CancellableDownloader(
|
|
1040
|
+
cancelled_event=launch_info.cancel_event
|
|
1041
|
+
) as downloader:
|
|
1042
|
+
launch_info.downloader = downloader
|
|
1043
|
+
progressor = await self._get_progressor("launching-" + model_uid)
|
|
1044
|
+
# split into download and launch
|
|
1045
|
+
progressor.split_stages(2, stage_weight=[0, 0.8, 1.0])
|
|
1046
|
+
with progressor:
|
|
1047
|
+
upload_progress_task = asyncio.create_task(
|
|
1048
|
+
asyncio.to_thread(
|
|
1049
|
+
self._upload_download_progress, progressor, downloader
|
|
1050
|
+
)
|
|
1051
|
+
)
|
|
1052
|
+
model, model_description = await asyncio.to_thread(
|
|
1053
|
+
create_model_instance,
|
|
1054
|
+
subpool_address,
|
|
1055
|
+
devices,
|
|
1056
|
+
model_uid,
|
|
1057
|
+
model_type,
|
|
1058
|
+
model_name,
|
|
1059
|
+
model_engine,
|
|
1060
|
+
model_format,
|
|
1061
|
+
model_size_in_billions,
|
|
1062
|
+
quantization,
|
|
1063
|
+
peft_model_config,
|
|
1064
|
+
download_hub,
|
|
1065
|
+
model_path,
|
|
1066
|
+
**model_kwargs,
|
|
1067
|
+
)
|
|
1068
|
+
await self.update_cache_status(model_name, model_description)
|
|
1069
|
+
|
|
1070
|
+
def check_cancel():
|
|
1071
|
+
# check downloader first, sometimes download finished
|
|
1072
|
+
# cancelled already
|
|
1073
|
+
if downloader.cancelled:
|
|
1074
|
+
with progressor:
|
|
1075
|
+
# just report progress
|
|
1076
|
+
pass
|
|
1077
|
+
downloader.raise_error(error_msg="Launch cancelled")
|
|
1078
|
+
|
|
1079
|
+
# check cancel before prepare virtual env
|
|
1080
|
+
check_cancel()
|
|
1081
|
+
|
|
1082
|
+
# install packages in virtual env
|
|
1083
|
+
if virtual_env_manager:
|
|
1084
|
+
await asyncio.to_thread(
|
|
1085
|
+
self._prepare_virtual_env,
|
|
1086
|
+
virtual_env_manager,
|
|
1087
|
+
model_description.spec.virtualenv,
|
|
1088
|
+
)
|
|
1089
|
+
launch_info.virtual_env_manager = virtual_env_manager
|
|
1090
|
+
|
|
1091
|
+
# check before creating model actor
|
|
1092
|
+
check_cancel()
|
|
1093
|
+
|
|
911
1094
|
model_ref = await xo.create_actor(
|
|
912
1095
|
ModelActor,
|
|
913
1096
|
address=subpool_address,
|
|
@@ -939,12 +1122,28 @@ class WorkerActor(xo.StatelessActor):
|
|
|
939
1122
|
pool_addresses = await asyncio.gather(*coros)
|
|
940
1123
|
all_subpool_addresses.extend(pool_addresses)
|
|
941
1124
|
await model_ref.set_pool_addresses(pool_addresses)
|
|
942
|
-
|
|
1125
|
+
|
|
1126
|
+
# check before loading
|
|
1127
|
+
check_cancel()
|
|
1128
|
+
|
|
1129
|
+
# set all subpool addresses
|
|
1130
|
+
# when cancelled, all subpool addresses need to be destroyed
|
|
1131
|
+
launch_info.sub_pools = all_subpool_addresses
|
|
1132
|
+
|
|
1133
|
+
with progressor:
|
|
1134
|
+
try:
|
|
1135
|
+
await model_ref.load()
|
|
1136
|
+
except xo.ServerClosed:
|
|
1137
|
+
check_cancel()
|
|
1138
|
+
raise
|
|
943
1139
|
except:
|
|
944
1140
|
logger.error(f"Failed to load model {model_uid}", exc_info=True)
|
|
945
1141
|
self.release_devices(model_uid=model_uid)
|
|
946
1142
|
for addr in all_subpool_addresses:
|
|
947
|
-
|
|
1143
|
+
try:
|
|
1144
|
+
await self._main_pool.remove_sub_pool(addr)
|
|
1145
|
+
except KeyError:
|
|
1146
|
+
continue
|
|
948
1147
|
raise
|
|
949
1148
|
self._model_uid_to_model[model_uid] = model_ref
|
|
950
1149
|
self._model_uid_to_model_spec[model_uid] = model_description
|
|
@@ -978,6 +1177,39 @@ class WorkerActor(xo.StatelessActor):
|
|
|
978
1177
|
model_ref = self._model_uid_to_model[model_uid]
|
|
979
1178
|
await model_ref.wait_for_load()
|
|
980
1179
|
|
|
1180
|
+
@log_sync(logger=logger, level=logging.INFO)
|
|
1181
|
+
async def cancel_launch_model(self, model_uid: str):
|
|
1182
|
+
try:
|
|
1183
|
+
launch_info = self._model_uid_launching_guard[model_uid]
|
|
1184
|
+
|
|
1185
|
+
# downloader shared same cancel event
|
|
1186
|
+
# sometimes cancel happens very early before downloader
|
|
1187
|
+
# even if users cancel at this time,
|
|
1188
|
+
# downloader will know and stop everything
|
|
1189
|
+
launch_info.cancel_event.set()
|
|
1190
|
+
|
|
1191
|
+
if launch_info.downloader:
|
|
1192
|
+
logger.debug("Try to cancel download, %s")
|
|
1193
|
+
launch_info.downloader.cancel()
|
|
1194
|
+
if launch_info.virtual_env_manager:
|
|
1195
|
+
launch_info.virtual_env_manager.cancel_install()
|
|
1196
|
+
if launch_info.sub_pools:
|
|
1197
|
+
logger.debug("Try to stop sub pools: %s", launch_info.sub_pools)
|
|
1198
|
+
coros = []
|
|
1199
|
+
for addr in launch_info.sub_pools:
|
|
1200
|
+
coros.append(self._main_pool.remove_sub_pool(addr, force=True))
|
|
1201
|
+
await asyncio.gather(*coros)
|
|
1202
|
+
if self._status_guard_ref is not None:
|
|
1203
|
+
await self._status_guard_ref.update_instance_info(
|
|
1204
|
+
parse_replica_model_uid(model_uid)[0],
|
|
1205
|
+
{"status": LaunchStatus.ERROR.name},
|
|
1206
|
+
)
|
|
1207
|
+
except KeyError:
|
|
1208
|
+
logger.error("Fail to cancel launching", exc_info=True)
|
|
1209
|
+
raise RuntimeError(
|
|
1210
|
+
"Model is not launching, may have launched or not launched yet"
|
|
1211
|
+
)
|
|
1212
|
+
|
|
981
1213
|
@log_async(logger=logger, level=logging.INFO)
|
|
982
1214
|
async def terminate_model(self, model_uid: str, is_model_die=False):
|
|
983
1215
|
# Terminate model while its launching is not allow
|
|
@@ -1157,16 +1389,9 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1157
1389
|
}
|
|
1158
1390
|
path = list.get("model_file_location")
|
|
1159
1391
|
cached_model["path"] = path
|
|
1160
|
-
|
|
1161
|
-
if
|
|
1162
|
-
|
|
1163
|
-
# dir has files
|
|
1164
|
-
if files:
|
|
1165
|
-
resolved_file = os.path.realpath(os.path.join(path, files[0]))
|
|
1166
|
-
if resolved_file:
|
|
1167
|
-
cached_model["real_path"] = os.path.dirname(resolved_file)
|
|
1168
|
-
else:
|
|
1169
|
-
cached_model["real_path"] = os.path.realpath(path)
|
|
1392
|
+
real_path = get_real_path(path)
|
|
1393
|
+
if real_path:
|
|
1394
|
+
cached_model["real_path"] = real_path
|
|
1170
1395
|
cached_model["actor_ip_address"] = self.address
|
|
1171
1396
|
cached_models.append(cached_model)
|
|
1172
1397
|
return cached_models
|
|
@@ -1267,7 +1492,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
1267
1492
|
# Note that `store_port` needs to be generated on the worker,
|
|
1268
1493
|
# as the TCP store is on rank 0, not on the supervisor.
|
|
1269
1494
|
store_port = xo.utils.get_next_port()
|
|
1270
|
-
self._model_uid_launching_guard[rep_model_uid] =
|
|
1495
|
+
self._model_uid_launching_guard[rep_model_uid] = LaunchInfo()
|
|
1271
1496
|
try:
|
|
1272
1497
|
try:
|
|
1273
1498
|
xavier_config["rank_address"] = subpool_address
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -16,10 +16,12 @@ import asyncio
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
import sys
|
|
19
|
+
import time
|
|
19
20
|
import warnings
|
|
20
21
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
21
22
|
|
|
22
23
|
import click
|
|
24
|
+
from tqdm.auto import tqdm
|
|
23
25
|
from xoscar.utils import get_next_port
|
|
24
26
|
|
|
25
27
|
from .. import __version__
|
|
@@ -925,6 +927,9 @@ def model_launch(
|
|
|
925
927
|
if api_key is None:
|
|
926
928
|
client._set_token(get_stored_token(endpoint, client))
|
|
927
929
|
|
|
930
|
+
# do not wait for launching.
|
|
931
|
+
kwargs["wait_ready"] = False
|
|
932
|
+
|
|
928
933
|
model_uid = client.launch_model(
|
|
929
934
|
model_name=model_name,
|
|
930
935
|
model_type=model_type,
|
|
@@ -943,8 +948,35 @@ def model_launch(
|
|
|
943
948
|
model_path=model_path,
|
|
944
949
|
**kwargs,
|
|
945
950
|
)
|
|
951
|
+
try:
|
|
952
|
+
with tqdm(
|
|
953
|
+
total=100, desc="Launching model", bar_format="{l_bar}{bar} | {n:.1f}%"
|
|
954
|
+
) as pbar:
|
|
955
|
+
while True:
|
|
956
|
+
status = client.get_instance_info(model_name, model_uid)
|
|
957
|
+
if all(s["status"] in ["READY", "ERROR", "TERMINATED"] for s in status):
|
|
958
|
+
break
|
|
959
|
+
|
|
960
|
+
progress = client.get_launch_model_progress(model_uid)["progress"]
|
|
961
|
+
percent = max(round(progress * 100, 1), pbar.n)
|
|
946
962
|
|
|
947
|
-
|
|
963
|
+
pbar.update(percent - pbar.n)
|
|
964
|
+
|
|
965
|
+
time.sleep(0.5)
|
|
966
|
+
|
|
967
|
+
# setting to 100%
|
|
968
|
+
pbar.update(pbar.total - pbar.n)
|
|
969
|
+
|
|
970
|
+
print(f"Model uid: {model_uid}", file=sys.stderr)
|
|
971
|
+
except KeyboardInterrupt:
|
|
972
|
+
user_input = (
|
|
973
|
+
input("Do you want to cancel model launching? (y/[n]): ").strip().lower()
|
|
974
|
+
)
|
|
975
|
+
if user_input == "y":
|
|
976
|
+
client.cancel_launch_model(model_uid)
|
|
977
|
+
print(f"Cancel request sent: {model_uid}")
|
|
978
|
+
else:
|
|
979
|
+
print("Skip cancel, launching model will be running still.")
|
|
948
980
|
|
|
949
981
|
|
|
950
982
|
@cli.command(
|
xinference/model/audio/core.py
CHANGED
|
@@ -17,7 +17,7 @@ from collections import defaultdict
|
|
|
17
17
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
|
-
from ..core import CacheableModelSpec, ModelDescription
|
|
20
|
+
from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
@@ -26,6 +26,7 @@ from .f5tts_mlx import F5TTSMLXModel
|
|
|
26
26
|
from .fish_speech import FishSpeechModel
|
|
27
27
|
from .funasr import FunASRModel
|
|
28
28
|
from .kokoro import KokoroModel
|
|
29
|
+
from .megatts import MegaTTSModel
|
|
29
30
|
from .melotts import MeloTTSModel
|
|
30
31
|
from .whisper import WhisperModel
|
|
31
32
|
from .whisper_mlx import WhisperMLXModel
|
|
@@ -55,6 +56,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
|
|
|
55
56
|
default_model_config: Optional[Dict[str, Any]]
|
|
56
57
|
default_transcription_config: Optional[Dict[str, Any]]
|
|
57
58
|
engine: Optional[str]
|
|
59
|
+
virtualenv: Optional[VirtualEnvSettings]
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
class AudioModelDescription(ModelDescription):
|
|
@@ -68,6 +70,10 @@ class AudioModelDescription(ModelDescription):
|
|
|
68
70
|
super().__init__(address, devices, model_path=model_path)
|
|
69
71
|
self._model_spec = model_spec
|
|
70
72
|
|
|
73
|
+
@property
|
|
74
|
+
def spec(self):
|
|
75
|
+
return self._model_spec
|
|
76
|
+
|
|
71
77
|
def to_dict(self):
|
|
72
78
|
return {
|
|
73
79
|
"model_type": "audio",
|
|
@@ -178,6 +184,7 @@ def create_audio_model_instance(
|
|
|
178
184
|
F5TTSMLXModel,
|
|
179
185
|
MeloTTSModel,
|
|
180
186
|
KokoroModel,
|
|
187
|
+
MegaTTSModel,
|
|
181
188
|
],
|
|
182
189
|
AudioModelDescription,
|
|
183
190
|
]:
|
|
@@ -195,6 +202,7 @@ def create_audio_model_instance(
|
|
|
195
202
|
F5TTSMLXModel,
|
|
196
203
|
MeloTTSModel,
|
|
197
204
|
KokoroModel,
|
|
205
|
+
MegaTTSModel,
|
|
198
206
|
]
|
|
199
207
|
if model_spec.model_family == "whisper":
|
|
200
208
|
if not model_spec.engine:
|
|
@@ -217,6 +225,8 @@ def create_audio_model_instance(
|
|
|
217
225
|
model = MeloTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
218
226
|
elif model_spec.model_family == "Kokoro":
|
|
219
227
|
model = KokoroModel(model_uid, model_path, model_spec, **kwargs)
|
|
228
|
+
elif model_spec.model_family == "MegaTTS":
|
|
229
|
+
model = MegaTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
220
230
|
else:
|
|
221
231
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
222
232
|
model_description = AudioModelDescription(
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import io
|
|
15
|
+
import logging
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
from typing import TYPE_CHECKING, Optional
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .core import AudioModelFamilyV1
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MegaTTSModel:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model_uid: str,
|
|
29
|
+
model_path: str,
|
|
30
|
+
model_spec: "AudioModelFamilyV1",
|
|
31
|
+
device: Optional[str] = None,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
self._model_uid = model_uid
|
|
35
|
+
self._model_path = model_path
|
|
36
|
+
self._model_spec = model_spec
|
|
37
|
+
self._device = device
|
|
38
|
+
self._model = None
|
|
39
|
+
self._vocoder = None
|
|
40
|
+
self._kwargs = kwargs
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def model_ability(self):
|
|
44
|
+
return self._model_spec.model_ability
|
|
45
|
+
|
|
46
|
+
def load(self):
|
|
47
|
+
import os
|
|
48
|
+
import sys
|
|
49
|
+
|
|
50
|
+
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
51
|
+
sys.path.insert(
|
|
52
|
+
0, os.path.join(os.path.dirname(__file__), "../../thirdparty/megatts3")
|
|
53
|
+
)
|
|
54
|
+
# For whisper
|
|
55
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
|
|
56
|
+
|
|
57
|
+
from tts.infer_cli import MegaTTS3DiTInfer
|
|
58
|
+
|
|
59
|
+
self._model = MegaTTS3DiTInfer(ckpt_root=self._model_path)
|
|
60
|
+
|
|
61
|
+
def speech(
|
|
62
|
+
self,
|
|
63
|
+
input: str,
|
|
64
|
+
voice: str,
|
|
65
|
+
response_format: str = "mp3",
|
|
66
|
+
speed: float = 1.0,
|
|
67
|
+
stream: bool = False,
|
|
68
|
+
**kwargs,
|
|
69
|
+
):
|
|
70
|
+
import soundfile
|
|
71
|
+
|
|
72
|
+
if stream:
|
|
73
|
+
raise Exception("MegaTTS3 does not support stream generation.")
|
|
74
|
+
if voice:
|
|
75
|
+
raise Exception(
|
|
76
|
+
"MegaTTS3 does not support voice, please specify prompt_speech and prompt_latent."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
80
|
+
prompt_latent: Optional[bytes] = kwargs.pop("prompt_latent", None)
|
|
81
|
+
if not prompt_speech:
|
|
82
|
+
raise Exception("Please set prompt_speech for MegaTTS3.")
|
|
83
|
+
if not prompt_latent:
|
|
84
|
+
raise Exception("Please set prompt_latent for MegaTTS3.")
|
|
85
|
+
|
|
86
|
+
assert self._model is not None
|
|
87
|
+
with io.BytesIO(prompt_latent) as prompt_latent_io:
|
|
88
|
+
resource_context = self._model.preprocess(
|
|
89
|
+
prompt_speech, latent_file=prompt_latent_io
|
|
90
|
+
)
|
|
91
|
+
wav_bytes = self._model.forward(
|
|
92
|
+
resource_context,
|
|
93
|
+
input,
|
|
94
|
+
time_step=kwargs.get("time_step", 32),
|
|
95
|
+
p_w=kwargs.get("p_w", 1.6),
|
|
96
|
+
t_w=kwargs.get("t_w", 2.5),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Save the generated audio
|
|
100
|
+
with BytesIO() as out:
|
|
101
|
+
with soundfile.SoundFile(
|
|
102
|
+
out, "w", self._model.sr, 1, format=response_format.upper()
|
|
103
|
+
) as f:
|
|
104
|
+
f.write(wav_bytes)
|
|
105
|
+
return out.getvalue()
|
|
@@ -203,6 +203,21 @@
|
|
|
203
203
|
"merge_length_s": 15
|
|
204
204
|
}
|
|
205
205
|
},
|
|
206
|
+
{
|
|
207
|
+
"model_name": "paraformer-zh",
|
|
208
|
+
"model_family": "funasr",
|
|
209
|
+
"model_id": "funasr/paraformer-zh",
|
|
210
|
+
"model_revision": "5ed094cdfc8f6a9b6b022bd08bc904ef862bc79e",
|
|
211
|
+
"model_ability": "audio-to-text",
|
|
212
|
+
"multilingual": false,
|
|
213
|
+
"default_model_config": {
|
|
214
|
+
"vad_model": "fsmn-vad",
|
|
215
|
+
"punc_model": "ct-punc"
|
|
216
|
+
},
|
|
217
|
+
"default_transcription_config": {
|
|
218
|
+
"batch_size_s": 300
|
|
219
|
+
}
|
|
220
|
+
},
|
|
206
221
|
{
|
|
207
222
|
"model_name": "ChatTTS",
|
|
208
223
|
"model_family": "ChatTTS",
|
|
@@ -216,7 +231,7 @@
|
|
|
216
231
|
"model_family": "CosyVoice",
|
|
217
232
|
"model_id": "FunAudioLLM/CosyVoice-300M",
|
|
218
233
|
"model_revision": "39c4e13d46bd4dfb840d214547623e5fcd2428e2",
|
|
219
|
-
"model_ability": "
|
|
234
|
+
"model_ability": "text-to-audio",
|
|
220
235
|
"multilingual": true
|
|
221
236
|
},
|
|
222
237
|
{
|
|
@@ -346,5 +361,13 @@
|
|
|
346
361
|
"model_revision": "7884269d6fd3f9beabc271b6f1308e5699281fa9",
|
|
347
362
|
"model_ability": "text-to-audio",
|
|
348
363
|
"multilingual": true
|
|
364
|
+
},
|
|
365
|
+
{
|
|
366
|
+
"model_name": "MegaTTS3",
|
|
367
|
+
"model_family": "MegaTTS",
|
|
368
|
+
"model_id": "ByteDance/MegaTTS3",
|
|
369
|
+
"model_revision": "409a7002b006d80f0730fca6f80441b08c10e738",
|
|
370
|
+
"model_ability": "text-to-audio",
|
|
371
|
+
"multilingual": true
|
|
349
372
|
}
|
|
350
373
|
]
|