xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +413 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +447 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
xinference/core/worker.py CHANGED
@@ -22,9 +22,20 @@ import signal
22
22
  import threading
23
23
  import time
24
24
  from collections import defaultdict
25
- from dataclasses import dataclass
25
+ from dataclasses import dataclass, field
26
26
  from logging import getLogger
27
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, no_type_check
27
+ from typing import (
28
+ TYPE_CHECKING,
29
+ Any,
30
+ Dict,
31
+ List,
32
+ Literal,
33
+ Optional,
34
+ Set,
35
+ Tuple,
36
+ Union,
37
+ no_type_check,
38
+ )
28
39
 
29
40
  import xoscar as xo
30
41
  from async_timeout import timeout
@@ -34,13 +45,17 @@ from ..constants import (
34
45
  XINFERENCE_CACHE_DIR,
35
46
  XINFERENCE_DISABLE_HEALTH_CHECK,
36
47
  XINFERENCE_DISABLE_METRICS,
48
+ XINFERENCE_ENABLE_VIRTUAL_ENV,
37
49
  XINFERENCE_HEALTH_CHECK_INTERVAL,
50
+ XINFERENCE_VIRTUAL_ENV_DIR,
38
51
  )
39
52
  from ..core.model import ModelActor
40
53
  from ..core.status_guard import LaunchStatus
41
54
  from ..device_utils import get_available_device_env_name, gpu_count
42
- from ..model.core import ModelDescription, create_model_instance
55
+ from ..model.core import ModelDescription, VirtualEnvSettings, create_model_instance
56
+ from ..model.utils import CancellableDownloader
43
57
  from ..types import PeftModelConfig
58
+ from ..utils import get_pip_config_args, get_real_path
44
59
  from .cache_tracker import CacheTrackerActor
45
60
  from .event import Event, EventCollectorActor, EventType
46
61
  from .metrics import launch_metrics_export_server, record_metrics
@@ -48,6 +63,14 @@ from .resource import gather_node_info
48
63
  from .status_guard import StatusGuardActor
49
64
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
50
65
 
66
+ try:
67
+ from xoscar.virtualenv import VirtualEnvManager
68
+ except ImportError:
69
+ VirtualEnvManager = None
70
+
71
+ if TYPE_CHECKING:
72
+ from .progress_tracker import Progressor
73
+
51
74
  logger = getLogger(__name__)
52
75
 
53
76
 
@@ -64,6 +87,17 @@ class ModelStatus:
64
87
  last_error: str = ""
65
88
 
66
89
 
90
+ @dataclass
91
+ class LaunchInfo:
92
+ cancel_event: threading.Event = field(default_factory=threading.Event)
93
+ # virtualenv manager
94
+ virtual_env_manager: Optional["VirtualEnvManager"] = None
95
+ # downloader, report progress or cancel entire download
96
+ downloader: Optional[CancellableDownloader] = None
97
+ # sub pools created for the model
98
+ sub_pools: Optional[List[str]] = None
99
+
100
+
67
101
  class WorkerActor(xo.StatelessActor):
68
102
  def __init__(
69
103
  self,
@@ -92,7 +126,7 @@ class WorkerActor(xo.StatelessActor):
92
126
 
93
127
  # internal states.
94
128
  # temporary placeholder during model launch process:
95
- self._model_uid_launching_guard: Dict[str, bool] = {}
129
+ self._model_uid_launching_guard: Dict[str, LaunchInfo] = {}
96
130
  # attributes maintained after model launched:
97
131
  self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
98
132
  self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
@@ -352,6 +386,7 @@ class WorkerActor(xo.StatelessActor):
352
386
  self._cache_tracker_ref = await xo.actor_ref(
353
387
  address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
354
388
  )
389
+ self._progress_tracker_ref = None
355
390
  # cache_tracker is on supervisor
356
391
  from ..model.audio import get_audio_model_descriptions
357
392
  from ..model.embedding import get_embedding_model_descriptions
@@ -548,8 +583,9 @@ class WorkerActor(xo.StatelessActor):
548
583
  model_type: Optional[str] = None,
549
584
  n_gpu: Optional[Union[int, str]] = "auto",
550
585
  gpu_idx: Optional[List[int]] = None,
586
+ env: Optional[Dict[str, str]] = None,
551
587
  ) -> Tuple[str, List[str]]:
552
- env = {}
588
+ env = {} if env is None else env
553
589
  devices = []
554
590
  env_name = get_available_device_env_name() or "CUDA_VISIBLE_DEVICES"
