xinference 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (64) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +16 -11
  5. xinference/client/restful/restful_client.py +12 -2
  6. xinference/conftest.py +13 -2
  7. xinference/constants.py +2 -0
  8. xinference/core/supervisor.py +32 -1
  9. xinference/core/worker.py +139 -20
  10. xinference/deploy/cmdline.py +119 -20
  11. xinference/model/llm/__init__.py +6 -0
  12. xinference/model/llm/llm_family.json +711 -10
  13. xinference/model/llm/llm_family_modelscope.json +557 -7
  14. xinference/model/llm/pytorch/chatglm.py +2 -1
  15. xinference/model/llm/pytorch/core.py +2 -0
  16. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  17. xinference/model/llm/pytorch/internlm2.py +2 -1
  18. xinference/model/llm/pytorch/omnilmm.py +153 -0
  19. xinference/model/llm/sglang/__init__.py +13 -0
  20. xinference/model/llm/sglang/core.py +365 -0
  21. xinference/model/llm/utils.py +46 -13
  22. xinference/model/llm/vllm/core.py +10 -0
  23. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  24. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  25. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  26. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  27. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  28. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  29. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  30. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  31. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  32. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  33. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  34. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  35. xinference/thirdparty/omnilmm/__init__.py +0 -0
  36. xinference/thirdparty/omnilmm/chat.py +216 -0
  37. xinference/thirdparty/omnilmm/constants.py +4 -0
  38. xinference/thirdparty/omnilmm/conversation.py +332 -0
  39. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  40. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  41. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  42. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  43. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  44. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  45. xinference/thirdparty/omnilmm/utils.py +134 -0
  46. xinference/web/ui/build/asset-manifest.json +3 -3
  47. xinference/web/ui/build/index.html +1 -1
  48. xinference/web/ui/build/static/js/main.98516614.js +3 -0
  49. xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
  54. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/METADATA +21 -5
  55. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/RECORD +60 -31
  56. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  57. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  60. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
  61. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
  62. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
  63. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
  64. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-03-15T11:03:04+0800",
11
+ "date": "2024-03-29T12:46:14+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "60f098c78adb7e28dd564278a7741b56eaf062d8",
15
- "version": "0.9.3"
14
+ "full-revisionid": "2857ec497afbd2a6895d3658384ff3b4022b2840",
15
+ "version": "0.10.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -11,8 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import re
14
15
  from datetime import timedelta
15
- from typing import List, Optional
16
+ from typing import List, Optional, Tuple
16
17
 
17
18
  from fastapi import Depends, HTTPException, status
18
19
  from fastapi.security import OAuth2PasswordBearer, SecurityScopes
@@ -40,13 +41,30 @@ class AuthService:
40
41
  def config(self):
41
42
  return self._config
42
43
 
44
+ @staticmethod
45
+ def is_legal_api_key(key: str) -> bool:
46
+ pattern = re.compile("^sk-[a-zA-Z0-9]{13}$")
47
+ return re.match(pattern, key) is not None
48
+
43
49
  def init_auth_config(self):
44
50
  if self._auth_config_file:
45
51
  config: AuthStartupConfig = parse_file_as(
46
52
  path=self._auth_config_file, type_=AuthStartupConfig
47
53
  )
54
+ all_api_keys = set()
48
55
  for user in config.user_config:
49
56
  user.password = get_password_hash(user.password)
57
+ for api_key in user.api_keys:
58
+ if not self.is_legal_api_key(api_key):
59
+ raise ValueError(
60
+ "Api-Key should be a string started with 'sk-' with a total length of 16"
61
+ )
62
+ if api_key in all_api_keys:
63
+ raise ValueError(
64
+ "Duplicate api-keys exists, please check your configuration"
65
+ )
66
+ else:
67
+ all_api_keys.add(api_key)
50
68
  return config
51
69
 
52
70
  def __call__(
@@ -67,28 +85,30 @@ class AuthService:
67
85
  headers={"WWW-Authenticate": authenticate_value},
68
86
  )
69
87
 
