xinference 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (53) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +22 -7
  3. xinference/client/restful/restful_client.py +10 -0
  4. xinference/constants.py +14 -4
  5. xinference/core/chat_interface.py +8 -1
  6. xinference/core/resource.py +19 -12
  7. xinference/core/supervisor.py +94 -30
  8. xinference/core/utils.py +29 -1
  9. xinference/core/worker.py +18 -3
  10. xinference/deploy/local.py +2 -2
  11. xinference/deploy/supervisor.py +2 -2
  12. xinference/model/audio/model_spec.json +29 -1
  13. xinference/model/embedding/model_spec.json +24 -0
  14. xinference/model/embedding/model_spec_modelscope.json +24 -0
  15. xinference/model/llm/__init__.py +2 -0
  16. xinference/model/llm/core.py +2 -0
  17. xinference/model/llm/ggml/chatglm.py +15 -6
  18. xinference/model/llm/llm_family.json +56 -0
  19. xinference/model/llm/llm_family_modelscope.json +56 -0
  20. xinference/model/llm/pytorch/chatglm.py +3 -3
  21. xinference/model/llm/pytorch/core.py +1 -0
  22. xinference/model/llm/pytorch/utils.py +21 -9
  23. xinference/model/llm/pytorch/yi_vl.py +246 -0
  24. xinference/model/rerank/core.py +1 -1
  25. xinference/model/rerank/model_spec.json +6 -0
  26. xinference/model/rerank/model_spec_modelscope.json +7 -0
  27. xinference/thirdparty/__init__.py +0 -0
  28. xinference/thirdparty/llava/__init__.py +1 -0
  29. xinference/thirdparty/llava/conversation.py +205 -0
  30. xinference/thirdparty/llava/mm_utils.py +122 -0
  31. xinference/thirdparty/llava/model/__init__.py +1 -0
  32. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  33. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  34. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  35. xinference/thirdparty/llava/model/constants.py +6 -0
  36. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  37. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  38. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  39. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  40. xinference/types.py +1 -1
  41. xinference/web/ui/build/asset-manifest.json +3 -3
  42. xinference/web/ui/build/index.html +1 -1
  43. xinference/web/ui/build/static/js/{main.abedc3c9.js → main.15822aeb.js} +3 -3
  44. xinference/web/ui/build/static/js/{main.abedc3c9.js.map → main.15822aeb.js.map} +1 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  46. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/METADATA +21 -18
  47. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/RECORD +52 -38
  48. xinference/web/ui/node_modules/.cache/babel-loader/c157e34990b23834b7ad4c13c42962209942c60f8130978c1514f3d085cfaea0.json +0 -1
  49. /xinference/web/ui/build/static/js/{main.abedc3c9.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  50. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  51. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  52. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  53. {xinference-0.8.2.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-01-26T16:31:51+0800",
11
+ "date": "2024-02-02T12:27:24+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "6fa3ee0d57378ca0bbe038a84bd6a3df2010703d",
15
- "version": "0.8.2"
14
+ "full-revisionid": "749ef3ff298a94b88c1e67415819fae4fb1de75c",
15
+ "version": "0.8.3"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -62,6 +62,7 @@ from ..types import (
62
62
  CreateChatCompletion,
63
63
  CreateCompletion,
64
64
  ImageList,
65
+ max_tokens_field,
65
66
  )
66
67
  from .oauth2.auth_service import AuthService
67
68
  from .oauth2.types import LoginUserForm
@@ -216,6 +217,9 @@ class RESTfulAPI:
216
217
  self._router.add_api_route(
217
218
  "/v1/models/families", self._get_builtin_families, methods=["GET"]
218
219
  )
220
+ self._router.add_api_route(
221
+ "/v1/cluster/info", self.get_cluster_device_info, methods=["GET"]
222
+ )
219
223
  self._router.add_api_route(
220
224
  "/v1/cluster/devices", self._get_devices_count, methods=["GET"]
221
225
  )
@@ -791,6 +795,9 @@ class RESTfulAPI:
791
795
  }
792
796
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
793
797
 
798
+ if body.max_tokens is None:
799
+ kwargs["max_tokens"] = max_tokens_field.default
800
+
794
801
  if body.logit_bias is not None:
795
802
  raise HTTPException(status_code=501, detail="Not implemented")
796
803
 
@@ -1079,6 +1086,9 @@ class RESTfulAPI:
1079
1086
  }