555
591
  if gpu_idx is None:
@@ -778,6 +814,96 @@ class WorkerActor(xo.StatelessActor):
778
814
  version_info["model_file_location"],
779
815
  )
780
816
 
817
+ @classmethod
818
+ def _create_virtual_env_manager(
819
+ cls,
820
+ enable_virtual_env: Optional[bool],
821
+ virtual_env_name: Optional[str],
822
+ env_path: str,
823
+ ) -> Optional[VirtualEnvManager]:
824
+ if enable_virtual_env is None:
825
+ enable_virtual_env = XINFERENCE_ENABLE_VIRTUAL_ENV
826
+
827
+ if not enable_virtual_env:
828
+ # skip preparing virtualenv
829
+ return None
830
+
831
+ from xoscar.virtualenv import get_virtual_env_manager
832
+
833
+ virtual_env_manager: VirtualEnvManager = get_virtual_env_manager(
834
+ virtual_env_name or "uv", env_path
835
+ )
836
+ return virtual_env_manager
837
+
838
+ @classmethod
839
+ def _prepare_virtual_env(
840
+ cls,
841
+ virtual_env_manager: "VirtualEnvManager",
842
+ settings: Optional[VirtualEnvSettings],
843
+ ):
844
+ if not settings or not settings.packages:
845
+ # no settings or no packages
846
+ return
847
+
848
+ # create env
849
+ virtual_env_manager.create_env()
850
+
851
+ if settings.inherit_pip_config:
852
+ # inherit pip config
853
+ pip_config = get_pip_config_args()
854
+ for k, v in pip_config.items():
855
+ if hasattr(settings, k) and not getattr(settings, k):
856
+ setattr(settings, k, v)
857
+
858
+ packages = settings.packages
859
+ index_url = settings.index_url
860
+ extra_index_url = settings.extra_index_url
861
+ find_links = settings.find_links
862
+ trusted_host = settings.trusted_host
863
+
864
+ logger.info(
865
+ "Installing packages %s in virtual env %s, with settings(index_url=%s)",
866
+ packages,
867
+ virtual_env_manager.env_path,
868
+ index_url,
869
+ )
870
+ virtual_env_manager.install_packages(
871
+ packages,
872
+ index_url=index_url,
873
+ extra_index_url=extra_index_url,
874
+ find_links=find_links,
875
+ trusted_host=trusted_host,
876
+ )
877
+
878
+ async def _get_progressor(self, request_id: str):
879
+ from .progress_tracker import Progressor, ProgressTrackerActor
880
+
881
+ progress_tracker_ref = self._progress_tracker_ref
882
+ if progress_tracker_ref is None:
883
+ progress_tracker_ref = self._progress_tracker_ref = await xo.actor_ref(
884
+ address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
885
+ )
886
+
887
+ progressor = Progressor(
888
+ request_id,
889
+ progress_tracker_ref,
890
+ asyncio.get_running_loop(),
891
+ )
892
+ await progressor.start()
893
+ progressor.set_progress(0.0, "start to launch model")
894
+ return progressor
895
+
896
+ @classmethod
897
+ def _upload_download_progress(
898
+ cls, progressor: "Progressor", downloader: CancellableDownloader
899
+ ):
900
+ while not downloader.done:
901
+ progress = downloader.get_progress()
902
+ progressor.set_progress(progress)
903
+ downloader.wait(1)
904
+
905
+ progressor.set_progress(1.0, "Start to load model")
906
+
781
907
  @log_async(logger=logger, level=logging.INFO)
782
908
  async def launch_builtin_model(
783
909
  self,
@@ -870,9 +996,27 @@ class WorkerActor(xo.StatelessActor):
870
996
  raise ValueError(f"{model_uid} is running")
871
997
 
872
998
  try:
873
- self._model_uid_launching_guard[model_uid] = True
999
+ self._model_uid_launching_guard[model_uid] = launch_info = LaunchInfo()
1000
+
1001
+ # virtualenv
1002
+ enable_virtual_env = kwargs.pop("enable_virtual_env", None)
1003
+ virtual_env_name = kwargs.pop("virtual_env_name", None)
1004
+ virtual_env_path = os.path.join(XINFERENCE_VIRTUAL_ENV_DIR, model_name)
1005
+ virtual_env_manager = await asyncio.to_thread(
1006
+ self._create_virtual_env_manager,
1007
+ enable_virtual_env,
1008
+ virtual_env_name,
1009
+ virtual_env_path,
1010
+ )
1011
+ # setting os.environ if virtualenv created
1012
+ env = (
1013
+ {"PYTHONPATH": virtual_env_manager.get_lib_path()}
1014
+ if virtual_env_manager
1015
+ else None
1016
+ )
1017
+
874
1018
  subpool_address, devices = await self._create_subpool(
875
- model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
1019
+ model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx, env=env
876
1020
  )
877
1021
  all_subpool_addresses = [subpool_address]
878
1022
  try:
@@ -891,23 +1035,62 @@ class WorkerActor(xo.StatelessActor):
891
1035
  driver_info=driver_info,
892
1036
  )