70
- try:
71
- assert self._config is not None
72
- payload = jwt.decode(
73
- token,
74
- self._config.auth_config.secret_key,
75
- algorithms=[self._config.auth_config.algorithm],
76
- options={"verify_exp": False}, # TODO: supports token expiration
77
- )
78
- username: str = payload.get("sub")
79
- if username is None:
88
+ if self.is_legal_api_key(token):
89
+ user, token_scopes = self.get_user_and_scopes_with_api_key(token)
90
+ else:
91
+ try:
92
+ assert self._config is not None
93
+ payload = jwt.decode(
94
+ token,
95
+ self._config.auth_config.secret_key,
96
+ algorithms=[self._config.auth_config.algorithm],
97
+ options={"verify_exp": False}, # TODO: supports token expiration
98
+ )
99
+ username: str = payload.get("sub")
100
+ if username is None:
101
+ raise credentials_exception
102
+ token_scopes = payload.get("scopes", [])
103
+ user = self.get_user(username)
104
+ except (JWTError, ValidationError):
80
105
  raise credentials_exception
81
- token_scopes = payload.get("scopes", [])
82
- token_data = TokenData(scopes=token_scopes, username=username)
83
- except (JWTError, ValidationError):
84
- raise credentials_exception
85
- user = self.get_user(token_data.username)
86
106
  if user is None:
87
107
  raise credentials_exception
88
- if "admin" in token_data.scopes:
108
+ if "admin" in token_scopes:
89
109
  return user
90
110
  for scope in security_scopes.scopes:
91
- if scope not in token_data.scopes:
111
+ if scope not in token_scopes:
92
112
  raise HTTPException(
93
113
  status_code=status.HTTP_403_FORBIDDEN,
94
114
  detail="Not enough permissions",
@@ -102,6 +122,15 @@ class AuthService:
102
122
  return user
103
123
  return None
104
124
 
125
+ def get_user_and_scopes_with_api_key(
126
+ self, api_key: str
127
+ ) -> Tuple[Optional[User], List]:
128
+ for user in self._config.user_config:
129
+ for key in user.api_keys:
130
+ if api_key == key:
131
+ return user, user.permissions
132
+ return None, []
133
+
105
134
  def authenticate_user(self, username: str, password: str):
106
135
  user = self.get_user(username)
107
136
  if not user:
@@ -23,6 +23,7 @@ class LoginUserForm(BaseModel):
23
23
 
24
24
  class User(LoginUserForm):
25
25
  permissions: List[str]
26
+ api_keys: List[str]
26
27
 
27
28
 
28
29
  class AuthConfig(BaseModel):
@@ -59,6 +59,7 @@ from ..core.utils import json_dumps
59
59
  from ..types import (
60
60
  SPECIAL_TOOL_PROMPT,
61
61
  ChatCompletion,
62
+ ChatCompletionMessage,
62
63
  Completion,
63
64
  CreateChatCompletion,
64
65
  CreateCompletion,
@@ -88,7 +89,9 @@ class CreateCompletionRequest(CreateCompletion):
88
89
 
89
90
  class CreateEmbeddingRequest(BaseModel):
90
91
  model: str
91
- input: Union[str, List[str]] = Field(description="The input to embed.")
92
+ input: Union[str, List[str], List[int], List[List[int]]] = Field(
93
+ description="The input to embed."
94
+ )
92
95
  user: Optional[str] = None
93
96
 
94
97
  class Config:
@@ -692,6 +695,8 @@ class RESTfulAPI:
692
695
  peft_model_path = payload.get("peft_model_path", None)
693
696
  image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
694
697
  image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
698
+ worker_ip = payload.get("worker_ip", None)
699
+ gpu_idx = payload.get("gpu_idx", None)
695
700
 
696
701
  exclude_keys = {
697
702
  "model_uid",
@@ -706,6 +711,8 @@ class RESTfulAPI:
706
711
  "peft_model_path",
707
712
  "image_lora_load_kwargs",
708
713
  "image_lora_fuse_kwargs",
714
+ "worker_ip",
715
+ "gpu_idx",
709
716
  }
710
717
 
711
718
  kwargs = {
@@ -733,6 +740,8 @@ class RESTfulAPI:
733
740
  peft_model_path=peft_model_path,
734
741
  image_lora_load_kwargs=image_lora_load_kwargs,
735
742
  image_lora_fuse_kwargs=image_lora_fuse_kwargs,
743
+ worker_ip=worker_ip,
744
+ gpu_idx=gpu_idx,
736
745
  **kwargs,
737
746
  )
738
747
 
@@ -1258,25 +1267,21 @@ class RESTfulAPI:
1258
1267
  status_code=400, detail="Invalid input. Please specify the prompt."
1259
1268
  )
1260
1269
 
1261
- system_messages = []
1270
+ system_messages: List["ChatCompletionMessage"] = []
1271
+ system_messages_contents = []
1262
1272
  non_system_messages = []
1263
1273
  for msg in messages:
1264
1274
  assert (
1265
1275
  msg.get("content") != SPECIAL_TOOL_PROMPT
1266
1276
  ), f"Invalid message content {SPECIAL_TOOL_PROMPT}"
1267
1277
  if msg["role"] == "system":
1268
- system_messages.append(msg)
1278
+ system_messages_contents.append(msg["content"])
1269
1279
  else:
1270
1280
  non_system_messages.append(msg)
1281
+ system_messages.append(
1282
+ {"role": "system", "content": ". ".join(system_messages_contents)}
1283
+ )
1271
1284
 
1272
- if len(system_messages) > 1:
1273
- raise HTTPException(
1274
- status_code=400, detail="Multiple system messages are not supported."
1275
- )
1276
- if len(system_messages) == 1 and messages[0]["role"] != "system":
1277
- raise HTTPException(
1278
- status_code=400, detail="System message should be the first one."
1279
- )
1280
1285
  assert non_system_messages
1281
1286
 
1282
1287
  has_tool_message = messages[-1].get("role") == "tool"
@@ -651,11 +651,13 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
651
651
 
652
652
 
653
653
  class Client:
654
- def __init__(self, base_url):
654
+ def __init__(self, base_url, api_key: Optional[str] = None):
655
655
  self.base_url = base_url
656
- self._headers = {}
656
+ self._headers: Dict[str, str] = {}
657
657
  self._cluster_authed = False
658
658
  self._check_cluster_authenticated()
659
+ if api_key is not None and self._cluster_authed:
660
+ self._headers["Authorization"] = f"Bearer {api_key}"
659
661
 
660
662
  def _set_token(self, token: Optional[str]):
661
663
  if not self._cluster_authed or token is None:
@@ -795,6 +797,8 @@ class Client:
795
797
  peft_model_path: Optional[str] = None,
796
798
  image_lora_load_kwargs: Optional[Dict] = None,
797
799
  image_lora_fuse_kwargs: Optional[Dict] = None,
800
+ worker_ip: Optional[str] = None,
801
+ gpu_idx: Optional[Union[int, List[int]]] = None,
798
802
  **kwargs,
799
803
  ) -> str:
800
804
  """
@@ -828,6 +832,10 @@ class Client:
828
832
  lora load parameters for image model
829
833
  image_lora_fuse_kwargs: Optional[Dict]
830
834
  lora fuse parameters for image model
835
+ worker_ip: Optional[str]
836
+ Specify the worker ip where the model is located in a distributed scenario.
837
+ gpu_idx: Optional[Union[int, List[int]]]
838
+ Specify the GPU index where the model is located.
831
839
  **kwargs:
832
840
  Any other parameters been specified.
833
841
 
@@ -853,6 +861,8 @@ class Client:
853
861
  "peft_model_path": peft_model_path,
854
862
  "image_lora_load_kwargs": image_lora_load_kwargs,
855
863
  "image_lora_fuse_kwargs": image_lora_fuse_kwargs,
864
+ "worker_ip": worker_ip,
865
+ "gpu_idx": gpu_idx,
856
866
  }
857
867
 
858
868
  for key, value in kwargs.items():
xinference/conftest.py CHANGED
@@ -261,12 +261,23 @@ def setup_with_auth():
261
261
  if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
262
262
  raise RuntimeError("Cluster is not available after multiple attempts")
263
263
 
264
- user1 = User(username="user1", password="pass1", permissions=["admin"])
265
- user2 = User(username="user2", password="pass2", permissions=["models:list"])
264
+ user1 = User(
265
+ username="user1",
266
+ password="pass1",
267
+ permissions=["admin"],
268
+ api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"],
269
+ )
270
+ user2 = User(
271
+ username="user2",
272
+ password="pass2",
273
+ permissions=["models:list"],
274
+ api_keys=["sk-72tkvudyGLPMi"],
275
+ )
266
276
  user3 = User(
267
277
  username="user3",
268
278
  password="pass3",
269
279
  permissions=["models:list", "models:read", "models:start"],
280
+ api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"],
270
281
  )
271
282
  auth_config = AuthConfig(
272
283
  algorithm="HS256",
xinference/constants.py CHANGED
@@ -25,6 +25,7 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
25
25
  XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
26
26
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
27
27
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
28
+ XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
28
29
 
29
30
 
30
31
  def get_xinference_home() -> str:
@@ -64,3 +65,4 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
64
65
  int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
65
66
  )
66
67
  XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
68
+ XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
@@ -92,6 +92,15 @@ class SupervisorActor(xo.StatelessActor):
92
92
  def uid(cls) -> str:
93
93
  return "supervisor"
94
94
 
95
+ def _get_worker_ref_by_ip(
96
+ self, ip: str
97
+ ) -> Optional[xo.ActorRefType["WorkerActor"]]:
98
+ for addr, ref in self._worker_address_to_worker.items():
99
+ existing_ip = addr.split(":")[0]
100
+ if existing_ip == ip:
101
+ return ref
102
+ return None
103
+
95
104
  async def __post_create__(self):
96
105
  self._uptime = time.time()
97
106
  if not XINFERENCE_DISABLE_HEALTH_CHECK:
@@ -717,8 +726,25 @@ class SupervisorActor(xo.StatelessActor):
717
726
  peft_model_path: Optional[str] = None,
718
727
  image_lora_load_kwargs: Optional[Dict] = None,
719
728
  image_lora_fuse_kwargs: Optional[Dict] = None,
729
+ worker_ip: Optional[str] = None,
730
+ gpu_idx: Optional[Union[int, List[int]]] = None,
720
731
  **kwargs,
721
732
  ) -> str:
733
+ target_ip_worker_ref = (
734
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
735
+ )
736
+ if (
737
+ worker_ip is not None
738
+ and not self.is_local_deployment()
739
+ and target_ip_worker_ref is None
740
+ ):
741
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
742
+ if worker_ip is not None and self.is_local_deployment():
743
+ logger.warning(
744
+ f"You specified the worker ip: {worker_ip} in local mode, "
745
+ f"xinference will ignore this option."
746
+ )
747
+
722
748
  if model_uid is None:
723
749
  model_uid = self._gen_model_uid(model_name)
724
750
 
@@ -735,7 +761,11 @@ class SupervisorActor(xo.StatelessActor):
735
761
  )
736
762
 
737
763
  nonlocal model_type
738
- worker_ref = await self._choose_worker()
764
+ worker_ref = (
765
+ target_ip_worker_ref
766
+ if target_ip_worker_ref is not None
767
+ else await self._choose_worker()
768
+ )
739
769
  # LLM as default for compatibility
740
770
  model_type = model_type or "LLM"
741
771
  await worker_ref.launch_builtin_model(
@@ -750,6 +780,7 @@ class SupervisorActor(xo.StatelessActor):
750
780
  peft_model_path=peft_model_path,
751
781
  image_lora_load_kwargs=image_lora_load_kwargs,
752
782
  image_lora_fuse_kwargs=image_lora_fuse_kwargs,
783
+ gpu_idx=gpu_idx,
753
784
  **kwargs,
754
785
  )
755
786
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
xinference/core/worker.py CHANGED
@@ -74,6 +74,10 @@ class WorkerActor(xo.StatelessActor):
74
74
  self._model_uid_to_model_spec: Dict[str, ModelDescription] = {}
75
75
  self._gpu_to_model_uid: Dict[int, str] = {}
76
76
  self._gpu_to_embedding_model_uids: Dict[int, Set[str]] = defaultdict(set)
77
+ # Dict structure: gpu_index: {(replica_model_uid, model_type)}
78
+ self._user_specified_gpu_to_model_uids: Dict[
79
+ int, Set[Tuple[str, str]]
80
+ ] = defaultdict(set)
77
81
  self._model_uid_to_addr: Dict[str, str] = {}
78
82
  self._model_uid_to_recover_count: Dict[str, int] = {}
79
83
  self._model_uid_to_launch_args: Dict[str, Dict] = {}
@@ -268,12 +272,27 @@ class WorkerActor(xo.StatelessActor):
268
272
  """