1080
1087
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1081
1088
 
1089
+ if body.max_tokens is None:
1090
+ kwargs["max_tokens"] = max_tokens_field.default
1091
+
1082
1092
  if body.logit_bias is not None:
1083
1093
  raise HTTPException(status_code=501, detail="Not implemented")
1084
1094
 
@@ -1147,16 +1157,13 @@ class RESTfulAPI:
1147
1157
  raise HTTPException(status_code=500, detail=str(e))
1148
1158
 
1149
1159
  model_name = desc.get("model_name", "")
1150
- is_chatglm_ggml = (
1151
- desc.get("model_format") == "ggmlv3" and "chatglm" in model_name
1152
- )
1153
1160
  function_call_models = ["chatglm3", "gorilla-openfunctions-v1", "qwen-chat"]
1154
1161
 
1155
1162
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen" in model_name
1156
1163
 
1157
- if (is_chatglm_ggml or is_qwen) and system_prompt is not None:
1164
+ if is_qwen and system_prompt is not None:
1158
1165
  raise HTTPException(
1159
- status_code=400, detail="ChatGLM ggml does not have system prompt"
1166
+ status_code=400, detail="Qwen ggml does not have system prompt"
1160
1167
  )
1161
1168
 
1162
1169
  if not any(name in model_name for name in function_call_models):
@@ -1181,7 +1188,7 @@ class RESTfulAPI:
1181
1188
  iterator = None
1182
1189
  try:
1183
1190
  try:
1184
- if is_chatglm_ggml or is_qwen:
1191
+ if is_qwen:
1185
1192
  iterator = await model.chat(prompt, chat_history, kwargs)
1186
1193
  else:
1187
1194
  iterator = await model.chat(
@@ -1201,7 +1208,7 @@ class RESTfulAPI:
1201
1208
  return EventSourceResponse(stream_results())
1202
1209
  else:
1203
1210
  try:
1204
- if is_chatglm_ggml or is_qwen:
1211
+ if is_qwen:
1205
1212
  data = await model.chat(prompt, chat_history, kwargs)
1206
1213
  else:
1207
1214
  data = await model.chat(prompt, system_prompt, chat_history, kwargs)
@@ -1285,6 +1292,14 @@ class RESTfulAPI:
1285
1292
  logger.error(e, exc_info=True)
1286
1293
  raise HTTPException(status_code=500, detail=str(e))
1287
1294
 
1295
+ async def get_cluster_device_info(self) -> JSONResponse:
1296
+ try:
1297
+ data = await (await self._get_supervisor_ref()).get_cluster_device_info()
1298
+ return JSONResponse(content=data)
1299
+ except Exception as e:
1300
+ logger.error(e, exc_info=True)
1301
+ raise HTTPException(status_code=500, detail=str(e))
1302
+
1288
1303
 
1289
1304
  def run(
1290
1305
  supervisor_address: str,
@@ -404,6 +404,7 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
404
404
  def chat(
405
405
  self,
406
406
  prompt: str,
407
+ system_prompt: Optional[str] = None,
407
408
  chat_history: Optional[List["ChatCompletionMessage"]] = None,
408
409
  tools: Optional[List[Dict]] = None,
409
410
  generate_config: Optional["ChatglmCppGenerateConfig"] = None,
@@ -415,6 +416,8 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
415
416
  ----------
416
417
  prompt: str
417
418
  The user's input.
419
+ system_prompt: Optional[str]
420
+ The system context provide to Model prior to any chats.
418
421
  chat_history: Optional[List["ChatCompletionMessage"]]
419
422
  A list of messages comprising the conversation so far.
420
423
  tools: Optional[List[Dict]]
@@ -441,6 +444,13 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
441
444
  if chat_history is None:
442
445
  chat_history = []
443
446
 
447
+ if chat_history and chat_history[0]["role"] == "system":
448
+ if system_prompt is not None:
449
+ chat_history[0]["content"] = system_prompt
450
+ else:
451
+ if system_prompt is not None:
452
+ chat_history.insert(0, {"role": "system", "content": system_prompt})
453
+
444
454
  chat_history.append({"role": "user", "content": prompt})
445
455
 
446
456
  request_body: Dict[str, Any] = {
xinference/constants.py CHANGED
@@ -18,8 +18,12 @@ from pathlib import Path
18
18
  XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
19
19
  XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
20
20
  XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
21
- XINFERENCE_ENV_HEALTH_CHECK_ATTEMPTS = "XINFERENCE_HEALTH_CHECK_ATTEMPTS"
21
+ XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
22
+ "XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
23
+ )
22
24
  XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
25
+ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
26
+ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
23
27
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
24
28
 
25
29
 
@@ -47,10 +51,16 @@ XINFERENCE_DEFAULT_ENDPOINT_PORT = 9997
47
51
  XINFERENCE_DEFAULT_LOG_FILE_NAME = "xinference.log"
48
52
  XINFERENCE_LOG_MAX_BYTES = 100 * 1024 * 1024
49
53
  XINFERENCE_LOG_BACKUP_COUNT = 30
50
- XINFERENCE_HEALTH_CHECK_ATTEMPTS = int(
51
- os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_ATTEMPTS, 3)
54
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD = int(
55
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD, 5)
52
56
  )
53
57
  XINFERENCE_HEALTH_CHECK_INTERVAL = int(
54
- os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_INTERVAL, 3)
58
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_INTERVAL, 5)
59
+ )
60
+ XINFERENCE_HEALTH_CHECK_TIMEOUT = int(
61
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT, 10)
62
+ )
63
+ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
64
+ int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
55
65
  )
56
66
  XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
@@ -98,9 +98,16 @@ class GradioInterface:
98
98
  return flat_list
99
99
 
100
100
  def to_chat(lst: List[str]) -> List[ChatCompletionMessage]:
101
+ from ..model.llm import BUILTIN_LLM_PROMPT_STYLE
102
+
101
103
  res = []
104
+ prompt_style = BUILTIN_LLM_PROMPT_STYLE.get(self.model_name)
105
+ if prompt_style is None:
106
+ roles = ["assistant", "user"]
107
+ else:
108
+ roles = prompt_style.roles
102
109
  for i in range(len(lst)):
103
- role = "assistant" if i % 2 == 1 else "user"
110
+ role = roles[0] if i % 2 == 1 else roles[1]
104
111
  res.append(ChatCompletionMessage(role=role, content=lst[i]))
105
112
  return res
106
113
 
@@ -13,10 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from dataclasses import dataclass
16
- from typing import Dict
16
+ from typing import Dict, Union
17
17
 
18
18
  import psutil
19
19
 
20
+ from .utils import get_nvidia_gpu_info
21
+
20
22
 
21
23
  @dataclass
22
24
  class ResourceStatus:
@@ -26,7 +28,14 @@ class ResourceStatus:
26
28
  memory_total: float
27
29
 
28
30
 
29
- def gather_node_info() -> Dict[str, ResourceStatus]:
31
+ @dataclass
32
+ class GPUStatus:
33
+ mem_total: float
34
+ mem_free: float
35
+ mem_used: float
36
+
37
+
38
+ def gather_node_info() -> Dict[str, Union[ResourceStatus, GPUStatus]]:
30
39
  node_resource = dict()