893
1037
  )
894
- model, model_description = await asyncio.to_thread(
895
- create_model_instance,
896
- subpool_address,
897
- devices,
898
- model_uid,
899
- model_type,
900
- model_name,
901
- model_engine,
902
- model_format,
903
- model_size_in_billions,
904
- quantization,
905
- peft_model_config,
906
- download_hub,
907
- model_path,
908
- **model_kwargs,
909
- )
910
- await self.update_cache_status(model_name, model_description)
1038
+
1039
+ with CancellableDownloader(
1040
+ cancelled_event=launch_info.cancel_event
1041
+ ) as downloader:
1042
+ launch_info.downloader = downloader
1043
+ progressor = await self._get_progressor("launching-" + model_uid)
1044
+ # split into download and launch
1045
+ progressor.split_stages(2, stage_weight=[0, 0.8, 1.0])
1046
+ with progressor:
1047
+ upload_progress_task = asyncio.create_task(
1048
+ asyncio.to_thread(
1049
+ self._upload_download_progress, progressor, downloader
1050
+ )
1051
+ )
1052
+ model, model_description = await asyncio.to_thread(
1053
+ create_model_instance,
1054
+ subpool_address,
1055
+ devices,
1056
+ model_uid,
1057
+ model_type,
1058
+ model_name,
1059
+ model_engine,
1060
+ model_format,
1061
+ model_size_in_billions,
1062
+ quantization,
1063
+ peft_model_config,
1064
+ download_hub,
1065
+ model_path,
1066
+ **model_kwargs,
1067
+ )
1068
+ await self.update_cache_status(model_name, model_description)
1069
+
1070
+ def check_cancel():
1071
+ # check downloader first, sometimes download finished
1072
+ # cancelled already
1073
+ if downloader.cancelled:
1074
+ with progressor:
1075
+ # just report progress
1076
+ pass
1077
+ downloader.raise_error(error_msg="Launch cancelled")
1078
+
1079
+ # check cancel before prepare virtual env
1080
+ check_cancel()
1081
+
1082
+ # install packages in virtual env
1083
+ if virtual_env_manager:
1084
+ await asyncio.to_thread(
1085
+ self._prepare_virtual_env,
1086
+ virtual_env_manager,
1087
+ model_description.spec.virtualenv,
1088
+ )
1089
+ launch_info.virtual_env_manager = virtual_env_manager
1090
+
1091
+ # check before creating model actor
1092
+ check_cancel()
1093
+
911
1094
  model_ref = await xo.create_actor(
912
1095
  ModelActor,
913
1096
  address=subpool_address,
@@ -939,12 +1122,28 @@ class WorkerActor(xo.StatelessActor):
939
1122
  pool_addresses = await asyncio.gather(*coros)
940
1123
  all_subpool_addresses.extend(pool_addresses)
941
1124
  await model_ref.set_pool_addresses(pool_addresses)
942
- await model_ref.load()
1125
+
1126
+ # check before loading
1127
+ check_cancel()
1128
+
1129
+ # set all subpool addresses
1130
+ # when cancelled, all subpool addresses need to be destroyed
1131
+ launch_info.sub_pools = all_subpool_addresses
1132
+
1133
+ with progressor:
1134
+ try:
1135
+ await model_ref.load()
1136
+ except xo.ServerClosed:
1137
+ check_cancel()
1138
+ raise
943
1139
  except:
944
1140
  logger.error(f"Failed to load model {model_uid}", exc_info=True)
945
1141
  self.release_devices(model_uid=model_uid)
946
1142
  for addr in all_subpool_addresses:
947
- await self._main_pool.remove_sub_pool(addr)
1143
+ try:
1144
+ await self._main_pool.remove_sub_pool(addr)
1145
+ except KeyError:
1146
+ continue
948
1147
  raise
949
1148
  self._model_uid_to_model[model_uid] = model_ref
950
1149
  self._model_uid_to_model_spec[model_uid] = model_description
@@ -978,6 +1177,39 @@ class WorkerActor(xo.StatelessActor):
978
1177
  model_ref = self._model_uid_to_model[model_uid]
979
1178
  await model_ref.wait_for_load()
980
1179
 
1180
+ @log_sync(logger=logger, level=logging.INFO)
1181
+ async def cancel_launch_model(self, model_uid: str):
1182
+ try:
1183
+ launch_info = self._model_uid_launching_guard[model_uid]
1184
+
1185
+ # downloader shared same cancel event
1186
+ # sometimes cancel happens very early before downloader
1187
+ # even if users cancel at this time,
1188
+ # downloader will know and stop everything
1189
+ launch_info.cancel_event.set()
1190
+
1191
+ if launch_info.downloader:
1192
+ logger.debug("Try to cancel download, %s")
1193
+ launch_info.downloader.cancel()
1194
+ if launch_info.virtual_env_manager:
1195
+ launch_info.virtual_env_manager.cancel_install()
1196
+ if launch_info.sub_pools:
1197
+ logger.debug("Try to stop sub pools: %s", launch_info.sub_pools)
1198
+ coros = []
1199
+ for addr in launch_info.sub_pools:
1200
+ coros.append(self._main_pool.remove_sub_pool(addr, force=True))
1201
+ await asyncio.gather(*coros)
1202
+ if self._status_guard_ref is not None:
1203
+ await self._status_guard_ref.update_instance_info(
1204
+ parse_replica_model_uid(model_uid)[0],
1205
+ {"status": LaunchStatus.ERROR.name},
1206
+ )
1207
+ except KeyError:
1208
+ logger.error("Fail to cancel launching", exc_info=True)
1209
+ raise RuntimeError(
1210
+ "Model is not launching, may have launched or not launched yet"
1211
+ )
1212
+
981
1213
  @log_async(logger=logger, level=logging.INFO)
982
1214
  async def terminate_model(self, model_uid: str, is_model_die=False):
983
1215
  # Terminate model while its launching is not allow
@@ -1157,16 +1389,9 @@ class WorkerActor(xo.StatelessActor):
1157
1389
  }