269
273
  candidates = []
270
274
  for _dev in self._total_gpu_devices:
271
- if _dev not in self._gpu_to_model_uid:
275
+ if (
276
+ _dev not in self._gpu_to_model_uid
277
+ and _dev not in self._user_specified_gpu_to_model_uids
278
+ ): # no possible vllm model on it, add it to candidates
272
279
  candidates.append(_dev)
273
- else:
274
- existing_model_uid = self._gpu_to_model_uid[_dev]
275
- is_vllm_model = await self.is_model_vllm_backend(existing_model_uid)
276
- if not is_vllm_model:
280
+ else: # need to judge that whether to have vllm model on this device
281
+ has_vllm_model = False
282
+ if _dev in self._gpu_to_model_uid:
283
+ existing_model_uid = self._gpu_to_model_uid[_dev]
284
+ has_vllm_model = await self.is_model_vllm_backend(
285
+ existing_model_uid
286
+ )
287
+ if (
288
+ not has_vllm_model
289
+ and _dev in self._user_specified_gpu_to_model_uids
290
+ ):
291
+ for rep_uid, _ in self._user_specified_gpu_to_model_uids[_dev]:
292
+ has_vllm_model = await self.is_model_vllm_backend(rep_uid)
293
+ if has_vllm_model:
294
+ break
295
+ if not has_vllm_model:
277
296
  candidates.append(_dev)
278
297
 
279
298
  if len(candidates) == 0:
@@ -285,9 +304,13 @@ class WorkerActor(xo.StatelessActor):
285
304
  device, min_cnt = -1, -1
286
305
  # Pick the device with the fewest existing models among all the candidate devices.
287
306
  for _dev in candidates:
288
- existing_cnt = len(self._gpu_to_embedding_model_uids[_dev])
307
+ existing_cnt = 0
308
+ if _dev in self._gpu_to_embedding_model_uids:
309
+ existing_cnt += len(self._gpu_to_embedding_model_uids[_dev])
289
310
  if _dev in self._gpu_to_model_uid:
290
311
  existing_cnt += 1
312
+ if _dev in self._user_specified_gpu_to_model_uids:
313
+ existing_cnt += len(self._user_specified_gpu_to_model_uids[_dev])
291
314
  if min_cnt == -1 or existing_cnt < min_cnt:
292
315
  device, min_cnt = _dev, existing_cnt
293
316
 
@@ -295,17 +318,82 @@ class WorkerActor(xo.StatelessActor):
295
318
  return device
296
319
 
297
320
  def allocate_devices(self, model_uid: str, n_gpu: int) -> List[int]:
298
- if n_gpu > len(self._total_gpu_devices) - len(self._gpu_to_model_uid):
321
+ user_specified_allocated_devices: Set[int] = set()
322
+ for dev, model_infos in self._user_specified_gpu_to_model_uids.items():
323
+ allocated_non_embedding_rerank_models = False
324
+ for _, model_type in model_infos:
325
+ allocated_non_embedding_rerank_models = model_type not in [
326
+ "embedding",
327
+ "rerank",
328
+ ]
329
+ if allocated_non_embedding_rerank_models:
330
+ break
331
+ if allocated_non_embedding_rerank_models:
332
+ user_specified_allocated_devices.add(dev)
333
+ allocated_devices = set(self._gpu_to_model_uid.keys()).union(
334
+ user_specified_allocated_devices
335
+ )
336
+ if n_gpu > len(self._total_gpu_devices) - len(allocated_devices):
299
337
  raise RuntimeError("No available slot found for the model")
