xinference 1.2.2__py3-none-any.whl → 1.3.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (68) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +9 -1
  3. xinference/core/model.py +19 -0
  4. xinference/core/resource.py +7 -1
  5. xinference/core/status_guard.py +1 -0
  6. xinference/core/supervisor.py +228 -19
  7. xinference/core/utils.py +1 -29
  8. xinference/core/worker.py +28 -2
  9. xinference/deploy/cmdline.py +33 -3
  10. xinference/deploy/test/test_cmdline.py +32 -0
  11. xinference/device_utils.py +43 -1
  12. xinference/model/audio/kokoro.py +19 -36
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/image/stable_diffusion/core.py +15 -6
  15. xinference/model/llm/llm_family.json +521 -6
  16. xinference/model/llm/llm_family.py +3 -1
  17. xinference/model/llm/llm_family_modelscope.json +559 -6
  18. xinference/model/llm/reasoning_parsers/__init__.py +13 -0
  19. xinference/model/llm/reasoning_parsers/abs_reasoning_parsers.py +98 -0
  20. xinference/model/llm/reasoning_parsers/deepseek_r1_reasoning_parser.py +140 -0
  21. xinference/model/llm/sglang/core.py +99 -11
  22. xinference/model/llm/transformers/intern_vl.py +23 -14
  23. xinference/model/llm/utils.py +55 -18
  24. xinference/model/llm/vllm/core.py +23 -2
  25. xinference/model/llm/vllm/xavier/executor.py +2 -2
  26. xinference/model/llm/vllm/xavier/scheduler.py +3 -3
  27. xinference/thirdparty/internvl/conversation.py +26 -17
  28. xinference/types.py +2 -0
  29. xinference/web/ui/build/asset-manifest.json +6 -6
  30. xinference/web/ui/build/index.html +1 -1
  31. xinference/web/ui/build/static/css/main.f8177338.css +2 -0
  32. xinference/web/ui/build/static/css/main.f8177338.css.map +1 -0
  33. xinference/web/ui/build/static/js/main.ad42919c.js +3 -0
  34. xinference/web/ui/build/static/js/main.ad42919c.js.map +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/074a42304bbbaa79e1bfc3b28502457a390df55708de9006f4cc8e35c60aea87.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/0acb065326560592b10888234242f94f67efe28458b90f273d4d4fba9daa0cd2.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/279ace390216236a82b3d8995c78eca4d637ac9a523e9f521a2d9c76607a43d7.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/630a7bd592596cc6e291fc32238ce7c08238038a64ed8ccee0eb0c13c9902910.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/6cb9f6c62ab4042f0b11c5d75e51187188e9d6f5f08b1d63e796e051bafdb457.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/914c33e91c1012e3bcd3e96f3a25884cbef148290632d0266dab972b8cc1e95f.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/b7939cd3a48adf12fccfdd0803019b5cc235ff7de3a297dae70ce635e0eea13e.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/fecf076bcd198a458c2a6ab0e85e40dc1c99994c353164e79c469be162cb74c9.json +1 -0
  47. xinference/web/ui/src/locales/en.json +14 -1
  48. xinference/web/ui/src/locales/zh.json +14 -1
  49. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/METADATA +11 -11
  50. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/RECORD +55 -49
  51. xinference/web/ui/build/static/css/main.51a587ff.css +0 -2
  52. xinference/web/ui/build/static/css/main.51a587ff.css.map +0 -1
  53. xinference/web/ui/build/static/js/main.b0936c54.js +0 -3
  54. xinference/web/ui/build/static/js/main.b0936c54.js.map +0 -1
  55. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +0 -1
  56. xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/a7f1a71f6580dfe810c685a9c1d68e318f71e1fa258fbe50b87a6ac37cc0a598.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +0 -1
  64. /xinference/web/ui/build/static/js/{main.b0936c54.js.LICENSE.txt → main.ad42919c.js.LICENSE.txt} +0 -0
  65. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/LICENSE +0 -0
  66. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/WHEEL +0 -0
  67. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/entry_points.txt +0 -0
  68. {xinference-1.2.2.dist-info → xinference-1.3.0.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-02-08T17:06:47+0800",
11
+ "date": "2025-02-22T00:10:55+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "ac97a13a831de6debda52e6fdb8c1bf9366be57c",
15
- "version": "1.2.2"
14
+ "full-revisionid": "b2004d49ddeda17dc6404473b1f25f8769911e18",
15
+ "version": "1.3.0.post1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -917,11 +917,13 @@ class Client:
917
917
  model_format: Optional[str] = None,
918
918
  quantization: Optional[str] = None,
919
919
  replica: int = 1,
920
+ n_worker: int = 1,
920
921
  n_gpu: Optional[Union[int, str]] = "auto",
921
922
  peft_model_config: Optional[Dict] = None,
922
923
  request_limits: Optional[int] = None,
923
924
  worker_ip: Optional[str] = None,
924
925
  gpu_idx: Optional[Union[int, List[int]]] = None,
926
+ model_path: Optional[str] = None,
925
927
  **kwargs,
926
928
  ) -> str:
927
929
  """
@@ -945,8 +947,10 @@ class Client:
945
947
  The quantization of model.
946
948
  replica: Optional[int]
947
949
  The replica of model, default is 1.
950
+ n_worker: int
951
+ Number of workers to run.
948
952
  n_gpu: Optional[Union[int, str]],
949
- The number of GPUs used by the model, default is "auto".
953
+ The number of GPUs used by the model, default is "auto". If n_worker>1, means number of GPUs per worker.
950
954
  ``n_gpu=None`` means cpu only, ``n_gpu=auto`` lets the system automatically determine the best number of GPUs to use.
951
955
  peft_model_config: Optional[Dict]
952
956
  - "lora_list": A List of PEFT (Parameter-Efficient Fine-Tuning) model and path.
@@ -959,6 +963,8 @@ class Client:
959
963
  Specify the worker ip where the model is located in a distributed scenario.
960
964
  gpu_idx: Optional[Union[int, List[int]]]
961
965
  Specify the GPU index where the model is located.
966
+ model_path: Optional[str]
967
+ Model path, if gguf format, should be the file path, otherwise, should be directory of the model.
962
968
  **kwargs:
963
969
  Any other parameters been specified.
964
970
 
@@ -985,10 +991,12 @@ class Client:
985
991
  "model_format": model_format,
986
992
  "quantization": quantization,
987
993
  "replica": replica,
994
+ "n_worker": n_worker,
988
995
  "n_gpu": n_gpu,
989
996
  "request_limits": request_limits,
990
997
  "worker_ip": worker_ip,
991
998
  "gpu_idx": gpu_idx,
999
+ "model_path": model_path,
992
1000
  }
993
1001
 
994
1002
  for key, value in kwargs.items():
xinference/core/model.py CHANGED
@@ -226,6 +226,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
226
226
  model_description: Optional["ModelDescription"] = None,
227
227
  request_limits: Optional[int] = None,
228
228
  xavier_config: Optional[Dict] = None,
229
+ n_worker: Optional[int] = 1,
230
+ shard: Optional[int] = 0,
231
+ driver_info: Optional[dict] = None, # for model across workers
229
232
  ):
230
233
  super().__init__()
231
234
  from ..model.llm.lmdeploy.core import LMDeployModel
@@ -263,6 +266,10 @@ class ModelActor(xo.StatelessActor, CancelMixin):
263
266
  "quantization": self._model_description.get("quantization", "none"),
264
267
  }
265
268
  self._loop: Optional[asyncio.AbstractEventLoop] = None
269
+ # model across workers
270
+ self._n_worker = n_worker
271
+ self._shard = shard
272
+ self._driver_info = driver_info
266
273
 
267
274
  self._scheduler_ref = None
268
275
  self._text_to_image_scheduler_ref = None
@@ -455,6 +462,8 @@ class ModelActor(xo.StatelessActor, CancelMixin):
455
462
  i += 1
456
463
  try:
457
464
  self._model.load()
465
+ if hasattr(self._model, "driver_info"):
466
+ self._driver_info = self._model.driver_info
458
467
  break
459
468
  except Exception as e:
460
469
  if (
@@ -477,6 +486,10 @@ class ModelActor(xo.StatelessActor, CancelMixin):
477
486
  )
478
487
  logger.info(f"{self} loaded")
479
488
 
489
+ async def wait_for_load(self):
490
+ if hasattr(self._model, "wait_for_load"):
491
+ self._model.wait_for_load()
492
+
480
493
  def model_uid(self):
481
494
  return (
482
495
  self._model.model_uid
@@ -488,6 +501,12 @@ class ModelActor(xo.StatelessActor, CancelMixin):
488
501
  )
489
502
  )
490
503
 
504
+ def get_driver_info(self):
505
+ # driver info is used for model across workers,
506
+ # the driver model actor(always the first worker)
507
+ # will hold driver information includes dist store etc.
508
+ return self._driver_info
509
+
491
510
  async def _handle_oom_error(self, ex):
492
511
  error_message = (
493
512
  f"Model actor is out of memory, model id: {self.model_uid()}, error: {ex}"
@@ -17,7 +17,7 @@ from typing import Dict, Union
17
17
 
18
18
  import psutil
19
19
 
20
- from .utils import get_nvidia_gpu_info
20
+ from ..device_utils import get_nvidia_gpu_info
21
21
 
22
22
 
23
23
  @dataclass
@@ -31,9 +31,12 @@ class ResourceStatus:
31
31
 
32
32
  @dataclass
33
33
  class GPUStatus:
34
+ name: str
34
35
  mem_total: float
35
36
  mem_free: float
36
37
  mem_used: float
38
+ mem_usage: float
39
+ gpu_util: float
37
40
 
38
41
 
39
42
  def gather_node_info() -> Dict[str, Union[ResourceStatus, GPUStatus]]:
@@ -48,9 +51,12 @@ def gather_node_info() -> Dict[str, Union[ResourceStatus, GPUStatus]]:
48
51
  )
49
52
  for gpu_idx, gpu_info in get_nvidia_gpu_info().items():
50
53
  node_resource[gpu_idx] = GPUStatus( # type: ignore
54
+ name=gpu_info["name"],
51
55
  mem_total=gpu_info["total"],
52
56
  mem_used=gpu_info["used"],
53
57
  mem_free=gpu_info["free"],
58
+ mem_usage=gpu_info["used"] / gpu_info["total"],
59
+ gpu_util=gpu_info["util"],
54
60
  )
55
61
 
56
62
  return node_resource # type: ignore
@@ -39,6 +39,7 @@ class InstanceInfo(BaseModel):
39
39
  replica: int
40
40
  status: str
41
41
  instance_created_ts: int
42
+ n_worker: Optional[int] = 1
42
43
 
43
44
  def update(self, **kwargs):
44
45
  for field, value in kwargs.items():
@@ -99,7 +99,11 @@ class SupervisorActor(xo.StatelessActor):
99
99
  self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {} # type: ignore
100
100
  self._worker_status: Dict[str, WorkerStatus] = {} # type: ignore
101
101
  self._replica_model_uid_to_worker: Dict[ # type: ignore
102
- str, xo.ActorRefType["WorkerActor"]
102
+ str,
103
+ Union[
104
+ xo.ActorRefType["WorkerActor"],
105
+ Tuple[xo.ActorRefType["WorkerActor"], ...],
106
+ ],
103
107
  ] = {}
104
108
  self._model_uid_to_replica_info: Dict[str, ReplicaInfo] = {} # type: ignore
105
109
  self._uptime = None
@@ -270,8 +274,8 @@ class SupervisorActor(xo.StatelessActor):
270
274
  from ..model.llm.vllm.xavier.block_tracker import VLLMBlockTracker
271
275
  from ..model.llm.vllm.xavier.collective_manager import CollectiveManager
272
276
 
273
- self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {}
274
- self._collective_manager_mapping: Dict[
277
+ self._block_tracker_mapping: Dict[str, xo.ActorRefType[VLLMBlockTracker]] = {} # type: ignore
278
+ self._collective_manager_mapping: Dict[ # type: ignore
275
279
  str, xo.ActorRefType[CollectiveManager]
276
280
  ] = {}
277
281
 
@@ -359,13 +363,16 @@ class SupervisorActor(xo.StatelessActor):
359
363
  worker_ref = await self._choose_worker()
360
364
  return await worker_ref.get_devices_count()
361
365
 
362
- async def _choose_worker(self) -> xo.ActorRefType["WorkerActor"]:
366
+ async def _choose_worker(
367
+ self, available_workers: Optional[List[str]] = None
368
+ ) -> xo.ActorRefType["WorkerActor"]:
363
369
  # TODO: better allocation strategy.
364
370
  min_running_model_count = None
365
371
  target_worker = None
366
372
 
367
- workers = list(self._worker_address_to_worker.values())
368
- for worker in workers:
373
+ for worker_addr, worker in self._worker_address_to_worker.items():
374
+ if available_workers and worker_addr not in available_workers:
375
+ continue
369
376
  running_model_count = await worker.get_model_count()
370
377
  if (
371
378
  min_running_model_count is None
@@ -911,6 +918,7 @@ class SupervisorActor(xo.StatelessActor):
911
918
  model_type: Optional[str],
912
919
  replica: int = 1,
913
920
  n_gpu: Optional[Union[int, str]] = "auto",
921
+ n_worker: Optional[int] = 1,
914
922
  request_limits: Optional[int] = None,
915
923
  wait_ready: bool = True,
916
924
  model_version: Optional[str] = None,
@@ -921,6 +929,35 @@ class SupervisorActor(xo.StatelessActor):
921
929
  model_path: Optional[str] = None,
922
930
  **kwargs,
923
931
  ) -> str:
932
+ if self.is_local_deployment() and n_worker > 1: # type: ignore
933
+ # ignore n_worker > 1 if local deployment
934
+ logger.warning("Local deployment, ignore n_worker(%s)", n_worker)
935
+ n_worker = 1
936
+
937
+ if n_worker > 1: # type: ignore
938
+ # distributed inference
939
+ return await self._launch_builtin_sharded_model(
940
+ model_uid,
941
+ model_name,
942
+ model_size_in_billions,
943
+ model_format,
944
+ quantization,
945
+ model_engine,
946
+ model_type,
947
+ replica=replica,
948
+ n_gpu=n_gpu,
949
+ n_worker=n_worker,
950
+ request_limits=request_limits,
951
+ wait_ready=wait_ready,
952
+ model_version=model_version,
953
+ peft_model_config=peft_model_config,
954
+ worker_ip=worker_ip,
955
+ gpu_idx=gpu_idx,
956
+ download_hub=download_hub,
957
+ model_path=model_path,
958
+ **kwargs,
959
+ )
960
+
924
961
  # search in worker first
925
962
  if not self.is_local_deployment():
926
963
  workers = list(self._worker_address_to_worker.values())
@@ -1157,6 +1194,150 @@ class SupervisorActor(xo.StatelessActor):
1157
1194
  task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
1158
1195
  return model_uid
1159
1196
 
1197
+ async def _launch_builtin_sharded_model(
1198
+ self,
1199
+ model_uid: Optional[str],
1200
+ model_name: str,
1201
+ model_size_in_billions: Optional[Union[int, str]],
1202
+ model_format: Optional[str],
1203
+ quantization: Optional[str],
1204
+ model_engine: Optional[str],
1205
+ model_type: Optional[str],
1206
+ replica: int = 1,
1207
+ n_gpu: Optional[Union[int, str]] = "auto",
1208
+ n_worker: Optional[int] = 1,
1209
+ request_limits: Optional[int] = None,
1210
+ wait_ready: bool = True,
1211
+ model_version: Optional[str] = None,
1212
+ peft_model_config: Optional[PeftModelConfig] = None,
1213
+ worker_ip: Optional[str] = None,
1214
+ gpu_idx: Optional[Union[int, List[int]]] = None,
1215
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
1216
+ model_path: Optional[str] = None,
1217
+ **kwargs,
1218
+ ):
1219
+ available_workers = []
1220
+ # search workers if registered
1221
+ tasks = []
1222
+ if not worker_ip:
1223
+ all_workers = list(self._worker_address_to_worker)
1224
+ for worker in all_workers:
1225
+ tasks.append(
1226
+ self._worker_address_to_worker[worker].get_model_registration(
1227
+ model_type, model_name
1228
+ )
1229
+ )
1230
+ res = await asyncio.gather(*tasks)
1231
+ for worker, res in zip(all_workers, res):
1232
+ # check regi
1233
+ if res:
1234
+ available_workers.append(worker)
1235
+ if not available_workers:
1236
+ # no registration, use all workers
1237
+ available_workers = all_workers
1238
+ else:
1239
+ if isinstance(worker_ip, list):
1240
+ available_workers.extend(worker_ip)
1241
+ else:
1242
+ available_workers.append(worker_ip)
1243
+
1244
+ async def _launch_model():
1245
+ try:
1246
+ for _idx, rep_model_uid in enumerate(
1247
+ iter_replica_model_uid(model_uid, replica)
1248
+ ):
1249
+ replica_gpu_idx = assign_replica_gpu(
1250
+ rep_model_uid, replica, gpu_idx
1251
+ )
1252
+ # launch shard
1253
+ worker_refs = []
1254
+ driver_info = None
1255
+ for i_worker in range(n_worker):
1256
+ worker_ref = await self._choose_worker(available_workers)
1257
+ nonlocal model_type
1258
+ model_type = model_type or "LLM"
1259
+ if i_worker > 1:
1260
+ assert (
1261
+ driver_info is not None
1262
+ ), "driver info should be passed by first model shard"
1263
+ info = await worker_ref.launch_builtin_model(
1264
+ model_uid=rep_model_uid,
1265
+ model_name=model_name,
1266
+ model_size_in_billions=model_size_in_billions,
1267
+ model_format=model_format,
1268
+ quantization=quantization,
1269
+ model_engine=model_engine,
1270
+ model_type=model_type,
1271
+ n_gpu=n_gpu,
1272
+ request_limits=request_limits,
1273
+ peft_model_config=peft_model_config,
1274
+ gpu_idx=replica_gpu_idx,
1275
+ download_hub=download_hub,
1276
+ model_path=model_path,
1277
+ shard=i_worker,
1278
+ n_worker=n_worker,
1279
+ driver_info=driver_info,
1280
+ **kwargs,
1281
+ )
1282
+ if i_worker == 0:
1283
+ # info will be subpool address + driver info
1284
+ # for shard 0
1285
+ driver_info = info[1]
1286
+ worker_refs.append(worker_ref)
1287
+ self._replica_model_uid_to_worker[rep_model_uid] = worker_refs
1288
+
1289
+ # for distributed inference,
1290
+ # launch will run asynchronously,
1291
+ # wait for load complete
1292
+ for worker_ref in worker_refs:
1293
+ await worker_ref.wait_for_load(rep_model_uid)
1294
+ except:
1295
+ # terminate_model will remove the replica info.
1296
+ await self.terminate_model(model_uid, suppress_exception=True)
1297
+ await self._status_guard_ref.update_instance_info(
1298
+ model_uid, {"status": LaunchStatus.ERROR.name}
1299
+ )
1300
+ raise
1301
+
1302
+ if model_uid is None:
1303
+ model_uid = self._gen_model_uid(model_name)
1304
+
1305
+ if not is_valid_model_uid(model_uid):
1306
+ raise ValueError(
1307
+ "The model UID is invalid. Please specify the model UID by 0 < length <= 100."
1308
+ )
1309
+
1310
+ if request_limits is not None and request_limits < 0:
1311
+ raise ValueError(
1312
+ "The `request_limits` parameter must be greater or equal than 0."
1313
+ )
1314
+
1315
+ if model_uid in self._model_uid_to_replica_info:
1316
+ raise ValueError(f"Model is already in the model list, uid: {model_uid}")
1317
+
1318
+ # Set replica info first for exception handler to terminate model.
1319
+ self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
1320
+ replica=replica, scheduler=itertools.cycle(range(replica))
1321
+ )
1322
+ instance_info = InstanceInfo(
1323
+ model_name=model_name,
1324
+ model_uid=model_uid,
1325
+ model_version=model_version,
1326
+ model_ability=[],
1327
+ replica=replica,
1328
+ n_worker=n_worker,
1329
+ status=LaunchStatus.CREATING.name,
1330
+ instance_created_ts=int(time.time()),
1331
+ )
1332
+ await self._status_guard_ref.set_instance_info(model_uid, instance_info)
1333
+ if wait_ready:
1334
+ await _launch_model()
1335
+ else:
1336
+ task = asyncio.create_task(_launch_model())
1337
+ ASYNC_LAUNCH_TASKS[model_uid] = task
1338
+ task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
1339
+ return model_uid
1340
+
1160
1341
  async def get_instance_info(
1161
1342
  self, model_name: Optional[str], model_uid: Optional[str]
1162
1343
  ) -> List[Dict]:
@@ -1186,11 +1367,13 @@ class SupervisorActor(xo.StatelessActor):
1186
1367
  if status.failure_remaining_count <= 0:
1187
1368
  dead_models = []
1188
1369
  for model_uid in self._replica_model_uid_to_worker:
1189
- if (
1190
- self._replica_model_uid_to_worker[model_uid].address
1191
- == address
1192
- ):
1193
- dead_models.append(model_uid)
1370
+ worker_refs = self._replica_model_uid_to_worker[model_uid]
1371
+ if not isinstance(worker_refs, list):
1372
+ worker_refs = [worker_refs]
1373
+ for worker_ref in worker_refs:
1374
+ model_address = worker_ref.address
1375
+ if model_address == address:
1376
+ dead_models.append(model_uid)
1194
1377
  logger.error(
1195
1378
  "Worker dead. address: %s, influenced models: %s",
1196
1379
  address,
@@ -1222,13 +1405,18 @@ class SupervisorActor(xo.StatelessActor):
1222
1405
  @log_async(logger=logger)
1223
1406
  async def terminate_model(self, model_uid: str, suppress_exception=False):
1224
1407
  async def _terminate_one_model(_replica_model_uid):
1225
- worker_ref = self._replica_model_uid_to_worker.get(_replica_model_uid, None)
1408
+ worker_refs = self._replica_model_uid_to_worker.get(
1409
+ _replica_model_uid, None
1410
+ )
1411
+ if not isinstance(worker_refs, list):
1412
+ worker_refs = [worker_refs]
1226
1413
 
1227
- if worker_ref is None:
1228
- raise ValueError(
1229
- f"Model not found in the model list, uid: {_replica_model_uid}"
1230
- )
1231
- await worker_ref.terminate_model(model_uid=_replica_model_uid)
1414
+ for worker_ref in worker_refs:
1415
+ if worker_ref is None:
1416
+ raise ValueError(
1417
+ f"Model not found in the model list, uid: {_replica_model_uid}"
1418
+ )
1419
+ await worker_ref.terminate_model(model_uid=_replica_model_uid)
1232
1420
  del self._replica_model_uid_to_worker[_replica_model_uid]
1233
1421
 
1234
1422
  replica_info = self._model_uid_to_replica_info.get(model_uid, None)
@@ -1290,6 +1478,9 @@ class SupervisorActor(xo.StatelessActor):
1290
1478
  raise ValueError(
1291
1479
  f"Model not found in the model list, uid: {replica_model_uid}"
1292
1480
  )
1481
+ if isinstance(worker_ref, list):
1482
+ # get first worker to fetch information if model across workers
1483
+ worker_ref = worker_ref[0]
1293
1484
  return await worker_ref.get_model(model_uid=replica_model_uid)
1294
1485
 
1295
1486
  @log_async(logger=logger)
@@ -1299,6 +1490,9 @@ class SupervisorActor(xo.StatelessActor):
1299
1490
  raise ValueError(
1300
1491
  f"Model not found in the model list, uid: {replica_model_uid}"
1301
1492
  )
1493
+ if isinstance(worker_ref, list):
1494
+ # get status from first shard if model has multiple shards across workers
1495
+ worker_ref = worker_ref[0]
1302
1496
  return await worker_ref.get_model_status(replica_model_uid)
1303
1497
 
1304
1498
  @log_async(logger=logger)
@@ -1314,6 +1508,9 @@ class SupervisorActor(xo.StatelessActor):
1314
1508
  raise ValueError(
1315
1509
  f"Model not found in the model list, uid: {replica_model_uid}"
1316
1510
  )
1511
+ if isinstance(worker_ref, list):
1512
+ # get status from first shard if model has multiple shards across workers
1513
+ worker_ref = worker_ref[0]
1317
1514
  info = await worker_ref.describe_model(model_uid=replica_model_uid)
1318
1515
  info["replica"] = replica_info.replica
1319
1516
  return info
@@ -1386,6 +1583,9 @@ class SupervisorActor(xo.StatelessActor):
1386
1583
  worker_ref = self._replica_model_uid_to_worker.get(rep_mid, None)
1387
1584
  if worker_ref is None:
1388
1585
  continue
1586
+ if isinstance(worker_ref, list):
1587
+ # get status from first shard if model has multiple shards across workers
1588
+ worker_ref = worker_ref[0]
1389
1589
  model_ref = await worker_ref.get_model(model_uid=rep_mid)
1390
1590
  result_info = await model_ref.abort_request(request_id, block_duration)
1391
1591
  res["msg"] = result_info
@@ -1415,8 +1615,17 @@ class SupervisorActor(xo.StatelessActor):
1415
1615
  async def remove_worker(self, worker_address: str):
1416
1616
  uids_to_remove = []
1417
1617
  for model_uid in self._replica_model_uid_to_worker:
1418
- if self._replica_model_uid_to_worker[model_uid].address == worker_address:
1419
- uids_to_remove.append(model_uid)
1618
+ worker_refs = self._replica_model_uid_to_worker[model_uid]
1619
+ if not isinstance(worker_refs, list):
1620
+ worker_refs = [worker_refs]
1621
+ for worker_ref in worker_refs:
1622
+ model_address = worker_ref.address
1623
+ if isinstance(model_address, str) and model_address == worker_address:
1624
+ uids_to_remove.append(model_uid)
1625
+ elif (
1626
+ isinstance(model_address, list) and worker_address in model_address
1627
+ ):
1628
+ uids_to_remove.append(model_uid)
1420
1629
 
1421
1630
  for replica_model_uid in uids_to_remove:
1422
1631
  model_uid, _ = parse_replica_model_uid(replica_model_uid)
xinference/core/utils.py CHANGED
@@ -19,10 +19,9 @@ import string
19
19
  import uuid
20
20
  import weakref
21
21
  from enum import Enum
22
- from typing import Dict, Generator, List, Optional, Tuple, Union
22
+ from typing import Generator, List, Optional, Tuple, Union
23
23
 
24
24
  import orjson
25
- from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
26
25
 
27
26
  from .._compat import BaseModel
28
27
  from ..constants import (
@@ -248,33 +247,6 @@ def parse_model_version(model_version: str, model_type: str) -> Tuple:
248
247
  raise ValueError(f"Not supported model_type: {model_type}")
249
248
 
250
249
 
251
- def _get_nvidia_gpu_mem_info(gpu_id: int) -> Dict[str, float]:
252
- from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
253
-
254
- handler = nvmlDeviceGetHandleByIndex(gpu_id)
255
- mem_info = nvmlDeviceGetMemoryInfo(handler)
256
- return {"total": mem_info.total, "used": mem_info.used, "free": mem_info.free}
257
-
258
-
259
- def get_nvidia_gpu_info() -> Dict:
260
- try:
261
- nvmlInit()
262
- device_count = nvmlDeviceGetCount()
263
- res = {}
264
- for i in range(device_count):
265
- res[f"gpu-{i}"] = _get_nvidia_gpu_mem_info(i)
266
- return res
267
- except:
268
- # TODO: add log here
269
- # logger.debug(f"Cannot init nvml. Maybe due to lack of NVIDIA GPUs or incorrect installation of CUDA.")
270
- return {}
271
- finally:
272
- try:
273
- nvmlShutdown()
274
- except:
275
- pass
276
-
277
-
278
250
  def assign_replica_gpu(
279
251
  _replica_model_uid: str, replica: int, gpu_idx: Optional[Union[int, List[int]]]
280
252
  ) -> Optional[List[int]]:
xinference/core/worker.py CHANGED
@@ -789,6 +789,9 @@ class WorkerActor(xo.StatelessActor):
789
789
  model_engine: Optional[str],
790
790
  model_type: str = "LLM",
791
791
  n_gpu: Optional[Union[int, str]] = "auto",
792
+ n_worker: Optional[int] = 1,
793
+ shard: Optional[int] = 0,
794
+ driver_info: Optional[dict] = None,
792
795
  peft_model_config: Optional[PeftModelConfig] = None,
793
796
  request_limits: Optional[int] = None,
794
797
  gpu_idx: Optional[Union[int, List[int]]] = None,
@@ -876,6 +879,18 @@ class WorkerActor(xo.StatelessActor):
876
879
  xavier_config: Optional[Dict] = kwargs.pop("xavier_config", None)
877
880
  if xavier_config is not None:
878
881
  xavier_config["rank_address"] = subpool_address
882
+ model_kwargs = kwargs.copy()
883
+ if n_worker > 1: # type: ignore
884
+ # for model across workers,
885
+ # add a few kwargs
886
+ model_kwargs.update(
887
+ dict(
888
+ address=self.address,
889
+ n_worker=n_worker,
890
+ shard=shard,
891
+ driver_info=driver_info,
892
+ )
893
+ )
879
894
  model, model_description = await asyncio.to_thread(
880
895
  create_model_instance,
881
896
  subpool_address,
@@ -890,7 +905,7 @@ class WorkerActor(xo.StatelessActor):
890
905
  peft_model_config,
891
906
  download_hub,
892
907
  model_path,
893
- **kwargs,
908
+ **model_kwargs,
894
909
  )
895
910
  await self.update_cache_status(model_name, model_description)
896
911
  model_ref = await xo.create_actor(
@@ -904,6 +919,9 @@ class WorkerActor(xo.StatelessActor):
904
919
  model_description=model_description,
905
920
  request_limits=request_limits,
906
921
  xavier_config=xavier_config,
922
+ n_worker=n_worker,
923
+ shard=shard,
924
+ driver_info=driver_info,
907
925
  )
908
926
  await model_ref.load()
909
927
  except:
@@ -933,7 +951,15 @@ class WorkerActor(xo.StatelessActor):
933
951
  origin_uid,
934
952
  {"model_ability": abilities, "status": LaunchStatus.READY.name},
935
953
  )
936
- return subpool_address
954
+ if n_worker > 1 and shard == 0: # type: ignore
955
+ return subpool_address, await model_ref.get_driver_info()
956
+ else:
957
+ return subpool_address
958
+
959
+ @log_async(logger=logger, level=logging.INFO)
960
+ async def wait_for_load(self, model_uid: str):
961
+ model_ref = self._model_uid_to_model[model_uid]
962
+ await model_ref.wait_for_load()
937
963
 
938
964
  @log_async(logger=logger, level=logging.INFO)
939
965
  async def terminate_model(self, model_uid: str, is_model_die=False):