1158
1390
  path = list.get("model_file_location")
1159
1391
  cached_model["path"] = path
1160
- # parsing soft links
1161
- if os.path.isdir(path):
1162
- files = os.listdir(path)
1163
- # dir has files
1164
- if files:
1165
- resolved_file = os.path.realpath(os.path.join(path, files[0]))
1166
- if resolved_file:
1167
- cached_model["real_path"] = os.path.dirname(resolved_file)
1168
- else:
1169
- cached_model["real_path"] = os.path.realpath(path)
1392
+ real_path = get_real_path(path)
1393
+ if real_path:
1394
+ cached_model["real_path"] = real_path
1170
1395
  cached_model["actor_ip_address"] = self.address
1171
1396
  cached_models.append(cached_model)
1172
1397
  return cached_models
@@ -1267,7 +1492,7 @@ class WorkerActor(xo.StatelessActor):
1267
1492
  # Note that `store_port` needs to be generated on the worker,
1268
1493
  # as the TCP store is on rank 0, not on the supervisor.
1269
1494
  store_port = xo.utils.get_next_port()
1270
- self._model_uid_launching_guard[rep_model_uid] = True
1495
+ self._model_uid_launching_guard[rep_model_uid] = LaunchInfo()
1271
1496
  try:
1272
1497
  try:
1273
1498
  xavier_config["rank_address"] = subpool_address
@@ -16,10 +16,12 @@ import asyncio
16
16
  import logging
17
17
  import os
18
18
  import sys
19
+ import time
19
20
  import warnings
20
21
  from typing import Dict, List, Optional, Sequence, Tuple, Union
21
22
 
22
23
  import click
24
+ from tqdm.auto import tqdm
23
25
  from xoscar.utils import get_next_port
24
26
 
25
27
  from .. import __version__