300
338
 
301
339
  devices: List[int] = [
302
- dev for dev in self._total_gpu_devices if dev not in self._gpu_to_model_uid
340
+ dev
341
+ for dev in self._total_gpu_devices
342
+ if dev not in self._gpu_to_model_uid
343
+ and dev not in user_specified_allocated_devices
303
344
  ][:n_gpu]
304
345
  for dev in devices:
305
346
  self._gpu_to_model_uid[int(dev)] = model_uid
306
347
 
307
348
  return sorted(devices)
308
349
 
350
+ async def allocate_devices_with_gpu_idx(
351
+ self, model_uid: str, model_type: str, gpu_idx: List[int]
352
+ ) -> List[int]:
353
+ """
354
+ When user specifies the gpu_idx, allocate models on user-specified GPUs whenever possible
355
+ """
356
+ # must be subset of total devices visible to this worker
357
+ if not set(gpu_idx) <= set(self._total_gpu_devices):
358
+ raise ValueError(
359
+ f"Worker {self.address} cannot use the GPUs with these indexes: {gpu_idx}. "
360
+ f"Worker {self.address} can only see these GPUs: {self._total_gpu_devices}."
361
+ )
362
+ # currently just report a warning log when there are already models on these GPUs
363
+ for idx in gpu_idx:
364
+ existing_model_uids = []
365
+ if idx in self._gpu_to_model_uid:
366
+ rep_uid = self._gpu_to_model_uid[idx]
367
+ is_vllm_model = await self.is_model_vllm_backend(rep_uid)
368
+ if is_vllm_model:
369
+ raise RuntimeError(
370
+ f"GPU index {idx} has been occupied with a vLLM model: {rep_uid}, "
371
+ f"therefore cannot allocate GPU memory for a new model."
372
+ )
373
+ existing_model_uids.append(rep_uid)
374
+ if idx in self._gpu_to_embedding_model_uids:
375
+ existing_model_uids.extend(self._gpu_to_embedding_model_uids[idx])
376
+ # If user has run the vLLM model on the GPU that was forced to be specified,
377
+ # it is not possible to force this GPU to be allocated again
378
+ if idx in self._user_specified_gpu_to_model_uids:
379
+ for rep_uid, _ in self._user_specified_gpu_to_model_uids[idx]:
380
+ is_vllm_model = await self.is_model_vllm_backend(rep_uid)
381
+ if is_vllm_model:
382
+ raise RuntimeError(
383
+ f"User specified GPU index {idx} has been occupied with a vLLM model: {rep_uid}, "
384
+ f"therefore cannot allocate GPU memory for a new model."
385
+ )
386
+
387
+ if existing_model_uids:
388
+ logger.warning(
389
+ f"WARNING!!! GPU index {idx} has been occupied "
390
+ f"with these models on it: {existing_model_uids}"
391
+ )
392
+
393
+ for idx in gpu_idx:
394
+ self._user_specified_gpu_to_model_uids[idx].add((model_uid, model_type))
395
+ return sorted(gpu_idx)
396
+
309
397
  def release_devices(self, model_uid: str):
