xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (132) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +54 -1
  4. xinference/client/restful/restful_client.py +82 -2
  5. xinference/constants.py +3 -0
  6. xinference/core/chat_interface.py +297 -83
  7. xinference/core/model.py +24 -3
  8. xinference/core/progress_tracker.py +16 -8
  9. xinference/core/supervisor.py +51 -1
  10. xinference/core/worker.py +315 -47
  11. xinference/deploy/cmdline.py +33 -1
  12. xinference/model/audio/core.py +11 -1
  13. xinference/model/audio/megatts.py +105 -0
  14. xinference/model/audio/model_spec.json +24 -1
  15. xinference/model/audio/model_spec_modelscope.json +26 -1
  16. xinference/model/core.py +14 -0
  17. xinference/model/embedding/core.py +6 -1
  18. xinference/model/flexible/core.py +6 -1
  19. xinference/model/image/core.py +6 -1
  20. xinference/model/image/model_spec.json +17 -1
  21. xinference/model/image/model_spec_modelscope.json +17 -1
  22. xinference/model/llm/__init__.py +4 -6
  23. xinference/model/llm/core.py +5 -0
  24. xinference/model/llm/llama_cpp/core.py +46 -17
  25. xinference/model/llm/llm_family.json +530 -85
  26. xinference/model/llm/llm_family.py +24 -1
  27. xinference/model/llm/llm_family_modelscope.json +572 -1
  28. xinference/model/llm/mlx/core.py +16 -2
  29. xinference/model/llm/reasoning_parser.py +3 -3
  30. xinference/model/llm/sglang/core.py +111 -13
  31. xinference/model/llm/transformers/__init__.py +14 -0
  32. xinference/model/llm/transformers/core.py +31 -6
  33. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  34. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  35. xinference/model/llm/transformers/gemma3.py +17 -2
  36. xinference/model/llm/transformers/intern_vl.py +28 -18
  37. xinference/model/llm/transformers/minicpmv26.py +21 -2
  38. xinference/model/llm/transformers/qwen-omni.py +308 -0
  39. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  40. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  41. xinference/model/llm/utils.py +37 -15
  42. xinference/model/llm/vllm/core.py +184 -8
  43. xinference/model/llm/vllm/distributed_executor.py +320 -0
  44. xinference/model/rerank/core.py +22 -12
  45. xinference/model/utils.py +118 -1
  46. xinference/model/video/core.py +6 -1
  47. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  48. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  49. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  50. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  51. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  52. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  53. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  54. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  55. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  56. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  57. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  58. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  59. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  60. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  61. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  62. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  63. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  64. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  65. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  66. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  67. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  68. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  69. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  70. xinference/thirdparty/megatts3/__init__.py +0 -0
  71. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  72. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  73. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  74. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  75. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  76. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  77. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  78. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  79. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  80. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  81. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  82. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  83. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  84. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  85. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  86. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  87. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  88. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  89. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  90. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  91. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  92. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  93. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  94. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  95. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  96. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  97. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  98. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  99. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  100. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  101. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  102. xinference/types.py +10 -0
  103. xinference/utils.py +54 -0
  104. xinference/web/ui/build/asset-manifest.json +6 -6
  105. xinference/web/ui/build/index.html +1 -1
  106. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  107. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  108. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  109. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  112. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  113. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  114. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  116. xinference/web/ui/src/locales/en.json +2 -1
  117. xinference/web/ui/src/locales/zh.json +2 -1
  118. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
  119. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
  120. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  121. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  122. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  123. xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
  124. xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  129. /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  130. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  131. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  132. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,13 @@ import os
18
18
  import signal
19
19
  import time
20
20
  import typing
21
- from dataclasses import dataclass
21
+ from collections import defaultdict
22
+ from dataclasses import dataclass, field
22
23
  from logging import getLogger