@@ -925,6 +927,9 @@ def model_launch(
925
927
  if api_key is None:
926
928
  client._set_token(get_stored_token(endpoint, client))
927
929
 
930
+ # do not wait for launching.
931
+ kwargs["wait_ready"] = False
932
+
928
933
  model_uid = client.launch_model(
929
934
  model_name=model_name,
930
935
  model_type=model_type,
@@ -943,8 +948,35 @@ def model_launch(
943
948
  model_path=model_path,
944
949
  **kwargs,
945
950
  )
951
+ try:
952
+ with tqdm(
953
+ total=100, desc="Launching model", bar_format="{l_bar}{bar} | {n:.1f}%"
954
+ ) as pbar:
955
+ while True:
956
+ status = client.get_instance_info(model_name, model_uid)
957
+ if all(s["status"] in ["READY", "ERROR", "TERMINATED"] for s in status):
958
+ break
959
+
960
+ progress = client.get_launch_model_progress(model_uid)["progress"]
961
+ percent = max(round(progress * 100, 1), pbar.n)
946
962
 
947
- print(f"Model uid: {model_uid}", file=sys.stderr)
963
+ pbar.update(percent - pbar.n)
964
+
965
+ time.sleep(0.5)
966
+
967
+ # setting to 100%
968
+ pbar.update(pbar.total - pbar.n)
969
+
970
+ print(f"Model uid: {model_uid}", file=sys.stderr)
971
+ except KeyboardInterrupt:
972
+ user_input = (
973
+ input("Do you want to cancel model launching? (y/[n]): ").strip().lower()
974
+ )
975
+ if user_input == "y":
976
+ client.cancel_launch_model(model_uid)
977
+ print(f"Cancel request sent: {model_uid}")
978
+ else:
979
+ print("Skip cancel, launching model will be running still.")
948
980
 
949
981
 
950
982
  @cli.command(
@@ -17,7 +17,7 @@ from collections import defaultdict
17
17
  from typing import Any, Dict, List, Literal, Optional, Tuple, Union
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
- from ..core import CacheableModelSpec, ModelDescription
20
+ from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
@@ -26,6 +26,7 @@ from .f5tts_mlx import F5TTSMLXModel
26
26
  from .fish_speech import FishSpeechModel
27
27
  from .funasr import FunASRModel
28
28
  from .kokoro import KokoroModel
29
+ from .megatts import MegaTTSModel
29
30
  from .melotts import MeloTTSModel
30
31
  from .whisper import WhisperModel
31
32
  from .whisper_mlx import WhisperMLXModel
@@ -55,6 +56,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
55
56
  default_model_config: Optional[Dict[str, Any]]
56
57
  default_transcription_config: Optional[Dict[str, Any]]
57
58
  engine: Optional[str]
59
+ virtualenv: Optional[VirtualEnvSettings]
58
60
 
59
61
 
60
62
  class AudioModelDescription(ModelDescription):
@@ -68,6 +70,10 @@ class AudioModelDescription(ModelDescription):
68
70
  super().__init__(address, devices, model_path=model_path)
69
71
  self._model_spec = model_spec
70
72
 
73
+ @property
74
+ def spec(self):
75
+ return self._model_spec
76
+
71
77
  def to_dict(self):
72
78
  return {
73
79
  "model_type": "audio",
@@ -178,6 +184,7 @@ def create_audio_model_instance(
178
184
  F5TTSMLXModel,
179
185
  MeloTTSModel,
180
186
  KokoroModel,
187
+ MegaTTSModel,
181
188
  ],
182
189
  AudioModelDescription,
183
190
  ]:
@@ -195,6 +202,7 @@ def create_audio_model_instance(
195
202
  F5TTSMLXModel,
196
203
  MeloTTSModel,
197
204
  KokoroModel,
205
+ MegaTTSModel,
198
206
  ]
199
207
  if model_spec.model_family == "whisper":
200
208
  if not model_spec.engine:
@@ -217,6 +225,8 @@ def create_audio_model_instance(
217
225
  model = MeloTTSModel(model_uid, model_path, model_spec, **kwargs)
218
226
  elif model_spec.model_family == "Kokoro":
219
227
  model = KokoroModel(model_uid, model_path, model_spec, **kwargs)
228
+ elif model_spec.model_family == "MegaTTS":
229
+ model = MegaTTSModel(model_uid, model_path, model_spec, **kwargs)
220
230
  else:
221
231
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
222
232
  model_description = AudioModelDescription(
@@ -0,0 +1,105 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ from io import BytesIO
17
+ from typing import TYPE_CHECKING, Optional
18
+
19
+ if TYPE_CHECKING:
20
+ from .core import AudioModelFamilyV1
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class MegaTTSModel:
26
+ def __init__(
27
+ self,
28
+ model_uid: str,
29
+ model_path: str,
30
+ model_spec: "AudioModelFamilyV1",
31
+ device: Optional[str] = None,
32
+ **kwargs,
33
+ ):
34
+ self._model_uid = model_uid
35
+ self._model_path = model_path
36
+ self._model_spec = model_spec
37
+ self._device = device
38
+ self._model = None
39
+ self._vocoder = None
40
+ self._kwargs = kwargs
41
+
42
+ @property
43
+ def model_ability(self):
44
+ return self._model_spec.model_ability
45
+
46
+ def load(self):
47
+ import os
48
+ import sys
49
+
50
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
51
+ sys.path.insert(
52
+ 0, os.path.join(os.path.dirname(__file__), "../../thirdparty/megatts3")
53
+ )
54
+ # For whisper
55
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
56
+
57
+ from tts.infer_cli import MegaTTS3DiTInfer
58
+
59
+ self._model = MegaTTS3DiTInfer(ckpt_root=self._model_path)
60
+
61
+ def speech(
62
+ self,
63
+ input: str,
64
+ voice: str,
65
+ response_format: str = "mp3",
66
+ speed: float = 1.0,
67
+ stream: bool = False,
68
+ **kwargs,
69
+ ):
70
+ import soundfile
71
+
72
+ if stream:
73
+ raise Exception("MegaTTS3 does not support stream generation.")
74
+ if voice:
75
+ raise Exception(
76
+ "MegaTTS3 does not support voice, please specify prompt_speech and prompt_latent."
77
+ )
78
+
79
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
80
+ prompt_latent: Optional[bytes] = kwargs.pop("prompt_latent", None)
81
+ if not prompt_speech:
82
+ raise Exception("Please set prompt_speech for MegaTTS3.")
83
+ if not prompt_latent:
84
+ raise Exception("Please set prompt_latent for MegaTTS3.")
85
+
86
+ assert self._model is not None
87
+ with io.BytesIO(prompt_latent) as prompt_latent_io:
88
+ resource_context = self._model.preprocess(
89
+ prompt_speech, latent_file=prompt_latent_io
90
+ )
91
+ wav_bytes = self._model.forward(
92
+ resource_context,
93
+ input,
94
+ time_step=kwargs.get("time_step", 32),
95
+ p_w=kwargs.get("p_w", 1.6),
96
+ t_w=kwargs.get("t_w", 2.5),
97
+ )
98
+
99
+ # Save the generated audio
100
+ with BytesIO() as out:
101
+ with soundfile.SoundFile(
102
+ out, "w", self._model.sr, 1, format=response_format.upper()
103
+ ) as f:
104
+ f.write(wav_bytes)
105
+ return out.getvalue()
@@ -203,6 +203,21 @@
203
203
  "merge_length_s": 15
204
204
  }
205
205
  },
206
+ {
207
+ "model_name": "paraformer-zh",
208
+ "model_family": "funasr",
209
+ "model_id": "funasr/paraformer-zh",
210
+ "model_revision": "5ed094cdfc8f6a9b6b022bd08bc904ef862bc79e",
211
+ "model_ability": "audio-to-text",
212
+ "multilingual": false,
213
+ "default_model_config": {
214
+ "vad_model": "fsmn-vad",
215
+ "punc_model": "ct-punc"
216
+ },
217
+ "default_transcription_config": {
218
+ "batch_size_s": 300
219
+ }
220
+ },
206
221
  {
207
222
  "model_name": "ChatTTS",
208
223
  "model_family": "ChatTTS",
@@ -216,7 +231,7 @@
216
231
  "model_family": "CosyVoice",
217
232
  "model_id": "FunAudioLLM/CosyVoice-300M",
218
233
  "model_revision": "39c4e13d46bd4dfb840d214547623e5fcd2428e2",
219
- "model_ability": "audio-to-audio",
234
+ "model_ability": "text-to-audio",
220
235
  "multilingual": true
221
236
  },
222
237
  {
@@ -346,5 +361,13 @@
346
361
  "model_revision": "7884269d6fd3f9beabc271b6f1308e5699281fa9",
347
362
  "model_ability": "text-to-audio",
348
363
  "multilingual": true
364
+ },
365
+ {
366
+ "model_name": "MegaTTS3",
367
+ "model_family": "MegaTTS",
368
+ "model_id": "ByteDance/MegaTTS3",
369
+ "model_revision": "409a7002b006d80f0730fca6f80441b08c10e738",
370
+ "model_ability": "text-to-audio",
371
+ "multilingual": true
349
372
  }
350
373
  ]