31
40
  mem_info = psutil.virtual_memory()
32
41
  node_resource["cpu"] = ResourceStatus(
@@ -35,13 +44,11 @@ def gather_node_info() -> Dict[str, ResourceStatus]:
35
44
  memory_available=mem_info.available,
36
45
  memory_total=mem_info.total,
37
46
  )
38
- # TODO: record GPU stats
39
- # for idx, gpu_card_stat in enumerate(resource.cuda_card_stats()):
40
- # node_resource[f"gpu-{idx}"] = ResourceStatus(
41
- # available=gpu_card_stat.gpu_usage / 100.0,
42
- # total=1,
43
- # memory_available=gpu_card_stat.fb_mem_info.available,
44
- # memory_total=gpu_card_stat.fb_mem_info.total,
45
- # )
46
-
47
- return node_resource
47
+ for gpu_idx, gpu_info in get_nvidia_gpu_info().items():
48
+ node_resource[gpu_idx] = GPUStatus( # type: ignore
49
+ mem_total=gpu_info["total"],
50
+ mem_used=gpu_info["used"],
51
+ mem_free=gpu_info["free"],
52
+ )
53
+
54
+ return node_resource # type: ignore
@@ -21,10 +21,16 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Un
21
21
 
22
22
  import xoscar as xo
23
23
 
24
+ from ..constants import (
25
+ XINFERENCE_DISABLE_HEALTH_CHECK,
26
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
+ XINFERENCE_HEALTH_CHECK_INTERVAL,
28
+ XINFERENCE_HEALTH_CHECK_TIMEOUT,
29
+ )
24
30
  from ..core import ModelActor
25
31
  from ..core.status_guard import InstanceInfo, LaunchStatus
26
32
  from .metrics import record_metrics
27
- from .resource import ResourceStatus
33
+ from .resource import GPUStatus, ResourceStatus
28
34
  from .utils import (
29
35
  build_replica_model_uid,
30
36
  gen_random_string,
@@ -48,7 +54,6 @@ if TYPE_CHECKING:
48
54
  logger = getLogger(__name__)
49
55
 
50
56
 
51
- DEFAULT_NODE_TIMEOUT = 60
52
57
  ASYNC_LAUNCH_TASKS = {} # type: ignore
53
58
 
54
59
 
@@ -60,7 +65,8 @@ def callback_for_async_launch(model_uid: str):
60
65
  @dataclass
61
66
  class WorkerStatus:
62
67
  update_time: float
63
- status: Dict[str, ResourceStatus]
68
+ failure_remaining_count: int
69
+ status: Dict[str, Union[ResourceStatus, GPUStatus]]
64
70
 
65
71
 
66
72
  @dataclass
@@ -87,8 +93,15 @@ class SupervisorActor(xo.StatelessActor):
87
93
 
88
94
  async def __post_create__(self):
89
95
  self._uptime = time.time()
90
- # comment this line to avoid worker lost
91
- # self._check_dead_nodes_task = asyncio.create_task(self._check_dead_nodes())
96
+ if not XINFERENCE_DISABLE_HEALTH_CHECK:
97
+ # Run _check_dead_nodes() in a dedicated thread.
98
+ from ..isolation import Isolation
99
+
100
+ self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
101
+ self._isolation.start()
102
+ asyncio.run_coroutine_threadsafe(
103
+ self._check_dead_nodes(), loop=self._isolation.loop
104
+ )
92
105
  logger.info(f"Xinference supervisor {self.address} started")
93
106
  from .cache_tracker import CacheTrackerActor
94
107
  from .status_guard import StatusGuardActor
@@ -166,6 +179,30 @@ class SupervisorActor(xo.StatelessActor):
166
179
  model_version_infos, self.address
167
180
  )
168
181
 
182
+ async def get_cluster_device_info(self) -> List:
183
+ supervisor_device_info = {
184
+ "ip_address": self.address.split(":")[0],
185
+ "gpu_count": 0,
186
+ "gpu_vram_total": 0,
187
+ }
188
+ res = [{"node_type": "Supervisor", **supervisor_device_info}]
189
+ for worker_addr, worker_status in self._worker_status.items():
190
+ vram_total: float = sum(
191
+ [v.mem_total for k, v in worker_status.status.items() if k != "cpu"] # type: ignore
192
+ )
193
+ total = (
194
+ vram_total if vram_total == 0 else f"{int(vram_total / 1024 / 1024)}MiB"
195
+ )
196
+ res.append(
197
+ {
198
+ "node_type": "Worker",
199
+ "ip_address": worker_addr.split(":")[0],
200
+ "gpu_count": len(worker_status.status) - 1,
201
+ "gpu_vram_total": total,
202
+ }
203
+ )
204
+ return res
205
+
169
206
  @staticmethod
170
207
  async def get_builtin_prompts() -> Dict[str, Any]:
171
208
  from ..model.llm.llm_family import BUILTIN_LLM_PROMPT_STYLE
@@ -752,27 +789,48 @@ class SupervisorActor(xo.StatelessActor):
752
789
 
753
790
  async def _check_dead_nodes(self):
754
791
  while True:
755
- dead_nodes = []
756
- for address, status in self._worker_status.items():
757
- if time.time() - status.update_time > DEFAULT_NODE_TIMEOUT:
758
- dead_models = []
759
- for model_uid in self._replica_model_uid_to_worker:
760
- if (
761
- self._replica_model_uid_to_worker[model_uid].address
762
- == address
763
- ):
764
- dead_models.append(model_uid)
765
- logger.error(
766
- "Worker timeout. address: %s, influenced models: %s",
767
- address,
768
- dead_models,
769
- )
770
- dead_nodes.append(address)
771
-
772
- for address in dead_nodes:
773
- self._worker_status.pop(address)
774
- self._worker_address_to_worker.pop(address)
775
- await asyncio.sleep(5)
792
+ try:
793
+ dead_nodes = []
794
+ for address, status in self._worker_status.items():
795
+ if (
796
+ time.time() - status.update_time
797
+ > XINFERENCE_HEALTH_CHECK_TIMEOUT
798
+ ):
799
+ status.failure_remaining_count -= 1
800
+ else:
801
+ status.failure_remaining_count = (
802
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD
803
+ )
804
+
805
+ if status.failure_remaining_count <= 0:
806
+ dead_models = []
807
+ for model_uid in self._replica_model_uid_to_worker:
808
+ if (
809
+ self._replica_model_uid_to_worker[model_uid].address
810
+ == address
811
+ ):
812
+ dead_models.append(model_uid)
813
+ logger.error(
814
+ "Worker dead. address: %s, influenced models: %s",
815
+ address,
816
+ dead_models,
817
+ )
818
+ dead_nodes.append(address)
819
+ elif (
820
+ status.failure_remaining_count
821
+ != XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD
822
+ ):
823
+ logger.error(
824
+ "Worker timeout. address: %s, check count remaining %s...",
825
+ address,
826
+ status.failure_remaining_count,
827
+ )
828
+
829
+ for address in dead_nodes:
830
+ self._worker_status.pop(address, None)
831
+ self._worker_address_to_worker.pop(address, None)
832
+ finally:
833
+ await asyncio.sleep(XINFERENCE_HEALTH_CHECK_INTERVAL)
776
834
 
777
835
  @log_async(logger=logger)
778
836
  async def terminate_model(self, model_uid: str, suppress_exception=False):
@@ -871,13 +929,19 @@ class SupervisorActor(xo.StatelessActor):
871
929
  )
872
930
 
873
931
  async def report_worker_status(
874
- self, worker_address: str, status: Dict[str, ResourceStatus]
932
+ self, worker_address: str, status: Dict[str, Union[ResourceStatus, GPUStatus]]
875
933
  ):
876
934
  if worker_address not in self._worker_status:
877
935
  logger.debug("Worker %s resources: %s", worker_address, status)
878
- self._worker_status[worker_address] = WorkerStatus(
879
- update_time=time.time(), status=status
880
- )
936
+ self._worker_status[worker_address] = WorkerStatus(
937
+ update_time=time.time(),
938
+ failure_remaining_count=XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
939
+ status=status,
940
+ )
941
+ else:
942
+ worker_status = self._worker_status[worker_address]
943
+ worker_status.update_time = time.time()
944
+ worker_status.status = status
881
945
 
882
946
  @staticmethod
883
947
  def record_metrics(name, op, kwargs):
xinference/core/utils.py CHANGED
@@ -16,10 +16,11 @@ import logging
16
16
  import os
17
17
  import random
18
18
  import string
19
- from typing import Generator, List, Tuple, Union
19
+ from typing import Dict, Generator, List, Tuple, Union
20
20
 
21
21
  import orjson
22
22
  from pydantic import BaseModel
23
+ from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
23
24
 
24
25
  logger = logging.getLogger(__name__)
25
26
 
@@ -162,3 +163,30 @@ def parse_model_version(model_version: str, model_type: str) -> Tuple:
162
163
  return tuple(results)
163
164
  else:
164
165
  raise ValueError(f"Not supported model_type: {model_type}")
166
+
167
+
168
+ def _get_nvidia_gpu_mem_info(gpu_id: int) -> Dict[str, float]:
169
+ from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
170
+
171
+ handler = nvmlDeviceGetHandleByIndex(gpu_id)
172
+ mem_info = nvmlDeviceGetMemoryInfo(handler)
173
+ return {"total": mem_info.total, "used": mem_info.used, "free": mem_info.free}
174
+
175
+
176
+ def get_nvidia_gpu_info() -> Dict:
177
+ try:
178
+ nvmlInit()
179
+ device_count = nvmlDeviceGetCount()
180
+ res = {}
181
+ for i in range(device_count):
182
+ res[f"gpu-{i}"] = _get_nvidia_gpu_mem_info(i)
183
+ return res
184
+ except:
185
+ # TODO: add log here
186
+ # logger.debug(f"Cannot init nvml. Maybe due to lack of NVIDIA GPUs or incorrect installation of CUDA.")
187
+ return {}
188
+ finally:
189
+ try:
190
+ nvmlShutdown()
191
+ except:
192
+ pass
xinference/core/worker.py CHANGED
@@ -24,6 +24,7 @@ from logging import getLogger
24
24
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
25
25
 
26
26
  import xoscar as xo
27
+ from async_timeout import timeout
27
28
  from xoscar import MainActorPoolType
28
29
 
29
30
  from ..constants import XINFERENCE_CACHE_DIR
@@ -152,6 +153,7 @@ class WorkerActor(xo.StatelessActor):
152
153
  return "worker"
153
154
 
154
155
  async def __post_create__(self):
156
+ from ..isolation import Isolation
155
157
  from .cache_tracker import CacheTrackerActor
156
158
  from .status_guard import StatusGuardActor
157
159
  from .supervisor import SupervisorActor
@@ -175,7 +177,12 @@ class WorkerActor(xo.StatelessActor):
175
177
  address=self._supervisor_address, uid=SupervisorActor.uid()
176
178
  )
177
179
  await self._supervisor_ref.add_worker(self.address)
178
- self._upload_task = asyncio.create_task(self._periodical_report_status())
180
+ # Run _periodical_report_status() in a dedicated thread.
181
+ self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
182
+ self._isolation.start()
183
+ asyncio.run_coroutine_threadsafe(
184
+ self._periodical_report_status(), loop=self._isolation.loop
185
+ )
179
186
  logger.info(f"Xinference worker {self.address} started")
180
187
  logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
181
188
  purge_dir(XINFERENCE_CACHE_DIR)
@@ -233,7 +240,7 @@ class WorkerActor(xo.StatelessActor):
233
240
  )
234
241
 
235
242
  async def __pre_destroy__(self):
236
- self._upload_task.cancel()
243
+ self._isolation.stop()
237
244
 
238
245
  @staticmethod
239
246
  def get_devices_count():
@@ -628,7 +635,15 @@ class WorkerActor(xo.StatelessActor):
628
635
  return model_desc.to_dict()
629
636
 
630
637
  async def report_status(self):
631
- status = await asyncio.to_thread(gather_node_info)
638
+ status = dict()
639
+ try:
640
+ # asyncio.timeout is only available in Python >= 3.11
641
+ async with timeout(2):
642
+ status = await asyncio.to_thread(gather_node_info)
643
+ except asyncio.CancelledError:
644
+ raise
645
+ except Exception:
646
+ logger.exception("Report status got error.")
632
647
  await self._supervisor_ref.report_worker_status(self.address, status)
633
648
 
634
649
  async def _periodical_report_status(self):
@@ -23,7 +23,7 @@ import xoscar as xo
23
23
  from xoscar.utils import get_next_port
24
24
 
25
25
  from ..constants import (
26
- XINFERENCE_HEALTH_CHECK_ATTEMPTS,
26
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
27
  XINFERENCE_HEALTH_CHECK_INTERVAL,
28
28
  )
29
29
  from ..core.supervisor import SupervisorActor
@@ -116,7 +116,7 @@ def main(
116
116
 
117
117
  if not health_check(
118
118
  address=supervisor_address,
119
- max_attempts=XINFERENCE_HEALTH_CHECK_ATTEMPTS,
119
+ max_attempts=XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
120
120
  sleep_interval=XINFERENCE_HEALTH_CHECK_INTERVAL,
121
121
  ):
122
122
  raise RuntimeError("Cluster is not available after multiple attempts")
@@ -23,7 +23,7 @@ import xoscar as xo
23
23
  from xoscar.utils import get_next_port
24
24
 
25
25
  from ..constants import (
26
- XINFERENCE_HEALTH_CHECK_ATTEMPTS,
26
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
27
  XINFERENCE_HEALTH_CHECK_INTERVAL,
28
28
  )
29
29
  from ..core.supervisor import SupervisorActor
@@ -82,7 +82,7 @@ def main(
82
82
 
83
83
  if not health_check(
84
84
  address=supervisor_address,
85
- max_attempts=XINFERENCE_HEALTH_CHECK_ATTEMPTS,
85
+ max_attempts=XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
86
86
  sleep_interval=XINFERENCE_HEALTH_CHECK_INTERVAL,
87
87
  ):
88
88
  raise RuntimeError("Supervisor is not available after multiple attempts")
@@ -27,6 +27,20 @@
27
27
  "model_revision": "911407f4214e0e1d82085af863093ec0b66f9cd6",
28
28
  "multilingual": false
29
29
  },
30
+ {
31
+ "model_name": "whisper-small",
32
+ "model_family": "whisper",
33
+ "model_id": "openai/whisper-small",
34
+ "model_revision": "998cb1a777c20db53d6033a61b977ed4c3792cac",
35
+ "multilingual": true
36
+ },
37
+ {
38
+ "model_name": "whisper-small.en",
39
+ "model_family": "whisper",
40
+ "model_id": "openai/whisper-small.en",
41
+ "model_revision": "e8727524f962ee844a7319d92be39ac1bd25655a",
42
+ "multilingual": false
43
+ },
30
44
  {
31
45
  "model_name": "whisper-medium",
32
46
  "model_family": "whisper",
@@ -47,5 +61,19 @@
47
61
  "model_id": "openai/whisper-large-v3",
48
62
  "model_revision": "6cdf07a7e3ec3806e5d55f787915b85d4cd020b1",
49
63
  "multilingual": true
64
+ },
65
+ {
66
+ "model_name": "Belle-distilwhisper-large-v2-zh",
67
+ "model_family": "whisper",
68
+ "model_id": "BELLE-2/Belle-distilwhisper-large-v2-zh",
69
+ "model_revision": "ed25d13498fa5bac758b2fc479435b698532dfe8",
70
+ "multilingual": false
71
+ },
72
+ {
73
+ "model_name": "Belle-whisper-large-v2-zh",
74
+ "model_family": "whisper",
75
+ "model_id": "BELLE-2/Belle-whisper-large-v2-zh",
76
+ "model_revision": "ec5bd5d78598545b7585814edde86dac2002b5b9",
77
+ "multilingual": false
50
78
  }
51
- ]
79
+ ]
@@ -143,6 +143,14 @@
143
143
  "model_id": "jinaai/jina-embeddings-v2-base-en",
144
144
  "model_revision": "7302ac470bed880590f9344bfeee32ff8722d0e5"
145
145
  },
146
+ {
147
+ "model_name": "jina-embeddings-v2-base-zh",
148
+ "dimensions": 768,
149
+ "max_tokens": 8192,
150
+ "language": ["zh", "en"],
151
+ "model_id": "jinaai/jina-embeddings-v2-base-zh",
152
+ "model_revision": "67974cbef5cf50562eadd745de8afc661c52c96f"
153
+ },
146
154
  {
147
155
  "model_name": "text2vec-large-chinese",
148
156
  "dimensions": 1024,
@@ -182,5 +190,21 @@
182
190
  "language": ["zh"],
183
191
  "model_id": "shibing624/text2vec-base-multilingual",
184
192
  "model_revision": "f241877385fa56ebcc75f04d1850e1579cfa661d"
193
+ },
194
+ {
195
+ "model_name": "bge-m3",
196
+ "dimensions": 1024,
197
+ "max_tokens": 8192,
198
+ "language": ["zh", "en"],
199
+ "model_id": "BAAI/bge-m3",
200
+ "model_revision": "73a15ad29ab604f3bdc31601849a9defe86d563f"
201
+ },
202
+ {
203
+ "model_name": "bce-embedding-base_v1",
204
+ "dimensions": 768,
205
+ "max_tokens": 512,
206
+ "language": ["zh", "en"],
207
+ "model_id": "maidalun1020/bce-embedding-base_v1",
208
+ "model_revision": "236d9024fc1b4046f03848723f934521a66a9323"
185
209
  }
186
210
  ]
@@ -161,6 +161,14 @@
161
161
  "model_revision": "v0.0.1",
162
162
  "model_hub": "modelscope"
163
163
  },
164
+ {
165
+ "model_name": "jina-embeddings-v2-base-zh",
166
+ "dimensions": 768,
167
+ "max_tokens": 8192,
168
+ "language": ["zh", "en"],
169
+ "model_id": "jinaai/jina-embeddings-v2-base-zh",
170
+ "model_hub": "modelscope"
171
+ },
164
172
  {
165
173
  "model_name": "text2vec-large-chinese",
166
174
  "dimensions": 1024,
@@ -184,5 +192,21 @@
184
192
  "language": ["zh"],
185
193
  "model_id": "mwei23/text2vec-base-chinese-paraphrase",
186
194
  "model_hub": "modelscope"
195
+ },
196
+ {
197
+ "model_name": "bge-m3",
198
+ "dimensions": 1024,
199
+ "max_tokens": 8192,
200
+ "language": ["zh", "en"],
201
+ "model_id": "Xorbits/bge-m3",
202
+ "model_hub": "modelscope"
203
+ },
204
+ {
205
+ "model_name": "bce-embedding-base_v1",
206
+ "dimensions": 768,
207
+ "max_tokens": 512,
208
+ "language": ["zh", "en"],
209
+ "model_id": "maidalun/bce-embedding-base_v1",
210
+ "model_hub": "modelscope"
187
211
  }
188
212
  ]