23
24
  from typing import (
24
25
  TYPE_CHECKING,
25
26
  Any,
27
+ DefaultDict,
26
28
  Dict,
27
29
  Iterator,
28
30
  List,
@@ -91,6 +93,9 @@ class WorkerStatus:
91
93
  class ReplicaInfo:
92
94
  replica: int
93
95
  scheduler: Iterator
96
+ replica_to_worker_refs: DefaultDict[
97
+ int, List[xo.ActorRefType["WorkerActor"]]
98
+ ] = field(default_factory=lambda: defaultdict(list))
94
99
 
95
100
 
96
101
  class SupervisorActor(xo.StatelessActor):
@@ -1097,6 +1102,7 @@ class SupervisorActor(xo.StatelessActor):
1097
1102
  xavier_config=xavier_config,
1098
1103
  **kwargs,
1099
1104
  )
1105
+ await worker_ref.wait_for_load(_replica_model_uid)
1100
1106
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
1101
1107
  return subpool_address
1102
1108
 
@@ -1112,6 +1118,9 @@ class SupervisorActor(xo.StatelessActor):
1112
1118
  if target_ip_worker_ref is not None
1113
1119
  else await self._choose_worker()
1114
1120
  )
1121
+ self._model_uid_to_replica_info[model_uid].replica_to_worker_refs[
1122
+ _idx
1123
+ ].append(worker_ref)
1115
1124
  if enable_xavier and _idx == 0:
1116
1125
  """
1117
1126
  Start the rank 0 model actor on the worker that holds the rank 1 replica,
@@ -1242,6 +1251,11 @@ class SupervisorActor(xo.StatelessActor):
1242
1251
  available_workers.append(worker_ip)
1243
1252
 
1244
1253
  async def _launch_model():
1254
+ # Validation of n_worker, intercept if it is greater than the available workers.
1255
+ if n_worker > len(available_workers):
1256
+ raise ValueError(
1257
+ "n_worker cannot be larger than the number of available workers."
1258
+ )
1245
1259
  try:
1246
1260
  for _idx, rep_model_uid in enumerate(
1247
1261
  iter_replica_model_uid(model_uid, replica)
@@ -1254,6 +1268,9 @@ class SupervisorActor(xo.StatelessActor):
1254
1268
  driver_info = None
1255
1269
  for i_worker in range(n_worker):
1256
1270
  worker_ref = await self._choose_worker(available_workers)
1271
+ self._model_uid_to_replica_info[
1272
+ model_uid
1273
+ ].replica_to_worker_refs[_idx].append(worker_ref)
1257
1274
  nonlocal model_type
1258
1275
  model_type = model_type or "LLM"
1259
1276
  if i_worker > 1:
@@ -1338,6 +1355,39 @@ class SupervisorActor(xo.StatelessActor):
1338
1355
  task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
1339
1356
  return model_uid
1340
1357
 
1358
+ async def get_launch_builtin_model_progress(self, model_uid: str) -> float:
1359
+ info = self._model_uid_to_replica_info[model_uid]
1360
+ all_progress = 0.0
1361
+ i = 0
1362
+ for rep_model_uid in iter_replica_model_uid(model_uid, info.replica):
1363
+ request_id = f"launching-{rep_model_uid}"
1364
+ try:
1365
+ all_progress += await self._progress_tracker.get_progress(request_id)
1366
+ i += 1
1367
+ except KeyError:
1368
+ continue
1369
+
1370
+ return all_progress / i if i > 0 else 0.0
1371
+
1372
+ async def cancel_launch_builtin_model(self, model_uid: str):
1373
+ info = self._model_uid_to_replica_info[model_uid]
1374
+ coros = []
1375
+ for i, rep_model_uid in enumerate(
1376
+ iter_replica_model_uid(model_uid, info.replica)
1377
+ ):
1378
+ worker_refs = self._model_uid_to_replica_info[
1379
+ model_uid
1380
+ ].replica_to_worker_refs[i]
1381
+ for worker_ref in worker_refs:
1382
+ coros.append(worker_ref.cancel_launch_model(rep_model_uid))
1383
+ try:
1384
+ await asyncio.gather(*coros)
1385
+ except RuntimeError:
1386
+ # some may have finished
1387
+ pass
1388
+ # remove replica info
1389
+ self._model_uid_to_replica_info.pop(model_uid, None)
1390
+
1341
1391
  async def get_instance_info(
1342
1392
  self, model_name: Optional[str], model_uid: Optional[str]
1343
1393
  ) -> List[Dict]:
xinference/core/worker.py CHANGED
@@ -22,9 +22,20 @@ import signal
22
22
  import threading
23
23
  import time
24
24
  from collections import defaultdict
25
- from dataclasses import dataclass
25
+ from dataclasses import dataclass, field
26
26
  from logging import getLogger
27
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, no_type_check
27
+ from typing import (
28
+ TYPE_CHECKING,
29
+ Any,
30
+ Dict,
31
+ List,
32
+ Literal,
33
+ Optional,
34
+ Set,
35
+ Tuple,
36
+ Union,
37
+ no_type_check,
38
+ )
28
39
 
29
40
  import xoscar as xo
30
41
  from async_timeout import timeout
@@ -34,13 +45,17 @@ from ..constants import (
34
45
  XINFERENCE_CACHE_DIR,
35
46
  XINFERENCE_DISABLE_HEALTH_CHECK,
36
47
  XINFERENCE_DISABLE_METRICS,
48
+ XINFERENCE_ENABLE_VIRTUAL_ENV,
37
49
  XINFERENCE_HEALTH_CHECK_INTERVAL,
50
+ XINFERENCE_VIRTUAL_ENV_DIR,
38
51
  )
39
52
  from ..core.model import ModelActor
40
53
  from ..core.status_guard import LaunchStatus
41
54
  from ..device_utils import get_available_device_env_name, gpu_count
42
- from ..model.core import ModelDescription, create_model_instance
55
+ from ..model.core import ModelDescription, VirtualEnvSettings, create_model_instance
56
+ from ..model.utils import CancellableDownloader
43
57
  from ..types import PeftModelConfig
58
+ from ..utils import get_pip_config_args, get_real_path
44
59
  from .cache_tracker import CacheTrackerActor
45
60
  from .event import Event, EventCollectorActor, EventType
46
61
  from .metrics import launch_metrics_export_server, record_metrics
@@ -48,6 +63,14 @@ from .resource import gather_node_info
48
63
  from .status_guard import StatusGuardActor
49
64
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
50
65
 
66
+ try:
67
+ from xoscar.virtualenv import VirtualEnvManager
68
+ except ImportError:
69
+ VirtualEnvManager = None
70
+
71
+ if TYPE_CHECKING:
72
+ from .progress_tracker import Progressor
73
+
51
74
  logger = getLogger(__name__)
52
75
 
53
76
 
@@ -64,6 +87,17 @@ class ModelStatus:
64
87
  last_error: str = ""
65
88
 
66
89
 
90
+ @dataclass
91
+ class LaunchInfo:
92
+ cancel_event: threading.Event = field(default_factory=threading.Event)
93
+ # virtualenv manager
94
+ virtual_env_manager: Optional["VirtualEnvManager"] = None
95
+ # downloader, report progress or cancel entire download
96
+ downloader: Optional[CancellableDownloader] = None
97
+ # sub pools created for the model
98
+ sub_pools: Optional[List[str]] = None
99
+
100
+
67
101
  class WorkerActor(xo.StatelessActor):
68
102
  def __init__(
69
103
  self,
@@ -92,7 +126,7 @@ class WorkerActor(xo.StatelessActor):
92
126
 
93
127
  # internal states.
94
128
  # temporary placeholder during model launch process:
95
- self._model_uid_launching_guard: Dict[str, bool] = {}
129
+ self._model_uid_launching_guard: Dict[str, LaunchInfo] = {}
96
130
  # attributes maintained after model launched:
97
131
  self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {}
98
132
  self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
@@ -352,6 +386,7 @@ class WorkerActor(xo.StatelessActor):
352
386
  self._cache_tracker_ref = await xo.actor_ref(
353
387
  address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
354
388
  )
389
+ self._progress_tracker_ref = None
355
390
  # cache_tracker is on supervisor
356
391
  from ..model.audio import get_audio_model_descriptions
357
392
  from ..model.embedding import get_embedding_model_descriptions
@@ -548,8 +583,9 @@ class WorkerActor(xo.StatelessActor):
548
583
  model_type: Optional[str] = None,
549
584
  n_gpu: Optional[Union[int, str]] = "auto",
550
585
  gpu_idx: Optional[List[int]] = None,
586
+ env: Optional[Dict[str, str]] = None,
551
587
  ) -> Tuple[str, List[str]]:
552
- env = {}
588
+ env = {} if env is None else env
553
589
  devices = []
554
590
  env_name = get_available_device_env_name() or "CUDA_VISIBLE_DEVICES"
555
591
  if gpu_idx is None:
@@ -778,6 +814,96 @@ class WorkerActor(xo.StatelessActor):
778
814
  version_info["model_file_location"],
779
815
  )
780
816
 
817
+ @classmethod
818
+ def _create_virtual_env_manager(
819
+ cls,
820
+ enable_virtual_env: Optional[bool],
821
+ virtual_env_name: Optional[str],
822
+ env_path: str,
823
+ ) -> Optional[VirtualEnvManager]:
824
+ if enable_virtual_env is None:
825
+ enable_virtual_env = XINFERENCE_ENABLE_VIRTUAL_ENV
826
+
827
+ if not enable_virtual_env:
828
+ # skip preparing virtualenv
829
+ return None
830
+
831
+ from xoscar.virtualenv import get_virtual_env_manager
832
+
833
+ virtual_env_manager: VirtualEnvManager = get_virtual_env_manager(
834
+ virtual_env_name or "uv", env_path
835
+ )
836
+ return virtual_env_manager
837
+
838
+ @classmethod
839
+ def _prepare_virtual_env(
840
+ cls,
841
+ virtual_env_manager: "VirtualEnvManager",
842
+ settings: Optional[VirtualEnvSettings],
843
+ ):
844
+ if not settings or not settings.packages:
845
+ # no settings or no packages
846
+ return
847
+
848
+ # create env
849
+ virtual_env_manager.create_env()
850
+
851
+ if settings.inherit_pip_config:
852
+ # inherit pip config
853
+ pip_config = get_pip_config_args()
854
+ for k, v in pip_config.items():
855
+ if hasattr(settings, k) and not getattr(settings, k):
856
+ setattr(settings, k, v)
857
+
858
+ packages = settings.packages
859
+ index_url = settings.index_url
860
+ extra_index_url = settings.extra_index_url
861
+ find_links = settings.find_links
862
+ trusted_host = settings.trusted_host
863
+
864
+ logger.info(
865
+ "Installing packages %s in virtual env %s, with settings(index_url=%s)",
866
+ packages,
867
+ virtual_env_manager.env_path,
868
+ index_url,
869
+ )
870
+ virtual_env_manager.install_packages(
871
+ packages,
872
+ index_url=index_url,
873
+ extra_index_url=extra_index_url,
874
+ find_links=find_links,
875
+ trusted_host=trusted_host,
876
+ )
877
+
878
+ async def _get_progressor(self, request_id: str):
879
+ from .progress_tracker import Progressor, ProgressTrackerActor
880
+
881
+ progress_tracker_ref = self._progress_tracker_ref
882
+ if progress_tracker_ref is None:
883
+ progress_tracker_ref = self._progress_tracker_ref = await xo.actor_ref(
884
+ address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
885
+ )
886
+
887
+ progressor = Progressor(
888
+ request_id,
889
+ progress_tracker_ref,
890
+ asyncio.get_running_loop(),
891
+ )
892
+ await progressor.start()
893
+ progressor.set_progress(0.0, "start to launch model")
894
+ return progressor
895
+
896
+ @classmethod
897
+ def _upload_download_progress(
898
+ cls, progressor: "Progressor", downloader: CancellableDownloader
899
+ ):
900
+ while not downloader.done:
901
+ progress = downloader.get_progress()
902
+ progressor.set_progress(progress)
903
+ downloader.wait(1)
904
+
905
+ progressor.set_progress(1.0, "Start to load model")
906
+
781
907
  @log_async(logger=logger, level=logging.INFO)
782
908
  async def launch_builtin_model(
783
909
  self,
@@ -870,11 +996,29 @@ class WorkerActor(xo.StatelessActor):
870
996
  raise ValueError(f"{model_uid} is running")
871
997
 
872
998
  try:
873
- self._model_uid_launching_guard[model_uid] = True
874
- subpool_address, devices = await self._create_subpool(
875
- model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
999
+ self._model_uid_launching_guard[model_uid] = launch_info = LaunchInfo()
1000
+
1001
+ # virtualenv
1002
+ enable_virtual_env = kwargs.pop("enable_virtual_env", None)
1003
+ virtual_env_name = kwargs.pop("virtual_env_name", None)
1004
+ virtual_env_path = os.path.join(XINFERENCE_VIRTUAL_ENV_DIR, model_name)
1005
+ virtual_env_manager = await asyncio.to_thread(
1006
+ self._create_virtual_env_manager,
1007
+ enable_virtual_env,
1008
+ virtual_env_name,
1009
+ virtual_env_path,
1010
+ )
1011
+ # setting os.environ if virtualenv created
1012
+ env = (
1013
+ {"PYTHONPATH": virtual_env_manager.get_lib_path()}
1014
+ if virtual_env_manager
1015
+ else None
876
1016
  )
877
1017
 
1018
+ subpool_address, devices = await self._create_subpool(
1019
+ model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx, env=env
1020
+ )
1021
+ all_subpool_addresses = [subpool_address]
878
1022
  try:
879
1023
  xavier_config: Optional[Dict] = kwargs.pop("xavier_config", None)
880
1024
  if xavier_config is not None:
@@ -885,29 +1029,68 @@ class WorkerActor(xo.StatelessActor):
885
1029
  # add a few kwargs
886
1030
  model_kwargs.update(
887
1031
  dict(
888
- address=self.address,
1032
+ address=subpool_address,
889
1033
  n_worker=n_worker,
890
1034
  shard=shard,
891
1035
  driver_info=driver_info,
892
1036
  )
893
1037
  )
894
- model, model_description = await asyncio.to_thread(
895
- create_model_instance,
896
- subpool_address,
897
- devices,
898
- model_uid,
899
- model_type,
900
- model_name,
901
- model_engine,
902
- model_format,
903
- model_size_in_billions,
904
- quantization,
905
- peft_model_config,
906
- download_hub,
907
- model_path,
908
- **model_kwargs,
909
- )
910
- await self.update_cache_status(model_name, model_description)
1038
+
1039
+ with CancellableDownloader(
1040
+ cancelled_event=launch_info.cancel_event
1041
+ ) as downloader:
1042
+ launch_info.downloader = downloader
1043
+ progressor = await self._get_progressor("launching-" + model_uid)
1044
+ # split into download and launch
1045
+ progressor.split_stages(2, stage_weight=[0, 0.8, 1.0])
1046
+ with progressor:
1047
+ upload_progress_task = asyncio.create_task(
1048
+ asyncio.to_thread(
1049
+ self._upload_download_progress, progressor, downloader
1050
+ )
1051
+ )
1052
+ model, model_description = await asyncio.to_thread(
1053
+ create_model_instance,
1054
+ subpool_address,
1055
+ devices,
1056
+ model_uid,
1057
+ model_type,
1058
+ model_name,
1059
+ model_engine,
1060
+ model_format,
1061
+ model_size_in_billions,
1062
+ quantization,
1063
+ peft_model_config,
1064
+ download_hub,
1065
+ model_path,
1066
+ **model_kwargs,
1067
+ )
1068
+ await self.update_cache_status(model_name, model_description)
1069
+
1070
+ def check_cancel():
1071
+ # check downloader first, sometimes download finished
1072
+ # cancelled already
1073
+ if downloader.cancelled:
1074
+ with progressor:
1075
+ # just report progress
1076
+ pass
1077
+ downloader.raise_error(error_msg="Launch cancelled")
1078
+
1079
+ # check cancel before prepare virtual env
1080
+ check_cancel()
1081
+
1082
+ # install packages in virtual env
1083
+ if virtual_env_manager:
1084
+ await asyncio.to_thread(
1085
+ self._prepare_virtual_env,
1086
+ virtual_env_manager,
1087
+ model_description.spec.virtualenv,
1088
+ )
1089
+ launch_info.virtual_env_manager = virtual_env_manager
1090
+
1091
+ # check before creating model actor
1092
+ check_cancel()
1093
+
911
1094
  model_ref = await xo.create_actor(
912
1095
  ModelActor,
913
1096
  address=subpool_address,
@@ -923,11 +1106,44 @@ class WorkerActor(xo.StatelessActor):
923
1106
  shard=shard,
924
1107
  driver_info=driver_info,
925
1108
  )
926
- await model_ref.load()
1109
+ if await model_ref.need_create_pools() and (
1110
+ len(devices) > 1 or n_worker > 1 # type: ignore
1111
+ ):
1112
+ coros = []
1113
+ env_name = get_available_device_env_name() or "CUDA_VISIBLE_DEVICES"
1114
+ env_value = ",".join(devices)
1115
+ for device in devices:
1116
+ coros.append(
1117
+ self._main_pool.append_sub_pool(
1118
+ env={env_name: env_value},
1119
+ start_method=self._get_start_method(),
1120
+ )
1121
+ )
1122
+ pool_addresses = await asyncio.gather(*coros)
1123
+ all_subpool_addresses.extend(pool_addresses)
1124
+ await model_ref.set_pool_addresses(pool_addresses)
1125
+
1126
+ # check before loading
1127
+ check_cancel()
1128
+
1129
+ # set all subpool addresses
1130
+ # when cancelled, all subpool addresses need to be destroyed
1131
+ launch_info.sub_pools = all_subpool_addresses
1132
+
1133
+ with progressor:
1134
+ try:
1135
+ await model_ref.load()
1136
+ except xo.ServerClosed:
1137
+ check_cancel()
1138
+ raise
927
1139
  except:
928
1140
  logger.error(f"Failed to load model {model_uid}", exc_info=True)
929
1141
  self.release_devices(model_uid=model_uid)
930
- await self._main_pool.remove_sub_pool(subpool_address)
1142
+ for addr in all_subpool_addresses:
1143
+ try:
1144
+ await self._main_pool.remove_sub_pool(addr)
1145
+ except KeyError:
1146
+ continue
931
1147
  raise
932
1148
  self._model_uid_to_model[model_uid] = model_ref
933
1149
  self._model_uid_to_model_spec[model_uid] = model_description
@@ -961,6 +1177,39 @@ class WorkerActor(xo.StatelessActor):
961
1177
  model_ref = self._model_uid_to_model[model_uid]
962
1178
  await model_ref.wait_for_load()
963
1179
 
1180
+ @log_sync(logger=logger, level=logging.INFO)
1181
+ async def cancel_launch_model(self, model_uid: str):
1182
+ try:
1183
+ launch_info = self._model_uid_launching_guard[model_uid]
1184
+
1185
+ # downloader shared same cancel event
1186
+ # sometimes cancel happens very early before downloader
1187
+ # even if users cancel at this time,
1188
+ # downloader will know and stop everything
1189
+ launch_info.cancel_event.set()
1190
+
1191
+ if launch_info.downloader:
1192
+ logger.debug("Try to cancel download, %s")
1193
+ launch_info.downloader.cancel()
1194
+ if launch_info.virtual_env_manager:
1195
+ launch_info.virtual_env_manager.cancel_install()
1196
+ if launch_info.sub_pools:
1197
+ logger.debug("Try to stop sub pools: %s", launch_info.sub_pools)
1198
+ coros = []
1199
+ for addr in launch_info.sub_pools:
1200
+ coros.append(self._main_pool.remove_sub_pool(addr, force=True))
1201
+ await asyncio.gather(*coros)
1202
+ if self._status_guard_ref is not None:
1203
+ await self._status_guard_ref.update_instance_info(
1204
+ parse_replica_model_uid(model_uid)[0],
1205
+ {"status": LaunchStatus.ERROR.name},
1206
+ )
1207
+ except KeyError:
1208
+ logger.error("Fail to cancel launching", exc_info=True)
1209
+ raise RuntimeError(
1210
+ "Model is not launching, may have launched or not launched yet"
1211
+ )
1212
+
964
1213
  @log_async(logger=logger, level=logging.INFO)
965
1214
  async def terminate_model(self, model_uid: str, is_model_die=False):
966
1215
  # Terminate model while its launching is not allow
@@ -994,15 +1243,36 @@ class WorkerActor(xo.StatelessActor):
994
1243
  if model_ref is None:
995
1244
  logger.debug("Model not found, uid: %s", model_uid)
996
1245
 
1246
+ pool_addresses = None
1247
+ if model_ref is not None:
1248
+ try:
1249
+ # pool addresses if model.need_create_pools()
1250
+ pool_addresses = await model_ref.get_pool_addresses()
1251
+ except Exception as e:
1252
+ # process may disappear, we just ignore it.
1253
+ logger.debug("Fail to get pool addresses, error: %s", e)
1254
+
997
1255
  try:
998
- await xo.destroy_actor(model_ref)
1256
+ logger.debug("Start to destroy model actor: %s", model_ref)
1257
+ coro = xo.destroy_actor(model_ref)
1258
+ await asyncio.wait_for(coro, timeout=5)
999
1259
  except Exception as e:
1000
1260
  logger.debug(
1001
1261
  "Destroy model actor failed, model uid: %s, error: %s", model_uid, e
1002
1262
  )
1003
1263
  try:
1264
+ to_remove_addresses = []
1004
1265
  subpool_address = self._model_uid_to_addr[model_uid]
1005
- await self._main_pool.remove_sub_pool(subpool_address, force=True)
1266
+ to_remove_addresses.append(subpool_address)
1267
+ if pool_addresses:
1268
+ to_remove_addresses.extend(pool_addresses)
1269
+ logger.debug("Remove sub pools: %s", to_remove_addresses)
1270
+ coros = []
1271
+ for to_remove_addr in to_remove_addresses:
1272
+ coros.append(
1273
+ self._main_pool.remove_sub_pool(to_remove_addr, force=True)
1274
+ )
1275
+ await asyncio.gather(*coros)
1006
1276
  except Exception as e:
1007
1277
  logger.debug(
1008
1278
  "Remove sub pool failed, model uid: %s, error: %s", model_uid, e
@@ -1119,16 +1389,9 @@ class WorkerActor(xo.StatelessActor):
1119
1389
  }
1120
1390
  path = list.get("model_file_location")
1121
1391
  cached_model["path"] = path
1122
- # parsing soft links
1123
- if os.path.isdir(path):
1124
- files = os.listdir(path)
1125
- # dir has files
1126
- if files:
1127
- resolved_file = os.path.realpath(os.path.join(path, files[0]))
1128
- if resolved_file:
1129
- cached_model["real_path"] = os.path.dirname(resolved_file)
1130
- else:
1131
- cached_model["real_path"] = os.path.realpath(path)
1392
+ real_path = get_real_path(path)
1393
+ if real_path:
1394
+ cached_model["real_path"] = real_path
1132
1395
  cached_model["actor_ip_address"] = self.address
1133
1396
  cached_models.append(cached_model)
1134
1397
  return cached_models
@@ -1204,18 +1467,23 @@ class WorkerActor(xo.StatelessActor):
1204
1467
  model_ref = self._model_uid_to_model[rep_model_uid]
1205
1468
  await model_ref.start_transfer_for_vllm(rank_addresses)
1206
1469
 
1207
- @log_async(logger=logger, level=logging.INFO)
1208
- async def launch_rank0_model(
1209
- self, rep_model_uid: str, xavier_config: Dict[str, Any]
1210
- ) -> Tuple[str, int]:
1211
- from ..model.llm.vllm.xavier.collective_manager import Rank0ModelActor
1212
-
1470
+ @staticmethod
1471
+ def _get_start_method():
1213
1472
  if os.name != "nt" and platform.system() != "Darwin":
1214
1473
  # Linux
1215
1474
  start_method = "forkserver"
1216
1475
  else:
1217
1476
  # Windows and macOS
1218
1477
  start_method = "spawn"
1478
+ return start_method
1479
+
1480
+ @log_async(logger=logger, level=logging.INFO)
1481
+ async def launch_rank0_model(
1482
+ self, rep_model_uid: str, xavier_config: Dict[str, Any]
1483
+ ) -> Tuple[str, int]:
1484
+ from ..model.llm.vllm.xavier.collective_manager import Rank0ModelActor
1485
+
1486
+ start_method = self._get_start_method()
1219
1487
  subpool_address = await self._main_pool.append_sub_pool(
1220
1488
  start_method=start_method
1221
1489
  )
@@ -1224,7 +1492,7 @@ class WorkerActor(xo.StatelessActor):
1224
1492
  # Note that `store_port` needs to be generated on the worker,
1225
1493
  # as the TCP store is on rank 0, not on the supervisor.
1226
1494
  store_port = xo.utils.get_next_port()
1227
- self._model_uid_launching_guard[rep_model_uid] = True
1495
+ self._model_uid_launching_guard[rep_model_uid] = LaunchInfo()
1228
1496
  try:
1229
1497
  try:
1230
1498
  xavier_config["rank_address"] = subpool_address
@@ -16,10 +16,12 @@ import asyncio
16
16
  import logging
17
17
  import os
18
18
  import sys
19
+ import time
19
20
  import warnings
20
21
  from typing import Dict, List, Optional, Sequence, Tuple, Union
21
22
 
22
23
  import click
24
+ from tqdm.auto import tqdm
23
25
  from xoscar.utils import get_next_port
24
26
 
25
27
  from .. import __version__
@@ -925,6 +927,9 @@ def model_launch(
925
927
  if api_key is None:
926
928
  client._set_token(get_stored_token(endpoint, client))
927
929
 
930
+ # do not wait for launching.
931
+ kwargs["wait_ready"] = False
932
+
928
933
  model_uid = client.launch_model(
929
934
  model_name=model_name,
930
935
  model_type=model_type,
@@ -943,8 +948,35 @@ def model_launch(
943
948
  model_path=model_path,
944
949
  **kwargs,
945
950
  )
951
+ try:
952
+ with tqdm(
953
+ total=100, desc="Launching model", bar_format="{l_bar}{bar} | {n:.1f}%"
954
+ ) as pbar:
955
+ while True:
956
+ status = client.get_instance_info(model_name, model_uid)
957
+ if all(s["status"] in ["READY", "ERROR", "TERMINATED"] for s in status):
958
+ break
959
+
960
+ progress = client.get_launch_model_progress(model_uid)["progress"]
961
+ percent = max(round(progress * 100, 1), pbar.n)
946
962
 
947
- print(f"Model uid: {model_uid}", file=sys.stderr)
963
+ pbar.update(percent - pbar.n)
964
+
965
+ time.sleep(0.5)
966
+
967
+ # setting to 100%
968
+ pbar.update(pbar.total - pbar.n)
969
+
970
+ print(f"Model uid: {model_uid}", file=sys.stderr)
971
+ except KeyboardInterrupt:
972
+ user_input = (
973
+ input("Do you want to cancel model launching? (y/[n]): ").strip().lower()
974
+ )
975
+ if user_input == "y":
976
+ client.cancel_launch_model(model_uid)
977
+ print(f"Cancel request sent: {model_uid}")
978
+ else:
979
+ print("Skip cancel, launching model will be running still.")
948
980
 
949
981
 
950
982
  @cli.command(