310
398
  devices = [
311
399
  dev
@@ -320,27 +408,46 @@ class WorkerActor(xo.StatelessActor):
320
408
  if model_uid in self._gpu_to_embedding_model_uids[dev]:
321
409
  self._gpu_to_embedding_model_uids[dev].remove(model_uid)
322
410
 
411
+ # check user-specified slots
412
+ for dev in self._user_specified_gpu_to_model_uids:
413
+ model_infos = list(
414
+ filter(
415
+ lambda x: x[0] == model_uid,
416
+ self._user_specified_gpu_to_model_uids[dev],
417
+ )
418
+ )
419
+ for model_info in model_infos:
420
+ self._user_specified_gpu_to_model_uids[dev].remove(model_info)
421
+
323
422
  async def _create_subpool(
324
423
  self,
325
424
  model_uid: str,
326
425
  model_type: Optional[str] = None,
327
426
  n_gpu: Optional[Union[int, str]] = "auto",
427
+ gpu_idx: Optional[List[int]] = None,
328
428
  ) -> Tuple[str, List[str]]:
329
429
  env = {}
330
430
  devices = []
331
- if isinstance(n_gpu, int) or (n_gpu == "auto" and gpu_count() > 0):
332
- # Currently, n_gpu=auto means using 1 GPU
333
- gpu_cnt = n_gpu if isinstance(n_gpu, int) else 1
334
- devices = (
335
- [await self.allocate_devices_for_embedding(model_uid)]
336
- if model_type in ["embedding", "rerank"]
337
- else self.allocate_devices(model_uid=model_uid, n_gpu=gpu_cnt)
431
+ if gpu_idx is None:
432
+ if isinstance(n_gpu, int) or (n_gpu == "auto" and gpu_count() > 0):
433
+ # Currently, n_gpu=auto means using 1 GPU
434
+ gpu_cnt = n_gpu if isinstance(n_gpu, int) else 1
435
+ devices = (
436
+ [await self.allocate_devices_for_embedding(model_uid)]
437
+ if model_type in ["embedding", "rerank"]
438
+ else self.allocate_devices(model_uid=model_uid, n_gpu=gpu_cnt)
439
+ )
440
+ env["CUDA_VISIBLE_DEVICES"] = ",".join([str(dev) for dev in devices])
441
+ logger.debug(f"GPU selected: {devices} for model {model_uid}")
442
+ if n_gpu is None:
443
+ env["CUDA_VISIBLE_DEVICES"] = "-1"
444
+ logger.debug(f"GPU disabled for model {model_uid}")
445
+ else:
446
+ assert isinstance(gpu_idx, list)
447
+ devices = await self.allocate_devices_with_gpu_idx(
448
+ model_uid, model_type, gpu_idx # type: ignore
338
449
  )
339
450
  env["CUDA_VISIBLE_DEVICES"] = ",".join([str(dev) for dev in devices])
340
- logger.debug(f"GPU selected: {devices} for model {model_uid}")
341
- if n_gpu is None:
342
- env["CUDA_VISIBLE_DEVICES"] = "-1"
343
- logger.debug(f"GPU disabled for model {model_uid}")
344
451
 
345
452
  if os.name != "nt" and platform.system() != "Darwin":
346
453
  # Linux
@@ -495,6 +602,7 @@ class WorkerActor(xo.StatelessActor):
495
602
  image_lora_load_kwargs: Optional[Dict] = None,
496
603
  image_lora_fuse_kwargs: Optional[Dict] = None,
497
604
  request_limits: Optional[int] = None,
605
+ gpu_idx: Optional[Union[int, List[int]]] = None,
498
606
  **kwargs,
499
607
  ):
500
608
  event_model_uid, _, __ = parse_replica_model_uid(model_uid)
@@ -510,6 +618,17 @@ class WorkerActor(xo.StatelessActor):
510
618
  launch_args.pop("self")
511
619
  launch_args.pop("kwargs")
512
620
  launch_args.update(kwargs)
621
+
622
+ if gpu_idx is not None:
623
+ logger.info(
624
+ f"You specify to launch the model: {model_name} on GPU index: {gpu_idx} "
625
+ f"of the worker: {self.address}, "
626
+ f"xinference will automatically ignore the `n_gpu` option."
627
+ )
628
+ if isinstance(gpu_idx, int):
629
+ gpu_idx = [gpu_idx]
630
+ assert isinstance(gpu_idx, list)
631
+
513
632
  if n_gpu is not None:
514
633
  if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > gpu_count()):
515
634
  raise ValueError(
@@ -535,7 +654,7 @@ class WorkerActor(xo.StatelessActor):
535
654
  is_local_deployment = await self._supervisor_ref.is_local_deployment()
536
655
 
537
656
  subpool_address, devices = await self._create_subpool(
538
- model_uid, model_type, n_gpu=n_gpu
657
+ model_uid, model_type, n_gpu=n_gpu, gpu_idx=gpu_idx
539
658
  )
540
659
 
541
660
  try: